Tutorial de tăiere - Tutoriale PyTorch 1

Tehnicile de învățare profundă de ultimă generație se bazează pe modele supra-parametrizate greu de implementat. Dimpotrivă, se știe că rețelele neuronale biologice utilizează o conectivitate redusă eficientă. Identificarea tehnicilor optime pentru comprimarea modelelor prin reducerea numărului de parametri din ele este importantă pentru a reduce consumul de memorie, baterie și hardware fără a sacrifica precizia, a implementa modele ușoare pe dispozitiv și pentru a garanta confidențialitatea cu calculul privat pe dispozitiv. Pe frontul cercetării, tăierea este utilizată pentru a investiga diferențele din dinamica învățării dintre rețelele supra-parametrizate și sub-parametrizate, pentru a studia rolul subrețelelor și inițializărilor rare („bilete de loterie”) ca o tehnică de căutare a arhitecturii neuronale distructive și Marea.

tăiere

În acest tutorial, veți învăța cum să utilizați torch.nn.utils.prune pentru a sparsifica rețelele neuronale și cum să îl extindeți pentru a implementa propria tehnică de tăiere personalizată.

Cerințe¶

Creați un model

În acest tutorial, folosim arhitectura LeNet de la LeCun și colab., 1998.

Inspectați un modul¶

Să inspectăm stratul de conv1 (netuns) din modelul nostru LeNet. Deocamdată va conține greutatea și părtinirea a doi parametri și nu există tampoane.

Tunderea unui modul¶

Pentru a tăia un modul (în acest exemplu, stratul conv1 al arhitecturii noastre LeNet), selectați mai întâi o tehnică de tăiere dintre cele disponibile în torch.nn.utils.prune (sau implementați-vă propriul prin subclasarea BasePruningMethod). Apoi, specificați modulul și numele parametrului de tăiat în cadrul acelui modul. În cele din urmă, utilizând argumentele adecvate ale cuvintelor cheie solicitate de tehnica de tăiere selectată, specificați parametrii de tăiere.

În acest exemplu, vom tăia la întâmplare 30% din conexiunile din parametrul numit greutate în stratul conv1. Modulul este transmis ca primul argument către funcție; numele identifică parametrul din cadrul acelui modul utilizând identificatorul său de șir; și suma indică fie procentul conexiunilor la prune (dacă este un float între 0. și 1.), fie numărul absolut de conexiuni la prune (dacă este un număr întreg negativ).

Tunderea acționează prin eliminarea greutății din parametri și înlocuirea acesteia cu un nou parametru numit weight_orig (adică adăugarea „_orig” la numele parametrului inițial). weight_orig stochează versiunea netezată a tensorului. Tendința nu a fost tăiată, așa că va rămâne intactă.

Masca de tăiere generată de tehnica de tăiere selectată mai sus este salvată ca tampon de modul numit weight_mask (adică adăugând „_mask” la numele parametrului inițial).

Pentru ca trecerea înainte să funcționeze fără modificări, atributul de greutate trebuie să existe. Tehnicile de tăiere implementate în torch.nn.utils.prune calculează versiunea tăiată a greutății (prin combinarea măștii cu parametrul original) și le stochează în greutatea atributului. Rețineți, acesta nu mai este un parametru al modulului, acum este pur și simplu un atribut.

În cele din urmă, tăierea se aplică înainte de fiecare trecere înainte folosind forward_pre_hooks ale PyTorch. Mai exact, atunci când modulul este tăiat, așa cum am făcut aici, acesta va achiziționa un forward_pre_hook pentru fiecare parametru asociat acestuia care este tăiat. În acest caz, deoarece până acum am tăiat doar parametrul original numit greutate, va fi prezent un singur cârlig.

Pentru completare, putem acum să tundem și părtinirea, pentru a vedea cum se modifică parametrii, tampoanele, cârligele și atributele modulului. Doar pentru a încerca o altă tehnică de tăiere, aici tăiem cele mai mici 3 intrări în prejudecată prin norma L1, așa cum este implementat în funcția de tăiere l1_unstructured.