Estoy confundido sobre el método view()
en el siguiente fragmento de código.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
Mi confusión es con respecto a la siguiente línea.
x = x.view(-1, 16*5*5)
¿Qué hace la tensor.view()
función? He visto su uso en muchos lugares, pero no puedo entender cómo interpreta sus parámetros.
¿Qué sucede si le doy valores negativos como parámetros a la view()
función? Por ejemplo, ¿qué pasa si llamo tensor_variable.view(1, 1, -1)
,?
¿Alguien puede explicar el principio principal de la view()
función con algunos ejemplos?
reshape
en PyTorch?Hagamos algunos ejemplos, de más simple a más difícil.
El
view
método devuelve un tensor con los mismos datos que elself
tensor (lo que significa que el tensor devuelto tiene el mismo número de elementos), pero con una forma diferente. Por ejemplo:Suponiendo que ese
-1
no es uno de los parámetros, cuando los multiplica, el resultado debe ser igual al número de elementos en el tensor. Si lo hace:a.view(3, 3)
generará unaRuntimeError
forma porque (3 x 3) no es válida para la entrada con 16 elementos. En otras palabras: 3 x 3 no es igual a 16 sino a 9.Puede usar
-1
uno de los parámetros que pasa a la función, pero solo una vez. Todo lo que sucede es que el método hará los cálculos matemáticos sobre cómo llenar esa dimensión. Por ejemploa.view(2, -1, 4)
es equivalente aa.view(2, 2, 4)
. [16 / (2 x 4) = 2]Observe que el tensor devuelto comparte los mismos datos . Si realiza un cambio en la "vista", está cambiando los datos del tensor original:
Ahora, para un caso de uso más complejo. La documentación dice que cada nueva dimensión de vista debe ser un subespacio de una dimensión original, o solo abarcar d, d + 1, ..., d + k que satisfagan la siguiente condición similar a la contigüidad que para todo i = 0,. .., k - 1, zancada [i] = zancada [i + 1] x tamaño [i + 1] . De lo contrario,
contiguous()
debe llamarse antes de que se pueda ver el tensor. Por ejemplo:Tenga en cuenta que for
a_t
, stride [0]! = Stride [1] x size [1] since 24! = 2 x 3fuente
torch.Tensor.view()
En pocas palabras,
torch.Tensor.view()
que está inspirado ennumpy.ndarray.reshape()
onumpy.reshape()
, crea una nueva vista del tensor, siempre que la nueva forma sea compatible con la forma del tensor original.Comprendamos esto en detalle utilizando un ejemplo concreto.
Con este tensor
t
de la forma(18,)
, nuevos puntos de vista pueden solamente ser creados por las siguientes formas:(1, 18)
o equivalente(1, -1)
o o equivalente o o equivalente o o equivalente o o equivalente o o equivalente o(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Como ya podemos observar en las tuplas de forma anteriores, la multiplicación de los elementos de la tupla de forma (por ejemplo
2*9
,3*6
etc.) siempre debe ser igual al número total de elementos en el tensor original (18
en nuestro ejemplo).Otra cosa a observar es que usamos un
-1
en uno de los lugares en cada una de las tuplas de forma. Al usar a-1
, somos perezosos al hacer el cálculo nosotros mismos y delegamos la tarea a PyTorch para hacer el cálculo de ese valor para la forma cuando crea la nueva vista . Una cosa importante a tener en cuenta es que solo podemos usar una sola-1
en la tupla de forma. Los valores restantes deben ser suministrados explícitamente por nosotros. Else PyTorch se quejará lanzando unRuntimeError
:Entonces, con todas las formas mencionadas anteriormente, PyTorch siempre devolverá una nueva vista del tensor original
t
. Esto básicamente significa que solo cambia la información de paso del tensor para cada una de las nuevas vistas que se solicitan.A continuación se muestran algunos ejemplos que ilustran cómo se cambian los pasos de los tensores con cada nueva vista .
Ahora, veremos los avances de las nuevas vistas :
Esa es la magia de la
view()
función. Simplemente cambia los pasos del tensor (original) para cada una de las nuevas vistas , siempre que la forma de la nueva vista sea compatible con la forma original.Otra cosa interesante uno podría observar desde las tuplas zancadas es que el valor del elemento en el 0 º posición es igual al valor del elemento en el 1 st posición de la tupla forma.
Esto es porque:
la zancada
(6, 1)
dice que para ir de un elemento al siguiente elemento a lo largo de la 0 ª dimensión, tenemos que saltar o tomar 6 pasos. (es decir, para ir de0
a6
, uno tiene que tomar 6 pasos). Pero para ir de un elemento al siguiente elemento en la 1ª dimensión, solo necesitamos un paso (por ejemplo, ir de2
a3
).Por lo tanto, la información de pasos está en el corazón de cómo se accede a los elementos desde la memoria para realizar el cálculo.
torch.reshape ()
Esta función devolvería una vista y es exactamente lo mismo que usar
torch.Tensor.view()
siempre que la nueva forma sea compatible con la forma del tensor original. De lo contrario, devolverá una copia.Sin embargo, las notas de
torch.reshape()
advierte que:fuente
Me di cuenta de que
x.view(-1, 16 * 5 * 5)
es equivalente ax.flatten(1)
, donde el parámetro 1 indica que el proceso de aplanar comienza desde la primera dimensión (sin aplanar la dimensión de 'muestra') Como puede ver, el último uso es semánticamente más claro y más fácil de usar, por lo que prefierenflatten()
.fuente
Puede leer
-1
como número dinámico de parámetros o "cualquier cosa". Por eso solo puede haber un parámetro-1
enview()
.Si pregunta
x.view(-1,1)
esto, generará una forma de tensor[anything, 1]
dependiendo del número de elementos enx
. Por ejemplo:Saldrá:
fuente
weights.reshape(a, b)
devolverá un nuevo tensor con los mismos datos que los pesos con tamaño (a, b) ya que copia los datos en otra parte de la memoria.weights.resize_(a, b)
devuelve el mismo tensor con una forma diferente. Sin embargo, si la nueva forma da como resultado menos elementos que el tensor original, algunos elementos se eliminarán del tensor (pero no de la memoria). Si la nueva forma da como resultado más elementos que el tensor original, los nuevos elementos no se inicializarán en la memoria.weights.view(a, b)
devolverá un nuevo tensor con los mismos datos que los pesos con tamaño (a, b)fuente
Realmente me gustaron los ejemplos de @Jadiel de Armas.
Me gustaría agregar una pequeña idea de cómo se ordenan los elementos para .view (...)
fuente
Tratemos de entender la vista con los siguientes ejemplos:
-1 como valor de argumento es una manera fácil de calcular el valor de decir x siempre que conozcamos los valores de y, z o al revés en el caso de 3d y para 2d nuevamente, una manera fácil de calcular el valor de decir x siempre que saber valores de y o viceversa.
fuente