Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions import-test.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
(module
;; CHECK: (type $t (func (param i32)))
(type $t (func (param i32)))

;; CHECK: (import "" "" (func $imported-func (type $t) (param i32)))
;; (import "" "" (func $imported-func (type $t)))
(import "" "" (func $imported-func (type $t)))

(elem declare $imported-func)

;; CHECK: (func $nop (type $t) (param $0 i32)
;; CHECK-NEXT: (nop)
;; CHECK-NEXT: )
(func $nop (param i32)
)

;; CHECK: (func $indirect-calls (type $1) (param $ref (ref $t))
;; CHECK-NEXT: (call_ref $t
;; CHECK-NEXT: (i32.const 1)
;; CHECK-NEXT: (local.get $ref)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $indirect-calls (param $ref (ref $t))
(call_ref $t (i32.const 1) (local.get $ref))
)

;; CHECK: (func $f (type $1) (param $ref (ref $t))
;; CHECK-NEXT: (nop)
;; CHECK-NEXT: )
(func $f (param $ref (ref $t))
;; $indirect-calls might end up calling an imported function,
;; so we don't know anything about effects here
(call $indirect-calls (local.get $ref))
)
)
2 changes: 1 addition & 1 deletion src/ir/module-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ template<typename T> inline void iterModuleItems(Module& wasm, T visitor) {
template<typename K, typename V> using DefaultMap = std::map<K, V>;
template<typename T,
Mutability Mut = Immutable,
template<typename, typename> class MapT = DefaultMap>
template<typename, typename, typename...> class MapT = DefaultMap>
struct ParallelFunctionAnalysis {
Module& wasm;

Expand Down
13 changes: 11 additions & 2 deletions src/ir/subtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ struct SubTypes {
// false, we stop. Returns the last value returned to it, that is, returns
// true if we did not stop early, and false if we did.
template<typename F>
bool iterSubTypes(HeapType type, Index depth, F func) const {
bool iterSubTypes(HeapType type, Index depth, F func) const
requires requires(F func, HeapType subtype, Index depth) {
{ func(subtype, depth) } -> std::same_as<bool>;
}
{
// Start by traversing the type itself.
if (!func(type, 0)) {
return false;
Expand Down Expand Up @@ -219,7 +223,12 @@ struct SubTypes {
}

// As above, but iterate to the maximum depth.
template<typename F> bool iterSubTypes(HeapType type, F func) const {
template<typename F>
bool iterSubTypes(HeapType type, F func) const
requires requires(F func, HeapType subtype, Index depth) {
{ func(subtype, depth) } -> std::same_as<bool>;
}
{
return iterSubTypes(type, std::numeric_limits<Index>::max(), func);
}

Expand Down
158 changes: 124 additions & 34 deletions src/passes/GlobalEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
*/

//
// Handle the computation of global effects. The effects are stored on the
// PassOptions structure; see more details there.
// Handle the computation of global effects. The effects are stored on
// Function::effects; see more details there.
//

#include "ir/effects.h"
#include "ir/element-utils.h"
#include "ir/module-utils.h"
#include "ir/subtypes.h"
#include "ir/table-utils.h"
#include "pass.h"
#include "support/unique_deferring_queue.h"
#include "wasm.h"
Expand All @@ -39,12 +42,15 @@ struct FuncInfo {

// Directly-called functions from this function.
std::unordered_set<Name> calledFunctions;

// Types that are targets of indirect calls.
std::unordered_set<HeapType> indirectCalledTypes;
};

std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
const PassOptions& passOptions) {
ModuleUtils::ParallelFunctionAnalysis<FuncInfo> analysis(
module, [&](Function* func, FuncInfo& funcInfo) {
std::unordered_map<Function*, FuncInfo>
analyzeFuncs(Module& module, const PassOptions& passOptions) {
ModuleUtils::ParallelFunctionAnalysis<FuncInfo, Immutable, std::unordered_map>
analysis(module, [&](Function* func, FuncInfo& funcInfo) {
if (func->imported()) {
// Imports can do anything, so we need to assume the worst anyhow,
// which is the same as not specifying any effects for them in the
Expand Down Expand Up @@ -84,11 +90,21 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
// Note the direct call.
funcInfo.calledFunctions.insert(call->target);
} else if (effects.calls) {
// This is an indirect call of some sort, so we must assume the
// worst. To do so, clear the effects, which indicates nothing
// is known (so anything is possible).
// TODO: We could group effects by function type etc.
funcInfo.effects = UnknownEffects;
if (!options.closedWorld) {
funcInfo.effects = UnknownEffects;
return;
}

HeapType type;
if (auto* callRef = curr->dynCast<CallRef>()) {
type = callRef->target->type.getHeapType();
} else if (auto* callIndirect = curr->dynCast<CallIndirect>()) {
type = callIndirect->heapType;
} else {
assert("Unexpected type of call");
}

funcInfo.indirectCalledTypes.insert(type);
} else {
// No call here, but update throwing if we see it. (Only do so,
// however, if we have effects; if we cleared it - see before -
Expand All @@ -107,48 +123,124 @@ std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
return std::move(analysis.map);
}

using CallGraphNode = std::variant<Name, HeapType>;

// Build a call graph for indirect and direct calls.
// key (callee) -> value (caller)
// Name -> Name : callee is called directly by caller
// Name -> HeapType : callee is a potential target of a virtual call
// with this HeapType HeapType -> Name : callee is indirectly called by
// caller HeapType -> HeapType : callee is a subtype of caller If we're
// running in an open world, we only include Name -> Name edges.
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
buildReverseCallGraph(Module& module,
const std::unordered_map<Function*, FuncInfo>& funcInfos,
bool closedWorld) {
// callee : caller
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>> callers;

if (!closedWorld) {
for (const auto& [func, info] : funcInfos) {
// Name -> Name for direct calls
for (const auto& callee : info.calledFunctions) {
callers[callee].insert(func->name);
}
}

return callers;
}

std::unordered_set<HeapType> allIndirectCalledTypes;

for (const auto& [func, info] : funcInfos) {
// Name -> Name for direct calls
for (const auto& callee : info.calledFunctions) {
callers[callee].insert(func->name);
}

// HeapType -> Name for indirect calls
for (const auto& calleeType : info.indirectCalledTypes) {
callers[calleeType].insert(func->name);
}

// Name -> HeapType for function types
// TODO: only look at functions that are addressable
// i.e. appear in a (ref.func) or are exported
callers[func->name].insert(func->type.getHeapType());

allIndirectCalledTypes.insert(func->type.getHeapType());
}

SubTypes subtypes(module);
for (auto type : allIndirectCalledTypes) {
subtypes.iterSubTypes(type, [&callers, type](HeapType sub, Index _) {
// HeapType -> HeapType
// A subtype is a 'callee' of its supertype.
// Supertypes need to inherit effects from their subtypes since they may
// be called via a ref to the subtype.
callers[sub].insert(type);
return true;
});
}

return callers;
}

// Propagate effects from callees to callers transitively
// e.g. if A -> B -> C (A calls B which calls C)
// Then B inherits effects from C and A inherits effects from both B and C.
void propagateEffects(
const Module& module,
const std::unordered_map<Name, std::unordered_set<Name>>& reverseCallGraph,
std::map<Function*, FuncInfo>& funcInfos) {
const std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>&
reverseCallGraph,
std::unordered_map<Function*, FuncInfo>& funcInfos) {

UniqueNonrepeatingDeferredQueue<std::pair<Name, Name>> work;
using CallGraphEdge = std::pair<CallGraphNode, CallGraphNode>;
UniqueNonrepeatingDeferredQueue<CallGraphEdge> work;

for (const auto& [callee, callers] : reverseCallGraph) {
// We only care about roots that will lead to a Name -> Name connection
// If there's a HeapType with no Name callee, we don't need to process it
// anyway.
if (!std::holds_alternative<Name>(callee)) {
continue;
}
for (const auto& caller : callers) {
work.push(std::pair(callee, caller));
}
}

auto propagate = [&](Name callee, Name caller) {
auto propagate = [&](const CallGraphNode& calleeNode,
const CallGraphNode& callerNode) {
if (!std::holds_alternative<Name>(calleeNode) ||
!std::holds_alternative<Name>(callerNode)) {
return;
}

Name callee = std::get<Name>(calleeNode);
Name caller = std::get<Name>(callerNode);
auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects;
const auto& calleeEffects =
funcInfos.at(module.getFunction(callee)).effects;
if (!callerEffects) {
if (callerEffects == UnknownEffects) {
return;
}

if (!calleeEffects) {
if (calleeEffects == UnknownEffects) {
callerEffects = UnknownEffects;
return;
}

callerEffects->mergeIn(*calleeEffects);
if (callee == caller) {
callerEffects->trap = true;
} else {
callerEffects->mergeIn(*calleeEffects);
}
};

while (!work.empty()) {
auto [callee, caller] = work.pop();

if (callee == caller) {
auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects;
if (callerEffects) {
callerEffects->trap = true;
}
}

// Even if nothing changed, we still need to keep traversing the callers
// to look for a potential cycle which adds a trap affect on the above
// lines.
Expand All @@ -159,32 +251,30 @@ void propagateEffects(
continue;
}

for (const Name& callerCaller : callerCallers->second) {
// TODO: handle exact refs here
for (const CallGraphNode& callerCaller : callerCallers->second) {
work.push(std::pair(callee, callerCaller));
}
}
}

struct GenerateGlobalEffects : public Pass {
void run(Module* module) override {
std::map<Function*, FuncInfo> funcInfos =
std::unordered_map<Function*, FuncInfo> funcInfos =
analyzeFuncs(*module, getPassOptions());

// callee : caller
std::unordered_map<Name, std::unordered_set<Name>> callers;
for (const auto& [func, info] : funcInfos) {
for (const auto& callee : info.calledFunctions) {
callers[callee].insert(func->name);
}
}
std::unordered_map<CallGraphNode, std::unordered_set<CallGraphNode>>
callers =
buildReverseCallGraph(*module, funcInfos, getPassOptions().closedWorld);

propagateEffects(*module, callers, funcInfos);

// Generate the final data, starting from a blank slate where nothing is
// known.
for (auto& [func, info] : funcInfos) {
func->effects.reset();
if (!info.effects) {
if (info.effects == UnknownEffects) {
continue;
}

Expand Down
Loading
Loading