真香~BERT在MAC Pytorch的使用

前言

终于,Pytorch也支持MAC的硬件加速,两个字评价一下感受:真香~

周末笔者在自己机器上完成环境安装,笔者机器环境如下:

接着,笔者在该文用卷积、BERT模型对比了有无MAC硬件加速的模型运行时间

软件安装

按照官网给出的命令,即可完成安装MAC硬件加速版pytorch。

https://pytorch.org/get-started/locally/

conda install pytorch torchvision torchaudio -c pytorch

简单测试

利用卷积操作,测试有无硬件加速的效果。

import torch

import time



dev = 'mps:0'

conv = torch.nn.Conv2d(10, 10, 3).to(dev)

img = torch.randn(64, 10, 64, 64).to(dev)



t0 = time.time()

for i in range(1000):

    conv(img)

t1 = time.time()

print('Use mps, time:{}'.format(t1-t0))



dev = 'cpu'

conv = torch.nn.Conv2d(10, 10, 3).to(dev)

img = torch.randn(64, 10, 64, 64).to(dev)



t0 = time.time()

for i in range(1000):

    conv(img)

t1 = time.time()

print('Use cpu, time:{}'.format(t1-t0))

运行结果

BERT测试

使用huggingface的glue代码作示例。

数据准备

运行下述代码完成数据下载工作。

''' Script for downloading all GLUE data.



Note: for legal reasons, we are unable to host MRPC.

You can either use the version hosted by the SentEval team, which is already tokenized,

or you can download the original data from (https://download.microsoft.com/download/D/4/6/D46FF87A-F6B9-4252-AA8B-3604ED519838/MSRParaphraseCorpus.msi) and extract the data from it manually.

For Windows users, you can run the .msi file. For Mac and Linux users, consider an external library such as 'cabextract' (see below for an example).

You should then rename and place specific files in a folder (see below for an example).



mkdir MRPC

cabextract MSRParaphraseCorpus.msi -d MRPC

cat MRPC/_2DEC3DBE877E4DB192D17C0256E90F1D | tr -d $'\r' > MRPC/msr_paraphrase_train.txt

cat MRPC/_D7B391F9EAFF4B1B8BCE8F21B20B1B61 | tr -d $'\r' > MRPC/msr_paraphrase_test.txt

rm MRPC/_*

rm MSRParaphraseCorpus.msi



1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.

2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!

'''



import os

import sys

import shutil

import argparse

import tempfile

import urllib.request

import zipfile



TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "QNLI", "RTE", "WNLI", "diagnostic"]

TASK2PATH = {"CoLA": 'https://dl.fbaipublicfiles.com/glue/data/CoLA.zip',

             "SST": 'https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',

             "QQP": 'https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip',

             "STS": 'https://dl.fbaipublicfiles.com/glue/data/STS-B.zip',

             "MNLI": 'https://dl.fbaipublicfiles.com/glue/data/MNLI.zip',

             "QNLI": 'https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip',

             "RTE": 'https://dl.fbaipublicfiles.com/glue/data/RTE.zip',

             "WNLI": 'https://dl.fbaipublicfiles.com/glue/data/WNLI.zip',

             "diagnostic": 'https://dl.fbaipublicfiles.com/glue/data/AX.tsv'}



MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'

MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'





def download_and_extract(task, data_dir):

    print("Downloading and extracting %s..." % task)

    if task == "MNLI":

        print(

            "\tNote (12/10/20): This script no longer downloads SNLI. You will need to manually download and format the data to use SNLI.")

    data_file = "%s.zip" % task

    urllib.request.urlretrieve(TASK2PATH[task], data_file)

    with zipfile.ZipFile(data_file) as zip_ref:

        zip_ref.extractall(data_dir)

    os.remove(data_file)

    print("\tCompleted!")





def format_mrpc(data_dir, path_to_data):

    print("Processing MRPC...")

    mrpc_dir = os.path.join(data_dir, "MRPC")

    if not os.path.isdir(mrpc_dir):

        os.mkdir(mrpc_dir)

    if path_to_data:

        mrpc_train_file = os.path.join(path_to_data, "msr_paraphrase_train.txt")

        mrpc_test_file = os.path.join(path_to_data, "msr_paraphrase_test.txt")

    else:

        try:

            mrpc_train_file = os.path.join(mrpc_dir, "msr_paraphrase_train.txt")

            mrpc_test_file = os.path.join(mrpc_dir, "msr_paraphrase_test.txt")

            URLLIB.urlretrieve(MRPC_TRAIN, mrpc_train_file)

            URLLIB.urlretrieve(MRPC_TEST, mrpc_test_file)

        except urllib.error.HTTPError:

            print("Error downloading MRPC")

            return

    assert os.path.isfile(mrpc_train_file), "Train data not found at %s" % mrpc_train_file

    assert os.path.isfile(mrpc_test_file), "Test data not found at %s" % mrpc_test_file



    with io.open(mrpc_test_file, encoding='utf-8') as data_fh, \

            io.open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding='utf-8') as test_fh:

        header = data_fh.readline()

        test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")

        for idx, row in enumerate(data_fh):

            label, id1, id2, s1, s2 = row.strip().split('\t')

            test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))



    try:

        URLLIB.urlretrieve(TASK2PATH["MRPC"], os.path.join(mrpc_dir, "dev_ids.tsv"))

    except KeyError or urllib.error.HTTPError:

        print("\tError downloading standard development IDs for MRPC. You will need to manually split your data.")

        return



    dev_ids = []

    with io.open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding='utf-8') as ids_fh:

        for row in ids_fh:

            dev_ids.append(row.strip().split('\t'))



    with io.open(mrpc_train_file, encoding='utf-8') as data_fh, \

            io.open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding='utf-8') as train_fh, \

            io.open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding='utf-8') as dev_fh:

        header = data_fh.readline()

        train_fh.write(header)

        dev_fh.write(header)

        for row in data_fh:

            label, id1, id2, s1, s2 = row.strip().split('\t')

            if [id1, id2] in dev_ids:

                dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))

            else:

                train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))



    print("\tCompleted!")





def download_diagnostic(data_dir):

    print("Downloading and extracting diagnostic...")

    if not os.path.isdir(os.path.join(data_dir, "diagnostic")):

        os.mkdir(os.path.join(data_dir, "diagnostic"))

    data_file = os.path.join(data_dir, "diagnostic", "diagnostic.tsv")

    urllib.request.urlretrieve(TASK2PATH["diagnostic"], data_file)

    print("\tCompleted!")

    return





def get_tasks(task_names):

    task_names = task_names.split(',')

    if "all" in task_names:

        tasks = TASKS

    else:

        tasks = []

        for task_name in task_names:

            assert task_name in TASKS, "Task %s not found!" % task_name

            tasks.append(task_name)

    return tasks





def main(arguments):

    parser = argparse.ArgumentParser()

    parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')

    parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',

                        type=str, default='all')

    parser.add_argument('--path_to_mrpc',

                        help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',

                        type=str, default='')

    args = parser.parse_args(arguments)



    if not os.path.isdir(args.data_dir):

        os.mkdir(args.data_dir)

    tasks = get_tasks(args.tasks)



    for task in tasks:

        if task == 'MRPC':

            format_mrpc(args.data_dir, args.path_to_mrpc)

        elif task == 'diagnostic':

            download_diagnostic(args.data_dir)

        else:

            download_and_extract(task, args.data_dir)





if __name__ == '__main__':

    sys.exit(main(sys.argv[1:]))

环境准备

requirements内容如下:

accelerate

datasets >= 1.8.0

sentencepiece != 0.1.92

scipy

scikit-learn

protobuf

numpy==1.17.3

#torch >= 1.3

代码准备

利用huggingface的run_glue_no_trainer.py

运行脚本如下:

export TASK_NAME=mrpc

python run_glue_no_trainer.py \
  --model_name_or_path Pretrained_LMs/bert-base-cased \
  --task_name $TASK_NAME \
  --max_length 128 \
  --per_device_train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3 \
  --output_dir ./output/$TASK_NAME/

在代码中修改运行设备方式如下:

    accelerator.state.device = 'mps'

    print('-' * 100)

    print(accelerator.state.device)

    print('-' * 100)

运行结果

CPU下运行时间约1h:

Num processes: 1
Process index: 0
Local process index: 0
Device: cpu
...
07/16/2022 17:13:00 - INFO - __main__ - ***** Running training *****
07/16/2022 17:13:00 - INFO - __main__ -   Num examples = 3668
07/16/2022 17:13:00 - INFO - __main__ -   Num Epochs = 3
07/16/2022 17:13:00 - INFO - __main__ -   Instantaneous batch size per device = 32
07/16/2022 17:13:00 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 32
07/16/2022 17:13:00 - INFO - __main__ -   Gradient Accumulation steps = 1
07/16/2022 17:13:00 - INFO - __main__ -   Total optimization steps = 345
  2%|███▌                                                                                                                                                                                                       | 6/345 [01:06<1:03:49, 11.30s/it]

硬件加速下运行时间约20min:

Num processes: 1
Process index: 0
Local process index: 0
Device: mps
...
07/16/2022 17:14:29 - INFO - __main__ - ***** Running training *****
07/16/2022 17:14:29 - INFO - __main__ -   Num examples = 3668
07/16/2022 17:14:29 - INFO - __main__ -   Num Epochs = 3
07/16/2022 17:14:29 - INFO - __main__ -   Instantaneous batch size per device = 32
07/16/2022 17:14:29 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 32
07/16/2022 17:14:29 - INFO - __main__ -   Gradient Accumulation steps = 1
07/16/2022 17:14:29 - INFO - __main__ -   Total optimization steps = 345
  5%|██████████▋                                                                                                                                                                                                 | 18/345 [01:03<20:14,  3.71s/it]

观察MAC活动监视器,可以看到程序确实有用到GPU硬件加速。

bug fix

在运行过程中出现如下错误:

OMP: Error #15: Initializing libomp.dylib, but found libiomp5.dylib already initialize异常

参照该链接解决了问题,如果Python是基本于Conda安装的,则Conda上的numpy包中的mkl很容易与系统内库发生冲突,可选择update numpy package in Conda或者设置为系统库。

解决方案:降低numpy的版本,此处笔者将版本降低到1.17.3

pip install numpy==1.17.3

猜你喜欢

转载自blog.csdn.net/qq_40676033/article/details/125826730