File tree

2 files changed

+104
-14
lines changed

2 files changed

+104
-14
lines changed
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ python_binary(
4949
name = "eval_llama_qnn",
5050
srcs = ["eval_llama_qnn.py"],
5151
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
52+
preload_deps = [
53+
"//executorch/extension/llm/custom_ops:model_sharding_py",
54+
],
5255
deps = [
5356
":llama_lib",
5457
"//executorch/examples/models/llama:eval_library",
Original file line numberDiff line numberDiff line change
@@ -8,50 +8,74 @@
88
import copy
99
import json
1010

11-
from typing import List, Optional, Tuple
11+
import logging
12+
import sys
13+
14+
from typing import List, Tuple
1215

1316
import torch
1417
import torch.nn as nn
18+
from executorch.backends.qualcomm.quantizer.custom_annotation import (
19+
annotate_linear_16a8w_in_affine_layer,
20+
annotate_matmul_16a8w,
21+
)
22+
23+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
24+
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d
1525

1626
from executorch.examples.models.llama.eval_llama_lib import (
1727
build_args_parser,
1828
GraphModuleEvalWrapper,
1929
)
2030

31+
from executorch.examples.models.llama.source_transformation.quantize import (
32+
get_quant_embedding_transform,
33+
)
34+
35+
from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate
36+
2137
from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
2238
LlamaModel,
2339
ModelArgs,
2440
)
41+
42+
from executorch.examples.qualcomm.utils import make_quantizer
43+
2544
from lm_eval.evaluator import simple_evaluate
2645

2746
from pytorch_tokenizers import get_tokenizer
2847

48+
from torchao.quantization.pt2e import MinMaxObserver
49+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
50+
51+
sys.setrecursionlimit(4096)
52+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
53+
logging.basicConfig(level=logging.INFO, format=FORMAT)
54+
logging.getLogger().setLevel(logging.INFO)
55+
2956

3057
class WrappedLlamaModel(nn.Module):
31-
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"):
58+
def __init__(
59+
self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"
60+
):
3261
super(WrappedLlamaModel, self).__init__()
3362
self.model = model
3463
self.max_seq_len = max_seq_len
3564
self.use_kv_cache = use_kv_cache
3665
self.device = device
66+
self.atten_mask = atten_mask
3767

3868
def forward(
3969
self,
4070
tokens: torch.Tensor,
41-
input_pos: Optional[torch.Tensor] = None,
4271
*args,
4372
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
4473
# Pad input if necessary, since LlamaModel requires static shape
4574
if tokens.shape[1] != self.max_seq_len:
4675
tokens = torch.nn.functional.pad(
47-
tokens, (self.max_seq_len - tokens.shape[1], 0)
76+
tokens, (0, self.max_seq_len - tokens.shape[1])
4877
)
49-
atten_mask = (
50-
self.model.get_example_inputs(self.use_kv_cache)[1]
51-
.to(device=self.device)
52-
.to(dtype=torch.bfloat16)
53-
)
54-
return self.model.forward(tokens, atten_mask, input_pos, *args)
78+
return self.model.forward(tokens, self.atten_mask)
5579

5680

5781
def gen_eval_wrapper(model_name, args):
@@ -119,14 +143,69 @@ def permute(w, heads):
119143
layer.feed_forward.prepare_feedfoward_conv()
120144

121145
model.to(dtype=torch.bfloat16)
122-
model.to(args.device)
146+
model.to(device=args.device)
123147

124-
wrapped_model = WrappedLlamaModel(
125-
model, args.use_kv_cache, args.max_seq_length, args.device
148+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
149+
tokens = tokens.to(device=args.device)
150+
atten_mask = atten_mask.to(device=args.device)
151+
atten_mask = atten_mask.to(dtype=torch.bfloat16)
152+
inputs = (tokens, atten_mask)
153+
154+
if args.embedding_quantize:
155+
model = get_quant_embedding_transform(
156+
embedding_quantize=args.embedding_quantize
157+
)(model)
158+
159+
model = convert_linear_to_conv2d(model)
160+
161+
if args.ptq:
162+
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
163+
164+
custom_annotations = (annotate_matmul_16a8w,)
165+
if args.llama_model == "stories110m":
166+
custom_annotations = custom_annotations + (
167+
annotate_linear_16a8w_in_affine_layer,
168+
)
169+
quantizer = make_quantizer(
170+
quant_dtype=quant_dtype,
171+
per_channel_conv=True,
172+
per_channel_linear=True,
173+
act_observer=MinMaxObserver,
174+
)
175+
quantizer.add_custom_quant_annotations(custom_annotations)
176+
177+
model.has_quant_io = True
178+
179+
with torch.no_grad():
180+
model = torch.export.export(model, inputs, strict=True).module()
181+
if quant_dtype == QuantDtype.use_16a4w_block:
182+
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
183+
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
184+
quantizer.set_block_size_map(block_size_map)
185+
186+
model = prepare_pt2e(model, quantizer)
187+
188+
logging.info("Quantizing the model...")
189+
190+
calibrate(
191+
inputs,
192+
"Once upon a time",
193+
model,
194+
tokenizer=tokenizer,
195+
ar_len=args.prefill_ar_len,
196+
max_seq_len=args.max_seq_len,
197+
kv_updater=None,
198+
use_i64_token=use_i64_token,
199+
)
200+
201+
model = convert_pt2e(model)
202+
203+
model = WrappedLlamaModel(
204+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
126205
)
127206

128207
return GraphModuleEvalWrapper(
129-
model=wrapped_model,
208+
model=model,
130209
tokenizer=tokenizer,
131210
max_seq_length=args.calibration_seq_length,
132211
use_kv_cache=args.use_kv_cache,
@@ -167,6 +246,7 @@ def main() -> None:
167246
modelname = "llama2"
168247
parser = build_args_parser()
169248
args = parser.parse_args()
249+
args.llama_model = "llama3_2"
170250
# Overrides this arg, because evaluation requires full logits.
171251
args.generate_full_logits = True
172252

@@ -177,7 +257,14 @@ def main() -> None:
177257
args.use_kv_cache = False
178258
args.prefill_ar_len = args.max_seq_length
179259

260+
# To do fewer samples for faster evaluation
261+
args.limit = 0.1
262+
# args.samples = {'wikitext': list(range(1))}
263+
180264
args.device = "cuda" if torch.cuda.is_available() else "cpu"
265+
torch.set_default_device(args.device)
266+
267+
args.ptq = "8a8w"
181268

182269
eval_llama(modelname, args)
183270

0 commit comments

Comments
 (0)