¿Cómo enumerar todas las operaciones utilizadas en Tensorflow SavedModel?

10

Si tensorflow.saved_model.saveguardo mi modelo usando la función en formato Guardado, ¿cómo puedo recuperar qué Tensorflow Ops se utilizan en este modelo después? Como el modelo se puede restaurar, estas operaciones se almacenan en el gráfico, supongo que está en el saved_model.pbarchivo. Si cargo este protobuf (por lo tanto, no todo el modelo), la parte de la biblioteca del protobuf enumera estos, pero esto no está documentado y etiquetado como una característica experimental por ahora. Los modelos creados en Tensorflow 1.x no tendrán esta parte.

Entonces, ¿cuál es una forma rápida y confiable de recuperar una lista de Operaciones usadas (Me gusta MatchingFileso WriteFile) de un modelo en formato GuardadoModelo?

Ahora mismo puedo congelar todo, como lo tensorflowjs-converterhace. Como también verifican las operaciones compatibles. Actualmente, esto no funciona cuando hay un LSTM en el modelo, vea aquí . ¿Hay una mejor manera de hacer esto, ya que los Ops definitivamente están ahí?

Un modelo de ejemplo:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Se espera en la salida todas las operaciones, que contienen en este caso al menos:

  • ReadFilecomo se describe aquí
  • ...
sampers
fuente
1
Es difícil decir exactamente lo que quieres, qué es saved_model.pb, ¿es tf.GraphDefun SavedModelmensaje de protobuf? Si tiene una tf.GraphDefllamada gd, puede obtener la lista de operaciones usadas con sorted(set(n.op for n in gd.node)). Si tiene un modelo cargado, puede hacerlo sorted(set(op.type for op in tf.get_default_graph().get_operations())). Si es un SavedModel, puede obtenerlo tf.GraphDef(p saved_model.meta_graphs[0].graph_def. Ej .).
jdehesa
Quiero recuperar las operaciones de un modelo guardado guardado. De hecho, la última opción que está describiendo. ¿Cuál es la saved_modelvariable en tu último ejemplo? El resultado de tf.saved_model.load('/path/to/model')o cargando el protobuf del archivo saved_model.pb.
Sampers

Respuestas:

1

Si saved_model.pbes un SavedModelmensaje protobuf, entonces obtiene las operaciones directamente desde allí. Digamos que creamos un modelo de la siguiente manera:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Ahora podemos encontrar las operaciones utilizadas por ese modelo así:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin
jdehesa
fuente
Intenté algo como esto, pero desafortunadamente esto no es lo que espero que haga: Digamos que tengo un modelo que hace esto: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')Entonces, el ReadFile Op como se muestra aquí está allí, pero no está impreso.
Sampers
1
@sampers He editado la respuesta con un ejemplo como el que sugieres. Obtengo la ReadFileoperación en la salida. ¿Es posible que, en su caso real, esa operación no esté entre la entrada y la salida del modelo guardado? En ese caso, creo que podría podarse.
jdehesa
De hecho, con el modelo dado funciona. Desafortunadamente para un módulo hecho en tf2, no lo hace. Si creo un tf.Module con 1 función con una anotación de file_nameargumento @tf.function, que contiene las llamadas que enumeré en mi comentario anterior, da la siguiente lista:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
sampers
agregó un modelo a mi pregunta
sampers
@sampers He actualizado mi respuesta. Estaba usando TF 1.x antes, no estaba familiarizado con los cambios en los objetos de definición de gráficos en TF 2.x, creo que la respuesta ahora cubre todo en el modelo guardado. Creo que las operaciones correspondientes a la función Python que escribió están en saved_model.meta_graphs[0].graph_def.library.function[0](la node_defcolección dentro de ese objeto de función).
jdehesa