-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSVM.py
More file actions
97 lines (82 loc) · 3.49 KB
/
SVM.py
File metadata and controls
97 lines (82 loc) · 3.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
# 1. 加载数据(鸢尾花前2个特征,3个类别)
iris = datasets.load_iris()
X_original = iris.data[:, :2] # 萼片长度、萼片宽度(原始特征,未标准化)
y = iris.target # 0:山鸢尾, 1:变色鸢尾, 2:维吉尼亚鸢尾
feature_names = iris.feature_names[:2]
# 2. 先划分训练集和测试集(使用原始数据)
X_train_original, X_test_original, y_train, y_test = train_test_split(
X_original, y, test_size=0.3, random_state=42
)
# 3. 特征标准化(仅用训练集拟合scaler,避免数据 泄漏)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train_original) # 训练集:拟合+转换
X_test = scaler.transform(X_test_original) # 测试集:仅转换(用训练集的均值/标准差)
# 4. 网格搜索最优参数(C和gamma)
param_grid = {
'C': [0.1, 1, 10, 100], # 正则化强度
'gamma': [0.01, 0.1, 1, 10] # 核函数范围参数
}
svm_clf = SVC(kernel='rbf') # RBF核SVM
grid_search = GridSearchCV(
svm_clf, param_grid, cv=5, scoring='accuracy', n_jobs=-1
)
grid_search.fit(X_train, y_train)
# 获取最佳参数并训练模型
best_params = grid_search.best_params_
print(f"最佳参数:C={best_params['C']}, gamma={best_params['gamma']}")
best_svm = grid_search.best_estimator_
# 5. 预测与评估
y_pred = best_svm.predict(X_test)
print(f"测试集准确率:{accuracy_score(y_test, y_pred):.2f}")
# 6. 可视化决策边界(确保坐标系统一)
def plot_decision_boundary(model, X_original, scaler, feature_names):
"""
可视化SVM决策边界
X_original:原始特征(未标准化,用于绘图坐标)
scaler:用训练集拟合的标准化器
"""
# 生成网格范围(基于原始数据,保证坐标直观)
x_min, x_max = X_original[:, 0].min() - 1, X_original[:, 0].max() + 1
y_min, y_max = X_original[:, 1].min() - 1, X_original[:, 1].max() + 1
xx, yy = np.meshgrid(
np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02)
)
# 网格数据标准化(与模型训练的尺度一致)
grid_data = np.c_[xx.ravel(), yy.ravel()]
grid_data_scaled = scaler.transform(grid_data) # 使用训练集拟合的scaler
# 预测并重塑形状
Z = model.predict(grid_data_scaled)
Z = Z.reshape(xx.shape)
# 绘图
plt.figure(figsize=(10, 8))
# 绘制决策区域
plt.contourf(xx, yy, Z, alpha=0.8, cmap=plt.cm.coolwarm)
# 绘制原始数据点
plt.scatter(
X_original[:, 0], X_original[:, 1],
c=y, edgecolors='k', marker='o', s=80,
cmap=plt.cm.coolwarm, label='数据点'
)
# 标记测试集错误分类的点
y_pred_all = model.predict(scaler.transform(X_original)) # 对所有原始数据预测
misclassified = X_original[y_pred_all != y]
plt.scatter(
misclassified[:, 0], misclassified[:, 1],
facecolors='none', edgecolors='yellow', marker='o', s=150,
label='错误分类'
)
plt.xlabel(feature_names[0], fontsize=12)
plt.ylabel(feature_names[1], fontsize=12)
plt.title(f"RBF核SVM决策边界 (C={best_params['C']}, gamma={best_params['gamma']})", fontsize=14)
plt.legend()
plt.show()
# 执行可视化
plot_decision_boundary(best_svm, X_original, scaler, feature_names)