¿Hay alguna posibilidad de cambiar la métrica utilizada por la devolución de llamada de detención temprana en Keras?

12

Cuando se utiliza la devolución de llamada de detención temprana en el entrenamiento de Keras, se detiene cuando alguna métrica (generalmente pérdida de validación) no aumenta. ¿Hay alguna manera de usar otra métrica (como precisión, recuperación, medida f) en lugar de pérdida de validación? Todos los ejemplos que he visto hasta ahora son similares a este: callbacks.EarlyStopping (monitor = 'val_loss', paciencia = 5, detallado = 0, modo = 'auto')

P.Joseph
fuente

Respuestas:

10

Puede usar cualquier función métrica que haya especificado al compilar el modelo.

Digamos que tiene la siguiente función métrica:

def my_metric(y_true, y_pred):
     return some_metric_computation(y_true, y_pred)

El único requisito para esta función es que toma acepta la verdadera y la y predicha.

Cuando compila el modelo, especifica esta métrica, de forma similar a cómo especifica la compilación en métricas como 'precisión':

model.compile(metrics=['accuracy', my_metric], ...)

Tenga en cuenta que estamos utilizando el nombre de la función my_metric sin '' (en contraste con la 'precisión' incorporada).

Luego, si define su EarlyStopping, simplemente use el nombre de la función (esta vez con ''):

EarlyStopping(monitor='my_metric', mode='min')

Asegúrese de especificar el modo (min si más bajo es mejor, max si más alto es mejor).

Puede usarlo como cualquier métrica integrada. Esto probablemente también funcione con otras devoluciones de llamada como ModelCheckpoint (pero no lo he probado). Internamente, Keras simplemente agrega la nueva métrica a la lista de métricas disponibles para este modelo utilizando el nombre de la función.

Si especifica datos para la validación en su model.fit (...), también puede usarlos para EarlyStopping usando 'val_my_metric'.

Miguel
fuente
3

Por supuesto, ¡solo crea el tuyo!

class EarlyStopByF1(keras.callbacks.Callback):
    def __init__(self, value = 0, verbose = 0):
        super(keras.callbacks.Callback, self).__init__()
        self.value = value
        self.verbose = verbose


    def on_epoch_end(self, epoch, logs={}):
         predict = np.asarray(self.model.predict(self.validation_data[0]))
         target = self.validation_data[1]
         score = f1_score(target, prediction)
         if score > self.value:
            if self.verbose >0:
                print("Epoch %05d: early stopping Threshold" % epoch)
            self.model.stop_training = True


callbacks = [EarlyStopByF1(value = .90, verbose =1)]
model.fit(X, y, batch_size = 32, nb_epoch=nb_epoch, verbose = 1, 
validation_data(X_val,y_val), callbacks=callbacks)

No he probado esto, pero ese debería ser el sabor general de cómo lo haces. Si no funciona, avíseme y volveré a intentarlo durante el fin de semana. También estoy asumiendo que ya tienes tu propia puntuación f1 implementada. Si no solo importa para sklearn.

Sombrero de copa
fuente
+1 Todavía funciona a partir del 11/02/2020 utilizando las últimas Keras y Python 3.7
Austin