Docker容器中实现Tensorflow分布式训练

Docker容器中实现Tensorflow分布式训练

一、简介

Tensorflow分布式介绍:tensorflow分布式训练主要有以下几种形式–单机多卡、多机单卡、多机多卡;以上几种形式是基于PS结构的,使用的通信方式–同步(同步SGD)、异步(异步SGD) 。

  • 环境
    • win10+ 虚拟机 + Centos7 + docker + tensorflow
  • 内容
    • 本文的主要内容是使用docker容器来模拟多机多卡的情况

二、在docker中实现分布式训练

  1. Tensorflow镜像的制作

    1. 编辑Dockerfile 文件vim tensorflow-cpu,并写入
      FROM centos:7.3.1611
      MAINTAINER urmsone
      RUN yum install -y vim  net-tools
      RUN yum install -y epel-release && yum install -y gcc python-devel python2-pip && pip install --upgrade pip && pip install jupyter && python -m ipykernel.kernelspec
      RUN pip install -i https://pypi.tuna.tsinghua.edu.cn/simple/ https://mirrors.tuna.tsinghua.edu.cn/tensorflow/linux/cpu/tensorflow-1.3.0-cp27-none-linux_x86_64.whl
      # 复制分布式脚本到容器根目录
      COPY mnist_replica.py /
      # 复制data目录下的mnist数据集到容器的/tmp/mnist-data/目录中
      COPY data/* /tmp/mnist-data/
      
    2. 拉取分布式训练脚本
      curl https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/tools/dist_test/python/mnist_replica.py -o mnist_replica.py
    3. 准备mnist数据集
    4. 完成镜像制作docker build -t tensorflow:1 -f tensorflow-cpu .
  2. 启动容器

    1. 启动容器ps,作为参数服务器
      docker run -it --name ps -p 2222 --rm tensorflow:1 /bin/bash
    2. 启动容器worker1,作为计算服务器1
      docker run -it --name worker1 -p 2222 --rm tensorflow:1 /bin/bash
    3. 启动容器worker2,作为计算服务器2
      docker run -it --name worker2 -p 2222 --rm tensorflow:1 /bin/bash
  3. 在容器中搭建集群并开始分布式训练

    1. 在ps容器中执行
      python mnist_replica.py --ps_hosts=172.17.0.2:2222 --worker_hosts=172.17.0.5:2222,172.17.0.3:2222 --job_name=ps --task_index=0
    2. 在worker1容器中执行
      python mnist_replica.py --ps_hosts=172.17.0.2:2222 --worker_hosts=172.17.0.5:2222,172.17.0.3:2222 --job_name=worker --task_index=0
    3. 在worker2容器中执行
      python mnist_replica.py --ps_hosts=172.17.0.2:2222 --worker_hosts=172.17.0.5:2222,172.17.0.3:2222 --job_name=worker --task_index=1

    注:python命令中的–ps_hosts为ps容器的ip地址,–worker_hosts为worker1和worker2的ip地址;可以使用命令docker inspect ps |grep Addr查看容器的ip地址

  4. 报错总结

    1. tensorflow.python.framework.errors_impl.UnknownError: Could not start gRPC server
      原因:上次运行的python mnist_replica.py没有中断,再次运行时,导致发生未知错误。
      解决办法:只需让正在运行的程序终止运行,然后再重新运行就好了。
    2. TensorFlow IOError: [Errno socket error] [Errno 104] Connection reset by peer
      原因:input_data.read_data_sets()读取mnist数据集时,如果文件不存在,会自动的远程拉取。如拉取数据集的远程url需要翻墙才能访问时,就会报以上的网络错误。
      解决方法:更换不需翻墙的url或者手动下载mnist数据集。

猜你喜欢

转载自blog.csdn.net/Urms_handsomeyu/article/details/86940028
今日推荐