分类网络搭建思路

读取数据#
def prepare_dataset(data_dir):
    data_trasnforms = transforms.Compose({
    
    
        transforms.Resize((32,32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    })
    image_dataset = datasets.ImageFolder(data_dir,data_trasnforms)
    data_loaders = torch.utils.data.DataLoader(image_dataset, batch_size=16, shuffle=True)
    return data_loaders

def train():
	data_loaders = prepare_dataset(data_dir)
	#定义模型
	model = torchvision.models.resnet18()
	#优化器
	optimizer=optim.Adam(model.parameters(), lr=0.0005)
	#损失函数
	nn.CrossEntropyLoss()
	for ephch in  range(ephchs):
			for i, data in enumerate(data_loaders,0):
					input,label = data
					optimizer.zero_grad()#梯度清零
	            	outputs = model(inputs)#提取特征
	            	loss = criterion(outputs, labels) #计算损失
	            	loss.backward() #反向传播         	
		            optimizer.step() #更新权重
	torch.save(model.state_dict(), '模型名称')

猜你喜欢

转载自blog.csdn.net/weixin_47681965/article/details/125722229
今日推荐