Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7898,6 +7898,98 @@ def test_static_llm_model(self): # noqa: C901
device_inference_speed, expected_inference_speed
)

def test_static_llm_eval_limit(self):
# Verify calib and eval are truly separate: compile once with a fixed
# calib, then eval the same pte with different eval limit.
if not self.required_envs([self.model_name]):
self.skipTest("missing required envs")
assert (
self.model_name in self.llm_specs
), f"Unable to find {self.model_name} under model_specs."

def run_llama(extra_cmds):
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--prompt",
"I would like to learn python, could you teach me with a simple example?",
"--temperature",
"0",
"--decoder_model",
self.model_name,
"--model_mode",
"kv",
"--max_seq_len",
"1024",
"--max_context_len",
"1024",
"--skip_user_prompt_calibration",
"--soc_model",
self.soc_model,
"--target",
self.target,
"--ip",
self.ip,
"--port",
str(self.port),
"--seed",
str(1126),
"--backend",
self.backend,
]
cmds.extend(extra_cmds)
if self.host:
cmds.extend(["--host", self.host])
elif self.enable_x86_64:
cmds.extend(["--enable_x86_64"])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
return msg

# Compile once with a fixed calibration so quantization is identical
# across the eval runs below.
run_llama(
[
"--compile_only",
"--calib_tasks",
"wikitext",
"--calib_limit",
"1",
]
)

def eval_ppl(eval_limit):
extra_cmds = [
"--pre_gen_pte",
self.artifact_dir,
"--eval_methods",
"tasks_eval",
"--eval_tasks",
"wikitext",
"--eval_limit",
str(eval_limit),
]
if self.device:
extra_cmds.extend(["--device", self.device])
return run_llama(extra_cmds)["wiki_ppl"]

ppl_eval_1 = eval_ppl(1)
ppl_eval_3 = eval_ppl(3)
logging.info(
f"wiki_ppl: eval_limit=1: {ppl_eval_1}, eval_limit=3: {ppl_eval_3}"
)
self.assertNotEqual(ppl_eval_1, ppl_eval_3)

def test_codegen2_1b(self):
if not self.required_envs():
self.skipTest("missing required envs")
Expand Down
Loading