このデモでは、事前にトレーニングされた重みに基づいて DETR の簡易バージョンを実装し、画像を予測し、最後に予測効果を示します。
まず、必要な関連ライブラリをインポートします。
from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
# 推理过程,不需要梯度
torch.set_grad_enabled(False);
まずデモモデルの完全なコードを表示します
class DETRdemo(nn.Module):
"""
DETR的一个基本实现
此演示Demo与论文中的完整DETR模型有一下不同点:
* 使用的是可学习位置编码(论文中使用的是正余弦位置编码)
* 位置编码在输入端传入 (原文在注意力模块传入)
* 采用fc bbox 预测器 (没有采用MLP)
该模型在 COCO val5k 上达到约 40 AP,在 Tesla V100 上以约 28 FPS 的速度运行。
仅支持批量大小为1。
"""
def __init__(self, num_classes, hidden_dim=256, nheads=8,
num_encoder_layers=6, num_decoder_layers=6):
# hidden_dim: 隐藏状态的神经单元个数,也就是隐藏层的节点数,应该可以按计算需要“随意”设置。
super().__init__()
# create ResNet-50 backbone
# 创建Resnet50
self.backbone = resnet50()
# 删除最后的全连接层
del self.backbone.fc
# create conversion layer
# 将骨干网络的输出特征图维度映射到Transformer输入所需的维度
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# create a default PyTorch transformer
# nheads代表多头注意力的"头数"
self.transformer = nn.Transformer(
hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
# prediction heads, one extra class for predicting non-empty slots
# num_classes需要在原有来别数量上多加一个non-empty类
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
# note that in baseline DETR linear_bbox layer is 3-layer MLP
# 标准DETR模型中最后的输出层由三个全连接层构成而非一个全连接层
# bbox的形式是(x,y,w,h),因此是四维
self.linear_bbox = nn.Linear(hidden_dim, 4)
# output positional encodings (object queries)
# 用于解码器输入的位置编码,100代表最终解码出100个物体
# 即对一张图片(最多)检测出100个物体
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
# spatial positional encodings
# note that in baseline DETR we use sine positional encodings
# 用于编码器输入的位置编码
# 对特征图的行、列分别进行位置编码,而后会将两者结果拼接
# 因此维度格式hidden_dim的一半,前128是x后128是y
# nn.Parameter() 在指定Tensor中随机生成参数
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
# propagate inputs through ResNet-50 up to avg-pool layer
x = self.backbone.conv1(inputs)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# convert from 2048 to 256 feature planes for the transformer
# 将backbone的输出维度转换为Transformer输入所需的维度
# h.shape = (1, hidden_dim, H, W)
h = self.conv(x)
# construct positional encodings
H, W = h.shape[-2:]
# Tensor.unsqueeze() 在指定位置插入新维度
# Tensor.repeat() 沿某个维度复制Tensor
# self.col_embed[:W].shape = (W, hidden_dim / 2) hidden_dim = 256
# self.col_embed[:W].unsqueeze(0).shape = (1, W, 128)
# self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape = (H, W, 128)
# torch.cat(...).flatten(0, 1).shape = (HxW, 256)
# torch.cat(...).flatten(0, 1).unsqueeze(1).shape = (HxW, 256, 256)
# pos.shape = (HxW, 1, 256) (HxW, 1, hidden_dim) 这里中间加一维是对应batch的维度
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
# propagate through the transformer
# 输出到Transformer中h的维度为(HxW, batch, hidden_dim),
# query_pos的维度为(100, 1, hidden_dim)
# Tensor.permute() 按照指定维度顺序对Tonser进行转职
# h.flatten(2).shape = (1, hidden_dim, HxW)
# h.flatten(2).permute(2, 0, 1).shape = (HxW, 1, hidden_dim)
# h.shape = (1, 100, hidden_dim)
h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1)).transpose(0, 1)
# finally project transformer outputs to class labels and bounding boxes
# 输出预测物体类别(batch, 100, num_classes + 1)
# 预测的物体bbox(batch, 100, 4)
# 之所以sigmoid是因为回归的是归一化的值
return {
'pred_logits': self.linear_class(h),
'pred_boxes': self.linear_bbox(h).sigmoid()}
モデルは主にバックボーン、トランスフォーマー、最終的に予測出力を形成する線形層で構成されており、さらに、バックボーンが出力する特徴マップの次元を入力に必要な次元にマッピングするために畳み込み層が必要です変圧器の。
Transformer を知っている友人なら、Transformer は入力シーケンス内の各部分の位置関係を理解できないことを知っているはずです。そのため、通常は位置エンコーディングを追加する必要があり、ここでも同様です。
# output positional encodings (object queries)
# 用于解码器输入的位置编码,100代表最终解码出100个物体
# 即对一张图片(最多)检测出100个物体
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
# spatial positional encodings
# note that in baseline DETR we use sine positional encodings
# 用于编码器输入的位置编码
# 对特征图的行、列分别进行位置编码,而后会将两者结果拼接
# 因此维度格式hidden_dim的一半,前128是x后128是y
# nn.Parameter() 在指定Tensor中随机生成参数
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
上の図では、行コードと列コードの最初の次元は 50 です。これは、ここでのデフォルトのバックボーンによって出力される特徴マップのサイズが 50x50 を超えないことを意味します。
非常に簡潔かつ明確なモデルの初期化メソッドはこれで終わりです。次に、モデルの前方プロセスを見てみましょう。
def forward(self, inputs):
# propagate inputs through ResNet-50 up to avg-pool layer
x = self.backbone.conv1(inputs)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
x = self.backbone.layer2(x)
x = self.backbone.layer3(x)
x = self.backbone.layer4(x)
# convert from 2048 to 256 feature planes for the transformer
# 将backbone的输出维度转换为Transformer输入所需的维度
# h.shape = (1, hidden_dim, H, W)
h = self.conv(x)
# construct positional encodings
H, W = h.shape[-2:]
# Tensor.unsqueeze() 在指定位置插入新维度
# Tensor.repeat() 沿某个维度复制Tensor
# self.col_embed[:W].shape = (W, hidden_dim / 2) hidden_dim = 256 # self.col_embed[:W].unsqueeze(0).shape = (1, W, 128) # self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape = (H, W, 128) # torch.cat(...).flatten(0, 1).shape = (HxW, 256) # torch.cat(...).flatten(0, 1).unsqueeze(1).shape = (HxW, 256, 256) # pos.shape = (HxW, 1, 256) (HxW, 1, hidden_dim) 这里中间加一维是对应batch的维度
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
上図の部分は、画像をバックボーンに入力して特徴を抽出し、出力特徴マップの次元を変換して、位置エンコーディング テンソルを構築する部分です。ここでの位置エンコーディング テンソルの実装は、特徴マップの行と列をエンコードしてそれらを結合し、同時にエンコーダーの入力に適応するために次元変換を実行することです。
以下は、エンコードとデコードのために上記の部分を Transformer に入力し、最後にデコードされた結果を線形層に入力して最終的な予測結果を形成します。
# propagate through the transformer
# 输出到Transformer中h的维度为(HxW, batch, hidden_dim),
# query_pos的维度为(100, 1, hidden_dim)
# Tensor.permute() 按照指定维度顺序对Tonser进行转职
# h.flatten(2).shape = (1, hidden_dim, HxW)
# h.flatten(2).permute(2, 0, 1).shape = (HxW, 1, hidden_dim)
# h.shape = (1, 100, hidden_dim)
h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1)).transpose(0, 1)
# finally project transformer outputs to class labels and bounding boxes
# 输出预测物体类别(batch, 100, num_classes + 1)
# 预测的物体bbox(batch, 100, 4)
# 之所以sigmoid是因为回归的是归一化的值
return {
'pred_logits': self.linear_class(h),
'pred_boxes': self.linear_bbox(h).sigmoid()}
上の図では、Transformer の出力次元の順序が調整されているため、h の最終次元は (batch, 100, hidden_dim) になっていることに注意してください。
転送プロセス全体はこれで終わりです。コードを取得する必要があるというプレッシャーはないと思いませんか?
以下は、入力画像と出力 bbox の処理です。
# standard PyTorch mean-std input image normalization
transform = T.Compose([
T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# for output bounding box post-processing
# 将bbox 的形式由中心坐标和宽高转换为左上角和右下角坐标
def box_cxcywh_to_xyxy(x):
# x_c,y_c,w,h的shape都是(batch,)
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
# 将bbox坐标由归一化的值转换为基于图像尺寸的绝对坐标值
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
出力されるbboxは、まず中心点と幅と高さの座標から、長方形ボックスの左上隅と右下隅の座標に変換し、同時に正規化された値が返されるので、画像サイズに応じて絶対座標値に変換する必要があります。
ここで、予測結果を取得するための推論プロセス全体をカプセル化するメソッドを定義します。
def detect(im, model, transform):
# mean-std normalize the input image (batch-size: 1)
# 对图像进行预处理并转换为Tensor, 在最外层加一维对应的batch的维度
# 这里是1, 代表对一场图片做检测
img = transform(im).unsqueeze(0)
# demo model only support by default images with aspect ratio between 0.5 and 2
# if you want to use images with an aspect ratio outside this range # rescale your image so that the maximum size is at most 1333 for best results assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'
# propagate through the model
outputs = model(img)
# keep only predictions with 0.7+ confidence
# 去掉背景那一类, 并且对于模型输出预测的100个物体,
# 只取置信度大于0.7的那批
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.7
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
return probas[keep], bboxes_scaled
ここで言及すべき点があります。torch1.5 バージョンでは、tensor.max() への戻りは torch.return_types.max(values=tensor(xxx), indices=tensor(xxx)) ですが、torch1.0 ではこれです。メソッドの戻り値はタプルです。
COCO データセットのカテゴリが選択され、合計 80 カテゴリが選択されますが、インデックスは 1 ~ 90 です。
# COCO classes
CLASSES = [
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
上図の COLORS は、bbox の長方形の色を描画するために使用されます。
COCO のカテゴリ インデックスは 1 ~ 90 なので、モデルをインスタンス化できるようになりました。そのため、num_classes パラメータを 91 に設定する必要があります。
detr = DETRdemo(num_classes=91)
state_dict = torch.hub.load_state_dict_from_url(
url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
map_location='cpu', check_hash=True)
detr.load_state_dict(state_dict)
detr.eval();
OK、すべての準備が整ったので、画像を検出しましょう。
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
scores, boxes = detect(im, detr, transform)
# 可以看到置信度大于0.7的只有5个物体
print(scores.shape)
print(boxes.shape)
ご覧のとおり、モデルはこの画像で 5 つのオブジェクトを検出し、最終的に結果を視覚化します。
def plot_results(pil_img, prob, boxes):
plt.figure(figsize=(16,10))
# 现实原始图片
plt.imshow(pil_img)
ax = plt.gca()
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
# 在图片上画出物体的bbox
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
# 该物体置信度最大的类别
cl = p.argmax()
text = f'{
CLASSES[cl]}: {
p[cl]:0.2f}'
# 在物体bbox的左上角写出其预测类别机器相应的置信度
ax.text(xmin, ymin, text, fontsize=15,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
plt.show()
plot_results(im, scores, boxes)
最終的な効果を以下の図に示します
。