w LLM-ach koszt treningu jest ogromny. Mamy trzy ograniczenia:
Jeśli zmieniam batch size, czyli liczbę przykładów/tokenów używanych w jednym kroku treningowym, to zwykle musz´ też zmienić learning rate, żeby trening nadal był stabilny i efektywny.
Co robi learning rate? (LR)
Learning rate mówi, jak duży krok robi optimizer po policzeniu gradientu.
za mały learning rate → model uczy się wolno
za duży learning rate → model może być niestabilny albo eksplodować
dobry learning rate → model uczy się szybko i stabilnie
Gradient descent wygląda intuicyjnie tak:
nowe_wagi = stare_wagi - learning_rate × gradient
Czyli learning rate jest „siłą ruchu”.
Co robi batch size?
Batch size mówi, z ilu przykładów/tokenów liczymy gradient w jednym kroku.
Dla LLM ważniejsze jest zwykle myślenie w tokenach:
batch size = liczba tokenów użytych do jednego kroku optymalizacji
Przykład:
batch = 4M tokenów
oznacza, że jeden update wag jest liczony na podstawie 4 milionów tokenów.
Większy batch daje bardziej stabilny gradient, bo gradient jest uśredniony z większej liczby przykładów.
mały batch → gradient jest głośny, losowy, niestabilny
duży batch → gradient jest gładszy, dokładniejszy
Ale większy batch ma też minus: robisz mniej update’ów dla tej samej liczby tokenów.
Dlaczego learning rate i batch size są powiązane?
Bo batch size zmienia „jakość” gradientu.
Jeśli batch jest mały, gradient jest szumiący. Wtedy zbyt duży learning rate może wyrzucić model z sensownej ścieżki.
Jeśli batch jest większy, gradient jest mniej szumiący. Możesz pozwolić sobie na większy learning rate, bo kierunek kroku jest bardziej wiarygodny.
Czyli intuicja
większy batch → mniej szumu → można zrobić większy krok
mniejszy batch → więcej szumu → trzeba ostrożniejszego kroku
Linear scaling rule
Najbardziej znana prosta reguła:
Jeśli zwiększasz batch size k razy, zwiększ learning rate k razy
np
batch 256 → LR 0.001
batch 512 → LR 0.002
batch 1024 → LR 0.004
W praktyce działało to dobrze w wielu klasycznych treningach deep learningu, szczególnie przy SGD i dużych batchach, np. w vision models.
Ale dla LLM trzeba uważać, bo pełne liniowe skalowanie często jest zbyt agresywne.
Square root scaling rule
jeśli zwiększasz batch size k razy, zwiększ learning rate √k razy
np
batch 256 → LR 0.001
batch 1024 → LR 0.002
Critical batch size
Tu pojawia się pojęcie critical batch size.
To jest punkt, po którym zwiększanie batch size przestaje dawać duży zysk.
Do pewnego momentu:
większy batch → lepsze wykorzystanie GPU → szybszy trening
Ale po przekroczeniu pewnej granicy:
większy batch → mniej update’ów → gorsza efektywność uczenia
Intuicja
Batch 1M tokenów → dobry trening
Batch 4M tokenów → nadal dobrze
Batch 32M tokenów → GPU pracują, ale każdy update jest zbyt rzadki
w LLM-ach koszt treningu jest ogromny. Mamy trzy ograniczenia:
- compute
- czas
- stabilność treningu
Duży batch pomaga skalować trening na wiele GPU.
Ale jeżeli batch będzie za duży, możesz marnować compute.
Czyli wybór batch size to kompromis:
za mały batch:
- słabe wykorzystanie GPU
- wolniejszy wall-clock training
- więcej szumu
za duży batch:
- mniej kroków optymalizacji
- możliwa gorsza efektywność tokenowa
- większe ryzyko złego minimum / gorszej generalizacji
Learning rate / batch size scaling rules
Learning rate / batch size scaling rules opisują, jak należy zmieniać learning rate, gdy zmieniamy batch size.
Większy batch daje mniej zaszumiony gradient, więc często pozwala na większy learning rate. Jednak po przekroczeniu critical batch size dalsze zwiększanie batcha daje coraz mniejsze korzyści, bo model wykonuje mniej update’ów względem liczby przetworzonych tokenów.
Najważniejsze reguły:
- linear scaling rule: batch × k → LR × k
- square root scaling rule: batch × k → LR × √k
- warmup: przy dużym LR zaczynamy od małej wartości i stopniowo zwiększamy
- decay: po osiągnięciu peak LR learning rate zwykle maleje
- critical batch size: punkt, po którym większy batch daje coraz mniejszy zwrot
W LLM-ach te reguły są bardzo ważne, bo batch size wpływa jednocześnie na stabilność treningu, wykorzystanie GPU, liczbę update’ów i efektywność tokenową.