亲测,很准
from sklearn import svm, datasets
class Dataset:
def __init__(self, name):
self.name = name
#下载数据集
def download_data(self):
if self.name == 'iris':
self.downloaded_data = datasets.load_iris()
elif self.name == 'digits':
self.downloaded_data = datasets.load_digits()
else:
print('Dataset error: no named datasets')
# 获取数据集data 及 target
def generate_xy(self):
self.download_data()
x = self.downloaded_data.data
y = self.downloaded_data.target
print('\nOrginal data looks like this: \n', x)
print('\nLabels looks like this: \n', y)
return x,y
# 分割训练及测试数据集
def get_train_test_set(self, ratio):
x, y = self.generate_xy()
n_samples = len(x)
n_train = (int)(n_samples * ratio)
x_train = x[:n_train]
y_train = y[:n_train]
x_test = x[n_train:]
y_test = y[n_train:]
return x_train, y_train, x_test, y_test
data = Dataset('digits')
x_train, y_train, x_test, y_test = data.get_train_test_set(0.7)
print('*****')
print(x_test[12])
print(y_test[12])
print('*****')
clf = svm.SVC()
clf.fit(x_train, y_train)
test_point = x_test[12]
y_true = y_test[12]
print(clf.predict(test_point))