读取数据#
def prepare_dataset(data_dir):
data_trasnforms = transforms.Compose({
transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
})
image_dataset = datasets.ImageFolder(data_dir,data_trasnforms)
data_loaders = torch.utils.data.DataLoader(image_dataset, batch_size=16, shuffle=True)
return data_loaders
def train():
data_loaders = prepare_dataset(data_dir)
#定义模型
model = torchvision.models.resnet18()
#优化器
optimizer=optim.Adam(model.parameters(), lr=0.0005)
#损失函数
nn.CrossEntropyLoss()
for ephch in range(ephchs):
for i, data in enumerate(data_loaders,0):
input,label = data
optimizer.zero_grad()#梯度清零
outputs = model(inputs)#提取特征
loss = criterion(outputs, labels) #计算损失
loss.backward() #反向传播
optimizer.step() #更新权重
torch.save(model.state_dict(), '模型名称')
分类网络搭建思路
猜你喜欢
转载自blog.csdn.net/weixin_47681965/article/details/125722229
今日推荐
周排行