基于SVM的数字手势识别模型

1.将一张图片扩展为n张图片

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
 
datagen = ImageDataGenerator(
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')
 
 
img = load_img('E:\\1.jpg') 
x = img_to_array(img) 
x = x.reshape((1,) + x.shape)
 
i = 0
for batch in datagen.flow(x, batch_size=1,
                          save_to_dir='E:\Pictures_ML', save_prefix='num', save_format='jpeg'):
    i += 1
    if i >10: 
        break

2.将图片压缩为64*48像素

from glob import glob
from PIL import Image
import os

source_dir = 'E:\Pictures_ML'
target_dir = 'E:\Pictures_64'

filenames = glob('{}/*'.format(source_dir))
print(filenames)
for filename in filenames:
    with Image.open(filename) as im:
        width, height = im.size
        print(filename, width, height, os.path.getsize(filename))
        threshold = 2*64*64
        for filename in filenames:
            filesize = os.path.getsize(filename)
            if filesize >= threshold:
                print(filename)
                
if not os.path.exists(target_dir):
    os.makedirs(target_dir)
    
        
for filename in filenames:
    filesize = os.path.getsize(filename)
    if filesize >= threshold:
        print(filename)
        with Image.open(filename) as im:
            width, height = im.size
            new_width = 64
            new_height =48
            print('adjusted size:', new_width, new_height)
            resized_im = im.resize((new_width, new_height))
            output_filename = filename.replace(source_dir, target_dir)
            resized_im.save(output_filename)

3.将文件夹内图片转pyh5格式

from PIL import Image
import os
import numpy
import h5py
from sklearn.model_selection import train_test_split

dirs = os.listdir("E:\Pictures_64")
Y = []
X = []
for filename in dirs:
    print(filename)
    label = int(filename.split('_')[0])
    print(label)
    Y.append(label)
    im = Image.open("E:\Pictures_64//{}".format(filename)).convert('RGB')
    mat = numpy.asarray(im)
    mat_1=mat.flatten()
    X.append(mat_1)
print('-------------------------分割线------------------------------------')
    
file = h5py.File("E:\\filename.h5","a")
file.create_dataset('X', data=numpy.array(X))
file.create_dataset('Y', data=numpy.array(Y))
file.close()

4.读取h5文件内容训练模型

from PIL import Image
import os
import numpy
import h5py
import time
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn import datasets

data = h5py.File("E:\\filename.h5","r")
X_data = data['X']
Y_data = data['Y']

print(X_data.shape)
print(Y_data.shape)

print('-------------------------分割线------------------------------------')

X_arr=numpy.array(X_data)
Y_arr=numpy.array(Y_data)

X_train, X_test, y_train, y_test = train_test_split(X_arr,Y_arr, test_size=0.25, random_state=10)
print(y_train.shape)
print(y_test.shape)

print('-------------------------分割线------------------------------------')
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
ss = StandardScaler()
X_train = ss.fit_transform(X_train)
X_test = ss.transform(X_test)
lsvc = LinearSVC()
lsvc.fit(X_train, y_train)
y_predict = lsvc.predict(X_test)
print ('The Accuracy of Linear SVC is', lsvc.score(X_test, y_test))

print('-------------------------分割线------------------------------------')
z_predict=lsvc.predict(X_arr)
print(z_predict[200])

from sklearn.externals import joblib
# 保存的模型的文件名
file = 'E:\\svm.joblib'
# 保存模型
joblib.dump(lsvc,file)
# 读取模型
svm_model = joblib.load(file)

5.读取保存的模型并测试

from PIL import Image
import os
import numpy
import h5py
import time
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn import datasets
from sklearn.externals import joblib

# 保存的模型的文件名
file = "E:\\svm.joblib"
# 读取模型
lsvc= joblib.load(file)

data = h5py.File("E:\\filename.h5","r")
X_data = data['X']
Y_data = data['Y']

X_arr=numpy.array(X_data)
Y_arr=numpy.array(Y_data)

Z_predict=lsvc.predict(X_arr)
print(Z_predict[209])

import os

os.system("explorer c:\program files")

猜你喜欢

转载自blog.csdn.net/qq_43182766/article/details/89281433