Interpretación del código fuente del metaaprendizaje agnóstico del modelo (MAML)

Recientemente, he leído el código fuente de MAML durante mucho tiempo y también aprendí TF. Quiero escribir un artículo para resumir el código fuente de MAML y desvelar el misterio de MAML.

papel

Inserte la descripción de la imagen aquí

Resumen

Proponemos un algoritmo para el metaaprendizaje que es agnóstico del modelo, en el sentido de que es compatible con cualquier modelo entrenado con descenso de gradiente y aplicable a una variedad de problemas de aprendizaje diferentes, incluida la clasificación, la regresión y el aprendizaje por refuerzo. El objetivo del metaaprendizaje es entrenar un modelo en una variedad de tareas de aprendizaje, de modo que pueda resolver nuevas tareas de aprendizaje utilizando solo una pequeña cantidad de muestras de capacitación. En nuestro enfoque, los parámetros del modelo se entrenan explícitamente de modo que una pequeña cantidad de pasos de gradiente con una pequeña cantidad de datos de entrenamiento de una nueva tarea producirán un buen rendimiento de generalización en esa tarea.De hecho, nuestro método entrena el modelo para que sea fácil de ajustar. Demostramos que este enfoque conduce a un rendimiento de vanguardia en dos puntos de referencia de clasificación de imágenes de pocos disparos, produce buenos resultados en la regresión de pocos disparos y acelera el ajuste para el aprendizaje de refuerzo de gradiente de políticas con políticas de redes neuronales.

Resumen: MAML encuentra un buen parámetro inicial en lugar de 0, lo que reduce en gran medida el tiempo de entrenamiento y la cantidad de muestras.

algoritmo

Inserte la descripción de la imagen aquí
Inserte la descripción de la imagen aquí
Como se muestra en la figura, tenemos el parámetro inicial θ, hay 3 tareas y cada tarea tiene el mejor parámetro θ *. En este momento, θ puede tener 3 direcciones de descenso de gradiente, pero no elegimos el descenso de gradiente, sino que dimos un paso en la dirección compartida por estos 3 puntos. De esta forma, el θ recién obtenido solo requiere unos pocos pasos para llegar al θ * de otras tareas.

Específicamente: El sexto paso del algoritmo calcula el θ * de cada tarea (no θ *, porque solo algunos pasos no se han completado), asumiendo que θ ha alcanzado este paso, se convierte en θi *, y luego encuentra este punto. gradiente, sumar y promediar, obtener una dirección neutral, ir en esta dirección, es decir, actualizar el original θ en el paso 8 y obtener un metaaprendiz

Código fuente

El código fuente que estoy buscando no es el código fuente proporcionado en este documento, sino la versión simplificada del código fuente implementada por "dragen1860" de acuerdo con el código fuente oficial. Para el enlace , puede ver lo más destacado de otros:

  • adoptado de la implementación oficial de cbfin con un rendimiento equivalente en mini-imagenet
  • estilo de código pequeño y limpio y muy fácil de seguir a partir de comentarios en casi todas las líneas
  • mejoras más rápidas y triviales, por ejemplo. 0,335 s por época en comparación con 0,563 s por época, lo que ahorra hasta 3,8 horas para un proceso de formación total de 60.000

Estructura de archivo

Inserte la descripción de la imagen aquí

Flujo de algoritmos de metaaprendizaje

  1. Entrar desde la entrada de la función principal
  2. Configuración de parámetros
  • nway: 5, número de clasificación, como gato, perro, caballo ...
  • kshot: 1, el número de muestras
  • kquery : 15 ,?
  • meta_batchsz: 4, el número de lotes en el metaaprendizaje, es decir, el número de tareas
  • K: 5, para encontrar el mejor θ * para cada tarea, MAML puede realizar un descenso de gradiente K, no un tiempo fijo
  1. Generar datos (tensor)

Revisemos primero el conjunto de soporte y el conjunto de consultas: cada tarea es una tarea del aprendizaje automático tradicional, incluido el conjunto de entrenamiento y el conjunto de pruebas, pero es fácil de confundir, no lo llamaremos conjunto de soporte y conjunto de consultas (estos dos Cómo ¿Grande es la configuración establecida?), luego se usan 4 tareas como meta-tren, llamado conjunto de trenes, y 4 tareas se usan como meta-prueba, llamado conjunto de prueba

Inserte la descripción de la imagen aquí

Los siguientes son el conjunto de soporte y el conjunto de consultas de las 4 tareas en la fase de meta-tren

# image_tensor: [4, 80, 84*84*3]
support_x = tf.slice(image_tensor, [0, 0, 0], [-1,  nway *  kshot, -1], name='support_x')
query_x = tf.slice(image_tensor, [0,  nway *  kshot, 0], [-1, -1, -1], name='query_x')
support_y = tf.slice(label_tensor, [0, 0, 0], [-1,  nway *  kshot, -1], name='support_y')
query_y = tf.slice(label_tensor, [0,  nway *  kshot, 0], [-1, -1, -1], name='query_y')
# support_x : [4, 1*5, 84*84*3]
# query_x   : [4, 15*5, 84*84*3]
# support_y : [4, 5, 5]
# query_y   : [4, 15*5, 5]

Inserte la descripción de la imagen aquí

El mismo método para construir dos conjuntos de la etapa de metaprueba

# construct test tensors.
image_tensor, label_tensor = db.make_data_tensor(training=False)
support_x_test = tf.slice(image_tensor, [0, 0, 0], [-1,  nway *  kshot, -1], name='support_x_test')
query_x_test = tf.slice(image_tensor, [0,  nway *  kshot, 0], [-1, -1, -1],  name='query_x_test')
support_y_test = tf.slice(label_tensor, [0, 0, 0], [-1,  nway *  kshot, -1],  name='support_y_test')
query_y_test = tf.slice(label_tensor, [0,  nway *  kshot, 0], [-1, -1, -1],  name='query_y_test')

El resultado final es que solo se dibuja el contenido de la tarea de un conjunto de trenes. El support_x real contiene 4 tareas. Además, en 5way, el conjunto de soporte debe ser el mismo que el conjunto de consultas:
Inserte la descripción de la imagen aquí

  1. Construya el modelo MAML, llame al método de construcción (los siguientes están todos en el método de construcción, salte del paso 8)
#这里的参数如84用来做tensor的reshape,为什么是这个数我也不知道
model = MAML(84, 3, 5)
model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
  1. Luego ingresamos el método de construcción: para cada tarea, llamamos al algoritmo meta_task
result = tf.map_fn(meta_task, elems=(support_xb, support_yb, query_xb, query_yb),dtype=out_dtype, parallel_iterations=meta_batchsz, name='map_fn')

Este algoritmo meta_task corresponde al apartado marcado en rojo en el algoritmo, es decir, encontrar el mejor parámetro θi * para cada tarea:

Inserte la descripción de la imagen aquí

  1. Echemos un vistazo a los detalles del algoritmo meta_task, que en realidad es el proceso de los parámetros normales de actualización de propagación hacia atrás y derivación hacia adelante:

Inserte la descripción de la imagen aquí
Usamos supportx más peso para calcular el gradiente, el descenso del gradiente para obtener el peso rápido, usamos este peso rápido para probar el conjunto de consultas para obtener la pérdida de consultas y luego actualizar iterativamente el peso rápido K veces

Tenga en cuenta que solo escribí descenso de gradiente de un paso en mi imagen, y hay un descenso de gradiente de K pasos en el código real. Cada vez que se obtiene el peso rápido en el conjunto de soporte, la pérdida se calcula en el conjunto de consultas y se obtiene la pérdida de consultas.

  1. Fuera del circuito, realice un segundo descenso de gradiente

Calculamos la pérdida de tantas tareas en el conjunto de consultas y calculamos el gradiente para la pérdida de consultas.

# meta-train optim
optimizer = tf.train.AdamOptimizer(self.meta_lr, name='meta_optim')
# meta-train gradients, query_losses[-1] is the accumulated loss across over tasks.
gvs = optimizer.compute_gradients(self.query_losses[-1])

Luego actualice el gradiente del parámetro real θ

# meta-train grads clipping
gvs = [(tf.clip_by_norm(grad, 10), var) for grad, var in gvs]
# update theta
self.meta_op = optimizer.apply_gradients(gvs)

Este paso corresponde a la sección roja del algoritmo:
Inserte la descripción de la imagen aquí

  1. Salte de la compilación,
    es decir, todo lo anterior es una línea de código para hacer:
if  training:
		model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train')
		model.build(support_x_test, support_y_test, query_x_test, query_y_test, K, meta_batchsz, mode='eval')
	else:
		model.build(support_x_test, support_y_test, query_x_test, query_y_test, K + 5, meta_batchsz, mode='test')

Luego, ingrese el método train (), es decir, 600,000 iteraciones. Cada iteración completa las siguientes funciones.
No he entendido la matriz de resultados aquí.

# this is the main op
		ops = [model.meta_op]

		# add summary and print op
		if iteration % 200 == 0:
			ops.extend([model.summ_op,
			            model.query_losses[0], model.query_losses[-1],
			            model.query_accs[0], model.query_accs[-1]])

		# run all ops
		result = sess.run(ops)

		# summary
		if iteration % 200 == 0:
			# summ_op
			# tb.add_summary(result[1], iteration)
			# query_losses[0]
			prelosses.append(result[2])
			# query_losses[-1]
			postlosses.append(result[3])
			# query_accs[0]
			preaccs.append(result[4])
			# query_accs[-1]
			postaccs.append(result[5])

			print(iteration, '\tloss:', np.mean(prelosses), '=>', np.mean(postlosses),
			      '\t\tacc:', np.mean(preaccs), '=>', np.mean(postaccs))
			prelosses, postlosses, preaccs, postaccs = [], [], [], []

		# evaluation
		if iteration % 2000 == 0:
			# DO NOT write as a = b = [], in that case a=b
			# DO NOT use train variable as we have train func already.
			acc1s, acc2s = [], []
			# sample 20 times to get more accurate statistics.
			for _ in range(200):
				acc1, acc2 = sess.run([model.test_query_accs[0],
				                        model.test_query_accs[-1]])
				acc1s.append(acc1)
				acc2s.append(acc2)

			acc = np.mean(acc2s)
			print('>>>>\t\tValidation accs: ', np.mean(acc1s), acc, 'best:', best_acc, '\t\t<<<<')

			if acc - best_acc > 0.05 or acc > 0.4:
				saver.save(sess, os.path.join('ckpt', 'mini.mdl'))
				best_acc = acc
				print('saved into ckpt:', acc)

saver.save(sess, os.path.join('ckpt', 'mini.mdl'))Guarde los parámetros del modelo y obtenga el metaaprendiz, luego ingresaremos al paso de prueba para ver si este aprendiz es realmente como dice el resumen del autor, ¿ puede completar el entrenamiento del modelo con unos pocos pasos de descenso de gradiente + una pequeña cantidad de datos?

Flujo de algoritmos de metaprueba

Entrene a θ en soporte para K + 5 veces para obtener θ *, y verifique que θ * sea bueno en consulta, porque las clases en el conjunto de soporte y el conjunto de consultas son las mismas.

ops = [model.test_support_acc]
ops.extend(model.test_query_accs)
result = sess.run(ops)
test_accs.append(result)

Uso de código

El archivo Léame es muy claro, pero estoy bajo el sistema win10, que es un poco diferente. El uso específico es el siguiente:

  1. Descargue la colección de imágenes de imagenet desde el enlace proporcionado por el autor, alrededor de cientos de miles de imágenes, 3G
  2. Modifique el archivo proc_images.py y cambie el comando python linux a windows
path = 'C:/Users/Administrator/Desktop/MAML-TensorFlow-master/miniimagenet/'
# Put in correct directory
for datatype in ['train', 'val', 'test']:
    os.system('mkdir ' + datatype)

    with open(datatype + '.csv', 'r') as f:
        reader = csv.reader(f, delimiter=',')
        last_label = ''
        for i, row in enumerate(reader):
            if i == 0:  # skip the headers
                continue
            label = row[1]
            image_name = row[0]
            if label != last_label:
                cur_dir = ''+datatype + '/' + label + '/'
                if not os.path.exists(path + cur_dir):
                    os.mkdir(path + cur_dir)
                last_label = label
            print( path+image_name + ' ' + path+cur_dir)
            #os.system('cpoy images/' + image_name + ' ' + cur_dir)
            shutil.move(path+'images/'+image_name, path+cur_dir)
  1. Configurar el entorno con aniconda: python3.6 TF1.15.0, conda enable está activado
  2. En el nuevo entorno, python main.py está bien
  3. Resultado: la velocidad es extremadamente lenta. Los resultados de usar la CPU de su computadora durante varias horas son los siguientes:

Inserte la descripción de la imagen aquí

Puede verse que la tasa de precisión está mejorando gradualmente.

2020.8.27todo :

  • En el siguiente paso, después de ejecutar los resultados, debe ver la correspondencia entre la parte experimental del artículo y este resultado.

  • Y haz un dibujo en el equipo de prueba.

  • Además, analice el contenido de estas dos funciones test val.

Supongo que te gusta

Origin blog.csdn.net/Protocols7/article/details/108250998
Recomendado
Clasificación