Обходной путь
Вот исправленная версия, которая не дает сбоев при сохранении:
from keras.layers import Lambda, concatenate
from keras import Model
import tensorflow as tf
def multi_gpu_model(model, gpus):
if isinstance(gpus, (list, tuple)):
num_gpus = len(gpus)
target_gpu_ids = gpus
else:
num_gpus = gpus
target_gpu_ids = range(num_gpus)
def get_slice(data, i, parts):
shape = tf.shape(data)
batch_size = shape[:1]
input_shape = shape[1:]
step = batch_size // parts
if i == num_gpus - 1:
size = batch_size - step * i
else:
size = step
size = tf.concat([size, input_shape], axis=0)
stride = tf.concat([step, input_shape * 0], axis=0)
start = stride * i
return tf.slice(data, start, size)
all_outputs = []
for i in range(len(model.outputs)):
all_outputs.append([])
# Place a copy of the model on each GPU,
# each getting a slice of the inputs.
for i, gpu_id in enumerate(target_gpu_ids):
with tf.device('/gpu:%d' % gpu_id):
with tf.name_scope('replica_%d' % gpu_id):
inputs = []
# Retrieve a slice of the input.
for x in model.inputs:
input_shape = tuple(x.get_shape().as_list())[1:]
slice_i = Lambda(get_slice,
output_shape=input_shape,
arguments={'i': i,
'parts': num_gpus})(x)
inputs.append(slice_i)
# Apply model on slice
# (creating a model replica on the target device).
outputs = model(inputs)
if not isinstance(outputs, list):
outputs = [outputs]
# Save the outputs for merging back together later.
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
# Merge outputs on CPU.
with tf.device('/cpu:0'):
merged = []
for name, outputs in zip(model.output_names, all_outputs):
merged.append(concatenate(outputs,
axis=0, name=name))
return Model(model.inputs, merged)
Вы можете использовать эту функцию multi_gpu_model
, пока ошибка не будет исправлена в keras. Кроме того, при загрузке модели важно предоставить объект модуля тензорного потока:
model = load_model('multi_gpu_model.h5', {'tf': tf})
Как это устроено
Проблема в строке import tensorflow
в середине multi_gpu_model
:
def multi_gpu_model(model, gpus):
...
import tensorflow as tf
...
Это создает замыкание для лямбда-функции get_slice
, которая включает в себя количество GPU (все в порядке) и модуль тензорного потока (не в порядке). Сохранение модели пытается сериализовать все слои, включая те, которые вызывают get_slice
, и терпит неудачу именно потому, что tf
находится в замыкании.
Решение состоит в том, чтобы переместить импорт из multi_gpu_model
, чтобы tf
стал глобальным объектом, хотя он по-прежнему необходим для работы get_slice
. Это устраняет проблему сохранения, но при загрузке необходимо указать tf
явно.
person
Maxim
schedule
02.01.2018