Transformer架构最佳实践
本文档总结了Transformer架构设计和实现的最佳实践。
架构设计最佳实践
1. 注意力机制优化
高效的注意力实现
import torch
import torch.nn as nn
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""缩放点积注意力"""
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 线性变换并分头
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 注意力计算
attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# 合并多头
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# 输出投影
output = self.W_o(attention_output)
return output, attention_weights
2. 位置编码
位置编码实现
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
3. 前馈网络
高效的前馈网络
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear2(self.dropout(torch.relu(self.linear1(x))))
训练最佳实践
1. 学习率调度
Warmup和衰减
from transformers import get_linear_schedule_with_warmup
def create_scheduler(optimizer, num_training_steps, warmup_steps=0.1):
"""创建学习率调度器"""
num_warmup_steps = int(num_training_steps * warmup_steps)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
return scheduler
2. 梯度累积
模拟大批次训练
def train_with_gradient_accumulation(model, dataloader, optimizer, accumulation_steps=4):
"""梯度累积训练"""
model.train()
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
3. 混合精度训练
FP16训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
def train_with_amp(model, batch):
"""混合精度训练"""
optimizer.zero_grad()
with autocast():
loss = model(batch)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
优化最佳实践
1. 模型压缩
模型剪枝
import torch.nn.utils.prune as prune
def prune_model(model, pruning_ratio=0.3):
"""模型剪枝"""
for module in model.modules():
if isinstance(module, nn.Linear):
prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
prune.remove(module, 'weight')
2. 量化
INT8量化
import torch.quantization as quantization
def quantize_model(model):
"""模型量化"""
model.eval()
model.qconfig = quantization.get_default_qconfig('fbgemm')
quantization.prepare(model, inplace=True)
# 校准...
quantization.convert(model, inplace=True)
return model
3. 知识蒸馏
模型蒸馏
def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
"""知识蒸馏损失"""
# 软标签损失
soft_loss = F.kl_div(
F.log_softmax(student_logits / temperature, dim=1),
F.softmax(teacher_logits / temperature, dim=1),
reduction='batchmean'
) * (temperature ** 2)
# 硬标签损失
hard_loss = F.cross_entropy(student_logits, labels)
# 组合损失
return alpha * soft_loss + (1 - alpha) * hard_loss
总结
遵循这些最佳实践可以:
- 提高模型性能:通过优化的架构设计
- 加速训练:通过高效的训练技巧
- 减少模型大小:通过压缩和量化
- 提升推理速度:通过模型优化