Supongamos que tengo un tensor de Tensorflow. ¿Cómo obtengo las dimensiones (forma) del tensor como valores enteros? Sé que hay dos métodos, tensor.get_shape()
y tf.shape(tensor)
, pero no puedo obtener los valores de forma como int32
valores enteros .
Por ejemplo, a continuación, he creado un tensor 2-D y necesito obtener la cantidad de filas y columnas int32
para poder llamar reshape()
para crear un tensor de forma (num_rows * num_cols, 1)
. Sin embargo, el método tensor.get_shape()
devuelve valores como Dimension
tipo, no int32
.
import tensorflow as tf
import numpy as np
sess = tf.Session()
tensor = tf.convert_to_tensor(np.array([[1001,1002,1003],[3,4,5]]), dtype=tf.float32)
sess.run(tensor)
# array([[ 1001., 1002., 1003.],
# [ 3., 4., 5.]], dtype=float32)
tensor_shape = tensor.get_shape()
tensor_shape
# TensorShape([Dimension(2), Dimension(3)])
print tensor_shape
# (2, 3)
num_rows = tensor_shape[0] # ???
num_cols = tensor_shape[1] # ???
tensor2 = tf.reshape(tensor, (num_rows*num_cols, 1))
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 1750, in reshape
# name=name)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 454, in apply_op
# as_ref=input_arg.is_ref)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 621, in convert_to_tensor
# ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 180, in _constant_tensor_conversion_function
# return constant(v, dtype=dtype, name=name)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/constant_op.py", line 163, in constant
# tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 353, in make_tensor_proto
# _AssertCompatible(values, dtype)
# File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/tensor_util.py", line 290, in _AssertCompatible
# (dtype.name, repr(mismatch), type(mismatch).__name__))
# TypeError: Expected int32, got Dimension(6) of type 'Dimension' instead.
fuente
tf.reshape()
, pero realmente me gustaría obtenernum_rows
ynum_cols
como enteros para otras operaciones.tensor.get_shape().as_list()
as_list()
funciona. Por favor agréguelo a su respuesta y lo aceptaré.num_rows, num_cols = x.get_shape().as_list()
Otra forma de resolver esto es así:
tensor_shape[0].value
Esto devolverá el valor int del objeto Dimension.
fuente
para un tensor 2-D, puede obtener el número de filas y columnas como int32 usando el siguiente código:
rows, columns = map(lambda i: i.value, tensor.get_shape())
fuente
2.0 Respuesta compatible : En
Tensorflow 2.x (2.1)
, puede obtener las dimensiones (forma) del tensor como valores enteros, como se muestra en el siguiente Código:Método 1 (usando
tf.shape
) :import tensorflow as tf c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) Shape = c.shape.as_list() print(Shape) # [2,3]
Método 2 (usando
tf.get_shape()
) :import tensorflow as tf c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) Shape = c.get_shape().as_list() print(Shape) # [2,3]
fuente
Otra solución simple es usar la
map()
siguiente:Esto convierte todos los
Dimension
objetos aint
fuente
En versiones posteriores (probadas con TensorFlow 1.14) hay una forma más complicada de obtener la forma de un tensor. Puede usar
tensor.shape
para obtener la forma del tensor.fuente