四、Tensorflow的分布式训练

TensorFlow中的集群(cluster)指的是一系列能够针对图(Graph)进行分布式计算任务(task)。每个任务是同服务(server)相关联的。TensorFlow中的服务会包含一个用于创建session的主节点和至少一个用于图运算的工作节点,一个集群可以被拆分为一个活着多个作业(job),每个作业可以包含至少一个任务。

以下的例子是一个最简单的例子

1、服务端代码:

import tensorflow as tf

'''
运行命令:
python tensf_server_01 --job_name=ps --task_index=0
python tensf_server_01 --job_name=ps --task_index=0
python tensf_server_01 --job_name=work --task_index=0
python tensf_server_01 --job_name=work --task_index=1
python tensf_server_01 --job_name=work --task_index=2

'''

#1、配置服务器相关信息
#因为tensorflow底层代码中,默认就是使用ps和work分别表示两类不同的工作节点
#ps:变量/张量的初始化,存储相关节点
#work:变量/张量的计算/运算的相关节点
ps_host = ['127.0.0.1:33331','127.0.0.1:33332']
work_hosts = ['127.0.0.1:33333','127.0.0.1:33334','127.0.0.1:33335']
cluster = tf.train.ClusterSpec({'ps':ps_host,'work':work_hosts})

#2、定义一些运行参数(在运行该python文件的时候就可以制定这些参数了)
tf.app.flags.DEFINE_string('job_name',default_value='work',docstring="One of 'ps' or 'work'")
tf.app.flags.DEFINE_integer('task_index',default_value=0,docstring="Index of task within the job")
FLAGS = tf.app.flags.FLAGS

#2、启动服务
#_下划线表示占位符
def main(_):
    print(FLAGS.job_name)
    server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
    server.join()

if __name__ == '__main__':
    #底层默认会调用main方法
    tf.app.run()

2、client端的代码:

import tensorflow as tf
import numpy as np


#1、构建图
#表示使用ps的job,task:0表示使用第一个配置,也就是127.0.0.1:33331
with tf.device('/job:ps/task:0'):
    #构造函数
    x = tf.constant(np.random.rand(100).astype(np.float32))

with tf.device('/job:ps/task:1'):
    y = y = x * 0.2 +0.3

#2、运行
with tf.Session(target='grpc://127.0.0.1:33335',
                config=tf.ConfigProto(log_device_placement=True)) as sess:
    sess.run(y)

猜你喜欢

转载自www.cnblogs.com/allen-GC/p/10721064.html
今日推荐