Reference blog post: https://blog.csdn.net/qq_37652891/article/details/123932772
Dataset preparation
Multi-category semantic segmentation of remote sensing images, divided into 7 categories in total (including background)
image:
label_rgb
label (here is not all black, the category value is 0,1,2,3,4,5,6
), and the subsequent training uses the same data
Data address
Baidu cloud: https://pan.baidu.com/s/1zZHnZfBgVWxs6TJW4yjeeQ
Extraction code: 2022
SwinUNet code address
Dataset processing
image
The sum of the dataset label
, which should provide rgb
the format label and 0,1,2,3,4,5,6
the label containing the value, SwinUNet
using the included 0,1,2,3,4,5,6
label image;
1. Dataset
The data set is stored in SwinUNet
the root directory, image
the middle is the original image, label
and the middle is the label image (a total of 7 categories, the value of the label 0,1,2,3,4,5,6,7
);
if you use other data sets, pay attention to the value of the label. For example, if it is a binary classification. That is, the label 0
or 255
, needs to be replaced with 0
or1
—SwinUNet
---------configs
---------img_datas
---------------train
--------------------image
--------------------label
---------------test
--------------------image
--------------------label
2. SwinUnet
Create npz.py
a file in the root directory and run npz.py
the file
import glob
import cv2
import numpy as np
import os
def npz(im, la, s):
images_path = im
labels_path = la
path2 = s
images = os.listdir(images_path)
for s in images:
image_path = os.path.join(images_path, s)
label_path = os.path.join(labels_path, s)
image = cv2.imread(image_path)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
# 标签由三通道转换为单通道
label = cv2.imread(label_path, flags=0)
# 保存npz文件
np.savez(path2+s[:-4]+".npz",image=image,label=label)
npz('./img_datas/train/image/', './img_datas/train/label/', './data/Synapse/train_npz')
npz('./img_datas/test/image/', './img_datas/test/label/', './data/Synapse/test_vol_h5')
3. SwinUnet
Create txt.py
a file in the root directory and run txt.py
the file
The purpose is to generate ./list/list_Synapse/train.txt
and ./list/list_Synapse/test_vol.txt
document
import os
def write_name(np, tx):
#npz文件路径
files = os.listdir(np)
#txt文件路径
f = open(tx, 'w')
for i in files:
#name = i.split('\\')[-1]
name = i[:-4]+'\n'
f.write(name)
write_name('./data/Synapse/train_npz', './lists/lists_Synapse/train.txt')
write_name('./data/Synapse/test_vol_h5', './lists/lists_Synapse/test_vol.txt')
4. Download the pre-trained weights and put them in the folder SwinUnet
under the directorypretrained_ckpt
Link: https://pan.baidu.com/s/1-hYwJRlr95Fv08e9AEARww
Extraction code:2022
modify network
1. Modify train.py
the file
The more important thing is the number of categories , others depend on the situation
2. Modify ./datasets/dataset_synapse.py
the file
3. Modify trainer.py
the file
don't know why here
4. Run the code
These information can be passed in as hyperparameters. If not, you can use default=
the method to write the default value
. If the default value is set, then run python train.py
it.