Estoy tratando de hacer el tensorflow
equivalente de torch.transforms.Resize(TRAIN_IMAGE_SIZE)
, que cambia el tamaño de la dimensión de imagen más pequeñaTRAIN_IMAGE_SIZE
. Algo como esto
def transforms(filename):
parts = tf.strings.split(filename, '/')
label = parts[-2]
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, tf.float32)
# this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
image = largest_sq_crop(image)
image = tf.image.resize(image, (256,256))
return image, label
list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)
La respuesta simple está aquí: Tensorflow: recorta la región cuadrada central más grande de la imagen
Pero cuando uso el método con tf.data.Dataset.map(transforms)
, me sale shape=(None,None,3)
de adentro largest_sq_crop(image)
. El método funciona bien cuando lo llamo normalmente.
python
tensorflow2.0
Miguel
fuente
fuente
EagerTensors
no están disponibles dentro,Dataset.map()
por lo que se desconoce la forma. ¿hay alguna solución?largest_sq_crop
?Respuestas:
Encontré la respuesta. Tenía que ver con el hecho de que mi método de cambio de tamaño funcionaba bien con una ejecución ansiosa, por ejemplo,
tf.executing_eagerly()==True
pero fallaba cuando se usaba dentrodataset.map()
. Al parecer, en ese entorno de ejecución,tf.executing_eagerly()==False
.Mi error fue en la forma en que estaba desempacando la forma de la imagen para obtener dimensiones para escalar. La ejecución del gráfico de Tensorflow no parece admitir el acceso a la
tensor.shape
tupla.Estaba usando dimensiones de forma aguas abajo en mi
dataset.map()
función y arrojó la siguiente excepción porque estaba obteniendo enNone
lugar de un valor.Cuando cambié a desempaquetar manualmente la forma
tf.shape()
, todo funcionó bien.fuente