Test (index calculation)

T est (index calculation) Test (index calculation) T E S T ( means standard basis calculated )

import torch as t
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader

from evalution_segmentaion import eval_semantic_segmentation
from dataset import LoadDataset
from Models import FCN
import cfg

# 指定设备
device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
# 分类数
num_class = cfg.DATASET_Num_Class
# 批次大小
BATCH_SIZE = 4
# 初始化指标
miou_list = [0]

# 数据读取
Load_test = LoadDataset([cfg.TEST_ROOT, cfg.TEST_LABEL], cfg.crop_size)
test_data = DataLoader(Load_test, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# 模型读入
net = FCN.FCN(num_class)
net.eval()
net.to(device)
net.load_state_dict(t.load("./Results/weights/xxx.pth"))

# 指标初始化
train_acc = 0
train_miou = 0
train_class_acc = 0
train_mpa = 0
error = 0

for i, sample in enumerate(test_data):
	data = Variable(sample['img']).to(device)
	label = Variable(sample['label']).to(device)
	out = net(data)
	out = F.log_softmax(out, dim=1)

	pre_label = out.max(dim=1)[1].data.cpu().numpy()
	pre_label = [i for i in pre_label]

	true_label = label.data.cpu().numpy()
	true_label = [i for i in true_label]

	eval_metrix = eval_semantic_segmentation(pre_label, true_label)
	train_acc = eval_metrix['mean_class_accuracy'] + train_acc
	train_miou = eval_metrix['miou'] + train_miou
	train_mpa = eval_metrix['pixel_accuracy'] + train_mpa
	# 校验
	if len(eval_metrix['class_accuracy']) < 12:
		eval_metrix['class_accuracy'] = 0
		train_class_acc = train_class_acc + eval_metrix['class_accuracy']
		error += 1
	else:
		train_class_acc = train_class_acc + eval_metrix['class_accuracy']

	print(eval_metrix['class_accuracy'], '================', i)


epoch_str = ('test_acc :{:.5f} ,test_miou:{:.5f}, test_mpa:{:.5f}, test_class_acc :{:}'.format(train_acc /(len(test_data)-error),
															train_miou/(len(test_data)-error), train_mpa/(len(test_data)-error),
															train_class_acc/(len(test_data)-error)))

if train_miou/(len(test_data)-error) > max(miou_list):
	miou_list.append(train_miou/(len(test_data)-error))
	print(epoch_str+'==========last')

Guess you like

Origin blog.csdn.net/qq_41375318/article/details/108572108