🐛 Describe the bug
When preparing a simple "per channel scale" model for QAT using EthosUQuantizer + prepare_qat_pt2e, the resulting prepared_model is skipping/missing the scale Parameter. This fails the QAT training flow because optimizers fail with:
ValueError: optimizer got an empty parameter list
The same setup (model, data, optimizer) works with the XNNPACK quantizer, so this appears specific to the Arm Ethos-U quantizer path.
Script to reproduce:
import torch
from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec
from executorch.backends.arm.quantizer import EthosUQuantizer, get_symmetric_quantization_config
from torch.export import export
from torchao.quantization.pt2e import move_exported_model_to_train
from torchao.quantization.pt2e.quantize_pt2e import prepare_qat_pt2e
SEED = 0
NUM_CHANNELS = 3
NUM_STEPS = 8
LR = 1.0
MOMENTUM = 0.3
class MyScale(torch.nn.Module):
"""Per-channel scale module."""
def __init__(self, num_channels: int) -> None:
super().__init__()
self.num_channels = num_channels
self.scale = torch.nn.Parameter(torch.ones(num_channels))
def forward(self, x: torch.Tensor) -> torch.Tensor:
shape = [1] * x.dim() # [1, 1, ...]
shape[1] = self.num_channels # [1, NUM_CHANNELS, 1, 1, ...]
return x * self.scale.view(shape)
def _make_batch() -> tuple[torch.Tensor, torch.Tensor]:
x = torch.randn(16, NUM_CHANNELS, 8, 12, dtype=torch.float32)
# Simulate a target that applies a different scale to each channel.
# These are the numbers the model should learn
y = x * torch.tensor([2.7, 0.6, -1.4], dtype=torch.float32).view(1, NUM_CHANNELS, 1, 1)
return x, y
def train_step(model: torch.nn.Module, optimizer: torch.optim.Optimizer) -> None:
x, y = _make_batch()
pred = model(x)
loss = torch.nn.functional.mse_loss(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss
torch.manual_seed(SEED)
# Float
model = MyScale(NUM_CHANNELS)
print(f"[float] num model parameters: {len(list(model.parameters()))}")
print("[float] training ...")
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
for step in range(NUM_STEPS):
loss = train_step(model, optimizer)
params = [p.data.squeeze() for p in model.parameters()]
print(f" step={step + 1} loss={loss.item():.4f} {params=}")
# QAT
model = MyScale(NUM_CHANNELS)
example_inputs, _ = _make_batch()
exported_graph = export(model, (example_inputs,)).module(check_guards=False)
quantizer = EthosUQuantizer(EthosUCompileSpec(target="ethos-u55-128"))
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
prepared_model = prepare_qat_pt2e(exported_graph, quantizer)
prepared_model = move_exported_model_to_train(prepared_model)
print(f"[qat] num model parameters: {len(list(prepared_model.parameters()))}")
print("[qat] training ...")
optimizer = torch.optim.SGD(prepared_model.parameters(), lr=LR, momentum=MOMENTUM)
for step in range(NUM_STEPS):
loss = train_step(prepared_model, optimizer)
params = [p.data.squeeze() for p in model.parameters()]
print(f" step={step + 1} loss={loss.item():.4f} {params=}")
Actual output:
[float] num model parameters: 1
[float] training ...
step=1 loss=2.8747 params=[tensor([ 2.1537, 0.7403, -0.5351])]
step=2 loss=0.3676 params=[tensor([ 2.8816, 0.5678, -1.5891])]
step=3 loss=0.0239 params=[tensor([ 2.9718, 0.5378, -1.7790])]
step=4 loss=0.0732 params=[tensor([ 2.8246, 0.5720, -1.5816])]
step=5 loss=0.0160 params=[tensor([ 2.6995, 0.6003, -1.4040])]
step=6 loss=0.0000 params=[tensor([ 2.6623, 0.6086, -1.3481])]
step=7 loss=0.0014 params=[tensor([ 2.6767, 0.6055, -1.3646])]
step=8 loss=0.0006 params=[tensor([ 2.6965, 0.6009, -1.3937])]
[qat] num model parameters: 0
[qat] training ...
Traceback (most recent call last):
File "/home/wsluser/proj/nn-deploy-kit/examples/inspect_qat_repeated_backward/run_minimal_scale.py", line 74, in <module>
optimizer = torch.optim.SGD(prepared_model.parameters(), lr=LR, momentum=MOMENTUM)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/wsluser/proj/nn-deploy-kit/.pixi/envs/default/lib/python3.12/site-packages/torch/optim/sgd.py", line 65, in __init__
super().__init__(params, defaults)
File "/home/wsluser/proj/nn-deploy-kit/.pixi/envs/default/lib/python3.12/site-packages/torch/optim/optimizer.py", line 403, in __init__
raise ValueError("optimizer got an empty parameter list")
ValueError: optimizer got an empty parameter list
Expected behavior: QAT training loop should optimize the trainable parameter without crashing.
Versions
Collecting environment information...
PyTorch version: 2.11.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (conda-forge gcc 14.3.0-15) 14.3.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.39
Python version: 3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.6.114.1-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz
CPU family: 6
Model: 140
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 1
BogoMIPS: 5990.42
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves vnmi avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid movdiri movdir64b fsrm avx512_vp2intersect md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 192 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 5 MiB (4 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Indirect target selection: Mitigation; Aligned branch/return thunks
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected
Versions of relevant libraries:
[pip3] executorch==1.3.0a0+fa857bd
[pip3] numpy==2.4.3
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.19.0.56
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] pytorch-lightning==2.6.1
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.11.0
[pip3] torchao==0.17.0+git42bcdc491
[pip3] torchaudio==2.11.0+cpu
[pip3] torchdata==0.11.0
[pip3] torchmetrics==1.9.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.0.0
[pip3] torchvision==0.26.0+cpu
[pip3] triton==3.6.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani
🐛 Describe the bug
When preparing a simple "per channel scale" model for QAT using EthosUQuantizer +
prepare_qat_pt2e, the resulting prepared_model is skipping/missing the scale Parameter. This fails the QAT training flow because optimizers fail with:The same setup (model, data, optimizer) works with the XNNPACK quantizer, so this appears specific to the Arm Ethos-U quantizer path.
Script to reproduce:
Actual output:
Expected behavior: QAT training loop should optimize the trainable parameter without crashing.
Versions
Collecting environment information...
PyTorch version: 2.11.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (conda-forge gcc 14.3.0-15) 14.3.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.39
Python version: 3.12.12 | packaged by conda-forge | (main, Oct 22 2025, 23:25:55) [GCC 14.3.0] (64-bit runtime)
Python platform: Linux-6.6.114.1-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz
CPU family: 6
Model: 140
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 1
BogoMIPS: 5990.42
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves vnmi avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid movdiri movdir64b fsrm avx512_vp2intersect md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 192 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 5 MiB (4 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Indirect target selection: Mitigation; Aligned branch/return thunks
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected
Versions of relevant libraries:
[pip3] executorch==1.3.0a0+fa857bd
[pip3] numpy==2.4.3
[pip3] nvidia-cublas==13.1.0.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.19.0.56
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.0
[pip3] nvidia-nccl-cu13==2.28.9
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] pytorch-lightning==2.6.1
[pip3] pytorch_tokenizers==1.1.0
[pip3] torch==2.11.0
[pip3] torchao==0.17.0+git42bcdc491
[pip3] torchaudio==2.11.0+cpu
[pip3] torchdata==0.11.0
[pip3] torchmetrics==1.9.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.0.0
[pip3] torchvision==0.26.0+cpu
[pip3] triton==3.6.0
[conda] Could not collect
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani