注意事项:
1
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = False
else:
original_aux_logits = False
# we are loading weights from a pretrained model
kwargs['init_weights'] = False
model = Inception3(**kwargs)
# state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
# progress=progress)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model
将所有aux涉及到的变量设成false。删除所有的辅助分类器。
2
调整输入尺寸为N x 3 x 299 x 299。色彩模式为RGB。
def Inception_loader(path):
# ANTIALIAS:high quality
return Image.open(path).resize((299, 299), Image.ANTIALIAS).convert('RGB')
3
遇到一个问题,如下代码:
def _transform_input(self, x):
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * \
(0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * \
(0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * \
(0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
是在计算三个色彩通道吗?反正我关掉了。
目前推测,这个是为了将色彩通道前移的方法,但是里面这些数字的含义仍然让人无法理解。
(原稿)2021-1-24日夜
注意事项:
1.输入图像 N x 3 x 299 x 299 的 尺寸必须被保证:
使用如下的自定义loader:
def Inception_loader(path):
# ANTIALIAS:high quality
return Image.open(path).resize((299, 299), Image.ANTIALIAS).convert('RGB')
2.关闭辅助分类器:
if pretrained:
if 'transform_input' not in kwargs:
kwargs['transform_input'] = True
if 'aux_logits' in kwargs:
original_aux_logits = kwargs['aux_logits']
kwargs['aux_logits'] = False
else:
original_aux_logits = False
# we are loading weights from a pretrained model
kwargs['init_weights'] = False
model = Inception3(**kwargs)
# state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
# progress=progress)
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
if not original_aux_logits:
model.aux_logits = False
del model.AuxLogits
return model
把所有的aux相关属性全设成false就好了
3.
def _transform_input(self, x):
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * \
(0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * \
(0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * \
(0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
这个代码一直比较疑惑是怎么回事:像这种公式是怎么推出来的?