用pytorch学习github写了个picture caption的AI项目的经验总结

转载自

目录

1 整个项目的架构

1.1 文件名和文件作用

1.2 创建项目的整体思路

2 各部分文件的经验总结

2.1 主函数 main.py

2.1.1 命令行参数 argparse.ArgumentParser

2.1.2 main.py文件的一般逻辑

2.1.3 可用的gpu环境部署

2.1.4 main函数 承载主要逻辑

2.1.5 train函数 训练

2.1.6 validate函数 验证

2.2 模型model.py

2.3 预处理prepro.py

2.4 数据加载data_loader.py

2.5 工具utils.py

2.6 验证集数据集创建

2.7 其他一般性的经验总结

2.7.1 注释

2.7.2 nohup


1 整个项目的架构

1.1 文件名和文件作用

python文件名及作用
main.py 主函数
model.py 模型
prepro.py 预处理
data_loader.py 数据集加载辅助
flickr8k_dataloader.py 针对flickr8k的数据集加载辅助
compute_mean_val.py 计算数据集图片的均值、标准差
utils.py 工具类
make_val_dataset.py 创建验证数据集

1.2 创建项目的整体思路

  1. 首先书写main.py文件,在主函数文件中理清思路和头绪(遇到未写的变量时,假装已经定义,做好标记,跳过具体内容,继续余下书写代码,以梳理整体思路并通过标记让main.py文件和接下来要写辅助文件逻辑一致)。
  2. 在书写main.py过程中,就会发现需要的模型文件、工具类预处理等辅助文件,并从整体上理解了所需要的功能接口
  3. 书写预处理prepro.py文件,根据main.py中相关部分所需要的模型输入接口,对数据进行预处理。
  4. 书写模型model.py文件,根据预处理后的数据格式,和相关算法理论(比如阅读到的paper或者自己构思的idea),用pytorch搭建model,遇到需要数据加载类时,同main.py文件,做好标记,跳过具体内容,继续余下书写代码。
  5. 书写data_loader.py文件,根据基本书写好的model.py文件中模型对输入data的要求,基于pytorch的数据加载类torch.utils.data.DataLoader,构造自己的数据加载类。
  6. 书写其他的工具类utils.py文件,根据已经书写好的主函数文件,预处理文件,数据加载文件,模型文件中的所需要的具有普适性的一般功能(尤其是暂时跳过尚未书写的),将其归纳进入工具类文件。
  7. 书写创建验证集数据集的make_val_dataset.py文件,一般就是基于训练数据从里面选出一些数据,最好让选择出来的数据不再参与训练过程,以保证验证过程的客观公正。

2 各部分文件的经验总结

2.1 主函数 main.py

2.1.1 命令行参数 argparse.ArgumentParser

首先就是命令行参数的构建,定义如下


  
  
  1. import argparse
  2. parser = argparse.ArgumentParser() # 命令行参数解析器
  3. parser.add_argument(
  4. '--model_path', # 命令行参数名
  5. type=str, # 类型
  6. default= './models/', # 默认值
  7. help= 'path for saving trained models') # 提示
  8. # 创建其他命令行参数...
  9. args = parser.parse_args() # 获取命令行参数
  10. print(args) # 打印查看命令行参数

建议

  1. 将其写在全局,这样方便全局引用。
  2. 一定要用argparse.ArgumentParser()构造命令行参数,规范简介而且功能一目了然

调用时

model_path=args.model_path
  
  

再调用model_path即可。

当然,如果后面不再使用这一变量,可以直接使用args.model_path

2.1.2 main.py文件的一般逻辑

  • 调用包
  • 可用的gpu环境部署
  • main函数 承载主要逻辑
  • train函数 训练
  • validate函数 验证
  • if __name__ == '__main__':  main函数功能逻辑之外的其余背景部署

2.1.3 可用的gpu环境部署

当gpu可用时,一般而言只有一块gpu,由多块时指定某一块x就写作cuda:x即可;gpu不可用就为cpu模式。

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
  
  

使用时,将变量、模型或计算转移至gpu上:


  
  
  1. imgs = imgs.to(device) # 图片部署至gpu
  2. decoder = decoder.to(device) # 解码器部署至gpu
  3. criterion = nn.CrossEntropyLoss().to(device) # 计算部署至gpu

2.1.4 main函数 承载主要逻辑

一 预加载

一般需要加载些东西,比如字典、模型什么的


  
  
  1. # 加载字典包装
  2. with open(args.vocab_path, 'rb') as f:
  3. vocab = pickle.load(f)

使用到pickle模块

import pickle
  
  

二 预定义变量

在训练之前,需要预先创建一些变量。当然这时就得分之前是否训练过了。

a 如果之前训练过,则需要加载之前保存的checkpoint(一般是个用torch保存的字典)

保存的例子(一般放在工具utils.py中,作为一个单独的函数),例如


  
  
  1. def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
  2. decoder, encoder_optimizer, decoder_optimizer, bleu4,
  3. is_best):
  4. """
  5. Saves model checkpoint.
  6. :param data_name: base name of processed dataset
  7. :param epoch: epoch number
  8. :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
  9. :param encoder: encoder model
  10. :param decoder: decoder model
  11. :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
  12. :param decoder_optimizer: optimizer to update decoder's weights
  13. :param bleu4: validation BLEU-4 score for this epoch
  14. :param is_best: is this checkpoint the best so far?
  15. """
  16. state = {
  17. 'epoch': epoch,
  18. 'epochs_since_improvement': epochs_since_improvement,
  19. 'bleu-4': bleu4,
  20. 'encoder': encoder,
  21. 'decoder': decoder,
  22. 'encoder_optimizer': encoder_optimizer,
  23. 'decoder_optimizer': decoder_optimizer
  24. }
  25. filename = 'checkpoint_' + data_name + '.pth.tar'
  26. torch.save(state, filename)
  27. # 如果这个 checkpoint 是目前为止最好的,存储一个副本,这样它就不会被更差的 checkpoint 覆盖
  28. if is_best:
  29. torch.save(state, 'BEST_' + filename)

加载时,用torch.load,之后就得到一个字典类型的变量,用访问字典键值对的方式读取即可


  
  
  1. checkpoint = torch.load(args.checkpoint)
  2. start_epoch = checkpoint[ 'epoch'] + 1
  3. epochs_since_improvement = checkpoint[ 'epochs_since_improvement']
  4. best_bleu4 = checkpoint[ 'bleu-4']
  5. decoder = checkpoint[ 'decoder']
  6. decoder_optimizer = checkpoint[ 'decoder_optimizer']
  7. encoder = checkpoint[ 'encoder']
  8. encoder_optimizer = checkpoint[ 'encoder_optimizer']
  9. if fine_tune_encoder is True and encoder_optimizer is None:
  10. encoder.fine_tune(fine_tune_encoder) # 微调器微调
  11. encoder_optimizer = torch.optim.Adam(
  12. params=filter( lambda p: p.requires_grad, encoder.parameters()),
  13. lr=args.encoder_lr) # 编码器优化器

b 如果之前没有训练过,则需要预定义新变量


  
  
  1. decoder = AttnDecoderRNN(
  2. attention_dim=args.attention_dim,
  3. embed_dim=args.embed_dim,
  4. decoder_dim=args.decoder_dim,
  5. vocab_size=len(vocab),
  6. dropout=args.dropout) # 解码器
  7. decoder_optimizer = torch.optim.Adam(
  8. params=filter( lambda p: p.requires_grad, decoder.parameters()),
  9. lr=args.decoder_lr) # 解码器优化器
  10. encoder = EncoderCNN() # 编码器
  11. encoder.fine_tune(args.fine_tune_encoder) # 编码器微调
  12. encoder_optimizer = torch.optim.Adam(
  13. params=filter( lambda p: p.requires_grad, encoder.parameters()),
  14. lr=args.encoder_lr) if args.fine_tune_encoder else None # 编码器优化器
  15. best_bleu4 = args.best_bleu4

可以看到,这里普遍使用了lambda表达式filter函数,优化器选用的是常用而鲁棒的Adam

三 损失函数

然后定义损失函数,例如使用交叉熵

criterion = nn.CrossEntropyLoss().to(device)
  
  

这里用到了包

import torch.nn as nn
  
  

四 数据集加载器

如前文所述,一般就是利用torch.utils.data.DataLoader,构造自己的dataloader。如


  
  
  1. flickr = DataLoader(
  2. root=root, json=json, vocab=vocab, rank=rank, transform=transform)
  3. data_loader = torch.utils.data.DataLoader(
  4. dataset=flickr,
  5. batch_size=batch_size,
  6. shuffle=shuffle, # 打乱
  7. num_workers=num_workers, # 用于数据加载的子进程数
  8. collate_fn=collate_fn)

其中,参数dataset是继承torch.utils.data.Dataset类的数据集子类

继承torch.utils.data.Dataset类,需要实现两个方法

  1. __getitem__(self, index)(支持范围从0到len(self)独占的整数索引,即给出索引数字下标返回数据对象)
  2. __len__(self) 返回总数据量的长度

具体实现如下:


  
  
  1. class DataLoader(data.Dataset):
  2. def __init__(self, root, json, vocab, rank, transform=None):
  3. self.root = root
  4. self.flickr = flickr8k(
  5. ann_text_location=json, imgs_location=root, ann_rank=rank)
  6. self.vocab = vocab
  7. self.rank = rank
  8. self.transform = transform
  9. # 支持范围从0到len(self)独占的整数索引
  10. def __getitem__(self, index):
  11. flickr = self.flickr
  12. vocab = self.vocab
  13. # ann:annotation
  14. caption = flickr.anns[index][ 'caption']
  15. img_id = flickr.anns[index][ 'image_id']
  16. path = flickr.loadImg(img_id)
  17. image = Image.open(path).convert( 'RGB')
  18. if self.transform is not None:
  19. image = self.transform(image)
  20. tokens = nltk.tokenize.word_tokenize(str(caption).lower()) # 分词
  21. caption = []
  22. caption.append(vocab( '<start>'))
  23. caption.extend([vocab(token) for token in tokens])
  24. caption.append(vocab( '<end>'))
  25. target = torch.Tensor(caption)
  26. return image, target
  27. def __len__(self):
  28. return len(self.flickr.anns)

参数collate_fn自定义的数据批量获取的方法,即每次训练返回的batch


  
  
  1. def collate_fn(data):
  2. data.sort(key= lambda x: len(x[ 1]), reverse= True)
  3. images, captions = zip(*data)
  4. images = torch.stack(images, 0) # 将张量序列沿新维度串联起来
  5. lengths = [len(cap) for cap in captions]
  6. targets = torch.zeros(len(captions), max(lengths)).long()
  7. for i, cap in enumerate(captions):
  8. end = lengths[i]
  9. targets[i, :end] = cap[:end]
  10. return images, targets, lengths

这里每次就返回一些图片、对应的captions和captions的长度

有了这些,封装成我们自己的数据加载器get_loader,返回一个DataLoader对象用于数据加载


  
  
  1. def get_loader(root, json, vocab, transform, batch_size, rank, shuffle,
  2. num_workers):
  3. flickr = DataLoader(
  4. root=root, json=json, vocab=vocab, rank=rank, transform=transform)
  5. # 数据加载 flickr 数据集
  6. # 每次迭代返回 (images, captions, lengths)
  7. # images: tensor of shape (batch_size, 3, 224, 224).
  8. # captions: tensor of shape (batch_size, padded_length).
  9. # lengths: 表示每个标题有效长度的列表. length is (batch_size).
  10. data_loader = torch.utils.data.DataLoader(
  11. dataset=flickr,
  12. batch_size=batch_size,
  13. shuffle=shuffle,
  14. num_workers=num_workers,
  15. collate_fn=collate_fn) # 合并一个示例列表以形成一个 mini-batch
  16. return data_loader

然后就可以顺理成章的创建我们的DataLoader了


  
  
  1. train_loader = get_loader(
  2. args.image_dir,
  3. args.caption_path,
  4. vocab,
  5. transform,
  6. args.batch_size,
  7. args.rank,
  8. shuffle= True,
  9. num_workers=args.num_workers) # 训练数据集加载器
  10. val_loader = get_loader(
  11. args.image_dir_val,
  12. args.caption_path_val,
  13. vocab,
  14. transform,
  15. args.batch_size,
  16. args.rank,
  17. shuffle= True,
  18. num_workers=args.num_workers) # 验证数据集加载器

五 训练及验证的迭代过程

一般就使用for循环定义最大训练上限(当然也可以在train和validate函数中分别定义训练次数),然后每轮训练再验证,并打印中间信息,最后保存最终模型即可。

但考虑到训练会发生过拟合多次训练未见效果提升的情况,所以可以考虑

1 设置自上次训练以来,未提升历史最佳效果的训练次数上限,达到后自动退出循环,以免浪费时间。


  
  
  1. if args.epochs_since_improvement == 20: # 自上次优化以来 20次迭代仍不见优化则退出
  2. break
  3. # 训练
  4. # 验证
  5. is_best = recent_bleu4 > best_bleu4 # 判断当前是否表现得最好
  6. best_bleu4 = max(recent_bleu4, best_bleu4) # 记录最优bleu4值
  7. if not is_best: # 仍未实现优化
  8. args.epochs_since_improvement += 1
  9. print( "\nEpoch since last improvement: %d\n" %
  10. (args.epochs_since_improvement, )) # 打印自上次优化以来的目前的epoch数目
  11. else: # 当前迭代实现了优化
  12. args.epochs_since_improvement = 0 # epochs_since_improvement 计数清零

2 在训练过程中,当训练次数达到一定数量仍未见效果提升,但未达到1提到的退出上限,可以考虑降低学习率


  
  
  1. if args.epochs_since_improvement > 0 and args.epochs_since_improvement % 8 == 0:
  2. adjust_learning_rate(decoder_optimizer, 0.8) # 将解码器学习率降低一个特定的因子
  3. if args.fine_tune_encoder:
  4. adjust_learning_rate(encoder_optimizer,
  5. 0.8) # 将编码器学习率降低一个特定的因子

六 保存模型

最后,保存中间模型,一般最后就剩两个模型,最终的模型和历史最佳模型。


  
  
  1. save_checkpoint(args.data_name, epoch, args.epochs_since_improvement,
  2. encoder, decoder, encoder_optimizer, decoder_optimizer,
  3. recent_bleu4, is_best) # 保存模型检查点

save_checkpoint函数自定义如下


  
  
  1. def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
  2. decoder, encoder_optimizer, decoder_optimizer, bleu4,
  3. is_best):
  4. """
  5. Saves model checkpoint.
  6. :param data_name: base name of processed dataset
  7. :param epoch: epoch number
  8. :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
  9. :param encoder: encoder model
  10. :param decoder: decoder model
  11. :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
  12. :param decoder_optimizer: optimizer to update decoder's weights
  13. :param bleu4: validation BLEU-4 score for this epoch
  14. :param is_best: is this checkpoint the best so far?
  15. """
  16. state = {
  17. 'epoch': epoch,
  18. 'epochs_since_improvement': epochs_since_improvement,
  19. 'bleu-4': bleu4,
  20. 'encoder': encoder,
  21. 'decoder': decoder,
  22. 'encoder_optimizer': encoder_optimizer,
  23. 'decoder_optimizer': decoder_optimizer
  24. }
  25. filename = 'checkpoint_' + data_name + '.pth.tar'
  26. torch.save(state, filename)
  27. # 如果这个 checkpoint 是目前为止最好的,存储一个副本,这样它就不会被更差的 checkpoint 覆盖
  28. if is_best:
  29. torch.save(state, 'BEST_' + filename)

2.1.5 train函数 训练


  
  
  1. def train(train_loader, encoder, decoder, criterion, encoder_optimizer,
  2. decoder_optimizer, epoch)

先把编码器、解码器设置为训练模式


  
  
  1. decoder.train() # 将解码器模块设置为训练模式
  2. encoder.train() # 将编码器模块设置为训练模式

下面的几个变量用到了utils工具类文件的AverageMeter类,这是用来跟踪度量的最新值val、平均值avg、和sum和计数count的辅助类


  
  
  1. # AverageMeter 跟踪度量的最新值val、平均值avg、和sum和计数count
  2. batch_time = AverageMeter()
  3. data_time = AverageMeter()
  4. losses = AverageMeter()
  5. top5accs = AverageMeter()

utils中AverageMeter类定义如下


  
  
  1. class AverageMeter(object):
  2. """
  3. 跟踪度量的最新值、平均值、和与计数
  4. """
  5. def __init__(self):
  6. self.reset()
  7. def reset(self):
  8. self.val = 0
  9. self.avg = 0
  10. self.sum = 0
  11. self.count = 0
  12. def update(self, val, n=1):
  13. self.val = val
  14. self.sum += val * n
  15. self.count += n
  16. self.avg = self.sum / self.count

然后就是从之前定义得DataLoader中获取数据

for i, (imgs, caps, caplens) in enumerate(train_loader):
  
  

之后的逻辑大体上就是

  1. 将数据转移到gpu上
  2. 预测结果
  3. 计算loss
  4. 添加正则化到loss
  5. 优化器清除梯度
  6. 反向传播
  7. 优化器推进一步(step)
  8. 返回给定输入张量沿给定维度的5个最大元素
  9. 到一定迭代次数后打印当前信息

train函数完整参考如下


  
  
  1. # 训练
  2. def train(train_loader, encoder, decoder, criterion, encoder_optimizer,
  3. decoder_optimizer, epoch):
  4. decoder.train() # 将解码器模块设置为训练模式
  5. encoder.train() # 将编码器模块设置为训练模式
  6. # AverageMeter 跟踪度量的最新值val、平均值avg、和sum和计数count
  7. batch_time = AverageMeter()
  8. data_time = AverageMeter()
  9. losses = AverageMeter()
  10. top5accs = AverageMeter()
  11. start = time.time() # 开始时间计时
  12. for i, (imgs, caps, caplens) in enumerate(train_loader):
  13. data_time.update(time.time() - start)
  14. # 部署图片和标题至gpu
  15. imgs = imgs.to(device)
  16. caps = caps.to(device)
  17. imgs = encoder(imgs) # 编码器训练
  18. scores, decode_lengths, alphas = decoder(imgs, caplens) # 解码器
  19. scores = pack_padded_sequence(
  20. scores, decode_lengths, batch_first= True) # 包一个包含可变长度的填充序列的张量
  21. targets = caps[:, 1:]
  22. targets = pack_padded_sequence(
  23. targets, decode_lengths, batch_first= True)
  24. scores = scores.data
  25. targets = targets.data
  26. loss = criterion(scores, targets) # 根据自定义标准计算损失值
  27. loss += args.alpha_c * (( 1. - alphas.sum(dim= 1))** 2).mean() # 加上正则化项
  28. decoder_optimizer.zero_grad() # 清除解码器所有梯度
  29. if encoder_optimizer is not None:
  30. encoder_optimizer.zero_grad() # 清除编码器所有梯度
  31. loss.backward() # 损失值反向传播
  32. if args.grad_clip is not None:
  33. clip_gradient(decoder_optimizer,
  34. args.grad_clip) # 在反向传播过程中计算剪辑梯度,以避免梯度爆炸
  35. if encoder_optimizer is not None:
  36. clip_gradient(encoder_optimizer, args.grad_clip)
  37. decoder_optimizer.step() # 解码器优化器前进一步
  38. if encoder_optimizer is not None:
  39. encoder_optimizer.step() # 编码器优化器前进一步
  40. top5 = accuracy(scores, targets, 5) # 返回给定输入张量沿给定维度的5个最大元素
  41. losses.update(loss.item(), sum(decode_lengths))
  42. top5accs.update(top5, sum(decode_lengths))
  43. batch_time.update(time.time() - start)
  44. start = time.time()
  45. # 到了打印一波日志的时候
  46. if i % args.log_step == 0:
  47. print( 'Epoch: [{0}][{1}/{2}]\t'
  48. 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  49. 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
  50. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  51. 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
  52. epoch,
  53. i,
  54. len(train_loader),
  55. batch_time=batch_time,
  56. data_time=data_time,
  57. loss=losses,
  58. top5=top5accs))

2.1.6 validate函数 验证

验证函数与之类似,就是多了计算BLEU-4分数以评估模型

关键代码


  
  
  1. from nltk.translate.bleu_score import corpus_bleu
  2. # 计算 BLEU-4 得分
  3. bleu4 = corpus_bleu(references, hypotheses)

validate函数完整代码


  
  
  1. # 验证集上效果计算
  2. def validate(val_loader, encoder, decoder, criterion):
  3. """
  4. Performs one epoch's validation.
  5. :param val_loader: DataLoader for validation data.
  6. :param encoder: encoder model
  7. :param decoder: decoder model
  8. :param criterion: loss layer
  9. :return: BLEU-4 score
  10. """
  11. decoder.eval() # 将模块设置为评估模式 (no dropout or batchnorm)
  12. if encoder is not None:
  13. encoder.eval()
  14. batch_time = AverageMeter()
  15. losses = AverageMeter()
  16. top5accs = AverageMeter()
  17. start = time.time()
  18. references = list() # 计算BLEU-4分数的参考(真实标题)
  19. hypotheses = list() # 假设(预测)
  20. # 每轮batch迭代
  21. for i, (imgs, caps, caplens) in enumerate(val_loader):
  22. # 迁移至gpu
  23. imgs = imgs.to(device)
  24. caps = caps.to(device)
  25. # 前向传播
  26. if encoder is not None:
  27. imgs = encoder(imgs)
  28. scores, decode_lengths, alphas = decoder(imgs, caplens)
  29. # 因为我们是从<start>开始解码的,所以目标都是<start>之后的单词,一直到<end>
  30. targets = caps[:, 1:]
  31. # 删除我们没有解码的时间步长,或者是pad
  32. # pack_padded_sequence 是完成这个目的的一个简单的技巧
  33. scores_copy = scores.clone()
  34. scores = pack_padded_sequence(scores, decode_lengths, batch_first= True)
  35. targets = pack_padded_sequence(
  36. targets, decode_lengths, batch_first= True)
  37. scores = scores.data
  38. targets = targets.data
  39. loss = criterion(scores, targets) # 计算损失
  40. # 加入 doubly stochastic attention 正则化
  41. loss += args.alpha_c * (( 1. - alphas.sum(dim= 1))** 2).mean()
  42. # 跟踪指标
  43. losses.update(loss.item(), sum(decode_lengths))
  44. top5 = accuracy(scores, targets, 5)
  45. top5accs.update(top5, sum(decode_lengths))
  46. batch_time.update(time.time() - start)
  47. start = time.time()
  48. if i % args.log_step == 0:
  49. print( 'Validation: [{0}/{1}]\t'
  50. 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  51. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  52. 'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(
  53. i,
  54. len(val_loader),
  55. batch_time=batch_time,
  56. loss=losses,
  57. top5=top5accs))
  58. # 存储每个图像的引用(真实标题)和假设(预测)
  59. # 如果对于n幅图像,我们有n个假设,参考文献a, b, c…
  60. # 对于每个图像,我们需要
  61. # references= [[ref1a, ref1b, ref1c], [ref2a, ref2b, ref2c],…
  62. # hypotheses= [hyp1, hyp2, …]
  63. # References
  64. # caps = caps[sort_ind] # 因为图像是在解码器中排序的
  65. for j in range(caps.shape[ 0]):
  66. img_caps = caps[j].tolist()
  67. img_captions = list(
  68. map(
  69. lambda c: [
  70. w for w in img_caps if w not in
  71. {vocab.__call__( '<start>'),
  72. vocab.__call__( '<end>')}
  73. ], img_caps)) # 去除 <start> and 填充
  74. references.append(img_captions)
  75. # Hypotheses
  76. _, preds = torch.max(scores_copy, dim= 2)
  77. preds = preds.tolist()
  78. temp_preds = list()
  79. for j, p in enumerate(preds):
  80. temp_preds.append(preds[j][:decode_lengths[j]]) # 移除结尾的填充
  81. preds = temp_preds
  82. hypotheses.extend(preds)
  83. assert len(references) == len(hypotheses)
  84. # 计算 BLEU-4 得分
  85. bleu4 = corpus_bleu(references, hypotheses)
  86. print(
  87. '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}\n'
  88. .format(loss=losses, top5=top5accs, bleu=bleu4))
  89. return bleu4

2.1.7 if __name__ == '__main__':  main函数功能逻辑之外的其余背景部署

可以在这里修改一下进程名字,这样在多人共用服务器是可以互相看见,以免误伤2333


  
  
  1. if __name__ == '__main__':
  2. setproctitle.setproctitle( "张晋豪的python caption flickr8k")
  3. main(args)

2.2 模型model.py

这里就是pytorch定义神经网络的地方了。一般来说,最简单的,就直接继承nn.Module父类,重写forward方法即可。forward方法用于每次数据获取(输入参数)预测输出(return)

当然,还可以定义其他的辅助方法,如fine_tune微调等。

具体例子如下:

CNN编码器定义如下:


  
  
  1. class EncoderCNN(nn.Module):
  2. def __init__(self, encoded_image_size=14):
  3. super(EncoderCNN, self).__init__()
  4. resnet = models.resnet101(pretrained= True)
  5. # children 返回直接子模块上的迭代器
  6. modules = list(resnet.children())[: -2]
  7. self.resnet = nn.Sequential(*modules)
  8. self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size,
  9. encoded_image_size))
  10. self.fine_tune()
  11. def forward(self, images):
  12. out = self.resnet(images)
  13. out = self.adaptive_pool(out)
  14. out = out.permute( 0, 2, 3, 1) # 转换数组轴
  15. return out
  16. def fine_tune(self, fine_tune=True):
  17. for p in self.resnet.parameters():
  18. p.requires_grad = False
  19. for c in list(self.resnet.children())[ 5:]:
  20. for p in c.parameters():
  21. p.requires_grad = fine_tune

attention解码器定义如下:


  
  
  1. class AttnDecoderRNN(nn.Module):
  2. def __init__(self,
  3. attention_dim,
  4. embed_dim,
  5. decoder_dim,
  6. vocab_size,
  7. encoder_dim=2048,
  8. dropout=0.5):
  9. super(AttnDecoderRNN, self).__init__()
  10. self.encoder_dim = encoder_dim
  11. self.attention_dim = attention_dim
  12. self.embed_dim = embed_dim
  13. self.decoder_dim = decoder_dim
  14. self.vocab_size = vocab_size
  15. self.dropout = dropout
  16. self.attention = Attention(encoder_dim, decoder_dim, attention_dim)
  17. self.embedding = nn.Embedding(vocab_size, embed_dim)
  18. self.dropout = nn.Dropout(p=self.dropout)
  19. self.decode_step = nn.LSTMCell(
  20. embed_dim + encoder_dim, decoder_dim, bias= True)
  21. self.init_h = nn.Linear(encoder_dim, decoder_dim)
  22. self.init_c = nn.Linear(encoder_dim, decoder_dim)
  23. self.f_beta = nn.Linear(
  24. decoder_dim,
  25. encoder_dim) # linear layer to create a sigmoid-activated gate
  26. self.sigmoid = nn.Sigmoid()
  27. self.fc = nn.Linear(decoder_dim, vocab_size)
  28. self.init_weights()
  29. def init_weights(self):
  30. self.embedding.weight.data.uniform_( -0.1, 0.1)
  31. self.fc.bias.data.fill_( 0)
  32. self.fc.weight.data.uniform_( -0.1, 0.1)
  33. def load_pretrained_embeddings(self, embeddings):
  34. # Parameter
  35. # 在参数优化的时候可以进行优化 所以经过类型转换这个self.v变成了模型的一部分
  36. # 成为了模型中根据训练可以改动的参数了
  37. # 使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化
  38. self.embedding.weight = nn.Parameter(embeddings)
  39. def fine_tune_embeddings(self, fine_tune=True):
  40. for p in self.embedding.parameters():
  41. p.requires_grad = fine_tune
  42. def init_hidden_state(self, encoder_out):
  43. mean_encoder_out = encoder_out.mean(dim= 1)
  44. h = self.init_h(mean_encoder_out)
  45. c = self.init_c(mean_encoder_out)
  46. return h, c
  47. def forward(self, encoder_out, encoded_captions, caption_lengths):
  48. """
  49. :return: scores for vocabulary, sorted encoded captions, decode lengths, weights
  50. """
  51. batch_size = encoder_out.size( 0)
  52. encoder_dim = encoder_out.size( -1)
  53. vocab_size = self.vocab_size
  54. encoder_out = encoder_out.view(batch_size, -1,
  55. encoder_dim) # view pytorch的reshape
  56. num_pixels = encoder_out.size( 1)
  57. embeddings = self.embedding(encoded_captions)
  58. h, c = self.init_hidden_state(encoder_out)
  59. decode_lengths = [c - 1 for c in caption_lengths]
  60. predictions = torch.zeros(batch_size, max(decode_lengths),
  61. vocab_size).to(device)
  62. alphas = torch.zeros(batch_size, max(decode_lengths),
  63. num_pixels).to(device)
  64. # 一个batch为一个整体预测集合
  65. # 每个caption一个单词一个单词的预测
  66. # 当短的预测完成时,就开始预测剩下的长的
  67. # 在dataloader处已经排序了, 从头到尾caption长度逐渐减少
  68. for t in range(max(decode_lengths)):
  69. batch_size_t = sum([l > t for l in decode_lengths])
  70. attention_weighted_encoding, alpha = self.attention(
  71. encoder_out[:batch_size_t], h[:batch_size_t])
  72. gate = self.sigmoid(self.f_beta(h[:batch_size_t]))
  73. attention_weighted_encoding = gate * attention_weighted_encoding
  74. h, c = self.decode_step(
  75. torch.cat([
  76. embeddings[:batch_size_t, t, :],
  77. attention_weighted_encoding
  78. ],
  79. dim= 1), (h[:batch_size_t], c[:batch_size_t]))
  80. preds = self.fc(self.dropout(h))
  81. predictions[:batch_size_t, t, :] = preds
  82. alphas[:batch_size_t, t, :] = alpha
  83. return predictions, encoded_captions, decode_lengths, alphas

attention辅助类定义如下:


  
  
  1. class Attention(nn.Module):
  2. def __init__(self, encoder_dim, decoder_dim, attention_dim):
  3. super(Attention, self).__init__()
  4. self.encoder_att = nn.Linear(encoder_dim, attention_dim)
  5. self.decoder_att = nn.Linear(decoder_dim, attention_dim)
  6. self.full_att = nn.Linear(attention_dim, 1)
  7. self.relu = nn.ReLU()
  8. self.softmax = nn.Softmax(dim= 1)
  9. def forward(self, encoder_out, decoder_hidden):
  10. att1 = self.encoder_att(encoder_out)
  11. att2 = self.decoder_att(decoder_hidden)
  12. # unsqueeze(arg) 在第arg维增加一个维度值为1的维度
  13. # squeeze(arg) 第arg维的维度值为1,则去掉该维度
  14. att = self.full_att(self.relu(att1 + att2.unsqueeze( 1))).squeeze( 2)
  15. alpha = self.softmax(att)
  16. attention_weighted_encoding = (encoder_out * alpha.unsqueeze( 2)).sum(
  17. dim= 1)
  18. return attention_weighted_encoding, alpha

2.3 预处理prepro.py

预处理部分一般依据任务类型而定,例如nlp的话主要是搭建字典,而cv主要是将图片进行resize降噪标准化等等。

而这个picture_caption的项目就决定了要同时做nlp和cv的预处理工作。

一 nlp 搭建字典的部分


  
  
  1. from flickr8k_dataloader import flickr8k
  2. class Vocabulary(object):
  3. """Simple vocabulary wrapper."""
  4. def __init__(self):
  5. self.word2idx = {}
  6. self.idx2word = {}
  7. self.idx = 0
  8. def add_word(self, word):
  9. if not word in self.word2idx:
  10. self.word2idx[word] = self.idx
  11. self.idx2word[self.idx] = word
  12. self.idx += 1
  13. def __call__(self, word):
  14. if not word in self.word2idx:
  15. return self.word2idx[ '<unk>']
  16. return self.word2idx[word]
  17. def __len__(self):
  18. return len(self.word2idx)
  19. def build_vocab(json, threshold):
  20. """Build a simple vocabulary wrapper."""
  21. flickr = flickr8k(ann_text_location=json)
  22. counter = Counter()
  23. anns_length = len(flickr.anns)
  24. for id in range(anns_length):
  25. caption = str(flickr.anns[id][ 'caption'])
  26. tokens = nltk.tokenize.word_tokenize(caption.lower())
  27. counter.update(tokens)
  28. if id % 1000 == 0:
  29. print( "[%d/%d] Tokenized the captions." % (id, anns_length))
  30. # 如果当词频低于 'threshold', 就会被抛弃
  31. words = [word for word, cnt in counter.items() if cnt >= threshold]
  32. # 创建一个并添加一些特殊的 token
  33. vocab = Vocabulary()
  34. vocab.add_word( '<pad>')
  35. vocab.add_word( '<start>')
  36. vocab.add_word( '<end>')
  37. vocab.add_word( '<unk>')
  38. # 将单词添加到字典中
  39. for i, word in enumerate(words):
  40. vocab.add_word(word)
  41. return vocab

这里用到了我的 flickr8k_dataloader.py 中的辅助类 flickr8k

flickr8k_dataloader.py 完整文件如下


  
  
  1. # coding=utf-8
  2. '''
  3. 读取flickr8k数据集
  4. '''
  5. import re
  6. import os
  7. class flickr8k():
  8. def __init__(
  9. self,
  10. ann_text_location='/mnt/disk2/flickr8k/Flickr8k_text/Flickr8k.lemma.token.txt',
  11. imgs_location='/mnt/disk2/flickr8k/Flickr8k_Dataset/Flickr8k_Dataset/',
  12. ann_rank=4):
  13. '''
  14. 读取flickr8k数据集的辅助类
  15. :param ann_text_location: annotation文件所在的位置
  16. :param imgs_location: 图片文件夹所在位置
  17. :param ann_rank: 选取第几个等级的annotation
  18. '''
  19. self.ann_text_location = ann_text_location
  20. self.ann_rank = ann_rank
  21. self.imgs_location = imgs_location
  22. self.anns = self.read_anns()
  23. def read_anns(self):
  24. '''
  25. 读取图片id(不含.jpg)和annotation
  26. :returns: anns 一个list 每个元素为一个dict: {'image_id': image_id, 'annotation': image_annotation}
  27. '''
  28. anns = []
  29. with open(self.ann_text_location, 'r') as raw_ann_text:
  30. ann_text_lines = raw_ann_text.readlines()
  31. match_re = r'(.*).jpg#' + str(self.ann_rank) + '\s+(.*)'
  32. for line in ann_text_lines:
  33. matchObj = re.match(match_re, line)
  34. if matchObj:
  35. image_id = matchObj.group( 1)
  36. image_annotation = matchObj.group( 2)
  37. image = { 'image_id': image_id, 'caption': image_annotation}
  38. anns.append(image)
  39. return anns
  40. def loadImg(self, img_id):
  41. '''
  42. 返回一张图片的完整路径
  43. :param imgid: 图片的id(不含.jpg)
  44. :param return: img_path 图片的完整路径
  45. :returns: img_path 图片完整路径
  46. '''
  47. img_path = os.path.join(self.imgs_location, img_id + '.jpg')
  48. return img_path
  49. # 测试
  50. # if __name__ == "__main__":
  51. # f = flickr8k()
  52. # print('f.anns[0] ', f.anns[0])
  53. # print('len(f.anns)', len(f.anns))
  54. # id = f.anns[0]['image_id']
  55. # path = f.loadImg(id)
  56. # print('path', path)

二 cv 调整图片的部分


  
  
  1. from PIL import Image
  2. def resize_image(image):
  3. width, height = image.size
  4. # 图片 resize 后以长和宽两者中较短的长度为基准
  5. # 长的边取基准长度的中心部分进行截取 最后形成方形
  6. if width > height:
  7. left = (width - height) / 2
  8. right = width - left
  9. top = 0
  10. bottom = height
  11. else:
  12. top = (height - width) / 2
  13. bottom = height - top
  14. left = 0
  15. right = width
  16. image = image.crop((left, top, right, bottom))
  17. image = image.resize([ 224, 224], Image.ANTIALIAS) # ANTIALIAS 高质量
  18. return image

三 两个配套的主函数(构造字典、resize图片并保存)


  
  
  1. def main(args):
  2. vocab = build_vocab(json=args.caption_path, threshold=args.threshold)
  3. vocab_path = args.vocab_path
  4. with open(vocab_path, 'wb') as f:
  5. pickle.dump(vocab, f)
  6. print( "Total vocabulary size: %d" % len(vocab))
  7. print( "Saved the vocabulary wrapper to '%s'" % vocab_path)
  8. folder = '/mnt/disk2/flickr8k/Flickr8k_Dataset/Flickr8k_Dataset/'
  9. resized_folder = '/mnt/disk2/flickr8k/Flickr8k_Dataset/Flickr8k_Dataset_resized/'
  10. if not os.path.exists(resized_folder):
  11. os.makedirs(resized_folder)
  12. print( 'Start resizing images.')
  13. image_files = os.listdir(folder)
  14. num_images = len(image_files)
  15. for i, image_file in enumerate(image_files):
  16. with open(os.path.join(folder, image_file), 'rb') as f:
  17. with Image.open(f) as image:
  18. image = resize_image(image) # resize 图片
  19. image.save(
  20. os.path.join(resized_folder, image_file),
  21. image.format) # 保存resize之后的图片
  22. if i % 100 == 0:
  23. print( 'Resized images: %d/%d' % (i, num_images))
  24. if __name__ == '__main__':
  25. parser = argparse.ArgumentParser()
  26. parser.add_argument(
  27. '--caption_path',
  28. type=str,
  29. default= '/mnt/disk2/flickr8k/Flickr8k_text/Flickr8k.lemma.token.txt',
  30. help= 'path for train annotation file')
  31. parser.add_argument(
  32. '--vocab_path',
  33. type=str,
  34. default= '/mnt/disk2/flickr8k/Flickr8k_Dataset/vocab.pkl',
  35. help= 'path for saving vocabulary wrapper')
  36. parser.add_argument(
  37. '--threshold',
  38. type=int,
  39. default= 1,
  40. help= 'minimum word count threshold')
  41. args = parser.parse_args()
  42. main(args)

2.4 数据加载data_loader.py

2.1.4 main函数 承载主要逻辑的第四部分讲数据集搭建时已经完整介绍,故不再赘述,贴完整代码如下


  
  
  1. # coding=utf-8
  2. import os
  3. import nltk
  4. import torch
  5. import torch.utils.data as data
  6. from PIL import Image
  7. from flickr8k_dataloader import flickr8k
  8. class DataLoader(data.Dataset):
  9. def __init__(self, root, json, vocab, rank, transform=None):
  10. self.root = root
  11. self.flickr = flickr8k(
  12. ann_text_location=json, imgs_location=root, ann_rank=rank)
  13. self.vocab = vocab
  14. self.rank = rank
  15. self.transform = transform
  16. # 支持范围从0到len(self)独占的整数索引
  17. def __getitem__(self, index):
  18. flickr = self.flickr
  19. vocab = self.vocab
  20. # ann:annotation
  21. caption = flickr.anns[index][ 'caption']
  22. img_id = flickr.anns[index][ 'image_id']
  23. path = flickr.loadImg(img_id)
  24. image = Image.open(path).convert( 'RGB')
  25. if self.transform is not None:
  26. image = self.transform(image)
  27. tokens = nltk.tokenize.word_tokenize(str(caption).lower()) # 分词
  28. caption = []
  29. caption.append(vocab( '<start>'))
  30. caption.extend([vocab(token) for token in tokens])
  31. caption.append(vocab( '<end>'))
  32. target = torch.Tensor(caption)
  33. return image, target
  34. def __len__(self):
  35. return len(self.flickr.anns)
  36. def collate_fn(data):
  37. data.sort(key= lambda x: len(x[ 1]), reverse= True)
  38. images, captions = zip(*data)
  39. images = torch.stack(images, 0) # 将张量序列沿新维度串联起来
  40. lengths = [len(cap) for cap in captions]
  41. targets = torch.zeros(len(captions), max(lengths)).long()
  42. for i, cap in enumerate(captions):
  43. end = lengths[i]
  44. targets[i, :end] = cap[:end]
  45. return images, targets, lengths
  46. def get_loader(root, json, vocab, transform, batch_size, rank, shuffle,
  47. num_workers):
  48. flickr = DataLoader(
  49. root=root, json=json, vocab=vocab, rank=rank, transform=transform)
  50. # 数据加载 flickr 数据集
  51. # 每次迭代返回 (images, captions, lengths)
  52. # images: tensor of shape (batch_size, 3, 224, 224).
  53. # captions: tensor of shape (batch_size, padded_length).
  54. # lengths: 表示每个标题有效长度的列表. length is (batch_size).
  55. data_loader = torch.utils.data.DataLoader(
  56. dataset=flickr,
  57. batch_size=batch_size,
  58. shuffle=shuffle,
  59. num_workers=num_workers,
  60. collate_fn=collate_fn) # 合并一个示例列表以形成一个 mini-batch
  61. return data_loader

2.5 工具utils.py

这里主要是一些小工具,之前文字和代码已经提到 clip_gradient(在反向传播过程中计算剪辑梯度, 以避免梯度爆炸) 、save_checkpoint(保存中间模型)、AverageMeter(辅助类,跟踪度量的最新值、平均值、和与计数)、adjust_learning_rate(将学习率降低一个特定的因子)和accuracy(从预测和真实的标签, 计算top-k精度)。注释格式写得挺好的,直接看吧。


  
  
  1. # coding=utf-8
  2. import numpy as np
  3. import torch
  4. def init_embedding(embeddings):
  5. """
  6. 用均匀分布填补embedding tensor
  7. :param embeddings: embedding tensor
  8. """
  9. bias = np.sqrt( 3.0 / embeddings.size( 1))
  10. torch.nn.init.uniform_(embeddings, -bias, bias)
  11. def load_embeddings(emb_file, word_map):
  12. """
  13. 为指定的 word map 创建一个 embedding tensor, 用于加载到模型中
  14. :param emb_file: file containing embeddings (stored in GloVe format)
  15. :param word_map: word map
  16. :return: embeddings(顺序与 word map 中的单词相同, 即 embeddings 的维度) emb_dim(embedding 维度)
  17. """
  18. # 找到 embedding 维数
  19. with open(emb_file, 'r') as f:
  20. emb_dim = len(f.readline().split( ' ')) - 1
  21. vocab = set(word_map.keys())
  22. # 创建 tensor 来保存 embeddings, initialize
  23. embeddings = torch.FloatTensor(len(vocab), emb_dim)
  24. init_embedding(embeddings)
  25. # 读取 embedding 文件
  26. print( "\nLoading embeddings...")
  27. for line in open(emb_file, 'r'):
  28. line = line.split( ' ')
  29. emb_word = line[ 0]
  30. # 处理词向量
  31. # 去掉空格 再把字符串转换为 float 类型
  32. embedding = list(
  33. map( lambda t: float(t),
  34. filter( lambda n: n and not n.isspace(), line[ 1:])))
  35. # 忽略不在 train_vocab 中的单词
  36. if emb_word not in vocab:
  37. continue
  38. # 将 embedding 中的单词和词向量记录在 embeddings 中
  39. embeddings[word_map[emb_word]] = torch.FloatTensor(embedding)
  40. return embeddings, emb_dim
  41. def clip_gradient(optimizer, grad_clip):
  42. """
  43. 在反向传播过程中计算剪辑梯度, 以避免梯度爆炸
  44. :param optimizer: optimizer with the gradients to be clipped
  45. :param grad_clip: clip value
  46. """
  47. for group in optimizer.param_groups:
  48. for param in group[ 'params']:
  49. if param.grad is not None:
  50. # 将输入的所有元素钳入范围[min, max]并返回一个结果张量
  51. # 本身在其中的就不变 超出的分别用 min 和 max 代替
  52. param.grad.data.clamp_(-grad_clip, grad_clip)
  53. def save_checkpoint(data_name, epoch, epochs_since_improvement, encoder,
  54. decoder, encoder_optimizer, decoder_optimizer, bleu4,
  55. is_best):
  56. """
  57. Saves model checkpoint.
  58. :param data_name: base name of processed dataset
  59. :param epoch: epoch number
  60. :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score
  61. :param encoder: encoder model
  62. :param decoder: decoder model
  63. :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
  64. :param decoder_optimizer: optimizer to update decoder's weights
  65. :param bleu4: validation BLEU-4 score for this epoch
  66. :param is_best: is this checkpoint the best so far?
  67. """
  68. state = {
  69. 'epoch': epoch,
  70. 'epochs_since_improvement': epochs_since_improvement,
  71. 'bleu-4': bleu4,
  72. 'encoder': encoder,
  73. 'decoder': decoder,
  74. 'encoder_optimizer': encoder_optimizer,
  75. 'decoder_optimizer': decoder_optimizer
  76. }
  77. filename = 'checkpoint_' + data_name + '.pth.tar'
  78. torch.save(state, filename)
  79. # 如果这个 checkpoint 是目前为止最好的,存储一个副本,这样它就不会被更差的 checkpoint 覆盖
  80. if is_best:
  81. torch.save(state, 'BEST_' + filename)
  82. class AverageMeter(object):
  83. """
  84. 跟踪度量的最新值、平均值、和与计数
  85. """
  86. def __init__(self):
  87. self.reset()
  88. def reset(self):
  89. self.val = 0
  90. self.avg = 0
  91. self.sum = 0
  92. self.count = 0
  93. def update(self, val, n=1):
  94. self.val = val
  95. self.sum += val * n
  96. self.count += n
  97. self.avg = self.sum / self.count
  98. def adjust_learning_rate(optimizer, shrink_factor):
  99. """
  100. 将学习率降低一个特定的因子
  101. :param optimizer: optimizer whose learning rate must be shrunk.
  102. :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
  103. """
  104. print( "\nDECAYING learning rate.")
  105. for param_group in optimizer.param_groups:
  106. param_group[ 'lr'] = param_group[ 'lr'] * shrink_factor
  107. print(
  108. "The new learning rate is %f\n" % (optimizer.param_groups[ 0][ 'lr'], ))
  109. def accuracy(scores, targets, k):
  110. """

猜你喜欢

转载自blog.csdn.net/weixin_41817576/article/details/94629732
今日推荐