[swinUnet official code to test its own data set (has been trained)]

**************************************************** *

It's not easy to code words. Besides collecting, don't forget to give me a like!

**************************************************** *

---------Start

First refer to the training process in the previous article , because the weights obtained by training are required for testing.

1. Check related documents

1.1 Check whether the content of test_vol.txt is the name of the npz file for testing

insert image description here
npz file for the test set
insert image description here

1.2 Check the model weights file

insert image description here

2. Modify some code

2.1 Modify dataset_synapse.py

insert image description here

            slice_name = self.sample_list[idx].strip('\n')
            data_path = os.path.join(self.data_dir, slice_name+'.npz')
            data = np.load(data_path)
            image, label = data['image'], data['label']
            #改,numpy转tensor
            image = torch.from_numpy(image.astype(np.float32))
            image = image.permute(2,0,1)
            label = torch.from_numpy(label.astype(np.float32))

2.2 Modify the test.py code

Modify related parameters and file paths is_savenii : whether to
save the prediction result pictures

insert image description here


insert image description here

insert image description here

insert image description here

2.3 Modify the util.py code (two cases)

The first case: save the original image of the prediction, the saved result is a grayscale image, and the value of each pixel represents which category the pixel belongs to. For example (0: background, 1: target 1, 2: target 2...), this is a full black image.


def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    _, x, y = image.shape

    # 缩放图像符合网络输入大小224x224
    if x != patch_size[0] or y != patch_size[1]:
        image = zoom(image, (1, patch_size[0] / x, patch_size[1] / y), order=3)
    input = torch.from_numpy(image).unsqueeze(0).float().cuda()
    net.eval()
    with torch.no_grad():
        out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)
        out = out.cpu().detach().numpy()
        # 缩放预测结果图像同原始图像大小
        if x != patch_size[0] or y != patch_size[1]:
            prediction = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
        else:
            prediction = out
    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(prediction == i, label == i))

    if test_save_path is not None:
        #保存预测结果
        prediction = Image.fromarray(np.uint8(prediction)).convert('L')
        prediction.save(test_save_path + '/' + case + '.png')
    return metric_list

Second case: save the visible image, map different classes to different colors. Just replace the content of if test_save_path is not None: in the above code with the following code.

        #将不同类别区域呈彩色展示
        #2分类 背景为黑色,类别1为绿色
    if test_save_path is not None:
        a1 = copy.deepcopy(prediction)
        a2 = copy.deepcopy(prediction)
        a3 = copy.deepcopy(prediction)
        #r通道
        a1[a1 == 1] = 0
		#g通道
        a2[a2 == 1] = 255
		#b通道
        a3[a3 == 1] = 0
        a1 = Image.fromarray(np.uint8(a1)).convert('L')
        a2 = Image.fromarray(np.uint8(a2)).convert('L')
        a3 = Image.fromarray(np.uint8(a3)).convert('L')
        prediction = Image.merge('RGB', [a1, a2, a3])
        prediction.save(test_save_path+'/'+case+'.png')

At this point, the setting is completed, right-click to run, if the following result appears in the console, it means that the operation is correct. The weight here is only trained for one epoch, so the predictions are all 0.
insert image description here

3. View the forecast results

View the log file
insert image description here
View the prediction result graph
insert image description here

Summary: swinUnet is mainly composed of the swin_transform module. When the amount of data is too small, the training effect is very poor, which cannot be compared with TransUnet. Due to the limitations of only expressing some operations in words, it can only be briefly described. If you have any questions, you can leave a comment or private message below. Sorry for the lack of reply, thank you very much!

Guess you like

Origin blog.csdn.net/qq_37652891/article/details/123938713