File tree

5 files changed

+250
-296
lines changed

5 files changed

+250
-296
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
{
2+
"cells": [
3+
{
4+
"metadata": {},
5+
"cell_type": "code",
6+
"source": [
7+
"from labml_nn.lora.experiment import Configs\n",
8+
"from labml import experiment"
9+
],
10+
"id": "1b9da2e59ffce5d5",
11+
"outputs": [],
12+
"execution_count": null
13+
},
14+
{
15+
"cell_type": "code",
16+
"id": "initial_id",
17+
"metadata": {
18+
"collapsed": true
19+
},
20+
"source": "experiment.create(name=\"lora_gpt2\")",
21+
"outputs": [],
22+
"execution_count": null
23+
},
24+
{
25+
"metadata": {},
26+
"cell_type": "code",
27+
"source": "conf = Configs()",
28+
"id": "31c9bc08eca2592",
29+
"outputs": [],
30+
"execution_count": null
31+
},
32+
{
33+
"metadata": {},
34+
"cell_type": "code",
35+
"source": "experiment.configs(conf)",
36+
"id": "fb6ce74326558948",
37+
"outputs": [],
38+
"execution_count": null
39+
},
40+
{
41+
"metadata": {},
42+
"cell_type": "code",
43+
"source": "conf.initialize()",
44+
"id": "1456cfab47dee3b",
45+
"outputs": [],
46+
"execution_count": null
47+
},
48+
{
49+
"metadata": {},
50+
"cell_type": "code",
51+
"source": [
52+
"with experiment.start():\n",
53+
" conf.run()"
54+
],
55+
"id": "3fe4068fd2df9094",
56+
"outputs": [],
57+
"execution_count": null
58+
},
59+
{
60+
"metadata": {},
61+
"cell_type": "code",
62+
"source": "",
63+
"id": "d3c3c723ebbe854a",
64+
"outputs": [],
65+
"execution_count": null
66+
}
67+
],
68+
"metadata": {
69+
"kernelspec": {
70+
"display_name": "Python (ml)",
71+
"language": "python",
72+
"name": "ml"
73+
},
74+
"language_info": {
75+
"codemirror_mode": {
76+
"name": "ipython",
77+
"version": 2
78+
},
79+
"file_extension": ".py",
80+
"mimetype": "text/x-python",
81+
"name": "python",
82+
"nbconvert_exporter": "python",
83+
"pygments_lexer": "ipython2",
84+
"version": "2.7.6"
85+
}
86+
},
87+
"nbformat": 4,
88+
"nbformat_minor": 5
89+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
from labml import lab, monit, tracker
3+
from labml.configs import BaseConfigs, option
4+
from labml.utils.download import download_file
5+
from labml_helpers.device import DeviceConfigs
6+
from torch.optim import Adam
7+
from torch.utils.data import DataLoader, TensorDataset
8+
from transformers import AutoTokenizer, AutoModelForCausalLM
9+
from labml_nn.lora.gpt2 import GPTModel
10+
11+
12+
class Configs(BaseConfigs):
13+
device: torch.device = DeviceConfigs()
14+
layer_norm_epsilon: float = 1e-05
15+
n_embed: int = 768
16+
n_layer: int = 12
17+
n_positions: int = 1024
18+
vocab_size: int = 50257
19+
epochs: int = 10
20+
batch_size: int = 32
21+
learning_rate: float = 1e-4
22+
context_len: int = 512
23+
r: int = 32
24+
25+
text: TensorDataset = "tiny_shakespeare"
26+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
27+
model: GPTModel
28+
optimizer: torch.optim.Adam
29+
criterion = torch.nn.CrossEntropyLoss()
30+
data_loader: DataLoader
31+
32+
def _load_pretrained_weights(self):
33+
hf_model = AutoModelForCausalLM.from_pretrained("gpt2")
34+
35+
state_dict = hf_model.state_dict()
36+
37+
mapping = {
38+
'transformer.wte.weight': 'token_embedding.weight',
39+
'transformer.wpe.weight': 'position_embedding.weight',
40+
'transformer.ln_f.weight': 'final_norm.weight',
41+
'transformer.ln_f.bias': 'final_norm.bias',
42+
'lm_head.weight': 'lm_head.weight'
43+
}
44+
45+
for i in range(12):
46+
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
47+
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
48+
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.c_att.weight'
49+
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.c_att.bias'
50+
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.c_proj.weight'
51+
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.c_proj.bias'
52+
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
53+
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
54+
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.c_fc.weight'
55+
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.c_fc.bias'
56+
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.c_proj.weight'
57+
mapping[f'transformer.h.{i}.mlp.c_proj.bias'] = f'blocks.{i}.ffn.c_proj.bias'
58+
59+
new_state_dict = {}
60+
for old_key, new_key in mapping.items():
61+
if old_key in state_dict:
62+
new_state_dict[new_key] = state_dict[old_key]
63+
64+
# transpose weight matrices of convo 1d layers to use linear layers instead
65+
convo_layers = ([f'blocks.{i}.ffn.c_fc.weight' for i in range(12)] +
66+
[f'blocks.{i}.ffn.c_proj.weight' for i in range(12)] +
67+
[f'blocks.{i}.attn.c_att.weight' for i in range(12)] +
68+
[f'blocks.{i}.attn.c_proj.weight' for i in range(12)])
69+
70+
for layer in convo_layers:
71+
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
72+
73+
self.model.load_state_dict(new_state_dict, strict=False) # state dict does not have lora weights
74+
75+
del hf_model
76+
del state_dict
77+
del new_state_dict
78+
79+
def initialize(self):
80+
self.model = GPTModel(
81+
layer_norm_epsilon=self.layer_norm_epsilon,
82+
n_embd=self.n_embed,
83+
n_layer=self.n_layer,
84+
n_positions=self.n_positions,
85+
vocab_size=self.vocab_size,
86+
r=self.r,
87+
device=self.device
88+
).to(self.device)
89+
self._load_pretrained_weights()
90+
91+
self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)
92+
93+
self.data_loader = DataLoader(self.text, batch_size=self.batch_size, shuffle=True)
94+
95+
def run(self):
96+
for _ in monit.loop(self.epochs):
97+
for i, batch in monit.enum('Train', self.data_loader):
98+
inputs = batch[0]
99+
inputs = inputs.to(self.device)
100+
labels = inputs.clone()
101+
102+
outputs = self.model(inputs)
103+
104+
shift_logits = outputs[..., :-1, :]
105+
shift_labels = labels[..., 1:]
106+
107+
loss = self.criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
108+
109+
self.optimizer.zero_grad()
110+
loss.backward()
111+
self.optimizer.step()
112+
113+
tracker.add({'loss': loss})
114+
115+
tracker.save()
116+
tracker.add_global_step()
117+
tracker.new_line()
118+
119+
120+
@option(Configs.text)
121+
def tiny_shakespeare(c: Configs):
122+
"""
123+
### Tiny Shakespeare dataset
124+
125+
It will download from the url if not present
126+
"""
127+
path = lab.get_data_path() / 'tiny_shakespeare.txt'
128+
if not path.exists():
129+
download_file("https://raw.usercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", path)
130+
with open(path, 'r', encoding='utf-8') as f:
131+
text = f.read()
132+
133+
tokens = c.tokenizer.encode(text)
134+
num_batches = len(tokens) // (c.batch_size * c.context_len)
135+
tokens = tokens[:num_batches * c.batch_size * c.context_len]
136+
input_ids = torch.tensor(tokens).view(-1, c.context_len)
137+
return TensorDataset(input_ids)
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,13 @@
11
import torch
22
import torch.nn as nn
3-
from transformers import AutoTokenizer
43
from labml_nn.lora import Linear, Embedding
54

6-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
7-
8-
config = {
9-
"layer_norm_epsilon": 1e-05,
10-
"n_embd": 768,
11-
"n_head": 12,
12-
"n_layer": 12,
13-
"n_positions": 1024,
14-
"vocab_size": 50257,
15-
"device": "cuda"
16-
}
17-
185

196
class FFN(nn.Module):
20-
def __init__(self, dim):
7+
def __init__(self, dim: int, n_embed: int, r: int):
218
super().__init__()
22-
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
23-
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
9+
self.c_fc = Linear(n_embed, dim, r=r, bias=True)
10+
self.c_proj = Linear(dim, n_embed, r=r, bias=True)
2411
self.act = nn.functional.gelu
2512

2613
def forward(self, hidden_states):
@@ -31,15 +18,15 @@ def forward(self, hidden_states):
3118

3219

3320
class MultiHeadAttention(nn.Module):
34-
def __init__(self):
21+
def __init__(self, n_embed: int, r: int):
3522
super().__init__()
36-
self.embed_dim = config['n_embd']
37-
self.num_heads = config['n_head']
23+
self.embed_dim = n_embed
24+
self.num_heads = n_embed
3825
self.head_dim = self.embed_dim // self.num_heads
3926
self.split_size = self.embed_dim
4027

41-
self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
42-
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)
28+
self.c_att = Linear(n_embed, n_embed * 3, r=r, bias=True)
29+
self.c_proj = Linear(n_embed, n_embed, r=r, bias=True)
4330

4431
def _split_heads(self, tensor, num_heads, attn_head_size):
4532
"""
@@ -76,12 +63,12 @@ def forward(self, hidden_states):
7663

7764

7865
class Block(nn.Module):
79-
def __init__(self):
66+
def __init__(self, n_embed: int, layer_norm_epsilon: float, r: int):
8067
super().__init__()
81-
self.pre_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
82-
self.attn = MultiHeadAttention()
83-
self.post_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
84-
self.ffn = FFN(config['n_embd'] * 4)
68+
self.pre_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
69+
self.attn = MultiHeadAttention(n_embed, r)
70+
self.post_norm = nn.LayerNorm(n_embed, eps=layer_norm_epsilon)
71+
self.ffn = FFN(n_embed * 4, n_embed, r)
8572

8673
def forward(self, hidden_states):
8774
residual = hidden_states
@@ -99,23 +86,27 @@ def forward(self, hidden_states):
9986

10087

10188
class GPTModel(nn.Module):
102-
def __init__(self):
89+
def __init__(self, layer_norm_epsilon: float, n_embd: int, n_layer: int, n_positions: int,
90+
vocab_size: int, r: int, device: torch.device):
10391
super().__init__()
10492

105-
self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
106-
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)
93+
self.token_embedding = Embedding(vocab_size, n_embd, r=r)
94+
self.position_embedding = Embedding(n_positions, n_embd, r=r)
95+
96+
self.blocks = nn.ModuleList([Block(n_embd, layer_norm_epsilon, r=r)
97+
for _ in range(n_layer)])
10798

108-
self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])
99+
self.final_norm = nn.LayerNorm(n_embd, eps=layer_norm_epsilon)
109100

110-
self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])
101+
self.lm_head = Linear(n_embd, vocab_size, r=r, bias=False)
111102

112-
self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)
103+
self.device = device
113104

114105
def forward(self, input_ids):
115106
batch_size, input_shape = input_ids.size()
116107

117108
token_embeddings = self.token_embedding(input_ids) # B T C
118-
position_ids = torch.arange(input_shape, device=config['device']) # T C
109+
position_ids = torch.arange(input_shape, device=self.device) # T C
119110
position_embeddings = self.position_embedding(position_ids) # B T C
120111

121112
hidden_states = token_embeddings + position_embeddings

0 commit comments

Comments
 (0)