Pytorch 模型输入报错踩坑笔记

前言

这一篇希望能长期记录一下写 pytorch 时踩到的各种的坑,有些时候偷个小懒一时半会儿不会遇到问题,但会在很久之后背刺(捂脸)。

CSDN 应该都是目的导向型的比较受欢迎,带着问题 google 搜到的文章。这种泛读的经验文章估计又吃灰了hhh。

正文

输入使用 squeeze()

谨慎使用 torch.squeeze,会将所有长度为 1 的维度都抹掉,所以注意补充参数。

下面的场景中,读入的 audio_input 的 shape 为 [batch_size, 1, audio_length] 。额外的维度1是音频加载库从文件中读取音频时自带的声道维度。但是使用的模型默认单声道,如果传入该维度会报错。于是使用下面的代码去除该维度。

audio_input = torch.squeeze(audio_input)

在很长的时间中没有出现任何问题。

直到有一次评估的时候,在最后一个样本的时候,他崩了。因为只有一个样本,所以输入的 shape 为 [1,1, audio_length] ,这导致前两个维度都被 squeeze 了。训练的时候没有遇到这个问题因为会 drop last 。所以建议加上 index 参数,确定 squeeze 第几个维度。例如切换为下面的写法。

audio_input = audio_input.squeeze(1)

猜你喜欢

转载自blog.csdn.net/Haulyn5/article/details/129497169
今日推荐