python—将array格式图片保存至文件夹中

将array格式的图像保存至路径中

灰度数字图像是每个像素只有一个采样颜色的图像。这类图像通常显示为从最暗黑色到最亮的白色的灰度

def save_image(im, i):
	# 对图像进行反相处理
    im = 255 - im
    # 转换数组的类型
    a = im.astype(np.uint8)
    output_path = '.\\HandWritten'
    # 判断路径是否存在
    if not os.path.exists(output_path):
    	# 如果路径不存在,则创建对应路径
        os.mkdir(output_path)
     # array()变换的相反操作可以使用PIL的fromarray()完成,如im = Image.fromarray(im)
     # Image.save—保存路径
    Image.fromarray(a).save(output_path + ('\\%d.png' % i))

案例

链接—提取码:1234

#!/usr/bin/python
# -*- coding:utf-8 -*-

import numpy as np
from sklearn import svm
import matplotlib.colors
import matplotlib.pyplot as plt
from PIL import Image
import warnings
from sklearn.metrics import accuracy_score
import pandas as pd
import os
import csv
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from time import time
from pprint import pprint


def save_image(im, i):
    im = 255 - im
    a = im.astype(np.uint8)
    output_path = '.\\HandWritten'
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    Image.fromarray(a).save(output_path + ('\\%d.png' % i))


def save_result(model):
    data_test_hat = model.predict(data_test)
    with open('Prediction.csv', 'wb') as f:
        writer = csv.writer(f)
        writer.writerow(['ImageId', 'Label'])
        for i, d in enumerate(data_test_hat):
            writer.writerow([i, d])
        # writer.writerows(zip(np.arange(1, len(data_test_hat) + 1), data_test_hat))


if __name__ == "__main__":
    # 消除警告
    warnings.filterwarnings(action='ignore')
    classifier_type = 'RF'

    print('载入训练数据...')
    t = time()
    data = pd.read_csv('.\\MNIST.train.csv', header=0, dtype=np.int)
    print('载入完成,耗时%f秒' % (time() - t))
    y = data['label'].values
    x = data.values[:, 1:]
    print('图片个数:%d,图片像素数目:%d' % x.shape)
    images = x.reshape(-1, 28, 28)
    y = y.ravel()

    print('载入测试数据...')
    t = time()
    data_test = pd.read_csv('.\\MNIST.test.csv', header=0, dtype=np.int)
    data_test = data_test.values
    images_test_result = data_test.reshape(-1, 28, 28)
    print('载入完成,耗时%f秒' % (time() - t))

    np.random.seed(0)
    x, x_test, y, y_test = train_test_split(x, y, train_size=0.8, random_state=1)
    images = x.reshape(-1, 28, 28)
    images_test = x_test.reshape(-1, 28, 28)
    print(x.shape, x_test.shape)

    matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
    matplotlib.rcParams['axes.unicode_minus'] = False
    plt.figure(figsize=(15, 9), facecolor='w')
    for index, image in enumerate(images[:16]):
        plt.subplot(4, 8, index + 1)
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title(u'训练图片: %i' % y[index])
    for index, image in enumerate(images_test_result[:16]):
        plt.subplot(4, 8, index + 17)
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        save_image(image.copy(), index)
        plt.title(u'测试图片')
    plt.tight_layout()
    plt.show()

    # SVM
    if classifier_type == 'SVM':
        params = {
    
    'C':np.logspace(1, 4, 4, base=10), 'gamma':np.logspace(-10, -2, 9, base=10)}
        clf = svm.SVC(kernel='rbf')
        model = GridSearchCV(clf, param_grid=params, cv=3)
        # model = svm.SVC(C=1000, kernel='rbf', gamma=1e-10)
        print('SVM开始训练...')
        t = time()
        model.fit(x, y)
        t = time() - t
        print('SVM训练结束,耗时%d分钟%.3f秒' % (int(t/60), t - 60*int(t/60)))
        print ('最优分类器:', model.best_estimator_)
        print ('最优参数:\t', model.best_params_)
        print ('model.cv_results_ =',model.cv_results_)

        t = time()
        y_hat = model.predict(x)
        t = time() - t
        print('SVM训练集准确率:%.3f%%,耗时%d分钟%.3f秒' % (accuracy_score(y, y_hat)*100, int(t/60), t - 60*int(t/60)))
        t = time()
        y_test_hat = model.predict(x_test)
        t = time() - t
        print ('SVM测试集准确率:%.3f%%,耗时%d分钟%.3f秒' % (accuracy_score(y_test, y_test_hat)*100, int(t/60), t - 60*int(t/60)))
        save_result(model)
    elif classifier_type == 'RF':
        rfc = RandomForestClassifier(100, criterion='gini', min_samples_split=2,
                                     min_impurity_split=1e-10, bootstrap=True, oob_score=True)
        print('随机森林开始训练...')
        t = time()
        rfc.fit(x, y)
        t = time() - t
        print('随机森林训练结束,耗时%d分钟%.3f秒' % (int(t/60), t - 60*int(t/60)))
        print('OOB准确率:%.3f%%' % (rfc.oob_score_*100))
        t = time()
        y_hat = rfc.predict(x)
        t = time() - t
        print('随机森林训练集准确率:%.3f%%,预测耗时:%d秒' % (accuracy_score(y, y_hat)*100, t))
        t = time()
        y_test_hat = rfc.predict(x_test)
        t = time() - t
        print('随机森林测试集准确率:%.3f%%,预测耗时:%d秒' % (accuracy_score(y_test, y_test_hat)*100, t))


    err = (y_test != y_test_hat)
    err_images = images_test[err]
    err_y_hat = y_test_hat[err]
    err_y = y_test[err]
    print(err_y_hat)
    print(err_y)
    plt.figure(figsize=(10, 8), facecolor='w')
    for index, image in enumerate(err_images):
        if index >= 12:
            break
        plt.subplot(3, 4, index + 1)
        plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
        plt.title(u'错分为:%i,真实值:%i' % (err_y_hat[index], err_y[index]))
    plt.suptitle(u'数字图片手写体识别:分类器%s' % classifier_type, fontsize=18)
    plt.tight_layout(rect=(0, 0, 1, 0.95))
    plt.show()

载入训练数据...
载入完成,耗时1.746520秒
图片个数:42000,图片像素数目:784
载入测试数据...
载入完成,耗时1.179636(33600, 784) (8400, 784)
随机森林开始训练...
随机森林训练结束,耗时0分钟18.557秒
OOB准确率:95.821%
随机森林训练集准确率:100.000%,预测耗时:1秒
随机森林测试集准确率:96.512%,预测耗时:0[2 6 0 7 8 7 0 7 5 0 2 2 9 3 9 8 2 8 7 0 4 9 3 9 2 8 9 7 4 4 1 5 8 7 5 3 2
 4 0 1 8 7 3 2 6 5 8 4 9 2 7 8 5 2 4 9 9 6 8 2 5 3 9 2 5 1 4 6 2 8 8 0 8 2
 2 1 9 8 4 9 3 3 2 8 7 6 8 9 7 3 5 3 1 2 3 9 9 3 9 8 4 4 7 8 3 3 7 3 4 4 0
 9 9 1 7 4 9 5 2 8 8 3 5 5 8 5 1 3 6 2 7 7 3 6 4 3 4 5 0 7 4 9 5 1 4 3 3 5
 8 7 9 2 0 8 3 3 2 6 5 9 9 9 8 1 3 7 1 5 5 3 4 1 2 9 5 2 8 3 1 4 4 3 2 8 3
 4 4 3 9 2 5 7 1 7 6 8 0 5 9 5 6 5 0 8 8 7 0 6 4 8 7 9 4 3 2 4 4 2 8 6 3 1
 9 2 9 6 2 8 9 5 8 4 4 0 0 2 2 6 9 7 9 4 5 0 6 2 3 6 5 9 9 9 2 5 2 9 9 8 8
 5 4 2 3 9 3 7 9 0 5 0 9 5 3 2 4 3 9 8 8 3 9 5 2 7 9 2 8 5 8 5 4 5 8]
[4 5 6 4 2 2 5 2 9 7 7 3 7 9 4 3 7 5 3 5 6 4 5 3 7 2 4 4 6 8 4 8 9 8 3 5 3
 0 5 8 2 3 8 9 0 6 9 8 5 3 3 2 9 1 9 8 5 5 3 7 3 5 4 3 6 8 2 4 3 1 3 3 5 3
 4 8 8 2 7 4 9 5 4 9 3 2 3 8 2 8 3 5 8 3 9 7 3 2 7 9 2 8 9 4 9 1 0 9 6 9 9
 4 4 7 9 9 4 8 7 3 5 5 9 3 6 9 8 5 4 7 5 3 5 5 9 9 9 3 2 5 8 4 6 8 9 5 9 0
 3 3 2 7 8 4 9 5 1 4 6 3 4 8 9 5 1 9 7 8 3 1 9 8 3 7 3 7 3 8 3 9 9 1 3 3 2
 9 7 5 7 0 3 2 9 9 8 6 6 8 4 3 4 3 9 2 0 9 5 5 7 9 2 3 9 9 1 9 7 1 9 8 9 8
 4 7 7 0 1 7 7 3 9 3 9 4 9 3 3 2 4 9 4 8 1 5 5 3 8 4 1 7 7 7 3 8 3 7 5 3 0
 3 7 7 2 7 8 4 4 2 9 3 4 8 1 8 9 5 7 5 3 1 7 8 3 9 8 3 2 6 5 6 8 8 3]

在这里插入图片描述

载入训练数据...
载入完成,耗时1.661703秒
图片个数:42000,图片像素数目:784
载入测试数据...
载入完成,耗时1.106137(33600, 784) (8400, 784)
SVM开始训练...
SVM训练结束,耗时2分钟21.074秒
SVM训练集准确率:94.676%,耗时4分钟13.307秒
SVM测试集准确率:93.810%,耗时1分钟3.327

猜你喜欢

转载自blog.csdn.net/weixin_46649052/article/details/112758955