keras designated GPU

Count subscripts start at 1

例:CUDA_VISIBLE_DEVICES=2 python3 predict_3.py



How can I run a Keras model on multiple GPUs?

We recommend doing so using the TensorFlow backend. There are two ways to run a single model on multiple GPUs: data parallelism and device parallelism.

In most cases, what you need is most likely data parallelism.

Data parallelism

Data parallelism consists in replicating the target model once on each device, and using each replica to process a different fraction of the input data. Keras has a built-in utility, keras.utils.multi_gpu_model, which can produce a data-parallel version of any model, and achieves quasi-linear speedup on up to 8 GPUs.

For more information, see the documentation for multi_gpu_model. Here is a quick example:

from keras.utils import multi_gpu_model

# Replicates `model` on 8 GPUs.
# This assumes that your machine has 8 available GPUs.
parallel_model = multi_gpu_model(model, gpus=8)
parallel_model.compile(loss='categorical_crossentropy',
                       optimizer='rmsprop')

# This `fit` call will be distributed on 8 GPUs.
# Since the batch size is 256, each GPU will process 32 samples.
parallel_model.fit(x, y, epochs=20, batch_size=256)

Device parallelism

Device parallelism consists in running different parts of a same model on different devices. It works best for models that have a parallel architecture, e.g. a model with two branches.

This can be achieved by using TensorFlow device scopes. Here is a quick example:

# Model where a shared LSTM is used to encode two different sequences in parallel
input_a = keras.Input(shape=(140, 256))
input_b = keras.Input(shape=(140, 256))

shared_lstm = keras.layers.LSTM(64)

# Process the first sequence on one GPU
with tf.device_scope('/gpu:0'):
    encoded_a = shared_lstm(tweet_a)
# Process the next sequence on another GPU
with tf.device_scope('/gpu:1'):
    encoded_b = shared_lstm(tweet_b)

# Concatenate results on CPU
with tf.device_scope('/cpu:0'):
    merged_vector = keras.layers.concatenate([encoded_a, encoded_b],
                                             axis=-1)

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325618563&siteId=291194637