Clasificare multietichetă - Cum se calculează greutăți dezechilibrate pentru BCEWithLogitsLoss în pirtorch -

Încerc să rezolv o problemă cu mai multe etichete cu 270 de etichete și am convertit etichetele țintă într-o singură formă codificată la cald. Folosesc BCEWithLogitsLoss (). Deoarece datele de antrenament sunt dezechilibrate, folosesc argumentul pos_weight, dar sunt confuz.

pos_weight (Tensor, opțional) - o pondere de exemple pozitive. Trebuie să fie un vector cu lungimea egală cu numărul de clase.

Trebuie să dau un număr total de valori pozitive ale fiecărei etichete ca tensor sau înseamnă altceva prin greutăți?

calculează

3 Răspunsuri 3

Documentația PyTorch pentru BCEWithLogitsLoss recomandă poz_weight să fie un raport între numărul negativ și numărul pozitiv pentru fiecare clasă.

Deci, dacă len (set de date) este 1000, elementul 0 al codificării dvs. multihot are 100 de numere pozitive, atunci elementul 0 al vectorului pos_weights_ ar trebui să fie 900/100 = 9. Asta înseamnă că pierderea binară încrucișată se va comporta ca și cum setul de date conține 900 de exemple pozitive în loc de 100.

Iată implementarea mea: