From 0698e55dc25d90deb584686c95bdc97956bffaae Mon Sep 17 00:00:00 2001 From: default Date: Fri, 22 May 2026 02:14:18 +0000 Subject: [PATCH] Revamp the rust part of the tutorial with soluttions --- case-study-linear-algebra-compiler/Cargo.lock | 400 +++++++++++------- case-study-linear-algebra-compiler/Cargo.toml | 2 +- .../rust-toolchain.toml | 2 +- .../solution.patch | 369 ++++++++++++++++ .../src/defn.solution.egg | 104 +++++ .../src/main.rs | 109 +++-- 6 files changed, 788 insertions(+), 198 deletions(-) create mode 100644 case-study-linear-algebra-compiler/solution.patch create mode 100644 case-study-linear-algebra-compiler/src/defn.solution.egg diff --git a/case-study-linear-algebra-compiler/Cargo.lock b/case-study-linear-algebra-compiler/Cargo.lock index e707c14..8afa5ec 100644 --- a/case-study-linear-algebra-compiler/Cargo.lock +++ b/case-study-linear-algebra-compiler/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 4 -[[package]] -name = "add_primitive" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog.git?rev=24ea499#24ea499f7301cdec8198a381b5428589deb8b231" -dependencies = [ - "quote", - "syn 2.0.102", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -148,6 +139,16 @@ version = "3.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee" +[[package]] +name = "cc" +version = "1.2.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.1" @@ -209,44 +210,6 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" -[[package]] -name = "concurrency" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog-backend.git?rev=a0b98d3#a0b98d355601bff4d6c13cf4420a875a14b484e8" -dependencies = [ - "arc-swap", - "rayon", -] - -[[package]] -name = "core-relations" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog-backend.git?rev=a0b98d3#a0b98d355601bff4d6c13cf4420a875a14b484e8" -dependencies = [ - "anyhow", - "bumpalo", - "concurrency", - "crossbeam-queue", - "dashmap", - "dyn-clone", - "fixedbitset 0.5.7", - "hashbrown 0.15.4", - "indexmap", - "lazy_static", - "log", - "num", - "numeric-id", - "once_cell", - "petgraph 0.6.5", - "rand", - "rayon", - "rustc-hash", - "smallvec", - "thiserror 2.0.12", - "union-find", - "web-time", -] - [[package]] name = "cpufeatures" version = "0.2.17" @@ -322,6 +285,27 @@ dependencies = [ "typenum", ] +[[package]] +name = "csv" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "704a3c26996a80471189265814dbc2c257598b96b8a7feae2d31ace646bb9782" +dependencies = [ + "memchr", +] + [[package]] name = "dashmap" version = "6.1.0" @@ -369,68 +353,167 @@ checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" [[package]] name = "egglog" -version = "0.5.0" -source = "git+https://github.com/egraphs-good/egglog.git?rev=24ea499#24ea499f7301cdec8198a381b5428589deb8b231" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" dependencies = [ - "add_primitive", "chrono", "clap", - "core-relations", + "csv", "dyn-clone", + "egglog-add-primitive", + "egglog-ast", "egglog-bridge", + "egglog-core-relations", + "egglog-numeric-id", + "egglog-reports", "egraph-serialize", + "enum-map", "env_logger", - "hashbrown 0.15.4", + "hashbrown 0.16.0", "im-rc", "indexmap", - "lazy_static", "log", + "mimalloc", "num", - "numeric-id", "ordered-float", + "rayon", "rustc-hash", - "thiserror 2.0.12", + "serde_json", + "thiserror", "web-time", ] +[[package]] +name = "egglog-add-primitive" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" +dependencies = [ + "quote", + "syn 2.0.102", +] + +[[package]] +name = "egglog-ast" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" +dependencies = [ + "ordered-float", +] + [[package]] name = "egglog-bridge" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog-backend.git?rev=a0b98d3#a0b98d355601bff4d6c13cf4420a875a14b484e8" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" dependencies = [ "anyhow", - "core-relations", "dyn-clone", - "hashbrown 0.15.4", + "egglog-core-relations", + "egglog-numeric-id", + "egglog-reports", + "egglog-union-find", + "hashbrown 0.16.0", "indexmap", "log", "num-rational", - "numeric-id", "once_cell", - "petgraph 0.6.5", + "ordered-float", "rayon", "smallvec", - "thiserror 1.0.69", - "union-find", + "thiserror", + "web-time", +] + +[[package]] +name = "egglog-concurrency" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" +dependencies = [ + "arc-swap", + "bumpalo", + "egglog-numeric-id", + "rayon", + "smallvec", +] + +[[package]] +name = "egglog-core-relations" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" +dependencies = [ + "anyhow", + "bumpalo", + "crossbeam", + "crossbeam-queue", + "dashmap", + "dyn-clone", + "egglog-concurrency", + "egglog-numeric-id", + "egglog-reports", + "egglog-union-find", + "fixedbitset", + "hashbrown 0.16.0", + "indexmap", + "log", + "num", + "once_cell", + "rand 0.9.4", + "rayon", + "rustc-hash", + "smallvec", + "thiserror", "web-time", ] [[package]] name = "egglog-experimental" version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog-experimental.git?rev=202078f#202078fee12499f0160293f172546548f909b3e3" +source = "git+https://github.com/egraphs-good/egglog-experimental.git?rev=a0a911f#a0a911fe806ca919bccda14baececeaad6cd9819" dependencies = [ "egglog", + "egglog-ast", + "egglog-reports", "lazy_static", "log", "num", ] +[[package]] +name = "egglog-numeric-id" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" +dependencies = [ + "rayon", +] + +[[package]] +name = "egglog-reports" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" +dependencies = [ + "clap", + "hashbrown 0.16.0", + "indexmap", + "rustc-hash", + "serde", + "serde_json", + "web-time", +] + +[[package]] +name = "egglog-union-find" +version = "2.0.0" +source = "git+https://github.com/saulshanabrook/egg-smol.git?branch=codex%2Fsplit-scheduler-can-stop-report#ebba7bb902bdc1b0f377b6bb22c06ac305912674" +dependencies = [ + "crossbeam", + "egglog-concurrency", + "egglog-numeric-id", +] + [[package]] name = "egraph-serialize" -version = "0.2.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31c5c0d7f760f9c1c84e21d73dcd3b3ce7a4770c27689f56a0db26e0f3e79ca" +checksum = "0977732fb537ace6f8c15ce160ebdda78b6502b4866d3b904e4fe752e2be4702" dependencies = [ "graphviz-rust", "indexmap", @@ -455,6 +538,26 @@ dependencies = [ "log", ] +[[package]] +name = "enum-map" +version = "2.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6866f3bfdf8207509a033af1a75a7b08abda06bbaaeae6669323fd5a097df2e9" +dependencies = [ + "enum-map-derive", +] + +[[package]] +name = "enum-map-derive" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.102", +] + [[package]] name = "env_filter" version = "0.1.3" @@ -501,10 +604,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] -name = "fixedbitset" -version = "0.4.2" +name = "find-msvc-tools" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" [[package]] name = "fixedbitset" @@ -514,9 +617,9 @@ checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "foldhash" -version = "0.1.5" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" [[package]] name = "generic-array" @@ -528,17 +631,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "getrandom" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" -dependencies = [ - "cfg-if", - "libc", - "wasi 0.11.1+wasi-snapshot-preview1", -] - [[package]] name = "getrandom" version = "0.3.3" @@ -548,14 +640,14 @@ dependencies = [ "cfg-if", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasi", ] [[package]] name = "graphviz-rust" -version = "0.6.6" +version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27dafd1ac303e0dfb347a3861d9ac440859bab26ec2f534bbceb262ea492a1e0" +checksum = "dee83cefff83c5dd5f34c603145f4e8d478e70cc17873049b6a36eeaf37b250a" dependencies = [ "dot-generator", "dot-structures", @@ -563,7 +655,7 @@ dependencies = [ "into-attr-derive", "pest", "pest_derive", - "rand", + "rand 0.9.4", "tempfile", ] @@ -578,10 +670,17 @@ name = "hashbrown" version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" + +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" dependencies = [ "allocator-api2", "equivalent", "foldhash", + "serde", ] [[package]] @@ -606,7 +705,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af1955a75fa080c677d3972822ec4bad316169ab1cfc6c257a942c2265dbe5fe" dependencies = [ "bitmaps", - "rand_core", + "rand_core 0.6.4", "rand_xoshiro", "sized-chunks", "typenum", @@ -721,7 +820,7 @@ dependencies = [ "ena", "itertools", "lalrpop-util", - "petgraph 0.7.1", + "petgraph", "pico-args", "regex", "regex-syntax", @@ -754,6 +853,15 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "libmimalloc-sys" +version = "0.1.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2892ae4ea6fa2cb7acb0e236a6880d39523239cd9089de71d220910ccc806790" +dependencies = [ + "cc", +] + [[package]] name = "linear-algebra-compiler" version = "0.1.0" @@ -792,6 +900,15 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "mimalloc" +version = "0.1.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebca48a43116bc25f18a61360f1be98412f50cc218f5e52c823086b999a4a21a" +dependencies = [ + "libmimalloc-sys", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -871,15 +988,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "numeric-id" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog-backend.git?rev=a0b98d3#a0b98d355601bff4d6c13cf4420a875a14b484e8" -dependencies = [ - "lazy_static", - "rayon", -] - [[package]] name = "once_cell" version = "1.21.3" @@ -894,12 +1002,12 @@ checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" [[package]] name = "ordered-float" -version = "3.9.2" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1e1c390732d15f1d48471625cd92d154e66db2c56645e29a9cd26f4699f72dc" +checksum = "b7d950ca161dc355eaf28f82b11345ed76c6e1f6eb1f4f4479e0323b9e2fbd0e" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", "serde", ] @@ -933,7 +1041,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "198db74531d58c70a361c42201efde7e2591e976d518caf7662a47dc5720e7b6" dependencies = [ "memchr", - "thiserror 2.0.12", + "thiserror", "ucd-trie", ] @@ -971,23 +1079,13 @@ dependencies = [ "sha2", ] -[[package]] -name = "petgraph" -version = "0.6.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" -dependencies = [ - "fixedbitset 0.4.2", - "indexmap", -] - [[package]] name = "petgraph" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ - "fixedbitset 0.5.7", + "fixedbitset", "indexmap", ] @@ -1066,20 +1164,28 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", - "rand_chacha", - "rand_core", + "rand_core 0.6.4", "serde", ] +[[package]] +name = "rand" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c5af06bb1b7d3216d91932aed5265164bf384dc89cd6ba05cf59a35f5f76ea" +dependencies = [ + "rand_chacha", + "rand_core 0.9.5", +] + [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.5", ] [[package]] @@ -1088,17 +1194,25 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.16", "serde", ] +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" dependencies = [ - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -1259,6 +1373,12 @@ dependencies = [ "keccak", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "siphasher" version = "1.0.1" @@ -1328,7 +1448,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ "fastrand", - "getrandom 0.3.3", + "getrandom", "once_cell", "rustix", "windows-sys", @@ -1344,33 +1464,13 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "thiserror" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" -dependencies = [ - "thiserror-impl 1.0.69", -] - [[package]] name = "thiserror" version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.12", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.102", + "thiserror-impl", ] [[package]] @@ -1408,16 +1508,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" -[[package]] -name = "union-find" -version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog-backend.git?rev=a0b98d3#a0b98d355601bff4d6c13cf4420a875a14b484e8" -dependencies = [ - "concurrency", - "crossbeam", - "numeric-id", -] - [[package]] name = "utf8parse" version = "0.2.2" @@ -1440,12 +1530,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "wasi" -version = "0.11.1+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" - [[package]] name = "wasi" version = "0.14.2+wasi-0.2.4" diff --git a/case-study-linear-algebra-compiler/Cargo.toml b/case-study-linear-algebra-compiler/Cargo.toml index 67c8cd4..31f7a5d 100644 --- a/case-study-linear-algebra-compiler/Cargo.toml +++ b/case-study-linear-algebra-compiler/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" lalrpop-util = { version = "0.22.2", features = [ "lexer", ] } -egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental.git", rev = "202078f" } +egglog-experimental = { git = "https://github.com/egraphs-good/egglog-experimental.git", rev = "a0a911f" } regex = "1.11.1" [build-dependencies] diff --git a/case-study-linear-algebra-compiler/rust-toolchain.toml b/case-study-linear-algebra-compiler/rust-toolchain.toml index b8889a3..d72668b 100644 --- a/case-study-linear-algebra-compiler/rust-toolchain.toml +++ b/case-study-linear-algebra-compiler/rust-toolchain.toml @@ -1,2 +1,2 @@ [toolchain] -channel = "1.87.0" +channel = "1.91.0" diff --git a/case-study-linear-algebra-compiler/solution.patch b/case-study-linear-algebra-compiler/solution.patch new file mode 100644 index 0000000..d85055e --- /dev/null +++ b/case-study-linear-algebra-compiler/solution.patch @@ -0,0 +1,369 @@ +--- src/main.rs ++++ src/main.rs +@@ -1,5 +1,4 @@ + #![allow(dead_code)] +-#![allow(unreachable_code)] + #![allow(unused_imports)] + #![allow(unused_mut)] + #![allow(unused_variables)] +@@ -24,7 +23,7 @@ + use crate::util::*; + + fn egglog_program() -> &'static str { +- include_str!("defn.egg") ++ include_str!("defn.solution.egg") + } + + fn main() { +@@ -44,118 +43,85 @@ + } + }; + +- // Problem 1: run defn.egg +- // +- // Create a new experimental EGraph with `new_experimental_egraph()` and run +- // the egglog program from defn.egg using +- // `egraph.parse_and_run_program(None, program)`. ++ // Problem 1 + let mut egraph = new_experimental_egraph(); +- todo!("Problem 1"); ++ egraph.parse_and_run_program(None, egglog_program()).unwrap(); + +- // Problem 2: +- // +- // For each variable declaration in the input program, +- // bind the egglog variable x to either (MVar x) or (SVar x) +- // +- // For each matrix declaration, additionally insert its dimension +- // information into the "MatrixDim" relation in the E-graph: +- // +- // (relation MatrixDim (String i64 i64)) +- // +- // Hint: use `egraph.parse_and_run_program(None, &format!("..."))`. ++ // Problem 2 + for decl in core_bindings.declares.iter() { + let x = &decl.var; + if let Type::Matrix { nrows, ncols } = decl.ty { +- todo!("Problem 2") ++ egraph ++ .parse_and_run_program( ++ None, ++ &format!( ++ "(let {x} (MVar \"{x}\")) (MatrixDim \"{x}\" {nrows} {ncols})" ++ ), ++ ) ++ .unwrap(); + } else { +- todo!("Problem 2") ++ egraph ++ .parse_and_run_program(None, &format!("(let {x} (SVar \"{x}\"))")) ++ .unwrap(); + } + } + +- // Problem 3: +- // +- // For each variable assignment in the input program, +- // bind the egglog variable to its corresponding expression. +- // +- // We have provided [`to_egglog_expr`] in util.rs that converts a [`CoreExpr`] +- // to an egglog AST expression [`egglog::ast::Expr`]. +- // +- // To run a `let` command, build a +- // Command::Action(Action::Let(span!(), name, expr)) +- // and pass it to `egraph.run_program(vec![cmd])`. ++ // Problem 3 + for bind in core_bindings.bindings.iter() { + let var = &bind.var; +- let expr = &bind.expr; +- +- todo!("Problem 3") ++ let expr = to_egglog_expr(&bind.expr); ++ let cmd = Command::Action(Action::Let(span!(), var.clone(), expr)); ++ egraph.run_program(vec![cmd]).unwrap(); ++ } ++ ++ // Problem 6: saturate analysis before each optimization step (replaces Problem 4's `(run 20)`) ++ egraph ++ .parse_and_run_program( ++ None, ++ "(run-schedule (repeat 20 (saturate analysis) (run optimization)))", ++ ) ++ .unwrap(); ++ ++ // Problem 7: use FirstNScheduler for the optimization ruleset ++ let scheduler_id = egraph.add_scheduler(Box::new(FirstNScheduler { n: 3 })); ++ for _ in 0..20 { ++ loop { ++ let report = egraph.step_rules("analysis").unwrap(); ++ if report.can_stop { ++ break; ++ } ++ } ++ let report = egraph ++ .step_rules_with_scheduler(scheduler_id, "optimization") ++ .unwrap(); ++ if report.can_stop { ++ break; ++ } + } + +- // Problem 4: +- // +- // Now we have inserted all the ASTs and definitions. We will run our rules. +- // To start with, let's run our rules 20 times with `(run 20)`: +- // +- // egraph.parse_and_run_program(None, "(run 20)").unwrap(); +- +- todo!("Problem 4"); +- +- // Problem 5: +- // +- // Extract the optimized program using `DynamicCostModel` and `Extractor`. +- // +- // Steps: +- // a) Evaluate the output variable to get its (ArcSort, Value): +- // egraph.eval_expr(&exprs::var(&output.var)) +- // b) Compute an extractor: +- // Extractor::compute_costs_from_rootsorts(Some(vec![sort]), &egraph, DynamicCostModel) +- // c) Extract the best term: +- // extractor.extract_best(&egraph, &mut termdag, value) +- // d) Convert to CoreBindings using `termdag_to_bindings` ++ // Problem 5 + let output = core_bindings.bindings.last().unwrap(); +- let bindings: CoreBindings = todo!("Problem 5"); ++ let (sort, value) = egraph.eval_expr(&exprs::var(&output.var)).unwrap(); ++ let extractor = Extractor::compute_costs_from_rootsorts( ++ Some(vec![sort]), ++ &egraph, ++ DynamicCostModel, ++ ); ++ let mut termdag = TermDag::default(); ++ let (_cost, term_id) = extractor ++ .extract_best(&egraph, &mut termdag, value) ++ .unwrap(); ++ let bindings = termdag_to_bindings( ++ core_bindings.declares.clone(), ++ &termdag, ++ termdag.get(term_id), ++ ); + + // Print the optimized bindings + println!("{bindings}"); +- +- // Problem 6: +- // +- // Break down the rules into optimization rules and analysis rules by adding +- // `:ruleset optimization` and `:ruleset analysis` annotations in defn.egg, +- // and declaring the rulesets: +- // +- // (ruleset optimization) +- // (ruleset analysis) +- // +- // Then replace `(run 20)` above with a schedule that: +- // - Saturates the analysis rules before each optimization step +- // - Runs the optimization rules once per iteration +- // +- // Hint: use `(run-schedule (repeat 20 (saturate analysis) (run optimization)))`. +- // +- // Also update `egglog_program()` to include your updated defn file. +- +- // Problem 7: +- // +- // Fill in the blanks for the [`FirstNScheduler`] below. FirstNScheduler +- // applies at most `n` matches of a rule in each iteration. Compared +- // to the default scheduler, it allows the E-graph to grow more gently. +- // +- // Register the scheduler with the egraph: +- // +- // let scheduler_id = egraph.add_scheduler(Box::new(FirstNScheduler { n: 3 })); +- // +- // Then use `egraph.step_rules_with_scheduler(scheduler_id, "optimization")` +- // inside a loop to run the optimization ruleset with the scheduler. +- // +- // Update Problem 6's schedule to use this scheduler for optimization rules. + } + +-// Problem 7: +-// +-// Implement the `filter_matches` method. +-// FirstNScheduler should apply at most `n` matches of a rule per iteration: +-// - If there are <= n matches: apply all of them, return false (done). +-// - If there are > n matches: apply the first n, return true (more work remains). ++// Problem 7: FirstNScheduler implementation + #[derive(Clone)] + struct FirstNScheduler { + n: usize, +@@ -163,25 +129,29 @@ + + impl Scheduler for FirstNScheduler { + fn filter_matches(&mut self, _rule: &str, _ruleset: &str, matches: &mut Matches) -> bool { +- todo!("Problem 7") ++ if matches.match_size() <= self.n { ++ matches.choose_all(); ++ false ++ } else { ++ for i in 0..self.n { ++ matches.choose(i); ++ } ++ true ++ } + } + } + +-// Problem 8: +-// +-// We are going to define an alternative cost model that assigns the *depth* +-// of an AST as its cost. An extractor using this model will prefer shallower terms. +-// +-// The cost of a compound node is: max(children_costs) + enode_cost +-// The cost of a leaf (primitive) is: 0 ++// Problem 8: AstDepthCostModel + // +-// Use this cost model in the extractor from Problem 5 instead of DynamicCostModel. ++// The cost of a term is its AST depth (the length of the longest root-to-leaf path). ++// The extractor therefore prefers the shallowest equivalent term. + pub struct AstDepthCostModel; + + pub type C = usize; + impl CostModel for AstDepthCostModel { + fn fold(&self, _head: &str, children_cost: &[C], head_cost: C) -> C { +- todo!("Problem 8") ++ let max_child = children_cost.iter().copied().max().unwrap_or(0); ++ max_child.saturating_add(head_cost) + } + + fn enode_cost( +@@ -190,7 +160,7 @@ + _func: &egglog::Function, + _row: &egglog::FunctionRow, + ) -> C { +- todo!("Problem 8") ++ 1 + } + + fn container_cost( +@@ -200,7 +170,7 @@ + _value: egglog::Value, + element_costs: &[C], + ) -> C { +- todo!("Problem 8") ++ element_costs.iter().copied().max().unwrap_or(0) + } + + fn base_value_cost( +@@ -209,6 +179,6 @@ + _sort: &egglog::ArcSort, + _value: egglog::Value, + ) -> C { +- todo!("Problem 8") ++ 0 + } + } +--- /dev/null ++++ src/defn.solution.egg +@@ -0,0 +1,104 @@ ++(datatype Scalar ++ (Num i64) ++ (SVar String) ++ (SAdd Scalar Scalar) ++ (SMul Scalar Scalar) ++ (SSub Scalar Scalar) ++ (SDiv Scalar Scalar)) ++ ++(with-dynamic-cost ++ (datatype Matrix ++ (MVar String) ++ (Scale Scalar Matrix) ++ (MAdd Matrix Matrix) ++ (MMul Matrix Matrix) ++ ) ++) ++ ++(relation MatrixDim (String i64 i64)) ++ ++(ruleset optimization) ++(ruleset analysis) ++ ++;; Commutativity ++(rewrite (SAdd x y) (SAdd y x) :ruleset optimization) ++(rewrite (SMul x y) (SMul y x) :ruleset optimization) ++(rewrite (MAdd A B) (MAdd B A) :ruleset optimization) ++ ++;; Associativity ++(birewrite (SAdd (SAdd x y) z) (SAdd x (SAdd y z)) :ruleset optimization) ++(birewrite (SMul (SMul x y) z) (SMul x (SMul y z)) :ruleset optimization) ++(birewrite (Scale y (Scale x A)) (Scale (SMul x y) A) :ruleset optimization) ++(birewrite (MAdd (MAdd A B) C) (MAdd A (MAdd B C)) :ruleset optimization) ++(birewrite (MMul (MMul A B) C) (MMul A (MMul B C)) :ruleset optimization) ++ ++;; Distributivity ++(birewrite (SMul x (SAdd y z)) (SAdd (SMul x y) (SMul x z)) :ruleset optimization) ++(birewrite (MMul A (MAdd B C)) (MAdd (MMul A B) (MMul A C)) :ruleset optimization) ++(birewrite (Scale x (MAdd A B)) (MAdd (Scale x A) (Scale x B)) :ruleset optimization) ++(birewrite (Scale (SAdd x y) A) (MAdd (Scale x A) (Scale y A)) :ruleset optimization) ++ ++(rewrite (MMul (Scale a A) B) (Scale a (MMul A B)) :ruleset optimization) ++(rewrite (MMul A (Scale a B)) (Scale a (MMul A B)) :ruleset optimization) ++ ++;; Identity ++(rewrite (SAdd x (Num 0)) x :ruleset optimization) ++(rewrite (SMul x (Num 1)) x :ruleset optimization) ++(rewrite (SMul x (Num 0)) (Num 0) :ruleset optimization) ++ ++;; Analysis for dimensions of matrices ++(function nrows (Matrix) i64 :no-merge) ++(function ncols (Matrix) i64 :no-merge) ++ ++(rule ((MatrixDim m i j)) ++ ((set (nrows (MVar m)) i) ++ (set (ncols (MVar m)) j)) ++ :ruleset analysis) ++ ++(rule ( ++ (= e (Scale x A)) ++ (= i (nrows A)) (= j (ncols A)) ++) ( ++ (set (nrows e) i) ++ (set (ncols e) j) ++) :ruleset analysis) ++ ++(rule ( ++ (= e (MAdd A B)) ++ (= i (nrows A)) (= j (ncols A)) ++ (= i (nrows B)) (= j (ncols B)) ++) ( ++ (set (nrows e) i) ++ (set (ncols e) j) ++) :ruleset analysis) ++ ++(rule ( ++ (= e (MMul A B)) ++ (= i (nrows A)) (= k (ncols A)) ++ (= k (nrows B)) (= j (ncols B)) ++) ( ++ (set (nrows e) i) ++ (set (ncols e) j) ++) :ruleset analysis) ++ ++(rule ((= e (Scale a A)) ++ (= n (nrows A)) ++ (= m (ncols A)) ++) ( ++ (set-cost (Scale a A) (* n m)) ++) :ruleset analysis) ++ ++(rule ((= e (MAdd A B)) ++ (= n (nrows A)) ++ (= m (ncols A)) ++) ( ++ (set-cost (MAdd A B) (* n m)) ++) :ruleset analysis) ++ ++(rule ((= e (MMul A B)) ++ (= n (nrows A)) ++ (= k (ncols A)) ++ (= m (ncols B)) ++) ( ++ (set-cost (MMul A B) (* (* n m) k)) ++) :ruleset analysis) diff --git a/case-study-linear-algebra-compiler/src/defn.solution.egg b/case-study-linear-algebra-compiler/src/defn.solution.egg new file mode 100644 index 0000000..88d5356 --- /dev/null +++ b/case-study-linear-algebra-compiler/src/defn.solution.egg @@ -0,0 +1,104 @@ +(datatype Scalar + (Num i64) + (SVar String) + (SAdd Scalar Scalar) + (SMul Scalar Scalar) + (SSub Scalar Scalar) + (SDiv Scalar Scalar)) + +(with-dynamic-cost + (datatype Matrix + (MVar String) + (Scale Scalar Matrix) + (MAdd Matrix Matrix) + (MMul Matrix Matrix) + ) +) + +(relation MatrixDim (String i64 i64)) + +(ruleset optimization) +(ruleset analysis) + +;; Commutativity +(rewrite (SAdd x y) (SAdd y x) :ruleset optimization) +(rewrite (SMul x y) (SMul y x) :ruleset optimization) +(rewrite (MAdd A B) (MAdd B A) :ruleset optimization) + +;; Associativity +(birewrite (SAdd (SAdd x y) z) (SAdd x (SAdd y z)) :ruleset optimization) +(birewrite (SMul (SMul x y) z) (SMul x (SMul y z)) :ruleset optimization) +(birewrite (Scale y (Scale x A)) (Scale (SMul x y) A) :ruleset optimization) +(birewrite (MAdd (MAdd A B) C) (MAdd A (MAdd B C)) :ruleset optimization) +(birewrite (MMul (MMul A B) C) (MMul A (MMul B C)) :ruleset optimization) + +;; Distributivity +(birewrite (SMul x (SAdd y z)) (SAdd (SMul x y) (SMul x z)) :ruleset optimization) +(birewrite (MMul A (MAdd B C)) (MAdd (MMul A B) (MMul A C)) :ruleset optimization) +(birewrite (Scale x (MAdd A B)) (MAdd (Scale x A) (Scale x B)) :ruleset optimization) +(birewrite (Scale (SAdd x y) A) (MAdd (Scale x A) (Scale y A)) :ruleset optimization) + +(rewrite (MMul (Scale a A) B) (Scale a (MMul A B)) :ruleset optimization) +(rewrite (MMul A (Scale a B)) (Scale a (MMul A B)) :ruleset optimization) + +;; Identity +(rewrite (SAdd x (Num 0)) x :ruleset optimization) +(rewrite (SMul x (Num 1)) x :ruleset optimization) +(rewrite (SMul x (Num 0)) (Num 0) :ruleset optimization) + +;; Analysis for dimensions of matrices +(function nrows (Matrix) i64 :no-merge) +(function ncols (Matrix) i64 :no-merge) + +(rule ((MatrixDim m i j)) + ((set (nrows (MVar m)) i) + (set (ncols (MVar m)) j)) + :ruleset analysis) + +(rule ( + (= e (Scale x A)) + (= i (nrows A)) (= j (ncols A)) +) ( + (set (nrows e) i) + (set (ncols e) j) +) :ruleset analysis) + +(rule ( + (= e (MAdd A B)) + (= i (nrows A)) (= j (ncols A)) + (= i (nrows B)) (= j (ncols B)) +) ( + (set (nrows e) i) + (set (ncols e) j) +) :ruleset analysis) + +(rule ( + (= e (MMul A B)) + (= i (nrows A)) (= k (ncols A)) + (= k (nrows B)) (= j (ncols B)) +) ( + (set (nrows e) i) + (set (ncols e) j) +) :ruleset analysis) + +(rule ((= e (Scale a A)) + (= n (nrows A)) + (= m (ncols A)) +) ( + (set-cost (Scale a A) (* n m)) +) :ruleset analysis) + +(rule ((= e (MAdd A B)) + (= n (nrows A)) + (= m (ncols A)) +) ( + (set-cost (MAdd A B) (* n m)) +) :ruleset analysis) + +(rule ((= e (MMul A B)) + (= n (nrows A)) + (= k (ncols A)) + (= m (ncols B)) +) ( + (set-cost (MMul A B) (* (* n m) k)) +) :ruleset analysis) diff --git a/case-study-linear-algebra-compiler/src/main.rs b/case-study-linear-algebra-compiler/src/main.rs index cc1c154..c85d3fb 100644 --- a/case-study-linear-algebra-compiler/src/main.rs +++ b/case-study-linear-algebra-compiler/src/main.rs @@ -6,15 +6,15 @@ use std::io::Read; -use egglog_experimental::ast::Command; -use egglog_experimental::scheduler::Matches; -use egglog_experimental::{self as egglog, add_scheduler_builder, DynamicCostModel}; +use egglog_experimental::ast::{Action, Command}; +use egglog_experimental::extract::{CostModel, Extractor}; +use egglog_experimental::scheduler::{Matches, Scheduler}; use egglog_experimental::{ + self as egglog, + DynamicCostModel, TermDag, ast::Literal, - extract::CostModel, new_experimental_egraph, prelude::{exprs::*, *}, - scheduler::Scheduler, }; use crate::ast::{CoreBindings, Type}; @@ -45,6 +45,10 @@ fn main() { }; // Problem 1: run defn.egg + // + // Create a new experimental EGraph with `new_experimental_egraph()` and run + // the egglog program from defn.egg using + // `egraph.parse_and_run_program(None, program)`. let mut egraph = new_experimental_egraph(); todo!("Problem 1"); @@ -58,6 +62,7 @@ fn main() { // // (relation MatrixDim (String i64 i64)) // + // Hint: use `egraph.parse_and_run_program(None, &format!("..."))`. for decl in core_bindings.declares.iter() { let x = &decl.var; if let Type::Matrix { nrows, ncols } = decl.ty { @@ -70,29 +75,42 @@ fn main() { // Problem 3: // // For each variable assignment in the input program, - // bind the egglog variable x to its corresponding expression. + // bind the egglog variable to its corresponding expression. + // + // We have provided [`to_egglog_expr`] in util.rs that converts a [`CoreExpr`] + // to an egglog AST expression [`egglog::ast::Expr`]. // - // We have provided [`to_egglog_expr`] function that converts a [`CoreExpr`] - // to an egglog expression [`egglog::ast::Expr`] + // To run a `let` command, build a + // Command::Action(Action::Let(span!(), name, expr)) + // and pass it to `egraph.run_program(vec![cmd])`. for bind in core_bindings.bindings.iter() { let var = &bind.var; let expr = &bind.expr; - + todo!("Problem 3") } // Problem 4: // // Now we have inserted all the ASTs and definitions. We will run our rules. - // To start with, let's run our rules 20 times (`(run 20)`). + // To start with, let's run our rules 20 times with `(run 20)`: + // + // egraph.parse_and_run_program(None, "(run 20)").unwrap(); todo!("Problem 4"); // Problem 5: // - // Extract the optimized program using the `DynamicCostModel` from egglog_experimental. - // The extracted program is a directed acyclic graph (DAG) and has type [`TermDag`]` and [`Term`]. - // We have provided the method [`termdag_to_bindings`] to convert to [`CoreBindings`]. + // Extract the optimized program using `DynamicCostModel` and `Extractor`. + // + // Steps: + // a) Evaluate the output variable to get its (ArcSort, Value): + // egraph.eval_expr(&exprs::var(&output.var)) + // b) Compute an extractor: + // Extractor::compute_costs_from_rootsorts(Some(vec![sort]), &egraph, DynamicCostModel) + // c) Extract the best term: + // extractor.extract_best(&egraph, &mut termdag, value) + // d) Convert to CoreBindings using `termdag_to_bindings` let output = core_bindings.bindings.last().unwrap(); let bindings: CoreBindings = todo!("Problem 5"); @@ -101,29 +119,43 @@ fn main() { // Problem 6: // - // Break down the rules into optimization rules and analysis rules. - // Improve the schedule above with the more refined rulesets. + // Break down the rules into optimization rules and analysis rules by adding + // `:ruleset optimization` and `:ruleset analysis` annotations in defn.egg, + // and declaring the rulesets: + // + // (ruleset optimization) + // (ruleset analysis) + // + // Then replace `(run 20)` above with a schedule that: + // - Saturates the analysis rules before each optimization step + // - Runs the optimization rules once per iteration + // + // Hint: use `(run-schedule (repeat 20 (saturate analysis) (run optimization)))`. + // + // Also update `egglog_program()` to include your updated defn file. + + // Problem 7: + // + // Fill in the blanks for the [`FirstNScheduler`] below. FirstNScheduler + // applies at most `n` matches of a rule in each iteration. Compared + // to the default scheduler, it allows the E-graph to grow more gently. + // + // Register the scheduler with the egraph: + // + // let scheduler_id = egraph.add_scheduler(Box::new(FirstNScheduler { n: 3 })); + // + // Then use `egraph.step_rules_with_scheduler(scheduler_id, "optimization")` + // inside a loop to run the optimization ruleset with the scheduler. + // + // Update Problem 6's schedule to use this scheduler for optimization rules. } // Problem 7: // -// Fill in the blanks for the [`FirstNScheduler`]. FirstNScheduler -// applies at most `n` matches of a rule in each iteration. Compared -// to the default scheduler, it allows the E-graph to grow more gently. -// -// Add this scheduler to the global scheduler list in egglog-experimental with -// -// add_scheduler_builder("first-n".into(), Box::new(new_first_n_scheduler)); -// -// Update the `run-schedule` so that the optimization ruleset uses this scheduler. -pub fn new_first_n_scheduler(_egraph: &EGraph, exprs: &[egglog::ast::Expr]) -> Box { - assert!(exprs.len() == 1); - let egglog::ast::Expr::Lit(_, Literal::Int(n)) = exprs[0] else { - panic!("wrong arguments to first n scheduler"); - }; - Box::new(FirstNScheduler { n: n as usize }) -} - +// Implement the `filter_matches` method. +// FirstNScheduler should apply at most `n` matches of a rule per iteration: +// - If there are <= n matches: apply all of them, return false (done). +// - If there are > n matches: apply the first n, return true (more work remains). #[derive(Clone)] struct FirstNScheduler { n: usize, @@ -137,12 +169,13 @@ impl Scheduler for FirstNScheduler { // Problem 8: // -// We are going to define an alternative cost model that is not "sum of node costs". +// We are going to define an alternative cost model that assigns the *depth* +// of an AST as its cost. An extractor using this model will prefer shallower terms. // -// The cost model, named [`AstDepthCostModel`] assigns the depth of an AST as its cost, -// so an extractor using this cost model will extract a term with the smallest depth. +// The cost of a compound node is: max(children_costs) + enode_cost +// The cost of a leaf (primitive) is: 0 // -// Use this cost model in our extractor. +// Use this cost model in the extractor from Problem 5 instead of DynamicCostModel. pub struct AstDepthCostModel; pub type C = usize; @@ -160,7 +193,7 @@ impl CostModel for AstDepthCostModel { todo!("Problem 8") } - fn container_primitive( + fn container_cost( &self, _egraph: &EGraph, _sort: &egglog::ArcSort, @@ -170,7 +203,7 @@ impl CostModel for AstDepthCostModel { todo!("Problem 8") } - fn leaf_primitive( + fn base_value_cost( &self, _egraph: &EGraph, _sort: &egglog::ArcSort,