@@ -81,14 +81,14 @@ def _load_pretrained_weights(self):
|
81 | 81 |
|
82 | 82 | # Mapping (`hf: ours`) of decoder layers
|
83 | 83 | for i in range(12):
|
84 |
| -mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.pre_norm.weight' |
85 |
| -mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.pre_norm.bias' |
| 84 | +mapping[f'transformer.h.{i}.ln_1.weight'] = f'blocks.{i}.attn_norm.weight' |
| 85 | +mapping[f'transformer.h.{i}.ln_1.bias'] = f'blocks.{i}.attn_norm.bias' |
86 | 86 | mapping[f'transformer.h.{i}.attn.c_attn.weight'] = f'blocks.{i}.attn.qkv_projection.weight'
|
87 | 87 | mapping[f'transformer.h.{i}.attn.c_attn.bias'] = f'blocks.{i}.attn.qkv_projection.bias'
|
88 | 88 | mapping[f'transformer.h.{i}.attn.c_proj.weight'] = f'blocks.{i}.attn.output_projection.weight'
|
89 | 89 | mapping[f'transformer.h.{i}.attn.c_proj.bias'] = f'blocks.{i}.attn.output_projection.bias'
|
90 |
| -mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.post_norm.weight' |
91 |
| -mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.post_norm.bias' |
| 90 | +mapping[f'transformer.h.{i}.ln_2.weight'] = f'blocks.{i}.ffn_norm.weight' |
| 91 | +mapping[f'transformer.h.{i}.ln_2.bias'] = f'blocks.{i}.ffn_norm.bias' |
92 | 92 | mapping[f'transformer.h.{i}.mlp.c_fc.weight'] = f'blocks.{i}.ffn.linear_in.weight'
|
93 | 93 | mapping[f'transformer.h.{i}.mlp.c_fc.bias'] = f'blocks.{i}.ffn.linear_in.bias'
|
94 | 94 | mapping[f'transformer.h.{i}.mlp.c_proj.weight'] = f'blocks.{i}.ffn.linear_out.weight'
|
@@ -110,7 +110,11 @@ def _load_pretrained_weights(self):
|
110 | 110 | new_state_dict[layer] = torch.transpose(new_state_dict[layer], 0, 1)
|
111 | 111 |
|
112 | 112 | # Load out model. We use `strict = False` because the state does not have LoRA weights
|
113 |
| -self.model.load_state_dict(new_state_dict, strict=False) |
| 113 | +missing_keys, unexpected_keys = self.model.load_state_dict(new_state_dict, strict=False) |
| 114 | + |
| 115 | +# make sure that only lora weights are not loaded |
| 116 | +assert all('lora' in key for key in missing_keys) |
| 117 | +assert not unexpected_keys |
114 | 118 |
|
115 | 119 | def initialize(self):
|
116 | 120 | """
|
|
0 commit comments