CV | Emotionally Enhanced Talking Face Generation论文详解及代码实现

 本博客主要讲解了Emotionally Enhanced Talking Face Generation(情感增强的谈话人脸生成)论文概括与项目实现,以及代码理解。

Emotionally Enhanced Talking Face Generation

Paper :https://arxiv.org/pdf/2303.11548.pdf

Code: GitHub - sahilg06/EmoGen: PyTorch Implementation for Paper "Emotionally Enhanced Talking Face Generation"

(克隆项目下载权重后,可直接进行推理)

目录

论文概括

项目实现

1.环境设置

2.数据处理及项目运行

3.开始训练

3.1.训练专家鉴别器

3.2.训练情绪鉴别器

3.3.训练最终模型

4.推理

过程中遇到的问题及解决(PS)

代码详解(按运行顺序)


论文概括

论文创新点

  • 输入视频,任意人脸+情绪合成
  • 提出了一个新的深度学习模型,可以生成照片般逼真的唇语人脸视频,其中包含了不同的情绪和相关表情。
  • 引入了一个多模态框架,以生成与任何任意身份、语言和情感无关的唇语视频。
  • 开发了一个基于网络的响应式界面,用于实时生成带有情绪的对话脸。

模型框架

 

项目实现

1.环境设置

Ubuntu(docker 容器) ,torch-gpu,cuda11.7

克隆项目,

git clone https://github.com/sahilg06/EmoGen
cd EmoGen

 安装相关库

sudo apt-get install ffmpeg
pip install -r requirements.txt
#相关库
pip install albumentations

 配置下载命令工具,安装git lfs

curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash
apt-get install git-lfs

git lfs install

下载CREMA-D数据集命令 (使用git clone会出错)

git lfs clone https://github.com/CheyneyComputerScience/CREMA-D.git

2.数据处理及项目运行

运行

python convertFPS.py -i /workspace/facegan/EmoGen/CREMA-D/VideoFlash -o /workspace/facegan/EmoGen/CREMA-D/flv-output

将视频文件(.flv)每25帧截图保存为flv-output文件夹。(文件夹名自己设置)

如果出错可参考【PS3】

 

 (等大概25分钟左右)

 接着处理MP4文件,截取数据集的人脸部分,且每个人,每个人的心情等划分为不同的文件夹:

运行

python preprocess_crema-d.py --data_root /workspace/facegan/EmoGen/CREMA-D/flv-output --preprocessed_root preprocessed_dataset/

 如果出错可参考【PS4/5】,运行过程如图

 (大约运行3小时)

3.开始训练

主要分为三个步骤:

  1. 训练专家口型同步鉴别器
  2. 训练情绪鉴别器 
  3. 训练 EmoGen 模型

3.1.训练专家鉴别器

(如果只有一个gpu,或者内存不大需要修改相关代码)

相关参数在hparams.py中,可修改batch_size和num_workers(默认为batch_size=16,num_workers=16)

python color_syncnet_train.py --data_root preprocessed_dataset/ --checkpoint_dir sync-checkpoint

 开始训练

运行16小时后epoch210,保存sync-chrakpoint文件夹下

 这里产生的权重文件是SyncNet网络+残差跳跃连接 训练后的权重,

因为设置epoch数比较高,可按下快捷键Ctrl+C 停止训练。

3.2.训练情绪鉴别器

python emotion_disc_train.py -i preprocessed_dataset/ -o emo-checkpoint

 开始训练

 因为设置epoch数比较高,可按下快捷键Ctrl+C 停止训练。

3.3.训练最终模型

python train.py --data_root preprocessed_dataset/ --checkpoint_dir emogen-checkpoint --syncnet_checkpoint_path sync-checkpoint/checkpoint_step000011256.pth --emotion_disc_path emo-checkpoint/disc_emo_23000.pth

- emogen-checkpoint 训练后的权重文件都保存在这个文件夹 

- sync-checkpoint/checkpoint_step000011256.pth 训练鉴别器时的权重文件

- emo-checkpoint/disc_emo_23000.pth训练情绪鉴别器时的权重文件

4.推理

*推理时需要注释掉以下代码:

model/wav2lip.py中的108行和113行,改为

        #emotion = emotion.unsqueeze(1).repeat(1, 5, 1) #(B, T, 6) 

 #emotion = torch.cat([emotion[:, i] for i in range(emotion.size(1))], dim=0) #(B*T, 6)

 然后进行推理:

python inference.py --checkpoint_path emogen-checkpoint/训练后的checkpoint.pth --face 自己的mp4文件 --audio 一个语音文件  --emotion 想要生成的情绪选择
--checkpoint_path 
--face 自己的mp4文件 
--audio 一个语音文件 : *.wav*.mp3甚至是视频文件,代码会自动从中提取音频
--emotion 想要生成的情绪选择 :从列表中选择分类情绪:[HAP、SAD、FEA、ANG、DIS、NEU]

中间省略

在推理时,要在wav2lip模型中,把语音编码(audio_embedding)和情绪编码(emotion_embedding)进行连接(torch.cat)

也可以不进行训练,直接下载checkpoint,下载地址Google Drive

python inference.py --checkpoint_path checkpoint.pth --face temp.mp4 --audio temp.wav  --emotion HAP

过程中遇到的问题及解决(PS)

[PS1]docker容器安装ffmpeg失败,出现Err:1 http://security.ubuntu.com/ubuntu focal-updates/main amd64 libwebp6 amd64 0.6.1-2ubuntu0.20.04.1
  404  Not Found [IP: 185.125.190.39 80]
Err:2 http://security.ubuntu.com/ubuntu focal-updates/main amd64 libwebpmux3 amd64 0.6.1-2ubuntu0.20.04.1
  404  Not Found [IP: 185.125.190.39 80]

 原因分析

linux服务器上ffmpeg版本为4.2.7,且没问题

docker容器安装辅助项

apt-get install yasm
apt-get install libx264-dev
apt-get install libfdk-aac-dev
apt-get install libmp3lame-dev
apt-get install libopus-dev
apt-get install libvpx-dev

apt-get update
apt install ffmpeg

成功后

数据集Flash 样本

【PS2】TypeError: makedirs() got an unexpected keyword argument 'exist_ok'

解决方法:删掉exist_ok=True

【PS3】/workspace/facegan/EmoGen/CREMA-D/VideoFlash/1018_MTI_DIS_XX.flv: Invalid data found when processing input

解决方法:是下载数据时,文件出现问题,重新下载数据后正常。

【PS4】AttributeError: partially initialized module 'cv2' has no attribute 'gapi_wip_gst_GStreamerPipeline' (most likely due to a circular import)

查看opencv_python的版本,是4.7.0.72

把版本降级

pip install opencv-python==4.3.0.36

 【PS5】ImportError: libSM.so.6: cannot open shared object file: No such file or directory

 因为我是docker容器,所以要下载

pip install opencv-python-headless==4.3.0.36

 [PS6]RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 64 but got size 320 for tensor number 1 in the list.

代码详解(按运行顺序)

convertFPS.py

import argparse
import os
import subprocess

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    #定义参数 -i 是 --input-folder的缩写,运行时添加数据集的路径
    parser.add_argument("-i", "--input-folder", type=str, help='Path to folder that contains video files')
    #定义参数 -fps 是指画面每秒传输帧数,即动画或视频每秒切换的图片张数,帧数越大,流畅度越高,
    parser.add_argument("-fps", type=float, help='Target FPS', default=25.0)
    #定义参数 -o 是 --output-folder的缩写,运行时添加数据处理后的路径
    parser.add_argument("-o", "--output-folder", type=str, help='Path to output folder')
    args = parser.parse_args()

     #建立处理后的文件夹名
    #os.makedirs(args.output_folder, exist_ok=True)
    os.makedirs(args.output_folder)
    fileList = []
    #对于数据路径下的(文件格式为:MP4、mpg、mov、flv)的文件,循环切割提取文件名 
    for root, dirnames, filenames in os.walk(args.input_folder):
        for filename in filenames:
            if os.path.splitext(filename)[1] == '.mp4' or os.path.splitext(filename)[1] == '.mpg' or os.path.splitext(filename)[1] == '.mov' or os.path.splitext(filename)[1] == '.flv':
                 #对于所提取的文件进行展平
                fileList.append(os.path.join(root, filename))

    #对于所提取的文件利用ffmpeg库进行视频切片,并存为.MP4文件
     
    for file in fileList:
        subprocess.run("ffmpeg -i {} -r 25 -y {}".format(file, os.path.splitext(file.replace(args.input_folder, args.output_folder))[0]+".mp4"), shell=True)

使用ffmpeg分割视频时,指定开始、结束时间。使用以下命令

ffmpeg -ss [start] -i [input] -to [end] -c copy [output]

参数

参数 作用
-ss 读取位置
-i ffmpeg的必要字段
-t 持续时间
-to 结束位置
-c 编解码器
copy 源文件编解码器
[start] 开始时间
[end] 结束时间
[duration] 持续时间
[input] 输入文件路径
[output] 输出文件路径

-r : 每秒帧数(指定帧率,这样达到视频压缩效果)
注意 :-ss 要放在 -i 之前
preprocess_crema-d.py

import sys

if sys.version_info[0] < 3 and sys.version_info[1] < 2:
	raise Exception("Must be using >= Python 3.2")

from os import listdir, path

if not path.isfile('face_detection/detection/sfd/s3fd.pth'):
	raise FileNotFoundError('Save the s3fd model to face_detection/detection/sfd/s3fd.pth \
							before running this script!')

import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import argparse, os, cv2, traceback, subprocess
from tqdm import tqdm
from glob import glob
import audio
from hparams import hparams as hp

import face_detection

parser = argparse.ArgumentParser()
# gpu的数量
parser.add_argument('--ngpu', help='Number of GPUs across which to run in parallel', default=1, type=int)
# 单一gpu人脸检测的批量大小,默认32
parser.add_argument('--batch_size', help='Single GPU Face detection batch size', default=32, type=int)
# 数据集地址
parser.add_argument("--data_root", help="Root folder of the LRS2 dataset", required=True)
# 处理后的数据集地址
parser.add_argument("--preprocessed_root", help="Root folder of the preprocessed dataset", required=True)

args = parser.parse_args()
#识别视频数据集中的人脸
fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, 
									device='cuda:{}'.format(id)) for id in range(args.ngpu)]

#识别人脸后利用ffmpeg库处理
template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
# template2 = 'ffmpeg -hide_banner -loglevel panic -threads 1 -y -i {} -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 {}'

def process_video_file(vfile, args, gpu_id):
	video_stream = cv2.VideoCapture(vfile)
	
	frames = []
	while 1:
		still_reading, frame = video_stream.read()
		if not still_reading:
			video_stream.release()
			break
		frames.append(frame)
	
	vidname = os.path.basename(vfile).split('.')[0]
	#dirname = vfile.split('/')[-2]

	fulldir = path.join(args.preprocessed_root, vidname)
	os.makedirs(fulldir, exist_ok=True)

	batches = [frames[i:i + args.batch_size] for i in range(0, len(frames), args.batch_size)]

	i = -1
	for fb in batches:
		preds = fa[gpu_id].get_detections_for_batch(np.asarray(fb))

		for j, f in enumerate(preds):
			i += 1
			if f is None:
				continue
            # 截取的人脸保存四点为一张照片
			x1, y1, x2, y2 = f
			cv2.imwrite(path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])

def process_audio_file(vfile, args):
	vidname = os.path.basename(vfile).split('.')[0]
	#dirname = vfile.split('/')[-2]

	fulldir = path.join(args.preprocessed_root, vidname)
	os.makedirs(fulldir, exist_ok=True)

	wavpath = path.join(fulldir, 'audio.wav')

	command = template.format(vfile, wavpath)
	subprocess.call(command, shell=True)

	
def mp_handler(job):
	vfile, args, gpu_id = job
	try:
		process_video_file(vfile, args, gpu_id)
	except KeyboardInterrupt:
		exit(0)
	except:
		traceback.print_exc()
		
def main(args):
	print('Started processing for {} with {} GPUs'.format(args.data_root, args.ngpu))
    
    #
	filelist = glob(path.join(args.data_root, '*.mp4'))

	jobs = [(vfile, args, i%args.ngpu) for i, vfile in enumerate(filelist)]
	p = ThreadPoolExecutor(args.ngpu)
	futures = [p.submit(mp_handler, j) for j in jobs]
	_ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]

	print('Dumping audios...')

	for vfile in tqdm(filelist):
		try:
			process_audio_file(vfile, args)
		except KeyboardInterrupt:
			exit(0)
		except:
			traceback.print_exc()
			continue

if __name__ == '__main__':
	main(args)

 color_syncnet_train.py

from os.path import dirname, join, basename, isfile, isdir
from tqdm import tqdm

from models import SyncNet_color as SyncNet
import audio

import torch
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np

from glob import glob

import os, random, cv2, argparse
import albumentations as A
from hparams import hparams, get_image_list

parser = argparse.ArgumentParser(description='Code to train the expert lip-sync discriminator')

# 数据集路径
parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True)

parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
parser.add_argument('--checkpoint_path', help='Resumed from this checkpoint', default=None, type=str)

args = parser.parse_args()


global_step = 0
global_epoch = 0
os.environ['CUDA_VISIBLE_DEVICES']='2'
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
emonet_T = 5
syncnet_mel_step_size = 16

class Dataset(object):
    def __init__(self, split):
        #self.all_videos = get_image_list(args.data_root, split)
        self.all_videos = [join(args.data_root, f) for f in os.listdir(args.data_root) if isdir(join(args.data_root, f))]
        print('Num files: ', len(self.all_videos))

        # to apply same augmentation for all the frames
        target = {}
        for i in range(1, emonet_T):
            target['image' + str(i)] = 'image'
        
        self.augments = A.Compose([
                        A.RandomBrightnessContrast(p=0.2),    
                        A.RandomGamma(p=0.2),    
                        A.CLAHE(p=0.2),
                        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=50, val_shift_limit=50, p=0.2),  
                        A.ChannelShuffle(p=0.2), 
                        A.RGBShift(p=0.2),
                        A.RandomBrightness(p=0.2),
                        A.RandomContrast(p=0.2),
                        A.GaussNoise(var_limit=(10.0, 50.0), p=0.25),
                    ], additional_targets=target, p=0.8)
    
    def augmentVideo(self, video):
        args = {}
        args['image'] = video[0, :, :, :]
        for i in range(1, emonet_T):
            args['image' + str(i)] = video[i, :, :, :]
        result = self.augments(**args)
        video[0, :, :, :] = result['image']
        for i in range(1, emonet_T):
            video[i, :, :, :] = result['image' + str(i)]
        return video

    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def crop_audio_window(self, spec, start_frame):
        # num_frames = (T x hop_size * fps) / sample_rate
        start_frame_num = self.get_frame_id(start_frame)
        start_idx = int(80. * (start_frame_num / float(hparams.fps)))

        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx : end_idx, :]


    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):
        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]
            #print(vidname)

            img_names = list(glob(join(vidname, '*.jpg')))
            if len(img_names) <= 3 * syncnet_T:
                continue
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)

            if random.choice([True, False]):
                y = torch.ones(1).float()
                chosen = img_name
            else:
                y = torch.zeros(1).float()
                chosen = wrong_img_name

            window_fnames = self.get_window(chosen)
            if window_fnames is None:
                continue

            window = []
            all_read = True
            for fname in window_fnames:
                img = cv2.imread(fname)
                if img is None:
                    all_read = False
                    break
                try:
                    img = cv2.resize(img, (hparams.img_size, hparams.img_size))
                except Exception as e:
                    all_read = False
                    break

                window.append(img)

            if not all_read: continue

            try:
                wavpath = join(vidname, "audio.wav")
                wav = audio.load_wav(wavpath, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)

            if (mel.shape[0] != syncnet_mel_step_size):
                continue

            # H x W x 3 * T
            window = np.asarray(window)
            aug_results = self.augmentVideo(window)
            window = np.split(aug_results, syncnet_T, axis=0)

            x = np.concatenate(window, axis=3) / 255.
            x = np.squeeze(x, axis=0).transpose(2, 0, 1)
            # print(x.shape)
            x = x[:, x.shape[1]//2:]

            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)

            return x, mel, y

logloss = nn.BCELoss()
def cosine_loss(a, v, y):
    d = nn.functional.cosine_similarity(a, v)
    loss = logloss(d.unsqueeze(1), y)

    return loss

def train(device, model, train_data_loader, test_data_loader, optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    
    #scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=6)

    global global_step, global_epoch
    resumed_step = global_step
    num_batches = len(train_data_loader)
    
    while global_epoch < nepochs:
        print('Epoch: {}'.format(global_epoch))
        running_loss = 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, mel, y) in prog_bar:
            model.train()
            optimizer.zero_grad()
            
            # Transform data to CUDA device
            x = x.to(device)
            mel = mel.to(device)

            a, v = model(mel, x)
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            loss.backward()
            optimizer.step()

            global_step += 1
            cur_session_steps = global_step - resumed_step
            running_loss += loss.item()

            # if global_step == 1 or global_step % checkpoint_interval == 0:
            #     save_checkpoint(
            #         model, optimizer, global_step, checkpoint_dir, global_epoch)

            # if global_step % hparams.syncnet_eval_interval == 0:
            #     with torch.no_grad():
            #         eval_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)

            prog_bar.set_description('Loss: {}'.format(running_loss / (step + 1)))

        writer.add_scalar("Loss/train", running_loss/num_batches, global_epoch)

        with torch.no_grad():
            eval_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
            if(global_epoch % 50 == 0):
                save_checkpoint(model, optimizer, global_step, checkpoint_dir, global_epoch)

        global_epoch += 1

def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
    eval_steps = 1400
    print('Evaluating for {} steps'.format(eval_steps))
    losses = []
    while 1:
        for step, (x, mel, y) in enumerate(test_data_loader):

            model.eval()

            # Transform data to CUDA device
            x = x.to(device)

            mel = mel.to(device)

            a, v = model(mel, x)
            y = y.to(device)

            loss = cosine_loss(a, v, y)
            losses.append(loss.item())

            if step > eval_steps: break

        averaged_loss = sum(losses) / len(losses)
        print(averaged_loss)
        writer.add_scalar("Loss/val", averaged_loss, global_step)

        return averaged_loss

def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):

    checkpoint_path = join(
        checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
    optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
    torch.save({
        "state_dict": model.state_dict(),
        "optimizer": optimizer_state,
        "global_step": step,
        "global_epoch": epoch,
    }, checkpoint_path)
    print("Saved checkpoint:", checkpoint_path)

def _load(checkpoint_path):
    if use_cuda:
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint

def load_checkpoint(path, model, optimizer, reset_optimizer=False):
    global global_step
    global global_epoch

    print("Load checkpoint from: {}".format(path))
    checkpoint = _load(path)
    model.load_state_dict(checkpoint["state_dict"])
    if not reset_optimizer:
        optimizer_state = checkpoint["optimizer"]
        if optimizer_state is not None:
            print("Load optimizer state from {}".format(path))
            optimizer.load_state_dict(checkpoint["optimizer"])
    global_step = checkpoint["global_step"]
    global_epoch = checkpoint["global_epoch"]

    return model

if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    checkpoint_path = args.checkpoint_path

    if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)

    # Dataset and Dataloader setup
    #train_dataset = Dataset('train')
    #test_dataset = Dataset('val')

    full_dataset = Dataset('train')
    train_size = int(0.95 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.syncnet_batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.syncnet_batch_size,
        num_workers=8)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = SyncNet().to(device)
    #model = nn.DataParallel(SyncNet(), device_ids=[1,2]).to(device)

    print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.syncnet_lr,betas=(0.5,0.999))

    if checkpoint_path is not None:
        load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer=False)

    writer = SummaryWriter('runs/crema-d_disc_exp2_data_aug')

    train(device, model, train_data_loader, test_data_loader, optimizer,
          checkpoint_dir=checkpoint_dir,
          checkpoint_interval=hparams.syncnet_checkpoint_interval,
          nepochs=hparams.nepochs)

    writer.flush()

以Syncnet网络为基础,训练一个鉴别器,关于Syncnet,详细可查看

emotion_disc_train.py

import argparse
import json
import os
from tqdm import tqdm
import random as rn
import shutil

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score
from torch.utils.tensorboard import SummaryWriter

from models import emo_disc
from datagen_aug import Dataset

def initParams():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("-i", "--in-path", type=str, help="Input folder containing train data", default=None, required=True)
    # parser.add_argument("-v", "--val-path", type=str, help="Input folder containing validation data", default=None, required=True)
    parser.add_argument("-o", "--out-path", type=str, help="output folder", default='../models/def', required=True)

    parser.add_argument('--num_epochs', type=int, default=10000)
    parser.add_argument("--batch-size", type=int, default=64)

    parser.add_argument('--lr_emo', type=float, default=1e-06)

    parser.add_argument("--gpu-no", type=str, help="select gpu", default='1')
    parser.add_argument('--seed', type=int, default=9)

    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_no

    args.batch_size = args.batch_size * max(int(torch.cuda.device_count()), 1)
    args.steplr = 200

    args.filters = [64, 128, 256, 512, 512]
    #-----------------------------------------#
    #           Reproducible results          #
    #-----------------------------------------#
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    rn.seed(args.seed)
    torch.manual_seed(args.seed)
    #-----------------------------------------#
   
    if not os.path.exists(args.out_path):
        os.makedirs(args.out_path)
    else:
        shutil.rmtree(args.out_path)
        os.mkdir(args.out_path)

    with open(os.path.join(args.out_path, 'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    args.cuda = torch.cuda.is_available() 
    print('Cuda device available: ', args.cuda)
    args.device = torch.device("cuda" if args.cuda else "cpu") 
    args.kwargs = {'num_workers': 0, 'pin_memory': True} if args.cuda else {}

    return args

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d or type(m) == nn.Conv1d:
        torch.nn.init.xavier_uniform_(m.weight)

def enableGrad(model, requires_grad):
    for p in model.parameters():
        p.requires_grad_(requires_grad)
  

def train():
    args = initParams()
    
    trainDset = Dataset(args)

    train_loader = torch.utils.data.DataLoader(trainDset,
                                               batch_size=args.batch_size, 
                                               shuffle=True,
                                               drop_last=True,
                                               **args.kwargs)
    
    device_ids = list(range(torch.cuda.device_count()))
    
    disc_emo = emo_disc.DISCEMO().to(args.device)
    disc_emo.apply(init_weights)
    #disc_emo = nn.DataParallel(disc_emo, device_ids)

    emo_loss_disc = nn.CrossEntropyLoss()

    num_batches = len(train_loader)
    print(args.batch_size, num_batches)

    global_step = 0
    
    for epoch in range(args.num_epochs):
        print('Epoch: {}'.format(epoch))
        prog_bar = tqdm(enumerate(train_loader))
        running_loss = 0.
        for step, (x, y) in prog_bar:
            video, emotion = x.to(args.device), y.to(args.device)

            disc_emo.train()

            disc_emo.opt.zero_grad() # .module is because of nn.DataParallel 

            class_real = disc_emo(video)

            loss = emo_loss_disc(class_real, torch.argmax(emotion, dim=1))

            running_loss += loss.item()

            loss.backward()
            disc_emo.opt.step() # .module is because of nn.DataParallel 
            
            #每隔1000打印并保存权重文件
            if global_step % 1000 == 0:
                print('Saving the network')
                torch.save(disc_emo.state_dict(), os.path.join(args.out_path, f'disc_emo_{global_step}.pth'))
                print('Network has been saved')
            
            prog_bar.set_description('classification Loss: {}'.format(running_loss / (step + 1)))

            global_step += 1

        writer.add_scalar("classification Loss", running_loss/num_batches, epoch)
        
        disc_emo.scheduler.step() # .module is because of nn.DataParallel 

if __name__ == "__main__":

    writer = SummaryWriter('runs/emo_disc_exp4')
    train()

下载CREMA-D数据集命令

git clone https://github.com/CheyneyComputerScience/CREMA-D
#文件会出错,MP3文件,wav文件,flv文件等克隆后为二进制文件

猜你喜欢

转载自blog.csdn.net/weixin_44649780/article/details/130961092