Faster-RCNN code interpretation 3: Make your own data loader

Faster-RCNN code interpretation 3: Make your own data loader

foreword

​ Because I plan to try the reproduction of Faster-RCNN recently, don't think too much, I am not good enough to reproduce all the codes by myself. Therefore, it is to refer to other people's code and perform your own interpretation.

​The code comes from the UP master of station B (big brother 666) , who put the code on GitHub, and I put the link below (it should not be considered an infringement, after all, the code is open source_ ) :

b站链接:https://www.bilibili.com/video/BV1of4y1m7nj/?vd_source=afeab8b555e5eb1bfa1e7f267262cbf2

GitHub链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing

Purpose

​ In fact, the UP master has made a good video explaining his code, but sometimes I still like to read blogs to learn, and the video is very long, 6 hours, I tend to fall asleep when watching it_, so I plan to write Blog to record study notes.

What's done so far

​Part 1 : Detailed introduction of VOC dataset

​Part 2 : Faster-RCNN Code Interpretation 2: Getting Started Quickly

​Part III : Faster-RCNN code interpretation 3: Make your own data loader (this article)

Directory Structure

1 Introduction:

​ In fact, this part is relatively simple (if you have seen my previous image classification loader implementation or implemented it yourself), it is to define a datasetclass.

2. Interpretation of my_dataset.py file:

We know that if you want to define your own dataset class, you first need to inherit from torch's Dataset class, and you need to define at least three methods, namely, __init__, __len__and __getitem__.

​ Then, the general framework can be written:

class VOCDataSet(Dataset):
    """读取解析PASCAL VOC2007/2012数据集"""

    def __init__(self):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
		pass

​ Okay, let's implement them one by one.

2.1 init method:

​ First of all, we need to define our input parameters. If we implement it from scratch here, it is estimated that we need to think of what parameters to use. However, if we interpret it, we will directly look at the parameters defined by the author:

  • voc_root: the root directory where the dataset is located
  • year: Specify whether to read the data set of 2007 or 2012, the default is 2012
  • transforms: preprocessing method, default is None
  • txt_name: Specifies whether to load the training set or the test set, the default is the training set, ie train.txt

​ Next, the first step is to increase the fault tolerance of the code, which is to judge whether the incoming parameters are correct or not, and splicing out the required path:

# 判断是不是2007或2012,否则报错
assert year in ["2007", "2012"], "year must be in ['2007', '2012']"
# 增加容错能力
if "VOCdevkit" in voc_root:
    # 如果传入的参数为:.\VOCdevkit,那么直接拼接为.\VOCdevkit\VOC2012
    self.root = os.path.join(voc_root, f"VOC{
      
      year}")
else:
    # 如果传入的参数为:. ,那么直接拼接为.\VOCdevkit\VOC2012
    self.root = os.path.join(voc_root, "VOCdevkit", f"VOC{
      
      year}")
# 拼接路径,即图片路径和注释路径
self.img_root = os.path.join(self.root, "JPEGImages")
self.annotations_root = os.path.join(self.root, "Annotations")

​ The second step is to read the training set or test set txt file in the dataset .\VOCdevkit\VOC2012\ImageSets\Main(if you don’t know why, you can read the first article, Introduction to the VOC dataset), and splice the values ​​and suffixes in it xmlinto Annotation files for training or test sets:

# 读取train或者val文件
txt_path = os.path.join(self.root, "ImageSets", "Main", txt_name)
assert os.path.exists(txt_path), "not found {} file.".format(txt_name)
# 然后,将文件名(2007_000027)和后缀拼接在一起,这样才是真实的文件
with open(txt_path) as read:
    xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                        for line in read.readlines() if len(line.strip()) > 0]

​ The third step is to read the xml files one by one and convert the contents into dictionary values. The main purpose is to check whether there is any problem with the xml file:

# 定义真正的xml列表
self.xml_list = []
# 检测所有xml文件是否存在并读取内容
for xml_path in xml_list:
    if os.path.exists(xml_path) is False:
        print(f"Warning: not found '{
      
      xml_path}', skip this annotation file.")
        continue
    # 如果xml文件存在,继续下面的代码
    # check for targets
    # 读取xml文件
    with open(xml_path) as fid:
    	xml_str = fid.read()
    # 构建xml对象
    xml = etree.fromstring(xml_str)
    # 获取节点的内容,并转为字典值
    data = self.parse_xml_to_dict(xml)["annotation"] # 获取annotation节点下的所有内容
    if "object" not in data: # 判断object节点是否存在,如果不存在说明xml文件其实有问题,所以需要跳过
        print(f"INFO: no objects in {
      
      xml_path}, skip this annotation file.")
        continue
    # 添加
    self.xml_list.append(xml_path)

​ The fourth step is to load the category json file and read the contents inside:

# 读取类别文件,一共20个类,从1开始是因为0留给背景
json_file = './pascal_voc_classes.json'
assert os.path.exists(json_file), "{} file not exist.".format(json_file)
with open(json_file, 'r') as f:
	self.class_dict = json.load(f)

​ Finally, put the preprocessing function into a variable:

self.transforms = transforms

​ **Summary:**After the above processing, we get several main variables:

  • self.xml_list: the value inside is the xml file of each training set or test set, and the value inside is the file path value
  • self.transforms: inside is our preprocessing method
  • self.class_dict: It is our class dictionary, the value inside is in the form of {'preson':2}

Let me show you the content of the value under debug:

insert image description here

2.2 len method:

​ The len method, this is the simplest method, its function is to return the length value:

def __len__(self):
    # len函数就是返回长度
    return len(self.xml_list)

2.3 getitem method:

​ This method is as important as the init method, and its function is to obtain information such as the image and the label corresponding to the image.

def __getitem__(self, idx):
	pass

​ Among them, idx is a necessary parameter of this method, which randomly returns an index value to facilitate you to get the value in the variable you defined in the init method before.

​ So, first, get an xml file and open it to get the contents of the root node:

# 随机读取一个xml文件
xml_path = self.xml_list[idx]
with open(xml_path) as fid:
	xml_str = fid.read()
# 创建xml对象
xml = etree.fromstring(xml_str)
# 获取根节点,转为字典值
data = self.parse_xml_to_dict(xml)["annotation"]

​ Here is an explanation of what the data value above is. In fact, it is all the content in the annotation node of the xml file, as shown in the following figure:

insert image description here

​ Of course, also use debug to see the real value inside:

insert image description here

​ Then, **we know that the xml file name corresponds to the image name,** so get the image name through the xml file and open the image:

# 获取xml文件对应的图像路径
img_path = os.path.join(self.img_root, data["filename"])
# 打开图像
image = Image.open(img_path)
# 判断图像是否为jpeg格式,主要作者防止别人插入了其它的文件
if image.format != "JPEG":
	raise ValueError("Image '{}' format not JPEG".format(img_path))

Next, initialize some variables:

# 初始化一些变量
boxes = []		# 边界框
labels = []		# 标签值
iscrowd = []	# 是否为难以识别的图像

​Beginning below .

​ First, iteratively read objectthe content under the node of the xml file:

# 读取xml文件中object节点下的内容
for obj in data["object"]:

Among them, obj is the value in the following figure:

insert image description here

Or you can view it from the xml file:

insert image description here

​ Next, get the coordinates of the real bounding box of the object (upper left corner, lower right corner): (ps: the following codes are all placed in the above for loop)

# 获取bbox框的坐标
xmin = float(obj["bndbox"]["xmin"])
xmax = float(obj["bndbox"]["xmax"])
ymin = float(obj["bndbox"]["ymin"])
ymax = float(obj["bndbox"]["ymax"])

​ Check if there is a problem with the bounding box:

# 进一步检查数据,有的标注信息中可能有w或h为0的情况,这样的数据会导致计算回归loss为nan
if xmax <= xmin or ymax <= ymin:
    print("Warning: in '{}' xml, there are some bbox w/h <=0".format(xml_path))
    continue

​ Then, add the coordinate value to the boxes variable, add the label to the labels variable, and judge whether the image is difficult to recognize, and then add it to the iscrowd variable:

boxes.append([xmin, ymin, xmax, ymax])
# 添加标签  obj["name"]=person,  self.class_dict[obj["name"]] = 15
labels.append(self.class_dict[obj["name"]])
# 判断是否为difficult类型
if "difficult" in obj:
    iscrowd.append(int(obj["difficult"]))
    else:
        iscrowd.append(0)

​ Then, convert all variable types to tensor format (the loop has ended at this point):

# 将所有的类型转为tensor类型
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
image_id = torch.tensor([idx])

​ Next, calculate the area of ​​the bounding box according to the four coordinates of the bounding box, which is mainly convenient for later calculation of IOU:

#  boxes =[[,,,],[,,,],。。。。。。]
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# (ymax - ymin) * (xmax - xmin) ,即框的面积

​ Finally, put all the above values ​​into a dictionary variable:

# 把这些东西放入一个字典中
target = {
    
    }
target["boxes"] = boxes
target["labels"] = labels
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd

​ Then, preprocess the image and return the image and its corresponding value:

# 变换,此时为自己实现的方法,不是官方的方法
if self.transforms is not None:
	image, target = self.transforms(image, target)
return image, target

​ Finally, we look at the value of the variable under debug:

insert image description here

2.4 Auxiliary method: get_height_and_width

​ Role: Get the width and height of the image.

​ This is very simple, it is obtained through the xml file, and we do not need to calculate the coordinates ourselves:

def get_height_and_width(self, idx):
    # 获取图像的宽和高
    # 读取xml
    xml_path = self.xml_list[idx]
    with open(xml_path) as fid:
		xml_str = fid.read()
    # 构建xml对象
    xml = etree.fromstring(xml_str)
    # 获取根节点
    data = self.parse_xml_to_dict(xml)["annotation"]
    # 获取宽和高
    data_height = int(data["size"]["height"])
    data_width = int(data["size"]["width"])
    return data_height, data_width

2.5 Auxiliary method: parse_xml_to_dict

​ Main function: parse the data in xml format into dictionary format, that is, convert the value of node ----- node into {'node':'value of node'}.

​ This method is implemented through recursion . There is nothing to say about this. If you want to figure out how it works, you can deduce it step by step by yourself:

def parse_xml_to_dict(self, xml):
    """
    将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
    """

    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        # xml.tag节点名字
        # xml.text里面的值
        return {
    
    xml.tag: xml.text}

    result = {
    
    }
    # 对于每个xml中的子节点
    for child in xml:
        child_result = self.parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
	        result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {
    
    xml.tag: result}

2.6 Auxiliary method: coco_index

​This method has the same function as the getitem method , but the process is the same without reading the picture, so I won't go into details.

3. Summary:

​ The my_dataset.py file mainly implements the data loader class. The implementation idea is very simple, but the amount of code is still relatively large.

​ In addition, the author shows the sample code for using this class at the end of the file. You can cancel the comment and run it to see the result:

insert image description here

Guess you like

Origin blog.csdn.net/weixin_46676835/article/details/130166036