Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
793 changes: 439 additions & 354 deletions Cargo.lock

Large diffs are not rendered by default.

22 changes: 10 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ opentelemetry_sdk = "0.28"
# egglog-core-relations = { path = "../egg-smol/core-relations" }
# egglog-ast = { path = "../egg-smol/egglog-ast" }
# egglog-reports = { path = "../egg-smol/egglog-reports" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug", default-features = false }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }


egglog = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b", default-features = false }
egglog-ast = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-core-relations = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-reports = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-bridge = { git = "https://github.com/egraphs-good/egglog.git", rev = "2e5657b" }
egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental", branch = "main", default-features = false }
egraph-serialize = { version = "0.3", features = ["serde", "graphviz"] }
serde_json = "1"
Expand All @@ -52,11 +50,11 @@ base64 = "0.22.1"
# egglog-reports = { path = "../egg-smol/egglog-reports" }
# egglog-bridge = { path = "../egg-smol/egglog-bridge" }

egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", branch = "fix-container-fn-bug" }
egglog = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-ast = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-core-relations = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-bridge = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }
egglog-reports = { git = "https://github.com/saulshanabrook/egg-smol.git", rev = "2e5657b" }

# enable debug symbols for easier profiling
[profile.release]
Expand Down
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _This project uses semantic versioning_

## 13.1.0 (2026-03-25)

- Add Python-friendly `RunReport` wrapper that returns `CommandDecl` objects as rule keys instead of raw egglog s-expression strings, with pretty-printed Python syntax in `str()` output [#416](https://github.com/egraphs-good/egglog-python/pull/416)
- Improve high-level Python ergonomics and docs [#397](https://github.com/egraphs-good/egglog-python/pull/397)
- Add `EGraph.freeze()`, returning a `FrozenEGraph` snapshot that can be pretty-printed back into replayable high-level Python actions for debugging and inspection.
- Add a variadic `EGraph(*actions, seminaive=True, save_egglog_string=False)` constructor so actions can be registered at construction time, and export `ActionLike` from `egglog` for typing code that works with `EGraph.register(...)` and the constructor.
Expand Down
5 changes: 4 additions & 1 deletion python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,11 @@ class Rewrite:
lhs: _Expr
rhs: _Expr
conditions: list[_Fact]
name: str

def __new__(cls, span: _Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = ...) -> Rewrite: ...
def __new__(
cls, span: _Span, lhs: _Expr, rhs: _Expr, conditions: list[_Fact] = ..., name: str = ...
) -> Rewrite: ...

@final
class RunConfig:
Expand Down
17 changes: 8 additions & 9 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .egraph_state import *
from .ipython_magic import IN_IPYTHON
from .pretty import pretty_decl
from .run_report import RunReport
from .runtime import *
from .thunk import *

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add RunReport to __all__

Expand Down Expand Up @@ -953,36 +954,34 @@ def output(self) -> None:
raise NotImplementedError(msg)

@overload
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> bindings.RunReport: ...
def run(self, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport: ...

@overload
def run(self, schedule: Schedule, /) -> bindings.RunReport: ...
def run(self, schedule: Schedule, /) -> RunReport: ...

@_TRACER.start_as_current_span("run")
def run(
self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None
) -> bindings.RunReport:
def run(self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Ruleset | None = None) -> RunReport:
"""
Run the egraph until the given limit or until the given facts are true.
"""
if isinstance(limit_or_schedule, int):
limit_or_schedule = run(ruleset, *until) * limit_or_schedule
return self._run_schedule(limit_or_schedule)

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
def _run_schedule(self, schedule: Schedule) -> RunReport:
self._add_decls(schedule)
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report
return RunReport._from_bindings(command_output.report, self._state)

def stats(self) -> bindings.RunReport:
def stats(self) -> RunReport:
"""
Returns the overall run report for the egraph.
"""
(output,) = self._run_program(bindings.PrintOverallStatistics(span(1), None))
assert isinstance(output, bindings.OverallStatistics)
return output.report
return RunReport._from_bindings(output.report, self._state)

def check_bool(self, *facts: FactLike) -> bool:
"""
Expand Down
41 changes: 35 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class EGraphState:
type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)

egg_rule_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we instead just use rule_name_to_command_decl, so we can remove this additional mapping and there is just one source? We will know which ones are named, because we can see if the CommandDecl has a name or not. We can also update it to be more specific and just go from str to RuleDecl | BiRewriteDecl | RewriteDecl I believe.


# Cache of egg expressions for converting to egg
expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)

Expand All @@ -86,6 +88,11 @@ class EGraphState:
# Counter for deterministic synthetic names assigned to unnamed functions.
unnamed_function_counter: int = 0

# Counter for numeric rule names
rule_name_counter: int = 0
# Mapping from numeric name (str) to command decl
rule_name_to_command_decl: dict[str, CommandDecl] = field(default_factory=dict)

def copy(self) -> EGraphState:
"""
Returns a copy of the state. The egraph reference is kept the same. Used for pushing/popping.
Expand All @@ -102,6 +109,8 @@ def copy(self) -> EGraphState:
cost_callables=self.cost_callables.copy(),
expr_to_let_counter=self.expr_to_let_counter,
unnamed_function_counter=self.unnamed_function_counter,
rule_name_counter=self.rule_name_counter,
rule_name_to_command_decl=self.rule_name_to_command_decl.copy(),
)

def _run_program(self, *commands: bindings._Command) -> list[bindings._CommandOutput]:
Expand Down Expand Up @@ -247,6 +256,17 @@ def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
case _:
assert_never(schedule)

def translate_rule_key(self, egglog_key: str) -> CommandDecl | str:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we remove this, and instead store in the rule_name_to_command_decl version for <= and => when adding a bi-rewrite? Then that structure should always include all egglog rules we output, so we can do a lookup and if it's missing the exception just percolates up, avoiding a silent failure?

"""
Look up the original Python CommandDecl for an egglog rule key.
"""
clean_key = egglog_key.removesuffix("=>").removesuffix("<=")
if clean_key in self.rule_name_to_command_decl:
return self.rule_name_to_command_decl[clean_key]
if egglog_key in self.egg_rule_to_command_decl:
return self.egg_rule_to_command_decl[egglog_key]
return egglog_key

def ruleset_to_egg(self, ident: Ident) -> None:
"""
Registers a ruleset if it's not already registered.
Expand Down Expand Up @@ -283,24 +303,33 @@ def command_to_egg(self, cmd: CommandDecl, ruleset: Ident) -> bindings._Command
return bindings.ActionCommand(action_egg)
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
self.type_ref_to_egg(tp)
name = str(self.rule_name_counter)
self.rule_name_counter += 1
Comment on lines +306 to +307
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we now support name for rewrite/birewrite, could we expose this to the user level as well? And then this logic here would be similar to the RuleDecl handling, where it checks for an explicit name and if it doesn't have one generates one. This would entail adding the name to pretty.py, declarations.py and egraph.py I believe.

This isn't strictly necessary for this PR though so if you don't feel like doing this here that's fine.

self.rule_name_to_command_decl[name] = cmd
rewrite = bindings.Rewrite(
span(),
self._expr_to_egg(lhs),
self._expr_to_egg(rhs),
[self.fact_to_egg(c) for c in conditions],
name,
)
return (
bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
if isinstance(cmd, RewriteDecl)
else bindings.BiRewriteCommand(str(ruleset), rewrite)
)
egg_cmd: bindings._Command
if isinstance(cmd, RewriteDecl):
egg_cmd = bindings.RewriteCommand(str(ruleset), rewrite, cmd.subsume)
else:
egg_cmd = bindings.BiRewriteCommand(str(ruleset), rewrite)
return egg_cmd
case RuleDecl(head, body, name):
if not name:
name = str(self.rule_name_counter)
self.rule_name_counter += 1
self.rule_name_to_command_decl[name] = cmd
return bindings.RuleCommand(
bindings.Rule(
span(),
[self.action_to_egg(a) for a in head],
[self.fact_to_egg(f) for f in body],
name or "",
name,
str(ruleset),
)
)
Expand Down
123 changes: 123 additions & 0 deletions python/egglog/run_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import timedelta

from . import bindings
from .declarations import CommandDecl, Declarations
from .egraph_state import EGraphState
from .pretty import pretty_decl


def _format_rule_key(decls: Declarations, key: CommandDecl | str) -> str:
if isinstance(key, str):
return key
return pretty_decl(decls, key)


@dataclass
class RuleReport:
plan: bindings.Plan | None
search_and_apply_time: timedelta
num_matches: int

@classmethod
def _from_bindings(cls, report: bindings.RuleReport) -> RuleReport:
return cls(
plan=report.plan,
search_and_apply_time=report.search_and_apply_time,
num_matches=report.num_matches,
)


@dataclass
class RuleSetReport:
_decls: Declarations = field(repr=False)
changed: bool = False
rule_reports: dict[CommandDecl | str, list[RuleReport]] = field(default_factory=dict)
search_and_apply_time: timedelta = field(default_factory=timedelta)
merge_time: timedelta = field(default_factory=timedelta)

@classmethod
def _from_bindings(
cls, report: bindings.RuleSetReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations
) -> RuleSetReport:
return cls(
_decls=decls,
changed=report.changed,
rule_reports={
translate_key(k): [RuleReport._from_bindings(rr) for rr in v] for k, v in report.rule_reports.items()
Comment thread
kaeun97 marked this conversation as resolved.
},
search_and_apply_time=report.search_and_apply_time,
merge_time=report.merge_time,
)

def __repr__(self) -> str:
rule_reports_str = {_format_rule_key(self._decls, k): v for k, v in self.rule_reports.items()}
return (
f"RuleSetReport(changed={self.changed}, "
f"rule_reports={rule_reports_str}, "
f"search_and_apply_time={self.search_and_apply_time}, "
f"merge_time={self.merge_time})"
)


@dataclass
class IterationReport:
rule_set_report: RuleSetReport
rebuild_time: timedelta

@classmethod
def _from_bindings(
cls, report: bindings.IterationReport, translate_key: Callable[[str], CommandDecl | str], decls: Declarations
) -> IterationReport:
return cls(
rule_set_report=RuleSetReport._from_bindings(report.rule_set_report, translate_key, decls),
rebuild_time=report.rebuild_time,
)


@dataclass
class RunReport:
"""Python-friendly wrapper around bindings.RunReport."""

_decls: Declarations = field(repr=False)
iterations: list[IterationReport] = field(default_factory=list)
updated: bool = False
search_and_apply_time_per_rule: dict[CommandDecl | str, timedelta] = field(default_factory=dict)
num_matches_per_rule: dict[CommandDecl | str, int] = field(default_factory=dict)
Comment on lines +88 to +89
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we just store CommandDecl's here regardless of if it has a name or not, then just change the repr/str to display it the name as a string if has one, otherwise pretty print the full command?

search_and_apply_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
merge_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)
rebuild_time_per_ruleset: dict[str, timedelta] = field(default_factory=dict)

def __repr__(self) -> str:
time_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.search_and_apply_time_per_rule.items()}
matches_per_rule = {_format_rule_key(self._decls, k): v for k, v in self.num_matches_per_rule.items()}
return (
f"RunReport(iterations={self.iterations}, "
f"updated={self.updated}, "
f"search_and_apply_time_per_rule={time_per_rule}, "
f"num_matches_per_rule={matches_per_rule}, "
f"search_and_apply_time_per_ruleset={self.search_and_apply_time_per_ruleset}, "
f"merge_time_per_ruleset={self.merge_time_per_ruleset}, "
f"rebuild_time_per_ruleset={self.rebuild_time_per_ruleset})"
)

@classmethod
def _from_bindings(cls, report: bindings.RunReport, state: EGraphState) -> RunReport:
return cls(
_decls=state.__egg_decls__,
iterations=[
IterationReport._from_bindings(it, state.translate_rule_key, state.__egg_decls__)
for it in report.iterations
],
updated=report.updated,
search_and_apply_time_per_rule={
state.translate_rule_key(k): v for k, v in report.search_and_apply_time_per_rule.items()
},
num_matches_per_rule={state.translate_rule_key(k): v for k, v in report.num_matches_per_rule.items()},
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we build these dictionaries from the bindings, could we check for duplicate keys (either named or unnamed) and combine the values for them? So that for BiRewrite, we don't lose the first one?

search_and_apply_time_per_ruleset=report.search_and_apply_time_per_ruleset,
merge_time_per_ruleset=report.merge_time_per_ruleset,
rebuild_time_per_ruleset=report.rebuild_time_per_ruleset,
)
Loading
Loading