从DETR backbone 的NestedTensor 到DataLoader, Sampler,collate_fn,再到DETR transformer

1.背景:

在DETR中backbone中,resnet50 的构建继承了backbonebase的类,backbonebase的前向过程如下,这里引入了NestedTensor类。

     # 前向中输入的是NestedTensor这个类的实例,实质就是将图像张量与对应的mask封装到一起。
    def forward(self, tensor_list: NestedTensor):
        xs = self.body(tensor_list.tensors)
        out: Dict[str, NestedTensor] = {}
        for name, x in xs.items():

            m = tensor_list.mask
            assert m is not None
            # 将mask插值到与输出特征图一致
            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
            out[name] = NestedTensor(x, mask)
        return out

NestedTensor,包括tensor和mask两个成员,tensor就是输入的图像。mask跟tensor同高宽但是单通道。

DETR把resnet作为backbone套到了另一个子网络里,这个子网络主要是把tensor list送进resnet网络,然后逐个提取出来其中的节点(也就是里面的Tensor),把每个节点的“mask”提出来做一次采样,然后再打包进自定义的“NestedTensor”中,按照“名称”:Tensor的方式存入输出的out。(这个NestedTensor一个Tensor里打包存了两个变量:x和mask)。

2. DETR网络下NestedTensor的前世今生示例:

2.1 输入

假如我们输入的是如下两张图片,也就说batch为2:
img1 = torch.rand(3, 200, 200),
img2 = torch.rand(3, 200, 250)

x = nested_tensor_from_tensor_list([torch.rand(3, 200, 200), torch.rand(3, 200, 250)])

这里会转成nested_tensor, 为什么要转为nested_tensor呢?

这个nestd_tensor的类型简单说就是把{tensor, mask}打包在一起, tensor就是我们的图片的值,那么mask是什么呢?

当一个batch中的图片大小不一样的时候,我们要把它们处理的整齐,简单说就是把图片都padding成最大的尺寸,padding的方式就是补零,那么batch中的每一张图都有一个mask矩阵,所以mask大小为[2, 200,250], 在img有值的地方是1,补零的地方是0,tensor大小为[2,3,200,250]是经过padding后的。

2.2 提取特征

DETR 提取特征,是把NestedTensor中的tensor, 也就是图片输入到特征提取器中。这里使用的是残差网络resnet-50,tensor经过backbone后的结果就是[2,2048,7,8],下面是残差网络最后一层的结构

(2): Bottleneck(
(conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d()
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d()
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d()
(relu): ReLU(inplace=True)

另外,关于NestedTensor中的mask, mask采用的方式F.interpolate,最后得到的结果是[2,7,8],backboneBase的前向过程如下:

class BackboneBase(nn.Module):

    def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
        super().__init__()
        for name, parameter in backbone.named_parameters():
            if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
                parameter.requires_grad_(False)
        if return_interm_layers:
            return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
        else:
            return_layers = {'layer4': "0"}

        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
        
        self.num_channels = num_channels

    def forward(self, tensor_list: NestedTensor):

        xs = self.body(tensor_list.tensors)

        out: Dict[str, NestedTensor] = {}

        for name, x in xs.items():

            m = tensor_list.mask

            assert m is not None

            mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]

            out[name] = NestedTensor(x, mask)

        return out

Featuremap的位置编码,position_embedding 的前向如下:

利用三角函数的方式获取position_embedding,输入是上面的NestedTensor={tensor,mask}, 输出最终pos的size为[1,2,256,7,8]

def forward(self, tensor_list: NestedTensor):
        #tensor_list的类型是NestedTensor,内部自动附加了mask,用于表示动态shape,是pytorch中tensor新特性
        x = tensor_list.tensors
        mask = tensor_list.mask
        assert mask is not None
        not_mask = ~mask
        #因为图像是2d的,所以位置编码也分为x,y方向
        # 1 1 1 1 ..  2 2 2 2... 3 3 3...
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        # 1 2 3 4 ... 1 2 3 4...
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
            
		#num_pos_feats = 128
		## 0~127 self.num_pos_feats=128,因为前面输入向量是256,编码是一半sin,一半cos
        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
        
		## 输出shape=b,h,w,128
        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        # 每个特征图的xy位置都编码成256的向量,其中前128是y方向编码,而128是x方向编码
        return pos
        ## b,n=256,h,w

backbone+ position_embedding 中的NestedTensor流程如下:最终输出为

NestedTensor{tensor,mask},和pos。

tensor=[ 2, 2048,7,8],mask=[2,7,8], pos=[1,2,256,7,8]

class Joiner(nn.Sequential):
    def __init__(self, backbone, position_embedding):
        super().__init__(backbone, position_embedding)

    def forward(self, tensor_list: NestedTensor):
        xs = self[0](tensor_list)
        out: List[NestedTensor] = []
        pos = []
        for name, x in xs.items():
            out.append(x)
            # position encoding
            pos.append(self[1](x).to(x.tensors.dtype))

        return out, pos

以上是DETR中关于NestedTensor输入在resnet50 backbone以及position_embedding中的历程。

3.COCO格式数据集下的NestedTensor的来龙去脉(一):

我们找到输入数据封装为NesteTensor类型的最初:

data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=utils.collate_fn, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)

在Dataloder中涉及两个重要的参数,Sample()和collate_fn()。

3.1 Dataloder:数据预处理DataLoader及各参数详解

pytorch关于数据处理的功能模块均在torch.utils.data 中,pytorch输入数据PipeLine一般遵循一个“三步走”的策略,操作顺序是这样的:

① 继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
其中__len__返回数据集样本的数量,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。在实现自定义类时,一般需要对图像数据做增强处理,和标签处理,__getitem__返回图像和对应label,图像增强的方法可以使用pytorch自带的torchvision.transforms内模块,也可以使用自定义或者其他第三方增强库。

② 导入 DataLoader类,传入参数(上面自定义类的对象) 创建一个DataLoader对象。

③ 循环遍历这个 DataLoader 对象。将img, label加载到模型中进行训练

dataset = MyDataset()           # 第一步:构造Dataset对象
dataloader = DataLoader(dataset)# 第二步:通过DataLoader来构造迭代对象

num_epoches = 100
for epoch in range(num_epoches):# 第三步:逐步迭代数据
    for img, label in dataloader:
        # 训练代码

 pytorch内部默认的数据处理类有如下:

class Dataset(object):

class IterableDataset(Dataset):

class TensorDataset(Dataset): #  封装成tensor的数据集,每一个样本都通过索引张量来获得。

class ConcatDataset(Dataset): #  连接不同的数据集以构成更大的新数据集

class Subset(Dataset):  # 获取指定一个索引序列对应的子数据集

class ChainDataset(IterableDataset):

DataLoader类详解,数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代。

class DataLoader(object):
    Arguments:
        dataset (Dataset): 是一个DataSet对象,表示需要加载的数据集.三步走第一步创建的对象
        batch_size (int, optional): 每一个batch加载多少个样本,即指定batch_size,默认是 1 
        shuffle (bool, optional): 布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是False
------------------------------------------------------------------------------------
        sampler (Sampler, optional): 自定义从数据集中抽取样本的策略,如果指定这个参数,那么shuffle必须为False
        batch_sampler (Sampler, optional): 此参数很少使用,与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥)
------------------------------------------------------------------------------------
        num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
        collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
        pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.默认是False
------------------------------------------------------------------------------------
        drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了,如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
------------------------------------------------------------------------------------

以上使用的是dataset, batch_size, shuffe, sampler, num_workers, collate_fn, pin_memory,这几个参数。

3.2 sampler参数

sampler参数其实就是一个“采样器”,表示从样本中究竟如何取样,pytorch采样器有如下几个

class Sampler(object):

class SequentialSampler(Sampler):# 顺序采样样本,始终按照同一个顺序。

class RandomSampler(Sampler): # 无放回地随机采样样本元素。

class SubsetRandomSampler(Sampler): # 无放回地按照给定的索引列表采样样本元素

class WeightedRandomSampler(Sampler): # 按照给定的概率来采样样本。

class BatchSampler(Sampler):  # 在一个batch中封装一个其他的采样器。

# torch.utils.data.distributed.DistributedSampler
class DistributedSampler(Sampler): # 采样器可以约束数据加载进数据集的子集。

 默认是采用的采样器如下:

if batch_sampler is None:  # 没有手动传入batch_sampler参数时
    if sampler is None:  # 没有手动传入sampler参数时
        if shuffle:
            sampler = RandomSampler(dataset)  # 随机采样
        else:
            sampler = SequentialSampler(dataset)  # 顺序采样
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
self.__initialized = True

3.3 collate_fn 参数

当继承Dataset类自定义类时,__getitem__方法一般返回一组类似于(image,label)的一个样本,在创建DataLoader类的对象时,collate_fn函数会将batch_size个样本整理成一个batch样本,便于批量训练。

default_collate(batch)中的参数就是这里的 [self.dataset[i] for i in indices],indices是从所有样本的索引中选取的batch_size个索引,表示本次批量获取这些样本进行训练。self.dataset[i]就是自定义Dataset子类中__getitem__返回的结果。

默认的函数default_collate(batch) 只能对大小相同image的batch_size个image整理,如[(img0, label0), (img1, label1),(img2, label2), ] 整理成([img0,img1,img2,], [label0,label1,label2,]), 这里要求多个img的size相同。

当我们的图像大小不同时,需要自定义函数callate_fn来将batch个图像整理成统一大小的,若读取的数据有(img, box, label)这种也需要自定义,因为默认只能处理(img,label)。当然你可以提前将数据集全部整理成统一大小的。

例:目标检测时的自定义collate_fn(),给每个图像添加索引

def collate_fn(self, batch):
        paths, imgs, targets = list(zip(*batch))
        # Remove empty placeholder targets  
        # 有可能__getitem__返回的图像是None, 所以需要过滤掉
        targets = [boxes for boxes in targets if boxes is not None]
        # Add sample index to targets
        # boxes是每张图像上的目标框,但是每个图片上目标框数量不一样呢,所以需要给这些框添加上索引,对应到是哪个图像上的框。
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i
        targets = torch.cat(targets, 0)
        # Selects new image size every tenth batch
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
        # Resize images to input shape
        # 每个图像大小不同呢,所以resize到统一大小
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])
        self.batch_count += 1
        return paths, imgs, targets

其实也可结合使用默认的default_collate

from torch.utils.data.dataloader import default_collate  # 导入这个函数


def collate_fn(batch):
    """
    params:
        batch :是一个列表,列表的长度是 batch_size
               列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
               大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
    returns:
        整理之后的新的batch

    """

    # 这一部分是对 batch 进行重新 “校对、整理”的代码

    return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate
进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。

tip: 在使用pytorch时,当加载数据训练for i, batch in enumerate(train_loader):时,可能会出现TypeError: ‘NoneType’ object is not callable这个错误,若遇到更换pytorch版本即可.

4.COCO格式数据集下的NestedTensor的来龙去脉(二):

collate_fn 方法来重新组装一个batch的数据:

它的作用是将一个batch的数据重新组装为自定义的形式,输入参数batch就是原始的一个batch数据,通常在Pytorch中的Dataloader中,会将一个batch的数据组装为((data1, label1), (data2, label2), ...)这样的形式,于是第一行代码的作用就是将其变为[(data1, data2, data3, ...),(label1, label2, label3,...)]这样的形式,然后取出batch[0]即一个batch的图像输入到nested_tensor _from_tensor_list()方法中进行处理,最后将返回结果替代原始的这一个batch图像数据。

def collate_fn(batch):
    batch = list(zip(*batch))
    batch[0] = nested_tensor_from_tensor_list(batch[0])
    return tuple(batch)

为了能够统一batch中所有图像的尺寸,以便形成一个batch,我们需要得到其中的最大尺度(在所有维度上),然后对尺度较小的图像进行填充(padding),同时设置mask以指示哪些部分是padding得来的,以便后续模型能够在有效区域内去学习目标,相当于加入了一部分先验知识。

nested_tensor_from_tensor_list(tensor_list: List[Tensor])实现如下:

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general

    #得到一个batch中所有图像张量每个维度的最大尺寸

    if tensor_list[0].ndim == 3:

        if torchvision._is_tracing():
            # nested_tensor_from_tensor_list() does not export well to ONNX
            # call _onnx_nested_tensor_from_tensor_list() instead
            return _onnx_nested_tensor_from_tensor_list(tensor_list)

        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])

        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))

        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)

        #指示图像中哪些位置是padding部分
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)

        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
           
             #原始图像中有效部分设为false,以区分padding
            m[: img.shape[1], :img.shape[2]] = False
          

    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

如何得到batch中每张图像在每个维度上的最大值。

def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes

5. DETR网络结构一览:

The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W],  containing 1 on padded pixels

class DETR(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_classes: number of object classes
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = transformer.d_model
        self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
        self.query_embed = nn.Embedding(num_queries, hidden_dim)
        self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
        self.backbone = backbone
        self.aux_loss = aux_loss

    def forward(self, samples: NestedTensor):
        """ The forward expects a NestedTensor, which consists of:
               - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
               - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
            It returns a dict with the following elements:
               - "pred_logits": the classification logits (including no-object) for all queries.
                                Shape= [batch_size x num_queries x (num_classes + 1)]
               - "pred_boxes": The normalized boxes coordinates for all queries, represented as
                               (center_x, center_y, height, width). These values are normalized in [0, 1],
                               relative to the size of each individual image (disregarding possible padding).
                               See PostProcess for information on how to retrieve the unnormalized bounding box.
               - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
                                dictionnaries containing the two above keys for each decoder layer.
        """
        if not isinstance(samples, NestedTensor):
            samples = nested_tensor_from_tensor_list(samples)
        features, pos = self.backbone(samples) # backbone是一个CNN用于特征提取

        src, mask = features[-1].decompose() #??
        assert mask is not None
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]  # 这里是吧features的其中一部分信息作为src传进Transformer,input_proj是一个卷积层,用来收缩输入的维度,把维度控制到d_model的尺寸(model dimension)

        outputs_class = self.class_embed(hs)  # 为了把Transformer应用于目标检测问题上,作者引入了“类别嵌入网络”和“框嵌入网络”
        outputs_coord = self.bbox_embed(hs).sigmoid()  # 在框嵌入后加入一层sigmoid输出框坐标(原论文中提到是四点坐标,但是要考虑到原图片的尺寸)
        out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_coord):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_logits': a, 'pred_boxes': b}
                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

6.NestedTensor 在transformer中的摸爬滚打:

transformer encoder

接上第2节,输入NestdTensor 经backbone和positionembedding变为[tensor,mask,pos]后的故事:目前我们拥有src=[ 2, 2048,7,8],mask=[2,7,8], pos=[1,2,256,7,8]

hs = transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]

 # transformer 的输出是元组,分别为Decoder 和Encoder 的输出,因此这里取第一个代表的是Decoder的输出

input_proj:一个卷积层,卷积核为1*1,将压缩通道的作用,将2048压缩到256,所以传入transformer的维度是压缩后的[2,256,7,8]。

 self.input_proj = nn.Conv2d(num_channels, hidden_dim, kernel_size=1)
        # input_proj是将CNN提取的特征维度映射到Transformer隐层的维度,
src, mask = features[-1].decompose()
#取backbone最后一层featuremap, 然后将特征图映射为序列形式

看DETR的前向过程:

 # 前向输入是一个NestedTensor类的对象
    def forward(self, samples, postprocessors=None, targets=None, criterion=None):
        # 首先将样本转换为NestedTensor 对象
        if isinstance(samples, (list, torch.Tensor)):
            samples = nested_tensor_from_tensor_list(samples)
        #第一部分如下所示,先利用CNN提取特征,然后将特征图映射为序列形式,最后输入Transformer进行编、解码得到输出结果。
        # *********************************************************************************
        # 输入到cnn提取特征
        features, pos = self.backbone(samples)   #todo list

        num = self.args.layer1_num

        src, mask = features[num].decompose()
        # 然后将特征图映射为序列形式
        assert mask is not None
        # transformer 的输出是元组,分别为Decoder 和Encoder 的输出,因此这里取第一个代表的是Decoder的输出[0]
        hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[num])[0]
        # 将query_embedding 的权重作为参数输入到Transformer的前向过程,使用时与position encoding的方式相同,直接相加。

        # 第二部分对输出的维度进行转化,与分类和回归任务所要求的相对应
        # 生成分类与回归的预测结果
        outputs_class = self.class_embed(hs)
        outputs_coord = self.lines_embed(hs).sigmoid()
        # 由于hs包含了Transformer中Decoder每层的输出,因此索引为-1 代表去最后一层的输出
        out = {'pred_logits': outputs_class[-1], 'pred_lines': outputs_coord[-1]}
        # 若指定要计算Decoder 每层预测输出对应的loss,则记录对应的输出结果
        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
        return out

接着看Transformer:

class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()
		# encode
		# 单层
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        # 由6个单层组成整个encoder
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
		#decode
        decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
                                          return_intermediate=return_intermediate_dec)

在进行encoder之前先还有个处理:

bs, c, h, w = src.shape# 这个和我们上面说的一样[2,256,7,8]
src = src.flatten(2).permute(2, 0, 1) # src转为[56,2,256]
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)# pos_embed 转为[56,2,256]
mask = mask.flatten(1) #mask 转为[2,56]

encoder的输入为:src, mask, pos_embed

 q = k = self.with_pos_embed(src, pos)# pos + src
 src2 = self.self_attn(q, k, value=src, key_padding_mask=mask)[0]
 #做self_attention,这个不懂的需要补一下transfomer的知识
 src = src + self.dropout1(src2)# 类似于残差网络的加法
 src = self.norm1(src)# norm,这个不是batchnorm,很简单不在详述
 src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))#两个ffn
 src = src + self.dropout2(src2)# 同上残差加法
 src = self.norm2(src)# norm
 return src

单层的输出依然为src[56, 2, 256],第二个单层的输入依然是:src, mask, pos_embed。循环往复6次结束encoder,得到输出memory, memory的size依然为[56, 2, 256].

Decoder的输入:

tgt = torch.zeros_like(query_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
                  pos=pos_embed, query_pos=query_embed)
                   

query_embed其实是一个varible,size=[100,2,256],由训练得到,结束后就固定下来了。

class TransformerDecoder(nn.Module):

    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
        self.return_intermediate = return_intermediate

    def forward(self, tgt, memory,
                tgt_mask: Optional[Tensor] = None,
                memory_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):

        # tgt 是query embedding,shape是(num_queries,b,hidden_dim)
        # query_pos 是对应tgt的位置编码,shanpe是和tgt一致
        # memory是Encode的输出,shape是(h×w,b,hidden_dim)
        # memory_key_padding_mask 对应encoder的src_key_padding_mask,也是EncoderLayer的key_padding_mask,shape是(b,h×w)
        # pos对应输入到Encoder的位置编码,这里代表memory的位置编码,shape和memory一致。


        output = tgt

        intermediate = []

        #  intermediate = []中记录的是每一层输出后的归一化结果,而每一层的输入是前一层输出(没有归一化)的结果

        for layer in self.layers:
            output = layer(output, memory, tgt_mask=tgt_mask,
                           memory_mask=memory_mask,
                           tgt_key_padding_mask=tgt_key_padding_mask,
                           memory_key_padding_mask=memory_key_padding_mask,
                           pos=pos, query_pos=query_pos)
            if self.return_intermediate:
                intermediate.append(self.norm(output))

        # self.norm 是通过初始化时传进来的参数norm(默认none)设置的,那么self.norm就有可能是none,故以下对此作了判断。

        if self.norm is not None:
            output = self.norm(output)
            if self.return_intermediate:
                intermediate.pop()
                intermediate.append(output)

        if self.return_intermediate:
            return torch.stack(intermediate)

        return output.unsqueeze(0)

猜你喜欢

转载自blog.csdn.net/qq_35831906/article/details/124524455