-
Notifications
You must be signed in to change notification settings - Fork 22
feat: add pretty run report #416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2007b1a
86133fa
f554bb2
c554d1a
810bd95
c09c7de
3d7bec7
ff7f688
01802ec
8ca9d10
e217a3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,6 +76,8 @@ class EGraphState: | |
| type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict) | ||
| egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict) | ||
|
|
||
| egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we instead just use |
||
|
|
||
| # Cache of egg expressions for converting to egg | ||
| expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict) | ||
|
|
||
|
|
@@ -86,6 +88,11 @@ class EGraphState: | |
| # Counter for deterministic synthetic names assigned to unnamed functions. | ||
| unnamed_function_counter: int = 0 | ||
|
|
||
| # Counter for numeric rule names | ||
| rule_name_counter: int = 0 | ||
| # Mapping from numeric name (str) to command decl | ||
| rule_name_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict) | ||
|
|
||
| def copy(self) -> EGraphState: | ||
| """ | ||
| Returns a copy of the state. The egraph reference is kept the same. Used for pushing/popping. | ||
|
|
@@ -102,6 +109,8 @@ def copy(self) -> EGraphState: | |
| cost_callables=self.cost_callables.copy(), | ||
| expr_to_let_counter=self.expr_to_let_counter, | ||
| unnamed_function_counter=self.unnamed_function_counter, | ||
| rule_name_counter=self.rule_name_counter, | ||
| rule_name_to_command_decl=self.rule_name_to_command_decl.copy(), | ||
| ) | ||
|
|
||
| def _run_program(self, *commands: bindings._Command) -> list[bindings._CommandOutput]: | ||
|
|
@@ -247,6 +256,17 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 | |
| case _: | ||
| assert_never(schedule) | ||
|
|
||
| def translate_rule_key(self, egglog_key: str) -> CommandDecl | str: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we remove this, and instead store in the |
||
| """ | ||
| Look up the original Python CommandDecl for an egglog rule key. | ||
| """ | ||
| clean_key = egglog_key.removesuffix("=>").removesuffix("<=") | ||
| if clean_key in self.rule_name_to_command_decl: | ||
| return self.rule_name_to_command_decl[clean_key] | ||
| if egglog_key in self.egg_rule_to_command_decl: | ||
| return self.egg_rule_to_command_decl[egglog_key] | ||
| return egglog_key | ||
|
|
||
| def ruleset_to_egg(self, ident: Ident) -> None: | ||
| """ | ||
| Registers a ruleset if it's not already registered. | ||
|
|
@@ -283,24 +303,33 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command | |
| return bindings.ActionCommand(action_egg) | ||
| case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions): | ||
| self.type_ref_to_egg(tp) | ||
| name = str(self.rule_name_counter) | ||
| self.rule_name_counter += 1 | ||
|
Comment on lines
+306
to
+307
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we now support name for rewrite/birewrite, could we expose this to the user level as well? And then this logic here would be similar to the RuleDecl handling, where it checks for an explicit name and if it doesn't have one generates one. This would entail adding the This isn't strictly necessary for this PR though so if you don't feel like doing this here that's fine. |
||
| self.rule_name_to_command_decl[name] = cmd | ||
| rewrite = bindings.Rewrite( | ||
| span(), | ||
| self._expr_to_egg(lhs), | ||
| self._expr_to_egg(rhs), | ||
| [self.fact_to_egg(c) for c in conditions], | ||
| name, | ||
| ) | ||
| return ( | ||
| bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) | ||
| if isinstance(cmd, RewriteDecl) | ||
| else bindings.BiRewriteCommand(str(ruleset), rewrite) | ||
| ) | ||
| egg_cmd: bindings._Command | ||
| if isinstance(cmd, RewriteDecl): | ||
| egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume) | ||
| else: | ||
| egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite) | ||
| return egg_cmd | ||
| case RuleDecl(head, body, name): | ||
| if not name: | ||
| name = str(self.rule_name_counter) | ||
| self.rule_name_counter += 1 | ||
| self.rule_name_to_command_decl[name] = cmd | ||
| return bindings.RuleCommand( | ||
| bindings.Rule( | ||
| span(), | ||
| [self.action_to_egg(a) for a in head], | ||
| [self.fact_to_egg(f) for f in body], | ||
| name or "", | ||
| name, | ||
| str(ruleset), | ||
| ) | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Callable | ||
| from dataclasses import dataclass, field | ||
| from datetime import timedelta | ||
|
|
||
| from . import bindings | ||
| from .declarations import CommandDecl, Declarations | ||
| from .egraph_state import EGraphState | ||
| from .pretty import pretty_decl | ||
|
|
||
|
|
||
| def _format_rule_key(decls: Declarations, key: CommandDecl | str) -> str: | ||
| if isinstance(key, str): | ||
| return key | ||
| return pretty_decl(decls, key) | ||
|
|
||
|
|
||
| @dataclass | ||
| class RuleReport: | ||
| plan: bindings.Plan | None | ||
| search_and_apply_time: timedelta | ||
| num_matches: int | ||
|
|
||
| @classmethod | ||
| def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport: | ||
| return cls( | ||
| plan=report.plan, | ||
| search_and_apply_time=report.search_and_apply_time, | ||
| num_matches=report.num_matches, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class RuleSetReport: | ||
| _decls: Declarations = field(repr=False) | ||
| changed: bool = False | ||
| rule_reports: dict[CommandDecl | str, list[RuleReport]] = field(default_factory=dict) | ||
| search_and_apply_time: timedelta = field(default_factory=timedelta) | ||
| merge_time: timedelta = field(default_factory=timedelta) | ||
|
|
||
| @classmethod | ||
| def _from_bindings( | ||
| cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations | ||
| ) -> RuleSetReport: | ||
| return cls( | ||
| _decls=decls, | ||
| changed=report.changed, | ||
| rule_reports={ | ||
| translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items() | ||
|
kaeun97 marked this conversation as resolved.
|
||
| }, | ||
| search_and_apply_time=report.search_and_apply_time, | ||
| merge_time=report.merge_time, | ||
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()} | ||
| return ( | ||
| f"RuleSetReport(changed={self.changed}, " | ||
| f"rule_reports={rule_reports_str}, " | ||
| f"search_and_apply_time={self.search_and_apply_time}, " | ||
| f"merge_time={self.merge_time})" | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class IterationReport: | ||
| rule_set_report: RuleSetReport | ||
| rebuild_time: timedelta | ||
|
|
||
| @classmethod | ||
| def _from_bindings( | ||
| cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations | ||
| ) -> IterationReport: | ||
| return cls( | ||
| rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls), | ||
| rebuild_time=report.rebuild_time, | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class RunReport: | ||
| """Python-friendly wrapper around bindings.RunReport.""" | ||
|
|
||
| _decls: Declarations = field(repr=False) | ||
| iterations: list[IterationReport] = field(default_factory=list) | ||
| updated: bool = False | ||
| search_and_apply_time_per_rule: dict[CommandDecl | str, timedelta] = field(default_factory=dict) | ||
| num_matches_per_rule: dict[CommandDecl | str, int] = field(default_factory=dict) | ||
|
Comment on lines
+88
to
+89
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we just store CommandDecl's here regardless of if it has a name or not, then just change the repr/str to display it the name as a string if has one, otherwise pretty print the full command? |
||
| search_and_apply_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict) | ||
| merge_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict) | ||
| rebuild_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict) | ||
|
|
||
| def __repr__(self) -> str: | ||
| time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()} | ||
| matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()} | ||
| return ( | ||
| f"RunReport(iterations={self.iterations}, " | ||
| f"updated={self.updated}, " | ||
| f"search_and_apply_time_per_rule={time_per_rule}, " | ||
| f"num_matches_per_rule={matches_per_rule}, " | ||
| f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, " | ||
| f"merge_time_per_ruleset={self.merge_time_per_ruleset}, " | ||
| f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})" | ||
| ) | ||
|
|
||
| @classmethod | ||
| def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport: | ||
| return cls( | ||
| _decls=state.__egg_decls__, | ||
| iterations=[ | ||
| IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__) | ||
| for it in report.iterations | ||
| ], | ||
| updated=report.updated, | ||
| search_and_apply_time_per_rule={ | ||
| state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items() | ||
| }, | ||
| num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()}, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we build these dictionaries from the bindings, could we check for duplicate keys (either named or unnamed) and combine the values for them? So that for BiRewrite, we don't lose the first one? |
||
| search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset, | ||
| merge_time_per_ruleset=report.merge_time_per_ruleset, | ||
| rebuild_time_per_ruleset=report.rebuild_time_per_ruleset, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add RunReport to
__all__