最近一直在学pytorch,copy了几个经典的入门问题。现在作一下总结。
首先,做的小项目主要有
分类问题:Mnist手写体识别、FashionMnist识别、猫狗大战
语义分割:Unet分割肝脏图像、遥感图像
先把语义分割的心得总结一下,目前只是一部分,以后还会随着学习的深入慢慢往里面加新的感悟。
1)对于二分类问题
1. Unet输出channel:对于二分类问题,类别数为2,channel为1,用uint8的单通道灰度图像表示类别就行(0/1)。
2. label是单通道灰度图像,直接读取。
3. 损失函数:nn.sigmoid + nn.BCELoss / nn.BCEWithLogitsLoss
2)对于多分类问题
1. Unet输出channel: 输出channel是类别数。
2. label是单通道的灰度图像,用不同的灰度级表示不同的类别。具体操作是先在data的类中_get_item_方法里将label进行one hot编码读取(即多通道图片,一个通道一个类别),然后输入Unet进行训练。此时得到的模型输出应该也是one hot编码格式,最后用argmax把多通道图片转换成单通道图片(不同灰度级表示不同类别)。
3. 损失函数:nn.CrossEntorpyLoss来计算。
4. 此外,这种多分类的方法有时候精度相对不高,可以转化成多个二分类问题,最后合成在一起。