Merged
Show file tree
Hide file tree
Changes from all commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Failed to load files.
Original file line numberDiff line numberDiff line change
Expand Up@@ -49,6 +49,9 @@ python_binary(
name = "eval_llama_qnn",
srcs = ["eval_llama_qnn.py"],
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
preload_deps = [
"//executorch/extension/llm/custom_ops:model_sharding_py",
],
deps = [
":llama_lib",
"//executorch/examples/models/llama:eval_library",
Expand Down
Original file line numberDiff line numberDiff line change
Expand Up@@ -8,50 +8,74 @@
import copy
import json

from typing import List, Optional, Tuple
import logging
import sys

from typing import List, Tuple

import torch
import torch.nn as nn
from executorch.backends.qualcomm.quantizer.custom_annotation import (
annotate_linear_16a8w_in_affine_layer,
annotate_matmul_16a8w,
)

from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d

from executorch.examples.models.llama.eval_llama_lib import (
build_args_parser,
GraphModuleEvalWrapper,
)

from executorch.examples.models.llama.source_transformation.quantize import (
get_quant_embedding_transform,
)

from executorch.examples.qualcomm.oss_scripts.llama.llama import calibrate

from executorch.examples.qualcomm.oss_scripts.llama.model.static_llama import (
LlamaModel,
ModelArgs,
)

from executorch.examples.qualcomm.utils import make_quantizer

from lm_eval.evaluator import simple_evaluate

from pytorch_tokenizers import get_tokenizer

from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

sys.setrecursionlimit(4096)
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
logging.getLogger().setLevel(logging.INFO)


class WrappedLlamaModel(nn.Module):
def __init__(self, model, use_kv_cache=False, max_seq_len=512, device="cuda"):
def __init__(
self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda"
):
super(WrappedLlamaModel, self).__init__()
self.model = model
self.max_seq_len = max_seq_len
self.use_kv_cache = use_kv_cache
self.device = device
self.atten_mask = atten_mask

def forward(
self,
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
*args,
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
# Pad input if necessary, since LlamaModel requires static shape
if tokens.shape[1] != self.max_seq_len:
tokens = torch.nn.functional.pad(
tokens, (self.max_seq_len - tokens.shape[1], 0)
tokens, (0, self.max_seq_len - tokens.shape[1])
)
atten_mask = (
self.model.get_example_inputs(self.use_kv_cache)[1]
.to(device=self.device)
.to(dtype=torch.bfloat16)
)
return self.model.forward(tokens, atten_mask, input_pos, *args)
return self.model.forward(tokens, self.atten_mask)


def gen_eval_wrapper(model_name, args):
Expand DownExpand Up@@ -119,14 +143,69 @@ def permute(w, heads):
layer.feed_forward.prepare_feedfoward_conv()

model.to(dtype=torch.bfloat16)
model.to(args.device)
model.to(device=args.device)

wrapped_model = WrappedLlamaModel(
model, args.use_kv_cache, args.max_seq_length, args.device
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
tokens = tokens.to(device=args.device)
atten_mask = atten_mask.to(device=args.device)
atten_mask = atten_mask.to(dtype=torch.bfloat16)
inputs = (tokens, atten_mask)

if args.embedding_quantize:
model = get_quant_embedding_transform(
embedding_quantize=args.embedding_quantize
)(model)

model = convert_linear_to_conv2d(model)

if args.ptq:
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")

custom_annotations = (annotate_matmul_16a8w,)
if args.llama_model == "stories110m":
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
)
quantizer = make_quantizer(
quant_dtype=quant_dtype,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations(custom_annotations)

model.has_quant_io = True

with torch.no_grad():
model = torch.export.export(model, inputs, strict=True).module()
if quant_dtype == QuantDtype.use_16a4w_block:
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
quantizer.set_block_size_map(block_size_map)

model = prepare_pt2e(model, quantizer)

logging.info("Quantizing the model...")

calibrate(
inputs,
"Once upon a time",
model,
tokenizer=tokenizer,
ar_len=args.prefill_ar_len,
max_seq_len=args.max_seq_len,
kv_updater=None,
use_i64_token=use_i64_token,
)

model = convert_pt2e(model)

model = WrappedLlamaModel(
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
)

return GraphModuleEvalWrapper(
model=wrapped_model,
model=model,
tokenizer=tokenizer,
max_seq_length=args.calibration_seq_length,
use_kv_cache=args.use_kv_cache,
Expand DownExpand Up@@ -167,6 +246,7 @@ def main() -> None:
modelname = "llama2"
parser = build_args_parser()
args = parser.parse_args()
args.llama_model = "llama3_2"
# Overrides this arg, because evaluation requires full logits.
args.generate_full_logits = True

Expand All@@ -177,7 +257,14 @@ def main() -> None:
args.use_kv_cache = False
args.prefill_ar_len = args.max_seq_length

# To do fewer samples for faster evaluation
args.limit = 0.1
# args.samples = {'wikitext': list(range(1))}

args.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(args.device)

args.ptq = "8a8w"

eval_llama(modelname, args)

Expand Down
Loading