From 83486b97ef60f8553eeb2ec7643f9a7f34c67531 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Fri, 6 Mar 2026 09:55:16 +0800 Subject: [PATCH 1/3] Qualcomm AI Engine Direct - Minimal Inerence Runtime Core Requirment 1. Removed from_blob tensor creation 2. Compile and Linking Option optimization 3. Function visibility optimization 4. Expose Power Config to user --- CMakeLists.txt | 11 ++++ backends/qualcomm/CMakeLists.txt | 29 +++++++--- .../qualcomm/aot/python/PyQnnManagerAdaptor.h | 2 +- .../qualcomm/aot/wrappers/TensorWrapper.h | 2 +- backends/qualcomm/runtime/QnnExecuTorch.h | 10 +++- backends/qualcomm/runtime/QnnManager.cpp | 21 ++++--- backends/qualcomm/runtime/SharedBuffer.cpp | 4 +- .../backends/direct_mode/CMakeLists.txt | 1 + backends/qualcomm/tests/test_qnn_delegate.py | 56 +++++++++++++++++++ backends/qualcomm/utils/utils.py | 3 +- examples/qualcomm/utils.py | 15 ++++- 11 files changed, 128 insertions(+), 26 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c2251f933b..ca28a7d39fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,17 @@ project(executorch) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) +# Hexagon toolchain with release build complains about code in third party +# libraries. +if(${CMAKE_SYSTEM_PROCESSOR} MATCHES Hexagon AND ${CMAKE_BUILD_TYPE} STREQUAL + "Release" +) + add_compile_options( + -Wno-error=format -Wno-error=implicit-int-conversion + -Wno-error=unused-variable -Wno-error=unused-function + ) +endif() + # --- ExecuTorch Version --- # Parse version from version.txt (single source of truth) file(READ "${EXECUTORCH_ROOT}/version.txt" ET_VERSION_STRING) diff --git a/backends/qualcomm/CMakeLists.txt b/backends/qualcomm/CMakeLists.txt index c1f96b1df14..6303114d2fd 100644 --- a/backends/qualcomm/CMakeLists.txt +++ b/backends/qualcomm/CMakeLists.txt @@ -77,11 +77,12 @@ if(${ANDROID}) find_library(android_log log) endif() -add_compile_options("-Wall" "-Werror" "-Wno-sign-compare") +add_compile_options("-Wall" "-Werror" "-fvisibility=hidden") add_compile_definitions(C10_USING_CUSTOM_GENERATED_MACROS) -# GNU emit wanring for ignored attributes Unfortunately, we use [[maybe_unused]] -# which can be ignored by GNU. So we make it a warning, not an error in GNU. +# GNU emits warning for ignored attributes Unfortunately, we use +# [[maybe_unused]] which can be ignored by GNU. So we make it a warning, not an +# error in GNU. if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") add_compile_options("-Wno-error=attributes") add_link_options("-flto=auto") @@ -89,10 +90,21 @@ endif() if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") # strip symbols - add_link_options("-s") + add_link_options(LINKER:-s,--gc-sections) + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES Hexagon) + add_compile_options( + "-Os" + "-ffunction-sections" + "-fdata-sections" + "-frtti" + "-fno-exceptions" + "-fomit-frame-pointer" + "-fno-asynchronous-unwind-tables" + ) + else() - # --gc-sections is added by torch. - add_compile_options("-O3" "-ffunction-sections" "-fdata-sections" "-frtti") + add_compile_options("-O3" "-ffunction-sections" "-fdata-sections" "-frtti") + endif() endif() include_directories( @@ -230,9 +242,8 @@ target_link_libraries( qnn_schema shared_buffer qnn_dlc_manager ) target_link_libraries( - qnn_executorch_backend - PRIVATE qnn_executorch_header qnn_schema qnn_manager executorch_core - extension_tensor qnn_backend_options + qnn_executorch_backend PRIVATE qnn_executorch_header qnn_schema qnn_manager + executorch_core qnn_backend_options ) if(${CMAKE_SYSTEM_PROCESSOR} MATCHES Hexagon) diff --git a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h index e1850bc4fa8..3ce7c233559 100644 --- a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h +++ b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h @@ -250,7 +250,7 @@ class PyQnnManager { std::vector>>& op_wrappers) { QnnExecuTorchContextBinary binary_info; - for (int i = 0; i < graph_names.size(); ++i) { + for (uint32_t i = 0; i < graph_names.size(); ++i) { if (qnn_manager_->Compile(graph_names[i], op_wrappers[i]) != executorch::runtime::Error::Ok) { QNN_EXECUTORCH_LOG_ERROR("Fail to compile QNN graph"); diff --git a/backends/qualcomm/aot/wrappers/TensorWrapper.h b/backends/qualcomm/aot/wrappers/TensorWrapper.h index d4a5a67347e..6f20a807820 100644 --- a/backends/qualcomm/aot/wrappers/TensorWrapper.h +++ b/backends/qualcomm/aot/wrappers/TensorWrapper.h @@ -76,7 +76,7 @@ class TensorWrapper { rank); return; } - for (int i = 0; i < rank; ++i) { + for (size_t i = 0; i < rank; ++i) { QNN_TENSOR_VER_PTR(tensor_)->dimensions[i] = dims[i]; } } diff --git a/backends/qualcomm/runtime/QnnExecuTorch.h b/backends/qualcomm/runtime/QnnExecuTorch.h index ccd02273c4f..a6fb9438190 100644 --- a/backends/qualcomm/runtime/QnnExecuTorch.h +++ b/backends/qualcomm/runtime/QnnExecuTorch.h @@ -63,14 +63,18 @@ struct CustomMemTensorInfo { /// alignment as MemoryAllocator::kDefaultAlignment. /// See runtime/core/memory_allocator.h. The function returns a valid pointer /// if allocation is successful. -void* QnnExecuTorchAllocCustomMem(size_t bytes, size_t alignment); +__attribute__((__visibility__("default"))) void* QnnExecuTorchAllocCustomMem( + size_t bytes, + size_t alignment); /// Add tensor to custom memory with custom type descriptor. Create memory /// handle to tensor wrapper during execution -void QnnExecuTorchAddCustomMemTensorAddr(void* tensor_addr, void* custom_mem); +__attribute__((__visibility__("default"))) void +QnnExecuTorchAddCustomMemTensorAddr(void* tensor_addr, void* custom_mem); /// Free the allocated shared memory. -void QnnExecuTorchFreeCustomMem(void* buffer_ptr); +__attribute__((__visibility__("default"))) void QnnExecuTorchFreeCustomMem( + void* buffer_ptr); #ifdef __cplusplus } diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 05ae46360f7..4d86af4a622 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include #include #include @@ -427,11 +426,17 @@ Error QnnManager::Execute( QNN_TENSOR_VER_PTR(output_tensor)->dimensions + QNN_TENSOR_VER_PTR(output_tensor)->rank); - auto dump_tensor = executorch::extension::from_blob( - QNN_TENSOR_VER_PTR(output_tensor)->clientBuf.data, - sizes, + std::vector stride_size(sizes.size(), 0); + // Avoid using from_blob as it significantly increases shared library + // size. + executorch::aten::TensorImpl tensor_impl( qnn_dtype_to_scalar_type_[QNN_TENSOR_VER_PTR(output_tensor) - ->dataType]); + ->dataType], + sizes.size(), + sizes.data(), + QNN_TENSOR_VER_PTR(output_tensor)->clientBuf.data, + nullptr, + stride_size.data()); executorch::runtime::event_tracer_log_output_delegate< executorch::aten::Tensor>( @@ -439,7 +444,7 @@ Error QnnManager::Execute( QNN_TENSOR_VER_PTR(output_tensor)->name, /*delegate_debug_id=*/ static_cast(-1), - *dump_tensor); + executorch::aten::Tensor(&tensor_impl)); } } @@ -547,7 +552,7 @@ Error QnnManager::CompileDlc() { // Mapping memory address for the input and output of mutable buffer std::unordered_map mutable_buffer_id_to_memory_map; - for (int i = 0; i < graphInfo.numInputTensors; ++i) { + for (uint32_t i = 0; i < graphInfo.numInputTensors; ++i) { auto tw = CreateTensorWrapper(graphInfo.inputTensors[i]); tw->UpdateQnnTensorMeta(graphInfo.inputTensors[i]); @@ -560,7 +565,7 @@ Error QnnManager::CompileDlc() { } graph_inputs.push_back(tw); } - for (int i = 0; i < graphInfo.numOutputTensors; ++i) { + for (uint32_t i = 0; i < graphInfo.numOutputTensors; ++i) { auto tw = CreateTensorWrapper(graphInfo.outputTensors[i]); tw->UpdateQnnTensorMeta(graphInfo.outputTensors[i]); int mutable_buffer_id = ExtractMutableBufferNumber(tw->GetName()); diff --git a/backends/qualcomm/runtime/SharedBuffer.cpp b/backends/qualcomm/runtime/SharedBuffer.cpp index d79f8041932..8158bc5c6f2 100644 --- a/backends/qualcomm/runtime/SharedBuffer.cpp +++ b/backends/qualcomm/runtime/SharedBuffer.cpp @@ -21,7 +21,7 @@ std::size_t std::hash::operator()( hash_val ^= std::hash()(info.custom_mem); hash_val ^= std::hash()(info.pos); hash_val ^= std::hash()(info.tensor_bytes); - for (int i = 0; i < info.rank; ++i) { + for (size_t i = 0; i < info.rank; ++i) { hash_val ^= std::hash()(info.shape[i]); } hash_val ^= std::hash()(info.rank); @@ -36,7 +36,7 @@ bool operator==( (lhs.tensor_addr == rhs.tensor_addr && lhs.custom_mem == rhs.custom_mem && lhs.pos == rhs.pos && lhs.tensor_bytes == rhs.tensor_bytes && lhs.rank == rhs.rank && lhs.dtype == rhs.dtype); - for (int i = 0; i < lhs.rank; ++i) { + for (size_t i = 0; i < lhs.rank; ++i) { is_same &= lhs.shape[i] == rhs.shape[i]; } return is_same; diff --git a/backends/qualcomm/runtime/backends/direct_mode/CMakeLists.txt b/backends/qualcomm/runtime/backends/direct_mode/CMakeLists.txt index 137ff800f12..8beebddc343 100644 --- a/backends/qualcomm/runtime/backends/direct_mode/CMakeLists.txt +++ b/backends/qualcomm/runtime/backends/direct_mode/CMakeLists.txt @@ -45,3 +45,4 @@ target_link_libraries( ${HEXAGON_TOOLS_ROOT}/Tools/target/hexagon/lib/${DSP_VERSION}/G0/pic/libc++.so.1 ${HEXAGON_TOOLS_ROOT}/Tools/target/hexagon/lib/${DSP_VERSION}/G0/pic/libc++abi.so.1 ) +target_compile_options(qnn_executorch_skel PRIVATE "-fvisibility=default") diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4e23f43c2ea..dd097dfab30 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -29,6 +29,7 @@ from executorch.backends.qualcomm.debugger.utils import generate_optrace from executorch.backends.qualcomm.serialization.qc_schema import ( QnnExecuTorchBackendType, + QnnExecuTorchHtpPerformanceMode, ) from executorch.backends.qualcomm.tests.utils import ( convert_pt2e, @@ -4847,6 +4848,33 @@ def setUp(self): saver=False, ) + def test_qnn_backend_compile_time_option_htp_performance(self): + backend_options = generate_htp_compiler_spec( + use_fp16=True, + htp_performance_mode=QnnExecuTorchHtpPerformanceMode.kHtpHighPowerSaver, + ) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + + def output_callback(log_msg): + msg = log_msg.stdout + # Refer to HtpDevice.cpp for the following values + min_voltage = "coreVoltageCornerMin 80" + self.assertTrue(min_voltage in msg, f"Expecting '{min_voltage} ' in log") + + runtime_extra_commands = " --log_level 4" + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=runtime_extra_commands, + output_callback=partial(output_callback), + save_inference_speed=True, + ) + def test_qnn_backend_dump_intermediate_outputs_topk(self): TestQNN.dump_intermediate_outputs = True backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -5436,6 +5464,34 @@ def setUp(self): saver=False, ) + def test_qnn_backend_compile_time_option_htp_performance(self): + backend_options = generate_htp_compiler_spec( + use_fp16=False, + htp_performance_mode=QnnExecuTorchHtpPerformanceMode.kHtpHighPowerSaver, + ) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module(module, sample_input) + + def output_callback(log_msg): + msg = log_msg.stdout + # Refer to HtpDevice.cpp for the following values + min_voltage = "coreVoltageCornerMin 80" + self.assertTrue(min_voltage in msg, f"Expecting '{min_voltage} ' in log") + + runtime_extra_commands = " --log_level 4" + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=runtime_extra_commands, + output_callback=partial(output_callback), + save_inference_speed=True, + ) + def test_qnn_backend_dump_intermediate_outputs_simple_model(self): TestQNN.dump_intermediate_outputs = True backend_options = generate_htp_compiler_spec(use_fp16=False) diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index d45ef294bba..c26bd9c6022 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -994,6 +994,7 @@ def generate_htp_compiler_spec( use_multi_contexts: bool = False, use_weight_sharing: bool = False, use_slc_allocator: bool = False, + htp_performance_mode: QnnExecuTorchHtpPerformanceMode = QnnExecuTorchHtpPerformanceMode.kHtpBurst, ) -> QnnExecuTorchBackendOptions: """ Helper function generating backend options for QNN HTP @@ -1025,7 +1026,7 @@ def generate_htp_compiler_spec( # This actually is not an option which can affect the compiled blob. # But we don't have other place to pass this option at execution stage. # TODO: enable voting mechanism in runtime and make this as an option - htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst + htp_options.performance_mode = htp_performance_mode htp_options.use_multi_contexts = use_multi_contexts htp_options.use_weight_sharing = use_weight_sharing htp_options.use_dlbc = use_dlbc diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index d73f0863767..f8e768ef784 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -35,6 +35,7 @@ from executorch.backends.qualcomm.serialization.qc_schema import ( QcomChipset, QnnExecuTorchBackendType, + QnnExecuTorchHtpPerformanceMode, QnnExecuTorchOpPackageOptions, ) from executorch.backends.qualcomm.utils.constants import ( @@ -483,6 +484,7 @@ def build_executorch_binary( optrace=False, op_package_options: QnnExecuTorchOpPackageOptions = None, direct_mode_build_path=None, + htp_performance_mode: QnnExecuTorchHtpPerformanceMode = QnnExecuTorchHtpPerformanceMode.kHtpBurst, ): """ A function to generate an ExecuTorch binary for Qualcomm platforms. @@ -508,6 +510,8 @@ def build_executorch_binary( optrace (bool, optional): Enable optrace mode for performance analysis if set to True. op_package_options: Optional structure to specify op packages loaded and used by the backend. + direct_mode_build_path (string, optional): Path to build folder for direct mode. + htp_performance_mode (QnnExecuTorchHtpPerformanceMode, optional): Option to set the performance mode for htp backend. Returns: None: The function writes the output to a specified .pte file. @@ -517,7 +521,8 @@ def build_executorch_binary( backend_options = { QnnExecuTorchBackendType.kGpuBackend: generate_gpu_compiler_spec(), QnnExecuTorchBackendType.kHtpBackend: generate_htp_compiler_spec( - use_fp16=False if quant_dtype is not None else True + use_fp16=False if quant_dtype is not None else True, + htp_performance_mode=htp_performance_mode, ), }[backend] compile_spec = generate_qnn_executorch_compiler_spec( @@ -1038,6 +1043,14 @@ def setup_common_args_and_variables(): type=str, ) + parser.add_argument( + "--htp_performance_mode", + type=int, + choices=list(QnnExecuTorchHtpPerformanceMode), + help="Specify performance mode for htp from 0-8, default to burst(2). For more info, refer to qc_schema.py", + default=2, + ) + # QNN_SDK_ROOT might also be an argument, but it is used in various places. # So maybe it's fine to just use the environment. if "QNN_SDK_ROOT" not in os.environ: From 8abd7651d84830b1b2831066eb50eb896d1326bf Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Tue, 31 Mar 2026 14:25:11 +0800 Subject: [PATCH 2/3] Fix External CI --- .gitignore | 1 + CMakeLists.txt | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 4ddbb7c49ad..1d0b4882c06 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ cmake-out* cmake-out-android/ build-android/ build-x86/ +build-hexagon/ dist/ arm-scratch/ executorch.egg-info diff --git a/CMakeLists.txt b/CMakeLists.txt index ca28a7d39fd..3e2bea0d432 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,8 +52,8 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) # Hexagon toolchain with release build complains about code in third party # libraries. -if(${CMAKE_SYSTEM_PROCESSOR} MATCHES Hexagon AND ${CMAKE_BUILD_TYPE} STREQUAL - "Release" +if("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "Hexagon" AND "${CMAKE_BUILD_TYPE}" + STREQUAL "Release" ) add_compile_options( -Wno-error=format -Wno-error=implicit-int-conversion From 3c9a652a3dbfe007a92412bdb3db596045419db6 Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Tue, 7 Apr 2026 10:34:14 +0800 Subject: [PATCH 3/3] Update stride computation --- backends/qualcomm/runtime/QnnManager.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 4d86af4a622..dafdee62824 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -426,7 +426,14 @@ Error QnnManager::Execute( QNN_TENSOR_VER_PTR(output_tensor)->dimensions + QNN_TENSOR_VER_PTR(output_tensor)->rank); - std::vector stride_size(sizes.size(), 0); + // Compute contiguous strides from sizes (e.g. [2,3,4] -> [12,4,1]). + std::vector stride_size(sizes.size()); + if (!sizes.empty()) { + stride_size.back() = 1; + for (int i = sizes.size() - 2; i >= 0; --i) { + stride_size[i] = stride_size[i + 1] * sizes[i + 1]; + } + } // Avoid using from_blob as it significantly increases shared library // size. executorch::aten::TensorImpl tensor_impl(