Transformer实现细节
基础实现框架
使用Hugging Face Transformers
Hugging Face Transformers是最流行的Transformer实现库,提供了丰富的预训练模型和便捷的API。
from transformers import AutoModel, AutoTokenizer
import torch
# 加载预训练模型
model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# 基本推理
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
print(outputs.last_hidden_state.shape)
核心组件实现
1. 多头注意力机制
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 计算Q, K, V
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 2. 重塑为多头
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. 计算注意力
attention_output = self.scaled_dot_product_attention(Q, K, V, mask)
# 4. 拼接头
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# 5. 最终线性变换
return self.W_o(attention_output)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention_weights = F.softmax(scores, dim=-1)
return torch.matmul(attention_weights, V)
2. 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
3. 前馈网络
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
4. 编码器层
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差连接 + 层归一化
attn_output = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差连接 + 层归一化
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
训练实现
数据预处理
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=512):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten()
}
# 创建数据加载器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = TextDataset(texts, tokenizer)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
模型训练
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup
def train_model(model, dataloader, num_epochs=3):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.train()
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
total_steps = len(dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
num_training_steps=total_steps
)
for epoch in range(num_epochs):
total_loss = 0
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')
推理实现
文本分类
from transformers import AutoModelForSequenceClassification
import torch.nn.functional as F
def classify_text(text, model, tokenizer):
model.eval()
with torch.no_grad():
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
outputs = model(**inputs)
predictions = F.softmax(outputs.logits, dim=-1)
return predictions.numpy()
# 使用示例
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
text = "This is a great movie!"
predictions = classify_text(text, model, tokenizer)
print(f"Predictions: {predictions}")
文本生成
from transformers import GPT2LMHeadModel, GPT2Tokenizer
def generate_text(prompt, model, tokenizer, max_length=100):
model.eval()
with torch.no_grad():
inputs = tokenizer.encode(prompt, return_tensors='pt')
outputs = model.generate(
inputs,
max_length=max_length,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# 使用示例
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
prompt = "Once upon a time"
generated_text = generate_text(prompt, model, tokenizer)
print(f"Generated text: {generated_text}")