李宏毅2022机器学习HW8解析

准备工作

作业八是异常检测(Anomaly Detection),需要助教代码和数据集,运行代码过程中保持联网可以自动下载数据集,已经有数据集的情况可关闭助教代码中的下载数据部分。关注本公众号,可获得代码和数据集(文末有方法)。

提交地址

Kaggle:www.kaggle.com/competitions/ml2022spring-hw8,有想讨论沟通的同学可进QQ群:156013866。以下为作业解析,详细代码见文末。

Simple Baseline (AUC>0.53150)

方法:直接运行助教代码。注意在本地或kaggle上运行时候,需要调整相应的文件名称或者路径。提交kaggle的score是:0.53158 。

Medium Baseline (AUC>0.73171)

方法:CNN模型 + 减少latent dim。助教代码中使用的是VAE模型,这里我使用了CNN模型,并添加了线性层减少中间数据维度。提交kaggle的socre是:0.74087。另外我也微调了FCN模型和VAE模型,使用FCN模型可以很容易得到0.75以上的结果,但是VAE模型只能在0.60左右徘徊,可能是VAE的强大‘脑补’能力导致,也许可通过增加模型大小来解决。

扫描二维码关注公众号,回复: 15655771 查看本文章

Strong Baseline (AUC>0.77196)

方法:ResNet模型+更多epoch+小batch size。模型的encoder换成了ResNet,epoch增加为100,batch size减小为128。提交kaggle的socre是:0.77437。这里我一开始参考助教作业PPT,实验了融合模型:cnn + fcn,发现结果并不理想。

Boss Baseline (Acc>0.79506)

方法:ResNet模型+辅助网络。在strong baseline的基础上,使用了额外的一个decoder辅助网络来,ResNet网络与原来的训练方法一致,decoder网络的损失函数受resnet控制,结果也比resnet更强,提交的文件也是通过该decoder计算出来的,具体过程可以看代码,这里我先不解释为什么这么做,看谁能悟出来,算是一个小彩蛋,李老师的课堂上没讲过这个方法,现有的文章也很少有提及到它,有兴趣的可以在交流群里沟通或者私信我。该方法提交kaggle的socre是:0.79557

作业八答案获得方式:

  1. 关注微信公众号 “机器学习手艺人” 

  2. 后台回复关键词:202208

猜你喜欢

转载自blog.csdn.net/weixin_42369818/article/details/125292835