Estaba buscando formas alternativas de guardar un modelo entrenado en PyTorch. Hasta ahora, he encontrado dos alternativas.
- torch.save () para guardar un modelo y torch.load () para cargar un modelo.
- model.state_dict () para guardar un modelo entrenado y model.load_state_dict () para cargar el modelo guardado.
Me he encontrado con esta discusión donde se recomienda el enfoque 2 sobre el enfoque 1.
Mi pregunta es, ¿por qué se prefiere el segundo enfoque? ¿Es solo porque los módulos torch.nn tienen esas dos funciones y se nos recomienda usarlas?
python
serialization
deep-learning
pytorch
tensor
Wasi Ahmad
fuente
fuente
torch.save(model, f)
ytorch.save(model.state_dict(), f)
. Los archivos guardados tienen el mismo tamaño. Ahora estoy confundido. Además, encontré que usar pickle para guardar model.state_dict () es extremadamente lento. Creo que la mejor manera es usarlo,torch.save(model.state_dict(), f)
ya que maneja la creación del modelo, y la antorcha maneja la carga de los pesos del modelo, eliminando así posibles problemas. Referencia: discus.pytorch.org/t/saving-torch-models/838/4pickle
?Respuestas:
He encontrado esta página en su repositorio de Github, solo pegaré el contenido aquí.
Enfoque recomendado para guardar un modelo
Hay dos enfoques principales para serializar y restaurar un modelo.
El primero (recomendado) guarda y carga solo los parámetros del modelo:
Entonces despúes:
El segundo guarda y carga todo el modelo:
Entonces despúes:
Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y a la estructura de directorio exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunos refactores serios.
fuente
pickle
?Depende de lo que quieras hacer.
Caso n. ° 1: guarde el modelo para usarlo usted mismo por inferencia : guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación. Esto se hace porque usualmente tienes
BatchNorm
yDropout
capas que por defecto están en modo tren en la construcción:Caso n. ° 2: Guarde el modelo para reanudar el entrenamiento más tarde : si necesita seguir entrenando el modelo que está a punto de guardar, debe guardar más que solo el modelo. También debe guardar el estado del optimizador, las épocas, la puntuación, etc. Lo haría así:
Para reanudar el entrenamiento, haría cosas como:
state = torch.load(filepath)
y luego, para restaurar el estado de cada objeto individual, algo como esto:Como está reanudando el entrenamiento, NO llame
model.eval()
una vez que restaure los estados al cargar.Caso # 3: Modelo para ser usado por otra persona sin acceso a su código : en Tensorflow puede crear un
.pb
archivo que defina tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usaTensorflow serve
. La forma equivalente de hacer esto en Pytorch sería:De esta manera todavía no es a prueba de balas y dado que Pytorch todavía está experimentando muchos cambios, no lo recomendaría.
fuente
torch.load
devuelve solo un OrderedDict. ¿Cómo se obtiene el modelo para hacer predicciones?El pepinillo biblioteca Python implementa protocolos binarios para serializar y deserializar un objeto Python.
Cuando usted
import torch
(o cuando usa PyTorch) lo haráimport pickle
por usted y no necesita llamarpickle.dump()
ypickle.load()
directamente, cuáles son los métodos para guardar y cargar el objeto.De hecho,
torch.save()
ytorch.load()
lo envolverápickle.dump()
ypickle.load()
para ti.La
state_dict
otra respuesta mencionada merece unas pocas notas más.¿
state_dict
Qué tenemos dentro de PyTorch? En realidad hay dosstate_dict
s.El modelo PyTorch
torch.nn.Module
tienemodel.parameters()
llamada para obtener parámetros que se pueden aprender (w y b). Estos parámetros que se pueden aprender, una vez establecidos al azar, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primerosstate_dict
.El segundo
state_dict
es el optimizador de estado dict. Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizadorstate_dict
es fijo. Nada que aprender allí.Debido a que los
state_dict
objetos son diccionarios de Python, se pueden guardar, actualizar, alterar y restaurar fácilmente, agregando una gran modularidad a los modelos y optimizadores de PyTorch.Creemos un modelo súper simple para explicar esto:
Este código generará lo siguiente:
Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencial
Tenga en cuenta que solo las capas con parámetros que se pueden aprender (capas convolucionales, capas lineales, etc.) y memorias intermedias registradas (capas de batchnorm) tienen entradas en el modelo
state_dict
.Cosas que no se pueden aprender, pertenecen al objeto optimizador
state_dict
, que contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.El resto de la historia es igual; en la fase de inferencia (esta es una fase cuando usamos el modelo después del entrenamiento) para predecir; predecimos en función de los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros
model.state_dict()
.Y para usar luego model.load_state_dict (torch.load (filepath)) model.eval ()
Nota: No olvide la última línea,
model.eval()
esto es crucial después de cargar el modelo.Tampoco intentes guardar
torch.save(model.parameters(), filepath)
. Elmodel.parameters()
es solo el objeto generador.Por otro lado,
torch.save(model, filepath)
guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizadorstate_dict
. Verifique la otra excelente respuesta de @Jadiel de Armas para guardar la sentencia de estado del optimizador.fuente
Una convención común de PyTorch es guardar modelos usando una extensión de archivo .pt o .pth.
Guardar / cargar todo el modelo Guardar:
Carga:
La clase de modelo debe definirse en alguna parte
fuente
Si desea guardar el modelo y desea reanudar la capacitación más tarde:
GPU única: guardar:
Carga:
GPU múltiple: guardar
Carga:
fuente