>>>This article refers to the official tutorial given by Pytorch to further refine it<<<
---- The following is the detailed address used by Pytorch's official teaching Tensorboard. If you are interested, you can also refer to it.
----I usually don’t know what functions you will use when using Tensorboard.Personally, I mainly use four of the functions:
> ①Save the structure diagram of the network: In the GRAPHS of Tensorboard, there will be a structure diagram of the model, and each module of the entire model can be seen more clearly.
> ②Save training_loss, validation set acc and learning_rate changes in SCALARS of Tensorboard.
> ③Check the distribution of the weight value of each layer structure in the HISTOGRAMS of Tensorboard.
> ④Save some information of the predicted picture. In the IMAGES of Tensorboard, the prediction results of each step of some given pictures are saved
>>>The network model used in this article is ResNet. I won’t elaborate on its principles and how to build it. Next, let’s move on to the topic<<<
First enter the project, you need to create a folder, the pictures saved in it will be predicted during the training process and the results will be added to the tensorboard. In addition, you also need to prepare a label.txt file corresponding to the label of the picture, as shown in the following figure:
-> Download method of ResNet pre-training weights: import torchvision.models.resnet ctrl+click resnet in the statement to download the weights you want.But this experiment does not use pre-training weights, Because if you use the pre-training weights, you will find that acc and loss are basically unchanged, and the accuracy of the first epoch of training has reached 97%
Step-by-step instructions for tensorboard statements:
Instantiate the SummaryWriter object. The parameter is the folder where the tensorboard file is saved. After the statement is executed, the tensorboard file will be automatically created and saved.
After instantiating the model, a zero matrix is also created. Why do you want to create this zero matrix? : When adding a network structure diagram, it needs to be passed into the model for forward propagation. The network structure diagram is created according to the process of forward propagation of this matrix in the model, so as long as the size of the matrix and the picture are the same.
After each epoch, that is, after the validation set code is executed, the current epoch training set average loss, validation set acc, and learning_rate will be saved. -------------Note: the use of the tb_writer.add_scalar method: the first parameter is the label; the second parameter is the statistical data during the training process, the value here is not tensor, but floating-point data; the third parameter is the current training step. Use of the tb_writer.add_figure method: add the prediction result of the specified picture, draw it into a picture, and save it in tensorboard. The first parameter is the title of the drawn picture, the second parameter is the figure object, and the third parameter is the current training. step.
Fig is formed by the plot_class_preds() function: its parameter transform corresponds to the image preprocessing used in the verification set, and parameter four is how many pictures to display:
plot_class_preds(net,
images_dir:str,
transform,
num_plot:int=5,
device="cpu"):ifnot os.path.exists(images_dir):print("not found {} path, ignore add figure.".format(images_dir))returnNone
label_path = os.path.join(images_dir,"label.txt")ifnot os.path.exists(label_path):print("not found {} file, ignore add figure".format(label_path))returnNone# read class_indict
json_label_path ='./class_indices.json'assert os.path.exists(json_label_path),"not found {}".format(json_label_path)
json_file =open(json_label_path,'r')# {"0": "daisy"}
flower_class = json.load(json_file)# {"daisy": "0"}
class_indices =dict((v, k)for k, v in flower_class.items())# reading label.txt file
label_info =[]withopen(label_path,"r")as rd:for line in rd.readlines():
line = line.strip()iflen(line)>0:
split_info =[i for i in line.split(" ")iflen(i)>0]assertlen(split_info)==2,"label format error, expect file_name and class_name"
image_name, class_name = split_info
image_path = os.path.join(images_dir, image_name)# 如果文件不存在,则跳过ifnot os.path.exists(image_path):print("not found {}, skip.".format(image_path))continue# 如果读取的类别不在给定的类别内,则跳过if class_name notin class_indices.keys():print("unrecognized category {}, skip".format(class_name))continue
label_info.append([image_path, class_name])iflen(label_info)==0:returnNone# get first num_plot infoiflen(label_info)> num_plot:
label_info = label_info[:num_plot]
num_imgs =len(label_info)
images =[]
labels =[]for img_path, class_name in label_info:# read img
img = Image.open(img_path).convert("RGB")
label_index =int(class_indices[class_name])# preprocessing
img = transform(img)
images.append(img)
labels.append(label_index)# batching images
images = torch.stack(images, dim=0).to(device)# inferencewith torch.no_grad():
output = net(images)
probs, preds = torch.max(torch.softmax(output, dim=1), dim=1)
probs = probs.cpu().numpy()
preds = preds.cpu().numpy()# width, height
fig = plt.figure(figsize=(num_imgs *2.5,3), dpi=100)for i inrange(num_imgs):# 1:子图共1行,num_imgs:子图共num_imgs列,当前绘制第i+1个子图
ax = fig.add_subplot(1, num_imgs, i+1, xticks=[], yticks=[])# CHW -> HWC
npimg = images[i].cpu().numpy().transpose(1,2,0)# 将图像还原至标准化之前# mean:[0.485, 0.456, 0.406], std:[0.229, 0.224, 0.225]
npimg =(npimg *[0.229,0.224,0.225]+[0.485,0.456,0.406])*255
plt.imshow(npimg.astype('uint8'))
title ="{}, {:.2f}%\n(label: {})".format(
flower_class[str(preds[i])],# predict class
probs[i]*100,# predict probability
flower_class[str(labels[i])]# true class)
ax.set_title(title, color=("green"if preds[i]== labels[i]else"red"))return fig