Estoy convirtiendo un modelo keras SRGAN que se ejecutará en una GPU para ejecutarse en una TPU.

El SRGAN usa una activación PReLU y cuando uso tf.keras.layers.PReLU en mi API funcional

x = layers.PReLU(alpha_initializer = 'zeros', alpha_regularizer = None, alpha_constraint = None, shared_axes=[1,2])(x)

Obtengo el siguiente TypeError cuando encajo el modelo usando el TPU

TypeError: bad operand type for unary -: 'ReplicatedVariable'

Si cambio la capa PReLU con:

x = tf.Activations('relu')(x)

el error desaparece

¿Alguien ha visto este problema? Creo que podría estar relacionado con la función de llamada de la clase PReLU:

@tf_export('keras.layers.PReLU')
Class PReLU(layer):
...
  def call(self, inputs, mask=None):
  ...
  else:
  neg = -self.alpha * K.relu(-inputs)
user989129
fuente