Pytorch: forma correcta de usar mapas de peso personalizados en arquitectura unet

8

Hay un famoso truco en la arquitectura de u-net para usar mapas de peso personalizados para aumentar la precisión. A continuación se detallan los detalles

ingrese la descripción de la imagen aquí

Ahora, al preguntar aquí y en muchos otros lugares, conozco dos enfoques. Quiero saber cuál es el correcto o ¿hay algún otro enfoque correcto que sea más correcto?

1) Primero es usar el torch.nn.Functionalmétodo en el ciclo de entrenamiento

loss = torch.nn.functional.cross_entropy(output, target, w) donde w será el peso personalizado calculado.

2) El segundo es usar reduction='none'en la función de llamada de pérdida fuera del ciclo de entrenamiento criterion = torch.nn.CrossEntropy(reduction='none')

y luego en el ciclo de entrenamiento multiplicando con el peso personalizado

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

Ahora, estoy un poco confundido ¿cuál es el correcto o hay alguna otra manera, o ambos tienen razón?

marca
fuente

Respuestas:

3

La parte de ponderación parece simplemente una entropía cruzada ponderada que se realiza así para el número de clases (2 en el ejemplo a continuación).

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

EDITAR:

¿Has visto esta implementación de Patrick Black?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()
jchaykow
fuente
La cuestión es que el peso se calcula mediante una determinada función aquí y no es discreto. Para obtener más información, aquí hay un documento - arxiv.org/abs/1505.04597
Mark
1
@ Mark oh ya veo ahora. Por lo tanto, es una salida de pérdida en píxeles. Y los bordes se calculan previamente usando una biblioteca como opencvo algo así, y luego esas posiciones de píxeles se guardan para cada imagen y luego se multiplican por los tensores de pérdida más adelante durante el entrenamiento para que el algoritmo se centre en reducir la pérdida en esas áreas.
jchaykow
Gracias. Este legítimo parece una respuesta, intentaré verificarlo e implementarlo más y aceptaré tu respuesta después.
Mark
¿Puedes explicar la intuición detrás de esta línealogp = logp.gather(1, target.view(batch_size, 1, H, W))
Mark
0

Tenga en cuenta que torch.nn.CrossEntropyLoss () es una clase que llama a torch.nn.functional. Ver https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

Puede usar los pesos cuando defina los criterios. Comparándolos funcionalmente, ambos métodos son iguales.

Ahora, no entiendo su idea de calcular la pérdida dentro del ciclo de entrenamiento en el método 1 y fuera del ciclo de entrenamiento en el método 2. si calcula la pérdida fuera del ciclo, ¿cómo va a propagarse hacia atrás?

Devansh Bisla
fuente
No estaba confundido entre usar torch.nn.CrossEntropyLoss() y torch.nn.functional.cross_entropy(output, target, w), estaba confundido sobre cómo usar mapas de peso personalizados en la pérdida. Consulte este documento: arxiv.org/abs/1505.04597 y avíseme, si aún no puede entender lo que soy. preguntando
Mark
1
Si lo entiendo correctamente, creo que el método 2 es el correcto. Los pesos (w) dentro de la pérdida torch.nn.functional.cross_entropy (output, target, w) son pesos para las clases que no son w (x) en la fórmula. Podemos probarlo fácilmente con un pequeño script.
Devansh Bisla
Sí, incluso estoy llegando a la misma conclusión. Volveré a usted si mi red funciona como se esperaba y marcaré la respuesta como aceptada.
Mark
está bien, no grad can be implicitly created only for scalar outputsfunciona. Lo estoy obteniendo cuando ejecuto el método loss = loss * w
Mark
¿Estás seguro de que los estás resumiendo o tomando la media?
Devansh Bisla