Después de entrenar un modelo en Tensorflow:
- ¿Cómo se guarda el modelo entrenado?
- ¿Cómo restaurar más tarde este modelo guardado?
python
tensorflow
machine-learning
model
mathetes
fuente
fuente
Respuestas:
Docs
tutorial exhaustivo y útil -> https://www.tensorflow.org/guide/saved_model
Guía detallada de Keras para guardar modelos -> https://www.tensorflow.org/guide/keras/save_and_serialize
De los documentos:
Salvar
Restaurar
Tensorflow 2
Esto todavía es beta, por lo que desaconsejaría por ahora. Si todavía quieres seguir ese camino, aquí está la
tf.saved_model
guía de usoTensorflow <2
simple_save
Muchas buenas respuestas, para completar agregaré mis 2 centavos: simple_save . También un ejemplo de código independiente que usa la
tf.data.Dataset
API.Python 3; Tensorflow 1.14
Restaurando:
Ejemplo independiente
Publicación original del blog
El siguiente código genera datos aleatorios por el bien de la demostración.
Dataset
y luego suIterator
. Obtenemos el tensor generado por el iterador, llamadoinput_tensor
que servirá como entrada a nuestro modelo.input_tensor
: un RNN bidireccional basado en GRU seguido de un clasificador denso. Porque, porque no.softmax_cross_entropy_with_logits
, optimizada conAdam
. Después de 2 épocas (de 2 lotes cada una), guardamos el modelo "entrenado" contf.saved_model.simple_save
. Si ejecuta el código como está, el modelo se guardará en una carpeta llamadasimple/
en su directorio de trabajo actual.tf.saved_model.loader.load
. Agarramos los marcadores de posición y logits congraph.get_tensor_by_name
y laIterator
operación de inicialización congraph.get_operation_by_name
.Código:
Esto imprimirá:
fuente
tf.contrib.layers
?[n.name for n in graph2.as_graph_def().node]
. Como dice la documentación, guardar simple tiene como objetivo simplificar la interacción con el servicio de tensorflow, este es el punto de los argumentos; sin embargo, otras variables aún se restauran, de lo contrario no se produciría inferencia. Simplemente tome sus variables de interés como lo hice en el ejemplo. Consulte la documentaciónglobal_step
argumento, si se detiene, intente retomar el entrenamiento nuevamente, pensará que es un paso uno. Al menos arruinará las visualizaciones de su tensorboardEstoy mejorando mi respuesta para agregar más detalles para guardar y restaurar modelos.
En (y después) Tensorflow versión 0.11 :
Guarda el modelo:
Restaurar el modelo:
Este y algunos casos de uso más avanzados se han explicado muy bien aquí.
Un tutorial rápido y completo para guardar y restaurar modelos de Tensorflow
fuente
:0
a los nombres?En (y después) TensorFlow versión 0.11.0RC1, puede guardar y restaurar su modelo directamente llamando
tf.train.export_meta_graph
y detf.train.import_meta_graph
acuerdo con https://www.tensorflow.org/programmers_guide/meta_graph .Guardar el modelo
Restaurar el modelo
fuente
<built-in function TF_Run> returned a result with an error set
tf.get_variable_scope().reuse_variables()
seguido devar = tf.get_variable("varname")
. Esto me da el error: "ValueError: el variable varname no existe o no se creó con tf.get_variable ()". ¿Por qué? ¿No debería ser esto posible?Para la versión TensorFlow <0.11.0RC1:
Los puntos de control que se guardan contienen valores para los
Variable
s en su modelo, no el modelo / gráfico en sí, lo que significa que el gráfico debe ser el mismo cuando restaure el punto de control.Aquí hay un ejemplo para una regresión lineal donde hay un ciclo de entrenamiento que guarda puntos de control de variables y una sección de evaluación que restaurará las variables guardadas en una ejecución anterior y calculará predicciones. Por supuesto, también puede restaurar variables y continuar entrenando si lo desea.
Aquí están los documentos para
Variable
s, que cubren el ahorro y la restauración. Y aquí están los documentos para elSaver
.fuente
batch_x
debe ser? ¿Binario? Numpy array?undefined
. ¿Me puede decir cuál es def de FLAGS para este código? @RyanSepassiMi entorno: Python 3.6, Tensorflow 1.3.0
Aunque ha habido muchas soluciones, la mayoría de ellas se basan en
tf.train.Saver
. Cuando cargamos un.ckpt
salvados porSaver
, tenemos que redefinir la red, ya sea tensorflow o utilizar algún nombre raro y recordado duro, por ejemplo'placehold_0:0'
,'dense/Adam/Weight:0'
. Aquí recomiendo usartf.saved_model
, un ejemplo más simple que se muestra a continuación, puede obtener más información al servir un modelo TensorFlow :Guarda el modelo:
Cargue el modelo:
fuente
Hay dos partes en el modelo, la definición del modelo, guardada
Supervisor
comograph.pbtxt
en el directorio del modelo y los valores numéricos de los tensores, guardados en archivos de puntos de control comomodel.ckpt-1003418
.La definición del modelo se puede restaurar usando
tf.import_graph_def
, y los pesos se restauran usandoSaver
.Sin embargo,
Saver
usa una lista de variables de retención de colección especial que se adjunta al modelo Graph, y esta colección no se inicializa usando import_graph_def, por lo que no puede usar las dos juntas en este momento (está en nuestra hoja de ruta para solucionarlo). Por ahora, debe usar el enfoque de Ryan Sepassi: construir manualmente un gráfico con nombres de nodo idénticos y usarloSaver
para cargar los pesos en él.(Alternativamente, podría piratearlo usando
import_graph_def
, usando , creando variables manualmente, y usandotf.add_to_collection(tf.GraphKeys.VARIABLES, variable)
para cada variable, luego usandoSaver
)fuente
También puedes tomar este camino más fácil.
Paso 1: inicializa todas tus variables
Paso 2: guarde la sesión dentro del modelo
Saver
y guárdelaPaso 3: restaurar el modelo
Paso 4: verifica tu variable
Mientras se ejecuta en una instancia de Python diferente, use
fuente
En la mayoría de los casos, guardar y restaurar desde el disco usando a
tf.train.Saver
es su mejor opción:También puede guardar / restaurar la estructura del gráfico en sí (consulte la documentación de MetaGraph para más detalles). Por defecto,
Saver
guarda la estructura del gráfico en un.meta
archivo. Puedes llamarimport_meta_graph()
para restaurarlo. Restaura la estructura del gráfico y devuelve unSaver
que puede usar para restaurar el estado del modelo:Sin embargo, hay casos en los que necesita algo mucho más rápido. Por ejemplo, si implementa una detención temprana, desea guardar puntos de control cada vez que el modelo mejora durante el entrenamiento (según lo medido en el conjunto de validación), luego, si no hay progreso durante algún tiempo, desea volver al mejor modelo. Si guarda el modelo en el disco cada vez que mejora, ralentizará enormemente el entrenamiento. El truco es guardar los estados variables en la memoria , luego restaurarlos más tarde:
Una explicación rápida: cuando crea una variable
X
, TensorFlow crea automáticamente una operación de asignaciónX/Assign
para establecer el valor inicial de la variable. En lugar de crear marcadores de posición y operaciones de asignación adicionales (lo que haría que el gráfico fuera desordenado), solo usamos estas operaciones de asignación existentes. La primera entrada de cada asignación op es una referencia a la variable que se supone que debe inicializar, y la segunda entrada (assign_op.inputs[1]
) es el valor inicial. Entonces, para establecer cualquier valor que queramos (en lugar del valor inicial), necesitamos usarfeed_dict
ay reemplazar el valor inicial. Sí, TensorFlow le permite alimentar un valor para cualquier operación, no solo para marcadores de posición, por lo que funciona bien.fuente
Como dijo Yaroslav, puede piratear la restauración desde un gráfico_def y un punto de control importando el gráfico, creando variables manualmente y luego utilizando un protector.
Implementé esto para mi uso personal, así que pensé en compartir el código aquí.
Enlace: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(Esto es, por supuesto, un truco, y no hay garantía de que los modelos guardados de esta manera sigan siendo legibles en futuras versiones de TensorFlow).
fuente
Si es un modelo guardado internamente, solo debe especificar un restaurador para todas las variables como
y úselo para restaurar variables en una sesión actual:
Para el modelo externo, debe especificar la asignación de los nombres de sus variables a sus nombres de variables. Puede ver los nombres de las variables del modelo con el comando
El script inspect_checkpoint.py se puede encontrar en la carpeta './tensorflow/python/tools' de la fuente de Tensorflow.
Para especificar el mapeo, puede usar mi Tensorflow-Worklab , que contiene un conjunto de clases y scripts para entrenar y reentrenar diferentes modelos. Incluye un ejemplo de reciclaje de modelos de ResNet, que se encuentra aquí.
fuente
all_variables()
ahora está en desusoAquí está mi solución simple para los dos casos básicos que difieren en si desea cargar el gráfico del archivo o compilarlo durante el tiempo de ejecución.
Esta respuesta es válida para Tensorflow 0.12+ (incluido 1.0).
Reconstruyendo el gráfico en código
Ahorro
Cargando
Cargando también el gráfico desde un archivo
Cuando utilice esta técnica, asegúrese de que todas sus capas / variables hayan establecido explícitamente nombres únicos.De lo contrario, Tensorflow hará que los nombres sean únicos y, por lo tanto, serán diferentes de los nombres almacenados en el archivo. No es un problema en la técnica anterior, porque los nombres están "destrozados" de la misma manera tanto en la carga como en el almacenamiento.
Ahorro
Cargando
fuente
global_step
variable y los promedios móviles de la normalización de lotes son variables no entrenables, pero definitivamente vale la pena guardar ambas. Además, debe distinguir más claramente la construcción del gráfico de la ejecución de la sesión, por ejemploSaver(...).save()
, creará nuevos nodos cada vez que lo ejecute. Probablemente no sea lo que quieres. Y hay más ...: /También puede consultar ejemplos en TensorFlow / skflow , que ofrece
save
yrestore
métodos que pueden ayudarlo a administrar fácilmente sus modelos. Tiene parámetros que también puede controlar con qué frecuencia desea hacer una copia de seguridad de su modelo.fuente
Si usa tf.train.MonitoredTrainingSession como sesión predeterminada, no necesita agregar código adicional para guardar / restaurar cosas. Simplemente pase un nombre de directorio de punto de control al constructor de MonitoredTrainingSession, usará ganchos de sesión para manejarlos.
fuente
Todas las respuestas aquí son geniales, pero quiero agregar dos cosas.
Primero, para explicar la respuesta de @ user7505159, puede ser importante agregar "./" al principio del nombre del archivo que está restaurando.
Por ejemplo, puede guardar un gráfico sin "./" en el nombre del archivo de esta manera:
Pero para restaurar el gráfico, es posible que deba anteponer un "./" al nombre_archivo:
No siempre necesitará el "./", pero puede causar problemas dependiendo de su entorno y versión de TensorFlow.
También quiero mencionar que
sess.run(tf.global_variables_initializer())
puede ser importante antes de restaurar la sesión.Si recibe un error con respecto a las variables no inicializadas al intentar restaurar una sesión guardada, asegúrese de incluir
sess.run(tf.global_variables_initializer())
antes de lasaver.restore(sess, save_file)
línea. Puede ahorrarte un dolor de cabeza.fuente
Como se describe en el número 6255 :
en vez de
fuente
Según la nueva versión de Tensorflow,
tf.train.Checkpoint
es la forma preferible de guardar y restaurar un modelo:Aquí hay un ejemplo:
Más información y ejemplo aquí.
fuente
Para tensorflow 2.0 , es tan simple como
Restaurar:
fuente
tf.keras Ahorro de modelo con
TF2.0
Veo excelentes respuestas para guardar modelos usando TF1.x. Quiero proporcionar un par de punteros más para guardar
tensorflow.keras
modelos, lo cual es un poco complicado ya que hay muchas maneras de guardar un modelo.Aquí estoy proporcionando un ejemplo de guardar un
tensorflow.keras
modelo en lamodel_path
carpeta en el directorio actual. Esto funciona bien con el tensorflow más reciente (TF2.0). Actualizaré esta descripción si hay algún cambio en el futuro cercano.Guardar y cargar todo el modelo
Guardar y cargar pesas modelo solamente
Si está interesado en guardar solo los pesos del modelo y luego cargar los pesos para restaurar el modelo, entonces
Guardar y restaurar usando la devolución de llamada de punto de control Keras
modelo de guardado con métricas personalizadas
Guardar el modelo de Keras con operaciones personalizadas
Cuando tenemos operaciones personalizadas como en el siguiente caso (
tf.tile
), necesitamos crear una función y envolver con una capa Lambda. De lo contrario, el modelo no se puede guardar.Creo que he cubierto algunas de las muchas formas de guardar el modelo tf.keras. Sin embargo, hay muchas otras formas. Comente a continuación si ve que su caso de uso no está cubierto anteriormente. ¡Gracias!
fuente
Use tf.train.Saver para guardar un modelo, remerber, necesita especificar var_list, si desea reducir el tamaño del modelo. Val_list puede ser tf.trainable_variables o tf.global_variables.
fuente
Puede guardar las variables en la red usando
Para restaurar la red para su reutilización posterior o en otro script, use:
Puntos importantes:
sess
debe ser igual entre la primera y las últimas ejecuciones (estructura coherente).saver.restore
necesita la ruta de la carpeta de los archivos guardados, no una ruta de archivo individual.fuente
Donde quiera guardar el modelo,
Asegúrese de que todos
tf.Variable
tengan nombres, porque es posible que desee restaurarlos más tarde utilizando sus nombres. Y donde quieres predecir,Asegúrese de que el protector se ejecute dentro de la sesión correspondiente. Recuerda eso, si usas el
tf.train.latest_checkpoint('./')
, solo el último punto de control.fuente
Estoy en la versión:
Manera simple es
Salvar:
Restaurar:
fuente
Para tensorflow-2.0
es muy simple.
SALVAR
RESTAURAR
fuente
Siguiendo la respuesta de @Vishnuvardhan Janapati, aquí hay otra forma de guardar y recargar el modelo con capa / métrica / pérdida personalizada bajo TensorFlow 2.0.0
De esta manera, una vez que se ha ejecutado este tipo de códigos, y salvó su modelo con
tf.keras.models.save_model
omodel.save
oModelCheckpoint
de devolución de llamada, puede volver a cargar el modelo sin la necesidad de objetos personalizados precisos, tan simple comofuente
En la nueva versión de tensorflow 2.0, el proceso de guardar / cargar un modelo es mucho más fácil. Debido a la implementación de la API de Keras, una API de alto nivel para TensorFlow.
Para guardar un modelo: Consulte la documentación para referencia: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model
Para cargar un modelo:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model
fuente
Aquí hay un ejemplo simple que usa el formato Tensorflow 2.0 SavedModel (que es el formato recomendado, según los documentos ) para un clasificador de conjunto de datos MNIST simple, usando la API funcional Keras sin demasiada fantasía:
¿Qué es
serving_default
?Es el nombre de la definición de firma de la etiqueta que seleccionó (en este caso,
serve
se seleccionó la etiqueta predeterminada ). Además, aquí se explica cómo encontrar las etiquetas y firmas de un modelo usandosaved_model_cli
.Renuncias
Este es solo un ejemplo básico si solo desea ponerlo en funcionamiento, pero de ninguna manera es una respuesta completa, tal vez pueda actualizarlo en el futuro. Solo quería dar un ejemplo simple usando el
SavedModel
TF 2.0 porque no he visto uno, ni siquiera este simple, en ningún lado.La respuesta de @ Tom es un ejemplo de SavedModel, pero no funcionará en Tensorflow 2.0, porque desafortunadamente hay algunos cambios importantes.
La respuesta de @ Vishnuvardhan Janapati dice TF 2.0, pero no es para el formato SavedModel.
fuente