¿La mejor manera de guardar un modelo entrenado en PyTorch?

192

Estaba buscando formas alternativas de guardar un modelo entrenado en PyTorch. Hasta ahora, he encontrado dos alternativas.

  1. torch.save () para guardar un modelo y torch.load () para cargar un modelo.
  2. model.state_dict () para guardar un modelo entrenado y model.load_state_dict () para cargar el modelo guardado.

Me he encontrado con esta discusión donde se recomienda el enfoque 2 sobre el enfoque 1.

Mi pregunta es, ¿por qué se prefiere el segundo enfoque? ¿Es solo porque los módulos torch.nn tienen esas dos funciones y se nos recomienda usarlas?

Wasi Ahmad
fuente
2
Creo que es porque torch.save () guarda todas las variables intermedias también, como salidas intermedias para uso de propagación inversa. Pero solo necesita guardar los parámetros del modelo, como peso / sesgo, etc. A veces, el primero puede ser mucho más grande que el segundo.
Dawei Yang
2
Lo probé torch.save(model, f)y torch.save(model.state_dict(), f). Los archivos guardados tienen el mismo tamaño. Ahora estoy confundido. Además, encontré que usar pickle para guardar model.state_dict () es extremadamente lento. Creo que la mejor manera es usarlo, torch.save(model.state_dict(), f)ya que maneja la creación del modelo, y la antorcha maneja la carga de los pesos del modelo, eliminando así posibles problemas. Referencia: discus.pytorch.org/t/saving-torch-models/838/4
Dawei Yang
Parece que PyTorch ha abordado esto un poco más explícitamente en su sección de tutoriales : hay mucha buena información allí que no se enumera en las respuestas aquí, incluido el almacenamiento de más de un modelo a la vez y los modelos de inicio cálido.
whlteXbread
¿Qué hay de malo en usar pickle?
Charlie Parker
1
@CharlieParker torch.save se basa en pickle. Lo siguiente es del tutorial vinculado anteriormente: "[torch.save] guardará todo el módulo usando el módulo pickle de Python. La desventaja de este enfoque es que los datos serializados están vinculados a las clases específicas y la estructura de directorios exacta utilizada cuando el modelo se guarda. La razón de esto es porque pickle no guarda la clase de modelo en sí. Más bien, guarda una ruta al archivo que contiene la clase, que se usa durante el tiempo de carga. Debido a esto, su código puede romperse de varias maneras cuando utilizado en otros proyectos o después de refactores ".
David Miller

Respuestas:

214

He encontrado esta página en su repositorio de Github, solo pegaré el contenido aquí.


Enfoque recomendado para guardar un modelo

Hay dos enfoques principales para serializar y restaurar un modelo.

El primero (recomendado) guarda y carga solo los parámetros del modelo:

torch.save(the_model.state_dict(), PATH)

Entonces despúes:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

El segundo guarda y carga todo el modelo:

torch.save(the_model, PATH)

Entonces despúes:

the_model = torch.load(PATH)

Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y a la estructura de directorio exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunos refactores serios.

dontloo
fuente
8
De acuerdo con @smth discusion.pytorch.org/t/saving-and-loading-a-model-in-pytorch/… el modelo se recarga para entrenar el modelo por defecto. por lo tanto, debe llamar manualmente a the_model.eval () después de cargar, si lo está cargando por inferencia, no reanudar el entrenamiento.
WillZ
el segundo método da error stackoverflow.com/questions/53798009/… en Windows 10. no pudo resolverlo
Gulzar
¿Hay alguna opción para guardar sin necesidad de un acceso para la clase de modelo?
Michael D
Con ese enfoque, ¿cómo hace un seguimiento de los * args y ** kwargs que necesita pasar para el caso de carga?
Mariano Kamp
¿Qué hay de malo en usar pickle?
Charlie Parker
144

Depende de lo que quieras hacer.

Caso n. ° 1: guarde el modelo para usarlo usted mismo por inferencia : guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación. Esto se hace porque usualmente tienes BatchNormy Dropoutcapas que por defecto están en modo tren en la construcción:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Caso n. ° 2: Guarde el modelo para reanudar el entrenamiento más tarde : si necesita seguir entrenando el modelo que está a punto de guardar, debe guardar más que solo el modelo. También debe guardar el estado del optimizador, las épocas, la puntuación, etc. Lo haría así:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Para reanudar el entrenamiento, haría cosas como: state = torch.load(filepath)y luego, para restaurar el estado de cada objeto individual, algo como esto:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Como está reanudando el entrenamiento, NO llame model.eval()una vez que restaure los estados al cargar.

Caso # 3: Modelo para ser usado por otra persona sin acceso a su código : en Tensorflow puede crear un .pbarchivo que defina tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve. La forma equivalente de hacer esto en Pytorch sería:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

De esta manera todavía no es a prueba de balas y dado que Pytorch todavía está experimentando muchos cambios, no lo recomendaría.

Jadiel de Armas
fuente
1
¿Hay un archivo recomendado que termine para los 3 casos? ¿O es siempre .pth?
Verena Haunschmid
1
En el caso # 3 torch.loaddevuelve solo un OrderedDict. ¿Cómo se obtiene el modelo para hacer predicciones?
Alber8295
Hola, ¿puedo saber cómo hacer el mencionado "Caso # 2: Guardar modelo para reanudar la capacitación más tarde"? Logré cargar el punto de control en el modelo, luego no pude ejecutar o reanudar el modelo de entrenamiento como "model.to (device) model = train_model_epoch (modelo, criterio, optimizador, programado, épocas)"
dnez
1
Hola, para el caso uno que es por inferencia, en el documento oficial de pytorch dice que debe guardar el optimizador state_dict para inferencia o para completar el entrenamiento. "Al guardar un punto de control general, para usarlo ya sea para inferencia o para reanudar el entrenamiento, debe guardar más que solo el estado_dict del modelo. Es importante también guardar el estado_dictorio del optimizador, ya que este contiene buffers y parámetros que se actualizan a medida que el modelo entrena . "
Mohammed Awney
1
En el caso # 3, la clase de modelo debe definirse en alguna parte.
Michael D
12

El pepinillo biblioteca Python implementa protocolos binarios para serializar y deserializar un objeto Python.

Cuando usted import torch(o cuando usa PyTorch) lo hará import picklepor usted y no necesita llamar pickle.dump()y pickle.load()directamente, cuáles son los métodos para guardar y cargar el objeto.

De hecho, torch.save()y torch.load()lo envolverá pickle.dump()y pickle.load()para ti.

La state_dictotra respuesta mencionada merece unas pocas notas más.

¿ state_dictQué tenemos dentro de PyTorch? En realidad hay dos state_dicts.

El modelo PyTorch torch.nn.Moduletiene model.parameters()llamada para obtener parámetros que se pueden aprender (w y b). Estos parámetros que se pueden aprender, una vez establecidos al azar, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict.

El segundo state_dictes el optimizador de estado dict. Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizadorstate_dict es fijo. Nada que aprender allí.

Debido a que los state_dictobjetos son diccionarios de Python, se pueden guardar, actualizar, alterar y restaurar fácilmente, agregando una gran modularidad a los modelos y optimizadores de PyTorch.

Creemos un modelo súper simple para explicar esto:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Este código generará lo siguiente:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencial

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Tenga en cuenta que solo las capas con parámetros que se pueden aprender (capas convolucionales, capas lineales, etc.) y memorias intermedias registradas (capas de batchnorm) tienen entradas en el modelo state_dict.

Cosas que no se pueden aprender, pertenecen al objeto optimizador state_dict , que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.

El resto de la historia es igual; en la fase de inferencia (esta es una fase cuando usamos el modelo después del entrenamiento) para predecir; predecimos en función de los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros model.state_dict().

torch.save(model.state_dict(), filepath)

Y para usar luego model.load_state_dict (torch.load (filepath)) model.eval ()

Nota: No olvide la última línea, model.eval()esto es crucial después de cargar el modelo.

Tampoco intentes guardar torch.save(model.parameters(), filepath). El model.parameters()es solo el objeto generador.

Por otro lado, torch.save(model, filepath)guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict. Verifique la otra excelente respuesta de @Jadiel de Armas para guardar la sentencia de estado del optimizador.

prosti
fuente
Aunque no es una solución directa, ¡la esencia del problema se analiza en profundidad! Voto a favor.
Jason Young
7

Una convención común de PyTorch es guardar modelos usando una extensión de archivo .pt o .pth.

Guardar / cargar todo el modelo Guardar:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Carga:

La clase de modelo debe definirse en alguna parte

model = torch.load(PATH)
model.eval()
duro
fuente
4

Si desea guardar el modelo y desea reanudar la capacitación más tarde:

GPU única: guardar:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Carga:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

GPU múltiple: guardar

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Carga:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Joy Mazumder
fuente