diff --git a/crates/wasm_split_cli/src/emit.rs b/crates/wasm_split_cli/src/emit.rs index ade6116..cc9eb63 100644 --- a/crates/wasm_split_cli/src/emit.rs +++ b/crates/wasm_split_cli/src/emit.rs @@ -557,16 +557,20 @@ pub struct OutputExport<'a> { pub index: u32, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Copy)] -enum OutputFunctionKind { - Defined, - IndirectStub, -} - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -struct OutputFunction { - input_func_id: InputFuncId, - kind: OutputFunctionKind, +enum OutputFunction { + DefinedFromInput { + input_func_id: InputFuncId, + local_index: usize, + }, + IndirectCallShim { + input_func_id: InputFuncId, + shim_index: usize, + }, + LocalShim { + input_func_id: InputFuncId, + shim_index: usize, + }, } struct ModuleEmitState<'a> { @@ -581,6 +585,7 @@ struct ModuleEmitState<'a> { defined_functions: Vec, dep_to_local_index: HashMap, + local_shims: HashMap, } impl RelocTarget for ModuleEmitState<'_> { @@ -672,6 +677,7 @@ impl<'a> ModuleEmitState<'a> { let mut exports = vec![]; let mut defined_functions = vec![]; let mut dep_to_local_index = HashMap::new(); + let mut local_shims = HashMap::new(); for (func, func_import) in emit_state.input_module.imported_funcs.iter().enumerate() { if !output_module_info @@ -680,6 +686,10 @@ impl<'a> ModuleEmitState<'a> { { continue; } + debug_assert!( + output_module_index == 0, + "expected a function import to happen in the main module" + ); let import = &emit_state.input_module.imports[func_import.import_id]; let local_func = num_func_imports; dep_to_local_index.insert(DepNode::Function(func), local_func); @@ -692,23 +702,29 @@ impl<'a> ModuleEmitState<'a> { } for dep in &output_module_info.included_symbols { - // define it - match dep { - DepNode::Function(func) - if *func >= emit_state.input_module.imported_funcs.len() => + if let &dep @ DepNode::Function(input_func) = dep { + let is_import = input_func < emit_state.input_module.imported_funcs.len(); + if !is_import { + let local_func = num_func_imports + defined_functions.len(); + defined_functions.push(OutputFunction::DefinedFromInput { + input_func_id: input_func, + local_index: local_func, + }); + dep_to_local_index.insert(dep, local_func); + } + if (is_import && program_info.shared_deps.contains(&dep)) + || program_info.needs_shim_in_main.contains(&input_func) { let local_func = num_func_imports + defined_functions.len(); - dep_to_local_index.insert(DepNode::Function(*func), local_func); - defined_functions.push(OutputFunction { - input_func_id: *func, - kind: OutputFunctionKind::Defined, + defined_functions.push(OutputFunction::LocalShim { + input_func_id: input_func, + shim_index: local_func, }); + local_shims.insert(input_func, local_func); + tracing::trace!("Adding shim for {input_func} -> {local_func}"); } - // already processed in the previous loop - DepNode::Function(_) => {} - // definition is implicit in the generate methods - _ => {} } + // for other types of symbols, the definition is implicit } let mut also_needs_indirect_table = @@ -721,12 +737,12 @@ impl<'a> ModuleEmitState<'a> { match used_shared { DepNode::Function(func) => { let local_func = num_func_imports + defined_functions.len(); - dep_to_local_index.insert(DepNode::Function(*func), local_func); - also_needs_indirect_table = true; // Because we use it in the indirect stub - defined_functions.push(OutputFunction { + defined_functions.push(OutputFunction::IndirectCallShim { input_func_id: *func, - kind: OutputFunctionKind::IndirectStub, + shim_index: local_func, }); + dep_to_local_index.insert(DepNode::Function(*func), local_func); + also_needs_indirect_table = true; // Because we use it in the indirect stub } dep @ DepNode::Global(_) | dep @ DepNode::Table(_) @@ -829,6 +845,7 @@ impl<'a> ModuleEmitState<'a> { imports, exports, dep_to_local_index, + local_shims, } } @@ -902,8 +919,15 @@ impl<'a> ModuleEmitState<'a> { fn generate_function_section(&mut self) { let mut section = wasm_encoder::FunctionSection::new(); - for OutputFunction { input_func_id, .. } in self.defined_functions.iter() { - section.function(self.input_module.func_type_id(*input_func_id) as u32); + for defined_func in self.defined_functions.iter() { + let func_ty = match defined_func { + OutputFunction::DefinedFromInput { input_func_id, .. } + | OutputFunction::LocalShim { input_func_id, .. } + | OutputFunction::IndirectCallShim { input_func_id, .. } => { + self.input_module.func_type_id(*input_func_id) + } + }; + section.function(func_ty as u32); } self.output_module.section(§ion); } @@ -1026,14 +1050,20 @@ impl<'a> ModuleEmitState<'a> { .map(|table_index| -> Result { let input_func_id = self.emit_state.indirect_functions.table_entries[table_index - 1]; - let output_func_id = *self - .dep_to_local_index - .get(&DepNode::Function(input_func_id)) - .ok_or_else(|| { - anyhow!( - "No output function corresponding to input function {input_func_id:?}" - ) - })?; + // If we have a local shim, we insert that into the indirect function table instead + // of the import directly. + // The reason is that wasm-bindgen will rewrite the signature of externref imports. + // Having such a function in the IFT would lead to mismatching signatures. The call + // in the shim will get rewritten correctly (in the main module). + // Note: for calls (and other reloc targets), we instead prefer not going through the shim. + let shim_func_id = self.local_shims.get(&input_func_id); + let output_func_id = shim_func_id.or_else(|| { + self.dep_to_local_index + .get(&DepNode::Function(input_func_id)) + }); + let Some(&output_func_id) = output_func_id else { + bail!("No output function corresponding to input function {input_func_id:?}") + }; Ok(output_func_id as u32) }) .collect::>>()?; @@ -1088,22 +1118,38 @@ impl<'a> ModuleEmitState<'a> { self.output_module.section(§ion); } - fn generate_indirect_stub( + fn generate_indirect_shim( &self, indirect_index: usize, type_id: usize, ) -> wasm_encoder::Function { let func_type = &self.input_module.types[type_id]; let mut func = wasm_encoder::Function::new([]); + let mut sink = func.instructions(); for (param_i, _param_type) in func_type.params().iter().enumerate() { - func.instruction(&wasm_encoder::Instruction::LocalGet(param_i as u32)); + sink.local_get(param_i as u32); } - func.instruction(&wasm_encoder::Instruction::I32Const(indirect_index as i32)); - func.instruction(&wasm_encoder::Instruction::CallIndirect { - type_index: type_id as u32, - table_index: 0, - }); - func.instruction(&wasm_encoder::Instruction::End); + sink.i32_const(indirect_index as i32); + // TODO: can optionally use return_call_indirect for tail call optimizations + sink.call_indirect(0, type_id as u32); + sink.end(); + func + } + + fn generate_local_shim( + &self, + imported_func_id: usize, + type_id: usize, + ) -> wasm_encoder::Function { + let func_type = &self.input_module.types[type_id]; + let mut func = wasm_encoder::Function::new([]); + let mut sink = func.instructions(); + for (param_i, _param_type) in func_type.params().iter().enumerate() { + sink.local_get(param_i as u32); + } + // TODO: can optionally use return_call for tail call optimizations + sink.call(imported_func_id as u32); + sink.end(); func } @@ -1115,53 +1161,57 @@ impl<'a> ModuleEmitState<'a> { ); let mut section = wasm_encoder::CodeSection::new(); for output_func in self.defined_functions.iter() { - match output_func.kind { - OutputFunctionKind::Defined - if self - .emit_state - .no_reloc_stubs - .contains(&output_func.input_func_id) => + match *output_func { + OutputFunction::DefinedFromInput { input_func_id, .. } + if self.emit_state.no_reloc_stubs.contains(&input_func_id) => { let body = self.input_module.defined_funcs - [output_func.input_func_id - self.input_module.imported_funcs.len()] + [input_func_id - self.input_module.imported_funcs.len()] .body .clone(); FuncReencoder { dep_to_local_index: &self.dep_to_local_index, - func_id: output_func.input_func_id, + func_id: input_func_id, } .parse_function_body(&mut section, body) - .with_context(|| { - format!("re-encoding no-reloc func {}", output_func.input_func_id) - })?; + .with_context(|| format!("re-encoding no-reloc func {}", input_func_id))?; } - OutputFunctionKind::Defined => { + OutputFunction::DefinedFromInput { input_func_id, .. } => { let input_func = &self.input_module.defined_funcs - [output_func.input_func_id - self.input_module.imported_funcs.len()]; + [input_func_id - self.input_module.imported_funcs.len()]; let relocated_def = self .get_relocated_data(input_func.body.range()) .with_context(|| { format!( "when emitted definition of func[{}] in module {}", - output_func.input_func_id, self.output_module_index, + input_func_id, self.output_module_index, ) })?; section.raw(&relocated_def); } - OutputFunctionKind::IndirectStub => { + OutputFunction::IndirectCallShim { input_func_id, .. } => { let indirect_index = self .emit_state .indirect_functions .function_table_index - .get(&output_func.input_func_id) + .get(&input_func_id) .unwrap(); - let function = self.generate_indirect_stub( + let function = self.generate_indirect_shim( *indirect_index, - self.input_module.func_type_id(output_func.input_func_id), + self.input_module.func_type_id(input_func_id), ); section.function(&function); } + OutputFunction::LocalShim { input_func_id, .. } => { + let emitted_func_id = *self + .dep_to_local_index + .get(&DepNode::Function(input_func_id)) + .expect("input to have a name in this output module"); + let type_id = self.input_module.func_type_id(input_func_id); + let function = self.generate_local_shim(emitted_func_id, type_id); + section.function(&function); + } } } self.output_module.section(§ion); @@ -1277,21 +1327,44 @@ impl<'a> ModuleEmitState<'a> { let mut name_map = wasm_encoder::NameMap::new(); let mut locals_map = wasm_encoder::IndirectNameMap::new(); let mut labels_map = wasm_encoder::IndirectNameMap::new(); - for OutputFunction { input_func_id, .. } in &self.defined_functions { - let Some(&output_func_id) = self - .dep_to_local_index - .get(&DepNode::Function(*input_func_id)) - else { - continue; - }; - if let Some(name) = self.input_module.names.functions.get(input_func_id) { - name_map.append(output_func_id as u32, name); - } - if let Some(name_map) = self.input_module.names.locals.get(input_func_id) { - locals_map.append(output_func_id as u32, &convert_name_map(name_map)?); - } - if let Some(name_map) = self.input_module.names.labels.get(input_func_id) { - labels_map.append(output_func_id as u32, &convert_name_map(name_map)?); + for output_function in &self.defined_functions { + match output_function { + OutputFunction::DefinedFromInput { input_func_id, .. } => { + let Some(&output_func_id) = self + .dep_to_local_index + .get(&DepNode::Function(*input_func_id)) + else { + continue; + }; + if let Some(&name) = self.input_module.names.functions.get(input_func_id) { + name_map.append(output_func_id as u32, name); + } + if let Some(name_map) = self.input_module.names.locals.get(input_func_id) { + locals_map.append(output_func_id as u32, &convert_name_map(name_map)?); + } + if let Some(name_map) = self.input_module.names.labels.get(input_func_id) { + labels_map.append(output_func_id as u32, &convert_name_map(name_map)?); + } + } + &OutputFunction::IndirectCallShim { + ref input_func_id, + shim_index: local_index, + .. + } => { + if let Some(&name) = self.input_module.names.functions.get(input_func_id) { + name_map + .append(local_index as u32, &format!("{name} indirect_call_shim")); + } + } + &OutputFunction::LocalShim { + ref input_func_id, + shim_index: local_index, + .. + } => { + if let Some(&name) = self.input_module.names.functions.get(input_func_id) { + name_map.append(local_index as u32, &format!("{name} local_shim")); + } + } } } section.functions(&name_map); @@ -1328,19 +1401,9 @@ impl<'a> ModuleEmitState<'a> { fn generate_target_features_section(&mut self) { for custom in self.input_module.custom_sections.iter() { if custom.name == "target_features" { - // Another wasm-bindgen hack: To make sure reference-types is not detected, replace the feature string :) - let mut data: Vec = custom.data.into(); - if self.is_main() { - // 0x0f is the length of the following string - let needle = b"+\x0freference-types"; - if let Some(pos) = data.windows(needle.len()).position(|feat| feat == needle) { - data[pos..pos + needle.len()].copy_from_slice(b"+\x0fREFERENCE-TYPES"); - } - } - self.output_module.section(&wasm_encoder::CustomSection { name: custom.name.into(), - data: data.into(), + data: custom.data.into(), }); } } diff --git a/crates/wasm_split_cli/src/split_point.rs b/crates/wasm_split_cli/src/split_point.rs index 15041f5..73dfd55 100644 --- a/crates/wasm_split_cli/src/split_point.rs +++ b/crates/wasm_split_cli/src/split_point.rs @@ -356,6 +356,14 @@ pub struct SplitProgramInfo { /// the input deterministically. Options can (and should) influence this if they lead to different /// output modules. pub canary_export_name: String, + /// Related to [wasm-bindgen hack]. + /// wasm-bindgen replaces a part of cast-functions with an import. This import then gets its + /// signature transformed in some cases in the externref pass. Hence, two things need to happen: + /// - the replaced function must be in the main module (done by coloring it and its dependencies + /// with the main module). + /// - an additional shim function needs to be used in the indirect_function_table instead of it, + /// because other modules expect the original signature. + pub needs_shim_in_main: HashSet, } impl SplitProgramInfo { @@ -475,6 +483,7 @@ pub fn compute_split_modules( graph_analysis.explore(roots, SplitModuleIdentifier::Split(module_name.clone())); } + let mut program_info = SplitProgramInfo::default(); // We "paint" each dependency with the modules it must be loaded in, then put them into that module // accordingly. let mut painter = graph_analysis.into_painter(); @@ -488,10 +497,19 @@ pub fn compute_split_modules( .or_default() .included_symbols .insert(node); + let DepNode::Function(func_id) = node else { + continue; + }; + if color == &SplitModuleIdentifier::Main + && dep_graph + .get(&node) + .is_some_and(|deps| !deps.is_disjoint(&wbg_rooting_deps)) + { + program_info.needs_shim_in_main.insert(func_id); + } } // Now, check for each module which of its dependencies it needs to import from some other module. - let mut program_info = SplitProgramInfo::default(); for out_module in split_module_contents.values_mut() { let needed_symbols = out_module .included_symbols