Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 150 additions & 87 deletions crates/wasm_split_cli/src/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -557,16 +557,20 @@ pub struct OutputExport<'a> {
pub index: u32,
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Copy)]
enum OutputFunctionKind {
Defined,
IndirectStub,
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct OutputFunction {
input_func_id: InputFuncId,
kind: OutputFunctionKind,
enum OutputFunction {
DefinedFromInput {
input_func_id: InputFuncId,
local_index: usize,
},
IndirectCallShim {
input_func_id: InputFuncId,
shim_index: usize,
},
LocalShim {
input_func_id: InputFuncId,
shim_index: usize,
},
}

struct ModuleEmitState<'a> {
Expand All @@ -581,6 +585,7 @@ struct ModuleEmitState<'a> {
defined_functions: Vec<OutputFunction>,

dep_to_local_index: HashMap<DepNode, usize>,
local_shims: HashMap<InputFuncId, usize>,
}

impl RelocTarget for ModuleEmitState<'_> {
Expand Down Expand Up @@ -672,6 +677,7 @@ impl<'a> ModuleEmitState<'a> {
let mut exports = vec![];
let mut defined_functions = vec![];
let mut dep_to_local_index = HashMap::new();
let mut local_shims = HashMap::new();

for (func, func_import) in emit_state.input_module.imported_funcs.iter().enumerate() {
if !output_module_info
Expand All @@ -680,6 +686,10 @@ impl<'a> ModuleEmitState<'a> {
{
continue;
}
debug_assert!(
output_module_index == 0,
"expected a function import to happen in the main module"
);
let import = &emit_state.input_module.imports[func_import.import_id];
let local_func = num_func_imports;
dep_to_local_index.insert(DepNode::Function(func), local_func);
Expand All @@ -692,23 +702,29 @@ impl<'a> ModuleEmitState<'a> {
}

for dep in &output_module_info.included_symbols {
// define it
match dep {
DepNode::Function(func)
if *func >= emit_state.input_module.imported_funcs.len() =>
if let &dep @ DepNode::Function(input_func) = dep {
let is_import = input_func < emit_state.input_module.imported_funcs.len();
if !is_import {
let local_func = num_func_imports + defined_functions.len();
defined_functions.push(OutputFunction::DefinedFromInput {
input_func_id: input_func,
local_index: local_func,
});
dep_to_local_index.insert(dep, local_func);
}
if (is_import && program_info.shared_deps.contains(&dep))
|| program_info.needs_shim_in_main.contains(&input_func)
{
let local_func = num_func_imports + defined_functions.len();
dep_to_local_index.insert(DepNode::Function(*func), local_func);
defined_functions.push(OutputFunction {
input_func_id: *func,
kind: OutputFunctionKind::Defined,
defined_functions.push(OutputFunction::LocalShim {
input_func_id: input_func,
shim_index: local_func,
});
local_shims.insert(input_func, local_func);
tracing::trace!("Adding shim for {input_func} -> {local_func}");
}
// already processed in the previous loop
DepNode::Function(_) => {}
// definition is implicit in the generate methods
_ => {}
}
// for other types of symbols, the definition is implicit
}

let mut also_needs_indirect_table =
Expand All @@ -721,12 +737,12 @@ impl<'a> ModuleEmitState<'a> {
match used_shared {
DepNode::Function(func) => {
let local_func = num_func_imports + defined_functions.len();
dep_to_local_index.insert(DepNode::Function(*func), local_func);
also_needs_indirect_table = true; // Because we use it in the indirect stub
defined_functions.push(OutputFunction {
defined_functions.push(OutputFunction::IndirectCallShim {
input_func_id: *func,
kind: OutputFunctionKind::IndirectStub,
shim_index: local_func,
});
dep_to_local_index.insert(DepNode::Function(*func), local_func);
also_needs_indirect_table = true; // Because we use it in the indirect stub
}
dep @ DepNode::Global(_)
| dep @ DepNode::Table(_)
Expand Down Expand Up @@ -829,6 +845,7 @@ impl<'a> ModuleEmitState<'a> {
imports,
exports,
dep_to_local_index,
local_shims,
}
}

Expand Down Expand Up @@ -902,8 +919,15 @@ impl<'a> ModuleEmitState<'a> {

fn generate_function_section(&mut self) {
let mut section = wasm_encoder::FunctionSection::new();
for OutputFunction { input_func_id, .. } in self.defined_functions.iter() {
section.function(self.input_module.func_type_id(*input_func_id) as u32);
for defined_func in self.defined_functions.iter() {
let func_ty = match defined_func {
OutputFunction::DefinedFromInput { input_func_id, .. }
| OutputFunction::LocalShim { input_func_id, .. }
| OutputFunction::IndirectCallShim { input_func_id, .. } => {
self.input_module.func_type_id(*input_func_id)
}
};
section.function(func_ty as u32);
}
self.output_module.section(&section);
}
Expand Down Expand Up @@ -1026,14 +1050,20 @@ impl<'a> ModuleEmitState<'a> {
.map(|table_index| -> Result<u32> {
let input_func_id =
self.emit_state.indirect_functions.table_entries[table_index - 1];
let output_func_id = *self
.dep_to_local_index
.get(&DepNode::Function(input_func_id))
.ok_or_else(|| {
anyhow!(
"No output function corresponding to input function {input_func_id:?}"
)
})?;
// If we have a local shim, we insert that into the indirect function table instead
// of the import directly.
// The reason is that wasm-bindgen will rewrite the signature of externref imports.
// Having such a function in the IFT would lead to mismatching signatures. The call
// in the shim will get rewritten correctly (in the main module).
// Note: for calls (and other reloc targets), we instead prefer not going through the shim.
let shim_func_id = self.local_shims.get(&input_func_id);
let output_func_id = shim_func_id.or_else(|| {
self.dep_to_local_index
.get(&DepNode::Function(input_func_id))
});
let Some(&output_func_id) = output_func_id else {
bail!("No output function corresponding to input function {input_func_id:?}")
};
Ok(output_func_id as u32)
})
.collect::<Result<Vec<_>>>()?;
Expand Down Expand Up @@ -1088,22 +1118,38 @@ impl<'a> ModuleEmitState<'a> {
self.output_module.section(&section);
}

fn generate_indirect_stub(
fn generate_indirect_shim(
&self,
indirect_index: usize,
type_id: usize,
) -> wasm_encoder::Function {
let func_type = &self.input_module.types[type_id];
let mut func = wasm_encoder::Function::new([]);
let mut sink = func.instructions();
for (param_i, _param_type) in func_type.params().iter().enumerate() {
func.instruction(&wasm_encoder::Instruction::LocalGet(param_i as u32));
sink.local_get(param_i as u32);
}
func.instruction(&wasm_encoder::Instruction::I32Const(indirect_index as i32));
func.instruction(&wasm_encoder::Instruction::CallIndirect {
type_index: type_id as u32,
table_index: 0,
});
func.instruction(&wasm_encoder::Instruction::End);
sink.i32_const(indirect_index as i32);
// TODO: can optionally use return_call_indirect for tail call optimizations
sink.call_indirect(0, type_id as u32);
sink.end();
func
}

fn generate_local_shim(
&self,
imported_func_id: usize,
type_id: usize,
) -> wasm_encoder::Function {
let func_type = &self.input_module.types[type_id];
let mut func = wasm_encoder::Function::new([]);
let mut sink = func.instructions();
for (param_i, _param_type) in func_type.params().iter().enumerate() {
sink.local_get(param_i as u32);
}
// TODO: can optionally use return_call for tail call optimizations
sink.call(imported_func_id as u32);
sink.end();
func
}

Expand All @@ -1115,53 +1161,57 @@ impl<'a> ModuleEmitState<'a> {
);
let mut section = wasm_encoder::CodeSection::new();
for output_func in self.defined_functions.iter() {
match output_func.kind {
OutputFunctionKind::Defined
if self
.emit_state
.no_reloc_stubs
.contains(&output_func.input_func_id) =>
match *output_func {
OutputFunction::DefinedFromInput { input_func_id, .. }
if self.emit_state.no_reloc_stubs.contains(&input_func_id) =>
{
let body = self.input_module.defined_funcs
[output_func.input_func_id - self.input_module.imported_funcs.len()]
[input_func_id - self.input_module.imported_funcs.len()]
.body
.clone();

FuncReencoder {
dep_to_local_index: &self.dep_to_local_index,
func_id: output_func.input_func_id,
func_id: input_func_id,
}
.parse_function_body(&mut section, body)
.with_context(|| {
format!("re-encoding no-reloc func {}", output_func.input_func_id)
})?;
.with_context(|| format!("re-encoding no-reloc func {}", input_func_id))?;
}
OutputFunctionKind::Defined => {
OutputFunction::DefinedFromInput { input_func_id, .. } => {
let input_func = &self.input_module.defined_funcs
[output_func.input_func_id - self.input_module.imported_funcs.len()];
[input_func_id - self.input_module.imported_funcs.len()];
let relocated_def = self
.get_relocated_data(input_func.body.range())
.with_context(|| {
format!(
"when emitted definition of func[{}] in module {}",
output_func.input_func_id, self.output_module_index,
input_func_id, self.output_module_index,
)
})?;
section.raw(&relocated_def);
}
OutputFunctionKind::IndirectStub => {
OutputFunction::IndirectCallShim { input_func_id, .. } => {
let indirect_index = self
.emit_state
.indirect_functions
.function_table_index
.get(&output_func.input_func_id)
.get(&input_func_id)
.unwrap();
let function = self.generate_indirect_stub(
let function = self.generate_indirect_shim(
*indirect_index,
self.input_module.func_type_id(output_func.input_func_id),
self.input_module.func_type_id(input_func_id),
);
section.function(&function);
}
OutputFunction::LocalShim { input_func_id, .. } => {
let emitted_func_id = *self
.dep_to_local_index
.get(&DepNode::Function(input_func_id))
.expect("input to have a name in this output module");
let type_id = self.input_module.func_type_id(input_func_id);
let function = self.generate_local_shim(emitted_func_id, type_id);
section.function(&function);
}
}
}
self.output_module.section(&section);
Expand Down Expand Up @@ -1277,21 +1327,44 @@ impl<'a> ModuleEmitState<'a> {
let mut name_map = wasm_encoder::NameMap::new();
let mut locals_map = wasm_encoder::IndirectNameMap::new();
let mut labels_map = wasm_encoder::IndirectNameMap::new();
for OutputFunction { input_func_id, .. } in &self.defined_functions {
let Some(&output_func_id) = self
.dep_to_local_index
.get(&DepNode::Function(*input_func_id))
else {
continue;
};
if let Some(name) = self.input_module.names.functions.get(input_func_id) {
name_map.append(output_func_id as u32, name);
}
if let Some(name_map) = self.input_module.names.locals.get(input_func_id) {
locals_map.append(output_func_id as u32, &convert_name_map(name_map)?);
}
if let Some(name_map) = self.input_module.names.labels.get(input_func_id) {
labels_map.append(output_func_id as u32, &convert_name_map(name_map)?);
for output_function in &self.defined_functions {
match output_function {
OutputFunction::DefinedFromInput { input_func_id, .. } => {
let Some(&output_func_id) = self
.dep_to_local_index
.get(&DepNode::Function(*input_func_id))
else {
continue;
};
if let Some(&name) = self.input_module.names.functions.get(input_func_id) {
name_map.append(output_func_id as u32, name);
}
if let Some(name_map) = self.input_module.names.locals.get(input_func_id) {
locals_map.append(output_func_id as u32, &convert_name_map(name_map)?);
}
if let Some(name_map) = self.input_module.names.labels.get(input_func_id) {
labels_map.append(output_func_id as u32, &convert_name_map(name_map)?);
}
}
&OutputFunction::IndirectCallShim {
ref input_func_id,
shim_index: local_index,
..
} => {
if let Some(&name) = self.input_module.names.functions.get(input_func_id) {
name_map
.append(local_index as u32, &format!("{name} indirect_call_shim"));
}
}
&OutputFunction::LocalShim {
ref input_func_id,
shim_index: local_index,
..
} => {
if let Some(&name) = self.input_module.names.functions.get(input_func_id) {
name_map.append(local_index as u32, &format!("{name} local_shim"));
}
}
}
}
section.functions(&name_map);
Expand Down Expand Up @@ -1328,19 +1401,9 @@ impl<'a> ModuleEmitState<'a> {
fn generate_target_features_section(&mut self) {
for custom in self.input_module.custom_sections.iter() {
if custom.name == "target_features" {
// Another wasm-bindgen hack: To make sure reference-types is not detected, replace the feature string :)
let mut data: Vec<u8> = custom.data.into();
if self.is_main() {
// 0x0f is the length of the following string
let needle = b"+\x0freference-types";
if let Some(pos) = data.windows(needle.len()).position(|feat| feat == needle) {
data[pos..pos + needle.len()].copy_from_slice(b"+\x0fREFERENCE-TYPES");
}
}

self.output_module.section(&wasm_encoder::CustomSection {
name: custom.name.into(),
data: data.into(),
data: custom.data.into(),
});
}
}
Expand Down
Loading