File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool,
7979
# Matrix $A \in \mathbb{R}^{r \times k}$
8080
self.lora_a = nn.Parameter(torch.empty((r, in_features)))
8181
# Matrix $B \in \mathbb{R}^{d \times r}$, we keep $A$ and $B$ transposed
82-
self.lora_b = nn.Parameter(torch.empty((outfeatures, r)))
82+
self.lora_b = nn.Parameter(torch.empty((out_features, r)))
8383

8484
with torch.no_grad():
8585
# Initialize $A$ similar to a weight matrix in a normal linear layer
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ def _load_pretrained_weights(self):
8181

8282
# Mapping (`hf: ours`) of decoder layers
8383
for i in range(12):
84-
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight'
85-
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias'
84+
mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight'
85+
mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias'
8686
mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
8787
mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
8888
mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
8989
mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
90-
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight'
91-
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias'
90+
mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight'
91+
mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias'
9292
mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
9393
mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
9494
mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
@@ -110,7 +110,11 @@ def _load_pretrained_weights(self):
110110
new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
111111

112112
# Load out model. We use `strict = False` because the state does not have LoRA weights
113-
self.model.load_state_dict(new_state_dict, strict=False)
113+
missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False)
114+
115+
# make sure that only lora weights are not loaded
116+
assert all('lora' in key for key in missing_keys)
117+
assert not unexpected_keys
114118

115119
def initialize(self):
116120
"""

0 commit comments

Comments
 (0)