TensorFlow, ¿por qué hay 3 archivos después de guardar el modelo?

113

Después de leer los documentos , guardé un modelo TensorFlow, aquí está mi código de demostración:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

pero después de eso, encontré que hay 3 archivos

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

Y no puedo restaurar el modelo restaurando el model.ckptarchivo, ya que no existe tal archivo. Aqui esta mi codigo

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Entonces, ¿por qué hay 3 archivos?

GoingMyWay
fuente
2
¿Descubrió cómo abordar esto? ¿Cómo puedo volver a cargar el modelo (usando Keras)?
rajkiran

Respuestas:

116

Prueba esto:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

El método de guardado de TensorFlow guarda tres tipos de archivos porque almacena la estructura del gráfico por separado de los valores de las variables . El .metaarchivo describe la estructura del gráfico guardado, por lo que debe importarlo antes de restaurar el punto de control (de lo contrario, no sabe a qué variables corresponden los valores guardados del punto de control).

Alternativamente, puede hacer esto:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Aunque no hay ningún archivo con nombre model.ckpt, aún se refiere al punto de control guardado por ese nombre al restaurarlo. Del saver.pycódigo fuente :

Los usuarios solo necesitan interactuar con el prefijo especificado por el usuario ... en lugar de cualquier nombre de ruta física.

TK Bartel
fuente
1
por lo que no se utilizan .index y .data? Entonces, ¿cuándo se utilizan esos 2 archivos?
ajfbiw.s
26
@ ajfbiw.s .meta almacena la estructura del gráfico, .data almacena los valores de cada variable en el gráfico, .index identifica el checkpiont. Entonces, en el ejemplo anterior: import_meta_graph usa .meta, y saver.restore usa .data e .index
TK Bartel
Oh ya veo. Gracias.
ajfbiw.s
1
¿Alguna posibilidad de que hayas guardado el modelo con una versión de TensorFlow diferente a la que estás usando para cargarlo? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel
5
¿Alguien sabe lo que significan eso 00000y los 00001números? en variables.data-?????-of-?????archivo
Ivan Talalaev
55
  • archivo meta : describe la estructura del gráfico guardado, incluye GraphDef, SaverDef, etc. luego aplica tf.train.import_meta_graph('/tmp/model.ckpt.meta'), restaurará Savery Graph.

  • archivo de índice : es una tabla inmutable string-string (tensorflow :: table :: Table). Cada clave es un nombre de un tensor y su valor es un BundleEntryProto serializado. Cada BundleEntryProto describe los metadatos de un tensor: cuál de los archivos de "datos" contiene el contenido de un tensor, el desplazamiento en ese archivo, la suma de comprobación, algunos datos auxiliares, etc.

  • archivo de datos : es la colección TensorBundle, guarda los valores de todas las variables.

Guangcong Liu
fuente
Tengo el archivo pb que tengo para la clasificación de imágenes. ¿Puedo usarlo para clasificar videos en tiempo real?
¿Me puede hacer saber, usando Keras 2, cómo cargo el modelo si está guardado como 3 archivos?
rajkiran
5

Estoy restaurando incrustaciones de palabras entrenadas del tutorial de Word2Vec tensorflow.

En caso de que haya creado varios puntos de control:

por ejemplo, los archivos creados se ven así

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

prueba esto

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

al llamar a restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")
Steven Wong
fuente
¿Qué significa "00000-of-00001" en "model.ckpt-55695.data-00000-of-00001"?
hafiz031
0

Si entrenó a una CNN con deserción, por ejemplo, podría hacer esto:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
happy_sisyphus
fuente