策略模式
大约 11 分钟
策略模式
什么是策略模式
策略模式(Strategy Pattern)是一种行为型设计模式,它定义了一系列算法,并将每个算法封装起来,使它们可以相互替换。策略模式让算法的变化独立于使用算法的客户,使得算法可以独立于客户端而变化。
策略模式的核心思想是:
- 定义一系列可互换的算法
- 将每个算法封装在独立的类中
- 使算法可以独立于使用它的客户端而变化
- 客户端可以在运行时选择不同的算法
为什么需要策略模式
在实际开发中,我们经常会遇到需要在运行时根据不同的条件选择不同算法的情况。如果使用条件语句(if-else或switch-case)来实现,会导致代码复杂、难以维护和扩展。策略模式通过将算法封装成独立的类,使得算法的增加、修改和删除变得更加容易,同时也提高了代码的可读性和可维护性。
策略模式的结构
策略模式包含以下几个角色:
- 策略接口(Strategy):定义所有支持的算法的公共接口
- 具体策略类(ConcreteStrategy):实现了策略接口的具体算法
- 上下文类(Context):使用某个策略的类,维护一个对策略对象的引用
策略模式的实现
基本实现
// 策略接口
interface Strategy {
int doOperation(int num1, int num2);
}
// 具体策略类 - 加法
class AddStrategy implements Strategy {
@Override
public int doOperation(int num1, int num2) {
return num1 + num2;
}
}
// 具体策略类 - 减法
class SubtractStrategy implements Strategy {
@Override
public int doOperation(int num1, int num2) {
return num1 - num2;
}
}
// 具体策略类 - 乘法
class MultiplyStrategy implements Strategy {
@Override
public int doOperation(int num1, int num2) {
return num1 * num2;
}
}
// 具体策略类 - 除法
class DivideStrategy implements Strategy {
@Override
public int doOperation(int num1, int num2) {
if (num2 != 0) {
return num1 / num2;
}
throw new IllegalArgumentException("除数不能为零");
}
}
// 上下文类
class Context {
private Strategy strategy;
public Context(Strategy strategy) {
this.strategy = strategy;
}
public void setStrategy(Strategy strategy) {
this.strategy = strategy;
}
public int executeStrategy(int num1, int num2) {
return strategy.doOperation(num1, num2);
}
}
// 使用示例
public class StrategyDemo {
public static void main(String[] args) {
// 创建上下文对象
Context context = new Context(new AddStrategy());
// 执行不同的策略
System.out.println("=== 策略模式演示 ===");
System.out.println("10 + 5 = " + context.executeStrategy(10, 5));
context.setStrategy(new SubtractStrategy());
System.out.println("10 - 5 = " + context.executeStrategy(10, 5));
context.setStrategy(new MultiplyStrategy());
System.out.println("10 * 5 = " + context.executeStrategy(10, 5));
context.setStrategy(new DivideStrategy());
System.out.println("10 / 5 = " + context.executeStrategy(10, 5));
}
}
改进的实现(枚举策略)
// 改进的策略枚举
enum Operation {
ADD {
@Override
public double apply(double x, double y) {
return x + y;
}
},
SUBTRACT {
@Override
public double apply(double x, double y) {
return x - y;
}
},
MULTIPLY {
@Override
public double apply(double x, double y) {
return x * y;
}
},
DIVIDE {
@Override
public double apply(double x, double y) {
if (y != 0) {
return x / y;
}
throw new IllegalArgumentException("除数不能为零");
}
};
public abstract double apply(double x, double y);
}
// 改进的上下文类
class Calculator {
public double calculate(double x, double y, Operation operation) {
return operation.apply(x, y);
}
}
// 使用示例
public class EnumStrategyDemo {
public static void main(String[] args) {
Calculator calculator = new Calculator();
System.out.println("=== 枚举策略模式演示 ===");
System.out.println("10 + 5 = " + calculator.calculate(10, 5, Operation.ADD));
System.out.println("10 - 5 = " + calculator.calculate(10, 5, Operation.SUBTRACT));
System.out.println("10 * 5 = " + calculator.calculate(10, 5, Operation.MULTIPLY));
System.out.println("10 / 5 = " + calculator.calculate(10, 5, Operation.DIVIDE));
}
}
策略模式的应用场景
- 排序算法:根据数据特征选择不同的排序算法
- 支付方式:支持多种支付方式(信用卡、支付宝、微信等)
- 压缩算法:支持不同的压缩算法(ZIP、RAR、7Z等)
- 路由算法:网络路由选择不同的路径算法
- 缓存策略:不同的缓存淘汰策略(LRU、FIFO、LFU等)
- 验证规则:不同的数据验证规则
策略模式的优缺点
优点
- 算法可以自由切换:客户端可以在运行时选择不同的算法
- 避免使用多重条件判断:将算法封装成类,消除复杂的条件语句
- 扩展性良好:增加新的算法无须修改原有代码
- 符合开闭原则:对扩展开放,对修改关闭
缺点
- 增加类的数量:每个算法都需要一个类,可能会增加系统中类的数量
- 客户端必须知道所有的策略类:客户端需要知道所有策略类的区别
- 可能产生对象膨胀:如果策略过多,会产生大量策略类对象
策略模式与其他模式的比较
与状态模式的区别
- 策略模式:算法之间不能相互转换,客户端主动选择算法
- 状态模式:状态之间可以相互转换,对象的行为随着状态的改变而改变
与命令模式的区别
- 策略模式:关注算法的封装和替换
- 命令模式:关注将请求封装成对象
与模板方法模式的区别
- 策略模式:使用组合来实现算法的替换
- 模板方法模式:使用继承来实现算法的固定流程
实际项目中的应用
// 电商支付系统示例
// 支付策略接口
interface PaymentStrategy {
void pay(double amount);
String getPaymentMethod();
}
// 信用卡支付策略
class CreditCardPayment implements PaymentStrategy {
private String cardNumber;
private String cardHolderName;
private String cvv;
private String expirationDate;
public CreditCardPayment(String cardNumber, String cardHolderName, String cvv, String expirationDate) {
this.cardNumber = cardNumber;
this.cardHolderName = cardHolderName;
this.cvv = cvv;
this.expirationDate = expirationDate;
}
@Override
public void pay(double amount) {
System.out.println("使用信用卡支付 $" + amount);
System.out.println("卡号: " + maskCardNumber(cardNumber));
System.out.println("持卡人: " + cardHolderName);
System.out.println("支付完成");
}
@Override
public String getPaymentMethod() {
return "信用卡";
}
private String maskCardNumber(String cardNumber) {
if (cardNumber.length() > 4) {
return "**** **** **** " + cardNumber.substring(cardNumber.length() - 4);
}
return cardNumber;
}
}
// 支付宝支付策略
class AlipayPayment implements PaymentStrategy {
private String alipayAccount;
public AlipayPayment(String alipayAccount) {
this.alipayAccount = alipayAccount;
}
@Override
public void pay(double amount) {
System.out.println("使用支付宝支付 $" + amount);
System.out.println("支付宝账号: " + maskAccount(alipayAccount));
System.out.println("跳转到支付宝页面...");
System.out.println("支付完成");
}
@Override
public String getPaymentMethod() {
return "支付宝";
}
private String maskAccount(String account) {
if (account.contains("@")) {
String[] parts = account.split("@");
if (parts[0].length() > 2) {
return parts[0].substring(0, 2) + "****@" + parts[1];
}
} else if (account.length() > 4) {
return account.substring(0, 3) + "****" + account.substring(account.length() - 4);
}
return account;
}
}
// 微信支付策略
class WeChatPayment implements PaymentStrategy {
private String wechatAccount;
public WeChatPayment(String wechatAccount) {
this.wechatAccount = wechatAccount;
}
@Override
public void pay(double amount) {
System.out.println("使用微信支付 $" + amount);
System.out.println("微信账号: " + maskAccount(wechatAccount));
System.out.println("打开微信扫描二维码...");
System.out.println("支付完成");
}
@Override
public String getPaymentMethod() {
return "微信支付";
}
private String maskAccount(String account) {
if (account.length() > 4) {
return account.substring(0, 2) + "****" + account.substring(account.length() - 2);
}
return account;
}
}
// 支付上下文类
class PaymentContext {
private PaymentStrategy paymentStrategy;
public void setPaymentStrategy(PaymentStrategy paymentStrategy) {
this.paymentStrategy = paymentStrategy;
}
public void executePayment(double amount) {
if (paymentStrategy != null) {
System.out.println("开始支付...");
paymentStrategy.pay(amount);
System.out.println("支付方式: " + paymentStrategy.getPaymentMethod());
System.out.println("支付成功!");
} else {
System.out.println("未选择支付方式");
}
}
}
// 使用示例
public class ECommercePaymentDemo {
public static void main(String[] args) {
// 创建支付上下文
PaymentContext paymentContext = new PaymentContext();
// 模拟购物支付
System.out.println("=== 电商支付系统演示 ===");
double amount = 299.99;
System.out.println("订单金额: $" + amount);
// 使用信用卡支付
System.out.println("\n1. 使用信用卡支付:");
PaymentStrategy creditCard = new CreditCardPayment("1234567890123456", "张三", "123", "12/25");
paymentContext.setPaymentStrategy(creditCard);
paymentContext.executePayment(amount);
// 使用支付宝支付
System.out.println("\n2. 使用支付宝支付:");
PaymentStrategy alipay = new AlipayPayment("zhangsan@example.com");
paymentContext.setPaymentStrategy(alipay);
paymentContext.executePayment(amount);
// 使用微信支付
System.out.println("\n3. 使用微信支付:");
PaymentStrategy wechat = new WeChatPayment("zhangsan123");
paymentContext.setPaymentStrategy(wechat);
paymentContext.executePayment(amount);
}
}
// 排序算法示例
// 排序策略接口
interface SortStrategy {
void sort(int[] array);
String getSortName();
}
// 冒泡排序策略
class BubbleSortStrategy implements SortStrategy {
@Override
public void sort(int[] array) {
int n = array.length;
for (int i = 0; i < n - 1; i++) {
for (int j = 0; j < n - i - 1; j++) {
if (array[j] > array[j + 1]) {
// 交换元素
int temp = array[j];
array[j] = array[j + 1];
array[j + 1] = temp;
}
}
}
}
@Override
public String getSortName() {
return "冒泡排序";
}
}
// 快速排序策略
class QuickSortStrategy implements SortStrategy {
@Override
public void sort(int[] array) {
quickSort(array, 0, array.length - 1);
}
private void quickSort(int[] array, int low, int high) {
if (low < high) {
int pi = partition(array, low, high);
quickSort(array, low, pi - 1);
quickSort(array, pi + 1, high);
}
}
private int partition(int[] array, int low, int high) {
int pivot = array[high];
int i = (low - 1);
for (int j = low; j < high; j++) {
if (array[j] <= pivot) {
i++;
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}
}
int temp = array[i + 1];
array[i + 1] = array[high];
array[high] = temp;
return i + 1;
}
@Override
public String getSortName() {
return "快速排序";
}
}
// 归并排序策略
class MergeSortStrategy implements SortStrategy {
@Override
public void sort(int[] array) {
mergeSort(array, 0, array.length - 1);
}
private void mergeSort(int[] array, int left, int right) {
if (left < right) {
int middle = (left + right) / 2;
mergeSort(array, left, middle);
mergeSort(array, middle + 1, right);
merge(array, left, middle, right);
}
}
private void merge(int[] array, int left, int middle, int right) {
int n1 = middle - left + 1;
int n2 = right - middle;
int[] leftArray = new int[n1];
int[] rightArray = new int[n2];
for (int i = 0; i < n1; ++i)
leftArray[i] = array[left + i];
for (int j = 0; j < n2; ++j)
rightArray[j] = array[middle + 1 + j];
int i = 0, j = 0;
int k = left;
while (i < n1 && j < n2) {
if (leftArray[i] <= rightArray[j]) {
array[k] = leftArray[i];
i++;
} else {
array[k] = rightArray[j];
j++;
}
k++;
}
while (i < n1) {
array[k] = leftArray[i];
i++;
k++;
}
while (j < n2) {
array[k] = rightArray[j];
j++;
k++;
}
}
@Override
public String getSortName() {
return "归并排序";
}
}
// 排序上下文类
class SortContext {
private SortStrategy sortStrategy;
public void setSortStrategy(SortStrategy sortStrategy) {
this.sortStrategy = sortStrategy;
}
public void executeSort(int[] array) {
if (sortStrategy != null) {
System.out.println("使用 " + sortStrategy.getSortName() + " 排序:");
System.out.println("排序前: " + Arrays.toString(array));
sortStrategy.sort(array);
System.out.println("排序后: " + Arrays.toString(array));
} else {
System.out.println("未选择排序策略");
}
}
}
// 排序策略选择器
class SortStrategySelector {
public static SortStrategy selectStrategy(int[] array) {
// 根据数组大小选择不同的排序策略
if (array.length <= 10) {
System.out.println("数组较小,选择冒泡排序");
return new BubbleSortStrategy();
} else if (array.length <= 1000) {
System.out.println("数组中等,选择快速排序");
return new QuickSortStrategy();
} else {
System.out.println("数组较大,选择归并排序");
return new MergeSortStrategy();
}
}
}
// 使用示例
public class SortingStrategyDemo {
public static void main(String[] args) {
// 创建排序上下文
SortContext sortContext = new SortContext();
// 测试小数组
System.out.println("=== 小数组排序演示 ===");
int[] smallArray = {64, 34, 25, 12, 22, 11, 90};
SortStrategy smallStrategy = SortStrategySelector.selectStrategy(smallArray);
sortContext.setSortStrategy(smallStrategy);
sortContext.executeSort(smallArray.clone());
// 测试中等数组
System.out.println("\n=== 中等数组排序演示 ===");
int[] mediumArray = new int[100];
Random random = new Random();
for (int i = 0; i < mediumArray.length; i++) {
mediumArray[i] = random.nextInt(1000);
}
SortStrategy mediumStrategy = SortStrategySelector.selectStrategy(mediumArray);
sortContext.setSortStrategy(mediumStrategy);
sortContext.executeSort(mediumArray.clone());
// 测试大数组
System.out.println("\n=== 大数组排序演示 ===");
int[] largeArray = new int[5000];
for (int i = 0; i < largeArray.length; i++) {
largeArray[i] = random.nextInt(10000);
}
SortStrategy largeStrategy = SortStrategySelector.selectStrategy(largeArray);
sortContext.setSortStrategy(largeStrategy);
long startTime = System.currentTimeMillis();
sortContext.executeSort(largeArray.clone());
long endTime = System.currentTimeMillis();
System.out.println("排序耗时: " + (endTime - startTime) + " 毫秒");
}
}
Java中的策略模式应用
// Java Collections框架中的策略模式示例
import java.util.*;
// 自定义比较策略
class NameComparator implements Comparator<String> {
@Override
public int compare(String s1, String s2) {
return s1.compareTo(s2);
}
}
class LengthComparator implements Comparator<String> {
@Override
public int compare(String s1, String s2) {
return Integer.compare(s1.length(), s2.length());
}
}
class ReverseComparator implements Comparator<String> {
@Override
public int compare(String s1, String s2) {
return s2.compareTo(s1);
}
}
// 使用示例
public class JavaCollectionsStrategyDemo {
public static void main(String[] args) {
List<String> names = new ArrayList<>();
names.add("Alice");
names.add("Bob");
names.add("Charlie");
names.add("David");
names.add("Eve");
System.out.println("=== Java Collections策略模式演示 ===");
System.out.println("原始列表: " + names);
// 使用不同的比较策略排序
System.out.println("\n1. 按字母顺序排序:");
Collections.sort(names, new NameComparator());
System.out.println(names);
System.out.println("\n2. 按长度排序:");
Collections.sort(names, new LengthComparator());
System.out.println(names);
System.out.println("\n3. 按字母逆序排序:");
Collections.sort(names, new ReverseComparator());
System.out.println(names);
// 使用Lambda表达式(Java 8+)
System.out.println("\n4. 使用Lambda表达式按长度排序:");
Collections.sort(names, (s1, s2) -> Integer.compare(s1.length(), s2.length()));
System.out.println(names);
// 使用方法引用
System.out.println("\n5. 使用方法引用按字母顺序排序:");
Collections.sort(names, String::compareTo);
System.out.println(names);
}
}
// 线程池执行策略示例
import java.util.concurrent.*;
// 自定义拒绝策略
class CustomRejectedExecutionHandler implements RejectedExecutionHandler {
@Override
public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
System.out.println("任务被拒绝: " + r.toString());
// 可以选择其他处理方式,如记录日志、放入队列等
try {
// 尝试将任务放入队列
executor.getQueue().put(r);
System.out.println("任务已放入队列");
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
System.out.println("任务放入队列时被中断");
}
}
}
// 使用示例
public class ThreadPoolStrategyDemo {
public static void main(String[] args) throws InterruptedException {
System.out.println("=== 线程池策略模式演示 ===");
// 创建不同策略的线程池
// 1. 默认策略(AbortPolicy)
ThreadPoolExecutor executor1 = new ThreadPoolExecutor(
2, 4, 60L, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(2),
Executors.defaultThreadFactory(),
new ThreadPoolExecutor.AbortPolicy()
);
// 2. 自定义策略
ThreadPoolExecutor executor2 = new ThreadPoolExecutor(
2, 4, 60L, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(2),
Executors.defaultThreadFactory(),
new CustomRejectedExecutionHandler()
);
// 提交任务
System.out.println("1. 使用默认拒绝策略:");
for (int i = 0; i < 10; i++) {
final int taskId = i;
try {
executor1.submit(() -> {
System.out.println("任务 " + taskId + " 正在执行");
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
System.out.println("任务 " + taskId + " 执行完成");
});
} catch (RejectedExecutionException e) {
System.out.println("任务 " + taskId + " 被拒绝: " + e.getMessage());
}
}
Thread.sleep(5000);
executor1.shutdown();
System.out.println("\n2. 使用自定义拒绝策略:");
for (int i = 0; i < 10; i++) {
final int taskId = i;
executor2.submit(() -> {
System.out.println("任务 " + taskId + " 正在执行");
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
System.out.println("任务 " + taskId + " 执行完成");
});
}
Thread.sleep(10000);
executor2.shutdown();
}
}
总结
策略模式是一种非常实用的行为型设计模式,它定义了一系列算法,并将每个算法封装起来,使它们可以相互替换。策略模式让算法的变化独立于使用算法的客户,使得算法可以独立于客户端而变化。
使用策略模式的关键点:
- 识别需要不同实现的算法族
- 定义策略接口和具体策略类
- 在上下文类中维护对策略对象的引用
- 客户端可以在运行时选择不同的策略
策略模式的优点是算法可以自由切换,避免使用多重条件判断,扩展性良好,符合开闭原则。但需要注意的是,会增加系统中类的数量,客户端必须知道所有的策略类。在现代Java开发中,策略模式广泛应用于支付系统、排序算法、缓存策略、验证规则等需要算法切换的场景。Java标准库中的Collections框架、线程池等都大量使用了策略模式。