Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 147 additions & 141 deletions src/passes/GlobalEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,168 +22,172 @@
#include "ir/effects.h"
#include "ir/module-utils.h"
#include "pass.h"
#include "support/hash.h"
#include "support/unique_deferring_queue.h"
#include "wasm.h"

namespace wasm {

struct GenerateGlobalEffects : public Pass {
void run(Module* module) override {
// First, we do a scan of each function to see what effects they have,
// including which functions they call directly (so that we can compute
// transitive effects later).

struct FuncInfo {
// Effects in this function.
std::optional<EffectAnalyzer> effects;

// Directly-called functions from this function.
std::unordered_set<Name> calledFunctions;
};

ModuleUtils::ParallelFunctionAnalysis<FuncInfo> analysis(
*module, [&](Function* func, FuncInfo& funcInfo) {
if (func->imported()) {
// Imports can do anything, so we need to assume the worst anyhow,
// which is the same as not specifying any effects for them in the
// map (which we do by not setting funcInfo.effects).
return;
}

// Gather the effects.
funcInfo.effects.emplace(getPassOptions(), *module, func);

if (funcInfo.effects->calls) {
// There are calls in this function, which we will analyze in detail.
// Clear the |calls| field first, and we'll handle calls of all sorts
// below.
funcInfo.effects->calls = false;

// Clear throws as well, as we are "forgetting" calls right now, and
// want to forget their throwing effect as well. If we see something
// else that throws, below, then we'll note that there.
funcInfo.effects->throws_ = false;

struct CallScanner
: public PostWalker<CallScanner,
UnifiedExpressionVisitor<CallScanner>> {
Module& wasm;
PassOptions& options;
FuncInfo& funcInfo;

CallScanner(Module& wasm, PassOptions& options, FuncInfo& funcInfo)
: wasm(wasm), options(options), funcInfo(funcInfo) {}

void visitExpression(Expression* curr) {
ShallowEffectAnalyzer effects(options, wasm, curr);
if (auto* call = curr->dynCast<Call>()) {
// Note the direct call.
funcInfo.calledFunctions.insert(call->target);
} else if (effects.calls) {
// This is an indirect call of some sort, so we must assume the
// worst. To do so, clear the effects, which indicates nothing
// is known (so anything is possible).
// TODO: We could group effects by function type etc.
funcInfo.effects.reset();
} else {
// No call here, but update throwing if we see it. (Only do so,
// however, if we have effects; if we cleared it - see before -
// then we assume the worst anyhow, and have nothing to update.)
if (effects.throws_ && funcInfo.effects) {
funcInfo.effects->throws_ = true;
}
namespace {

struct FuncInfo {
// Effects in this function.
std::optional<EffectAnalyzer> effects;

// Directly-called functions from this function.
std::unordered_set<Name> calledFunctions;
};

std::map<Function*, FuncInfo> analyzeFuncs(Module& module,
const PassOptions& passOptions) {
ModuleUtils::ParallelFunctionAnalysis<FuncInfo> analysis(
module, [&](Function* func, FuncInfo& funcInfo) {
if (func->imported()) {
// Imports can do anything, so we need to assume the worst anyhow,
// which is the same as not specifying any effects for them in the
// map (which we do by not setting funcInfo.effects).
return;
}

// Gather the effects.
funcInfo.effects.emplace(passOptions, module, func);

if (funcInfo.effects->calls) {
// There are calls in this function, which we will analyze in detail.
// Clear the |calls| field first, and we'll handle calls of all sorts
// below.
funcInfo.effects->calls = false;

// Clear throws as well, as we are "forgetting" calls right now, and
// want to forget their throwing effect as well. If we see something
// else that throws, below, then we'll note that there.
funcInfo.effects->throws_ = false;

struct CallScanner
: public PostWalker<CallScanner,
UnifiedExpressionVisitor<CallScanner>> {
Module& wasm;
const PassOptions& options;
FuncInfo& funcInfo;

CallScanner(Module& wasm,
const PassOptions& options,
FuncInfo& funcInfo)
: wasm(wasm), options(options), funcInfo(funcInfo) {}

void visitExpression(Expression* curr) {
ShallowEffectAnalyzer effects(options, wasm, curr);
if (auto* call = curr->dynCast<Call>()) {
// Note the direct call.
funcInfo.calledFunctions.insert(call->target);
} else if (effects.calls) {
// This is an indirect call of some sort, so we must assume the
// worst. To do so, clear the effects, which indicates nothing
// is known (so anything is possible).
// TODO: We could group effects by function type etc.
funcInfo.effects.reset();
} else {
// No call here, but update throwing if we see it. (Only do so,
// however, if we have effects; if we cleared it - see before -
// then we assume the worst anyhow, and have nothing to update.)
if (effects.throws_ && funcInfo.effects) {
funcInfo.effects->throws_ = true;
}
}
};
CallScanner scanner(*module, getPassOptions(), funcInfo);
scanner.walkFunction(func);
}
});

// Compute the transitive closure of effects. To do so, first construct for
// each function a list of the functions that it is called by (so we need to
// propagate its effects to them), and then we'll construct the closure of
// that.
//
// callers[foo] = [func that calls foo, another func that calls foo, ..]
//
std::unordered_map<Name, std::unordered_set<Name>> callers;

// Our work queue contains info about a new call pair: a call from a caller
// to a called function, that is information we then apply and propagate.
using CallPair = std::pair<Name, Name>; // { caller, called }
UniqueDeferredQueue<CallPair> work;
for (auto& [func, info] : analysis.map) {
for (auto& called : info.calledFunctions) {
work.push({func->name, called});
}
};
CallScanner scanner(module, passOptions, funcInfo);
scanner.walkFunction(func);
}
});

return std::move(analysis.map);
}

// Propagate effects from callees to callers transitively
// e.g. if A -> B -> C (A calls B which calls C)
// Then B inherits effects from C and A inherits effects from both B and C.
void propagateEffects(
const Module& module,
const std::unordered_map<Name, std::unordered_set<Name>>& in,
std::map<Function*, FuncInfo>& funcInfos) {

std::unordered_set<std::pair<Name, Name>> processed;
std::deque<std::pair<Name, Name>> work;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you compare using a deque with using a UniqueDeferredQueue or a stack vector?

In DAE2 I found that using a stack vector was significantly faster than using UniqueDeferredQueue, but I don't think I tried a deque.


for (const auto& [callee, callers] : in) {
for (const auto& caller : callers) {
work.emplace_back(callee, caller);
processed.emplace(callee, caller);
}
}

auto propagate = [&](Name callee, Name caller) {
auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects;
const auto& calleeEffects =
funcInfos.at(module.getFunction(callee)).effects;
if (!callerEffects) {
return;
}

if (!calleeEffects) {
callerEffects.reset();
return;
Comment on lines +133 to +135
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this correct? The caller might have effects from its own body or from other callees.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nullopt here actually means "don't know" or "all effects" rather than "no effects". If we don't know the effects of the callee then we also don't know the effects of the caller.

I agree that this is misleading though. I think I can introduce a static method in EffectAnalyzer that gives back an EffectAnalyzer that contains all effects, that way we don't need an optional at all here and the logic stays more uniform, we just always merge unconditionally. Does that sound good?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds much simpler 👍

}

// Compute the transitive closure of the call graph, that is, fill out
// |callers| so that it contains the list of all callers - even through a
// chain - of each function.
while (!work.empty()) {
auto [caller, called] = work.pop();

// We must not already have an entry for this call (that would imply we
// are doing wasted work).
assert(!callers[called].contains(caller));

// Apply the new call information.
callers[called].insert(caller);

// We just learned that |caller| calls |called|. It also calls
// transitively, which we need to propagate to all places unaware of that
// information yet.
//
// caller => called => called by called
//
auto& calledInfo = analysis.map[module->getFunction(called)];
for (auto calledByCalled : calledInfo.calledFunctions) {
if (!callers[calledByCalled].contains(caller)) {
work.push({caller, calledByCalled});
}
callerEffects->mergeIn(*calleeEffects);
};

while (!work.empty()) {
auto [callee, caller] = work.back();
work.pop_back();

if (callee == caller) {
auto& callerEffects = funcInfos.at(module.getFunction(caller)).effects;
if (callerEffects) {
callerEffects->trap = true;
Comment on lines +145 to +148
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this does not handle mutual recursion the same way the previous implementation did. The previous implementation could find arbitrary recursion cycles by just looking at self edges in the transitive closure of the reverse call graph. But we don't have that transitive closure of the graph anymore, so looking for self edges is no longer sufficient to find nontrivial recursion cycles.

(So actually we may want to use the SCC utility to find cycles here after all 🫣)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here is basically the same as the way we computed the transitive closure before, the difference is that we just don't materialize the graph of the transitive closure. I think the case you're saying is tested here (or let me know if you have a different case in mind).

In this case we have 1 -> 2 and 2 -> 1 in the work list. Then we pop off 1 -> 2 and push 1 -> 1 which ends up hitting the cycle detection.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see... but isn't this inefficient? If we have call graph A -> B -> C, after we merge C's effects into B and then merge B's effects into A, we know that merging C's effects into A would not change anything.

But landing this and then iterating with smaller changes sounds fine, especially since this already improves on the status quo.

}
}

// Now that we have transitively propagated all static calls, apply that
// information. First, apply infinite recursion: if a function can call
// itself then it might recurse infinitely, which we consider an effect (a
// trap).
for (auto& [func, info] : analysis.map) {
if (callers[func->name].contains(func->name)) {
if (info.effects) {
info.effects->trap = true;
}
// Even if nothing changed, we still need to keep traversing the callers
// to look for a potential cycle which adds a trap affect on the above
// lines.
Comment on lines +152 to +154
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would also be possible to look at whether anything was changed by either the cycle-handling code above or the effect merge here. If not, then we don't need to propagate further. (This will be even simpler if we find cycles and apply the trap effect as a separate step before computing this fixed point.)

propagate(callee, caller);

const auto& callerCallers = in.find(caller);
if (callerCallers == in.end()) {
continue;
}

for (const Name& callerCaller : callerCallers->second) {
if (processed.contains({callee, callerCaller})) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use the bool returned from emplace to avoid this separate call to contains.

continue;
}

processed.emplace(callee, callerCaller);
work.emplace_back(callee, callerCaller);
}
}
}

struct GenerateGlobalEffects : public Pass {
void run(Module* module) override {
std::map<Function*, FuncInfo> funcInfos =
analyzeFuncs(*module, getPassOptions());

// Next, apply function effects to their callers.
for (auto& [func, info] : analysis.map) {
auto& funcEffects = info.effects;

for (auto& caller : callers[func->name]) {
auto& callerEffects = analysis.map[module->getFunction(caller)].effects;
if (!callerEffects) {
// Nothing is known for the caller, which is already the worst case.
continue;
}

if (!funcEffects) {
// Nothing is known for the called function, which means nothing is
// known for the caller either.
callerEffects.reset();
continue;
}

// Add func's effects to the caller.
callerEffects->mergeIn(*funcEffects);
// callee : caller
std::unordered_map<Name, std::unordered_set<Name>> callers;
for (const auto& [func, info] : funcInfos) {
for (const auto& callee : info.calledFunctions) {
callers[callee].insert(func->name);
}
}

propagateEffects(*module, callers, funcInfos);

// Generate the final data, starting from a blank slate where nothing is
// known.
for (auto& [func, info] : analysis.map) {
for (auto& [func, info] : funcInfos) {
func->effects.reset();
if (!info.effects) {
continue;
Expand All @@ -202,6 +206,8 @@ struct DiscardGlobalEffects : public Pass {
}
};

} // namespace

Pass* createGenerateGlobalEffectsPass() { return new GenerateGlobalEffects(); }

Pass* createDiscardGlobalEffectsPass() { return new DiscardGlobalEffects(); }
Expand Down
Loading