pyqt 搭建深度学习训练界面(二)

炼丹软件
在这里插入图片描述

github链接:

有需要联系我

requirements:

测试在ubuntu18.04和Windows均可运行

ubuntu18.04

OS: Ubuntu 18.04.6 LTS
Python version: 3.7 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.1.74
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3090 Ti
Nvidia driver version: 510.108.03

安装可能存在的问题:

No module named ‘kornia’

pip install kornia==0.5

注: 不带版本号会默认下载新的torch

No module named ‘PyQt5.QtChart’
需单独安装

popip install pyqtchart

ModuleNotFoundError: No module named ‘qt_material’

ModuleNotFoundError: No module named ‘jinja2’

pip install jinja2

主函数

if __name__ == '__main__':
    # multiprocessing.freeze_support()
    # initialize QApplication
    app = QApplication(sys.argv)
    # set style
    app.setStyleSheet(StyleSheet)
    # set the launch 
    global splash
    splash = GifSplashScreen()
    splash.show()

    m_window = MainCode()
    apply_stylesheet(app, theme='dark_blue.xml') #, invert_secondary=True)
    m_window.show()
    sys.exit(app.exec_())

启动界面

通常在主界面加载完成之前提供一个启动界面,减少主程序加载过程用户的等待

# 启动界面
class GifSplashScreen(QSplashScreen):
    def __init__(self, *args, **kwargs):
        super(GifSplashScreen, self).__init__(*args, **kwargs)
        self.movie = QMovie('./Lib/splash.gif')
        self.movie.frameChanged.connect(self.onFrameChanged)
        self.movie.start()

    def onFrameChanged(self, _):
        self.setPixmap(self.movie.currentPixmap())

    def finish(self, widget):
        self.movie.stop()
        super(GifSplashScreen, self).finish(widget)

主界面设计

利用QtDesigner来设计界面,通过Pycharm外部工具PyUIC转化成py文件


class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(1000, 800)
        icon = QtGui.QIcon()
        icon.addPixmap(QtGui.QPixmap("./icon.ico"), QtGui.QIcon.Normal, QtGui.QIcon.Off)
        MainWindow.setWindowIcon(icon)
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.horizontalLayout = QtWidgets.QHBoxLayout(self.centralwidget)
        self.horizontalLayout.setObjectName("horizontalLayout")
        self.toolBox = QtWidgets.QToolBox(self.centralwidget)
        self.toolBox.setMaximumSize(QtCore.QSize(200, 16777215))
        self.toolBox.setObjectName("toolBox")
        self.m_First = QtWidgets.QWidget()
        self.m_First.setGeometry(QtCore.QRect(0, 0, 200, 597))
        self.m_First.setObjectName("m_First")
        self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.m_First)
        self.horizontalLayout_2.setObjectName("horizontalLayout_2")
        self.m_ImageDirBtn = QtWidgets.QPushButton(self.m_First)
        self.m_ImageDirBtn.setObjectName("m_ImageDirBtn")
        self.horizontalLayout_2.addWidget(self.m_ImageDirBtn)
        self.toolBox.addItem(self.m_First, "")
        self.m_Second = QtWidgets.QWidget()
        self.m_Second.setGeometry(QtCore.QRect(0, 0, 200, 597))
        self.m_Second.setObjectName("m_Second")
        self.horizontalLayout_3 = QtWidgets.QHBoxLayout(self.m_Second)
        self.horizontalLayout_3.setObjectName("horizontalLayout_3")
        self.m_SetTrainPareBtn = QtWidgets.QPushButton(self.m_Second)
        self.m_SetTrainPareBtn.setObjectName("m_SetTrainPareBtn")
        self.horizontalLayout_3.addWidget(self.m_SetTrainPareBtn)
        self.toolBox.addItem(self.m_Second, "")
        self.m_Third = QtWidgets.QWidget()
        self.m_Third.setGeometry(QtCore.QRect(0, 0, 200, 597))
        self.m_Third.setObjectName("m_Third")
        self.verticalLayout = QtWidgets.QVBoxLayout(self.m_Third)
        self.verticalLayout.setObjectName("verticalLayout")
        self.m_StartTrainBtn = QtWidgets.QPushButton(self.m_Third)
        self.m_StartTrainBtn.setObjectName("m_StartTrainBtn")
        self.verticalLayout.addWidget(self.m_StartTrainBtn)
        self.toolBox.addItem(self.m_Third, "")
        self.m_Forth = QtWidgets.QWidget()
        self.m_Forth.setGeometry(QtCore.QRect(0, 0, 200, 597))
        self.m_Forth.setObjectName("m_Forth")
        self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.m_Forth)
        self.verticalLayout_2.setObjectName("verticalLayout_2")
        self.m_DetectSinglePicBtn = QtWidgets.QPushButton(self.m_Forth)
        self.m_DetectSinglePicBtn.setObjectName("m_DetectSinglePicBtn")
        self.verticalLayout_2.addWidget(self.m_DetectSinglePicBtn)
        self.toolBox.addItem(self.m_Forth, "")
        self.m_fifth = QtWidgets.QWidget()
        self.m_fifth.setObjectName("m_fifth")
        self.m_convertModel = QtWidgets.QPushButton(self.m_fifth)
        self.m_convertModel.setGeometry(QtCore.QRect(50, 180, 75, 23))
        self.m_convertModel.setObjectName("m_convertModel")
        self.toolBox.addItem(self.m_fifth, "")
        self.horizontalLayout.addWidget(self.toolBox)
        self.tabWidget = QtWidgets.QTabWidget(self.centralwidget)
        self.tabWidget.setObjectName("tabWidget")
        self.home_page = QtWidgets.QWidget()
        self.home_page.setObjectName("home_page")
        self.gridLayout_8 = QtWidgets.QGridLayout(self.home_page)
        self.gridLayout_8.setObjectName("gridLayout_8")
        self.m_homePagelabel = QtWidgets.QLabel(self.home_page)
        sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding)
        sizePolicy.setHorizontalStretch(0)
        sizePolicy.setVerticalStretch(0)
        sizePolicy.setHeightForWidth(self.m_homePagelabel.sizePolicy().hasHeightForWidth())
        self.m_homePagelabel.setSizePolicy(sizePolicy)
        self.m_homePagelabel.setObjectName("m_homePagelabel")
        self.gridLayout_8.addWidget(self.m_homePagelabel, 0, 0, 1, 1)
        self.tabWidget.addTab(self.home_page, "")
        self.m_FirstW = QtWidgets.QWidget()
        self.m_FirstW.setObjectName("m_FirstW")
        self.gridLayout = QtWidgets.QGridLayout(self.m_FirstW)
        self.gridLayout.setObjectName("gridLayout")
        self.le_imageDir = QtWidgets.QLineEdit(self.m_FirstW)
        self.le_imageDir.setObjectName("le_imageDir")
        self.gridLayout.addWidget(self.le_imageDir, 0, 0, 1, 1)
        self.tableImgList = QtWidgets.QTableWidget(self.m_FirstW)
        self.tableImgList.setMaximumSize(QtCore.QSize(300, 16777215))
        self.tableImgList.setObjectName("tableImgList")
        self.tableImgList.setColumnCount(0)
        self.tableImgList.setRowCount(0)
        self.gridLayout.addWidget(self.tableImgList, 1, 1, 1, 1)
        self.la_image = QtWidgets.QLabel(self.m_FirstW)
        self.la_image.setObjectName("la_image")
        self.gridLayout.addWidget(self.la_image, 1, 0, 1, 1)
        self.tabWidget.addTab(self.m_FirstW, "")
        self.m_SecondW = QtWidgets.QWidget()
        self.m_SecondW.setObjectName("m_SecondW")
        self.gridLayout_5 = QtWidgets.QGridLayout(self.m_SecondW)
        self.gridLayout_5.setObjectName("gridLayout_5")
        self.m_MaxIterationL = QtWidgets.QLabel(self.m_SecondW)
        self.m_MaxIterationL.setObjectName("m_MaxIterationL")
        self.gridLayout_5.addWidget(self.m_MaxIterationL, 0, 0, 1, 1)
        self.m_MaxIterationEd = QtWidgets.QLineEdit(self.m_SecondW)
        self.m_MaxIterationEd.setObjectName("m_MaxIterationEd")
        self.gridLayout_5.addWidget(self.m_MaxIterationEd, 0, 1, 1, 1)
        self.m_BathSizeL = QtWidgets.QLabel(self.m_SecondW)
        self.m_BathSizeL.setObjectName("m_BathSizeL")
        self.gridLayout_5.addWidget(self.m_BathSizeL, 1, 0, 1, 1)
        self.m_BathSizeEd = QtWidgets.QLineEdit(self.m_SecondW)
        self.m_BathSizeEd.setObjectName("m_BathSizeEd")
        self.gridLayout_5.addWidget(self.m_BathSizeEd, 1, 1, 1, 1)
        self.m_ImageSizeL = QtWidgets.QLabel(self.m_SecondW)
        self.m_ImageSizeL.setObjectName("m_ImageSizeL")
        self.gridLayout_5.addWidget(self.m_ImageSizeL, 2, 0, 1, 1)
        self.m_ImageSizeEd = QtWidgets.QLineEdit(self.m_SecondW)
        self.m_ImageSizeEd.setObjectName("m_ImageSizeEd")
        self.gridLayout_5.addWidget(self.m_ImageSizeEd, 2, 1, 1, 1)
        self.m_ValidationRatioL = QtWidgets.QLabel(self.m_SecondW)
        self.m_ValidationRatioL.setObjectName("m_ValidationRatioL")
        self.gridLayout_5.addWidget(self.m_ValidationRatioL, 3, 0, 1, 1)
        self.m_ValidationRatioEd = QtWidgets.QLineEdit(self.m_SecondW)
        self.m_ValidationRatioEd.setObjectName("m_ValidationRatioEd")
        self.gridLayout_5.addWidget(self.m_ValidationRatioEd, 3, 1, 1, 1)
        self.m_LearningRateL = QtWidgets.QLabel(self.m_SecondW)
        self.m_LearningRateL.setObjectName("m_LearningRateL")
        self.gridLayout_5.addWidget(self.m_LearningRateL, 4, 0, 1, 1)
        self.m_LearningRateEd = QtWidgets.QLineEdit(self.m_SecondW)
        self.m_LearningRateEd.setObjectName("m_LearningRateEd")
        self.gridLayout_5.addWidget(self.m_LearningRateEd, 4, 1, 1, 1)
        self.m_WeightDecayL = QtWidgets.QLabel(self.m_SecondW)
        self.m_WeightDecayL.setObjectName("m_WeightDecayL")
        self.gridLayout_5.addWidget(self.m_WeightDecayL, 5, 0, 1, 1)
        self.m_WeightDecayEd = QtWidgets.QLineEdit(self.m_SecondW)
        self.m_WeightDecayEd.setObjectName("m_WeightDecayEd")
        self.gridLayout_5.addWidget(self.m_WeightDecayEd, 5, 1, 1, 1)
        self.m_isCuda = QtWidgets.QCheckBox(self.m_SecondW)
        self.m_isCuda.setObjectName("m_isCuda")
        self.gridLayout_5.addWidget(self.m_isCuda, 6, 0, 1, 1)
        self.m_OkBtn = QtWidgets.QPushButton(self.m_SecondW)
        self.m_OkBtn.setObjectName("m_OkBtn")
        self.gridLayout_5.addWidget(self.m_OkBtn, 7, 1, 1, 1)
        self.tabWidget.addTab(self.m_SecondW, "")
        self.m_ThirdW = QtWidgets.QWidget()
        self.m_ThirdW.setObjectName("m_ThirdW")
        self.gridLayout_7 = QtWidgets.QGridLayout(self.m_ThirdW)
        self.gridLayout_7.setObjectName("gridLayout_7")
        self.m_trainwidget = QtWidgets.QWidget(self.m_ThirdW)
        self.m_trainwidget.setObjectName("m_trainwidget")
        self.gridLayout_2 = QtWidgets.QGridLayout(self.m_trainwidget)
        self.gridLayout_2.setObjectName("gridLayout_2")
        self.m_initModelBtn = QtWidgets.QPushButton(self.m_trainwidget)
        self.m_initModelBtn.setObjectName("m_initModelBtn")
        self.gridLayout_2.addWidget(self.m_initModelBtn, 0, 0, 1, 1)
        self.m_startTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
        self.m_startTrainBtn.setObjectName("m_startTrainBtn")
        self.gridLayout_2.addWidget(self.m_startTrainBtn, 0, 1, 1, 1)
        self.m_pauseTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
        self.m_pauseTrainBtn.setObjectName("m_pauseTrainBtn")
        self.gridLayout_2.addWidget(self.m_pauseTrainBtn, 1, 0, 1, 1)
        self.m_resumTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
        self.m_resumTrainBtn.setObjectName("m_resumTrainBtn")
        self.gridLayout_2.addWidget(self.m_resumTrainBtn, 1, 1, 1, 1)
        self.m_stopTrainBtn = QtWidgets.QPushButton(self.m_trainwidget)
        self.m_stopTrainBtn.setObjectName("m_stopTrainBtn")
        self.gridLayout_2.addWidget(self.m_stopTrainBtn, 2, 0, 1, 1)
        self.gridLayout_7.addWidget(self.m_trainwidget, 0, 1, 1, 1)
        self.m_savemodelWidget = QtWidgets.QWidget(self.m_ThirdW)
        self.m_savemodelWidget.setMinimumSize(QtCore.QSize(0, 100))
        self.m_savemodelWidget.setMaximumSize(QtCore.QSize(16777215, 100))
        self.m_savemodelWidget.setObjectName("m_savemodelWidget")
        self.gridLayout_3 = QtWidgets.QGridLayout(self.m_savemodelWidget)
        self.gridLayout_3.setObjectName("gridLayout_3")
        self.m_modelSaveEd = QtWidgets.QLineEdit(self.m_savemodelWidget)
        self.m_modelSaveEd.setObjectName("m_modelSaveEd")
        self.gridLayout_3.addWidget(self.m_modelSaveEd, 0, 1, 1, 1)
        self.m_modelSaveBtn = QtWidgets.QPushButton(self.m_savemodelWidget)
        self.m_modelSaveBtn.setObjectName("m_modelSaveBtn")
        self.gridLayout_3.addWidget(self.m_modelSaveBtn, 0, 3, 1, 1)
        self.m_modelSaveL = QtWidgets.QLabel(self.m_savemodelWidget)
        self.m_modelSaveL.setObjectName("m_modelSaveL")
        self.gridLayout_3.addWidget(self.m_modelSaveL, 0, 0, 1, 1)
        self.gridLayout_7.addWidget(self.m_savemodelWidget, 0, 0, 1, 1)
        self.m_modelTrainProcesssbar = QtWidgets.QProgressBar(self.m_ThirdW)
        self.m_modelTrainProcesssbar.setProperty("value", 24)
        self.m_modelTrainProcesssbar.setObjectName("m_modelTrainProcesssbar")
        self.gridLayout_7.addWidget(self.m_modelTrainProcesssbar, 1, 0, 1, 2)
        self.textBrowser = QtWidgets.QTextBrowser(self.m_ThirdW)
        self.textBrowser.setObjectName("textBrowser")
        self.gridLayout_7.addWidget(self.textBrowser, 3, 0, 1, 2)
        self.tabWidget.addTab(self.m_ThirdW, "")
        self.m_ForthW = QtWidgets.QWidget()
        self.m_ForthW.setObjectName("m_ForthW")
        self.gridLayout_6 = QtWidgets.QGridLayout(self.m_ForthW)
        self.gridLayout_6.setObjectName("gridLayout_6")
        self.m_loadmodelwidget = QtWidgets.QWidget(self.m_ForthW)
        self.m_loadmodelwidget.setMinimumSize(QtCore.QSize(0, 50))
        self.m_loadmodelwidget.setMaximumSize(QtCore.QSize(16777215, 50))
        self.m_loadmodelwidget.setObjectName("m_loadmodelwidget")
        self.gridLayout_4 = QtWidgets.QGridLayout(self.m_loadmodelwidget)
        self.gridLayout_4.setObjectName("gridLayout_4")
        self.m_loadmodelBtn = QtWidgets.QPushButton(self.m_loadmodelwidget)
        self.m_loadmodelBtn.setObjectName("m_loadmodelBtn")
        self.gridLayout_4.addWidget(self.m_loadmodelBtn, 0, 0, 1, 1)
        self.m_loadmodelEd = QtWidgets.QLineEdit(self.m_loadmodelwidget)
        self.m_loadmodelEd.setObjectName("m_loadmodelEd")
        self.gridLayout_4.addWidget(self.m_loadmodelEd, 0, 1, 1, 1)
        self.gridLayout_6.addWidget(self.m_loadmodelwidget, 0, 0, 1, 1)
        self.la_result = QtWidgets.QLabel(self.m_ForthW)
        self.la_result.setObjectName("la_result")
        self.gridLayout_6.addWidget(self.la_result, 1, 0, 1, 1)
        self.tabWidget.addTab(self.m_ForthW, "")
        self.tab = QtWidgets.QWidget()
        self.tab.setObjectName("tab")
        self.gridLayout_9 = QtWidgets.QGridLayout(self.tab)
        self.gridLayout_9.setObjectName("gridLayout_9")
        self.m_choosemodelEd = QtWidgets.QLineEdit(self.tab)
        self.m_choosemodelEd.setObjectName("m_choosemodelEd")
        self.gridLayout_9.addWidget(self.m_choosemodelEd, 0, 0, 1, 1)
        self.m_choosemodelBtn = QtWidgets.QPushButton(self.tab)
        self.m_choosemodelBtn.setObjectName("m_choosemodelBtn")
        self.gridLayout_9.addWidget(self.m_choosemodelBtn, 1, 0, 1, 1)
        self.m_starttransformBtn = QtWidgets.QPushButton(self.tab)
        self.m_starttransformBtn.setObjectName("m_starttransformBtn")
        self.gridLayout_9.addWidget(self.m_starttransformBtn, 2, 0, 1, 1)
        self.tabWidget.addTab(self.tab, "")
        self.horizontalLayout.addWidget(self.tabWidget)
        MainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(MainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 1000, 22))
        self.menubar.setObjectName("menubar")
        self.openmenu = QtWidgets.QMenu(self.menubar)
        self.openmenu.setObjectName("openmenu")
        MainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)
        self.toolBar = QtWidgets.QToolBar(MainWindow)
        self.toolBar.setObjectName("toolBar")
        MainWindow.addToolBar(QtCore.Qt.BottomToolBarArea, self.toolBar)
        self.menubar.addAction(self.openmenu.menuAction())

        self.retranslateUi(MainWindow)
        self.toolBox.setCurrentIndex(4)
        self.tabWidget.setCurrentIndex(5)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)

    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "训练界面"))
        self.m_ImageDirBtn.setText(_translate("MainWindow", "选择图像路径"))
        self.toolBox.setItemText(self.toolBox.indexOf(self.m_First), _translate("MainWindow", "第一步"))
        self.m_SetTrainPareBtn.setText(_translate("MainWindow", "设置训练参数"))
        self.toolBox.setItemText(self.toolBox.indexOf(self.m_Second), _translate("MainWindow", "第二步"))
        self.m_StartTrainBtn.setText(_translate("MainWindow", "开始训练"))
        self.toolBox.setItemText(self.toolBox.indexOf(self.m_Third), _translate("MainWindow", "第三步"))
        self.m_DetectSinglePicBtn.setText(_translate("MainWindow", "检测图像"))
        self.toolBox.setItemText(self.toolBox.indexOf(self.m_Forth), _translate("MainWindow", "第四步"))
        self.m_convertModel.setText(_translate("MainWindow", "模型转换"))
        self.toolBox.setItemText(self.toolBox.indexOf(self.m_fifth), _translate("MainWindow", "第五步"))
        self.m_homePagelabel.setText(_translate("MainWindow", "TextLabel"))
        self.tabWidget.setTabText(self.tabWidget.indexOf(self.home_page), _translate("MainWindow", "home"))
        self.la_image.setText(_translate("MainWindow", "TextLabel"))
        self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_FirstW), _translate("MainWindow", "第一步"))
        self.m_MaxIterationL.setText(_translate("MainWindow", "最大训练次数:"))
        self.m_BathSizeL.setText(_translate("MainWindow", "batch_size(批尺寸)"))
        self.m_ImageSizeL.setText(_translate("MainWindow", "图像尺寸"))
        self.m_ValidationRatioL.setText(_translate("MainWindow", "验证集比例"))
        self.m_LearningRateL.setText(_translate("MainWindow", "学习率"))
        self.m_WeightDecayL.setText(_translate("MainWindow", "权重衰减系数:"))
        self.m_isCuda.setText(_translate("MainWindow", "是否使用显卡训练"))
        self.m_OkBtn.setText(_translate("MainWindow", "OK"))
        self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_SecondW), _translate("MainWindow", "第二步"))
        self.m_initModelBtn.setText(_translate("MainWindow", "初始化"))
        self.m_startTrainBtn.setText(_translate("MainWindow", "开始训练"))
        self.m_pauseTrainBtn.setText(_translate("MainWindow", "暂停训练"))
        self.m_resumTrainBtn.setText(_translate("MainWindow", "继续训练"))
        self.m_stopTrainBtn.setText(_translate("MainWindow", "停止训练"))
        self.m_modelSaveBtn.setText(_translate("MainWindow", "选择路径"))
        self.m_modelSaveL.setText(_translate("MainWindow", "模型保存位置:"))
        self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_ThirdW), _translate("MainWindow", "第三步"))
        self.m_loadmodelBtn.setText(_translate("MainWindow", "加载模型:"))
        self.la_result.setText(_translate("MainWindow", "TextLabel"))
        self.tabWidget.setTabText(self.tabWidget.indexOf(self.m_ForthW), _translate("MainWindow", "第四步"))
        self.m_choosemodelBtn.setText(_translate("MainWindow", "选择模型文件"))
        self.m_starttransformBtn.setText(_translate("MainWindow", "转换"))
        self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab), _translate("MainWindow", "第五步"))
        self.openmenu.setTitle(_translate("MainWindow", "打开"))
        self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))

训练参数类设置

class trainParameter():
    def __init__(self):
        self.epochs = 0
        self.batch_size = 0
        self.image_size = 0
        self.validation_ratio = 0
        self.lr = 0
        self.dw = 0
        self.Cuda = False
        self.kl_weights = 1.0
        self.l2_weights = 1.0
        self.gms_weights = 1.0
        self.ssim_weights = 1.0

多线程设置

class Thread(QThread):
    valueChange = pyqtSignal(int)
    textChange = pyqtSignal(int, str)
    def __init__(self, open_dir, train_parameters, save_dir):
        super(Thread, self).__init__()
        self._isPause = False
        self._isStop = False
        self._value = 0
        self.cond = QWaitCondition()
        self.mutex = QMutex()
        self.train = train_parameters
        self.open_dir = open_dir
        self.save_dir = save_dir
        self.xnet = XNet(self.open_dir, self.train, self.save_dir)

    def pause(self):
        self._isPause = True

    def stop(self):
        self._isStop = True

    def resume(self):
        self._isPause = False
        self.cond.wakeAll()

    def run(self):
        for epoch in range(1, self.train.epochs + 1):
            self.valueChange.emit(epoch)
            self.mutex.lock()  # 加锁
            if self._isPause:
                self.cond.wait(self.mutex)
            if self._isStop:
                return
            QApplication.processEvents()  # 实时刷新显示
            loss_avg = self.xnet.train_one_epoch()
            if np.isnan(loss_avg):
                loss_avg = 1e6

            # self.valueChanged.emit(loss_avg)



            print('Train Epoch: {} loss: {:.6f}'.format(epoch, loss_avg))
            self.textChange.emit(epoch, str(loss_avg))
            if epoch % 50 == 0:
                self.xnet.save_model(epoch)  # 保存权重
            self.mutex.unlock()

主界面功能

class MainCode(QMainWindow, Ui_MainWindow):
    def __init__(self):
        QMainWindow.__init__(self)
        Ui_MainWindow.__init__(self)
        self.setupUi(self)
        for i in range(2):
            sleep(1)
            splash.showMessage('加载进度: %d' % i, Qt.AlignHCenter | Qt.AlignBottom, Qt.white)
            QApplication.instance().processEvents()

        splash.showMessage('初始化完成', Qt.AlignHCenter | Qt.AlignBottom, Qt.white)
        splash.finish(self)

        self.train = trainParameter()
        self.cwd = os.getcwd()  # return current workdir


        self.initUi()

背景图片展示

    def showHomepage(self):
        imgName = "./Lib/train.jpeg"
        # print(imgName)
        jpg = QPixmap(imgName).scaled(self.m_homePagelabel.width(), self.m_homePagelabel.height())
        # 显示原图
        self.m_homePagelabel.setPixmap(jpg)
        self.m_homePagelabel.setScaledContents(True)

右下角显示北京时间

   # 获取当前时间
    def showCurrentTime(self, timeLabel):
        # 获取系统当前时间
        time = QDateTime.currentDateTime()
        # 设置系统时间的显示格式
        timeDisplay = time.toString('yyyy-MM-dd hh:mm:ss dddd')
        timeLabel.setText(timeDisplay)

    # 状态栏显示时间
    def statusShowTime(self):
        self.timer = QTimer()
        self.timeLabel = QLabel()
        self.statusbar.addPermanentWidget(self.timeLabel, 0)
        self.timer.timeout.connect(lambda: self.showCurrentTime(self.timeLabel))  # 这个通过调用槽函数来刷新时间

        self.timer.start(1000)  # 每隔一秒刷新一次,这里设置为1000ms  即1s

选择图像路径并显示在列表中

   def openfiledir(self):
        openfile_dir = QFileDialog.getExistingDirectory(self.centralwidget, "选择路径", self.cwd)
        # print(openfile_dir)
        self.le_imageDir.setText(openfile_dir)
        # self.le_imageDir.setReadOnly()
        # 寻找路径下以*jpg结尾的图像,放进列表中
        imgdir = QDir(openfile_dir)
        if not imgdir.exists():
            return

        imgList = imgdir.entryList(['*.bmp', '*.jpg', '*.png'],
                                   QtCore.QDir.NoFilter,
                                   QtCore.QDir.Name | QtCore.QDir.IgnoreCase)
        cnt = len(imgList)
        self.tableImgList.clear()
        self.tableImgList.setRowCount(cnt)
        self.tableImgList.setColumnCount(2)
        for i in range(0, cnt):
            self.tableImgList.setItem(i, 0, QTableWidgetItem(imgList[i]))
            self.tableImgList.setItem(i, 1, QTableWidgetItem(imgList[i]))
        #    #  ui->preFileList->setItem(i,0,new QTableWidgetItem(QString::number(i+1)));
        #     ui->preFileList->setItem(i,1,new QTableWidgetItem(m_fileNameList[i]));
        self.tableImgList.setEditTriggers(QAbstractItemView.NoEditTriggers)  #禁止编辑
        # 取列表的第一张
        if cnt > 0:
            imgName = openfile_dir +"/" + imgList[0]
            # print(imgName)
            jpg = QPixmap(imgName).scaled(self.la_image.width(), self.la_image.height())
            # 显示原图
            self.la_image.setPixmap(jpg)

        self.tabWidget.setCurrentIndex(1)

点击列表切换图像

    def drawImage(self):   # 点击列表中的图像并显示

        open_dir = self.le_imageDir.text()
        if len(open_dir) == 0:
            return
        # print(open_dir)
        imgFile = open_dir + "/" + self.tableImgList.currentItem().text()
        # self.tableImgList.currentItem().background(QBrush(QColor(255, 0, 0)))  # 设置选中的单元格颜色
        jpg = QPixmap(imgFile).scaled(self.la_image.width(), self.la_image.height())
        # 显示原图
        self.la_image.setPixmap(jpg)

读取配置文件

支持使用默认配置及手动输入

# 读取配置文件
    def readConfig(self):
        import yaml
        config = "./config.yaml"
        with open(config, 'r', encoding='utf8') as file:
            d = yaml.safe_load(file.read())
        self.m_MaxIterationEd.setText(str(d['epochs']))
        # print(self.train.epochs)
        self.m_BathSizeEd.setText(str(d['batch_size']))
        self.m_ImageSizeEd.setText(str(d['image_size']))
        self.m_ValidationRatioEd.setText(str(d['validation_ratio']))
        self.m_LearningRateEd.setText(str(d['lr']))
        self.m_WeightDecayEd.setText(str(d['dw']))
        if d['CUDA']:
            self.m_isCuda.setChecked(True)
        self.m_modelSaveEd.setText(d['save_dir'])
        if not os.path.exists(d['save_dir']):
            os.mkdir(d['save_dir'])

手动修改训练超参数

def finishSetting(self):
        # 判断不能漏填

        if len(self.m_MaxIterationEd.text()) == 0:
            QMessageBox.information(self, '警告', '训练次数不能为空')
            return

        if len(self.m_BathSizeEd.text()) == 0:
            QMessageBox.information(self, '警告', '批尺寸不能为空')
            return
        if len(self.m_ImageSizeEd.text()) == 0:
            QMessageBox.information(self, '警告', '图像尺寸不能为空')
            return

        if len(self.m_ValidationRatioEd.text()) == 0:
            QMessageBox.information(self, '警告', '验证集比例不能为空')
            return

        if len(self.m_LearningRateEd.text()) == 0:
            QMessageBox.information(self, '警告', '学习率不能为空')
            return
        if len(self.m_WeightDecayEd.text()) == 0:
            QMessageBox.information(self, '警告', '权重衰减系数不能为空')
            return

        self.train.epochs = int(self.m_MaxIterationEd.text())
        self.train.batch_size = int(self.m_BathSizeEd.text())
        self.train.image_size = int(self.m_ImageSizeEd.text())
        self.train.validation_ratio = float(self.m_ValidationRatioEd.text())
        self.train.lr = float(self.m_LearningRateEd.text())
        self.train.dw = float(self.m_WeightDecayEd.text())
        self.train.Cuda = self.m_isCuda.isChecked()

        # 将内容写到文本文档中
        reply = QMessageBox.information(self, "通知", "参数配置完成", QMessageBox.Ok)
        if reply == QMessageBox.Ok:
            self.doShake()

选择模型保存路径

    def selectSaveDir(self):
        save_dir = QFileDialog.getExistingDirectory(self.centralwidget, "选择路径", self.cwd)

        self.m_modelSaveEd.setText(save_dir)

开始训练

初始化模型

    def trainer(self):
        self.m_modelTrainProcesssbar.setMaximum(self.train.epochs)
        open_dir = self.le_imageDir.text()
        if len(open_dir) == 0:
            return
        save_dir = self.m_modelSaveEd.text()
        self.logFile = open(save_dir + '/log.txt', 'a')
        self.t = Thread(open_dir,self.train, save_dir)
        self.t.valueChange.connect(self.m_modelTrainProcesssbar.setValue)
        self.t.textChange.connect(self.updateText)
        # self.t.start()
        self.m_initModelBtn.setEnabled(False)
        self.m_startTrainBtn.setEnabled(True)

开始训练

    def startTrainer(self):
        self.t.start()
        self.m_startTrainBtn.setEnabled(False)
        self.m_pauseTrainBtn.setEnabled(True)

中止训练

    def suspendTrainer(self):
        self.t.pause()
        self.m_pauseTrainBtn.setEnabled(False)
        self.m_resumTrainBtn.setEnabled(True)

继续训练

    def wakeTrainer(self):
        self.t.resume()
        self.m_pauseTrainBtn.setEnabled(True)
        self.m_resumTrainBtn.setEnabled(False)

停止训练

    def stopTrainer(self):
        reply = QMessageBox.question(self, "停止训练", "您确定要停止训练吗?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No)
        if reply == QMessageBox.Yes:
            self.t.stop()
            self.m_initModelBtn.setEnabled(True)
            self.logFile.close()
        else:
            return


实时显示训练loss

    # 多行文本展示框
    def updateText(self, val, text):
        message = "epoch:{0}".format(val)
        msg = message + " " + "loss:" + text
        self.logFile.write('\n{}'.format(msg))
        self.textBrowser.append(msg)
        self.textBrowser.moveCursor(self.textBrowser.textCursor().End)

点击按钮实现窗口抖动

    # 点击Ok 后窗口抖动一下
    def doShake(self):
        self.doShakeWindow(self)

    # 下面这个方法可以做成这样的封装给任何控件
    def doShakeWindow(self, target):
        """窗口抖动动画
        :param target:        目标控件
        """
        if hasattr(target, '_shake_animation'):  # 函数用于判断对象是否包含对应的属性
            # 如果已经有该对象则跳过
            return

        animation = QPropertyAnimation(target, b'pos', target)
        target._shake_animation = animation
        animation.finished.connect(lambda: delattr(target, '_shake_animation'))  #删除属性

        pos = target.pos()
        x, y = pos.x(), pos.y()

        animation.setDuration(200)
        animation.setLoopCount(2)
        animation.setKeyValueAt(0, QPoint(x, y))
        animation.setKeyValueAt(0.09, QPoint(x + 2, y - 2))
        animation.setKeyValueAt(0.18, QPoint(x + 4, y - 4))
        animation.setKeyValueAt(0.27, QPoint(x + 2, y - 6))
        animation.setKeyValueAt(0.36, QPoint(x + 0, y - 8))
        animation.setKeyValueAt(0.45, QPoint(x - 2, y - 10))
        animation.setKeyValueAt(0.54, QPoint(x - 4, y - 8))
        animation.setKeyValueAt(0.63, QPoint(x - 6, y - 6))
        animation.setKeyValueAt(0.72, QPoint(x - 8, y - 4))
        animation.setKeyValueAt(0.81, QPoint(x - 6, y - 2))
        animation.setKeyValueAt(0.90, QPoint(x - 4, y - 0))
        animation.setKeyValueAt(0.99, QPoint(x - 2, y + 2))
        animation.setEndValue(QPoint(x, y))

        animation.start(animation.DeleteWhenStopped)

信号和槽

    def initUi(self):
        self.showHomepage()
        self.tabWidget.setCurrentIndex(0)
        # 设置按钮一闪闪的
        aniButton = AnimationShadowEffect(Qt.blue, self.m_ImageDirBtn)
        self.m_ImageDirBtn.setGraphicsEffect(aniButton)
        aniButton.start()

        self.m_ImageDirBtn.clicked.connect(self.openfiledir)
        self.tableImgList.itemSelectionChanged.connect(self.drawImage)  # 这
        self.m_OkBtn.clicked.connect(self.finishSetting)  #
        self.m_modelSaveBtn.clicked.connect(self.selectSaveDir)
        self.m_initModelBtn.clicked.connect(self.trainer)

        self.m_startTrainBtn.clicked.connect(self.startTrainer)
        self.m_pauseTrainBtn.clicked.connect(self.suspendTrainer)
        self.m_resumTrainBtn.clicked.connect(self.wakeTrainer)
        self.m_stopTrainBtn.clicked.connect(self.stopTrainer)


        self.m_SetTrainPareBtn.clicked.connect(self.second)
        self.m_StartTrainBtn.clicked.connect(self.third)
        self.m_modelTrainProcesssbar.setValue(0)
        self.readConfig()
        self.statusShowTime()
        self.m_loadmodelBtn.clicked.connect(self.loadmodelPath)
        # keyboard.add_hotkey('alt+s', self.onShow, suppress=False)  # 显示界面
        # keyboard.add_hotkey('ctrl+s', self.onHide, suppress=False)  # 隐藏界面
        self.m_startTrainBtn.setEnabled(False)
        self.m_pauseTrainBtn.setEnabled(False)
        self.m_resumTrainBtn.setEnabled(False)

退出软件

    def closeEvent(self, event):

        """
        对MainWindow的函数closeEvent进行重构
        退出软件时结束所有进程
         """
        reply = QMessageBox.question(self,
                                       '本程序',
                                       "是否要退出程序?",
                                       QMessageBox.Yes | QMessageBox.No,
                                       QMessageBox.No)
        if reply == QMessageBox.Yes:
            event.accept()
            os._exit(0)
        else:

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/shuaijieer/article/details/128377887
今日推荐