¿Cómo se pueden tener parámetros en un modelo de pytorch que no sean hojas y estar en el gráfico de cálculo?

10

Estoy tratando de actualizar / cambiar los parámetros de un modelo de red neuronal y luego hacer que el paso directo de la red neuronal actualizada esté en el gráfico de cálculo (no importa cuántos cambios / actualizaciones hagamos).

Intenté esta idea, pero cada vez que lo hago, pytorch configura mis tensores actualizados (dentro del modelo) para que sean hojas, lo que mata el flujo de gradientes a las redes que quiero recibir. Mata el flujo de gradientes porque los nodos de hoja no son parte del gráfico de cálculo de la forma en que quiero que sean (ya que no son realmente hojas).

He intentado varias cosas pero nada parece funcionar. Creé un código ficticio autónomo que imprime los gradientes de las redes que deseo tener gradientes:

import torch
import torch.nn as nn

import copy

from collections import OrderedDict

# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2

criterion = nn.CrossEntropyLoss()

#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))

hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
nb_updates = 2
for i in range(nb_updates):
    print(f'i = {i}')
    new_params = copy.deepcopy( loss_net.state_dict() )
    ## w^<t> := f(w^<t-1>,delta^<t-1>)
    for (name, w) in loss_net.named_parameters():
        print(f'name = {name}')
        print(w.size())
        hidden = updater_net(hidden).view(1)
        print(hidden.size())
        #delta = ((hidden**2)*w/2)
        delta = w + hidden
        wt = w + delta
        print(wt.size())
        new_params[name] = wt
        #del loss_net.fc0.weight
        #setattr(loss_net.fc0, 'weight', nn.Parameter( wt ))
        #setattr(loss_net.fc0, 'weight', wt)
        #loss_net.fc0.weight = wt
        #loss_net.fc0.weight = nn.Parameter( wt )
    ##
    loss_net.load_state_dict(new_params)
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}')
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')

si alguien sabe cómo hacer esto, por favor deme un ping ... configuré el número de veces para actualizar en 2 porque la operación de actualización debería estar en el gráfico de cálculo un número arbitrario de veces ... por lo que DEBE funcionar para 2)


Publicación fuertemente relacionada:

Publicación cruzada:

Pinocho
fuente
¿Intentaste argumentos para backward? A saber retain_graph=Truey / o create_graph=True?
Szymon Maszke

Respuestas:

3

NO FUNCIONA CORRECTAMENTE porque los módulos de parámetros nombrados se eliminan.


Parece que esto funciona:

import torch
import torch.nn as nn

from torchviz import make_dot

import copy

from collections import OrderedDict

# img = torch.randn([8,3,32,32])
# targets = torch.LongTensor([1, 2, 0, 6, 2, 9, 4, 9])
# img = torch.randn([1,3,32,32])
# targets = torch.LongTensor([1])
x = torch.randn(1)
target = 12.0*x**2

criterion = nn.CrossEntropyLoss()

#loss_net = nn.Sequential(OrderedDict([('conv0',nn.Conv2d(in_channels=3,out_channels=10,kernel_size=32))]))
loss_net = nn.Sequential(OrderedDict([('fc0', nn.Linear(in_features=1,out_features=1))]))

hidden = torch.randn(size=(1,1),requires_grad=True)
updater_net = nn.Sequential(OrderedDict([('fc0',nn.Linear(in_features=1,out_features=1))]))
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
#
def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

nb_updates = 2
for i in range(nb_updates):
    print(f'i = {i}')
    new_params = copy.deepcopy( loss_net.state_dict() )
    ## w^<t> := f(w^<t-1>,delta^<t-1>)
    for (name, w) in list(loss_net.named_parameters()):
        hidden = updater_net(hidden).view(1)
        #delta = ((hidden**2)*w/2)
        delta = w + hidden
        wt = w + delta
        del_attr(loss_net, name.split("."))
        set_attr(loss_net, name.split("."), wt)
    ##
#
print()
print(f'updater_net.fc0.weight.is_leaf = {updater_net.fc0.weight.is_leaf}')
print(f'loss_net.fc0.weight.is_leaf = {loss_net.fc0.weight.is_leaf}')
outputs = loss_net(x)
loss_val = 0.5*(target - outputs)**2
loss_val.backward()
print()
print(f'-- params that dont matter if they have gradients --')
print(f'loss_net.grad = {loss_net.fc0.weight.grad}')
print('-- params we want to have gradients --')
print(f'hidden.grad = {hidden.grad}') # None because this is not a leaf, it is overriden in the for loop above.
print(f'updater_net.fc0.weight.grad = {updater_net.fc0.weight.grad}')
print(f'updater_net.fc0.bias.grad = {updater_net.fc0.bias.grad}')
make_dot(loss_val)

salida:

updater_net.fc0.weight.is_leaf = True
i = 0
i = 1

updater_net.fc0.weight.is_leaf = True
loss_net.fc0.weight.is_leaf = False

-- params that dont matter if they have gradients --
loss_net.grad = None
-- params we want to have gradients --
hidden.grad = None
updater_net.fc0.weight.grad = tensor([[0.7152]])
updater_net.fc0.bias.grad = tensor([-7.4249])

Reconocimiento: poderoso albanD del equipo de pytorch: https://discuss.pytorch.org/t/how-does-one-have-the-parameters-of-a-model-not-be-leafs/70076/9?u= Pinocho

Pinocho
fuente
chicos, esto está mal, no usen este código, no permite propagar gradientes por más de 1 paso. Utilice esto en su lugar: github.com/facebookresearch/higher
Pinocho
esto no funciona ppl!
Pinocho
la biblioteca superior tampoco funciona para mí todavía.
Pinocho