From 754c30cfcdc0bb196aff8cec59186b66535d0c22 Mon Sep 17 00:00:00 2001 From: Martin Molzer Date: Sun, 28 Jun 2026 14:38:56 +0200 Subject: [PATCH] disable reference-types hack by shimming a few more functions, we can finally work around the reference-types transformations in wasm-bindgen. Specifically, casts are transformed into imports, and the signature of imports are transformed by externref passes. We must make sure that these imports do not land in the indirect function table, because only the main module is will also get fixed for these new signatures. --- crates/wasm_split_cli/src/emit.rs | 237 ++++++++++++++--------- crates/wasm_split_cli/src/split_point.rs | 20 +- 2 files changed, 169 insertions(+), 88 deletions(-) diff --git a/crates/wasm_split_cli/src/emit.rs b/crates/wasm_split_cli/src/emit.rs index ade6116..cc9eb63 100644 --- a/crates/wasm_split_cli/src/emit.rs +++ b/crates/wasm_split_cli/src/emit.rs @@ -557,16 +557,20 @@ pub struct OutputExport<'a> { pub index: u32, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Copy)] -enum OutputFunctionKind { - Defined, - IndirectStub, -} - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct OutputFunction { - input_func_id: InputFuncId, - kind: OutputFunctionKind, +enum OutputFunction { + DefinedFromInput { + input_func_id: InputFuncId, + local_index: usize, + }, + IndirectCallShim { + input_func_id: InputFuncId, + shim_index: usize, + }, + LocalShim { + input_func_id: InputFuncId, + shim_index: usize, + }, } struct ModuleEmitState<'a> { @@ -581,6 +585,7 @@ struct ModuleEmitState<'a> { defined_functions: Vec, dep_to_local_index: HashMap, + local_shims: HashMap, } impl RelocTarget for ModuleEmitState<'_> { @@ -672,6 +677,7 @@ impl<'a> ModuleEmitState<'a> { let mut exports = vec![]; let mut defined_functions = vec![]; let mut dep_to_local_index = HashMap::new(); + let mut local_shims = HashMap::new(); for (func, func_import) in emit_state.input_module.imported_funcs.iter().enumerate() { if !output_module_info @@ -680,6 +686,10 @@ impl<'a> ModuleEmitState<'a> { { continue; } + debug_assert!( + output_module_index == 0, + "expected a function import to happen in the main module" + ); let import = &emit_state.input_module.imports[func_import.import_id]; let local_func = num_func_imports; dep_to_local_index.insert(DepNode::Function(func), local_func); @@ -692,23 +702,29 @@ impl<'a> ModuleEmitState<'a> { } for dep in &output_module_info.included_symbols { - // define it - match dep { - DepNode::Function(func) - if *func >= emit_state.input_module.imported_funcs.len() => + if let &dep @ DepNode::Function(input_func) = dep { + let is_import = input_func < emit_state.input_module.imported_funcs.len(); + if !is_import { + let local_func = num_func_imports + defined_functions.len(); + defined_functions.push(OutputFunction::DefinedFromInput { + input_func_id: input_func, + local_index: local_func, + }); + dep_to_local_index.insert(dep, local_func); + } + if (is_import && program_info.shared_deps.contains(&dep)) + || program_info.needs_shim_in_main.contains(&input_func) { let local_func = num_func_imports + defined_functions.len(); - dep_to_local_index.insert(DepNode::Function(*func), local_func); - defined_functions.push(OutputFunction { - input_func_id: *func, - kind: OutputFunctionKind::Defined, + defined_functions.push(OutputFunction::LocalShim { + input_func_id: input_func, + shim_index: local_func, }); + local_shims.insert(input_func, local_func); + tracing::trace!("Adding shim for {input_func} -> {local_func}"); } - // already processed in the previous loop - DepNode::Function(_) => {} - // definition is implicit in the generate methods - _ => {} } + // for other types of symbols, the definition is implicit } let mut also_needs_indirect_table = @@ -721,12 +737,12 @@ impl<'a> ModuleEmitState<'a> { match used_shared { DepNode::Function(func) => { let local_func = num_func_imports + defined_functions.len(); - dep_to_local_index.insert(DepNode::Function(*func), local_func); - also_needs_indirect_table = true; // Because we use it in the indirect stub - defined_functions.push(OutputFunction { + defined_functions.push(OutputFunction::IndirectCallShim { input_func_id: *func, - kind: OutputFunctionKind::IndirectStub, + shim_index: local_func, }); + dep_to_local_index.insert(DepNode::Function(*func), local_func); + also_needs_indirect_table = true; // Because we use it in the indirect stub } dep @ DepNode::Global(_) | dep @ DepNode::Table(_) @@ -829,6 +845,7 @@ impl<'a> ModuleEmitState<'a> { imports, exports, dep_to_local_index, + local_shims, } } @@ -902,8 +919,15 @@ impl<'a> ModuleEmitState<'a> { fn generate_function_section(&mut self) { let mut section = wasm_encoder::FunctionSection::new(); - for OutputFunction { input_func_id, .. } in self.defined_functions.iter() { - section.function(self.input_module.func_type_id(*input_func_id) as u32); + for defined_func in self.defined_functions.iter() { + let func_ty = match defined_func { + OutputFunction::DefinedFromInput { input_func_id, .. } + | OutputFunction::LocalShim { input_func_id, .. } + | OutputFunction::IndirectCallShim { input_func_id, .. } => { + self.input_module.func_type_id(*input_func_id) + } + }; + section.function(func_ty as u32); } self.output_module.section(§ion); } @@ -1026,14 +1050,20 @@ impl<'a> ModuleEmitState<'a> { .map(|table_index| -> Result { let input_func_id = self.emit_state.indirect_functions.table_entries[table_index - 1]; - let output_func_id = *self - .dep_to_local_index - .get(&DepNode::Function(input_func_id)) - .ok_or_else(|| { - anyhow!( - "No output function corresponding to input function {input_func_id:?}" - ) - })?; + // If we have a local shim, we insert that into the indirect function table instead + // of the import directly. + // The reason is that wasm-bindgen will rewrite the signature of externref imports. + // Having such a function in the IFT would lead to mismatching signatures. The call + // in the shim will get rewritten correctly (in the main module). + // Note: for calls (and other reloc targets), we instead prefer not going through the shim. + let shim_func_id = self.local_shims.get(&input_func_id); + let output_func_id = shim_func_id.or_else(|| { + self.dep_to_local_index + .get(&DepNode::Function(input_func_id)) + }); + let Some(&output_func_id) = output_func_id else { + bail!("No output function corresponding to input function {input_func_id:?}") + }; Ok(output_func_id as u32) }) .collect::>>()?; @@ -1088,22 +1118,38 @@ impl<'a> ModuleEmitState<'a> { self.output_module.section(§ion); } - fn generate_indirect_stub( + fn generate_indirect_shim( &self, indirect_index: usize, type_id: usize, ) -> wasm_encoder::Function { let func_type = &self.input_module.types[type_id]; let mut func = wasm_encoder::Function::new([]); + let mut sink = func.instructions(); for (param_i, _param_type) in func_type.params().iter().enumerate() { - func.instruction(&wasm_encoder::Instruction::LocalGet(param_i as u32)); + sink.local_get(param_i as u32); } - func.instruction(&wasm_encoder::Instruction::I32Const(indirect_index as i32)); - func.instruction(&wasm_encoder::Instruction::CallIndirect { - type_index: type_id as u32, - table_index: 0, - }); - func.instruction(&wasm_encoder::Instruction::End); + sink.i32_const(indirect_index as i32); + // TODO: can optionally use return_call_indirect for tail call optimizations + sink.call_indirect(0, type_id as u32); + sink.end(); + func + } + + fn generate_local_shim( + &self, + imported_func_id: usize, + type_id: usize, + ) -> wasm_encoder::Function { + let func_type = &self.input_module.types[type_id]; + let mut func = wasm_encoder::Function::new([]); + let mut sink = func.instructions(); + for (param_i, _param_type) in func_type.params().iter().enumerate() { + sink.local_get(param_i as u32); + } + // TODO: can optionally use return_call for tail call optimizations + sink.call(imported_func_id as u32); + sink.end(); func } @@ -1115,53 +1161,57 @@ impl<'a> ModuleEmitState<'a> { ); let mut section = wasm_encoder::CodeSection::new(); for output_func in self.defined_functions.iter() { - match output_func.kind { - OutputFunctionKind::Defined - if self - .emit_state - .no_reloc_stubs - .contains(&output_func.input_func_id) => + match *output_func { + OutputFunction::DefinedFromInput { input_func_id, .. } + if self.emit_state.no_reloc_stubs.contains(&input_func_id) => { let body = self.input_module.defined_funcs - [output_func.input_func_id - self.input_module.imported_funcs.len()] + [input_func_id - self.input_module.imported_funcs.len()] .body .clone(); FuncReencoder { dep_to_local_index: &self.dep_to_local_index, - func_id: output_func.input_func_id, + func_id: input_func_id, } .parse_function_body(&mut section, body) - .with_context(|| { - format!("re-encoding no-reloc func {}", output_func.input_func_id) - })?; + .with_context(|| format!("re-encoding no-reloc func {}", input_func_id))?; } - OutputFunctionKind::Defined => { + OutputFunction::DefinedFromInput { input_func_id, .. } => { let input_func = &self.input_module.defined_funcs - [output_func.input_func_id - self.input_module.imported_funcs.len()]; + [input_func_id - self.input_module.imported_funcs.len()]; let relocated_def = self .get_relocated_data(input_func.body.range()) .with_context(|| { format!( "when emitted definition of func[{}] in module {}", - output_func.input_func_id, self.output_module_index, + input_func_id, self.output_module_index, ) })?; section.raw(&relocated_def); } - OutputFunctionKind::IndirectStub => { + OutputFunction::IndirectCallShim { input_func_id, .. } => { let indirect_index = self .emit_state .indirect_functions .function_table_index - .get(&output_func.input_func_id) + .get(&input_func_id) .unwrap(); - let function = self.generate_indirect_stub( + let function = self.generate_indirect_shim( *indirect_index, - self.input_module.func_type_id(output_func.input_func_id), + self.input_module.func_type_id(input_func_id), ); section.function(&function); } + OutputFunction::LocalShim { input_func_id, .. } => { + let emitted_func_id = *self + .dep_to_local_index + .get(&DepNode::Function(input_func_id)) + .expect("input to have a name in this output module"); + let type_id = self.input_module.func_type_id(input_func_id); + let function = self.generate_local_shim(emitted_func_id, type_id); + section.function(&function); + } } } self.output_module.section(§ion); @@ -1277,21 +1327,44 @@ impl<'a> ModuleEmitState<'a> { let mut name_map = wasm_encoder::NameMap::new(); let mut locals_map = wasm_encoder::IndirectNameMap::new(); let mut labels_map = wasm_encoder::IndirectNameMap::new(); - for OutputFunction { input_func_id, .. } in &self.defined_functions { - let Some(&output_func_id) = self - .dep_to_local_index - .get(&DepNode::Function(*input_func_id)) - else { - continue; - }; - if let Some(name) = self.input_module.names.functions.get(input_func_id) { - name_map.append(output_func_id as u32, name); - } - if let Some(name_map) = self.input_module.names.locals.get(input_func_id) { - locals_map.append(output_func_id as u32, &convert_name_map(name_map)?); - } - if let Some(name_map) = self.input_module.names.labels.get(input_func_id) { - labels_map.append(output_func_id as u32, &convert_name_map(name_map)?); + for output_function in &self.defined_functions { + match output_function { + OutputFunction::DefinedFromInput { input_func_id, .. } => { + let Some(&output_func_id) = self + .dep_to_local_index + .get(&DepNode::Function(*input_func_id)) + else { + continue; + }; + if let Some(&name) = self.input_module.names.functions.get(input_func_id) { + name_map.append(output_func_id as u32, name); + } + if let Some(name_map) = self.input_module.names.locals.get(input_func_id) { + locals_map.append(output_func_id as u32, &convert_name_map(name_map)?); + } + if let Some(name_map) = self.input_module.names.labels.get(input_func_id) { + labels_map.append(output_func_id as u32, &convert_name_map(name_map)?); + } + } + &OutputFunction::IndirectCallShim { + ref input_func_id, + shim_index: local_index, + .. + } => { + if let Some(&name) = self.input_module.names.functions.get(input_func_id) { + name_map + .append(local_index as u32, &format!("{name} indirect_call_shim")); + } + } + &OutputFunction::LocalShim { + ref input_func_id, + shim_index: local_index, + .. + } => { + if let Some(&name) = self.input_module.names.functions.get(input_func_id) { + name_map.append(local_index as u32, &format!("{name} local_shim")); + } + } } } section.functions(&name_map); @@ -1328,19 +1401,9 @@ impl<'a> ModuleEmitState<'a> { fn generate_target_features_section(&mut self) { for custom in self.input_module.custom_sections.iter() { if custom.name == "target_features" { - // Another wasm-bindgen hack: To make sure reference-types is not detected, replace the feature string :) - let mut data: Vec = custom.data.into(); - if self.is_main() { - // 0x0f is the length of the following string - let needle = b"+\x0freference-types"; - if let Some(pos) = data.windows(needle.len()).position(|feat| feat == needle) { - data[pos..pos + needle.len()].copy_from_slice(b"+\x0fREFERENCE-TYPES"); - } - } - self.output_module.section(&wasm_encoder::CustomSection { name: custom.name.into(), - data: data.into(), + data: custom.data.into(), }); } } diff --git a/crates/wasm_split_cli/src/split_point.rs b/crates/wasm_split_cli/src/split_point.rs index 15041f5..73dfd55 100644 --- a/crates/wasm_split_cli/src/split_point.rs +++ b/crates/wasm_split_cli/src/split_point.rs @@ -356,6 +356,14 @@ pub struct SplitProgramInfo { /// the input deterministically. Options can (and should) influence this if they lead to different /// output modules. pub canary_export_name: String, + /// Related to [wasm-bindgen hack]. + /// wasm-bindgen replaces a part of cast-functions with an import. This import then gets its + /// signature transformed in some cases in the externref pass. Hence, two things need to happen: + /// - the replaced function must be in the main module (done by coloring it and its dependencies + /// with the main module). + /// - an additional shim function needs to be used in the indirect_function_table instead of it, + /// because other modules expect the original signature. + pub needs_shim_in_main: HashSet, } impl SplitProgramInfo { @@ -475,6 +483,7 @@ pub fn compute_split_modules( graph_analysis.explore(roots, SplitModuleIdentifier::Split(module_name.clone())); } + let mut program_info = SplitProgramInfo::default(); // We "paint" each dependency with the modules it must be loaded in, then put them into that module // accordingly. let mut painter = graph_analysis.into_painter(); @@ -488,10 +497,19 @@ pub fn compute_split_modules( .or_default() .included_symbols .insert(node); + let DepNode::Function(func_id) = node else { + continue; + }; + if color == &SplitModuleIdentifier::Main + && dep_graph + .get(&node) + .is_some_and(|deps| !deps.is_disjoint(&wbg_rooting_deps)) + { + program_info.needs_shim_in_main.insert(func_id); + } } // Now, check for each module which of its dependencies it needs to import from some other module. - let mut program_info = SplitProgramInfo::default(); for out_module in split_module_contents.values_mut() { let needed_symbols = out_module .included_symbols