【附代码】【入门级】多任务分类学习

1.数据获取与处理

使用CIFAR-10[2]数据集,该数据集根据MIT许可证提供。

该数据集由60000张32x32像素的RGB图像组成,分为10个不同的类别。它被分为50000个训练样本和10000个测试样本,并且是完美平衡的,这意味着数据集包含每个类6000个图像。

可以通过执行以下操作轻松加载数据集:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

数据集包含以下类别:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。多任务模型要学习的两个任务将是这些标签上的分类,请参见:

  • 任务1:在修改后的CIFAR10数据集上进行多类别分类(飞机、汽车、鸟、猫、狗、青蛙、船和卡车标签,修改说明如下)。

  • 任务2:二分类(标签为动物和载体)。

猜你喜欢

转载自blog.csdn.net/allein_STR/article/details/129509468