训练失败模式与排障指南
1. 常见失败模式分类
训练失败
├── 数值问题
│ ├── Loss 变为 NaN
│ ├── Loss 尖峰(Spike)
│ └── 梯度爆炸
├── 内存问题
│ ├── CUDA OOM (Out of Memory)
│ ├── CPU OOM
│ └── 磁盘空间不足
├── 数据问题
│ ├── Token 越界
│ ├── 格式错误
│ └── 分布不均
├── 基础设施问题
│ ├── 分布式同步失败
│ ├── Checkpoint 损坏
│ └── 网络超时
└── 性能问题
├── 训练速度过慢
├── GPU 利用率低
└── 通信瓶颈
2. NaN Loss
症状
- Loss 突然变为
NaN
- 之后所有输出都是
NaN
- 模型权重变为
inf/NaN
根因分析
| 原因 |
概率 |
检测方法 |
| 学习率过高 |
高 |
降低 LR 10 倍后重试 |
| FP16 溢出 |
高 |
检查是否使用 FP16 而非 BF16 |
| 数据包含 NaN/Inf |
中 |
检查输入 Token ID 是否合法 |
| 梯度爆炸 |
中 |
监控梯度范数 |
| Embedding 查表越界 |
低 |
检查 Token ID 范围 |
排障步骤
# 1. 检测 NaN
if torch.isnan(loss):
print("NaN detected!")
# 检查输入
print(f"Input NaN: {torch.isnan(inputs['input_ids']).any()}")
print(f"Labels NaN: {torch.isnan(inputs['labels']).any()}")
# 检查模型参数
for name, param in model.named_parameters():
if torch.isnan(param).any():
print(f"NaN in parameter: {name}")
# 2. 添加梯度钩子
def nan_hook(module, input, output):
if isinstance(output, torch.Tensor) and torch.isnan(output).any():
print(f"NaN in {module.__class__.__name__}")
for module in model.modules():
module.register_forward_hook(nan_hook)
修复方案
- 切换到 BF16:BF16 动态范围更大,不易溢出
- 降低学习率:尝试当前 LR 的 1/10
- 启用梯度裁剪:
max_grad_norm=1.0
- 使用 Loss Scaling(FP16 时):GradScaler
- 检查数据质量:过滤 NaN/Inf Token
3. Loss 尖峰(Loss Spike)
症状
- Loss 从稳定值突然大幅跳升
- 可能自行恢复,也可能持续
根因分析
| 原因 |
描述 |
| 坏数据批次 |
包含异常样本(超长序列、罕见 Token) |
| 学习率调度问题 |
Warmup 不够或 Decay 过早 |
| 梯度累积同步问题 |
分布式训练中梯度未正确同步 |
| 数据分布变化 |
切换到新的数据分片 |
排障步骤
- 记录触发 Spike 的具体 Step
- 重现该 Step 的数据批次,检查内容
- 检查 LR Schedule:确保 Warmup 充分
- 增加梯度裁剪
- 检查数据加载器的 Shuffle 行为
修复方案
# 梯度裁剪(最常见修复)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# LR Warmup(确保充分预热)
from transformers import get_cosine_schedule_with_warmup
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=100, # 预热步数
num_training_steps=total_steps
)
# 数据过滤:跳过异常样本
def filter_bad_samples(example):
if len(example['input_ids']) > max_length:
return False
if any(tid >= vocab_size for tid in example['input_ids']):
return False
return True
4. CUDA OOM (Out of Memory)
症状
torch.cuda.OutOfMemoryError: CUDA out of memory.
Tried to allocate X.XX GiB. GPU 0 has a total capacity of XX.XX GiB
根因分析
| 场景 |
原因 |
| 训练初期 OOM |
模型/批次太大,基本配置不合理 |
| 训练中期 OOM |
序列长度不一致、内存泄漏 |
| 特定 Step OOM |
某些样本特别长 |
| 推理时 OOM |
KV Cache 增长 |
修复方案(按优先级)
- 减小 micro_batch_size
- 启用梯度检查点:
gradient_checkpointing=True
- 使用 LoRA/QLoRA 替代全量微调
- 减小最大序列长度:
max_seq_length
- 启用梯度累积:补偿减小的 batch size
- 使用 DeepSpeed ZeRO Offload:卸载到 CPU
- 清理碎片:
torch.cuda.empty_cache()
# 动态批次处理:截断超长序列
def preprocess(examples):
input_ids = tokenizer(examples["text"], truncation=True,
max_length=2048)["input_ids"]
return {"input_ids": input_ids}
5. 分布式训练问题
症状
- Loss 在不同 Rank 上不一致
- 训练挂起(Hang)
- NCCL 超时错误
排障步骤
- 检查进程同步:确保所有 Rank 执行相同的操作
- 验证梯度同步:
torch.distributed.all_reduce 后检查梯度一致
- 检查 Checkpoint 加载:确保所有 Rank 加载相同的初始权重
- 网络诊断:
nccl-tests 测试集群通信带宽
常见修复
# 确保 DDP 中 find_unused_parameters 正确设置
model = DDP(model, find_unused_parameters=False) # 如果所有参数都使用
model = DDP(model, find_unused_parameters=True) # 如果有未使用参数(如 LoRA)
# 设置 NCCL 超时
import datetime
torch.distributed.init_process_group(
backend="nccl",
timeout=datetime.timedelta(seconds=7200)
)
6. 训练速度慢
诊断
# 监控 GPU 利用率
# 命令行: nvidia-smi dmon -s u -d 1
# 在训练代码中计时
import time
start = time.time()
for step, batch in enumerate(dataloader):
loss = model(**batch).loss
loss.backward()
optimizer.step()
if step % 100 == 0:
elapsed = time.time() - start
tokens_per_sec = step * batch_size * seq_len / elapsed
print(f"Step {step}: {tokens_per_sec:.0f} tokens/s")
优化手段
| 优化 |
效果 |
实施难度 |
| 启用 Flash Attention 2 |
+30-50% 速度 |
低 |
| 使用 Unsloth |
+100-400% 速度 |
低 |
| bf16 混合精度 |
+50-100% 速度 |
低 |
| 数据预加载(dataloader workers) |
+10-30% |
低 |
| 打包(Packing) |
+20-50% |
中 |
| 使用 Liger Kernel |
+10-20% |
低 |
7. Checkpoint 问题
症状
- 恢复训练后 Loss 异常
- 加载 Checkpoint 报错
- 保存 Checkpoint 超时
最佳实践
# 定期保存,保留最近 N 个 Checkpoint
training_args = TrainingArguments(
save_strategy="steps",
save_steps=500,
save_total_limit=3, # 只保留最近 3 个
)
# 保存前验证
def safe_save(model, path):
# 验证模型参数无 NaN
for name, param in model.named_parameters():
assert not torch.isnan(param).any(), f"NaN in {name}"
model.save_pretrained(path)
# Checkpoint 恢复
trainer.train(resume_from_checkpoint="checkpoint-5000")
8. 排障检查清单
训练启动前:
训练过程中监控: