Tengo un conjunto de datos con 3 clases con los siguientes elementos:
- Clase 1: 900 elementos
- Clase 2: 15000 elementos.
- Clase 3: 800 elementos
Necesito predecir la clase 1 y la clase 3, que indican desviaciones importantes de la norma. La clase 2 es el caso "normal" predeterminado que no me importa.
¿Qué tipo de función de pérdida usaría aquí? Estaba pensando en usar CrossEntropyLoss, pero dado que hay un desequilibrio de clase, supongo que esto debería ponderarse. ¿Cómo funciona eso en la práctica? ¿Te gusta esto (usando PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
¿O debería invertirse el peso? es decir, 1 / peso?
¿Es este el enfoque correcto para comenzar o hay otros / mejores métodos que podría usar?
Gracias
fuente
Cuando dices: también puedes usar la clase más pequeña como nominador, que da 0.889, 0.053 y 1.0 respectivamente. Esto es solo una reescalada, los pesos relativos son los mismos.
Pero esta solución está en contradicción con la primera que diste, ¿cómo funciona?
fuente