跳到主要内容

Transformer架构最佳实践

本文档总结了Transformer架构设计和实现的最佳实践。

架构设计最佳实践

1. 注意力机制优化

高效的注意力实现

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads

self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""缩放点积注意力"""
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)

attention_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)

return output, attention_weights

def forward(self, query, key, value, mask=None):
batch_size = query.size(0)

# 线性变换并分头
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

# 注意力计算
attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

# 合并多头
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)

# 输出投影
output = self.W_o(attention_output)

return output, attention_weights

2. 位置编码

位置编码实现

class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))

pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)

self.register_buffer('pe', pe)

def forward(self, x):
return x + self.pe[:x.size(0), :]

3. 前馈网络

高效的前馈网络

class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)

def forward(self, x):
return self.linear2(self.dropout(torch.relu(self.linear1(x))))

训练最佳实践

1. 学习率调度

Warmup和衰减

from transformers import get_linear_schedule_with_warmup

def create_scheduler(optimizer, num_training_steps, warmup_steps=0.1):
"""创建学习率调度器"""
num_warmup_steps = int(num_training_steps * warmup_steps)

scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)

return scheduler

2. 梯度累积

模拟大批次训练

def train_with_gradient_accumulation(model, dataloader, optimizer, accumulation_steps=4):
"""梯度累积训练"""
model.train()
optimizer.zero_grad()

for i, batch in enumerate(dataloader):
loss = model(batch) / accumulation_steps
loss.backward()

if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

3. 混合精度训练

FP16训练

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

def train_with_amp(model, batch):
"""混合精度训练"""
optimizer.zero_grad()

with autocast():
loss = model(batch)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

优化最佳实践

1. 模型压缩

模型剪枝

import torch.nn.utils.prune as prune

def prune_model(model, pruning_ratio=0.3):
"""模型剪枝"""
for module in model.modules():
if isinstance(module, nn.Linear):
prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
prune.remove(module, 'weight')

2. 量化

INT8量化

import torch.quantization as quantization

def quantize_model(model):
"""模型量化"""
model.eval()
model.qconfig = quantization.get_default_qconfig('fbgemm')
quantization.prepare(model, inplace=True)
# 校准...
quantization.convert(model, inplace=True)
return model

3. 知识蒸馏

模型蒸馏

def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
"""知识蒸馏损失"""
# 软标签损失
soft_loss = F.kl_div(
F.log_softmax(student_logits / temperature, dim=1),
F.softmax(teacher_logits / temperature, dim=1),
reduction='batchmean'
) * (temperature ** 2)

# 硬标签损失
hard_loss = F.cross_entropy(student_logits, labels)

# 组合损失
return alpha * soft_loss + (1 - alpha) * hard_loss

总结

遵循这些最佳实践可以:

  1. 提高模型性能:通过优化的架构设计
  2. 加速训练:通过高效的训练技巧
  3. 减少模型大小:通过压缩和量化
  4. 提升推理速度:通过模型优化