c#调用python方式完成训练和预测

在前面系列博客基础上,来继续完善博客C#和python通过socket方法进行通信_jiugeshao的专栏-CSDN博客中的程序,使得c#可以通过调用python来完成模型的训练,训练完毕后可以实时预测一张图像的结果。

1.界面改为:

2. 代码中的引用类如下:

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows.Forms;
using System.IO;
using System.Diagnostics;
using System.Collections;
using System.Net.Sockets;
using System.Threading;
using System.Runtime.InteropServices;
using System.Drawing.Imaging;

3. 成员变量如下:

[DllImport("user32.dll", CharSet = CharSet.Auto)]
        public static extern IntPtr SendMessage(IntPtr hWnd, int Msg, int wParam, IntPtr lParam);

        [DllImport("Kernel32.dll", CharSet = CharSet.Auto)]
        public static extern IntPtr CreateFileMapping(int hFile, IntPtr lpAttributes, uint flProtect, uint dwMaxSizeHi, uint dwMaxSizeLow, string lpName);

        [DllImport("Kernel32.dll", CharSet = CharSet.Auto)]
        public static extern IntPtr OpenFileMapping(int dwDesiredAccess, [MarshalAs(UnmanagedType.Bool)] bool bInheritHandle, string lpName);

        [DllImport("Kernel32.dll", CharSet = CharSet.Auto)]
        public static extern IntPtr MapViewOfFile(IntPtr hFileMapping, uint dwDesiredAccess, uint dwFileOffsetHigh, uint dwFileOffsetLow, uint dwNumberOfBytesToMap);

        [DllImport("Kernel32.dll", CharSet = CharSet.Auto)]
        public static extern bool UnmapViewOfFile(IntPtr pvBaseAddress);

        [DllImport("Kernel32.dll", CharSet = CharSet.Auto)]
        public static extern bool CloseHandle(IntPtr handle);

        [DllImport("kernel32", EntryPoint = "GetLastError")]
        public static extern int GetLastError();

        const int INVALID_HANDLE_VALUE = -1;
        const int ERROR_ALREADY_EXISTS = 183;
        const int FILE_MAP_WRITE = 0x0002;
        int m_MemSize = 0;

        IntPtr m_hSharedMemoryFile = IntPtr.Zero;
        IntPtr m_pwData = IntPtr.Zero;
        const int PAGE_READWRITE = 0x04;
        string sharememoryName = "img";


        private string modelPath;
        private string dataFolderPath;
        private delegate void UpdateString(string text);
        private TcpClient _client;     
        private Thread _connectionThread;   // Thread that is responsible for identifying client connection requests.
        private long _totalBytes; // record the total number of bytes received
        string predictImgPath;

下面只说明博客5中有改动的按钮的回调函数

3.Init NetWork按钮回调函数代码:

  private void button1_InitNetWork(object sender, EventArgs e)
        {
            try
            {
                // 获取图片大小从而知道创建多大的共享内存
                byte[] vs = GetPictureData(@"D:\CNN\train\0.0.jpg");
                int lngSize = vs.Length;
                m_MemSize = lngSize;
                Console.WriteLine("m_MemSize: " + m_MemSize);
                if (lngSize <= 0 || lngSize > 100000000)
                {
                    MessageBox.Show("error,check the size!");
                    return;
                }

                //创建内存共享体(INVALID_HANDLE_VALUE)

                m_hSharedMemoryFile = CreateFileMapping(INVALID_HANDLE_VALUE, IntPtr.Zero, (uint)PAGE_READWRITE, 0, (uint)lngSize, sharememoryName);
                if (m_hSharedMemoryFile == IntPtr.Zero)
                {
                    MessageBox.Show("创建失败!");
                    return;
                }
                else
                {
                    if (GetLastError() == ERROR_ALREADY_EXISTS)  //已经创建
                    {
                        MessageBox.Show("error, 已经创建过!");
                        return;
                    }
                    else                                         //新创建
                    {
                        MessageBox.Show("新创建OK");
                    }
                }

                m_pwData = MapViewOfFile(m_hSharedMemoryFile, FILE_MAP_WRITE, 0, 0, (uint)lngSize);
                if (m_pwData == IntPtr.Zero)
                {
                    CloseHandle(m_hSharedMemoryFile);
                    MessageBox.Show("create fail"); //创建内存映射失败
                }
                else
                {
                    MessageBox.Show("sharememory create ok");
                }

                Process p = new Process();
                string path = @"D:\CNN\CsharpCallCNN.py";//待处理python文件的路径,本例中放在debug文件夹下
                string sArguments = path;
                ArrayList arrayList = new ArrayList();
                arrayList.Add(txtIP.Text);
                arrayList.Add(txtPort.Text);
                arrayList.Add(txt_modepath.Text);
                arrayList.Add(Convert.ToInt32(txtrows.Text));
                arrayList.Add(Convert.ToInt32(txtcols.Text));
                arrayList.Add(Convert.ToInt32(txtChannel.Text));
                arrayList.Add(txtTrainDataPath.Text);

                foreach (var param in arrayList)//添加参数
                {
                    sArguments += " " + param;
                }

                p.StartInfo.FileName = @"C:\Anaconda3\python.exe";
                p.StartInfo.Arguments = sArguments;//python命令的参数
                p.StartInfo.UseShellExecute = false;
                p.StartInfo.RedirectStandardOutput = true;
                p.StartInfo.RedirectStandardInput = true;
                p.StartInfo.RedirectStandardError = true;
                p.StartInfo.CreateNoWindow = true;
                p.Start();//启动进程
                p.BeginOutputReadLine();
                p.OutputDataReceived += new DataReceivedEventHandler(p_OutputDataReceived);
                //p.Close();
                // p.WaitForExit();
            }
            catch (Exception ec)
            {
                Console.WriteLine(ec);
            }
        }

GetPictureData函数代码如下:(这里训练和测试的图片是单通道,如果是三通道图,可以按我上一篇博客来设)

public byte[] GetPictureData(string imagepath)
        {
            Bitmap srcBmp = new Bitmap(imagepath);
            BitmapData bmdata = srcBmp.LockBits(new Rectangle(0, 0, srcBmp.Width, srcBmp.Height), ImageLockMode.ReadWrite,
                    PixelFormat.Format8bppIndexed);
            IntPtr pSrc = bmdata.Scan0;
            int iBytes = srcBmp.Width * srcBmp.Height;
            byte[] s = new byte[iBytes];

            System.Runtime.InteropServices.Marshal.Copy(pSrc, s, 0, iBytes);
            return s;
        }

4.TestDataPath后面的choose按钮回调函数代码如下:

  private void button1_Click(object sender, EventArgs e)
        {

            listBox2.Items.Clear();
            FolderBrowserDialog path = new FolderBrowserDialog();
            path.SelectedPath = "D:\\CNN\\test";
            path.ShowDialog();
            Console.WriteLine("from c#: TestData Path-> " + path.SelectedPath);
            DirectoryInfo info = new DirectoryInfo(path.SelectedPath);
            textBox1.Text = path.SelectedPath;
            dataFolderPath = path.SelectedPath;
            foreach (FileInfo fi in info.GetFiles())
            {
                if (fi.Extension == ".jpg" || fi.Extension == ".bmp" || fi.Extension == ".PNG" || fi.Extension == ".png" || fi.Extension == ".gif" || fi.Extension == ".brw" || fi.Extension == ".JPG" || fi.Extension == ".BMP" || fi.Extension == ".GIF" || fi.Extension == ".GIF" || fi.Extension == "BRW")
                {
                    //Console.WriteLine(fi.ToString());
                    listBox2.Items.Add(fi.ToString());
                }
            }

            MessageBox.Show("load test img over");
        }

5. listbox2选项选中回调函数代码如下:

 private void listBox2_SelectedIndexChanged(object sender, EventArgs e)
        {
            predictImgPath = dataFolderPath + "\\" + listBox2.SelectedItem.ToString();
            Console.WriteLine(predictImgPath);
        }

6. predict按钮回调函数代码如下:

/ private void btnPredict_Click(object sender, EventArgs e)
        {
            //图片写入共享内存
            //把图片转换为byte数组
            byte[] vs = GetPictureData(predictImgPath);
            int lngAddr = 0;
            int lngSize = vs.Length;
            if (lngAddr + lngSize > m_MemSize)
                MessageBox.Show("超出数据区");

            DateTime SingleStartTime;
            TimeSpan SingleTimeSpan;
            string SingleCT;
            SingleStartTime = DateTime.Now;
            Marshal.Copy(vs, lngAddr, m_pwData, lngSize);
            SingleTimeSpan = DateTime.Now - SingleStartTime;
            SingleCT = (SingleTimeSpan.TotalMilliseconds / 1000).ToString("0.000");
            Console.WriteLine("写耗时" + SingleCT + "ms");

            MessageBox.Show("write img to sharememory success");


            NetworkStream netStream = null;
            netStream = _client.GetStream();
            byte[] message = Encoding.ASCII.GetBytes("predict" + "\r\n");
            netStream.Write(message, 0, message.Length);
            netStream.Flush();
        }

D:\\CNN目录下的文件如下:

_example.cp36-win_amd64.pcd和example.py是上一篇博客中所生成的文件,可以直接拿来用

CsharpCallCNN.py中的代码如下:

import sys

import socket
import time
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import cv2
import numpy as np
import pydot
import graphviz

import  tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Activation, PReLU
from tensorflow.keras.optimizers import SGD, Adadelta, Adagrad
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import model_from_json, load_model
import h5py

import example
from PIL import Image, ImageStat

from tensorflow.python.keras.utils import np_utils


def cal(num1, num2, op):
    if op == 1:
        result = num1 + num2
    elif op == 2:
        result = num1 - num2
    elif op == 3:
        result = num1 * num2
    else:
        result = num1 - num2
    return str(result)

def loadModel():
    # 建立一个Sequential模型
    model = Sequential()

    # model.add(Conv2D(4, 5, 5, border_mode='valid',input_shape=(28,28,1)))
    # 第一个卷积层,4个卷积核,每个卷积核5*5,卷积后24*24,第一个卷积核要申明input_shape(通道,大小) ,激活函数采用“tanh”
    model.add(Conv2D(filters=4, kernel_size=(5, 5), padding='valid', input_shape=(28, 28, 1), activation='tanh'))
    # model.add(Conv2D(8, 3, 3, subsample=(2,2), border_mode='valid'))
    # 第二个卷积层,8个卷积核,不需要申明上一个卷积留下来的特征map,会自动识别,下采样层为2*2,卷完且采样后是11*11
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(filters=8, kernel_size=(3, 3), padding='valid', activation='tanh'))
    # model.add(Activation('tanh'))
    # model.add(Conv2D(16, 3, 3, subsample=(2,2), border_mode='valid'))
    # 第三个卷积层,16个卷积核,下采样层为2*2,卷完采样后是4*4
    model.add(Conv2D(filters=16, kernel_size=(3, 3), padding='valid', activation='tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    # model.add(Activation('tanh'))
    model.add(Flatten())
    # 把多维的模型压平为一维的,用在卷积层到全连接层的过度
    # model.add(Dense(128, input_dim=(16*4*4), init='normal'))
    # 全连接层,首层的需要指定输入维度16*4*4,128是输出维度,默认放第一位
    model.add(Dense(128, activation='tanh'))
    # model.add(Activation('tanh'))
    # model.add(Dense(10, input_dim= 128, init='normal'))
    # 第二层全连接层,其实不需要指定输入维度,输出为10维,因为是10类
    model.add(Dense(10, activation='softmax'))
    # model.add(Activation('softmax'))
    # 激活函数“softmax”,用于分类
    sys.stdout.flush()
    sgd = SGD(lr=0.05, momentum=0.9, decay=1e-6, nesterov=True)
    # 采用随机梯度下降法,学习率初始值0.05,动量参数为0.9,学习率衰减值为1e-6,确定使用Nesterov动量
    model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
    # 配置模型学习过程,目标函数为categorical_crossentropy:亦称作多类的对数损失,注意使用该目标函数时,需要将标签转化为形如(nb_samples, nb_classes)的二值序列,第18行已转化,优化器为sgd
    return model

def loadData(path, number):
    data = np.empty((number, 1, 28, 28), dtype="float32")  # empty与ones差不多原理,但是数值随机,类型随后面设定
    labels = np.empty((number,), dtype="uint8")
    listImg = os.listdir(path)
    count = 0
    for img in listImg:
        imgData = cv2.imread(path + '/' + img, 0)  # 数据
        l = int(img.split('.')[0])  # 答案
        arr = np.asarray(imgData, dtype="float32")  # 将img数据转化为数组形式
        data[count, :, :, :] = arr  # 将每个三维数组赋给data
        labels[count] = l  # 取该图像的数值属性作为标签
        count = count + 1
        path, " loaded ", count
        if count >= number:
            break
    return data, labels

def train(trainDatapath, modelpath):
    # 从图片文件加载数据
    trainData, trainLabels = loadData(trainDatapath, 42000)
    trainLabels = np_utils.to_categorical(trainLabels, 10)
    # label为0~9共10个类别,keras要求格式为binary class matrices,转化一下,直接调用keras提供的这个函数
    print(trainData.shape)
    trainData = trainData.reshape(trainData.shape[0], 28, 28, 1)
    print(trainData.shape)
    sys.stdout.flush()

    model = loadModel()
    model.fit(trainData, trainLabels, batch_size=100, epochs=10, shuffle=True, verbose=1, validation_split=0.2)
    model.save_weights(modelpath)
    print("train finished!")
    sys.stdout.flush()

def predict(modelpath, height, width, channel):
    model = loadModel()
    model.load_weights(modelpath)
    print("load mode ok")
    sys.stdout.flush()
    im = example.GetImageFromSM("img", height, width, channel)
    print(im.shape)
    sys.stdout.flush()

    print("get img ok")
    sys.stdout.flush()

    #im = im[..., ::-1]
    x = Image.fromarray(im,'L')
    x.show()

    im = im[np.newaxis, ..., np.newaxis]
    x_test1_pred = model.predict(im, batch_size=1, verbose=1)
    print("the image predict result: ", np.argmax(x_test1_pred))
    sys.stdout.flush()

if __name__ == '__main__':
    #train("D:\\CNN\\train","D:\\CNN\\CNN.h5")
    #predict("D:\\CNN\\CNN.h5", 28, 28, 1)
    print("python code has been called")
    IP = sys.argv[1]
    Port = sys.argv[2]
    modelPath = sys.argv[3]
    imgRows = int(sys.argv[4])
    imgCols = int(sys.argv[5])
    imgChannel = int(sys.argv[6])
    imgTrainDataPath = sys.argv[7]

    print("IP: ", IP)
    print("Port: ", Port)
    print("modePath: ", modelPath)
    print("imgRows: ", imgRows)
    print("imgCols: ", imgCols)
    print("imgChannel: ", imgChannel)
    print("imgTrainDataPath: ", imgTrainDataPath)
    print("param passed successfully")
    print("python code has been called, after 5s, the server will be created")
    sys.stdout.flush()

    s = socket.socket()
    s.bind(("127.0.0.1", int(10086)))
    print("the server has been created, waite the client to connect")
    s.listen(5)
    client, address = s.accept()
    print("Connect has been built successfully")
    sys.stdout.flush()
    client.send(bytes("Hi, Weclome!", 'utf-8'))
    while True:
        data = client.recv(1024)
        recv_str = data.decode()
        data_str = recv_str[0:len(recv_str) - 1]
        client.send(data_str.encode('utf-8'))

        if data_str.find("train")>=0:
            client.send("enter train function".encode('utf-8'))
            train(imgTrainDataPath, modelPath)

        if data_str.find("predict")>=0:
            client.send("enter predict function".encode('utf-8'))
            predict(modelPath,imgRows,imgCols,imgChannel)

这里也上传下该文件夹内的文件

链接:https://pan.baidu.com/s/1mmzITYGhU3yAHN8LKfFaFQ 
提取码:5s2b 
 

运行程序出现如下界面:

选择好model path和TrainDataPath路径后

点击Init NetWork按钮

可以看到python被调用起来,但其并没有以一个新窗口形式出现,输出信息直接显示在c#的控制台窗口里

完毕后点击create client to connect server按钮

再点击Train NetWork按钮

点击TestDataPath后面的Choose按钮,选择测试图片路径

选中上面4.1684.jpg图片,再点击predict按钮

可以看到图片预测结果正确。

到此c#调用python来实现网络模型的训练和预测demo结束。各位可以继续拓展。

接下来PyQt5来实现2D和3D及深度学习平台博客想等一等再写,后续博主精力还是先放在深度学习分类,检测,分割算法上面,之前也有博客总结,但想继续研究的深点。

Guess you like

Origin blog.csdn.net/jiugeshao/article/details/112093981