Use SwinUnet to train your own dataset

Reference blog post: https://blog.csdn.net/qq_37652891/article/details/123932772

Dataset preparation

Multi-category semantic segmentation of remote sensing images, divided into 7 categories in total (including background)
insert image description here
image:
insert image description here
label_rgb
insert image description here
label (here is not all black, the category value is 0,1,2,3,4,5,6), and the subsequent training uses the same data
insert image description here

Data address
Baidu cloud: https://pan.baidu.com/s/1zZHnZfBgVWxs6TJW4yjeeQ

Extraction code: 2022

SwinUNet code address

Dataset processing

imageThe sum of the dataset label, which should provide rgbthe format label and 0,1,2,3,4,5,6the label containing the value, SwinUNetusing the included 0,1,2,3,4,5,6label image;

1. Dataset

The data set is stored in SwinUNetthe root directory, imagethe middle is the original image, labeland the middle is the label image (a total of 7 categories, the value of the label 0,1,2,3,4,5,6,7);
if you use other data sets, pay attention to the value of the label. For example, if it is a binary classification. That is, the label 0or 255, needs to be replaced with 0or1

—SwinUNet
---------configs
---------img_datas
---------------train
--------------------image
--------------------label
---------------test
--------------------image
--------------------label

2. SwinUnetCreate npz.pya file in the root directory and run npz.pythe file

import glob
import cv2
import numpy as np
import os

def npz(im, la, s):
    images_path = im
    labels_path = la
    path2 = s
    images = os.listdir(images_path)
    for s in images:
        image_path = os.path.join(images_path, s)
        label_path = os.path.join(labels_path, s)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
		# 标签由三通道转换为单通道
        label = cv2.imread(label_path, flags=0)
        # 保存npz文件 
        np.savez(path2+s[:-4]+".npz",image=image,label=label)

npz('./img_datas/train/image/', './img_datas/train/label/', './data/Synapse/train_npz')

npz('./img_datas/test/image/', './img_datas/test/label/', './data/Synapse/test_vol_h5')

3. SwinUnetCreate txt.pya file in the root directory and run txt.pythe file

The purpose is to generate ./list/list_Synapse/train.txtand ./list/list_Synapse/test_vol.txtdocument

import os
def write_name(np, tx):
    #npz文件路径
    files = os.listdir(np)
    #txt文件路径
    f = open(tx, 'w')
    for i in files:
        #name = i.split('\\')[-1]
        name = i[:-4]+'\n'
        f.write(name)
        
write_name('./data/Synapse/train_npz', './lists/lists_Synapse/train.txt')
write_name('./data/Synapse/test_vol_h5', './lists/lists_Synapse/test_vol.txt')

4. Download the pre-trained weights and put them in the folder SwinUnetunder the directorypretrained_ckpt

Link: https://pan.baidu.com/s/1-hYwJRlr95Fv08e9AEARww
Extraction code:2022

insert image description here

modify network

1. Modify train.pythe file

insert image description here
The more important thing is the number of categories , others depend on the situation
insert image description here

2. Modify ./datasets/dataset_synapse.pythe file

insert image description here

3. Modify trainer.pythe file

don't know why here
insert image description here

4. Run the code

These information can be passed in as hyperparameters. If not, you can use default=the method to write the default value
insert image description here
. If the default value is set, then run python train.pyit.
insert image description here

Guess you like

Origin blog.csdn.net/weixin_44669966/article/details/125623961