Estoy tratando de actualizar / cambiar los parámetros de un modelo de red neuronal y luego hacer que el paso directo de la red neuronal actualizada esté en el gráfico de cálculo (no importa cuántos cambios / actualizaciones hagamos).
Intenté esta idea, pero cada vez que lo hago, pytorch configura mis tensores actualizados (dentro del modelo) para que sean hojas, lo que mata el flujo de gradientes a las redes que quiero recibir. Mata el flujo de gradientes porque los nodos de hoja no son parte del gráfico de cálculo de la forma en que quiero que sean (ya que no son realmente hojas).
He intentado varias cosas pero nada parece funcionar. Creé un código ficticio autónomo que imprime los gradientes de las redes que deseo tener gradientes:
import torch
import torch.nn as nn
import copy
from collections import OrderedDict
# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2
criterion = nn.CrossEntropyLoss()
#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))
hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
print(f'i = {i}')
new_params = copy.deepcopy( loss_net.state_dict() )
## w^<t> := f(w^<t-1>,delta^<t-1>)
for (name, w) in loss_net.named_parameters():
print(f'name = {name}')
print(w.size())
hidden = updater_net(hidden).view(1)
print(hidden.size())
#delta = ((hidden**2)*w/2)
delta = w + hidden
wt = w + delta
print(wt.size())
new_params[name] = wt
#del loss_net.fc0.weight
#setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
#setattr(loss_net.fc0, 'weight', wt)
#loss_net.fc0.weight = wt
#loss_net.fc0.weight = nn.Parameter( wt )
##
loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
si alguien sabe cómo hacer esto, por favor deme un ping ... configuré el número de veces para actualizar en 2 porque la operación de actualización debería estar en el gráfico de cálculo un número arbitrario de veces ... por lo que DEBE funcionar para 2)
Publicación fuertemente relacionada:
- SO: ¿Cómo se pueden tener parámetros en un modelo de pytorch que no sean hojas y estar en el gráfico de cálculo?
- foro de pytorch: https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076
Publicación cruzada:
backward
? A saberretain_graph=True
y / ocreate_graph=True
?Respuestas:
NO FUNCIONA CORRECTAMENTE porque los módulos de parámetros nombrados se eliminan.
Parece que esto funciona:
salida:
Reconocimiento: poderoso albanD del equipo de pytorch: https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076/9?u= Pinocho
fuente
Debes tratar de mantener los mismos tensores, no crear nuevos.
Vaya por su
data
atributo y establezca un nuevo valor.Esto funcionó para mí en esta pregunta: ¿Cómo asignar un nuevo valor a una variable pytorch sin romper la propagación hacia atrás?
fuente