|
8 | 8 | import copy
|
9 | 9 | import json
|
10 | 10 |
|
11 |
| -from typing import List, Optional, Tuple |
| 11 | +import logging |
| 12 | +import sys |
| 13 | + |
| 14 | +from typing import List, Tuple |
12 | 15 |
|
13 | 16 | import torch
|
14 | 17 | import torch.nn as nn
|
| 18 | +from executorch.backends.qualcomm.quantizer.custom_annotation import ( |
| 19 | +annotate_linear_16a8w_in_affine_layer, |
| 20 | +annotate_matmul_16a8w, |
| 21 | +) |
| 22 | + |
| 23 | +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype |
| 24 | +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d |
15 | 25 |
|
16 | 26 | from executorch.examples.models.llama.eval_llama_lib import (
|
17 | 27 | build_args_parser,
|
18 | 28 | GraphModuleEvalWrapper,
|
19 | 29 | )
|
20 | 30 |
|
| 31 | +from executorch.examples.models.llama.source_transformation.quantize import ( |
| 32 | +get_quant_embedding_transform, |
| 33 | +) |
| 34 | + |
| 35 | +from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate |
| 36 | + |
21 | 37 | from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
|
22 | 38 | LlamaModel,
|
23 | 39 | ModelArgs,
|
24 | 40 | )
|
| 41 | + |
| 42 | +from executorch.examples.qualcomm.utils import make_quantizer |
| 43 | + |
25 | 44 | from lm_eval.evaluator import simple_evaluate
|
26 | 45 |
|
27 | 46 | from pytorch_tokenizers import get_tokenizer
|
28 | 47 |
|
| 48 | +from torchao.quantization.pt2e import MinMaxObserver |
| 49 | +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 50 | + |
| 51 | +sys.setrecursionlimit(4096) |
| 52 | +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
| 53 | +logging.basicConfig(level=logging.INFO, format=FORMAT) |
| 54 | +logging.getLogger().setLevel(logging.INFO) |
| 55 | + |
29 | 56 |
|
30 | 57 | class WrappedLlamaModel(nn.Module):
|
31 |
| -def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"): |
| 58 | +def __init__( |
| 59 | +self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda" |
| 60 | +): |
32 | 61 | super(WrappedLlamaModel, self).__init__()
|
33 | 62 | self.model = model
|
34 | 63 | self.max_seq_len = max_seq_len
|
35 | 64 | self.use_kv_cache = use_kv_cache
|
36 | 65 | self.device = device
|
| 66 | +self.atten_mask = atten_mask |
37 | 67 |
|
38 | 68 | def forward(
|
39 | 69 | self,
|
40 | 70 | tokens: torch.Tensor,
|
41 |
| -input_pos: Optional[torch.Tensor] = None, |
42 | 71 | *args,
|
43 | 72 | ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
44 | 73 | # Pad input if necessary, since LlamaModel requires static shape
|
45 | 74 | if tokens.shape[1] != self.max_seq_len:
|
46 | 75 | tokens = torch.nn.functional.pad(
|
47 |
| -tokens, (self.max_seq_len - tokens.shape[1], 0) |
| 76 | +tokens, (0, self.max_seq_len - tokens.shape[1]) |
48 | 77 | )
|
49 |
| -atten_mask = ( |
50 |
| -self.model.get_example_inputs(self.use_kv_cache)[1] |
51 |
| -.to(device=self.device) |
52 |
| -.to(dtype=torch.bfloat16) |
53 |
| -) |
54 |
| -return self.model.forward(tokens, atten_mask, input_pos, *args) |
| 78 | +return self.model.forward(tokens, self.atten_mask) |
55 | 79 |
|
56 | 80 |
|
57 | 81 | def gen_eval_wrapper(model_name, args):
|
@@ -119,14 +143,69 @@ def permute(w, heads):
|
119 | 143 | layer.feed_forward.prepare_feedfoward_conv()
|
120 | 144 |
|
121 | 145 | model.to(dtype=torch.bfloat16)
|
122 |
| -model.to(args.device) |
| 146 | +model.to(device=args.device) |
123 | 147 |
|
124 |
| -wrapped_model = WrappedLlamaModel( |
125 |
| -model, args.use_kv_cache, args.max_seq_length, args.device |
| 148 | +tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) |
| 149 | +tokens = tokens.to(device=args.device) |
| 150 | +atten_mask = atten_mask.to(device=args.device) |
| 151 | +atten_mask = atten_mask.to(dtype=torch.bfloat16) |
| 152 | +inputs = (tokens, atten_mask) |
| 153 | + |
| 154 | +if args.embedding_quantize: |
| 155 | +model = get_quant_embedding_transform( |
| 156 | +embedding_quantize=args.embedding_quantize |
| 157 | +)(model) |
| 158 | + |
| 159 | +model = convert_linear_to_conv2d(model) |
| 160 | + |
| 161 | +if args.ptq: |
| 162 | +quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") |
| 163 | + |
| 164 | +custom_annotations = (annotate_matmul_16a8w,) |
| 165 | +if args.llama_model == "stories110m": |
| 166 | +custom_annotations = custom_annotations + ( |
| 167 | +annotate_linear_16a8w_in_affine_layer, |
| 168 | +) |
| 169 | +quantizer = make_quantizer( |
| 170 | +quant_dtype=quant_dtype, |
| 171 | +per_channel_conv=True, |
| 172 | +per_channel_linear=True, |
| 173 | +act_observer=MinMaxObserver, |
| 174 | +) |
| 175 | +quantizer.add_custom_quant_annotations(custom_annotations) |
| 176 | + |
| 177 | +model.has_quant_io = True |
| 178 | + |
| 179 | +with torch.no_grad(): |
| 180 | +model = torch.export.export(model, inputs, strict=True).module() |
| 181 | +if quant_dtype == QuantDtype.use_16a4w_block: |
| 182 | +conv_nodes = [n for n in model.graph.nodes if "conv" in n.name] |
| 183 | +block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} |
| 184 | +quantizer.set_block_size_map(block_size_map) |
| 185 | + |
| 186 | +model = prepare_pt2e(model, quantizer) |
| 187 | + |
| 188 | +logging.info("Quantizing the model...") |
| 189 | + |
| 190 | +calibrate( |
| 191 | +inputs, |
| 192 | +"Once upon a time", |
| 193 | +model, |
| 194 | +tokenizer=tokenizer, |
| 195 | +ar_len=args.prefill_ar_len, |
| 196 | +max_seq_len=args.max_seq_len, |
| 197 | +kv_updater=None, |
| 198 | +use_i64_token=use_i64_token, |
| 199 | +) |
| 200 | + |
| 201 | +model = convert_pt2e(model) |
| 202 | + |
| 203 | +model = WrappedLlamaModel( |
| 204 | +model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device |
126 | 205 | )
|
127 | 206 |
|
128 | 207 | return GraphModuleEvalWrapper(
|
129 |
| -model=wrapped_model, |
| 208 | +model=model, |
130 | 209 | tokenizer=tokenizer,
|
131 | 210 | max_seq_length=args.calibration_seq_length,
|
132 | 211 | use_kv_cache=args.use_kv_cache,
|
@@ -167,6 +246,7 @@ def main() -> None:
|
167 | 246 | modelname = "llama2"
|
168 | 247 | parser = build_args_parser()
|
169 | 248 | args = parser.parse_args()
|
| 249 | +args.llama_model = "llama3_2" |
170 | 250 | # Overrides this arg, because evaluation requires full logits.
|
171 | 251 | args.generate_full_logits = True
|
172 | 252 |
|
@@ -177,7 +257,14 @@ def main() -> None:
|
177 | 257 | args.use_kv_cache = False
|
178 | 258 | args.prefill_ar_len = args.max_seq_length
|
179 | 259 |
|
| 260 | +# To do fewer samples for faster evaluation |
| 261 | +args.limit = 0.1 |
| 262 | +# args.samples = {'wikitext': list(range(1))} |
| 263 | + |
180 | 264 | args.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 265 | +torch.set_default_device(args.device) |
| 266 | + |
| 267 | +args.ptq = "8a8w" |
181 | 268 |
|
182 | 269 | eval_llama(modelname, args)
|
183 | 270 |
|
|
0 commit comments