Trong web dev, có một cấu trúc gần như mọi framework đều có: vòng lặp request, handler, response. Express, FastAPI, Spring Boot, Rails. Bạn học một cái xong, đọc cái khác chỉ mất 30 phút vì cấu trúc giống nhau.
Trong ML, vòng lặp tương đương là training loop. Mọi neural network từ logistic regression đến Llama-3-400B đều training theo đúng 5 bước: forward, loss, backward, optimizer step, scheduler step. Hiểu một cái là hiểu tất cả.
Vấn đề là phần lớn dev đọc tutorial chỉ thấy trainer.fit() của Lightning hoặc Trainer.train() của HuggingFace. Magic. Khi cần debug “tại sao loss không giảm”, “tại sao gradient explode”, “tại sao learning rate nên warm up”, bạn không biết bắt đầu từ đâu vì cái trainer kia đã ăn hết cả pipeline.
Bài này tháo cái trainer ra. Code training loop từ zero bằng PyTorch thuần, train một MLP nhỏ trên synthetic data, rồi mở rộng sang transformer mini. Sau bài này bạn đọc training code của bất cứ ai cũng không bị lạc.
Mental model: 5 bước, lặp lại
Toàn bộ training loop, dù to hay nhỏ, đều theo cấu trúc này:
for epoch in range(num_epochs):
for batch in dataloader:
# 1. Forward: data đi qua model, sinh ra prediction
logits = model(batch.input)
# 2. Loss: so prediction với ground truth
loss = loss_fn(logits, batch.target)
# 3. Backward: tính gradient của loss theo từng parameter
loss.backward()
# 4. Optimizer step: cập nhật parameter theo gradient
optimizer.step()
optimizer.zero_grad()
# 5. Scheduler step: điều chỉnh learning rate
scheduler.step()
Năm dòng, đủ để train mọi thứ từ MNIST classifier đến GPT-4. Phần còn lại của training infrastructure (logging, checkpointing, multi-GPU, mixed precision) là bao quanh 5 dòng này.
Điều cần ngấm sớm: một step training = một batch đi qua đủ 5 bước, một lần cập nhật weights. Số step để train xong = (số sample / batch size) x số epoch. GPT-3 train 300 tỷ tokens với batch size 3.2 triệu tokens, tổng khoảng 95,000 steps. Mỗi step mất khoảng 5 phút trên 10,000 GPU. Tổng 34 ngày.
Phần 1: Forward pass
Forward là phần dễ nhất. Data đi vào, đi qua từng layer, sinh ra output.
import torch
import torch.nn as nn
class SimpleMLP(nn.Module):
def __init__(self, in_dim=784, hidden=256, out_dim=10):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.fc3 = nn.Linear(hidden, out_dim)
self.act = nn.GELU()
def forward(self, x):
x = self.act(self.fc1(x))
x = self.act(self.fc2(x))
return self.fc3(x)
model = SimpleMLP()
batch_input = torch.randn(32, 784)
logits = model(batch_input)
Forward pass đơn giản là gọi model(input). Đằng sau, PyTorch chạy __call__ của Module, cuối cùng gọi forward().
Có một thứ quan trọng PyTorch làm tự động: build computation graph. Mỗi phép toán trong forward (matmul, add, activation) được ghi lại thành một node trong graph. Graph này sẽ được dùng ở bước backward để tính gradient bằng chain rule.
Nếu bạn không muốn build graph (ví dụ lúc inference), wrap forward trong torch.no_grad():
with torch.no_grad():
logits = model(batch_input)
Trong training, KHÔNG wrap no_grad(). Cần graph để backward.
Phần 2: Loss function
Loss là một con số đo “model dự đoán sai bao nhiêu”. Training là quá trình tối thiểu hoá con số này.
Với classification:
loss_fn = nn.CrossEntropyLoss()
targets = torch.tensor([3, 7, 0, 9])
loss = loss_fn(logits, targets)
CrossEntropyLoss của PyTorch combo LogSoftmax + NLLLoss. Không cần softmax trong model, để raw logits là đủ. Đây là pattern phổ biến nhưng nhiều dev mới hay lặp softmax 2 lần, loss sai mà không biết.
Với LLM, loss vẫn là cross-entropy nhưng tính trên từng token:
loss = nn.functional.cross_entropy(
logits.view(-1, vocab_size),
targets.view(-1),
ignore_index=-100,
)
Bài Probability cho LLM: softmax, cross-entropy, perplexity đã đi sâu vào tại sao cross-entropy là choice mặc định cho LLM, và liên hệ với perplexity. Nếu chưa đọc, đọc lại sau bài này sẽ thấy mượt hơn.
Phần 3: Backward pass
Đây là phần làm nên tên tuổi PyTorch. Một dòng:
loss.backward()
PyTorch sẽ:
- Đi ngược computation graph từ loss về từng parameter
- Áp chain rule tính
d_loss / d_paramcho mọi param córequires_grad=True - Lưu gradient vào
param.grad
Sau khi gọi loss.backward(), mọi param.grad đều có giá trị. Kiểm tra:
for name, p in model.named_parameters():
if p.grad is not None:
print(f"{name}: grad mean={p.grad.mean():.6f}, std={p.grad.std():.6f}")
Hai pitfall hay gặp:
Pitfall 1: quên zero_grad(). Gradient được cộng dồn vào param.grad mỗi lần backward. Nếu không zero, gradient của step trước sẽ tích lũy với step này. Loss diverge ngay. Luôn gọi optimizer.zero_grad() (hoặc model.zero_grad()) trước hoặc sau mỗi step.
Pitfall 2: gradient explosion. Với network sâu (hơn 12 layer) hoặc learning rate quá cao, gradient có thể vọt lên hàng nghìn, sau đó NaN. Cách fix: gradient clipping.
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Cap tổng L2 norm của tất cả gradient ở 1.0. Mọi tutorial LLM đều có dòng này.
Phần 4: Optimizer step
Optimizer quyết định cách dùng gradient để cập nhật parameter. Đơn giản nhất là SGD:
param_new = param_old - learning_rate * gradient
PyTorch:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer.step()
Nhưng SGD thuần ít khi đủ cho LLM. Optimizer phổ biến nhất hiện nay là AdamW:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=3e-4,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.1,
)
AdamW giữ thêm 2 state buffer cho mỗi parameter: m (first moment, trung bình gradient) và v (second moment, variance gradient). Update rule:
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * grad^2
param = param - lr * m_hat / (sqrt(v_hat) + eps) - lr * weight_decay * param
Hệ quả: AdamW tốn 3x memory so với SGD vì giữ thêm m và v. Với Llama-3-8B FP32, model weights 32GB, optimizer state thêm 64GB, tổng 96GB chỉ cho optimizer + model, chưa kể activation. Đây là lý do tại sao DeepSpeed ZeRO chia optimizer state qua nhiều GPU (bài 17 sẽ chi tiết).
So sánh optimizer:
| Optimizer | Memory overhead | Pros | Cons |
|---|---|---|---|
| SGD | 1x param | Đơn giản, ít memory | Cần tune lr cẩn thận, convergence chậm |
| SGD + momentum | 2x param | Convergence nhanh hơn | Vẫn cần tune |
| Adam | 3x param | Adaptive lr, robust | Memory cao |
| AdamW | 3x param | Như Adam + decoupled weight decay | Memory cao |
| Lion | 2x param | Mới 2026, ít memory hơn Adam | Chưa được test rộng |
LLM hiện đại gần như 100% dùng AdamW.
Phần 5: Learning rate schedule
Learning rate không nên giữ cố định. Đầu training cần lr lớn để đi nhanh, cuối training cần lr nhỏ để fine-tune. Schedule là cách thay đổi lr theo step.
Schedule phổ biến nhất cho LLM: cosine với warmup.
lr
| warmup cosine decay
max | ___________
| / \___
| / \___
| / \___
min | / \___
|/_______________________________________> step
0 warmup_steps total_steps
Warmup: linear tăng từ 0 đến max trong vài nghìn step đầu. Tránh model rơi vào “bad region” do gradient không ổn định lúc bắt đầu.
Cosine decay: giảm từ max về min theo hàm cos. Smooth, không có cliff.
Implement bằng PyTorch:
from torch.optim.lr_scheduler import LambdaLR
import math
def get_lr_lambda(step, warmup_steps=1000, total_steps=100000, min_ratio=0.1):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return min_ratio + (1 - min_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
scheduler = LambdaLR(optimizer, lr_lambda=get_lr_lambda)
Gọi scheduler.step() sau mỗi optimizer.step().
GPT-3 dùng: max lr 6e-4, warmup 375 triệu tokens, cosine decay xuống 10% của max.
Llama-3 dùng: max lr 3e-4, warmup 8000 steps, cosine decay xuống 10%.
Phần 6: Full training loop, runnable
Code này train được trên CPU trong vài phút, dùng synthetic data:
import torch
import torch.nn as nn
import math
torch.manual_seed(42)
N, D = 10000, 64
X = torch.randn(N, D)
true_w = torch.randn(D)
y = X @ true_w + 0.1 * torch.randn(N)
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(D, 128), nn.GELU(),
nn.Linear(128, 128), nn.GELU(),
nn.Linear(128, 1),
)
def forward(self, x):
return self.net(x).squeeze(-1)
model = MLP()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
total_steps = 2000
warmup = 100
def lr_lambda(step):
if step < warmup:
return step / warmup
progress = (step - warmup) / (total_steps - warmup)
return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
loss_fn = nn.MSELoss()
batch_size = 64
for step in range(total_steps):
idx = torch.randint(0, N, (batch_size,))
xb, yb = X[idx], y[idx]
pred = model(xb)
loss = loss_fn(pred, yb)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
if step % 200 == 0:
current_lr = optimizer.param_groups[0]["lr"]
print(f"step {step:5d} | loss {loss.item():.4f} | lr {current_lr:.6f}")
Output mẫu:
step 0 | loss 65.3217 | lr 0.000010
step 200 | loss 1.2451 | lr 0.000997
step 400 | loss 0.0823 | lr 0.000936
step 600 | loss 0.0245 | lr 0.000815
step 1800 | loss 0.0098 | lr 0.000125
Loss giảm đều, learning rate warmup rồi decay. Pattern này lặp lại y hệt khi train transformer, chỉ khác model lớn hơn và data nhiều hơn.
Pitfall thực tế: loss không giảm
Có một lần tôi train một MLP nhỏ trên synthetic data, loss đứng yên ở 2.3 suốt 500 step. Đoán đủ thứ: lr quá cao, model bị broken, data sai. Cuối cùng nguyên nhân: quên gọi optimizer.zero_grad().
Gradient cộng dồn qua từng step, đến step 100 thì gradient đã to gấp 100 lần value thật. Optimizer cập nhật param theo gradient bị inflate, param nhảy lung tung không hội tụ.
Fix: thêm 1 dòng optimizer.zero_grad() trước loss.backward(). Loss bắt đầu giảm trong step thứ 2.
Bài học: khi loss không giảm, kiểm tra theo thứ tự:
optimizer.zero_grad()có được gọi không?loss.backward()có chạy không (in loss để verify nó là tensor có graph)?- Learning rate có quá cao (in
optimizer.param_groups[0]["lr"])? - Gradient có NaN không (
torch.isnan(p.grad).any()cho mọi param)? - Data có bị normalize lệch không (mean, std)?
90% trường hợp loss không giảm là do 1 trong 5 nguyên nhân trên.
Cheatsheet: PyTorch training API
| Code | Mục đích |
|---|---|
model.train() | Chuyển sang training mode (dropout, batch norm hoạt động) |
optimizer.zero_grad() | Reset gradient buffer về 0 |
loss.backward() | Tính gradient backward qua graph |
optimizer.step() | Cập nhật param theo gradient |
scheduler.step() | Cập nhật learning rate |
torch.nn.utils.clip_grad_norm_(params, max_norm) | Clip gradient L2 norm |
torch.no_grad() | Context manager tắt graph (inference) |
param.grad | Tensor chứa gradient của param đó |
param.requires_grad | Param có cần backward không |
| Hyperparameter | LLM range phổ biến | Lưu ý |
|---|---|---|
| Learning rate | 1e-4 đến 6e-4 | Pretraining cao hơn fine-tune |
| Batch size (tokens) | 0.5M đến 4M | GPT-3 dùng 3.2M, Llama-3 dùng 4M |
| Weight decay | 0.05 đến 0.1 | Decoupled trong AdamW |
| Gradient clip | 1.0 | Gần như universal |
| Warmup steps | 1-3% total steps | GPT-3: 0.4% |
| Beta1, beta2 (Adam) | (0.9, 0.95) | LLM dùng beta2 nhỏ hơn default 0.999 |
Lời kết
Bạn vừa đi qua nguyên tử của ML: training loop. Mọi paper LLM, mọi codebase research đều build trên 5 dòng này. Khi bạn debug training pipeline lần sau, hãy quay lại 5 bước cơ bản trước khi nghĩ đến những thứ phức tạp hơn.
Hands-on song song:
- Copy code trong Phần 6 vào một file
train.py, chạy thử với Python local. Không cần GPU. Verify loss giảm. - Modify: thử thay AdamW bằng SGD, xem convergence chậm hơn bao nhiêu. Thử bỏ warmup, xem loss có spike không.
- Đọc training loop của nanoGPT (
train.py, khoảng 400 dòng). Nhận diện 5 thành phần trên trong đó. Phần còn lại của file là DDP, mixed precision, checkpointing, sẽ học ở bài 16 và 17. - Nếu muốn dataset thật, dùng
tinystoriestừ HuggingFace (datasets.load_dataset("roneneldan/TinyStories")) làm test bed. Tokenize bằng GPT-2 tokenizer rồi train một transformer 6 layer. Chạy trên Colab free tier khoảng 2 tiếng được.
Bài 15 sẽ bàn về Scaling laws Chinchilla: dữ liệu bao nhiêu, parameter bao nhiêu, compute bao nhiêu là tối ưu. Hiểu được scaling laws là biết được “model 7B của Meta train với 15 trillion tokens có phải overkill không”, “training data 100GB của tôi đủ cho model 1B không”. Đây là kiến thức economist của ML engineer.