From 55fbff43ebd7b85fd1fad0f41519d58ccd3b9400 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Mon, 8 Dec 2025 20:58:33 +0100 Subject: [PATCH 01/26] feat(virtq): add packed virtual queue implementation Signed-off-by: Tomasz Andrzejak --- Cargo.lock | 551 +-- Cargo.toml | 1 + src/hyperlight_common/Cargo.toml | 25 +- src/hyperlight_common/benches/buffer_pool.rs | 176 + src/hyperlight_common/src/lib.rs | 5 +- src/hyperlight_common/src/virtq/access.rs | 136 + src/hyperlight_common/src/virtq/consumer.rs | 633 ++++ src/hyperlight_common/src/virtq/desc.rs | 326 ++ src/hyperlight_common/src/virtq/event.rs | 117 + src/hyperlight_common/src/virtq/mod.rs | 1134 +++++++ src/hyperlight_common/src/virtq/pool.rs | 1334 ++++++++ src/hyperlight_common/src/virtq/producer.rs | 790 +++++ src/hyperlight_common/src/virtq/ring.rs | 3169 ++++++++++++++++++ src/hyperlight_guest/src/error.rs | 7 +- 14 files changed, 8159 insertions(+), 245 deletions(-) create mode 100644 src/hyperlight_common/benches/buffer_pool.rs create mode 100644 src/hyperlight_common/src/virtq/access.rs create mode 100644 src/hyperlight_common/src/virtq/consumer.rs create mode 100644 src/hyperlight_common/src/virtq/desc.rs create mode 100644 src/hyperlight_common/src/virtq/event.rs create mode 100644 src/hyperlight_common/src/virtq/mod.rs create mode 100644 src/hyperlight_common/src/virtq/pool.rs create mode 100644 src/hyperlight_common/src/virtq/producer.rs create mode 100644 src/hyperlight_common/src/virtq/ring.rs diff --git a/Cargo.lock b/Cargo.lock index cffa145d2..2efdd92f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,7 +10,7 @@ checksum = "59317f77929f0e679d39364702289274de2f0f0b22cbf50b2b8cff2169a0b27a" dependencies = [ "cpp_demangle", "fallible-iterator", - "gimli 0.33.0", + "gimli 0.33.1", "memmap2", "object", "rustc-demangle", @@ -69,21 +69,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" -[[package]] -name = "anstream" -version = "0.6.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" -dependencies = [ - "anstyle", - "anstyle-parse 0.2.7", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "is_terminal_polyfill", - "utf8parse", -] - [[package]] name = "anstream" version = "1.0.0" @@ -91,7 +76,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", - "anstyle-parse 1.0.0", + "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", @@ -101,18 +86,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" - -[[package]] -name = "anstyle-parse" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" -dependencies = [ - "utf8parse", -] +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" @@ -193,6 +169,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_refcell" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" + [[package]] name = "autocfg" version = "1.5.0" @@ -238,7 +220,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash 2.1.1", + "rustc-hash 2.1.2", "shlex", "syn", ] @@ -324,9 +306,29 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.19.1" +version = "3.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "bytes" @@ -379,14 +381,14 @@ dependencies = [ "serde_json", "syn", "tempfile", - "toml", + "toml 0.9.12+spec-1.1.0", ] [[package]] name = "cc" -version = "1.2.59" +version = "1.2.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7a4d3ec6524d28a329fc53654bbadc9bdd7b0431f5d65f1a56ffb28a1ee5283" +checksum = "43c5703da9466b66a946814e1adf53ea2c90f10063b86290cc9eb67ce3478a20" dependencies = [ "find-msvc-tools", "jobserver", @@ -405,9 +407,9 @@ dependencies = [ [[package]] name = "cfg-expr" -version = "0.20.6" +version = "0.20.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78cef5b5a1a6827c7322ae2a636368a573006b27cfa76c7ebd53e834daeaab6a" +checksum = "3c6b04e07d8080154ed4ac03546d9a2b303cc2fe1901ba0b35b301516e289368" dependencies = [ "smallvec", "target-lexicon", @@ -489,20 +491,20 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.58" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63be97961acde393029492ce0be7a1af7e323e6bae9511ebfac33751be5e6806" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.58" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f13174bda5dfd69d7e947827e5af4b0f2f94a4a3ee92912fba07a66150f21e2" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ - "anstream 0.6.21", + "anstream", "anstyle", "clap_lex", "strsim", @@ -510,15 +512,15 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "constant_time_eq" @@ -555,19 +557,6 @@ dependencies = [ "libc", ] -[[package]] -name = "core-graphics" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "064badf302c3194842cf2c5d61f56cc88e54a759313879cdf03abdd27d0c3b97" -dependencies = [ - "bitflags 2.11.0", - "core-foundation", - "core-graphics-types", - "foreign-types", - "libc", -] - [[package]] name = "core-graphics-types" version = "0.2.0" @@ -581,13 +570,14 @@ dependencies = [ [[package]] name = "core-text" -version = "21.1.0" +version = "21.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fce32d657e17d6e4a8e70fe2ae6875218015f320620a78e5949d228bc76622bd" +checksum = "a593227b66cbd4007b2a050dfdd9e1d1318311409c8d600dc82ba1b15ca9c130" dependencies = [ "core-foundation", - "core-graphics 0.25.0", + "core-graphics", "foreign-types", + "libc", ] [[package]] @@ -813,9 +803,9 @@ checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] name = "env_filter" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" dependencies = [ "log", "regex", @@ -827,7 +817,7 @@ version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" dependencies = [ - "anstream 1.0.0", + "anstream", "anstyle", "env_filter", "jiff", @@ -852,9 +842,9 @@ dependencies = [ [[package]] name = "euclid" -version = "0.22.13" +version = "0.22.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df61bf483e837f88d5c2291dcf55c67be7e676b3a51acc48db3a7b163b91ed63" +checksum = "f1a05365e3b1c6d1650318537c7460c6923f1abdd272ad6842baa2b509957a06" dependencies = [ "num-traits", ] @@ -867,9 +857,9 @@ checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" [[package]] name = "fastrand" -version = "2.3.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" [[package]] name = "fdeflate" @@ -886,6 +876,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -976,9 +972,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", "futures-sink", @@ -986,15 +982,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] name = "futures-executor" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" dependencies = [ "futures-core", "futures-task", @@ -1003,15 +999,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] name = "futures-macro" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" dependencies = [ "proc-macro2", "quote", @@ -1020,21 +1016,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "futures-task" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" [[package]] name = "futures-util" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", "futures-io", @@ -1043,7 +1039,6 @@ dependencies = [ "futures-task", "memchr", "pin-project-lite", - "pin-utils", "slab", ] @@ -1071,6 +1066,21 @@ dependencies = [ "num-traits", ] +[[package]] +name = "generator" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f04ae4152da20c76fe800fa48659201d5cf627c5149ca0b707b69d7eef6cf9" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows-link", + "windows-result", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1100,19 +1110,19 @@ checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 5.3.0", "wasip2", ] [[package]] name = "getrandom" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" dependencies = [ "cfg-if", "libc", - "r-efi", + "r-efi 6.0.0", "rand_core 0.10.0", "wasip2", "wasip3", @@ -1130,9 +1140,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.33.0" +version = "0.33.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf7f043f89559805f8c7cacc432749b2fa0d0a0a9ee46ce47164ed5ba7f126c" +checksum = "19e16c5073773ccf057c282be832a59ee53ef5ff98db3aeff7f8314f52ffc196" dependencies = [ "stable_deref_trait", ] @@ -1315,6 +1325,12 @@ dependencies = [ "serde_core", ] +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + [[package]] name = "heck" version = "0.5.0" @@ -1362,9 +1378,9 @@ checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "hyper" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" dependencies = [ "atomic-waker", "bytes", @@ -1376,7 +1392,6 @@ dependencies = [ "httparse", "itoa", "pin-project-lite", - "pin-utils", "smallvec", "tokio", "want", @@ -1424,8 +1439,19 @@ version = "0.14.0" dependencies = [ "anyhow", "arbitrary", + "atomic_refcell", + "bitflags 2.11.0", + "bytemuck", + "bytes", + "criterion", + "fixedbitset", "flatbuffers", + "hyperlight-testing", "log", + "loom", + "quickcheck", + "rand 0.9.2", + "smallvec", "spin", "thiserror", "tracing", @@ -1443,7 +1469,7 @@ dependencies = [ "proc-macro2", "quote", "syn", - "wasmparser 0.246.1", + "wasmparser 0.246.2", ] [[package]] @@ -1456,7 +1482,7 @@ dependencies = [ "quote", "syn", "tracing", - "wasmparser 0.246.1", + "wasmparser 0.246.2", ] [[package]] @@ -1652,12 +1678,13 @@ dependencies = [ [[package]] name = "icu_collections" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" dependencies = [ "displaydoc", "potential_utf", + "utf8_iter", "yoke", "zerofrom", "zerovec", @@ -1665,9 +1692,9 @@ dependencies = [ [[package]] name = "icu_locale_core" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" dependencies = [ "displaydoc", "litemap", @@ -1678,9 +1705,9 @@ dependencies = [ [[package]] name = "icu_normalizer" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" dependencies = [ "icu_collections", "icu_normalizer_data", @@ -1692,15 +1719,15 @@ dependencies = [ [[package]] name = "icu_normalizer_data" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" [[package]] name = "icu_properties" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" dependencies = [ "icu_collections", "icu_locale_core", @@ -1712,15 +1739,15 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" [[package]] name = "icu_provider" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" dependencies = [ "displaydoc", "icu_locale_core", @@ -1760,27 +1787,27 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.13.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.16.1", + "hashbrown 0.17.0", "serde", "serde_core", ] [[package]] name = "ipnet" -version = "2.11.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" dependencies = [ "memchr", "serde", @@ -1812,9 +1839,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jiff" @@ -1852,10 +1879,12 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.85" +version = "0.3.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +checksum = "2e04e2ef80ce82e13552136fabeef8a5ed1f985a96805761cbb9a2c34e7664d9" dependencies = [ + "cfg-if", + "futures-util", "once_cell", "wasm-bindgen", ] @@ -1955,19 +1984,18 @@ dependencies = [ [[package]] name = "libredox" -version = "0.1.12" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d0b95e02c851351f877147b7deea7b1afb1df71b63aa5f8270716e0c5720616" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" dependencies = [ - "bitflags 2.11.0", "libc", ] [[package]] name = "libz-sys" -version = "1.1.23" +version = "1.1.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15d118bbf3771060e7311cc7bb0545b01d08a8b4a7de949198dec1fa0ca1c0f7" +checksum = "fc3a226e576f50782b3305c5ccf458698f92798987f551c6a02efe8276721e22" dependencies = [ "cc", "libc", @@ -2003,9 +2031,9 @@ checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "litemap" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" [[package]] name = "lock_api" @@ -2022,6 +2050,19 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "mach2" version = "0.4.3" @@ -2071,9 +2112,9 @@ checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" [[package]] name = "memmap2" -version = "0.9.9" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490" +checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3" dependencies = [ "libc", ] @@ -2224,9 +2265,9 @@ dependencies = [ [[package]] name = "num_enum" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c" +checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26" dependencies = [ "num_enum_derive", "rustversion", @@ -2234,9 +2275,9 @@ dependencies = [ [[package]] name = "num_enum_derive" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7" +checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8" dependencies = [ "proc-macro2", "quote", @@ -2361,9 +2402,9 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "ordered-float" -version = "5.1.0" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" +checksum = "b7d950ca161dc355eaf28f82b11345ed76c6e1f6eb1f4f4479e0323b9e2fbd0e" dependencies = [ "num-traits", ] @@ -2509,7 +2550,7 @@ dependencies = [ "cairo-rs", "cairo-sys-rs", "cfg-if", - "core-graphics 0.24.0", + "core-graphics", "piet", "piet-cairo", "piet-coregraphics", @@ -2529,7 +2570,7 @@ dependencies = [ "associative-cache", "core-foundation", "core-foundation-sys", - "core-graphics 0.24.0", + "core-graphics", "core-text", "foreign-types", "piet", @@ -2565,18 +2606,18 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.10" +version = "1.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6" dependencies = [ "proc-macro2", "quote", @@ -2585,15 +2626,9 @@ dependencies = [ [[package]] name = "pin-project-lite" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" - -[[package]] -name = "pin-utils" -version = "0.1.0" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" [[package]] name = "pkg-config" @@ -2656,18 +2691,18 @@ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" dependencies = [ "portable-atomic", ] [[package]] name = "potential_utf" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" dependencies = [ "zerovec", ] @@ -2786,6 +2821,17 @@ version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" +[[package]] +name = "quickcheck" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95c589f335db0f6aaa168a7cd27b1fc6920f5e1470c804f814d9cd6e62a0f70b" +dependencies = [ + "env_logger", + "log", + "rand 0.10.0", +] + [[package]] name = "quote" version = "1.0.45" @@ -2801,6 +2847,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + [[package]] name = "radix_trie" version = "0.2.1" @@ -2828,7 +2880,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" dependencies = [ "chacha20", - "getrandom 0.4.1", + "getrandom 0.4.2", "rand_core 0.10.0", ] @@ -2949,9 +3001,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" [[package]] name = "reqwest" @@ -3037,9 +3089,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc-hash" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" [[package]] name = "rustc_version" @@ -3114,6 +3166,12 @@ dependencies = [ "sdd", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -3148,9 +3206,9 @@ checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" [[package]] name = "semver" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" [[package]] name = "serde" @@ -3197,9 +3255,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "1.0.4" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8bbf91e5a4d6315eee45e704372590b30e260ee83af6639d64557f51b067776" +checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26" dependencies = [ "serde_core", ] @@ -3264,9 +3322,9 @@ dependencies = [ [[package]] name = "shellexpand" -version = "3.1.1" +version = "3.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b1fdf65dd6331831494dd616b30351c38e96e45921a27745cf98490458b90bb" +checksum = "32824fab5e16e6c4d86dc1ba84489390419a39f97699852b66480bb87d297ed8" dependencies = [ "dirs", ] @@ -3289,15 +3347,15 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" [[package]] name = "sketches-ddsketch" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1e9a774a6c28142ac54bb25d25562e6bcf957493a184f15ad4eebccb23e410a" +checksum = "0c6f73aeb92d671e0cc4dca167e59b2deb6387c375391bc99ee743f326994a2b" [[package]] name = "slab" @@ -3375,14 +3433,14 @@ dependencies = [ [[package]] name = "system-deps" -version = "7.0.7" +version = "7.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c8f33736f986f16d69b6cb8b03f55ddcad5c41acc4ccc39dd88e84aa805e7f" +checksum = "396a35feb67335377e0251fcbc1092fc85c484bd4e3a7a54319399da127796e7" dependencies = [ "cfg-expr", "heck", "pkg-config", - "toml", + "toml 1.1.2+spec-1.1.0", "version-compare", ] @@ -3399,7 +3457,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", - "getrandom 0.4.1", + "getrandom 0.4.2", "once_cell", "rustix", "windows-sys", @@ -3445,9 +3503,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" dependencies = [ "displaydoc", "zerovec", @@ -3527,7 +3585,22 @@ dependencies = [ "toml_datetime 0.7.5+spec-1.1.0", "toml_parser", "toml_writer", - "winnow", + "winnow 0.7.15", +] + +[[package]] +name = "toml" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81f3d15e84cbcd896376e6730314d59fb5a87f31e4b038454184435cd57defee" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned", + "toml_datetime 1.1.1+spec-1.1.0", + "toml_parser", + "toml_writer", + "winnow 1.0.1", ] [[package]] @@ -3541,45 +3614,45 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "1.0.0+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" dependencies = [ "serde_core", ] [[package]] name = "toml_edit" -version = "0.25.4+spec-1.1.0" +version = "0.25.11+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" +checksum = "0b59c4d22ed448339746c59b905d24568fcbb3ab65a500494f7b8c3e97739f2b" dependencies = [ "indexmap", - "toml_datetime 1.0.0+spec-1.1.0", + "toml_datetime 1.1.1+spec-1.1.0", "toml_parser", - "winnow", + "winnow 1.0.1", ] [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ - "winnow", + "winnow 1.0.1", ] [[package]] name = "toml_writer" -version = "1.0.6+spec-1.1.0" +version = "1.1.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" +checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" [[package]] name = "tonic" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a286e33f82f8a1ee2df63f4fa35c0becf4a85a0cb03091a15fd7bf0b402dc94a" +checksum = "fec7c61a0695dc1887c1b53952990f3ad2e3a31453e1f49f10e75424943a93ec" dependencies = [ "async-trait", "base64", @@ -3603,9 +3676,9 @@ dependencies = [ [[package]] name = "tonic-prost" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c55a2d6a14174563de34409c9f92ff981d006f56da9c6ecd40d9d4a31500b0" +checksum = "a55376a0bbaa4975a3f10d009ad763d8f4108f067c7c2e74f3001fb49778d309" dependencies = [ "bytes", "prost", @@ -3867,15 +3940,15 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.23" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" [[package]] name = "unicode-segmentation" -version = "1.12.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" [[package]] name = "unicode-xid" @@ -3919,7 +3992,7 @@ version = "1.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ac8b6f42ead25368cf5b098aeb3dc8a1a2c05a3eee8a9a1a68c640edbfc79d9" dependencies = [ - "getrandom 0.4.1", + "getrandom 0.4.2", "js-sys", "serde_core", "wasm-bindgen", @@ -4013,9 +4086,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.108" +version = "0.2.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +checksum = "0551fc1bb415591e3372d0bc4780db7e587d84e2a7e79da121051c5c4b89d0b0" dependencies = [ "cfg-if", "once_cell", @@ -4026,23 +4099,19 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.58" +version = "0.4.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70a6e77fd0ae8029c9ea0063f87c46fde723e7d887703d74ad2616d792e51e6f" +checksum = "03623de6905b7206edd0a75f69f747f134b7f0a2323392d664448bf2d3c5d87e" dependencies = [ - "cfg-if", - "futures-util", "js-sys", - "once_cell", "wasm-bindgen", - "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.108" +version = "0.2.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +checksum = "7fbdf9a35adf44786aecd5ff89b4563a90325f9da0923236f6104e603c7e86be" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -4050,9 +4119,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.108" +version = "0.2.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +checksum = "dca9693ef2bab6d4e6707234500350d8dad079eb508dca05530c85dc3a529ff2" dependencies = [ "bumpalo", "proc-macro2", @@ -4063,9 +4132,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.108" +version = "0.2.117" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +checksum = "39129a682a6d2d841b6c429d0c51e5cb0ed1a03829d8b3d1e69a011e62cb3d3b" dependencies = [ "unicode-ident", ] @@ -4106,9 +4175,9 @@ dependencies = [ [[package]] name = "wasmparser" -version = "0.246.1" +version = "0.246.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d991c35d79bf8336dc1cd632ed4aacf0dc5fac4bc466c670625b037b972bb9c" +checksum = "71cde4757396defafd25417cfb36aa3161027d06d865b0c24baaae229aac005d" dependencies = [ "bitflags 2.11.0", "hashbrown 0.16.1", @@ -4119,9 +4188,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.85" +version = "0.3.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "312e32e551d92129218ea9a2452120f4aabc03529ef03e4d0d82fb2780608598" +checksum = "cd70027e39b12f0849461e08ffc50b9cd7688d942c1c8e3c7b22273236b4dd0a" dependencies = [ "js-sys", "wasm-bindgen", @@ -4298,9 +4367,15 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.14" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" + +[[package]] +name = "winnow" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" +checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" dependencies = [ "memchr", ] @@ -4404,9 +4479,9 @@ dependencies = [ [[package]] name = "writeable" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" [[package]] name = "xi-unicode" @@ -4416,9 +4491,9 @@ checksum = "a67300977d3dc3f8034dae89778f502b6ba20b269527b3223ba59c0cf393bb8a" [[package]] name = "yoke" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" dependencies = [ "stable_deref_trait", "yoke-derive", @@ -4427,9 +4502,9 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" dependencies = [ "proc-macro2", "quote", @@ -4439,18 +4514,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.39" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.39" +version = "0.8.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" dependencies = [ "proc-macro2", "quote", @@ -4459,18 +4534,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.6" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" dependencies = [ "proc-macro2", "quote", @@ -4480,9 +4555,9 @@ dependencies = [ [[package]] name = "zerotrie" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" dependencies = [ "displaydoc", "yoke", @@ -4491,9 +4566,9 @@ dependencies = [ [[package]] name = "zerovec" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" dependencies = [ "yoke", "zerofrom", @@ -4502,9 +4577,9 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" dependencies = [ "proc-macro2", "quote", @@ -4513,6 +4588,6 @@ dependencies = [ [[package]] name = "zmij" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4de98dfa5d5b7fef4ee834d0073d560c9ca7b6c46a71d058c48db7960f8cfaf7" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml index 6c3ce9624..2578a8920 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ hyperlight-component-macro = { path = "src/hyperlight_component_macro", version [workspace.lints.rust] unsafe_op_in_unsafe_fn = "deny" +unexpected_cfgs = { level = "warn", check-cfg = [ 'cfg(loom)' ] } # this will generate symbols for release builds # so is handy for debugging issues in release builds diff --git a/src/hyperlight_common/Cargo.toml b/src/hyperlight_common/Cargo.toml index 68ebcae71..c6f961b66 100644 --- a/src/hyperlight_common/Cargo.toml +++ b/src/hyperlight_common/Cargo.toml @@ -15,13 +15,19 @@ Hyperlight's components common to host and guest. workspace = true [dependencies] -flatbuffers = { version = "25.12.19", default-features = false } +arbitrary = {version = "1.4.2", optional = true, features = ["derive"]} anyhow = { version = "1.0.102", default-features = false } +atomic_refcell = "0.1.13" +bitflags = "2.10.0" +bytemuck = { version = "1.24", features = ["derive"] } +bytes = { version = "1", default-features = false } +fixedbitset = { version = "0.5.7", default-features = false } +flatbuffers = { version = "25.12.9", default-features = false } log = "0.4.29" -tracing = { version = "0.1.44", optional = true } -arbitrary = {version = "1.4.2", optional = true, features = ["derive"]} +smallvec = "1.15.1" spin = "0.10.0" thiserror = { version = "2.0.18", default-features = false } +tracing = { version = "0.1.44", optional = true } tracing-core = { version = "0.1.36", default-features = false } [features] @@ -33,6 +39,19 @@ mem_profile = [] std = ["thiserror/std", "log/std", "tracing/std"] nanvix-unstable = [] +[dev-dependencies] +criterion = "0.8.1" +hyperlight-testing = { workspace = true } +quickcheck = "1.0.3" +rand = "0.9.2" + +[target.'cfg(loom)'.dev-dependencies] +loom = "0.7" + [lib] bench = false # see https://bheisler.github.io/criterion.rs/book/faq.html#cargo-bench-gives-unrecognized-option-errors-for-valid-command-line-options doctest = false # reduce noise in test output + +[[bench]] +name = "buffer_pool" +harness = false diff --git a/src/hyperlight_common/benches/buffer_pool.rs b/src/hyperlight_common/benches/buffer_pool.rs new file mode 100644 index 000000000..614f160b0 --- /dev/null +++ b/src/hyperlight_common/benches/buffer_pool.rs @@ -0,0 +1,176 @@ +use std::hint::black_box; + +use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use hyperlight_common::virtq::{BufferPool, BufferProvider}; + +// Helper to create a pool for benchmarking +fn make_pool(size: usize) -> BufferPool { + let base = 0x10000; + BufferPool::::new(base, size).unwrap() +} + +// Single allocation performance +fn bench_alloc_single(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_single"); + + for size in [64, 128, 256, 512, 1024, 1500, 4096].iter() { + group.throughput(Throughput::Elements(1)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(black_box(size)).unwrap(); + pool.dealloc(alloc).unwrap(); + }); + }); + } + group.finish(); +} + +// LIFO recycling +fn bench_alloc_lifo(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_lifo"); + + for size in [256, 1500, 4096].iter() { + group.throughput(Throughput::Elements(100)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + for _ in 0..100 { + let alloc = pool.alloc(black_box(size)).unwrap(); + pool.dealloc(alloc).unwrap(); + } + }); + }); + } + group.finish(); +} + +// Fragmented allocation worst case +fn bench_alloc_fragmented(c: &mut Criterion) { + let mut group = c.benchmark_group("alloc_fragmented"); + + group.bench_function("fragmented_256", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + + // Create fragmentation pattern: allocate many, free every other + let mut allocations = Vec::new(); + for _ in 0..100 { + allocations.push(pool.alloc(128).unwrap()); + } + for i in (0..100).step_by(2) { + pool.dealloc(allocations[i]).unwrap(); + } + + b.iter(|| { + let alloc = pool.alloc(black_box(256)).unwrap(); + pool.dealloc(alloc).unwrap(); + }); + }); + + group.finish(); +} + +// Realloc operations +fn bench_realloc(c: &mut Criterion) { + let mut group = c.benchmark_group("realloc"); + + // In-place grow (same tier) + group.bench_function("grow_inplace", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(256).unwrap(); + let grown = pool.resize(alloc, black_box(512)).unwrap(); + pool.dealloc(grown).unwrap(); + }); + }); + + // Relocate grow (cross tier) + group.bench_function("grow_relocate", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(128).unwrap(); + // Block in-place growth + let blocker = pool.alloc(256).unwrap(); + let grown = pool.resize(alloc, black_box(1500)).unwrap(); + pool.dealloc(grown).unwrap(); + pool.dealloc(blocker).unwrap(); + }); + }); + + // Shrink + group.bench_function("shrink", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(1500).unwrap(); + let shrunk = pool.resize(alloc, black_box(256)).unwrap(); + pool.dealloc(shrunk).unwrap(); + }); + }); + + group.finish(); +} + +// Free performance +fn bench_free(c: &mut Criterion) { + let mut group = c.benchmark_group("free"); + + for size in [256, 1500, 4096].iter() { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(size).unwrap(); + pool.dealloc(black_box(alloc)).unwrap(); + }); + }); + } + + group.finish(); +} + +// Cursor optimization +fn bench_last_free_run(c: &mut Criterion) { + let mut group = c.benchmark_group("last_free_run"); + + // With cursor optimization (LIFO) + group.bench_function("lifo_pattern", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + b.iter(|| { + let alloc = pool.alloc(256).unwrap(); + pool.dealloc(alloc).unwrap(); + let alloc2 = pool.alloc(black_box(256)).unwrap(); + pool.dealloc(alloc2).unwrap(); + }); + }); + + // Without cursor benefit (FIFO-like) + group.bench_function("fifo_pattern", |b| { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + let mut queue = Vec::new(); + + // Pre-fill queue + for _ in 0..10 { + queue.push(pool.alloc(256).unwrap()); + } + + b.iter(|| { + // FIFO: free oldest, allocate new + let old = queue.remove(0); + pool.dealloc(old).unwrap(); + queue.push(pool.alloc(black_box(256)).unwrap()); + }); + }); + + group.finish(); +} + +criterion_group!( + benches, + bench_alloc_single, + bench_alloc_lifo, + bench_alloc_fragmented, + bench_realloc, + bench_free, + bench_last_free_run, +); + +criterion_main!(benches); diff --git a/src/hyperlight_common/src/lib.rs b/src/hyperlight_common/src/lib.rs index eb4be220c..6e12d8cd4 100644 --- a/src/hyperlight_common/src/lib.rs +++ b/src/hyperlight_common/src/lib.rs @@ -18,7 +18,7 @@ limitations under the License. #![cfg_attr(not(any(test, debug_assertions)), warn(clippy::expect_used))] #![cfg_attr(not(any(test, debug_assertions)), warn(clippy::unwrap_used))] // We use Arbitrary during fuzzing, which requires std -#![cfg_attr(not(feature = "fuzzing"), no_std)] +#![cfg_attr(not(any(feature = "fuzzing", test, miri)), no_std)] extern crate alloc; @@ -50,3 +50,6 @@ pub mod vmem; /// ELF note types for embedding hyperlight version metadata in guest binaries. pub mod version_note; + +/// cbindgen:ignore +pub mod virtq; diff --git a/src/hyperlight_common/src/virtq/access.rs b/src/hyperlight_common/src/virtq/access.rs new file mode 100644 index 000000000..4daba3178 --- /dev/null +++ b/src/hyperlight_common/src/virtq/access.rs @@ -0,0 +1,136 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Memory Access Traits for Virtqueue Operations +//! +//! This module defines the [`MemOps`] trait that abstracts memory access patterns +//! required by the virtqueue implementation. This allows the virtqueue code to +//! work with different memory backends e.g. Host vs Guest. + +use bytemuck::Pod; + +/// Backend-provided memory access for virtqueue. +/// +/// # Safety +/// +/// Implementations must ensure that: +/// - Pointers passed to methods are valid for the duration of the call +/// - Memory ordering guarantees are upheld as documented +/// - Reads and writes don't cause undefined behavior (alignment, validity) +/// +/// [`RingProducer`]: super::RingProducer +/// [`RingConsumer`]: super::RingConsumer +pub trait MemOps { + type Error; + + /// Read bytes from physical memory. + /// + /// Used for reading buffer contents pointed to by descriptors. + /// + /// # Arguments + /// + /// * `addr` - Guest physical address to read from + /// * `dst` - Destination buffer to fill + /// + /// # Returns + /// + /// Number of bytes actually read (should equal `dst.len()` on success). + /// + /// # Safety + /// + /// The caller must ensure `paddr` is valid and points to at least `dst.len()` bytes. + fn read(&self, addr: u64, dst: &mut [u8]) -> Result; + + /// Write bytes to physical memory. + /// + /// # Arguments + /// + /// * `addr` - address to write to + /// * `src` - Source data to write + /// + /// # Returns + /// + /// Number of bytes actually written (should equal `src.len()` on success). + /// + /// # Safety + /// + /// The caller must ensure `paddr` is valid and points to at least `src.len()` bytes. + fn write(&self, addr: u64, src: &[u8]) -> Result; + + /// Load a u16 with acquire semantics. + /// + /// # Safety + /// + /// `addr` must translate to a valid, aligned `AtomicU16` in shared memory. + fn load_acquire(&self, addr: u64) -> Result; + + /// Store a u16 with release semantics. + /// + /// # Safety + /// + /// `addr` must translate to a valid `AtomicU16` in shared memory. + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error>; + + /// Get a direct read-only slice into shared memory. + /// + /// # Safety + /// + /// The caller must ensure: + /// - `addr` is valid and points to at least `len` bytes. + /// - The memory region is not concurrently modified for the lifetime of + /// the returned slice. Caller must uphold this via protocol-level + /// synchronisation, e.g. descriptor ownership transfer. + /// + /// See also [`BufferOwner`]: super::BufferOwner + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error>; + + /// Get a direct mutable slice into shared memory. + /// + /// # Safety + /// + /// The caller must ensure: + /// - `addr` is valid and points to at least `len` bytes. + /// - No other references (shared or mutable) to this memory region exist + /// for the lifetime of the returned slice. + /// - Protocol-level synchronisation (e.g. descriptor ownership) guarantees + /// exclusive access. + #[allow(clippy::mut_from_ref)] + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error>; + + /// Read a Pod type at the given pointer. + /// + /// # Safety + /// + /// The caller must ensure `addr` is valid, aligned, and translates to initialized memory. + fn read_val(&self, addr: u64) -> Result { + let mut val = T::zeroed(); + let bytes = bytemuck::bytes_of_mut(&mut val); + + self.read(addr, bytes)?; + Ok(val) + } + + /// Write a Pod type at the given pointer. + /// + /// # Safety + /// + /// The caller ensures that `ptr` is valid. + fn write_val(&self, addr: u64, val: T) -> Result<(), Self::Error> { + let bytes = bytemuck::bytes_of(&val); + self.write(addr, bytes)?; + Ok(()) + } +} diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs new file mode 100644 index 000000000..4c7bbc9ba --- /dev/null +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -0,0 +1,633 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use alloc::vec; +use alloc::vec::Vec; + +use bytes::Bytes; + +use super::*; + +/// In-flight entry tracking. +/// +/// Stored per descriptor ID while the entry is being processed. +/// Tracks that a descriptor slot is occupied. +#[derive(Debug, Clone, Copy)] +pub(crate) struct Inflight; + +/// Data received from the producer, safely copied out of shared memory. +/// +/// Created by [`VirtqConsumer::poll`]. The entry data is eagerly copied +/// from shared memory during poll using [`MemOps::read`] (volatile on +/// the host side), so accessing it requires no unsafe code and no +/// references into shared memory. +#[derive(Debug, Clone)] +pub struct RecvEntry { + token: Token, + data: Bytes, +} + +impl RecvEntry { + /// The token identifying this entry. + pub fn token(&self) -> Token { + self.token + } + + /// The entry payload, copied from shared memory. + /// + /// Returns empty [`Bytes`] when the chain has no readable buffers. + pub fn data(&self) -> &Bytes { + &self.data + } + + /// Consume the entry, taking ownership of the data. + pub fn into_data(self) -> Bytes { + self.data + } +} + +/// A pending completion, either writable or ack-only. +/// +/// Created by [`VirtqConsumer::poll`]. Must be submitted back via +/// [`VirtqConsumer::complete`] to release the descriptor. +#[must_use = "dropping without completing leaks the descriptor"] +pub enum SendCompletion { + /// Completion with a writable buffer (for chains with a completion buffer). + /// Use the `write*` methods on [`WritableCompletion`] to fill the + /// response buffer. + Writable(WritableCompletion), + /// Ack-only completion (for chains with only entry buffers). No response buffer. + /// Just pass back to [`VirtqConsumer::complete`] to acknowledge. + Ack(AckCompletion), +} + +impl SendCompletion { + /// The token identifying this completion. + pub fn token(&self) -> Token { + match self { + SendCompletion::Writable(wc) => wc.token(), + SendCompletion::Ack(ack) => ack.token(), + } + } + + /// Number of bytes written (0 for Ack). + pub fn written(&self) -> usize { + match self { + SendCompletion::Writable(wc) => wc.written(), + SendCompletion::Ack(_) => 0, + } + } + + fn id(&self) -> u16 { + match self { + SendCompletion::Writable(wc) => wc.id, + SendCompletion::Ack(ack) => ack.id, + } + } +} + +/// A completion with a writable buffer for response data. +/// +/// # Example +/// +/// ```ignore +/// if let SendCompletion::Writable(mut wc) = completion { +/// wc.write_all(b"response data")?; +/// consumer.complete(wc.into())?; +/// } +/// ``` +#[must_use = "dropping without completing leaks the descriptor"] +pub struct WritableCompletion { + mem: M, + id: u16, + token: Token, + elem: BufferElement, + written: usize, +} + +impl WritableCompletion { + fn new(mem: M, id: u16, token: Token, elem: BufferElement) -> Self { + Self { + mem, + id, + token, + elem, + written: 0, + } + } + + /// The token identifying this completion. + pub fn token(&self) -> Token { + self.token + } + + /// Total capacity of the completion buffer in bytes. + pub fn capacity(&self) -> usize { + self.elem.len as usize + } + + /// Number of bytes written so far. + pub fn written(&self) -> usize { + self.written + } + + /// Remaining writable capacity. + pub fn remaining(&self) -> usize { + self.capacity() - self.written + } + + /// Write bytes into the completion buffer, returning how many were written. + /// + /// Appends at the current write position. If `buf` is larger than the + /// remaining capacity, writes as many bytes as will fit (partial write). + /// + /// Returns the number of bytes actually written. + /// + /// # Errors + /// + /// - [`VirtqError::MemoryWriteError`] - underlying MemOps write failed + pub fn write(&mut self, buf: &[u8]) -> Result { + let to_write = buf.len().min(self.remaining()); + if to_write == 0 { + return Ok(0); + } + + let addr = self.elem.addr + self.written as u64; + self.mem + .write(addr, &buf[..to_write]) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += to_write; + Ok(to_write) + } + + /// Write the entire buffer or return an error. + /// + /// # Errors + /// + /// - [`VirtqError::CqeTooLarge`] - buf exceeds remaining capacity + /// - [`VirtqError::MemoryWriteError`] - underlying MemOps write failed + pub fn write_all(&mut self, buf: &[u8]) -> Result<(), VirtqError> { + if buf.len() > self.remaining() { + return Err(VirtqError::CqeTooLarge); + } + + let addr = self.elem.addr + self.written as u64; + self.mem + .write(addr, buf) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += buf.len(); + Ok(()) + } + + /// Reset the write cursor to the beginning. + /// + /// Previously written bytes in shared memory are not zeroed; the + /// `written` count is simply reset to 0. + pub fn reset(&mut self) { + self.written = 0; + } +} + +/// An ack-only completion for chains with no writable buffers. +/// +/// No response buffer - just pass back to [`VirtqConsumer::complete`] +/// to acknowledge processing and release the descriptor. +#[must_use = "dropping without completing leaks the descriptor"] +pub struct AckCompletion { + id: u16, + token: Token, +} + +impl AckCompletion { + fn new(id: u16, token: Token) -> Self { + Self { id, token } + } + + pub fn token(&self) -> Token { + self.token + } +} + +/// A high-level virtqueue consumer (device side). +/// +/// The consumer receives entries from the producer (driver), processes them, +/// and sends back completions. This is typically used on the device/host side. +/// +/// # Example +/// +/// ```ignore +/// let mut consumer = VirtqConsumer::new(layout, mem, notifier); +/// +/// // Poll and process +/// while let Some((entry, completion)) = consumer.poll(MAX_ENTRY_SIZE)? { +/// let data = entry.data(); +/// match completion { +/// SendCompletion::Writable(mut wc) => { +/// let response = handle_request(data); +/// wc.write_all(&response)?; +/// consumer.complete(wc.into())?; +/// } +/// SendCompletion::Ack(ack) => { +/// consumer.complete(ack.into())?; +/// } +/// } +/// } +/// +/// // Or defer completions +/// let mut pending = Vec::new(); +/// while let Some((entry, completion)) = consumer.poll(MAX_ENTRY_SIZE)? { +/// pending.push((process(entry), completion)); +/// } +/// for (result, completion) in pending { +/// // ... complete later ... +/// consumer.complete(completion)?; +/// } +/// ``` +pub struct VirtqConsumer { + inner: RingConsumer, + notifier: N, + inflight: Vec>, +} + +impl VirtqConsumer { + /// Create a new virtqueue consumer. + /// + /// # Arguments + /// + /// * `layout` - Ring memory layout (descriptor table and event suppression addresses) + /// * `mem` - Memory operations implementation for reading/writing to shared memory + /// * `notifier` - Callback for notifying the driver (producer) about completions + pub fn new(layout: Layout, mem: M, notifier: N) -> Self { + let inner = RingConsumer::new(layout, mem); + let inflight = vec![None; inner.len()]; + + Self { + inner, + notifier, + inflight, + } + } + + /// Poll for a single incoming entry from the driver. + /// + /// Returns a [`RecvEntry`] (data copied from shared memory) and a + /// [`SendCompletion`] (writable handle or ack token). Both are + /// independent owned values with no borrow on the consumer. + /// + /// # Arguments + /// + /// * `max_entry` - Maximum entry size to accept. Entries larger than + /// this will return [`VirtqError::EntryTooLarge`]. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - Entry data exceeds `max_entry` bytes + /// - [`VirtqError::BadChain`] - Descriptor chain format not recognized + /// - [`VirtqError::InvalidState`] - Descriptor ID collision (driver bug) + /// - [`VirtqError::MemoryReadError`] - Failed to read entry from shared memory + pub fn poll( + &mut self, + max_entry: usize, + ) -> Result)>, VirtqError> { + let (id, chain) = match self.inner.poll_available() { + Ok(x) => x, + Err(RingError::WouldBlock) => return Ok(None), + Err(e) => return Err(e.into()), + }; + + let (entry_elem, cqe_elem) = parse_chain(&chain)?; + + // Validate entry size + if let Some(ref elem) = entry_elem + && elem.len as usize > max_entry + { + return Err(VirtqError::EntryTooLarge); + } + + // Reserve the inflight slot + let slot = self + .inflight + .get_mut(id as usize) + .ok_or(VirtqError::InvalidState)?; + + if slot.is_some() { + return Err(VirtqError::InvalidState); + } + + *slot = Some(Inflight); + let token = Token(id); + + // Copy entry data from shared memory + let data = entry_elem + .map(|elem| self.read_element(&elem)) + .transpose()? + .unwrap_or_default(); + + let entry = RecvEntry { token, data }; + + // Build the appropriate completion handle + let completion = if let Some(elem) = cqe_elem { + let mem = self.inner.mem().clone(); + let cqe = WritableCompletion::new(mem, id, token, elem); + SendCompletion::Writable(cqe) + } else { + let ack = AckCompletion::new(id, token); + SendCompletion::Ack(ack) + }; + + Ok(Some((entry, completion))) + } + + /// Submit a completed entry back to the ring. + /// + /// Accepts both [`WritableCompletion`] (with written byte count) and + /// [`AckCompletion`] (zero-length) via the [`SendCompletion`] enum. + /// Clears the inflight slot and notifies the producer if event + /// suppression allows. + pub fn complete(&mut self, completion: SendCompletion) -> Result<(), VirtqError> { + let id = completion.id(); + let written = completion.written() as u32; + + let slot = self + .inflight + .get_mut(id as usize) + .ok_or(VirtqError::InvalidState)?; + + if slot.is_none() { + return Err(VirtqError::InvalidState); + } + + *slot = None; + + if self.inner.submit_used_with_notify(id, written)? { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + + Ok(()) + } + + /// Get the current available cursor position. + /// + /// Returns the position where the next available descriptor will be + /// consumed. Useful for setting up descriptor-based event suppression. + #[inline] + pub fn avail_cursor(&self) -> RingCursor { + self.inner.avail_cursor() + } + + /// Get the current used cursor position. + /// + /// Returns the position where the next used descriptor will be written. + /// Useful for setting up descriptor-based event suppression. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.inner.used_cursor() + } + + /// Configure event suppression for available buffer notifications. + /// + /// This controls when the driver (producer) signals us about new buffers: + /// + /// - [`SuppressionKind::Enable`] - Always signal (default) - good for latency + /// - [`SuppressionKind::Disable`] - Never signal - caller must poll + /// - [`SuppressionKind::Descriptor`] - Signal only at specific cursor position + /// + /// # Example: Polling Mode + /// ```ignore + /// consumer.set_avail_suppression(SuppressionKind::Disable)?; + /// loop { + /// while let Some((entry, completion)) = consumer.poll(1024)? { + /// process(entry, completion); + /// } + /// // ... do other work ... + /// } + /// ``` + pub fn set_avail_suppression(&mut self, kind: SuppressionKind) -> Result<(), VirtqError> { + match kind { + SuppressionKind::Enable => self.inner.enable_avail_notifications()?, + SuppressionKind::Disable => self.inner.disable_avail_notifications()?, + SuppressionKind::Descriptor(cursor) => self + .inner + .enable_avail_notifications_desc(cursor.head(), cursor.wrap())?, + } + Ok(()) + } + + /// Read a buffer element from shared memory into `Bytes`. + fn read_element(&self, elem: &BufferElement) -> Result { + let mut buf = vec![0u8; elem.len as usize]; + self.inner + .mem() + .read(elem.addr, &mut buf) + .map_err(|_| VirtqError::MemoryReadError)?; + + Ok(Bytes::from(buf)) + } +} + +/// Parse a descriptor chain into entry/completion buffer elements. +/// +/// Returns `(entry_element, completion_element)`. +fn parse_chain( + chain: &BufferChain, +) -> Result<(Option, Option), VirtqError> { + let r = chain.readables(); + let w = chain.writables(); + + match (r.len(), w.len()) { + (1, 1) => Ok((Some(r[0]), Some(w[0]))), + (0, 1) => Ok((None, Some(w[0]))), + (1, 0) => Ok((Some(r[0]), None)), + _ => Err(VirtqError::BadChain), + } +} + +impl From> for SendCompletion { + fn from(wc: WritableCompletion) -> Self { + SendCompletion::Writable(wc) + } +} + +impl From for SendCompletion { + fn from(ack: AckCompletion) -> Self { + SendCompletion::Ack(ack) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::virtq::ring::tests::make_ring; + use crate::virtq::test_utils::*; + + #[test] + fn test_write_only_entry_is_empty() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert!(entry.data().is_empty()); + assert!(matches!(completion, SendCompletion::Writable(_))); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response").unwrap(); + consumer.complete(wc.into()).unwrap(); + } + } + + #[test] + fn test_read_only_ack_completion() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(16).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + assert!(matches!(completion, SendCompletion::Ack(_))); + + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_readwrite_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + se.write_all(b"hello world").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello world"); + + if let SendCompletion::Writable(mut wc) = completion { + assert_eq!(wc.capacity(), 64); + assert_eq!(wc.written(), 0); + assert_eq!(wc.remaining(), 64); + wc.write_all(b"response").unwrap(); + assert_eq!(wc.written(), 8); + assert_eq!(wc.remaining(), 56); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable completion for entry+completion chain"); + } + } + + #[test] + fn test_writable_partial_write() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(8).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + let n = wc.write(b"hello world!").unwrap(); + assert_eq!(n, 8); + assert_eq!(wc.remaining(), 0); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_writable_write_all_too_large() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(4).build().unwrap(); + producer.submit(se).unwrap(); + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + let err = wc.write_all(b"too long").unwrap_err(); + assert!(matches!(err, VirtqError::CqeTooLarge)); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_writable_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"first").unwrap(); + assert_eq!(wc.written(), 5); + wc.reset(); + assert_eq!(wc.written(), 0); + assert_eq!(wc.remaining(), 16); + wc.write_all(b"second").unwrap(); + assert_eq!(wc.written(), 6); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + } + + #[test] + fn test_multiple_pending_completions() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se1 = producer.chain().completion(16).build().unwrap(); + producer.submit(se1).unwrap(); + let se2 = producer.chain().completion(16).build().unwrap(); + producer.submit(se2).unwrap(); + + let (_e1, c1) = consumer.poll(1024).unwrap().unwrap(); + let (_e2, c2) = consumer.poll(1024).unwrap().unwrap(); + + // Complete in reverse order + consumer.complete(c2).unwrap(); + consumer.complete(c1).unwrap(); + } + + #[test] + fn test_entry_into_data() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(16).build().unwrap(); + se.write_all(b"abc").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + let data = entry.into_data(); + assert_eq!(data.as_ref(), b"abc"); + consumer.complete(completion).unwrap(); + } +} diff --git a/src/hyperlight_common/src/virtq/desc.rs b/src/hyperlight_common/src/virtq/desc.rs new file mode 100644 index 000000000..57e5efb12 --- /dev/null +++ b/src/hyperlight_common/src/virtq/desc.rs @@ -0,0 +1,326 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Virtqueue Descriptor Types +//! +//! This module defines the descriptor format for packed virtqueues as specified +//! in VIRTIO 1.1+. Each descriptor represents a memory buffer in a scatter-gather +//! list that the device will read from or write to. + +use bitflags::bitflags; +use bytemuck::{Pod, Zeroable}; + +use super::MemOps; + +bitflags! { + /// Descriptor flags as defined by VIRTIO specification. + #[repr(transparent)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct DescFlags: u16 { + /// This marks a buffer as continuing via the next field. + const NEXT = 1 << 0; + /// This marks a buffer as device write-only (otherwise device read-only). + const WRITE = 1 << 1; + /// This means the buffer contains a list of buffer descriptors (unsupported here). + const INDIRECT = 1 << 2; + /// Available flag for packed virtqueue wrap counter. + const AVAIL = 1 << 7; + /// Used flag for packed virtqueue wrap counter. + const USED = 1 << 15; + } +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, Pod, Zeroable, PartialEq, Eq, Hash)] +pub struct Descriptor { + /// Physical address of the buffer. + pub addr: u64, + /// Length of the buffer in bytes. + /// For used descriptors, this contains bytes written by device. + pub len: u32, + /// Buffer ID - used to correlate completions with submissions. + /// All descriptors in a chain share the same ID. + pub id: u16, + /// Flags (NEXT, WRITE, INDIRECT, AVAIL, USED). + pub flags: u16, +} + +const _: () = assert!(core::mem::size_of::() == 16); +const _: () = assert!(Descriptor::ADDR_OFFSET == 0); +const _: () = assert!(Descriptor::LEN_OFFSET == 8); +const _: () = assert!(Descriptor::ID_OFFSET == 12); +const _: () = assert!(Descriptor::FLAGS_OFFSET == 14); + +impl Descriptor { + pub const SIZE: usize = core::mem::size_of::(); + pub const ALIGN: usize = core::mem::align_of::(); + + pub const ADDR_OFFSET: usize = core::mem::offset_of!(Self, addr); + pub const LEN_OFFSET: usize = core::mem::offset_of!(Self, len); + pub const ID_OFFSET: usize = core::mem::offset_of!(Self, id); + pub const FLAGS_OFFSET: usize = core::mem::offset_of!(Self, flags); + + pub fn new(addr: u64, len: u32, id: u16, flags: DescFlags) -> Self { + Self { + addr, + len, + id, + flags: flags.bits(), + } + } + + /// Get flags as a [`DescFlags`] bitfield. + #[inline] + pub fn flags(&self) -> DescFlags { + DescFlags::from_bits_truncate(self.flags) + } + + /// Did the guest mark this descriptor in the current guest round? + #[inline] + pub fn is_avail(&self, wrap: bool) -> bool { + let f = self.flags(); + let avail = f.contains(DescFlags::AVAIL); + let used = f.contains(DescFlags::USED); + avail == wrap && used != wrap + } + + /// Did the host mark this descriptor used in the current host round? + #[inline] + pub fn is_used(&self, wrap: bool) -> bool { + let f = self.flags(); + let avail = f.contains(DescFlags::AVAIL); + let used = f.contains(DescFlags::USED); + avail == wrap && used == wrap + } + + /// Is this descriptor writeable by the device? + #[inline] + pub fn is_writeable(&self) -> bool { + self.flags().contains(DescFlags::WRITE) + } + + /// Does this descriptor point to a next descriptor in the chain? + #[inline] + pub fn is_next(&self) -> bool { + self.flags().contains(DescFlags::NEXT) + } + + /// Mark descriptor as available according to the driver's wrap bit. + /// As per the packed-virtqueue description: + /// - set AVAIL bit to `driver_wrap` + /// - set USED bit to `!driver_wrap` (inverse) + #[inline] + pub fn mark_avail(&mut self, wrap: bool) { + if wrap { + self.flags |= DescFlags::AVAIL.bits(); + self.flags &= !DescFlags::USED.bits(); + } else { + self.flags &= !DescFlags::AVAIL.bits(); + self.flags |= DescFlags::USED.bits(); + } + } + + /// Mark descriptor as used according to the device's wrap bit. + /// As per spec: set both USED and AVAIL bits to match device_wrap + #[inline] + pub fn mark_used(&mut self, wrap: bool) { + if wrap { + self.flags |= DescFlags::USED.bits(); + self.flags |= DescFlags::AVAIL.bits(); + } else { + self.flags &= !DescFlags::USED.bits(); + self.flags &= !DescFlags::AVAIL.bits(); + } + } + + /// Read a descriptor from memory with acquire semantics for flags + /// This is the primary synchronization point for consuming descriptors. + /// + /// # Invariant + /// + /// The caller must ensure that `base` is valid for reads of Descriptor + pub fn read_acquire(mem: &M, addr: u64) -> Result { + let flags = mem.load_acquire(addr + Self::FLAGS_OFFSET as u64)?; + let addr_val: u64 = mem.read_val(addr + Self::ADDR_OFFSET as u64)?; + let len: u32 = mem.read_val(addr + Self::LEN_OFFSET as u64)?; + let id: u16 = mem.read_val(addr + Self::ID_OFFSET as u64)?; + + Ok(Self { + addr: addr_val, + len, + id, + flags, + }) + } + + /// Write a descriptor to memory with release semantics for flags at the given base pointer + /// + /// This is the primary synchronization point for publishing descriptors. + /// + /// # Invariant + /// + /// The caller must ensure that `base` is valid for writes of Descriptor + pub fn write_release(&self, mem: &M, addr: u64) -> Result<(), M::Error> { + mem.write_val(addr + Self::ADDR_OFFSET as u64, self.addr)?; + mem.write_val(addr + Self::LEN_OFFSET as u64, self.len)?; + mem.write_val(addr + Self::ID_OFFSET as u64, self.id)?; + // Flags written last with release semantics + mem.store_release(addr + Self::FLAGS_OFFSET as u64, self.flags)?; + Ok(()) + } +} + +/// A table of descriptors stored in shared memory. +#[derive(Debug, Clone, Copy)] +pub struct DescTable { + base_addr: u64, + size: usize, +} + +impl DescTable { + pub const DEFAULT_LEN: usize = 256; + + /// Create a descriptor table from shared memory. + /// + /// # Safety + /// + /// - `base` must be valid for reads and writes of `size` descriptors + /// - `base` must be properly aligned for `Descriptor` + /// - `size` must not exceed `u16::MAX` + /// - memory must remain valid for the lifetime of this table + pub unsafe fn from_raw_parts(base_addr: u64, size: usize) -> Self { + assert!(base_addr.is_multiple_of(Descriptor::ALIGN as u64)); + assert!(size <= u16::MAX as usize); + + Self { base_addr, size } + } + + /// Get view into descriptor at index or None if idx is out of bounds + pub fn desc_addr(&self, idx: u16) -> Option { + if idx >= self.size as u16 { + return None; + } + + Some(self.base_addr + (idx as u64 * Descriptor::SIZE as u64)) + } + + /// Get number of descriptors in table + pub fn len(&self) -> usize { + self.size + } + + /// Is the descriptor table empty? + pub fn is_empty(&self) -> bool { + self.size == 0 + } + + pub const fn default_len() -> usize { + Self::DEFAULT_LEN + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn mark_avail_sets_bits_correctly_wrap_true() { + let mut d = Descriptor::zeroed(); + d.flags = DescFlags::WRITE.bits() | DescFlags::NEXT.bits(); + d.mark_avail(true); + let f = d.flags(); + assert!(f.contains(DescFlags::AVAIL)); + assert!(!f.contains(DescFlags::USED)); + assert!(f.contains(DescFlags::WRITE)); + assert!(f.contains(DescFlags::NEXT)); + } + + #[test] + fn mark_avail_sets_bits_correctly_wrap_false() { + let mut d = Descriptor::zeroed(); + d.mark_avail(false); + let f = d.flags(); + assert!(!f.contains(DescFlags::AVAIL)); + assert!(f.contains(DescFlags::USED)); + } + + #[test] + fn mark_used_sets_both_bits_match_wrap_true() { + let mut d = Descriptor::zeroed(); + d.mark_used(true); + let f = d.flags(); + assert!(f.contains(DescFlags::AVAIL)); + assert!(f.contains(DescFlags::USED)); + } + + #[test] + fn mark_used_sets_both_bits_match_wrap_false() { + let mut d = Descriptor::zeroed(); + d.mark_used(false); + let f = d.flags(); + assert!(!f.contains(DescFlags::AVAIL)); + assert!(!f.contains(DescFlags::USED)); + } + + #[test] + fn is_avail_and_is_used() { + let mut d = Descriptor::zeroed(); + d.mark_avail(true); + assert!(d.is_avail(true)); + assert!(!d.is_used(true)); + d.mark_used(true); + assert!(d.is_used(true)); + assert!(!d.is_avail(true)); + d.mark_avail(false); + assert!(d.is_avail(false)); + assert!(!d.is_used(false)); + d.mark_used(false); + assert!(d.is_used(false)); + assert!(!d.is_avail(false)); + } + + #[test] + fn writeable_and_next_helpers() { + let mut d = Descriptor::zeroed(); + d.flags = (DescFlags::WRITE | DescFlags::NEXT).bits(); + assert!(d.is_writeable()); + assert!(d.is_next()); + d.flags = 0; + assert!(!d.is_writeable()); + assert!(!d.is_next()); + } + + #[test] + fn avail_then_used_wrap_flip_sequence() { + let mut d = Descriptor::zeroed(); + d.mark_avail(true); + assert!(d.is_avail(true)); + d.mark_used(false); + assert!(d.is_used(false)); + assert!(!d.is_avail(false)); + d.mark_avail(true); + assert!(d.is_avail(true)); + } + + #[test] + fn desc_table_get_out_of_bounds() { + let mut vec = vec![Descriptor::zeroed(); 4]; + let ptr = vec.as_mut_ptr(); + let table = unsafe { DescTable::from_raw_parts(ptr.addr() as u64, 4) }; + assert!(table.desc_addr(3).is_some()); + assert!(table.desc_addr(4).is_none()); + } +} diff --git a/src/hyperlight_common/src/virtq/event.rs b/src/hyperlight_common/src/virtq/event.rs new file mode 100644 index 000000000..3b0677264 --- /dev/null +++ b/src/hyperlight_common/src/virtq/event.rs @@ -0,0 +1,117 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Event Suppression for Virtqueue Notifications +//! +//! This module implements the event suppression mechanism from VIRTIO 1.1+ +//! that allows fine-grained control over when notifications are sent between +//! driver and device. + +use bitflags::bitflags; +use bytemuck::{Pod, Zeroable}; + +use super::MemOps; + +bitflags! { + #[repr(transparent)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct EventFlags: u16 { + /// Enable notifications (always notify). + const ENABLE = 0x0; + /// Disable notifications (never notify). + const DISABLE = 0x1; + /// Notify only at specific descriptor (EVENT_IDX mode). + const DESC = 0x2; + } +} + +/// Event suppression structure for controlling notifications. +#[repr(C)] +#[derive(Clone, Copy, Debug, Pod, Zeroable, PartialEq, Eq, Hash)] +pub struct EventSuppression { + // bits 0-14: offset, bit 15: wrap + pub off_wrap: u16, + // bits 0-1: flags, bits 2-15: reserved + pub flags: u16, +} + +const _: () = assert!(core::mem::size_of::() == 4); +const _: () = assert!(EventSuppression::WRAP_OFFSET == 0); +const _: () = assert!(EventSuppression::FLAGS_OFFSET == 2); + +impl EventSuppression { + pub const SIZE: usize = core::mem::size_of::(); + pub const ALIGN: usize = core::mem::align_of::(); + pub const WRAP_OFFSET: usize = core::mem::offset_of!(Self, off_wrap); + pub const FLAGS_OFFSET: usize = core::mem::offset_of!(Self, flags); + + /// Create a new event suppression with the given offset/wrap and flags. + pub fn new(off_wrap: u16, flags: EventFlags) -> Self { + Self { + off_wrap, + flags: flags.bits(), + } + } + + /// Get the event flags. + pub fn flags(&self) -> EventFlags { + EventFlags::from_bits_truncate(self.flags & 0x3) + } + + /// Set the event flags. + pub fn set_flags(&mut self, flags: EventFlags) { + self.flags = (self.flags & !0x3) | (flags.bits() & 0x3); + } + + /// Get the descriptor event offset (bits 0-14). + pub fn desc_event_off(&self) -> u16 { + self.off_wrap & 0x7FFF + } + + /// Check if the descriptor event wrap bit (bit 15) is set. + pub fn desc_event_wrap(&self) -> bool { + (self.off_wrap & 0x8000) != 0 + } + + /// Set the descriptor event offset and wrap bit. + pub fn set_desc_event(&mut self, off: u16, wrap: bool) { + self.off_wrap = (off & 0x7FFF) | if wrap { 0x8000 } else { 0 }; + } + + /// Create an `EventSuppression` from a raw pointer with acquire semantics. + /// + /// # Invariant + /// + /// The caller must ensure that `base` is a valid pointer to an EventSuppression. + pub fn read_acquire(mem: &M, addr: u64) -> Result { + // Atomic Acquire load of flags (publish point) + let flags = mem.load_acquire(addr + Self::FLAGS_OFFSET as u64)?; + let off_wrap: u16 = mem.read_val(addr + Self::WRAP_OFFSET as u64)?; + Ok(Self { off_wrap, flags }) + } + + /// Write an `EventSuppression` to a raw pointer with release semantics. + /// + /// # Invariant + /// + /// The caller must ensure that `base` is a valid pointer to an EventSuppression. + pub fn write_release(&self, mem: &M, addr: u64) -> Result<(), M::Error> { + mem.write_val(addr + Self::WRAP_OFFSET as u64, self.off_wrap)?; + // Atomic Release store of flags (publish point) + mem.store_release(addr + Self::FLAGS_OFFSET as u64, self.flags)?; + Ok(()) + } +} diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs new file mode 100644 index 000000000..490f30ac1 --- /dev/null +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -0,0 +1,1134 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Packed Virtqueue Implementation +//! +//! This module provides a high-level API for virtio packed virtqueues, built on top of +//! the lower-level ring primitives. It implements the VIRTIO 1.1+ packed ring format +//! with proper memory ordering and event suppression support. +//! +//! # Architecture +//! +//! The implementation is split into layers: +//! +//! - **High-level API** ([`VirtqProducer`], [`VirtqConsumer`]): Manages buffer allocation, +//! entry/completion lifecycle, and notification decisions. This is the recommended API +//! for most use cases. +//! +//! - **Ring primitives** ([`RingProducer`], [`RingConsumer`]): Low-level descriptor ring +//! operations with explicit buffer chain management. Use this when you need full control +//! over buffer layouts or custom allocation strategies. +//! +//! - **Descriptor and event types** ([`Descriptor`], [`EventSuppression`]): Raw virtio +//! data structures for direct memory manipulation. +//! +//! # Quick Start +//! +//! ## Single Entry/Completion +//! +//! ```ignore +//! // Producer (driver) side - build entry, submit, get completion +//! let mut entry = producer.chain() +//! .entry(64) +//! .completion(128) +//! .build()?; +//! entry.write_all(b"entry data")?; +//! let token = producer.submit(entry)?; +//! // ... wait for notification ... +//! if let Some(completion) = producer.poll()? { +//! process(completion.data); +//! } +//! +//! // Consumer (device) side - receive entry, send completion +//! if let Some((entry, completion)) = consumer.poll(max_request_size)? { +//! let request = entry.data(); +//! match completion { +//! SendCompletion::Writable(mut wc) => { +//! let response = handle(request); +//! wc.write_all(&response)?; +//! consumer.complete(wc.into())?; +//! } +//! SendCompletion::Ack(ack) => { +//! consumer.complete(ack.into())?; +//! } +//! } +//! } +//! +//! // Multiple pending completions (no borrow on consumer) +//! let mut pending = Vec::new(); +//! while let Some((entry, completion)) = consumer.poll(max_request_size)? { +//! pending.push((process(entry), completion)); +//! } +//! for (result, completion) in pending { +//! consumer.complete(completion)?; +//! } +//! ``` +//! +//! ## Multiple Entries +//! +//! Each submit checks event suppression and notifies independently: +//! +//! ```ignore +//! for data in entries { +//! let mut se = producer.chain() +//! .entry(data.len()) +//! .completion(64) +//! .build()?; +//! se.write_all(data)?; +//! producer.submit(se)?; +//! } +//! ``` +//! +//! ## Completion Batching with Event Suppression +//! +//! To receive a single notification when multiple requests complete: +//! +//! ```ignore +//! // Submit entries +//! for data in entries { +//! let mut se = producer.chain() +//! .entry(data.len()) +//! .completion(64) +//! .build()?; +//! se.write_all(data)?; +//! producer.submit(se)?; +//! } +//! +//! // Tell device: "notify me only after completing past this cursor" +//! let cursor = producer.used_cursor(); +//! producer.set_used_suppression(SuppressionKind::Descriptor(cursor))?; +//! +//! // Wait for single notification, then drain all responses +//! producer.drain(|token, data| { +//! handle_response(token, data); +//! })?; +//! ``` +//! +//! # Event Suppression +//! +//! Both sides can control when they want to be notified using [`SuppressionKind`]: +//! +//! - [`SuppressionKind::Enable`]: Always notify (default, lowest latency) +//! - [`SuppressionKind::Disable`]: Never notify (polling mode, lowest overhead) +//! - [`SuppressionKind::Descriptor`]: Notify at specific ring position (batching) +//! +//! See [`VirtqProducer::set_used_suppression`] and [`VirtqConsumer::set_avail_suppression`]. +//! +//! # Low-Level API +//! +//! For advanced use cases, the ring module exposes lower-level primitives: +//! +//! - [`RingProducer`] / [`RingConsumer`]: Direct ring access with [`BufferChain`] submission +//! - [`BufferChainBuilder`]: Construct scatter-gather buffer lists +//! - [`RingCursor`]: Track ring positions for event suppression +//! +//! Example using low-level API: +//! +//! ```ignore +//! let chain = BufferChainBuilder::new() +//! .readable(header_addr, header_len) +//! .readable(data_addr, data_len) +//! .writable(response_addr, response_len) +//! .build()?; +//! +//! let result = ring_producer.submit_available_with_notify(&chain)?; +//! if result.notify { +//! kick_device(); +//! } +//! ``` + +mod access; +mod consumer; +mod desc; +mod event; +mod pool; +mod producer; +mod ring; + +use core::num::NonZeroU16; + +pub use access::*; +pub use consumer::*; +pub use desc::*; +pub use event::*; +pub use pool::*; +pub use producer::*; +pub use ring::*; +use thiserror::Error; + +/// A trait for notifying about new requests in the virtqueue. +pub trait Notifier { + fn notify(&self, stats: QueueStats); +} + +/// Errors that can occur in the virtqueue operations. +#[derive(Error, Debug)] +pub enum VirtqError { + #[error("Ring error: {0}")] + RingError(#[from] RingError), + #[error("Allocation error: {0}")] + Alloc(#[from] AllocError), + #[error("Invalid token")] + BadToken, + #[error("Invalid chain received")] + BadChain, + #[error("Entry data too large for allocated buffer")] + EntryTooLarge, + #[error("Completion data too large for allocated buffer")] + CqeTooLarge, + #[error("Internal state error")] + InvalidState, + #[error("Memory write error")] + MemoryWriteError, + #[error("Memory read error")] + MemoryReadError, + #[error("No readable buffer in this entry")] + NoReadableBuffer, +} + +/// Layout of a packed virtqueue ring in shared memory. +/// +/// Describes the memory addresses for the descriptor table and event suppression +/// structures. Use [`from_base`](Self::from_base) to compute the layout from a +/// base address, or [`query_size`](Self::query_size) to determine memory requirements. +/// +/// # Memory Layout +/// +/// The packed ring consists of: +/// 1. Descriptor table: `num_descs` × 16 bytes, aligned to 16 bytes +/// 2. Driver event suppression: 4 bytes, aligned to 4 bytes +/// 3. Device event suppression: 4 bytes, aligned to 4 bytes +#[derive(Clone, Copy, Debug)] +pub struct Layout { + /// Packed ring descriptor table base in shared memory. + pub desc_table_addr: u64, + /// Number of descriptors (ring size, must be power of 2). + pub desc_table_len: u16, + /// Driver-written event suppression area in shared memory. + pub drv_evt_addr: u64, + /// Device-written event suppression area in shared memory. + pub dev_evt_addr: u64, +} + +#[inline] +const fn align_up(val: usize, align: usize) -> usize { + (val + align - 1) & !(align - 1) +} + +impl Layout { + /// Create a Layout from a base address and number of descriptors. + /// + /// The base address must be aligned to `Descriptor::ALIGN`. + /// The memory region starting at `base` must be at least `Layout::query_size(num_descs)` bytes. + /// + /// # Safety + /// - `base` must be valid for `Layout::query_size(num_descs)` bytes. + /// - `base` must be aligned to `Descriptor::ALIGN`. + /// - Memory must remain valid for the lifetime of the ring. + pub const unsafe fn from_base(base: u64, num_descs: NonZeroU16) -> Result { + if !base.is_multiple_of(Descriptor::ALIGN as u64) { + return Err(RingError::InvalidLayout); + } + + let desc_size = num_descs.get() as usize * Descriptor::SIZE; + let event_size = EventSuppression::SIZE; + let event_align = EventSuppression::ALIGN; + + let drv_evt_offset = align_up(desc_size, event_align); + let dev_evt_offset = align_up(drv_evt_offset + event_size, event_align); + + Ok(Self { + desc_table_addr: base, + desc_table_len: num_descs.get(), + drv_evt_addr: base + drv_evt_offset as u64, + dev_evt_addr: base + dev_evt_offset as u64, + }) + } + + /// Calculate the memory size needed for a ring with `num_descs` descriptors, + /// accounting for alignment requirements. + pub const fn query_size(num_descs: usize) -> usize { + let desc_size = num_descs * Descriptor::SIZE; + let event_size = EventSuppression::SIZE; + let event_align = EventSuppression::ALIGN; + + // desc table at offset 0, then aligned events + let drv_evt_offset = align_up(desc_size, event_align); + let dev_evt_offset = align_up(drv_evt_offset + event_size, event_align); + + dev_evt_offset + event_size + } +} + +/// Statistics about the current virtqueue state. +/// +/// Provided to the [`Notifier`] when sending notifications, allowing +/// the notifier to make decisions based on queue pressure. +#[derive(Debug, Clone, Copy)] +pub struct QueueStats { + /// Number of free descriptor slots available. + pub num_free: usize, + /// Number of descriptors currently in-flight (submitted but not completed). + pub num_inflight: usize, +} + +/// Event suppression mode for controlling when notifications are sent. +/// +/// This configures when the other side should signal (interrupt/kick) us +/// about new data. Used to optimize batching and reduce interrupt overhead. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SuppressionKind { + /// Always signal after each operation (default behavior). + Enable, + /// Never signal. + Disable, + /// Signal only when reaching a specific descriptor position. + Descriptor(RingCursor), +} + +/// A token representing a sent entry in the virtqueue. +/// +/// Tokens uniquely identify in-flight requests and are used to correlate +/// requests with their responses. The token value corresponds to the +/// descriptor ID in the underlying ring. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Token(pub u16); + +impl From for Allocation { + fn from(value: BufferElement) -> Self { + Allocation { + addr: value.addr, + len: value.len as usize, + } + } +} + +const _: () = { + const fn verify_layout(num_descs: usize) { + let base = 0x1000u64; + + // Safety: base is aligned and we're only checking layout math + let layout = + match unsafe { Layout::from_base(base, NonZeroU16::new(num_descs as u16).unwrap()) } { + Ok(l) => l, + Err(_) => panic!("from_base failed"), + }; + + let expected_size = Layout::query_size(num_descs); + + assert!(layout.desc_table_addr == base); + assert!(layout.desc_table_len as usize == num_descs); + assert!( + layout + .drv_evt_addr + .is_multiple_of(EventSuppression::ALIGN as u64) + ); + assert!( + layout + .dev_evt_addr + .is_multiple_of(EventSuppression::ALIGN as u64) + ); + + // Events don't overlap with descriptor table + let desc_end = base + (num_descs * Descriptor::SIZE) as u64; + assert!(layout.drv_evt_addr >= desc_end); + assert!(layout.dev_evt_addr >= layout.drv_evt_addr + EventSuppression::SIZE as u64); + + // Total size from query_size covers entire layout + let layout_end = layout.dev_evt_addr + EventSuppression::SIZE as u64; + assert!(base + expected_size as u64 == layout_end); + } + + verify_layout(1); + verify_layout(2); + verify_layout(4); + verify_layout(8); + verify_layout(16); + verify_layout(32); + verify_layout(64); + verify_layout(128); + verify_layout(256); + verify_layout(512); + verify_layout(1024); +}; + +/// Shared test utilities for virtqueue tests. +#[cfg(test)] +pub(crate) mod test_utils { + use alloc::sync::Arc; + use core::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; + + use super::*; + use crate::virtq::ring::tests::{OwnedRing, TestMem}; + + /// Simple notifier that tracks notification count. + #[derive(Debug, Clone)] + pub(crate) struct TestNotifier { + pub(crate) count: Arc, + } + + impl TestNotifier { + pub(crate) fn new() -> Self { + Self { + count: Arc::new(AtomicUsize::new(0)), + } + } + + pub(crate) fn notification_count(&self) -> usize { + self.count.load(Ordering::Relaxed) + } + } + + impl Notifier for TestNotifier { + fn notify(&self, _stats: QueueStats) { + self.count.fetch_add(1, Ordering::Relaxed); + } + } + + /// Simple test buffer pool that allocates from a range. + #[derive(Clone)] + pub(crate) struct TestPool { + base: u64, + next: Arc, + size: usize, + } + + impl TestPool { + pub(crate) fn new(base: u64, size: usize) -> Self { + Self { + base, + next: Arc::new(AtomicU64::new(base)), + size, + } + } + } + + impl BufferProvider for TestPool { + fn alloc(&self, len: usize) -> Result { + let addr = self.next.fetch_add(len as u64, Ordering::Relaxed); + let end = addr + len as u64; + if end > self.base + self.size as u64 { + return Err(AllocError::OutOfMemory); + } + Ok(Allocation { addr, len }) + } + + fn dealloc(&self, _alloc: Allocation) -> Result<(), AllocError> { + // Simple pool doesn't track individual allocations + Ok(()) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + // Simple implementation: always allocate new + self.dealloc(old_alloc)?; + self.alloc(new_len) + } + } + + /// Create test infrastructure: a producer, consumer, and notifier backed + /// by the supplied [`OwnedRing`]. + pub(crate) fn make_test_producer( + ring: &OwnedRing, + ) -> ( + VirtqProducer, TestNotifier, TestPool>, + VirtqConsumer, TestNotifier>, + TestNotifier, + ) { + let layout = ring.layout(); + let mem = ring.mem(); + + // Pool needs to be in memory accessible via mem - use memory after ring layout + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = TestNotifier::new(); + + let producer = VirtqProducer::new(layout, mem.clone(), notifier.clone(), pool); + let consumer = VirtqConsumer::new(layout, mem, notifier.clone()); + + (producer, consumer, notifier) + } +} + +#[cfg(test)] +mod tests { + use alloc::sync::Arc; + use core::sync::atomic::{AtomicUsize, Ordering}; + + use super::*; + use crate::virtq::ring::tests::{TestMem, make_ring}; + use crate::virtq::test_utils::*; + + /// Helper: build and submit an entry+completion chain using the chain() builder. + fn send_readwrite( + producer: &mut VirtqProducer, TestNotifier, TestPool>, + entry_data: &[u8], + cqe_cap: usize, + ) -> Token { + let mut se = producer + .chain() + .entry(entry_data.len()) + .completion(cqe_cap) + .build() + .unwrap(); + se.write_all(entry_data).unwrap(); + producer.submit(se).unwrap() + } + + #[test] + fn test_submit_notifies() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let token = send_readwrite(&mut producer, b"hello", 64); + assert!(notifier.notification_count() > initial_count); + + let (entry, _completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + } + + #[test] + fn test_multiple_submits() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let tok1 = send_readwrite(&mut producer, b"request1", 64); + let tok2 = send_readwrite(&mut producer, b"request2", 64); + let tok3 = send_readwrite(&mut producer, b"request3", 64); + + // Consumer sees all requests + for _ in 0..3 { + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(completion).unwrap(); + } + + // All completions available + let cqe1 = producer.poll().unwrap().unwrap(); + let cqe2 = producer.poll().unwrap().unwrap(); + let cqe3 = producer.poll().unwrap().unwrap(); + assert!( + [cqe1.token, cqe2.token, cqe3.token].contains(&tok1) + && [cqe1.token, cqe2.token, cqe3.token].contains(&tok2) + && [cqe1.token, cqe2.token, cqe3.token].contains(&tok3) + ); + } + + #[test] + fn test_completion_batching_with_suppression() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit entries + let tok1 = send_readwrite(&mut producer, b"req1", 64); + let tok2 = send_readwrite(&mut producer, b"req2", 64); + let tok3 = send_readwrite(&mut producer, b"req3", 64); + + // Set up completion batching via used suppression + let cursor = producer.used_cursor(); + producer + .set_used_suppression(SuppressionKind::Descriptor(cursor)) + .unwrap(); + + // Consumer processes requests + for _ in 0..3 { + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"cqe-data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } + + // Producer can drain all responses + let mut responses = Vec::new(); + producer + .drain(|tok, _data| { + responses.push(tok); + }) + .unwrap(); + + assert_eq!(responses.len(), 3); + assert!(responses.contains(&tok1)); + assert!(responses.contains(&tok2)); + assert!(responses.contains(&tok3)); + } + + #[test] + fn test_notifier_receives_context() { + #[derive(Debug, Clone)] + struct CtxNotifier { + last_num_free: Arc, + last_num_inflight: Arc, + count: Arc, + } + + impl Notifier for CtxNotifier { + fn notify(&self, stats: QueueStats) { + self.last_num_free.store(stats.num_free, Ordering::Relaxed); + self.last_num_inflight + .store(stats.num_inflight, Ordering::Relaxed); + self.count.fetch_add(1, Ordering::Relaxed); + } + } + + let ring = make_ring(16); + let layout = ring.layout(); + let mem = ring.mem(); + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + let pool = TestPool::new(pool_base, 0x8000); + let notifier = CtxNotifier { + last_num_free: Arc::new(AtomicUsize::new(0)), + last_num_inflight: Arc::new(AtomicUsize::new(0)), + count: Arc::new(AtomicUsize::new(0)), + }; + + let mut producer = VirtqProducer::new(layout, mem, notifier.clone(), pool); + + let mut se = producer.chain().entry(4).completion(32).build().unwrap(); + se.write_all(b"test").unwrap(); + producer.submit(se).unwrap(); + assert_eq!(notifier.count.load(Ordering::Relaxed), 1); + assert!(notifier.last_num_inflight.load(Ordering::Relaxed) > 0); + } + + #[test] + fn test_chain_zero_copy_batch() { + let ring = make_ring(16); + let (mut producer, mut consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + // Zero-copy entry via buf_mut + let mut se1 = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se1.buf_mut().unwrap(); + buf[..6].copy_from_slice(b"zc-ent"); + se1.set_written(6).unwrap(); + let _tok1 = producer.submit(se1).unwrap(); + + // Write-based entry + let mut se2 = producer.chain().entry(64).completion(64).build().unwrap(); + se2.write_all(b"copy-ent").unwrap(); + let _tok2 = producer.submit(se2).unwrap(); + + // Completion-only chain + let se3 = producer.chain().completion(32).build().unwrap(); + let tok3 = producer.submit(se3).unwrap(); + + // Each submit may notify independently + assert!(notifier.notification_count() > initial_count); + + // Consumer sees all three entries + let (entry1, completion1) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry1.data().as_ref(), b"zc-ent"); + consumer.complete(completion1).unwrap(); + + let (entry2, completion2) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry2.data().as_ref(), b"copy-ent"); + consumer.complete(completion2).unwrap(); + + let (_entry3, completion3) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion3 else { + panic!("expected writable completion"); + }; + wc.write_all(b"resp").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Drain completions + let _ = producer.poll().unwrap().unwrap(); + let _ = producer.poll().unwrap().unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok3); + assert_eq!(&cqe.data[..], b"resp"); + } + + #[test] + fn test_chain_zero_copy_send() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Zero-copy send: allocate, write directly, submit + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se.buf_mut().unwrap(); + assert_eq!(buf.len(), 64); + buf[..5].copy_from_slice(b"hello"); + se.set_written(5).unwrap(); + let token = producer.submit(se).unwrap(); + + // Consumer sees the data + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"hello"); + + // Write response + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"world").unwrap(); + consumer.complete(wc.into()).unwrap(); + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(&cqe.data[..], b"world"); + } + + #[test] + fn test_full_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Send an entry + let token = send_readwrite(&mut producer, b"round-trip-entry", 128); + + // Consumer receives and responds + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"round-trip-entry"); + + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + assert!(wc.capacity() >= 128); + wc.write_all(b"round-trip-rsp").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Producer gets the completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"round-trip-rsp"); + } + + #[test] + fn test_cancel_submits_zero_length() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let token = send_readwrite(&mut producer, b"entry-data", 64); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(completion).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), 0); + assert!(cqe.data.is_empty()); + } + + #[test] + fn test_hold_completion_and_complete() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let token = send_readwrite(&mut producer, b"deferred", 64); + + // Poll and hold the completion + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"deferred"); + + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"deferred-cqe").unwrap(); + consumer.complete(wc.into()).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"deferred-cqe"); + } + + #[test] + fn test_concurrent_pending_completions() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let tok1 = send_readwrite(&mut producer, b"first", 64); + let tok2 = send_readwrite(&mut producer, b"second", 64); + + // Poll both + let (entry1, completion1) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry1.token(), tok1); + assert_eq!(entry1.data().as_ref(), b"first"); + + let (entry2, completion2) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry2.token(), tok2); + assert_eq!(entry2.data().as_ref(), b"second"); + + // Complete second first (out of order) + let SendCompletion::Writable(mut wc2) = completion2 else { + panic!("expected writable"); + }; + wc2.write_all(b"resp2").unwrap(); + consumer.complete(wc2.into()).unwrap(); + + let SendCompletion::Writable(mut wc1) = completion1 else { + panic!("expected writable"); + }; + wc1.write_all(b"resp1").unwrap(); + consumer.complete(wc1.into()).unwrap(); + + let cqe1 = producer.poll().unwrap().unwrap(); + let cqe2 = producer.poll().unwrap().unwrap(); + let mut responses: Vec<_> = vec![ + (cqe1.token, cqe1.data.to_vec()), + (cqe2.token, cqe2.data.to_vec()), + ]; + responses.sort_by_key(|(t, _)| t.0); + + let expected_first = responses.iter().find(|(t, _)| *t == tok1).unwrap(); + let expected_second = responses.iter().find(|(t, _)| *t == tok2).unwrap(); + assert_eq!(&expected_first.1[..], b"resp1"); + assert_eq!(&expected_second.1[..], b"resp2"); + } +} +#[cfg(all(test, loom))] +mod fuzz { + //! Loom-based concurrency testing for the virtqueue implementation. + //! + //! Loom will explores all possible thread interleavings to find data races + //! and other concurrency bugs. However, it has specific requirements that + //! make our memory model more involved: + //! + //! ## Flag-Based Synchronization + //! + //! The virtqueue protocol uses flag-based synchronization: + //! 1. Producer writes descriptor fields (addr, len, id), then writes flags with release semantics + //! 2. Consumer reads flags with acquire semantics, then reads descriptor fields + //! + //! Loom would see this as concurrent access to the same memory and report a race, even though + //! acquire/release on flags provides proper synchronization. + //! + //! ## Shadow Atomics for Flags + //! + //! We maintain shadow atomics that loom tracks for synchronization: + //! + //! - `desc_flags`: One `AtomicU16` per descriptor for flags field + //! - `drv_flags`: `AtomicU16` for driver event suppression flags + //! - `dev_flags`: `AtomicU16` for device event suppression flags + //! + //! The `load_acquire`/`store_release` operations use these loom atomics, + //! while `read`/`write` access the underlying data directly. + //! + //! ## Memory Regions + //! + //! We use a `BTreeMap` to map addresses to memory regions: + //! - `Desc(idx)`: Individual descriptors in the ring + //! - `DrvEvt`: Driver event suppression structure + //! - `DevEvt`: Device event suppression structure + //! - `Pool`: Buffer pool for entry/completion data + + use alloc::collections::BTreeMap; + use alloc::sync::Arc; + use alloc::vec; + use core::num::NonZeroU16; + + use bytemuck::Zeroable; + use loom::sync::atomic::{AtomicU16, AtomicUsize, Ordering}; + use loom::thread; + + use super::*; + use crate::virtq::desc::Descriptor; + use crate::virtq::pool::BufferPoolSync; + + #[derive(Debug)] + pub struct MemErr; + + #[derive(Debug, Clone, Copy)] + enum RegionKind { + Desc(usize), + DrvEvt, + DevEvt, + Pool, + } + + #[derive(Debug, Clone, Copy)] + struct RegionInfo { + kind: RegionKind, + size: usize, + } + + #[derive(Debug)] + pub struct LoomMem { + descs: Vec, + drv: core::cell::UnsafeCell, + dev: core::cell::UnsafeCell, + pool: loom::cell::UnsafeCell>, + + desc_flags: Vec, + drv_flags: AtomicU16, + dev_flags: AtomicU16, + + regions: BTreeMap, + layout: Layout, + } + + unsafe impl Sync for LoomMem {} + unsafe impl Send for LoomMem {} + + impl LoomMem { + pub fn new(ring_base: u64, num_descs: usize, pool_base: u64, pool_size: usize) -> Self { + let descs_nz = NonZeroU16::new(num_descs as u16).unwrap(); + let layout = unsafe { Layout::from_base(ring_base, descs_nz).unwrap() }; + + let descs: Vec<_> = (0..num_descs).map(|_| Descriptor::zeroed()).collect(); + let desc_flags: Vec<_> = (0..num_descs).map(|_| AtomicU16::new(0)).collect(); + + let mut regions = BTreeMap::new(); + + // Register each descriptor as a separate region + for i in 0..num_descs { + let addr = layout.desc_table_addr + (i * Descriptor::SIZE) as u64; + regions.insert( + addr, + RegionInfo { + kind: RegionKind::Desc(i), + size: Descriptor::SIZE, + }, + ); + } + + regions.insert( + layout.drv_evt_addr, + RegionInfo { + kind: RegionKind::DrvEvt, + size: EventSuppression::SIZE, + }, + ); + + regions.insert( + layout.dev_evt_addr, + RegionInfo { + kind: RegionKind::DevEvt, + size: EventSuppression::SIZE, + }, + ); + + regions.insert( + pool_base, + RegionInfo { + kind: RegionKind::Pool, + size: pool_size, + }, + ); + + Self { + descs, + drv: core::cell::UnsafeCell::new(EventSuppression::zeroed()), + dev: core::cell::UnsafeCell::new(EventSuppression::zeroed()), + pool: loom::cell::UnsafeCell::new(vec![0u8; pool_size]), + desc_flags, + drv_flags: AtomicU16::new(0), + dev_flags: AtomicU16::new(0), + regions, + layout, + } + } + + pub fn layout(&self) -> Layout { + self.layout + } + + fn region(&self, addr: u64) -> Option<(RegionInfo, usize)> { + let (&base, &info) = self.regions.range(..=addr).next_back()?; + let offset = (addr - base) as usize; + + if offset < info.size { + Some((info, offset)) + } else { + None + } + } + + fn desc_ptr(&self, idx: usize) -> *mut Descriptor { + self.descs.as_ptr().cast_mut().wrapping_add(idx) + } + } + + impl MemOps for Arc { + type Error = MemErr; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => { + let desc = unsafe { &*self.desc_ptr(idx) }; + let bytes = bytemuck::bytes_of(desc); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::DrvEvt => { + let evt = unsafe { &*self.drv.get() }; + let bytes = bytemuck::bytes_of(evt); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::DevEvt => { + let evt = unsafe { &*self.dev.get() }; + let bytes = bytemuck::bytes_of(evt); + dst.copy_from_slice(&bytes[offset..offset + dst.len()]); + } + RegionKind::Pool => { + self.pool.with(|buf| { + dst.copy_from_slice(&(unsafe { &*buf })[offset..offset + dst.len()]); + }); + } + } + Ok(dst.len()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => { + let desc = unsafe { &mut *self.desc_ptr(idx) }; + let bytes = bytemuck::bytes_of_mut(desc); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::DrvEvt => { + let evt = unsafe { &mut *self.drv.get() }; + let bytes = bytemuck::bytes_of_mut(evt); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::DevEvt => { + let evt = unsafe { &mut *self.dev.get() }; + let bytes = bytemuck::bytes_of_mut(evt); + bytes[offset..offset + src.len()].copy_from_slice(src); + } + RegionKind::Pool => { + self.pool.with_mut(|buf| { + (unsafe { &mut *buf })[offset..offset + src.len()].copy_from_slice(src); + }); + } + } + Ok(src.len()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let (info, _offset) = self.region(addr).ok_or(MemErr)?; + + Ok(match info.kind { + RegionKind::Desc(idx) => self.desc_flags[idx].load(Ordering::Acquire), + RegionKind::DrvEvt => self.drv_flags.load(Ordering::Acquire), + RegionKind::DevEvt => self.dev_flags.load(Ordering::Acquire), + RegionKind::Pool => return Err(MemErr), + }) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let (info, _offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Desc(idx) => self.desc_flags[idx].store(val, Ordering::Release), + RegionKind::DrvEvt => self.drv_flags.store(val, Ordering::Release), + RegionKind::DevEvt => self.dev_flags.store(val, Ordering::Release), + RegionKind::Pool => return Err(MemErr), + } + Ok(()) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Pool => { + // Safety: pool memory is a contiguous Vec; caller ensures + // no concurrent writes for the lifetime of the returned slice. + let buf = unsafe { &*self.pool.get() }; + Ok(&buf[offset..offset + len]) + } + _ => Err(MemErr), + } + } + + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + let (info, offset) = self.region(addr).ok_or(MemErr)?; + + match info.kind { + RegionKind::Pool => { + let buf = unsafe { &mut *self.pool.get() }; + Ok(&mut buf[offset..offset + len]) + } + _ => Err(MemErr), + } + } + } + + #[derive(Debug)] + pub struct Notify { + kicks: AtomicUsize, + } + + impl Notify { + pub fn new() -> Self { + Self { + kicks: AtomicUsize::new(0), + } + } + } + + impl Notifier for Arc { + fn notify(&self, _stats: QueueStats) { + self.kicks.fetch_add(1, Ordering::Relaxed); + } + } + + #[test] + fn virtq_ping_pong() { + loom::model(|| { + let ring_base = 0x10000; + let pool_base = 0x40000; + let pool_size = 0x10000; + + let mem = Arc::new(LoomMem::new(ring_base, 8, pool_base, pool_size)); + let pool = BufferPoolSync::<256, 4096>::new(pool_base, pool_size).unwrap(); + let notify = Arc::new(Notify::new()); + + let mut prod = VirtqProducer::new(mem.layout(), mem.clone(), notify.clone(), pool); + let mut cons = VirtqConsumer::new(mem.layout(), mem.clone(), notify.clone()); + + let t_prod = thread::spawn(move || { + let mut se = prod.chain().entry(4).completion(32).build().unwrap(); + se.write_all(b"ping").unwrap(); + let tok = prod.submit(se).unwrap(); + loop { + if let Some(r) = prod.poll().unwrap() { + assert_eq!(r.token, tok); + assert_eq!(&r.data[..], b"pong"); + break; + } + thread::yield_now(); + } + }); + + let t_cons = thread::spawn(move || { + let (entry, completion) = loop { + if let Some(r) = cons.poll(1024).unwrap() { + break r; + } + thread::yield_now(); + }; + assert_eq!(entry.data().as_ref(), b"ping"); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable completion"); + }; + wc.write_all(b"pong").unwrap(); + cons.complete(wc.into()).unwrap(); + }); + + t_prod.join().unwrap(); + t_cons.join().unwrap(); + }); + } +} diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs new file mode 100644 index 000000000..0324c08fe --- /dev/null +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -0,0 +1,1334 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +//! Simple bitmap-based allocator for virtio buffer management. +//! +//! This module provides two layers: +//! +//! - [`Slab`] - a fixed-size region allocator with a power-of-two slot size `N`, +//! backed by a flat bitmap (`FixedBitSet`). +//! - [`BufferPool`] - a two-tier pool that composes two slabs: one with small +//! slots (e.g. 256 bytes) for control messages / small descriptors, and one +//! with page-sized slots (e.g. 4 KiB) for data buffers. +//! +//! # Design and algorithm +//! +//! The core allocation strategy is a bitmap allocator that performs a linear +//! search over the bitmap, but implemented via `fixedbitset`'s SIMD iteration +//! over zero bits. This is conceptually simpler than tree-based allocators +//! (e.g. linked lists or bitmaps representing a tree as in +//! ), yet for "moderate" region sizes it can +//! be faster in practice: +//! +//! - `FixedBitSet::zeroes()` and related methods use word/SIMD operations to +//! skip over runs of set bits, so the linear search is over words rather than +//! individual bits. +//! - We scan for a contiguous run of free bits corresponding to the required +//! number of slots; no auxiliary tree structure is maintained. +//! +//! The tree-based approach (bitmap encoding a tree and doing a binary search +//! in O(log(n)) time) is a natural next step if larger regions or stricter worst +//! case bounds are required; switching to such a representation should be +//! relatively straightforward since all allocation paths go through a single +//! `find_slots` function. +//! +//! # Locality characteristics +//! +//! The allocator tends to preserve spatial locality: +//! +//! - It searches from low indices upward, returning the first run of free +//! slots large enough for the request. Slots are merged if necessary. +//! - Freed runs are cached in `last_free_run` and reused eagerly, which +//! introduces a mild LIFO behavior for recently freed blocks. +//! - As a result, consecutive allocations are likely to end up in nearby slots, +//! which keeps virtqueue descriptors, control buffers, and data buffers +//! clustered in memory and helps cache performance. +//! +//! # Two-tier buffer pool +//! +//! [`BufferPool`] divides the underlying region into two slabs with different +//! slot sizes: +//! +//! - The lower tier (`Slab`, default `L = 256`) is intended for +//! *smaller allocations* - control messages, descriptor metadata, and other +//! small structures. Small allocations first try this tier. +//! - The upper tier (`Slab`, default `U = 4096`) uses page sized slots +//! and is intended for larger data buffers. +//! +//! The split of the region is currently fixed at a constant fraction +//! (`LOWER_FRACTION`) for the lower slab and the remainder for the upper slab. +//! +//! Allocation policy: +//! +//! - Requests `<= L` bytes are first attempted in the lower slab; on +//! `OutOfMemory` they fall back to the upper slab. +//! - Larger requests go directly to the upper slab. +//! - [`BufferPool::resize`] will try to grow or shrink in place within the +//! owning slab (`Slab::resize`) but will never move allocations between +//! slabs. + +#[cfg(all(test, loom))] +use alloc::sync::Arc; +use core::cmp::Ordering; + +use atomic_refcell::AtomicRefCell; +use fixedbitset::FixedBitSet; +use thiserror::Error; + +use super::access::MemOps; + +#[derive(Debug, Error, Copy, Clone)] +pub enum AllocError { + #[error("Invalid region addr {0}")] + InvalidAlign(u64), + #[error("Invalid free addr {0} and size {1}")] + InvalidFree(u64, usize), + #[error("Invalid argument")] + InvalidArg, + #[error("Empty region")] + EmptyRegion, + #[error("Out of memory")] + OutOfMemory, + #[error("Overflow")] + Overflow, +} + +/// Allocation result +#[derive(Debug, Clone, Copy)] +pub struct Allocation { + /// Starting address of the allocation + pub addr: u64, + /// Length of the allocation in bytes rounded up to slab size + pub len: usize, +} + +/// Trait for buffer providers. +pub trait BufferProvider { + /// Allocate at least `len` bytes. + fn alloc(&self, len: usize) -> Result; + + /// Free a previously allocated block. + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError>; + + /// Resize by trying in-place grow; otherwise reserve a new block and free old. + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; +} + +/// The owner of a mapped buffer, ensuring its lifetime. +/// +/// Holds a pool allocation and provides direct access to the underlying +/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it +/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for +/// zero-copy `Bytes` backed by shared memory. +/// +/// When dropped, the allocation is returned to the pool. +#[derive(Debug, Clone)] +pub struct BufferOwner { + pub(crate) pool: P, + pub(crate) mem: M, + pub(crate) alloc: Allocation, + pub(crate) written: usize, +} + +impl Drop for BufferOwner { + fn drop(&mut self) { + let _ = self.pool.dealloc(self.alloc); + } +} + +impl AsRef<[u8]> for BufferOwner { + fn as_ref(&self) -> &[u8] { + let len = self.written.min(self.alloc.len); + // Safety: BufferOwner keeps both the pool allocation and the M + // alive, so the memory region is valid. Protocol-level descriptor + // ownership transfer guarantees no concurrent writes. + match unsafe { self.mem.as_slice(self.alloc.addr, len) } { + Ok(slice) => slice, + Err(_) => &[], + } + } +} + +/// A guard that runs a cleanup function when dropped, unless dismissed. +pub struct AllocGuard(Option<(Allocation, F)>); + +impl AllocGuard { + pub fn new(alloc: Allocation, cleanup: F) -> Self { + Self(Some((alloc, cleanup))) + } + + pub fn release(mut self) -> Allocation { + self.0.take().unwrap().0 + } +} + +impl core::ops::Deref for AllocGuard { + type Target = Allocation; + + fn deref(&self) -> &Allocation { + &self.0.as_ref().unwrap().0 + } +} + +impl Drop for AllocGuard { + fn drop(&mut self) { + if let Some((alloc, cleanup)) = self.0.take() { + cleanup(alloc) + } + } +} + +#[derive(Debug, Clone)] +pub struct Slab { + /// Base address of the slab + base_addr: u64, + /// Flat bitmap to track allocated/free slots + used_slots: FixedBitSet, + /// Last free allocation cache + last_free_run: Option, +} + +impl Slab { + /// Create a new slab allocator over a fixed region. + /// Region is rounded down to a multiple of N. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let usable = region_len - (region_len % N); + let num_slots = usable / N; + let used_slots = FixedBitSet::with_capacity(num_slots); + + if base_addr % (N as u64) != 0 { + return Err(AllocError::InvalidAlign(base_addr)); + } + + if num_slots == 0 { + return Err(AllocError::EmptyRegion); + } + + Ok(Self { + base_addr, + used_slots, + last_free_run: None, + }) + } + + /// Get the address of a slot by its index + #[inline] + fn addr_of(&self, slot_idx: usize) -> Option { + self.base_addr + .checked_add((slot_idx as u64).checked_mul(N as u64)?) + } + + /// Get the slot index for a given address + #[inline] + fn slot_of(&self, addr: u64) -> usize { + let off = (addr - self.base_addr) as usize; + off / N + } + + /// Invalidate last_free_run cache if it overlaps with the given allocation. + fn maybe_invalidate_last_run(&mut self, alloc: Allocation) { + if let Some(run) = &self.last_free_run { + let new_end = alloc.addr + alloc.len as u64; + let run_end = run.addr + run.len as u64; + + if alloc.addr < run_end && run.addr < new_end { + self.last_free_run = None; + } + } + } + + /// Find a run of slots to satisfy at least `len` bytes starting at `start`. + pub fn find_slots(&mut self, slots_num: usize) -> Option { + debug_assert!(slots_num > 0); + + // Check last free run optimization + if let Some(alloc) = self.last_free_run + && alloc.len >= slots_num * N + { + let pos = self.slot_of(alloc.addr); + let _ = self.last_free_run.take(); + return Some(pos); + } + + // Fallback to full search + self.used_slots.zeroes().find(|&next_free| { + self.used_slots + .count_zeroes(next_free..next_free + slots_num) + == slots_num + }) + } + + /// Allocate at least `len` bytes by merging consecutive slots. + pub fn alloc(&mut self, len: usize) -> Result { + if len == 0 { + return Err(AllocError::InvalidArg); + } + + let total = self.used_slots.len(); + let need_slots = len.div_ceil(N); + if need_slots > total { + return Err(AllocError::OutOfMemory); + } + + let idx = self.find_slots(need_slots).ok_or(AllocError::OutOfMemory)?; + self.used_slots.insert_range(idx..idx + need_slots); + let addr = self.addr_of(idx).ok_or(AllocError::Overflow)?; + + let alloc = Allocation { + addr, + len: need_slots * N, + }; + + self.maybe_invalidate_last_run(alloc); + Ok(alloc) + } + + /// Free a previously allocated slot or multiple slots. + /// + /// `len` must be a multiple of N and `addr` must be N-aligned to base. + pub fn dealloc(&mut self, alloc: Allocation) -> Result<(), AllocError> { + let Allocation { addr, len } = alloc; + if len == 0 || len % N != 0 || addr < self.base_addr { + return Err(AllocError::InvalidFree(addr, len)); + } + let alloc_slots = len / N; + let off = (addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(addr, len)); + } + let start = off / N; + let num_slots = self.used_slots.len(); + if start + alloc_slots > num_slots { + return Err(AllocError::InvalidFree(addr, len)); + } + + // Ensure all bits are set (avoid double-free) + if !self + .used_slots + .contains_all_in_range(start..start + alloc_slots) + { + return Err(AllocError::InvalidFree(addr, len)); + } + + // Mark as free + self.used_slots.remove_range(start..start + alloc_slots); + self.last_free_run = Some(alloc); + + Ok(()) + } + + /// Try to grow a block in place by reserving adjacent free slots to the right. + /// + /// Returns Ok(None) if in-place growth is not possible. Returns Err on invalid input. + pub fn try_grow_inplace( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result, AllocError> { + let Allocation { + addr: old_addr, + len: old_len, + } = old_alloc; + + if new_len <= old_len || old_len == 0 || old_len % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let old_slots = old_len / N; + let need_slots = new_len.div_ceil(N); + let off = (old_addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let start = off / N; + if start + need_slots > self.used_slots.len() { + return Ok(None); + } + // Existing range must be allocated + if !self + .used_slots + .contains_all_in_range(start..start + old_slots) + { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + // Extension must be free + if self + .used_slots + .count_ones(start + old_slots..start + need_slots) + > 0 + { + return Ok(None); + } + + // Mark extension as allocated + self.used_slots + .insert_range(start + old_slots..start + need_slots); + + let alloc = Allocation { + addr: old_addr, + len: need_slots * N, + }; + + self.maybe_invalidate_last_run(alloc); + Ok(Some(alloc)) + } + + /// Shrink a block in place by freeing excess slots to the right. + pub fn shrink_inplace( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + let Allocation { + addr: old_addr, + len: old_len, + } = old_alloc; + + if new_len >= old_len || old_len == 0 || old_len % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let old_slots = old_len / N; + let need_slots = new_len.div_ceil(N); + let off = (old_addr - self.base_addr) as usize; + if off % N != 0 { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + let start = off / N; + if start + old_slots > self.used_slots.len() { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + // Existing range must be allocated + if !self + .used_slots + .contains_all_in_range(start..start + old_slots) + { + return Err(AllocError::InvalidFree(old_addr, old_len)); + } + + // Free the excess slots + self.used_slots + .remove_range(start + need_slots..start + old_slots); + + Ok(Allocation { + addr: old_addr, + len: need_slots * N, + }) + } + + /// Reallocate by trying in-place grow; otherwise reserve a new run of slots and free old. + /// Caller should copy the payload; this function only manages reservations. + pub fn resize( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + if new_len == 0 { + return Err(AllocError::InvalidArg); + } + + match new_len.cmp(&old_alloc.len) { + Ordering::Greater => { + match self.try_grow_inplace(old_alloc, new_len) { + // in-place growth succeeded + Ok(Some(new_alloc)) => Ok(new_alloc), + // in-place growth failed; allocate new and free old + Ok(None) => { + let new_alloc = self.alloc(new_len)?; + self.dealloc(old_alloc)?; + Ok(new_alloc) + } + // other errors are propagated + Err(err) => Err(err), + } + } + Ordering::Less => self.shrink_inplace(old_alloc, new_len), + Ordering::Equal => Ok(old_alloc), + } + } + + /// Usable size rounded up to slot multiple. + pub fn usable_size(&self, _addr: usize, len: usize) -> usize { + if len == 0 { 0 } else { len.div_ceil(N) * N } + } + + /// Number of free bytes in the slab. + pub fn free_bytes(&self) -> usize { + (self.used_slots.len() - self.used_slots.count_ones(..)) * N + } + + /// Total capacity of the slab in bytes. + pub fn capacity(&self) -> usize { + self.used_slots.len() * N + } + + /// Get the address range covered by this slab. + pub fn range(&self) -> core::ops::Range { + let end = self.base_addr + self.capacity() as u64; + self.base_addr..end + } + + /// Check if an address is within this slab's range. + pub fn contains(&self, addr: u64) -> bool { + self.range().contains(&addr) + } + + /// Get the slot size N. + pub const fn slot_size() -> usize { + N + } +} + +#[inline] +fn align_up(val: usize, align: usize) -> usize { + assert!(align > 0); + if val == 0 { + return 0; + } + val.div_ceil(align) * align +} + +#[derive(Debug)] +struct Inner { + lower: Slab, + upper: Slab, +} + +/// Two tier buffer pool with small and large slabs. +#[derive(Debug)] +pub struct BufferPool { + inner: AtomicRefCell>, +} + +impl BufferPool { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let inner = Inner::::new(base_addr, region_len)?; + Ok(Self { + inner: inner.into(), + }) + } +} + +#[cfg(all(test, loom))] +#[derive(Debug, Clone)] +pub struct BufferPoolSync { + inner: std::sync::Arc>>, +} + +#[cfg(all(test, loom))] +impl BufferPoolSync { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + let inner = Inner::::new(base_addr, region_len)?; + Ok(Self { + inner: Arc::new(std::sync::Mutex::new(inner)), + }) + } +} + +impl Inner { + /// Create a new buffer pool over a fixed region. + pub fn new(base_addr: u64, region_len: usize) -> Result { + const LOWER_FRACTION: usize = 8; + + let lower_region = region_len / LOWER_FRACTION; + let upper_region = region_len - lower_region; + + let mut aligned = base_addr; + aligned = align_up(aligned as usize, L) as u64; + let lower = Slab::::new(aligned, lower_region)?; + + // advance and align upper base to N + aligned = aligned + .checked_add(lower.capacity() as u64) + .ok_or(AllocError::Overflow)?; + + aligned = align_up(aligned as usize, U) as u64; + let upper = Slab::::new(aligned, upper_region)?; + + Ok(Self { lower, upper }) + } + + /// Allocate at least `len` bytes. + pub fn alloc(&mut self, len: usize) -> Result { + if len <= L { + match self.lower.alloc(len) { + Ok(alloc) => return Ok(alloc), + Err(AllocError::OutOfMemory) => {} + Err(e) => return Err(e), + } + } + + // fallback to upper slab + self.upper.alloc(len) + } + + /// Free a previously allocated block. + pub fn dealloc(&mut self, alloc: Allocation) -> Result<(), AllocError> { + if self.lower.contains(alloc.addr) { + self.lower.dealloc(alloc) + } else { + self.upper.dealloc(alloc) + } + } + + /// Reallocate by trying in-place grow; otherwise reserve a new block and free old. + pub fn resize( + &mut self, + old_alloc: Allocation, + new_len: usize, + ) -> Result { + if self.lower.contains(old_alloc.addr) { + maybe_move(&mut self.lower, &mut self.upper, old_alloc, new_len) + } else { + maybe_move(&mut self.upper, &mut self.lower, old_alloc, new_len) + } + } +} + +/// Try to realloc using slab that owns the old allocation; if that fails, +/// try to allocate in the other slab. The function prefers to move allocations +/// between slabs only when necessary based on size thresholds. +#[inline] +fn maybe_move( + slab: &mut Slab, + other: &mut Slab, + old_alloc: Allocation, + new_len: usize, +) -> Result { + let needs_move = if A < B { new_len > A } else { new_len <= B }; + if !needs_move { + return slab.resize(old_alloc, new_len); + } + + let new_alloc = other.alloc(new_len)?; + + slab.dealloc(old_alloc)?; + Ok(new_alloc) +} + +impl BufferProvider for BufferPool { + fn alloc(&self, len: usize) -> Result { + self.inner.borrow_mut().alloc(len) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + self.inner.borrow_mut().dealloc(alloc) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + self.inner.borrow_mut().resize(old_alloc, new_len) + } +} + +#[cfg(all(test, loom))] +impl BufferProvider for BufferPoolSync { + fn alloc(&self, len: usize) -> Result { + self.inner.lock().expect("poisoned mutex").alloc(len) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + self.inner.lock().expect("poisoned mutex").dealloc(alloc) + } + + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + self.inner + .lock() + .expect("poisoned mutex") + .resize(old_alloc, new_len) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_slab(size: usize) -> Slab { + let base = align_up(0x10000, N) as u64; + Slab::::new(base, size).unwrap() + } + + fn make_pool(size: usize) -> BufferPool { + let base = align_up(0x10000, L.max(U)) as u64; + BufferPool::::new(base, size).unwrap() + } + + #[test] + fn test_slab_new_success() { + let slab = Slab::<256>::new(0x10000, 1024).unwrap(); + assert_eq!(slab.capacity(), 1024); + assert_eq!(slab.free_bytes(), 1024); + } + + #[test] + fn test_slab_new_misaligned() { + let result = Slab::<256>::new(0x10001, 1024); + assert!(matches!(result, Err(AllocError::InvalidAlign(0x10001)))); + } + + #[test] + fn test_slab_new_empty_region() { + let result = Slab::<256>::new(0x10000, 100); + assert!(matches!(result, Err(AllocError::EmptyRegion))); + } + + #[test] + fn test_slab_alloc_single_slot() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(128).unwrap(); + assert_eq!(alloc.len, 256); + assert_eq!(slab.free_bytes(), 1024 - 256); + } + + #[test] + fn test_slab_alloc_multiple_slots() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(600).unwrap(); + assert_eq!(alloc.len, 768); // 3 slots × 256 bytes + assert_eq!(slab.free_bytes(), 1024 - 768); + } + + #[test] + fn test_slab_alloc_zero_length() { + let mut slab = make_slab::<256>(1024); + let result = slab.alloc(0); + assert!(matches!(result, Err(AllocError::InvalidArg))); + } + + #[test] + fn test_slab_alloc_too_large() { + let mut slab = make_slab::<256>(1024); + let result = slab.alloc(2048); + assert!(matches!(result, Err(AllocError::OutOfMemory))); + } + + #[test] + fn test_slab_alloc_until_full() { + let mut slab = make_slab::<256>(1024); + + // Allocate all 4 slots + let _a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + + assert_eq!(slab.free_bytes(), 0); + + // Next allocation should fail + let result = slab.alloc(256); + assert!(matches!(result, Err(AllocError::OutOfMemory))); + + // Free one and retry + slab.dealloc(a2).unwrap(); + let a5 = slab.alloc(256).unwrap(); + assert_eq!(a5.addr, a2.addr); // Should reuse same slot + } + + #[test] + fn test_slab_free_success() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + assert_eq!(slab.free_bytes(), 768); + + slab.dealloc(alloc).unwrap(); + assert_eq!(slab.free_bytes(), 1024); + } + + #[test] + fn test_slab_free_invalid_length() { + let mut slab = make_slab::<256>(1024); + let mut alloc = slab.alloc(256).unwrap(); + alloc.len = 100; // Invalid: not multiple of N + + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(_, 100)))); + } + + #[test] + fn test_slab_free_double_free() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + slab.dealloc(alloc).unwrap(); + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(_, _)))); + } + + #[test] + fn test_slab_free_invalid_address() { + let mut slab = make_slab::<256>(1024); + let alloc = Allocation { + addr: 0x99999, + len: 256, + }; + + let result = slab.dealloc(alloc); + assert!(matches!(result, Err(AllocError::InvalidFree(0x99999, _)))); + } + + #[test] + fn test_slab_cursor_optimization_lifo() { + let mut slab = make_slab::<256>(1024); + + let a1 = slab.alloc(256).unwrap(); + let addr1 = a1.addr; + + slab.dealloc(a1).unwrap(); + + // Next allocation should reuse same slot (cursor moved back) + let a2 = slab.alloc(256).unwrap(); + assert_eq!(a2.addr, addr1); + } + + #[test] + fn test_slab_cursor_rewind_for_single_slot() { + let mut slab = make_slab::<256>(1024); + + let _a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + + // Free single-slot at position 1, before cursor at 3 + slab.dealloc(a2).unwrap(); + + // Cursor should rewind to 1 + let a4 = slab.alloc(256).unwrap(); + // Should reuse slot 1 + assert_eq!(a4.addr, a2.addr); + } + + #[test] + fn test_slab_grow_inplace_success() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + // Grow from 256 to 512 (adjacent slot is free) + let grown = slab.try_grow_inplace(alloc, 512).unwrap(); + assert!(grown.is_some()); + assert_eq!(grown.unwrap().len, 512); + assert_eq!(grown.unwrap().addr, alloc.addr); + } + + #[test] + fn test_slab_grow_inplace_blocked() { + let mut slab = make_slab::<256>(1024); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); // Blocks growth + + // Can't grow because next slot is allocated + let result = slab.try_grow_inplace(a1, 512).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_slab_shrink_inplace() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(512).unwrap(); // 2 slots + + let shrunk = slab.shrink_inplace(alloc, 256).unwrap(); + assert_eq!(shrunk.len, 256); + assert_eq!(shrunk.addr, alloc.addr); + assert_eq!(slab.free_bytes(), 1024 - 256); + } + + #[test] + fn test_slab_realloc_grow_inplace() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + let new_alloc = slab.resize(alloc, 512).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); // Same address (in-place) + assert_eq!(new_alloc.len, 512); + } + + #[test] + fn test_slab_realloc_grow_relocate() { + let mut slab = make_slab::<256>(1024); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); // Blocks growth + + let new_alloc = slab.resize(a1, 512).unwrap(); + assert_ne!(new_alloc.addr, a1.addr); // Different address (relocated) + assert_eq!(new_alloc.len, 512); + } + + #[test] + fn test_slab_realloc_shrink() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(512).unwrap(); + + let new_alloc = slab.resize(alloc, 256).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); + assert_eq!(new_alloc.len, 256); + } + + #[test] + fn test_slab_realloc_same_size() { + let mut slab = make_slab::<256>(1024); + let alloc = slab.alloc(256).unwrap(); + + let new_alloc = slab.resize(alloc, 256).unwrap(); + assert_eq!(new_alloc.addr, alloc.addr); + assert_eq!(new_alloc.len, alloc.len); + } + + #[test] + fn test_slab_fragmentation_handling() { + let mut slab = make_slab::<256>(1024); + + // Create fragmentation: [U][F][U][F] + let a1 = slab.alloc(256).unwrap(); + let a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + + slab.dealloc(a2).unwrap(); + slab.dealloc(a1).unwrap(); + + // Should still be able to allocate 2-slot buffer + let big = slab.alloc(512).unwrap(); + assert_eq!(big.len, 512); + } + + #[test] + fn test_pool_new_success() { + let pool = BufferPool::<256, 4096>::new(0x10000, 1024 * 1024).unwrap(); + assert!(pool.inner.borrow().lower.capacity() > 0); + assert!(pool.inner.borrow().upper.capacity() > 0); + } + + #[test] + fn test_pool_alloc_small_to_lower() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + // Should come from lower slab + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + assert_eq!(alloc.len, 256); + } + + #[test] + fn test_pool_alloc_large_to_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + + // Should come from upper slab + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + assert_eq!(alloc.len, 4096); + } + + #[test] + fn test_pool_alloc_fallback_to_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + + // Fill lower slab completely + let mut allocations = Vec::new(); + while pool.inner.borrow().lower.free_bytes() > 0 { + allocations.push(pool.inner.borrow_mut().lower.alloc(256).unwrap()); + } + + // Small allocation should fallback to upper slab + let alloc = pool.alloc(128).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + } + + #[test] + fn test_pool_free_from_lower() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + let free_before = pool.inner.borrow().lower.free_bytes(); + pool.dealloc(alloc).unwrap(); + assert_eq!( + pool.inner.borrow().lower.free_bytes(), + free_before + alloc.len + ); + } + + #[test] + fn test_pool_free_from_upper() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + + let free_before = pool.inner.borrow().upper.free_bytes(); + pool.dealloc(alloc).unwrap(); + assert_eq!( + pool.inner.borrow().upper.free_bytes(), + free_before + alloc.len + ); + } + + #[test] + fn test_pool_realloc_within_same_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + + // Realloc within lower tier (128 -> 200, both fit in 256 slots) + let new_alloc = pool.resize(alloc, 200).unwrap(); + assert!(pool.inner.borrow().lower.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_realloc_move_to_different_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(128).unwrap(); + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + + // Realloc to size that needs upper tier + let new_alloc = pool.resize(alloc, 1500).unwrap(); + assert!(pool.inner.borrow().upper.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_realloc_shrink_stays_in_tier() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let alloc = pool.alloc(1500).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc.addr)); + + // Shrink but stay in upper tier + let new_alloc = pool.resize(alloc, 1000).unwrap(); + assert!(pool.inner.borrow().upper.contains(new_alloc.addr)); + } + + #[test] + fn test_pool_stress_many_allocations() { + let pool = make_pool::<256, 4096>(4 * 1024 * 1024); + let mut allocations = Vec::new(); + + // Allocate many buffers + for i in 0..100 { + let size = if i % 2 == 0 { 128 } else { 1500 }; + allocations.push(pool.alloc(size).unwrap()); + } + + // Free half of them + for i in (0..100).step_by(2) { + pool.dealloc(allocations[i]).unwrap(); + } + + // Should be able to allocate again + for i in 0..50 { + let size = if i % 2 == 0 { 128 } else { 1500 }; + let _alloc = pool.alloc(size).unwrap(); + } + } + + #[test] + fn test_pool_mixed_workload() { + let pool = make_pool::<256, 4096>(2 * 1024 * 1024); + + // Simulate virtio-net workload + let desc_buf = pool.alloc(64).unwrap(); // Control message + let rx_buf1 = pool.alloc(1500).unwrap(); // MTU packet + let rx_buf2 = pool.alloc(1500).unwrap(); // MTU packet + let tx_buf = pool.alloc(4096).unwrap(); // Large buffer + + // Free and reallocate + pool.dealloc(rx_buf1).unwrap(); + let rx_buf3 = pool.alloc(1500).unwrap(); + + // Should reuse freed buffer (LIFO) + assert_eq!(rx_buf3.addr, rx_buf1.addr); + + pool.dealloc(desc_buf).unwrap(); + pool.dealloc(rx_buf2).unwrap(); + pool.dealloc(rx_buf3).unwrap(); + pool.dealloc(tx_buf).unwrap(); + } + + #[test] + fn test_pool_zero_allocation_error() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let result = pool.alloc(0); + assert!(matches!(result, Err(AllocError::InvalidArg))); + } + + #[test] + fn test_pool_too_large_allocation() { + let pool = make_pool::<256, 4096>(1024 * 1024); + let result = pool.alloc(2 * 1024 * 1024); // Larger than pool + assert!(matches!(result, Err(AllocError::OutOfMemory))); + } + + #[test] + fn test_align_up_helper() { + assert_eq!(align_up(0, 256), 0); + assert_eq!(align_up(1, 256), 256); + assert_eq!(align_up(256, 256), 256); + assert_eq!(align_up(257, 256), 512); + assert_eq!(align_up(511, 256), 512); + assert_eq!(align_up(512, 256), 512); + } + + #[test] + fn test_slab_usable_size() { + let slab = make_slab::<256>(1024); + assert_eq!(slab.usable_size(0, 0), 0); + assert_eq!(slab.usable_size(0, 1), 256); + assert_eq!(slab.usable_size(0, 256), 256); + assert_eq!(slab.usable_size(0, 257), 512); + } + + #[test] + fn test_slab_contains() { + let slab = make_slab::<256>(1024); + let range = slab.range(); + + assert!(slab.contains(range.start)); + assert!(!slab.contains(range.end)); // Exclusive end + assert!(!slab.contains(0)); + } + + // Edge case: allocation exactly at boundary + #[test] + fn test_pool_boundary_allocation() { + let pool = make_pool::<256, 4096>(1024 * 1024); + + // Allocate exactly at boundary + let alloc = pool.alloc(256).unwrap(); + assert!(pool.inner.borrow().lower.contains(alloc.addr)); + + // Allocate just over boundary + let alloc2 = pool.alloc(257).unwrap(); + assert!(pool.inner.borrow().upper.contains(alloc2.addr)); + } + + // Test overflow protection + #[test] + fn test_addr_of_overflow_protection() { + let slab = make_slab::<4096>(8192); + + // This should not panic due to overflow checks + let addr = slab.addr_of(usize::MAX); + assert!(addr.is_none()); + } + + #[test] + fn test_no_overlapping_allocations() { + let mut slab = make_slab::<4096>(32768); // 8 slots + + // Allocate slot 0-1 + let a1 = slab.alloc(8000).unwrap(); + assert_eq!(a1.len, 8192); + + // Shrink to slot 0 only + let a2 = slab.shrink_inplace(a1, 4000).unwrap(); + assert_eq!(a2.len, 4096); + + // Allocate at slot 1-2 + let a3 = slab.alloc(8000).unwrap(); + assert_eq!(a3.len, 8192); + let slot1_addr = a2.addr + 4096; + assert_eq!(a3.addr, slot1_addr); + + // Free slot 0 + slab.dealloc(a2).unwrap(); + + // Try to allocate 2 slots - should NOT get slot 0-1 because slot 1 is occupied! + let a4 = slab.alloc(8000).unwrap(); + assert_ne!(a4.addr, a2.addr); // Should be at a different location + + slab.dealloc(a3).unwrap(); + slab.dealloc(a4).unwrap(); + } +} + +#[cfg(test)] +mod fuzz { + use quickcheck::{Arbitrary, Gen, QuickCheck}; + + use super::*; + + const MAX_OPS: usize = 10; + const MAX_ALLOC_SIZE: usize = 8192; + + #[derive(Clone, Debug)] + enum Op { + Alloc(usize), + Dealloc(usize), + Resize(usize, usize), + } + + impl Arbitrary for Op { + fn arbitrary(g: &mut Gen) -> Self { + match u8::arbitrary(g) % 3 { + 0 => Op::Alloc(usize::arbitrary(g) % MAX_ALLOC_SIZE + 1), + 1 => Op::Dealloc(usize::arbitrary(g)), + 2 => Op::Resize( + usize::arbitrary(g), + usize::arbitrary(g) % MAX_ALLOC_SIZE + 1, + ), + _ => unreachable!(), + } + } + } + + #[derive(Clone, Debug)] + struct Scenario { + pool_size: usize, + ops: Vec, + } + + impl Arbitrary for Scenario { + fn arbitrary(g: &mut Gen) -> Self { + let pool_size = (usize::arbitrary(g) % (4 * 1024 * 1024)) + (1024 * 1024); + let num_ops = usize::arbitrary(g) % MAX_OPS + 1; + let ops = (0..num_ops).map(|_| Op::arbitrary(g)).collect(); + + Scenario { pool_size, ops } + } + } + + fn run_scenario(s: Scenario) -> bool { + let base = align_up(0x10000, 4096) as u64; + let pool = match BufferPool::<256, 4096>::new(base, s.pool_size) { + Ok(p) => p, + Err(_) => return true, + }; + + let mut allocations: Vec = Vec::new(); + + for op in &s.ops { + match op { + Op::Alloc(size) => match pool.alloc(*size) { + Ok(alloc) => { + assert!(alloc.len >= *size); + allocations.push(alloc); + } + Err(AllocError::OutOfMemory) => {} + Err(_) => { + return false; + } + }, + Op::Dealloc(idx) => { + if allocations.is_empty() { + continue; + } + + let idx = idx % allocations.len(); + let alloc = allocations.swap_remove(idx); + + match pool.dealloc(alloc) { + Ok(_) => {} + Err(_) => return false, + } + } + Op::Resize(idx, new_size) => { + if allocations.is_empty() { + continue; + } + + let idx = idx % allocations.len(); + let old_alloc = allocations[idx]; + + match pool.resize(old_alloc, *new_size) { + Ok(new_alloc) => { + assert!(new_alloc.len >= *new_size); + allocations[idx] = new_alloc; + } + Err(AllocError::OutOfMemory) => {} + Err(_) => return false, + } + } + } + + if check_pool_invariants(&pool, &allocations).is_err() { + return false; + } + } + + // Cleanup + for alloc in &allocations { + if pool.dealloc(*alloc).is_err() { + return false; + } + } + + check_pool_invariants(&pool, &allocations).is_ok() + } + + fn check_slab_invariants(slab: &Slab) -> Result<(), &'static str> { + // Check that number of used + free slots equals total + let used = slab.used_slots.count_ones(..); + let free = slab.used_slots.count_zeroes(..); + if used + free != slab.used_slots.len() { + return Err("used + free != total slots"); + } + + let expected_free = free * N; + if slab.free_bytes() != expected_free { + return Err("free_bytes doesn't match bitmap"); + } + + if let Some(alloc) = slab.last_free_run { + if alloc.len == 0 || alloc.len % N != 0 { + return Err("last_free_run has invalid length"); + } + if !slab.contains(alloc.addr) { + return Err("last_free_run addr outside range"); + } + } + + Ok(()) + } + + fn check_pool_invariants( + pool: &BufferPool, + allocations: &[Allocation], + ) -> Result<(), &'static str> { + check_slab_invariants(&pool.inner.borrow().lower)?; + check_slab_invariants(&pool.inner.borrow().upper)?; + + if pool.inner.borrow().lower.range().end > pool.inner.borrow().upper.range().start { + return Err("lower and upper ranges overlap"); + } + + let mut seen = std::collections::HashSet::new(); + + for alloc in allocations { + if !pool.inner.borrow().lower.contains(alloc.addr) + && !pool.inner.borrow().upper.contains(alloc.addr) + { + return Err("allocation address outside pool ranges"); + } + + if alloc.len % L != 0 && alloc.len % U != 0 { + return Err("allocation length not aligned to any tier"); + } + + if !seen.insert(alloc.addr) { + return Err("duplicate allocation address in tracking"); + } + } + + Ok(()) + } + + #[test] + fn prop_allocator_invariants() { + #[cfg(miri)] + let tests = 10; + #[cfg(not(miri))] + let tests = 1000; + + QuickCheck::new() + .tests(tests) + .quickcheck(run_scenario as fn(Scenario) -> bool); + } +} diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs new file mode 100644 index 000000000..95db0b7ba --- /dev/null +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -0,0 +1,790 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +use alloc::vec; +use alloc::vec::Vec; + +use bytes::Bytes; + +use super::*; + +/// A completion received by the driver (producer) side. +/// +/// Contains the completion data and metadata about the completed entry. +/// The `data` field is a zero-copy [`Bytes`] backed by a shared-memory +/// pool allocation that is returned when the last `Bytes` clone is dropped. +#[derive(Debug)] +pub struct RecvCompletion { + /// Token identifying which entry this completion corresponds to. + pub token: Token, + /// Completion data from the device. + pub data: Bytes, +} + +/// Allocation tracking for an in-flight descriptor chain. +/// +/// Each variant corresponds to a buffer layout submitted by the driver +/// (guest/producer) and consumed by the device (host/consumer). +/// "Readable" and "writable" are from the device's perspective, following +/// the virtio convention. +#[derive(Debug, Clone, Copy)] +pub(crate) enum Inflight { + /// Driver sends data, device only acknowledges (fire-and-forget). + /// The readable buffer carries the entry; no writable buffer for a + /// device response. + ReadOnly { entry: Allocation }, + /// Driver pre-posts a writable buffer for the device to fill. + /// No readable entry - the device writes a response into the + /// completion buffer unprompted (e.g. event delivery). + WriteOnly { completion: Allocation }, + /// Bidirectional: driver sends an entry, device writes a response. + /// The readable buffer carries the entry, the writable buffer + /// receives the completion (typical request/response pattern). + ReadWrite { + entry: Allocation, + completion: Allocation, + }, +} + +impl Inflight { + fn entry(&self) -> Option { + match self { + Inflight::ReadOnly { entry } => Some(*entry), + Inflight::ReadWrite { entry, .. } => Some(*entry), + Inflight::WriteOnly { .. } => None, + } + } + + fn completion(&self) -> Option { + match self { + Inflight::WriteOnly { completion } => Some(*completion), + Inflight::ReadWrite { completion, .. } => Some(*completion), + Inflight::ReadOnly { .. } => None, + } + } + + fn try_into_chain(self, entry_len: usize) -> Result { + if let Some(entry) = self.entry() + && entry_len > entry.len + { + return Err(VirtqError::EntryTooLarge); + } + + Ok(match self { + Inflight::ReadOnly { entry } => BufferChainBuilder::new() + .readable(entry.addr, entry_len as u32) + .build()?, + Inflight::WriteOnly { completion } => BufferChainBuilder::new() + .writable(completion.addr, completion.len as u32) + .build()?, + Inflight::ReadWrite { entry, completion } => BufferChainBuilder::new() + .readable(entry.addr, entry_len as u32) + .writable(completion.addr, completion.len as u32) + .build()?, + }) + } +} + +/// A high-level virtqueue producer (driver side). +/// +/// The producer sends entries to the consumer (device), and receives completions. +/// This is typically used on the driver/guest side. +/// +/// # Example +/// +/// ```ignore +/// let mut producer = VirtqProducer::new(layout, mem, notifier, pool); +/// +/// // Build and submit an entry +/// let mut se = producer.chain().entry(64).completion(64).build()?; +/// se.write_all(b"hello")?; +/// let token = producer.submit(se)?; +/// +/// // Later, poll for completion +/// if let Some(cqe) = producer.poll()? { +/// assert_eq!(cqe.token, token); +/// println!("Got completion: {:?}", cqe.data); +/// } +/// ``` +pub struct VirtqProducer { + inner: RingProducer, + notifier: N, + pool: P, + inflight: Vec>, +} + +impl VirtqProducer +where + M: MemOps + Clone, + N: Notifier, + P: BufferProvider + Clone, +{ + /// Create a new virtqueue producer. + /// + /// # Arguments + /// + /// * `layout` - Ring memory layout (descriptor table and event suppression addresses) + /// * `mem` - Memory operations implementation for reading/writing to shared memory + /// * `notifier` - Callback for notifying the device (consumer) about new entries + /// * `pool` - Buffer allocator for entry/completion data + pub fn new(layout: Layout, mem: M, notifier: N, pool: P) -> Self { + let inner = RingProducer::new(layout, mem); + let inflight = vec![None; inner.len()]; + + Self { + inner, + pool, + notifier, + inflight, + } + } + + /// Poll for a single completion from the device. + /// + /// Returns `Ok(Some(completion))` if a completion is available, `Ok(None)` if no + /// completions are ready (would block), or an error if the device misbehaved. + /// + /// The returned [`RecvCompletion::data`] is a zero-copy [`Bytes`] backed by the + /// shared-memory allocation via [`BufferOwner`]. The pool allocation is + /// held alive as long as any `Bytes` clone exists, and is returned to the + /// pool when the last clone is dropped. + /// + /// # Errors + /// + /// - [`VirtqError::InvalidState`] - Device returned invalid descriptor ID or + /// wrote more data than the completion buffer capacity + pub fn poll(&mut self) -> Result, VirtqError> + where + M: Send + Sync + 'static, + P: Send + Sync + 'static, + { + let used = match self.inner.poll_used() { + Ok(u) => u, + Err(RingError::WouldBlock) => return Ok(None), + Err(e) => return Err(e.into()), + }; + + let id = used.id as usize; + let inf = self + .inflight + .get_mut(id) + .ok_or(VirtqError::InvalidState)? + .take() + .ok_or(VirtqError::InvalidState)?; + + let written = used.len as usize; + + // Free entry buffers (request data no longer needed) + if let Some(entry) = inf.entry() { + self.pool.dealloc(entry)?; + } + + // Read completion data + let data = match inf.completion() { + Some(buf) => { + if written > buf.len { + let _ = self.pool.dealloc(buf); + return Err(VirtqError::InvalidState); + } + let owner = BufferOwner { + pool: self.pool.clone(), + mem: self.inner.mem().clone(), + alloc: buf, + written, + }; + Bytes::from_owner(owner) + } + None => Bytes::new(), + }; + + Ok(Some(RecvCompletion { + token: Token(used.id), + data, + })) + } + + /// Drain all available completions, calling the provided closure for each. + /// + /// This is a convenience method that repeatedly calls [`poll`](Self::poll) + /// until no more completions are available. + /// + /// # Arguments + /// + /// * `f` - Closure called for each completion with its token and data + /// + /// # Example + /// + /// ```ignore + /// producer.drain(|token, data| { + /// println!("Got {:?}: {} bytes", token, data.len()); + /// })?; + /// ``` + pub fn drain(&mut self, mut f: impl FnMut(Token, Bytes)) -> Result<(), VirtqError> + where + M: Send + Sync + 'static, + P: Send + Sync + 'static, + { + while let Some(cqe) = self.poll()? { + f(cqe.token, cqe.data); + } + + Ok(()) + } + + /// Begin building a descriptor chain for submission. + /// + /// Returns a [`ChainBuilder`] that allocates buffers from the pool. + /// ``` + pub fn chain(&self) -> ChainBuilder { + ChainBuilder::new(self.inner.mem().clone(), self.pool.clone()) + } + + /// Submit a [`SendEntry`] to the ring. + /// + /// Publishes the descriptor chain, stores the in-flight tracking state, + /// and notifies the consumer if event suppression allows. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - written exceeds entry buffer capacity + /// - [`VirtqError::RingError`] - ring is full + /// - [`VirtqError::InvalidState`] - descriptor ID collision + pub fn submit(&mut self, mut entry: SendEntry) -> Result { + let written = entry.written; + let inflight = entry.inflight.take().ok_or(VirtqError::InvalidState)?; + + let cursor_before = self.inner.avail_cursor(); + let chain = inflight.try_into_chain(written)?; + let id = self.inner.submit_available(&chain)?; + + let slot = self + .inflight + .get_mut(id as usize) + .ok_or(VirtqError::InvalidState)?; + + if slot.is_some() { + return Err(VirtqError::InvalidState); + } + + *slot = Some(inflight); + + let should_notify = self.inner.should_notify_since(cursor_before)?; + if should_notify { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + + Ok(Token(id)) + } + + /// Get the current used cursor position. + /// + /// Useful for setting up descriptor-based event suppression. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.inner.used_cursor() + } + + /// Configure event suppression for used buffer notifications. + /// + /// This controls when the device (consumer) signals us about completed buffers: + /// + /// - [`SuppressionKind::Enable`]: Always signal (default) - good for latency + /// - [`SuppressionKind::Disable`]: Never signal - caller must poll + /// - [`SuppressionKind::Descriptor`]: Signal only at specific cursor position + /// + /// # Example: Completion Batching + /// + /// ```ignore + /// // Submit entries, then suppress notifications until all complete + /// let mut se = producer.chain().entry(64).completion(128).build()?; + /// se.write_all(b"entry1")?; + /// producer.submit(se)?; + /// let cursor = producer.used_cursor(); + /// producer.set_used_suppression(SuppressionKind::Descriptor(cursor))?; + /// // Device will notify only after reaching that cursor position + /// ``` + pub fn set_used_suppression(&mut self, kind: SuppressionKind) -> Result<(), VirtqError> { + match kind { + SuppressionKind::Enable => self.inner.enable_used_notifications()?, + SuppressionKind::Disable => self.inner.disable_used_notifications()?, + SuppressionKind::Descriptor(cursor) => self + .inner + .enable_used_notifications_desc(cursor.head(), cursor.wrap())?, + } + Ok(()) + } +} + +/// Builder for configuring a descriptor chain's buffer layout. +/// +/// If dropped without building, no resources are leaked (allocations are +/// deferred to [`build`](Self::build)). +#[must_use = "call .build() to create a SendEntry"] +pub struct ChainBuilder { + mem: M, + pool: P, + entry_cap: Option, + cqe_cap: Option, +} + +impl ChainBuilder { + fn new(mem: M, pool: P) -> Self { + Self { + mem, + pool, + entry_cap: None, + cqe_cap: None, + } + } + + fn alloc( + &self, + size: usize, + ) -> Result>, VirtqError> { + let alloc = self.pool.alloc(size)?; + let pool = self.pool.clone(); + + Ok(AllocGuard::new(alloc, move |a| { + let _ = pool.dealloc(a); + })) + } + + /// Request an entry buffer of `cap` bytes. + /// + /// The entry holds data sent from the driver to the consumer (device). + /// The actual allocation is deferred to [`build`](Self::build). + pub fn entry(mut self, cap: usize) -> Self { + self.entry_cap = Some(cap); + self + } + + /// Request a completion buffer of `cap` bytes. + /// + /// The completion buffer is filled by the consumer and returned via + /// [`VirtqProducer::poll`] as [`RecvCompletion`]. + pub fn completion(mut self, cap: usize) -> Self { + self.cqe_cap = Some(cap); + self + } + + /// Allocate buffers and return a [`SendEntry`] for writing. + /// + /// # Errors + /// + /// - [`VirtqError::InvalidState`] - Neither entry nor completion requested + /// - [`VirtqError::Alloc`] - Pool exhausted + pub fn build(self) -> Result, VirtqError> { + if self.entry_cap.is_none() && self.cqe_cap.is_none() { + return Err(VirtqError::InvalidState); + } + + let entry_alloc = self.entry_cap.map(|cap| self.alloc(cap)).transpose()?; + let completion_alloc = self.cqe_cap.map(|cap| self.alloc(cap)).transpose()?; + + let inflight = match (entry_alloc, completion_alloc) { + (Some(entry), Some(cqe)) => Inflight::ReadWrite { + entry: entry.release(), + completion: cqe.release(), + }, + (Some(entry), None) => Inflight::ReadOnly { + entry: entry.release(), + }, + (None, Some(cqe)) => Inflight::WriteOnly { + completion: cqe.release(), + }, + (None, None) => unreachable!(), + }; + + Ok(SendEntry { + mem: self.mem, + pool: self.pool, + inflight: Some(inflight), + written: 0, + }) + } +} + +/// A configured entry ready for writing and submission. +/// +/// Created by [`ChainBuilder::build`]. Write data into the entry buffer +/// with [`write_all`](Self::write_all), +/// or use [`buf_mut`](Self::buf_mut) for zero-copy direct access. +/// Then submit via [`VirtqProducer::submit`]. +/// +/// # Examples +/// +/// ```ignore +/// let mut se = producer.chain().entry(64).completion(128).build()?; +/// se.write_all(b"header")?; +/// se.write_all(b" body")?; +/// let tok = producer.submit(se)?; +/// +/// // Zero-copy direct access +/// let mut se = producer.chain().entry(128).build()?; +/// let buf = se.buf_mut()?; +/// let n = serialize_into(buf); +/// se.set_written(n)?; +/// let tok = producer.submit(se)?; +/// ``` +/// +/// If dropped without submitting, allocated buffers are returned to the pool. +#[must_use = "dropping without submitting deallocates the buffers"] +pub struct SendEntry { + mem: M, + pool: P, + written: usize, + inflight: Option, +} + +impl SendEntry { + fn entry(&self) -> Result { + self.inflight + .as_ref() + .and_then(|i| i.entry()) + .ok_or(VirtqError::NoReadableBuffer) + } + + /// Total entry buffer capacity in bytes. + /// + /// Returns 0 when there are no entry buffers. + pub fn capacity(&self) -> usize { + self.inflight + .as_ref() + .and_then(|i| i.entry()) + .map_or(0, |a| a.len) + } + + /// Number of bytes written so far via [`write_all`](Self::write_all) + /// or [`set_written`](Self::set_written). + pub fn written(&self) -> usize { + self.written + } + + /// Set the write cursor to an explicit byte count. + /// + /// Use this after [`buf_mut`](Self::buf_mut) where you wrote directly + /// into the buffer. The value tells the consumer how many bytes of + /// the entry buffer are valid. + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - `written` exceeds entry buffer capacity + pub fn set_written(&mut self, written: usize) -> Result<(), VirtqError> { + if written > self.capacity() { + return Err(VirtqError::EntryTooLarge); + } + + self.written = written; + Ok(()) + } + + /// Remaining writable capacity in the entry buffer. + pub fn remaining(&self) -> usize { + self.capacity() - self.written + } + + /// Write the entire buffer into the entry. + /// + /// Appends at the current write position. Uses [`MemOps::write`] + /// (volatile on host side). + /// + /// # Errors + /// + /// - [`VirtqError::EntryTooLarge`] - buf exceeds remaining capacity + /// - [`VirtqError::NoReadableBuffer`] - no entry buffer allocated + /// - [`VirtqError::MemoryWriteError`] - underlying write failed + pub fn write_all(&mut self, buf: &[u8]) -> Result<(), VirtqError> { + let alloc = self.entry()?; + + if buf.len() > self.remaining() { + return Err(VirtqError::EntryTooLarge); + } + + let addr = alloc.addr + self.written as u64; + self.mem + .write(addr, buf) + .map_err(|_| VirtqError::MemoryWriteError)?; + + self.written += buf.len(); + Ok(()) + } + + /// Zero-copy access to the full entry buffer in shared memory. + /// + /// Returns `&mut [u8]` pointing directly into the allocated buffer. + /// This is safe on the guest side (producer). After writing, call + /// [`set_written`](Self::set_written) to record how many bytes are valid. + /// + /// **Note**: This bypasses the write cursor. Use either `buf_mut()` + + /// `set_written(n)` or the `write_all` method, not both. + /// + /// # Errors + /// + /// - [`VirtqError::NoReadableBuffer`] - no entry buffer allocated + /// - [`VirtqError::MemoryWriteError`] - failed to access shared memory + pub fn buf_mut(&mut self) -> Result<&mut [u8], VirtqError> { + let alloc = self.entry()?; + unsafe { + self.mem + .as_mut_slice(alloc.addr, alloc.len) + .map_err(|_| VirtqError::MemoryWriteError) + } + } +} + +impl Drop for SendEntry { + fn drop(&mut self) { + let inf = match self.inflight.take() { + Some(i) => i, + None => return, // already submitted + }; + if let Some(a) = inf.entry() { + let _ = self.pool.dealloc(a); + } + if let Some(a) = inf.completion() { + let _ = self.pool.dealloc(a); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::virtq::ring::tests::make_ring; + use crate::virtq::test_utils::*; + + #[test] + fn test_chain_readwrite_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().entry(64).completion(128).build().unwrap(); + assert_eq!(se.capacity(), 64); + assert_eq!(se.written(), 0); + assert_eq!(se.remaining(), 64); + } + + #[test] + fn test_chain_entry_only_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().entry(32).build().unwrap(); + assert_eq!(se.capacity(), 32); + } + + #[test] + fn test_chain_completion_only_build() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(64).build().unwrap(); + assert_eq!(se.capacity(), 0); + } + + #[test] + fn test_chain_empty_build_fails() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let result = producer.chain().build(); + assert!(matches!(result, Err(VirtqError::InvalidState))); + } + + #[test] + fn test_send_entry_write_all_and_submit() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + + se.write_all(b"hello").unwrap(); + se.write_all(b" world").unwrap(); + assert_eq!(se.written(), 11); + assert_eq!(se.remaining(), 53); + let tok = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), tok); + assert_eq!(entry.data().as_ref(), b"hello world"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_send_entry_buf_mut() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + let buf = se.buf_mut().unwrap(); + assert_eq!(buf.len(), 64); + buf[..5].copy_from_slice(b"hello"); + se.set_written(5).unwrap(); + let _tok = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + consumer.complete(completion).unwrap(); + } + + #[test] + fn test_send_entry_write_too_large() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(4).build().unwrap(); + let err = se.write_all(b"too long").unwrap_err(); + assert!(matches!(err, VirtqError::EntryTooLarge)); + } + + #[test] + fn test_writeonly_has_no_entry_buffer() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().completion(32).build().unwrap(); + let err = se.write_all(b"data").unwrap_err(); + assert!(matches!(err, VirtqError::NoReadableBuffer)); + } + + #[test] + fn test_drop_chain_builder_deallocs() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + { + let _builder = producer.chain().entry(64).completion(128); + // dropped without build + } + + // Ring should still be fully usable + let se = producer.chain().entry(64).completion(128).build().unwrap(); + let tok = producer.submit(se).unwrap(); + assert!(tok.0 < 16); + } + + #[test] + fn test_drop_send_entry_deallocs() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + { + let _se = producer.chain().entry(64).completion(128).build().unwrap(); + // dropped without submit + } + + // Ring should still be fully usable + let se = producer.chain().entry(64).completion(128).build().unwrap(); + let tok = producer.submit(se).unwrap(); + assert!(tok.0 < 16); + } + + #[test] + fn test_submit_notifies() { + let ring = make_ring(16); + let (mut producer, _consumer, notifier) = make_test_producer(&ring); + + let initial_count = notifier.notification_count(); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + assert!(notifier.notification_count() > initial_count); + } + + #[test] + fn test_set_written_too_large() { + let ring = make_ring(16); + let (producer, _consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + let err = se.set_written(64).unwrap_err(); + assert!(matches!(err, VirtqError::EntryTooLarge)); + } + + #[test] + fn test_write_only_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let se = producer.chain().completion(32).build().unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert!(entry.data().is_empty()); + + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"filled-by-consumer").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), b"filled-by-consumer".len()); + assert_eq!(&cqe.data[..], b"filled-by-consumer"); + } + + #[test] + fn test_read_only_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(32).build().unwrap(); + se.write_all(b"fire-and-forget").unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.token(), token); + assert_eq!(entry.data().as_ref(), b"fire-and-forget"); + assert!(matches!(completion, SendCompletion::Ack(_))); + consumer.complete(completion).unwrap(); + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(cqe.data.len(), 0); + assert!(cqe.data.is_empty()); + } + + #[test] + fn test_readwrite_round_trip() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + let mut se = producer.chain().entry(64).completion(128).build().unwrap(); + se.write_all(b"request data").unwrap(); + let token = producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"request data"); + if let SendCompletion::Writable(mut wc) = completion { + wc.write_all(b"response data").unwrap(); + consumer.complete(wc.into()).unwrap(); + } else { + panic!("expected Writable"); + } + + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, token); + assert_eq!(&cqe.data[..], b"response data"); + } +} diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs new file mode 100644 index 000000000..c130cdcad --- /dev/null +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -0,0 +1,3169 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + */ + +//! Packed Virtqueue Ring Implementation +//! +//! This module implements the packed virtqueue format from the VIRTIO specification. +//! Packed virtqueues use a single descriptor ring where descriptors cycle through +//! available and used states, providing better cache locality and simpler memory +//! layout compared to split virtqueues. +//! +//! # Descriptor State Machine +//! +//! Each descriptor transitions through states using AVAIL and USED flags: +//! +//! ```text +//! Driver publishes +//! ┌─────────┐ (AVAIL=wrap) ┌───────────┐ +//! │ Free │ ──────────────────> │ Available │ +//! └─────────┘ └───────────┘ +//! ^ │ +//! │ │ Device consumes +//! │ Driver reclaims │ and marks used +//! │ (polls USED=wrap) │ (USED=wrap) +//! │ v +//! ┌─────────┐ ┌───────────┐ +//! │Reclaimed│ <────────────────── │ Used │ +//! └─────────┘ └───────────┘ +//! ``` +//! +//! # Wrap Counter +//! +//! The wrap counter solves ring wraparound ambiguity. When cursors wrap around +//! the ring, the wrap counter toggles, changing how AVAIL/USED flags are interpreted: +//! +//! - **wrap=true**: AVAIL=1, USED=0 means "available"; AVAIL=1, USED=1 means "used" +//! - **wrap=false**: AVAIL=0, USED=1 means "available"; AVAIL=0, USED=0 means "used" +//! +//! # Buffer Chains +//! +//! Multiple buffers can be chained using the NEXT flag. All descriptors in a chain +//! share the same ID, and only the head descriptor's AVAIL/USED flags matter for +//! state transitions: +//! +//! ```text +//! Chain with 3 buffers (ID=5): +//! ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ +//! │ Desc[0] │ │ Desc[1] │ │ Desc[2] │ +//! │ id=42 │───>│ id=42 │───>│ id=42 │ +//! │ flags=NEXT │ │ flags=NEXT │ │ flags=0 │ +//! │ AVAIL/USED │ │ (ignored) │ │ (ignored) │ +//! └──────────────┘ └──────────────┘ └──────────────┘ +//! HEAD MIDDLE TAIL +//! ``` +//! +//! # Event Suppression +//! +//! Both sides can control when they want to be notified: +//! +//! - **ENABLE**: Always notify (default) +//! - **DISABLE**: Never notify (for polling mode) +//! - **DESC**: Notify only when a specific descriptor index is reached +//! ``` + +use core::marker::PhantomData; +use core::sync::atomic::{Ordering, fence}; + +use bytemuck::Zeroable; +use smallvec::SmallVec; +use thiserror::Error; + +use super::desc::{DescFlags, DescTable, Descriptor}; +use super::event::{EventFlags, EventSuppression}; +use super::{Layout, MemOps}; + +/// A single buffer element in a scatter-gather list. +/// +/// Represents one contiguous memory region that the device will read from +/// or write to. Multiple elements can be chained together to form a +/// [`BufferChain`]. +#[derive(Debug, Copy, Clone, Zeroable)] +pub struct BufferElement { + /// Physical address of buffer + pub addr: u64, + /// Length of the buffer in bytes + pub len: u32, + /// Is this buffer writable + pub writable: bool, +} + +/// A buffer returned from the ring after being used by the device. +/// +/// When the device completes processing a buffer chain, it returns this +/// structure containing the original descriptor ID and the number of bytes +/// written (for chains with writable buffers). +#[derive(Debug, Copy, Clone)] +pub struct UsedBuffer { + /// Descriptor ID that was assigned when the buffer was submitted + pub id: u16, + /// Number of bytes written by the device to writable buffers. + /// For read-only chains, this may be 0 or the total readable length. + pub len: u32, +} + +/// Result of submitting a buffer to the ring. +/// +/// Contains the assigned descriptor ID and whether the other side +/// needs to be notified about the new buffer. +#[derive(Debug, Copy, Clone)] +pub struct SubmitResult { + /// Descriptor ID assigned to the submitted buffer chain + /// Use this ID to correlate completions with submissions. + pub id: u16, + /// Whether the device should be notified immediately based on the other + /// side's event suppression settings. + pub notify: bool, +} + +#[derive(Error, Debug)] +pub enum RingError { + #[error("Buffer chain is empty")] + EmptyChain, + #[error("Buffer chain is malformed")] + BadChain, + #[error("Operation would block")] + WouldBlock, + #[error("Out of memory")] + OutOfMemory, + #[error("Invalid state")] + InvalidState, + #[error("Invalid memory layout")] + InvalidLayout, + #[error("Backend memory error")] + MemError, +} + +/// Type-state: Can add readable buffers +pub struct Readable; + +/// Type-state: Can add writable buffers (no more readables allowed) +pub struct Writable; + +/// A builder for buffer chains using type-state to enforce readable/writable order. +/// +/// Upholds invariants: at least one buffer must be present in the chain, +/// and readable buffers must be added before writable buffers. +#[derive(Debug, Default)] +pub struct BufferChainBuilder { + elems: SmallVec<[BufferElement; 16]>, + split: usize, + marker: PhantomData, +} + +impl BufferChainBuilder { + /// Create a new builder in the [`Readable`] state. + pub fn new() -> Self { + Self { + elems: Default::default(), + split: 0, + marker: PhantomData, + } + } + + /// Add a readable buffer (device reads from this). + pub fn readable(mut self, addr: u64, len: u32) -> Self { + self.elems.push(BufferElement { + addr, + len, + writable: false, + }); + self.split += 1; + self + } + + /// Add multiple readable buffers from an iterator. + pub fn readables( + mut self, + elements: impl IntoIterator>, + ) -> Self { + for elem in elements { + self.elems.push(elem.into()); + self.split += 1; + } + + self + } + + /// Add a writable buffer (device writes to this). + /// + /// This transitions to Writable state so no more readable buffers can be added. + pub fn writable(mut self, addr: u64, len: u32) -> BufferChainBuilder { + self.elems.push(BufferElement { + addr, + len, + writable: true, + }); + + BufferChainBuilder { + elems: self.elems, + split: self.split, + marker: PhantomData, + } + } + + /// Add multiple readable buffers from an iterator. + /// + /// This transitions to Writable state so no more readable buffers can be added. + pub fn writables( + mut self, + elements: impl IntoIterator>, + ) -> BufferChainBuilder { + for elem in elements { + self.elems.push(elem.into()); + } + + BufferChainBuilder { + elems: self.elems, + split: self.split, + marker: PhantomData, + } + } + + /// Build a buffer chain with only readable buffers. + /// + /// Chain must have at least one buffer otherwise an error is returned. + pub fn build(self) -> Result { + if self.elems.is_empty() { + return Err(RingError::EmptyChain); + } + + Ok(BufferChain { + elems: self.elems, + split: self.split, + }) + } +} + +impl BufferChainBuilder { + /// Add writable buffer + pub fn writable(mut self, addr: u64, len: u32) -> Self { + self.elems.push(BufferElement { + addr, + len, + writable: true, + }); + self + } + + /// Add multiple readable buffers from an iterator. + pub fn writables( + mut self, + elements: impl IntoIterator>, + ) -> Self { + for elem in elements { + self.elems.push(elem.into()); + } + self + } + + /// Build the buffer chain. + /// + /// Chain must have at least one buffer otherwise an error is returned. + pub fn build(self) -> Result { + if self.elems.is_empty() { + return Err(RingError::EmptyChain); + } + + Ok(BufferChain { + elems: self.elems, + split: self.split, + }) + } +} + +/// A chain of buffers ready for submission to the virtqueue. +/// +/// Contains a scatter-gather list of [`BufferElement`]s, divided into +/// readable (driver->device) and writable (device->driver) sections. +#[derive(Debug, Default, Clone)] +pub struct BufferChain { + /// All buffer elements (readable followed by writable) + elems: SmallVec<[BufferElement; 16]>, + /// Split index between readable and writable buffers + split: usize, +} + +impl BufferChain { + /// Get all buffer elements in the chain. + pub fn elems(&self) -> &[BufferElement] { + self.elems.as_slice() + } + + /// Get writable buffers in chain + pub fn readables(&self) -> &[BufferElement] { + &self.elems[..self.split] + } + + /// Get writable buffers in chain + pub fn writables(&self) -> &[BufferElement] { + &self.elems[self.split..] + } + + /// Get total number of buffers in chain + // Note: buffer chain cannot be empty by construction + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.elems.len() + } +} + +/// Tracks position in a ring buffer with wrap-around handling. +/// +/// The cursor maintains both an index into the ring and a wrap counter +/// that toggles each time the index wraps around. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct RingCursor { + head: u16, + size: u16, + wrap: bool, +} + +impl RingCursor { + pub(crate) fn new(size: usize) -> Self { + Self { + head: 0, + size: size as u16, + wrap: true, + } + } + + /// Advance to next position, wrapping around and toggling wrap counter if needed + #[inline] + pub(crate) fn advance(&mut self) { + self.head += 1; + if self.head >= self.size { + self.head = 0; + self.wrap = !self.wrap; + } + } + + /// Advance by n positions + #[inline] + pub(crate) fn advance_by(&mut self, n: u16) { + for _ in 0..n { + self.advance(); + } + } + + /// Get current head index + #[inline] + pub fn head(&self) -> u16 { + self.head + } + + /// Get current wrap counter + #[inline] + pub fn wrap(&self) -> bool { + self.wrap + } +} + +/// Producer (driver) side of a packed virtqueue. +/// +/// The producer submits buffer chains for the device to process and polls +/// for completions. This is typically used by the driver/guest side. +/// +/// # Lifecycle +/// +/// 1.Submit: Call [`submit_available`](Self::submit_available) or +/// [`submit_one`](Self::submit_one) to make buffers available to device +/// 2. Notify: If `SubmitResult::notify` is true, signal the device +/// 3. Poll: Call [`poll_used`](Self::poll_used) to check for completions +/// 4. Process: Handle completed buffers and reuse descriptor IDs +#[derive(Debug)] +pub struct RingProducer { + /// Memory accessor + mem: M, + /// Next available descriptor position + avail_cursor: RingCursor, + /// Next used descriptor position + used_cursor: RingCursor, + /// Free slots in the ring + num_free: usize, + /// Descriptor table in shared memory + desc_table: DescTable, + /// Shadow of driver event flags (last written value) + event_flags_shadow: EventFlags, + // controls when device notifies about used buffers + drv_evt_addr: u64, + // reads device event to check if device wants notification + dev_evt_addr: u64, + /// stack of free IDs, allows out-of-order completion + id_free: SmallVec<[u16; DescTable::DEFAULT_LEN]>, + // chain length per ID, index = ID, + id_num: SmallVec<[u16; DescTable::DEFAULT_LEN]>, +} + +impl RingProducer { + /// Create a new producer from a memory layout and accessor. + pub fn new(layout: Layout, mem: M) -> Self { + let size = layout.desc_table_len as usize; + let raw = layout.desc_table_addr; + + // SAFETY: layout is valid + let table = unsafe { DescTable::from_raw_parts(raw, size) }; + let cursor = RingCursor::new(size); + + const DEFAULT_LEN: usize = DescTable::default_len(); + let id_free = (0..size as u16).collect::>(); + let id_num = SmallVec::<[_; DEFAULT_LEN]>::from_elem(0, size); + + // Notification enabled by default + let event_flags_shadow = EventFlags::ENABLE; + + Self { + mem, + avail_cursor: cursor, + used_cursor: cursor, + num_free: size, + desc_table: table, + id_free, + id_num, + event_flags_shadow, + drv_evt_addr: layout.drv_evt_addr, + dev_evt_addr: layout.dev_evt_addr, + } + } + + /// Fast path: submit exactly one descriptor + /// + /// This is more efficient than [`submit_available`](Self::submit_available) + /// for single-buffer submissions as it avoids chain iteration overhead. + /// + /// # Arguments + /// + /// * `addr` - physical address of the buffer + /// * `len` - Length of the buffer in bytes + /// * `writable` - If true, device writes to buffer; if false, device reads + /// + /// # Returns + /// + /// The descriptor ID assigned to this buffer, for matching with completions. + /// + /// # Errors + /// + /// - [`RingError::WouldBlock`] - No free descriptor slots + /// - [`RingError::OutOfMemory`] - No free descriptor IDs (internal error) + /// - [`RingError::InvalidState`] - ID tracking corrupted (internal error) + pub fn submit_one(&mut self, addr: u64, len: u32, writable: bool) -> Result { + if self.num_free < 1 { + return Err(RingError::WouldBlock); + } + + // Allocate ID and record chain length + let id = self.id_free.pop().ok_or(RingError::OutOfMemory)?; + + // We should never reuse an ID that is still outstanding + if self.id_num[id as usize] != 0 { + return Err(RingError::InvalidState); + } + + // Record chain length for single descriptor + self.id_num[id as usize] = 1; + + // Build and publish the head descriptor + let head_idx = self.avail_cursor.head(); + let head_wrap = self.avail_cursor.wrap(); + + let mut flags = DescFlags::empty(); + flags.set(DescFlags::WRITE, writable); + let mut desc = Descriptor::new(addr, len, id, flags); + desc.mark_avail(head_wrap); + + let addr = self + .desc_table + .desc_addr(head_idx) + .ok_or(RingError::InvalidState)?; + + // Release publish + desc.write_release(&self.mem, addr) + .map_err(|_| RingError::MemError)?; + + // Advance state + self.avail_cursor.advance(); + self.num_free -= 1; + + Ok(id) + } + + /// Submit a buffer chain to the ring, returning whether to notify the device. + pub fn submit_available_with_notify( + &mut self, + chain: &BufferChain, + ) -> Result { + let old = self.avail_cursor; + let id = self.submit_available(chain)?; + let new = self.avail_cursor; + let notify = self.should_notify_device(old, new)?; + + Ok(SubmitResult { id, notify }) + } + + /// Submit a single-buffer descriptor with notification check. + pub fn submit_one_with_notify( + &mut self, + addr: u64, + len: u32, + writable: bool, + ) -> Result { + let old = self.avail_cursor; + let id = self.submit_one(addr, len, writable)?; + let new = self.avail_cursor; + let notify = self.should_notify_device(old, new)?; + Ok(SubmitResult { id, notify }) + } + + /// Submit a buffer chain to the ring. + /// + /// Writes all descriptors in the chain to the ring, linking them with + /// NEXT flags. The head descriptor is written last with release semantics + /// to ensure atomicity of the chain. + /// + /// # Arguments + /// + /// * `chain` - The buffer chain to submit + /// + /// # Returns + /// + /// The descriptor ID assigned to this chain. All descriptors in the chain + /// share this ID for correlation during completion. + /// + /// # Errors + /// + /// - [`RingError::EmptyChain`] - Chain has no buffers + /// - [`RingError::WouldBlock`] - Not enough free descriptor slots + pub fn submit_available(&mut self, chain: &BufferChain) -> Result { + let total_descs = chain.len(); + if total_descs == 0 { + return Err(RingError::EmptyChain); + } + + if self.num_free < total_descs { + return Err(RingError::WouldBlock); + } + + if total_descs == 1 { + let elem = chain.elems()[0]; + return self.submit_one(elem.addr, elem.len, elem.writable); + } + + let head_idx = self.avail_cursor.head(); + let head_wrap = self.avail_cursor.wrap(); + + let id = self.id_free.pop().ok_or(RingError::InvalidState)?; + + // We should never reuse an ID that is still outstanding + if self.id_num[id as usize] != 0 { + return Err(RingError::InvalidState); + } + + // Record chain length + self.id_num[id as usize] = total_descs as u16; + + // Write tail elements first; head last. + let mut pos = self.avail_cursor; + pos.advance(); + + for (i, elem) in chain.elems().iter().enumerate().skip(1) { + let is_next = i + 1 < total_descs; + let mut flags = DescFlags::empty(); + + flags.set(DescFlags::NEXT, is_next); + flags.set(DescFlags::WRITE, elem.writable); + + let mut desc = Descriptor::new(elem.addr, elem.len, id, flags); + desc.mark_avail(pos.wrap()); + + let addr = self + .desc_table + .desc_addr(pos.head()) + .ok_or(RingError::InvalidState)?; + + self.mem + .write_val(addr, desc) + .map_err(|_| RingError::MemError)?; + pos.advance(); + } + + // Head descriptor + let head_elem = chain.elems()[0]; + // Record chain length + let mut head_flags = DescFlags::empty(); + head_flags.set(DescFlags::NEXT, total_descs > 1); + head_flags.set(DescFlags::WRITE, head_elem.writable); + + let mut head_desc = Descriptor::new(head_elem.addr, head_elem.len, id, head_flags); + head_desc.mark_avail(head_wrap); + + let head_addr = self + .desc_table + .desc_addr(head_idx) + .ok_or(RingError::InvalidState)?; + + // Release publish + head_desc + .write_release(&self.mem, head_addr) + .map_err(|_| RingError::MemError)?; + + self.num_free -= total_descs; + self.avail_cursor = pos; + + Ok(id) + } + + /// Poll the ring for a used buffer. + /// + /// Checks if the device has marked any buffers as used. If so, returns + /// the completion information and reclaims the descriptor(s). + /// + /// # Returns + /// + /// - `Ok(UsedBuffer)` - A buffer chain was completed + /// - `Err(RingError::WouldBlock)` - No completions available + pub fn poll_used(&mut self) -> Result { + let idx = self.used_cursor.head(); + let wrap = self.used_cursor.wrap(); + + // Read the descriptor at next_used position with ordering + let addr = self + .desc_table + .desc_addr(idx) + .ok_or(RingError::InvalidState)?; + + // Acquire flags then fields (publish point) + let desc = Descriptor::read_acquire(&self.mem, addr).map_err(|_| RingError::MemError)?; + if !desc.is_used(wrap) { + return Err(RingError::WouldBlock); + } + + let id = desc.id; + let count = *self + .id_num + .get(id as usize) + .ok_or(RingError::InvalidState)?; + + if count == 0 { + return Err(RingError::InvalidState); + } + + // Advance used cursor by number of reclaimed descriptors + self.used_cursor.advance_by(count); + // Update number of free descriptors + self.num_free += count as usize; + // SAFETY: id is valid because we checked above + self.id_num[id as usize] = 0; + // Return ID to free stack + self.id_free.push(id); + + Ok(UsedBuffer { id, len: desc.len }) + } + + /// Get number of free descriptors in the ring. + pub fn num_free(&self) -> usize { + self.num_free + } + + /// Get number of inflight (submitted but not yet used) descriptors. + pub fn num_inflight(&self) -> usize { + self.desc_table.len() - self.num_free + } + + /// Check if the ring is full (no free descriptors). + pub fn is_full(&self) -> bool { + self.num_free == 0 + } + + /// Get descriptor table length + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.desc_table.len() + } + + /// Get memory accessor reference + pub fn mem(&self) -> &M { + &self.mem + } + + /// Get a snapshot of the current available cursor position. + /// + /// Used for batch operations to track the cursor before submitting + /// multiple chains, enabling proper event suppression checks. + #[inline] + pub fn avail_cursor(&self) -> RingCursor { + self.avail_cursor + } + + /// Get a snapshot of the current used cursor position. + /// + /// Used for setting up DESC mode event suppression at specific positions. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.used_cursor + } + + /// Check if device should be notified given a cursor snapshot from before batch start. + /// + /// This is used for batching: record cursor before first submit, then after all + /// submits call this to determine if notification is needed based on event suppression. + /// + /// # Arguments + /// * `old` - Cursor position snapshot taken before batch started + pub fn should_notify_since(&self, old: RingCursor) -> Result { + self.should_notify_device(old, self.avail_cursor) + } + + /// Driver disables used-buffer notifications from device to driver. + pub fn disable_used_notifications(&mut self) -> Result<(), RingError> { + // Avoid redundant MMIO writes if already disabled + if self.event_flags_shadow == EventFlags::DISABLE { + return Ok(()); + } + + let mut evt = self + .mem + .read_val::(self.drv_evt_addr) + .map_err(|_| RingError::MemError)?; + + evt.set_flags(EventFlags::DISABLE); + + evt.write_release(&self.mem, self.drv_evt_addr) + .map_err(|_| RingError::MemError)?; + self.event_flags_shadow = EventFlags::DISABLE; + Ok(()) + } + + /// Driver enables used-buffer notifications from device to driver. + pub fn enable_used_notifications(&mut self) -> Result<(), RingError> { + if self.event_flags_shadow == EventFlags::ENABLE { + return Ok(()); + } + + let mut evt = self + .mem + .read_val::(self.drv_evt_addr) + .map_err(|_| RingError::MemError)?; + + evt.set_flags(EventFlags::ENABLE); + evt.write_release(&self.mem, self.drv_evt_addr) + .map_err(|_| RingError::MemError)?; + + self.event_flags_shadow = EventFlags::ENABLE; + Ok(()) + } + + /// Driver enables descriptor-specific used notifications (EVENT_IDX / DESC mode). + /// + /// This tells the device: "Interrupt me when you reach used index (off, wrap)". + /// + /// This enables batching on the device side - it can complete multiple requests + /// before triggering an interrupt. + pub fn enable_used_notifications_desc( + &mut self, + off: u16, + wrap: bool, + ) -> Result<(), RingError> { + let mut evt = self + .mem + .read_val::(self.drv_evt_addr) + .map_err(|_| RingError::MemError)?; + + evt.set_desc_event(off, wrap); + evt.set_flags(EventFlags::DESC); + + // Now publish flags = DESC with Release semantics. + evt.write_release(&self.mem, self.drv_evt_addr) + .map_err(|_| RingError::MemError)?; + // cache shadow + self.event_flags_shadow = EventFlags::DESC; + Ok(()) + } + + /// Convenience: enable DESC mode for "next used cursor" like Linux enable_cb_prepare. + pub fn enable_used_notifications_for_next(&mut self) -> Result<(), RingError> { + let off = self.used_cursor.head(); + let wrap = self.used_cursor.wrap(); + + self.enable_used_notifications_desc(off, wrap) + } + + /// Check whether the device should be notified about new available descriptors. + fn should_notify_device(&self, old: RingCursor, new: RingCursor) -> Result { + // VIRTIO 1.1 "The driver MUST perform a suitable memory barrier before + // reading the Device Event Suppression structure". + // + // After publishing descriptors with store-release on the AVAIL/USED flags, + // we need a full barrier before reading event suppression, because + // release+acquire across different memory locations does NOT provide + // Store/Load ordering on weakly-ordered architectures e.g. aarch64. + // + // Linux kernel uses virtio_mb() full barrier in virtqueue_kick_prepare_packed. + fence(Ordering::SeqCst); + + let evt = EventSuppression::read_acquire(&self.mem, self.dev_evt_addr) + .map_err(|_| RingError::MemError)?; + + Ok(should_notify(evt, self.len() as u16, old, new)) + } +} + +/// Consumer (device) side of a packed virtqueue. +/// +/// The consumer polls for available buffer chains submitted by the driver, +/// processes them, and marks them as used. This is typically used by the +/// device/host side. +/// +/// # Lifecycle +/// +/// 1. **Poll**: Call [`poll_available`](Self::poll_available) to get buffers +/// 2. **Process**: Read from readable buffers, write to writable buffers +/// 3. **Complete**: Call [`submit_used`](Self::submit_used) to return buffers +/// 4. **Notify**: If `submit_used_with_notify` returns true, signal the driver +#[derive(Debug)] +pub struct RingConsumer { + /// Memory accessor + mem: M, + /// Cursor for reading available (driver-published) descriptors + avail_cursor: RingCursor, + /// Cursor for writing used descriptors + used_cursor: RingCursor, + /// Shared descriptor table + desc_table: DescTable, + /// Per-ID chain length learned when polling (index = ID) + id_num: SmallVec<[u16; DescTable::DEFAULT_LEN]>, + /// Number of descriptors consumed from avail stream but not yet posted as used. + num_inflight: usize, + /// Shadow of device event flags (last written value) + event_flags_shadow: EventFlags, + // reads driver event to control when device should notify + drv_evt_addr: u64, + // write device_event (checks if device wants notification about available buffers) + dev_evt_addr: u64, +} + +impl RingConsumer { + pub fn new(layout: Layout, mem: M) -> Self { + let size = layout.desc_table_len as usize; + let raw = layout.desc_table_addr; + + // SAFETY: layout is valid + let table = unsafe { DescTable::from_raw_parts(raw, size) }; + let cursor = RingCursor::new(size); + let id_chain_len = SmallVec::<[u16; DescTable::DEFAULT_LEN]>::from_elem(0, size); + + // Notification enabled by default + let event_flags_shadow = EventFlags::ENABLE; + + Self { + mem, + avail_cursor: cursor, + used_cursor: cursor, + desc_table: table, + id_num: id_chain_len, + num_inflight: 0, + event_flags_shadow, + drv_evt_addr: layout.drv_evt_addr, + dev_evt_addr: layout.dev_evt_addr, + } + } + + /// Poll for an available buffer chain. + /// + /// Returns the chain ID and a [`BufferChain`] containing all buffers. + /// The chain ID must be passed to [`submit_used`](Self::submit_used) + /// when processing is complete. + /// + /// # Returns + /// + /// - `Ok((id, chain))` - A buffer chain is available + /// - `Err(RingError::WouldBlock)` - No buffers available + /// - `Err(RingError::BadChain)` - Malformed chain (driver bug) + pub fn poll_available(&mut self) -> Result<(u16, BufferChain), RingError> { + let idx = self.avail_cursor.head(); + let wrap = self.avail_cursor.wrap(); + + let head_addr = self + .desc_table + .desc_addr(idx) + .ok_or(RingError::InvalidState)?; + + // Acquire: flags then fields (publish point) + let head_desc = + Descriptor::read_acquire(&self.mem, head_addr).map_err(|_| RingError::MemError)?; + + // Check if head descriptor is available to consume + if !head_desc.is_avail(wrap) { + return Err(RingError::WouldBlock); + } + + // Build chain (head + tails). + let mut elements = SmallVec::<[BufferElement; 16]>::new(); + let mut pos = self.avail_cursor; + let mut chain_len: u16 = 1; + + let mut steps = 1; + let mut has_next = head_desc.is_next(); + + let max_steps = self.desc_table.len(); + + elements.push(BufferElement::from(&head_desc)); + pos.advance(); + + while has_next && steps < max_steps { + let addr = self + .desc_table + .desc_addr(pos.head()) + .ok_or(RingError::InvalidState)?; + + // tail reads does not need ordering because head has been already validated + let desc = self.mem.read_val(addr).map_err(|_| RingError::MemError)?; + elements.push(BufferElement::from(&desc)); + + chain_len += 1; + steps += 1; + + has_next = desc.is_next(); + pos.advance(); + } + + // Detect malformed chains, this means we reached max_steps but still have NEXT set. + if steps >= max_steps && has_next { + return Err(RingError::BadChain); + } + + // Verify that readable/writable split is correct + let readables = chain_readable_count(&elements)?; + + // Since driver wrote the same id everywhere, head_desc.id is valid. + let id = head_desc.id; + if (id as usize) >= self.id_num.len() { + return Err(RingError::InvalidState); + } + + // Record chain length for later used submission + self.id_num[id as usize] = chain_len; + // Advance avail cursor to first slot after chain + self.avail_cursor = pos; + // Update inflight count + self.num_inflight += chain_len as usize; + + assert!(self.num_inflight <= self.desc_table.len()); + + Ok(( + id, + BufferChain { + elems: elements, + split: readables, + }, + )) + } + + /// Publish a single used descriptor for the chain identified by id. + /// written_len is the total bytes produced by the device (for writeable part). + /// + /// # Arguments + /// + /// * `id` - The chain ID from `poll_available` + /// * `written_len` - Total bytes written to writable buffers + /// + /// # Errors + /// + /// - [`RingError::InvalidState`] - Unknown ID or already completed + pub fn submit_used(&mut self, id: u16, written_len: u32) -> Result<(), RingError> { + // Lookup chain length + let chain_len = *self + .id_num + .get(id as usize) + .ok_or(RingError::InvalidState)?; + + if chain_len == 0 { + return Err(RingError::InvalidState); + } + + let idx = self.used_cursor.head(); + let wrap = self.used_cursor.wrap(); + + // addr is unused for used descriptor according to packed-virtqueue spec + let mut used_desc = Descriptor::new(0, 0, id, DescFlags::empty()); + used_desc.len = written_len; + used_desc.mark_used(wrap); + + let addr = self + .desc_table + .desc_addr(idx) + .ok_or(RingError::InvalidState)?; + + // Release publish (flags written last inside write_release) + used_desc + .write_release(&self.mem, addr) + .map_err(|_| RingError::MemError)?; + + // Advance used cursor by whole chain length + self.used_cursor.advance_by(chain_len); + self.id_num[id as usize] = 0; + + self.num_inflight -= chain_len as usize; + assert!(self.num_inflight <= self.desc_table.len()); + + Ok(()) + } + + /// Try to peek whether the next chain is available without consuming it. + pub fn peek_available(&self) -> Result { + let Some(addr) = self.desc_table.desc_addr(self.avail_cursor.head()) else { + return Err(RingError::InvalidState); + }; + + let desc = Descriptor::read_acquire(&self.mem, addr).map_err(|_| RingError::MemError)?; + Ok(desc.is_avail(self.avail_cursor.wrap())) + } + + /// Submit a used descriptor and return whether to notify the driver. + pub fn submit_used_with_notify( + &mut self, + id: u16, + written_len: u32, + ) -> Result { + let old = self.used_cursor; + self.submit_used(id, written_len)?; + let new = self.used_cursor; + self.should_notify_driver(old, new) + } + + /// Get number of free descriptors in the ring. + pub fn num_free(&self) -> usize { + self.desc_table.len() - self.num_inflight + } + + /// Get number of inflight (submitted but not yet used) descriptors. + pub fn num_inflight(&self) -> usize { + self.num_inflight + } + + /// Check if the ring is full (no free descriptors). + pub fn is_full(&self) -> bool { + self.num_inflight == self.desc_table.len() + } + + /// Get descriptor table length + #[allow(clippy::len_without_is_empty)] + pub fn len(&self) -> usize { + self.desc_table.len() + } + + /// Get memory accessor reference + pub fn mem(&self) -> &M { + &self.mem + } + + /// Get a snapshot of the current avail cursor position. + #[inline] + pub fn avail_cursor(&self) -> RingCursor { + self.avail_cursor + } + + /// Get a snapshot of the current used cursor position. + #[inline] + pub fn used_cursor(&self) -> RingCursor { + self.used_cursor + } + + /// Device disables available-buffer notifications from driver to device. + /// + /// This is the device-side mirror of "disable callbacks" but for avail kicks. + pub fn disable_avail_notifications(&mut self) -> Result<(), RingError> { + if self.event_flags_shadow == EventFlags::DISABLE { + return Ok(()); + } + + let mut evt = self + .mem + .read_val::(self.dev_evt_addr) + .map_err(|_| RingError::MemError)?; + + evt.set_flags(EventFlags::DISABLE); + evt.write_release(&self.mem, self.dev_evt_addr) + .map_err(|_| RingError::MemError)?; + + self.event_flags_shadow = EventFlags::DISABLE; + Ok(()) + } + + /// Device enables available-buffer notifications from driver to device. + pub fn enable_avail_notifications(&mut self) -> Result<(), RingError> { + if self.event_flags_shadow == EventFlags::ENABLE { + return Ok(()); + } + + let mut evt = self + .mem + .read_val::(self.dev_evt_addr) + .map_err(|_| RingError::MemError)?; + + evt.set_flags(EventFlags::ENABLE); + evt.write_release(&self.mem, self.dev_evt_addr) + .map_err(|_| RingError::MemError)?; + + self.event_flags_shadow = EventFlags::ENABLE; + Ok(()) + } + + /// Device enables descriptor-specific available notifications (EVENT_IDX / DESC mode). + /// + /// This tells the driver: "Kick me when you reach avail index (off, wrap)". + pub fn enable_avail_notifications_desc( + &mut self, + off: u16, + wrap: bool, + ) -> Result<(), RingError> { + // Update off_wrap first + let mut evt = self + .mem + .read_val::(self.dev_evt_addr) + .map_err(|_| RingError::MemError)?; + + evt.set_desc_event(off, wrap); + evt.set_flags(EventFlags::DESC); + + // Now publish flags = DESC with Release semantics. + evt.write_release(&self.mem, self.dev_evt_addr) + .map_err(|_| RingError::MemError)?; + + self.event_flags_shadow = EventFlags::DESC; + Ok(()) + } + + /// Convenience: enable DESC mode for "next avail cursor" (device wants a kick when new + /// buffers arrive at the next index it will poll). + pub fn enable_avail_notifications_for_next(&mut self) -> Result<(), RingError> { + let off = self.avail_cursor.head(); + let wrap = self.avail_cursor.wrap(); + self.enable_avail_notifications_desc(off, wrap) + } + + /// Decide whether the device should notify the driver about newly used descriptors. + fn should_notify_driver(&self, old: RingCursor, new: RingCursor) -> Result { + // VIRTIO 1.1: Full memory barrier required before reading the + // Driver Event Suppression structure. See also should_notify_device() + fence(Ordering::SeqCst); + + let evt = EventSuppression::read_acquire(&self.mem, self.drv_evt_addr) + .map_err(|_| RingError::MemError)?; + + Ok(should_notify(evt, self.desc_table.len() as u16, old, new)) + } +} + +/// Common packed-ring notification decision: +/// - `old` and `new` are the ring indices (head) before/after publishing a batch +/// - `new.wrap()` is the wrap counter corresponding to `new.head()` +/// - `evt.desc_event_wrap()` is compared against `new.wrap()` +/// +/// This is compatible with Linux `virtqueue_kick_prepare_packed` logic +#[inline] +fn should_notify(evt: EventSuppression, ring_len: u16, old: RingCursor, new: RingCursor) -> bool { + match evt.flags() { + EventFlags::DISABLE => false, + EventFlags::ENABLE => true, + EventFlags::DESC => { + let mut off = evt.desc_event_off(); + let wrap = evt.desc_event_wrap(); + + if wrap != new.wrap() { + off = off.wrapping_sub(ring_len); + } + + ring_need_event(off, new.head(), old.head()) + } + _ => unreachable!(), + } +} + +#[inline(always)] +pub fn ring_need_event(event_idx: u16, new: u16, old: u16) -> bool { + new.wrapping_sub(event_idx).wrapping_sub(1) < new.wrapping_sub(old) +} + +#[inline] +/// Check that a buffer chain is well-formed: all readable buffers first, +/// then writable and return the count of readable buffers. +fn chain_readable_count(elems: &[BufferElement]) -> Result { + let mut seen_writable = false; + let mut writables = 0; + + for e in elems { + if e.writable { + seen_writable = true; + writables += 1; + } else if seen_writable { + return Err(RingError::BadChain); + } + } + + Ok(elems.len() - writables) +} + +impl From<&Descriptor> for BufferElement { + fn from(desc: &Descriptor) -> Self { + BufferElement { + addr: desc.addr, + len: desc.len, + writable: desc.is_writeable(), + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use alloc::sync::Arc; + use core::cell::UnsafeCell; + use core::num::NonZeroU16; + use core::ptr; + use core::sync::atomic::{AtomicU16, Ordering}; + + use bytemuck::{Pod, Zeroable}; + + use super::*; + use crate::virtq::event::EventSuppression; + + /// Test MemOps implementation that maintains pointer provenance. + /// + /// This wraps a Vec and provides memory access using the Vec's + /// base pointer to preserve provenance for Miri. + pub struct TestMem { + /// The backing storage - UnsafeCell for interior mutability + storage: UnsafeCell>, + /// Base address (the address we tell the ring about) + base_addr: u64, + } + + impl TestMem { + pub fn new(size: usize) -> Self { + let storage = vec![0u8; size]; + let base_addr = storage.as_ptr() as u64; + Self { + storage: UnsafeCell::new(storage), + base_addr, + } + } + + /// Get a pointer with proper provenance for the given address + fn ptr_for_addr(&self, addr: u64) -> *mut u8 { + let storage = unsafe { &mut *self.storage.get() }; + let base_ptr = storage.as_mut_ptr(); + let offset = (addr - self.base_addr) as usize; + // Use wrapping_add to maintain provenance from base_ptr + base_ptr.wrapping_add(offset) + } + + pub fn base_addr(&self) -> u64 { + self.base_addr + } + } + + // Safety: TestMem's UnsafeCell is only accessed from test code with no + // real concurrency in unit tests (loom tests use their own LoomMem). + // Required so Arc satisfies Send + Sync for Bytes::from_owner. + unsafe impl Send for TestMem {} + unsafe impl Sync for TestMem {} + + impl MemOps for Arc { + type Error = core::convert::Infallible; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + let src = self.ptr_for_addr(addr); + unsafe { + ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), dst.len()); + } + Ok(dst.len()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result { + let dst = self.ptr_for_addr(addr); + unsafe { + ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len()); + } + Ok(src.len()) + } + + fn read_val(&self, addr: u64) -> Result { + let ptr = self.ptr_for_addr(addr).cast::(); + Ok(unsafe { ptr::read_volatile(ptr) }) + } + + fn write_val(&self, addr: u64, val: T) -> Result<(), Self::Error> { + let ptr = self.ptr_for_addr(addr).cast::(); + unsafe { ptr::write_volatile(ptr, val) }; + Ok(()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let ptr = self.ptr_for_addr(addr).cast::(); + Ok(unsafe { (*ptr).load(Ordering::Acquire) }) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let ptr = self.ptr_for_addr(addr).cast::(); + unsafe { (*ptr).store(val, Ordering::Release) }; + Ok(()) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + let ptr = self.ptr_for_addr(addr); + Ok(unsafe { core::slice::from_raw_parts(ptr, len) }) + } + + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + let ptr = self.ptr_for_addr(addr); + Ok(unsafe { core::slice::from_raw_parts_mut(ptr, len) }) + } + } + + /// Owns the descriptor table and event suppression structures + pub struct OwnedRing { + mem: Arc, + layout: Layout, + } + + fn align_up(val: usize, align: usize) -> usize { + (val + align - 1) & !(align - 1) + } + + impl OwnedRing { + pub fn new(size: usize) -> Self { + let num_descs = NonZeroU16::new(size as u16).unwrap(); + let needed = Layout::query_size(size); + + // Add padding for alignment, plus extra space for pool buffers + // used by high-level API tests (pool offset = ring_end + 0x100, + // pool size = 0x8000). + let padding = Descriptor::ALIGN; + let pool_headroom = 0x100 + 0x8000; + let mem = Arc::new(TestMem::new(needed + padding + pool_headroom)); + + // Align the base address + let aligned_base = align_up(mem.base_addr() as usize, Descriptor::ALIGN) as u64; + let layout = unsafe { Layout::from_base(aligned_base, num_descs).unwrap() }; + + Self { mem, layout } + } + + pub fn layout(&self) -> Layout { + self.layout + } + + pub fn mem(&self) -> Arc { + self.mem.clone() + } + + /// Get address of descriptor at index + pub fn desc_addr(&self, idx: u16) -> u64 { + self.layout.desc_table_addr + (idx as u64 * Descriptor::SIZE as u64) + } + + /// Read descriptor directly (for test verification) + pub fn read_desc(&self, idx: u16) -> Descriptor { + self.mem.read_val(self.desc_addr(idx)).unwrap() + } + + /// Write descriptor directly (for test manipulation) + pub fn write_desc(&self, idx: u16, desc: Descriptor) { + self.mem.write_val(self.desc_addr(idx), desc).unwrap() + } + + /// Read driver event directly + pub fn read_driver_event(&self) -> EventSuppression { + self.mem.read_val(self.layout.drv_evt_addr).unwrap() + } + + /// Read device event directly + pub fn read_device_event(&self) -> EventSuppression { + self.mem.read_val(self.layout.dev_evt_addr).unwrap() + } + + pub fn len(&self) -> usize { + self.layout.desc_table_len as usize + } + } + + // Share the TestMem between producer and consumer via reference + pub(crate) fn make_ring(size: usize) -> OwnedRing { + OwnedRing::new(size) + } + + pub(crate) fn make_producer(ring: &OwnedRing) -> RingProducer> { + RingProducer::new(ring.layout(), ring.mem()) + } + + pub(crate) fn make_consumer(ring: &OwnedRing) -> RingConsumer> { + RingConsumer::new(ring.layout(), ring.mem()) + } + + fn assert_invariants(ring: &OwnedRing, prod: &RingProducer>) { + let outstanding: u16 = prod.id_num.iter().copied().sum(); + assert_eq!(outstanding as usize + prod.num_free, ring.len()); + + for id in prod.id_free.iter() { + assert_eq!(prod.id_num[*id as usize], 0); + } + + for (id, &n) in prod.id_num.iter().enumerate() { + if n > 0 { + assert!(!prod.id_free.contains(&(id as u16))); + } + } + } + + #[test] + fn test_initialization() { + let ring = make_ring(8); + let producer = make_producer(&ring); + + // All descriptors should be zeroed + for i in 0..8u16 { + let desc = ring.read_desc(i); + assert_eq!(desc, Descriptor::zeroed()); + assert_eq!(desc.flags, 0); + assert_eq!(desc.addr, 0); + assert_eq!(desc.len, 0); + assert_eq!(desc.id, 0); + } + + // Cursors start at head=0, wrap=true + assert_eq!(producer.avail_cursor.head(), 0); + assert!(producer.avail_cursor.wrap()); + assert_eq!(producer.used_cursor.head(), 0); + assert!(producer.used_cursor.wrap()); + + // All IDs free, id_num zeroed, num_free == size + assert_eq!(producer.id_free.len(), 8); + assert_eq!(producer.num_free, 8); + for i in 0..8 { + assert_eq!(producer.id_num[i], 0); + } + } + + #[test] + fn test_submit_one_descriptor() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let addr = 0x1000; + let len = 512; + let writable = false; + + let id = producer.submit_one(addr, len, writable).unwrap(); + + // Check descriptor was written correctly + let desc = ring.read_desc(0); + + assert_eq!(desc.addr, addr); + assert_eq!(desc.len, len); + assert_eq!(desc.id, id); + + // AVAIL should match wrap (true), USED should be inverse (false) + let flags = desc.flags(); + assert!(flags.contains(DescFlags::AVAIL)); + assert!(!flags.contains(DescFlags::USED)); + assert!(!flags.contains(DescFlags::WRITE)); + assert!(!flags.contains(DescFlags::NEXT)); + + // num_free should be decremented + assert_eq!(producer.num_free, 7); + + // Cursor advanced + assert_eq!(producer.avail_cursor.head(), 1); + assert!(producer.avail_cursor.wrap()); + + // ID allocated and chain length recorded + assert_eq!(producer.id_num[id as usize], 1); + assert_eq!(producer.id_free.len(), 7); + } + + #[test] + fn test_single_descriptor_wrap_toggle() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + // Advance to last slot + producer.avail_cursor.head = 3; + producer.avail_cursor.wrap = true; + producer.num_free = 1; + producer.id_free.clear(); + producer.id_free.push(0); + + let _id = producer.submit_one(0x1000, 512, false).unwrap(); + + // After submission, cursor should wrap + assert_eq!(producer.avail_cursor.head(), 0); + assert!(!producer.avail_cursor.wrap()); + + // Descriptor should have old wrap bits + let desc = ring.read_desc(3); + let flags = desc.flags(); + assert!(flags.contains(DescFlags::AVAIL)); + assert!(!flags.contains(DescFlags::USED)); + } + + #[test] + fn test_multi_descriptor_no_wrap() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 256) + .readable(0x2000, 256) + .writable(0x3000, 512) + .build() + .unwrap(); + + let id = producer.submit_available(&chain).unwrap(); + + // Check head descriptor + let head_desc = ring.read_desc(0); + assert_eq!(head_desc.addr, 0x1000); + assert_eq!(head_desc.len, 256); + assert_eq!(head_desc.id, id); + + let head_flags = head_desc.flags(); + assert!(head_flags.contains(DescFlags::NEXT)); + assert!(!head_flags.contains(DescFlags::WRITE)); + assert!(head_flags.contains(DescFlags::AVAIL)); + assert!(!head_flags.contains(DescFlags::USED)); + + // Check middle descriptor + let mid_desc = ring.read_desc(1); + assert_eq!(mid_desc.addr, 0x2000); + assert_eq!(mid_desc.len, 256); + assert_eq!(mid_desc.id, id); + + let mid_flags = mid_desc.flags(); + assert!(mid_flags.contains(DescFlags::NEXT)); + assert!(!mid_flags.contains(DescFlags::WRITE)); + + // Check tail descriptor + let tail_desc = ring.read_desc(2); + assert_eq!(tail_desc.addr, 0x3000); + assert_eq!(tail_desc.len, 512); + assert_eq!(tail_desc.id, id); + + let tail_flags = tail_desc.flags(); + assert!(!tail_flags.contains(DescFlags::NEXT)); + assert!(tail_flags.contains(DescFlags::WRITE)); + + // All descriptors have same ID + assert_eq!(head_desc.id, mid_desc.id); + assert_eq!(mid_desc.id, tail_desc.id); + + // Check state updates + assert_eq!(producer.num_free, 5); + assert_eq!(producer.avail_cursor.head(), 3); + assert_eq!(producer.id_num[id as usize], 3); + } + + #[test] + fn test_multi_descriptor_with_wrap() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + // Position head near end + producer.avail_cursor.head = 2; + producer.avail_cursor.wrap = true; + + let chain = BufferChainBuilder::new() + .readable(0x1000, 256) + .readable(0x2000, 256) + .readable(0x3000, 256) + .build() + .unwrap(); + + let _id = producer.submit_available(&chain).unwrap(); + + // Head at index 2 with wrap=true + let head_desc = ring.read_desc(2); + let head_flags = head_desc.flags(); + assert!(head_flags.contains(DescFlags::AVAIL)); + assert!(!head_flags.contains(DescFlags::USED)); + + // Middle at index 3 with wrap=true (before boundary) + let mid_desc = ring.read_desc(3); + let mid_flags = mid_desc.flags(); + assert!(mid_flags.contains(DescFlags::AVAIL)); + assert!(!mid_flags.contains(DescFlags::USED)); + + // Tail at index 0 with wrap=false (after boundary) + let tail_desc = ring.read_desc(0); + let tail_flags = tail_desc.flags(); + assert!(!tail_flags.contains(DescFlags::AVAIL)); + assert!(tail_flags.contains(DescFlags::USED)); + + // Cursor should have wrapped + assert_eq!(producer.avail_cursor.head(), 1); + assert!(!producer.avail_cursor.wrap()); + } + + #[test] + fn test_ring_full() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + // Fill ring completely + for _ in 0..4 { + producer.submit_one(0x1000, 256, false).unwrap(); + } + + assert_eq!(producer.num_free, 0); + + // Next submit should fail + let result = producer.submit_one(0x5000, 256, false); + assert!(matches!(result, Err(RingError::WouldBlock))); + } + + #[test] + fn test_poll_and_reclaim() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let id = producer.submit_one(0x1000, 512, false).unwrap(); + + // Manually mark as used (simulate device) + let mut desc = ring.read_desc(0); + desc.mark_used(true); + desc.len = 256; + ring.write_desc(0, desc); + + // Poll should return the used buffer + let used = producer.poll_used().unwrap(); + assert_eq!(used.id, id); + assert_eq!(used.len, 256); + + // State should be updated + assert_eq!(producer.num_free, 8); + assert_eq!(producer.used_cursor.head(), 1); + assert_eq!(producer.id_num[id as usize], 0); + assert!(producer.id_free.contains(&id)); + } + + #[test] + fn test_poll_multi_descriptor_chain() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 256) + .readable(0x2000, 256) + .writable(0x3000, 512) + .build() + .unwrap(); + + let id = producer.submit_available(&chain).unwrap(); + + // Mark only head as used + let mut head_desc = ring.read_desc(0); + head_desc.mark_used(true); + head_desc.len = 512; + ring.write_desc(0, head_desc); + + // Poll should reclaim all 3 descriptors + let used = producer.poll_used().unwrap(); + assert_eq!(used.id, id); + assert_eq!(used.len, 512); + + // Should have skipped 3 descriptors + assert_eq!(producer.used_cursor.head(), 3); + assert_eq!(producer.num_free, 8); + } + + #[test] + fn test_id_reuse() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + // Submit and complete first buffer + let id1 = producer.submit_one(0x1000, 256, false).unwrap(); + + let mut desc = ring.read_desc(0); + desc.mark_used(true); + ring.write_desc(0, desc); + + producer.poll_used().unwrap(); + + // Submit another buffer - should reuse ID + let id2 = producer.submit_one(0x2000, 256, false).unwrap(); + + // ID should be reused (LIFO from stack) + assert_eq!(id2, id1); + assert_eq!(producer.id_num[id2 as usize], 1); + } + + #[test] + fn test_available_descriptor_flags() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + producer.submit_one(0x1000, 256, false).unwrap(); + + let desc = ring.read_desc(0); + + // Available descriptor: AVAIL != USED + let flags = desc.flags(); + assert_ne!( + flags.contains(DescFlags::AVAIL), + flags.contains(DescFlags::USED) + ); + + // ... and AVAIL=true, USED=false for wrap=true + assert!(flags.contains(DescFlags::AVAIL)); + assert!(!flags.contains(DescFlags::USED)); + } + + #[test] + fn test_used_descriptor_flags() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + producer.submit_one(0x1000, 256, false).unwrap(); + + let mut desc = ring.read_desc(0); + desc.mark_used(true); + ring.write_desc(0, desc); + + let desc = ring.read_desc(0); + let flags = desc.flags(); + + // Used descriptor: AVAIL == USED + assert_eq!( + flags.contains(DescFlags::AVAIL), + flags.contains(DescFlags::USED) + ); + } + + #[test] + fn test_poll_empty_ring() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + // Poll without any submitted buffers + assert!(matches!(producer.poll_used(), Err(RingError::WouldBlock))); + } + + #[test] + fn test_submit_when_full() { + let ring = make_ring(2); + let mut producer = make_producer(&ring); + + producer.submit_one(0x1000, 256, false).unwrap(); + producer.submit_one(0x2000, 256, false).unwrap(); + + // Ring is full + assert!(matches!( + producer.submit_one(0x3000, 256, false), + Err(RingError::WouldBlock) + )); + } + + #[test] + fn test_empty_chain_rejected() { + let chain = BufferChain::default(); + assert_eq!(chain.len(), 0); + + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + let result = producer.submit_available(&chain); + assert!(matches!(result, Err(RingError::EmptyChain))); + } + + #[test] + fn test_wrap_stress() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Do multiple full laps + for lap in 0..3 { + let expected_wrap = lap % 2 == 0; + + for _ in 0..4 { + let id = producer.submit_one(0x1000, 256, false).unwrap(); + + let (dev_id, _) = consumer.poll_available().unwrap(); + assert_eq!(dev_id, id); + + consumer.submit_used(dev_id, 256).unwrap(); + + producer.poll_used().unwrap(); + } + + // After full lap, wrap should toggle + assert_eq!(producer.avail_cursor.wrap(), !expected_wrap); + } + assert_invariants(&ring, &producer); + } + + #[test] + fn test_next_flag_termination() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 256) + .readable(0x2000, 256) + .readable(0x3000, 256) + .build() + .unwrap(); + + producer.submit_available(&chain).unwrap(); + + // First two should have NEXT + for i in 0..2 { + let desc = ring.read_desc(i); + assert!(desc.flags().contains(DescFlags::NEXT)); + } + + // Last should not have NEXT + let tail_desc = ring.read_desc(2); + assert!(!tail_desc.flags().contains(DescFlags::NEXT)); + } + + #[test] + fn test_consumer_initialization() { + let ring = make_ring(8); + let consumer = make_consumer(&ring); + + assert_eq!(consumer.avail_cursor.head(), 0); + assert!(consumer.avail_cursor.wrap()); + assert_eq!(consumer.used_cursor.head(), 0); + assert!(consumer.used_cursor.wrap()); + + for i in 0..8 { + assert_eq!(consumer.id_num[i], 0); + } + } + + #[test] + fn test_consumer_poll_available_single() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let id = producer.submit_one(0x1000, 512, false).unwrap(); + + let (polled_id, chain) = consumer.poll_available().unwrap(); + + assert_eq!(polled_id, id); + assert_eq!(chain.len(), 1); + assert_eq!(chain.elems()[0].addr, 0x1000); + assert_eq!(chain.elems()[0].len, 512); + assert!(!chain.elems()[0].writable); + + // Chain length recorded + assert_eq!(consumer.id_num[id as usize], 1); + assert_eq!(consumer.avail_cursor.head(), 1); + } + + #[test] + fn test_consumer_poll_available_chain() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 256) + .readable(0x2000, 256) + .writable(0x3000, 512) + .build() + .unwrap(); + + let id = producer.submit_available(&chain).unwrap(); + + let (polled_id, polled_chain) = consumer.poll_available().unwrap(); + + assert_eq!(polled_id, id); + assert_eq!(polled_chain.len(), 3); + + assert_eq!(polled_chain.elems()[0].addr, 0x1000); + assert!(!polled_chain.elems()[0].writable); + + assert_eq!(polled_chain.elems()[1].addr, 0x2000); + assert!(!polled_chain.elems()[1].writable); + + assert_eq!(polled_chain.elems()[2].addr, 0x3000); + assert!(polled_chain.elems()[2].writable); + + assert_eq!(consumer.id_num[id as usize], 3); + } + + #[test] + fn test_consumer_submit_used() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let id = producer.submit_one(0x1000, 512, true).unwrap(); + + let (polled_id, _) = consumer.poll_available().unwrap(); + + // Submit as used + consumer.submit_used(polled_id, 256).unwrap(); + + // Check descriptor marked used + let desc = ring.read_desc(0); + + assert_eq!(desc.id, id); + assert_eq!(desc.len, 256); + assert!(desc.is_used(true)); + + // Cursor advanced, chain length cleared + assert_eq!(consumer.used_cursor.head(), 1); + assert_eq!(consumer.id_num[id as usize], 0); + } + + #[test] + fn test_consumer_submit_used_multi_descriptor() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 256) + .writable(0x2000, 512) + .writable(0x3000, 512) + .build() + .unwrap(); + + producer.submit_available(&chain).unwrap(); + + let (id, _) = consumer.poll_available().unwrap(); + + consumer.submit_used(id, 1024).unwrap(); + + // Only head marked used + let head_desc = ring.read_desc(0); + assert!(head_desc.is_used(true)); + assert_eq!(head_desc.len, 1024); + + // Cursor skipped entire chain + assert_eq!(consumer.used_cursor.head(), 3); + assert_eq!(consumer.id_num[id as usize], 0); + } + + #[test] + fn test_consumer_poll_empty() { + let ring = make_ring(4); + let mut consumer = make_consumer(&ring); + + assert!(matches!( + consumer.poll_available(), + Err(RingError::WouldBlock) + )); + } + + #[test] + fn test_consumer_peek() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let consumer = make_consumer(&ring); + + producer.submit_one(0x1000, 512, false).unwrap(); + assert!(consumer.peek_available().unwrap()); + + let empty_ring = make_ring(4); + let empty_consumer = make_consumer(&empty_ring); + assert!(!empty_consumer.peek_available().unwrap()); + } + + #[test] + fn test_full_roundtrip() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 256) + .writable(0x2000, 512) + .build() + .unwrap(); + + let id = producer.submit_available(&chain).unwrap(); + + let (consumer_id, consumer_chain) = consumer.poll_available().unwrap(); + + assert_eq!(consumer_id, id); + assert_eq!(consumer_chain.len(), 2); + + consumer.submit_used(consumer_id, 512).unwrap(); + + let used = producer.poll_used().unwrap(); + assert_eq!(used.id, id); + assert_eq!(used.len, 512); + } + + #[test] + fn ring_initial_poll_used_blocks() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + // No submissions yet: all descriptors zero. + for _ in 0..8 { + assert!(matches!(producer.poll_used(), Err(RingError::WouldBlock))); + } + // Invariants: num_free == ring size + assert_eq!(producer.num_free, ring.len()); + } + + #[test] + fn ring_consumer_blocks_until_submit() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + assert!(matches!( + consumer.poll_available(), + Err(RingError::WouldBlock) + )); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 32) + .readable(0x2000, 16) + .build() + .unwrap(); + + let id = producer.submit_available(&chain).unwrap(); + + let (cid, polled) = consumer.poll_available().unwrap(); + assert_eq!(cid, id); + assert_eq!(polled.len(), chain.len()); + } + + #[test] + fn test_out_of_order_completion_stream() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Driver submits two single-descriptor chains A then B + let id_a = producer.submit_one(0x1000, 256, true).unwrap(); + let id_b = producer.submit_one(0x2000, 256, true).unwrap(); + + // Device polls them in ring order (A then B) + let (dev_id_a, chain_a) = consumer.poll_available().unwrap(); + assert_eq!(dev_id_a, id_a); + assert_eq!(chain_a.len(), 1); + + let (dev_id_b, chain_b) = consumer.poll_available().unwrap(); + assert_eq!(dev_id_b, id_b); + assert_eq!(chain_b.len(), 1); + + // Device completes B first, then A + consumer.submit_used(dev_id_b, 128).unwrap(); + consumer.submit_used(dev_id_a, 256).unwrap(); + + // Driver polls used stream: should see B (first completion) + let used_b = producer.poll_used().unwrap(); + assert_eq!(used_b.id, id_b); + assert_eq!(used_b.len, 128); + + // Then sees A + let used_a = producer.poll_used().unwrap(); + assert_eq!(used_a.id, id_a); + assert_eq!(used_a.len, 256); + + // IDs recycled + assert!(producer.id_free.contains(&id_a)); + assert!(producer.id_free.contains(&id_b)); + } + + #[test] + fn test_mixed_chain_sizes_out_of_order_completion() { + let ring = make_ring(16); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let chains = vec![ + BufferChainBuilder::new() + .readable(0x1000, 10) + .writable(0x2000, 5) + .build() + .unwrap(), + BufferChainBuilder::new() + .readable(0x3000, 8) + .readable(0x3010, 8) + .writable(0x3020, 16) + .build() + .unwrap(), + BufferChainBuilder::new() + .readable(0x4000, 4) + .build() + .unwrap(), + BufferChainBuilder::new() + .readable(0x5000, 4) + .readable(0x5010, 4) + .readable(0x5020, 4) + .writable(0x5030, 4) + .build() + .unwrap(), + ]; + + for c in &chains { + producer.submit_available(c).unwrap(); + } + + let mut dev_chain_lens = Vec::new(); + for _ in &chains { + let (id, chain) = consumer.poll_available().unwrap(); + dev_chain_lens.push((id, chain.len() as u32)); + } + + let order = [1, 3, 0, 2]; + let mut completion = Vec::new(); + + for &idx in &order { + let (id, len) = dev_chain_lens[idx]; + consumer.submit_used(id, len).unwrap(); + completion.push((id, len)); + } + + for (expected_id, expected_len) in &completion { + let used = producer.poll_used().unwrap(); + assert_eq!(used.id, *expected_id); + assert_eq!(used.len, *expected_len); + assert_eq!(producer.id_num[*expected_id as usize], 0); + assert!(producer.id_free.contains(expected_id)); + } + + assert_invariants(&ring, &producer); + } + + // Used stream wrap crossing + #[test] + fn test_used_stream_wrap_crossing() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Submit enough single descriptors to make used writes wrap + let mut ids = Vec::new(); + for i in 0..8 { + ids.push(producer.submit_one(0x1000 + i as u64, 1, false).unwrap()); + } + + // Device polls all + for _ in 0..8 { + consumer.poll_available().unwrap(); + } + + // Complete all in order except we simulate out-of-order by reversing + for &id in ids.iter().rev() { + consumer.submit_used(id, 1).unwrap(); + } + + // Producer polls used; after consuming size descriptors used_cursor should wrap + for _ in 0..8 { + producer.poll_used().unwrap(); + } + assert_eq!(producer.used_cursor.head(), 0); + assert!(!producer.used_cursor.wrap()); // flipped once + assert_invariants(&ring, &producer); + } + + // Interleaved availability and completion + #[test] + fn test_interleaved_submit_completion() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Submit chain A (len 2) + let chain_a = BufferChainBuilder::new() + .readable(0x1000, 8) + .writable(0x2000, 8) + .build() + .unwrap(); + let id_a = producer.submit_available(&chain_a).unwrap(); + + // Device polls A + let (dev_id_a, _) = consumer.poll_available().unwrap(); + assert_eq!(dev_id_a, id_a); + + // Device completes A + consumer.submit_used(dev_id_a, 8).unwrap(); + + // Submit chain B (len 3) before driver reclaims A + let chain_b = BufferChainBuilder::new() + .readable(0x3000, 4) + .readable(0x3010, 4) + .writable(0x3020, 4) + .build() + .unwrap(); + let id_b = producer.submit_available(&chain_b).unwrap(); + + // Device polls B + let (dev_id_b, _) = consumer.poll_available().unwrap(); + assert_eq!(dev_id_b, id_b); + + // Driver reclaims A + let used_a = producer.poll_used().unwrap(); + assert_eq!(used_a.id, id_a); + + // Device completes B + consumer.submit_used(dev_id_b, 12).unwrap(); + + // Driver reclaims B + let used_b = producer.poll_used().unwrap(); + assert_eq!(used_b.id, id_b); + + assert_invariants(&ring, &producer); + } + + // Partial publish safety (head not published yet) + #[test] + fn test_partial_publish_safety() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + let mut producer = make_producer(&ring); + + // Build chain manually: write tails only + let chain = BufferChainBuilder::new() + .readable(0x1000, 4) + .readable(0x2000, 4) + .writable(0x3000, 4) + .build() + .unwrap(); + + // Simulate manual tail writes without head publish + let id = producer.id_free.pop().unwrap(); + producer.id_num[id as usize] = chain.len() as u16; + + // Emulate internal position logic + let head_idx = producer.avail_cursor.head(); + let wrap_start = producer.avail_cursor.wrap(); + let mut pos = producer.avail_cursor; + pos.advance(); + + for (i, elem) in chain.elems().iter().enumerate().skip(1) { + let is_next = i + 1 < chain.len(); + let mut flags = DescFlags::empty(); + flags.set(DescFlags::NEXT, is_next); + flags.set(DescFlags::WRITE, elem.writable); + let mut d = Descriptor::new(elem.addr, elem.len, id, flags); + d.mark_avail(pos.wrap()); + ring.write_desc(pos.head(), d); + pos.advance(); + } + + // Head not published yet: consumer must not see chain + assert!(matches!( + consumer.poll_available(), + Err(RingError::WouldBlock) + )); + + // Now publish head + let head_elem = chain.elems()[0]; + let mut head_flags = DescFlags::empty(); + head_flags.set(DescFlags::NEXT, true); + head_flags.set(DescFlags::WRITE, head_elem.writable); + let mut head_desc = Descriptor::new(head_elem.addr, head_elem.len, id, head_flags); + head_desc.mark_avail(wrap_start); + ring.write_desc(head_idx, head_desc); + producer.avail_cursor = pos; + producer.num_free -= chain.len(); + + // Consumer can now see the chain + let (dev_id, dev_chain) = consumer.poll_available().unwrap(); + assert_eq!(dev_id, id); + assert_eq!(dev_chain.len(), chain.len()); + assert_invariants(&ring, &producer); + } + + // Tail misuse negative test + #[test] + fn test_tail_marked_used_ignored() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 4) + .readable(0x2000, 4) + .build() + .unwrap(); + let id = producer.submit_available(&chain).unwrap(); + + // Incorrectly mark tail (index 1) used + let mut tail_desc = ring.read_desc(1); + tail_desc.mark_used(producer.used_cursor.wrap()); + ring.write_desc(1, tail_desc); + + // Poll should return WouldBlock (head not used yet) + assert!(matches!(producer.poll_used(), Err(RingError::WouldBlock))); + + // Mark head used properly + let mut head_desc = ring.read_desc(0); + head_desc.mark_used(producer.used_cursor.wrap()); + ring.write_desc(0, head_desc); + + // Now poll succeeds + let used = producer.poll_used().unwrap(); + assert_eq!(used.id, id); + assert_invariants(&ring, &producer); + } + + // Max chain length boundary + #[test] + fn test_max_chain_len_rejected() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + // Try chain longer than ring size + let elems = (0..9).map(|i| BufferElement { + addr: 0x1000 + i as u64, + len: 42, + writable: false, + }); + + let chain = BufferChainBuilder::new().readables(elems).build().unwrap(); + + // Submit_available should reject when num_free < total_descs + assert!(matches!( + producer.submit_available(&chain), + Err(RingError::WouldBlock) + )); + } + + // Descriptor state monotonicity after many cycles + #[test] + fn test_descriptor_state_monotonicity() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Track states: 0=zero/init, 1=available, 2=used, 3=reclaimed + let mut states = vec![0u8; 8]; + + for _ in 0..5 { + for state in states.iter_mut() { + let id = producer.submit_one(0x1000, 4, false).unwrap(); + // mark available + *state = (*state).max(1); + + // device polls and completes + let (dev_id, _) = consumer.poll_available().unwrap(); + consumer.submit_used(dev_id, 4).unwrap(); + *state = (*state).max(2); + + let used = producer.poll_used().unwrap(); + assert_eq!(used.id, id); + *state = (*state).max(3); + } + + assert_invariants(&ring, &producer); + } + + // Ensure monotonic progression (never decrease) + for s in states { + assert!(s >= 3); + } + } + + // Large multi-lap random submission/completion + #[test] + fn test_random_stress_small() { + use rand::Rng; + use rand::seq::SliceRandom; + + let ring = make_ring(16); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + let mut rng = rand::rng(); + + // Submit initial set + let mut active_ids = Vec::new(); + for _ in 0..8 { + let len = rng.random_range(1..=4); + let mut b = BufferChainBuilder::new().readable(0x1000, 4); + for i in 1..len { + b = b.readable(0x1000 + i as u64 * 0x10, 4); + } + let chain = b.build().unwrap(); + if let Ok(id) = producer.submit_available(&chain) { + active_ids.push(id); + } + } + + let mut dev_ids = Vec::new(); + while let Ok((id, _)) = consumer.poll_available() { + dev_ids.push(id); + } + + // Randomly complete + dev_ids.shuffle(&mut rng); + for id in &dev_ids { + let chain_len = consumer.id_num[*id as usize]; + consumer.submit_used(*id, chain_len as u32 * 4).unwrap(); + } + // Driver reclaim + for _ in &dev_ids { + if producer.poll_used().is_ok() {} + } + + assert_invariants(&ring, &producer); + } + + // Out-of-order multi-length explicit + #[test] + fn test_out_of_order_multi_length() { + let ring = make_ring(12); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let chain_a = BufferChainBuilder::new() + .readable(0x1000, 4) + .writable(0x2000, 4) + .build() + .unwrap(); + let chain_b = BufferChainBuilder::new() + .readable(0x3000, 4) + .readable(0x3010, 4) + .writable(0x3020, 4) + .build() + .unwrap(); + let chain_c = BufferChainBuilder::new() + .readable(0x4000, 4) + .build() + .unwrap(); + + let id_a = producer.submit_available(&chain_a).unwrap(); + let id_b = producer.submit_available(&chain_b).unwrap(); + let id_c = producer.submit_available(&chain_c).unwrap(); + + let (d_a, _) = consumer.poll_available().unwrap(); + let (d_b, _) = consumer.poll_available().unwrap(); + let (d_c, _) = consumer.poll_available().unwrap(); + assert_eq!(d_a, id_a); + assert_eq!(d_b, id_b); + assert_eq!(d_c, id_c); + + // Complete B, then C, then A + consumer.submit_used(d_b, 12).unwrap(); + consumer.submit_used(d_c, 4).unwrap(); + consumer.submit_used(d_a, 8).unwrap(); + + let u_b = producer.poll_used().unwrap(); + assert_eq!(u_b.id, id_b); + let u_c = producer.poll_used().unwrap(); + assert_eq!(u_c.id, id_c); + let u_a = producer.poll_used().unwrap(); + assert_eq!(u_a.id, id_a); + + assert_invariants(&ring, &producer); + } + + #[test] + fn interleave_submit_and_completion() { + let ring = make_ring(16); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Submit A (len 2) + let chain_a = BufferChainBuilder::new() + .readable(0x1000, 4) + .writable(0x2000, 4) + .build() + .unwrap(); + let id_a = producer.submit_available(&chain_a).unwrap(); + + // Device polls A + let (d_a, _) = consumer.poll_available().unwrap(); + assert_eq!(d_a, id_a); + + // Immediately complete A + consumer.submit_used(d_a, 8).unwrap(); + + // Submit B (len 3) + let chain_b = BufferChainBuilder::new() + .readable(0x3000, 4) + .readable(0x3010, 4) + .writable(0x3020, 4) + .build() + .unwrap(); + let id_b = producer.submit_available(&chain_b).unwrap(); + + // Driver polls used: gets A + let u_a = producer.poll_used().unwrap(); + assert_eq!(u_a.id, id_a); + assert_eq!(u_a.len, 8); + + // Device polls B and submits used for it + let (d_b, _) = consumer.poll_available().unwrap(); + assert_eq!(d_b, id_b); + consumer.submit_used(d_b, 12).unwrap(); + + // Submit C (len 1) + let id_c = producer.submit_one(0x4000, 4, false).unwrap(); + + // Device polls C and completes it + let (d_c, _) = consumer.poll_available().unwrap(); + assert_eq!(d_c, id_c); + consumer.submit_used(d_c, 4).unwrap(); + + // Driver polls used: gets B then C + let u_b = producer.poll_used().unwrap(); + assert_eq!(u_b.id, id_b); + assert_eq!(u_b.len, 12); + + let u_c = producer.poll_used().unwrap(); + assert_eq!(u_c.id, id_c); + assert_eq!(u_c.len, 4); + + assert_invariants(&ring, &producer); + } + + // Event suppression tests + #[test] + fn producer_disable_used_notifications_writes_driver_disable() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + assert_eq!(ring.read_driver_event().flags(), EventFlags::ENABLE); + producer.disable_used_notifications().unwrap(); + assert_eq!(ring.read_driver_event().flags(), EventFlags::DISABLE); + } + + #[test] + fn producer_enable_used_notifications_writes_driver_enable() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + producer.disable_used_notifications().unwrap(); + assert_eq!(ring.read_driver_event().flags(), EventFlags::DISABLE); + + producer.enable_used_notifications().unwrap(); + assert_eq!(ring.read_driver_event().flags(), EventFlags::ENABLE); + } + + #[test] + fn producer_enable_used_notifications_desc_sets_off_wrap_and_flags() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + producer.enable_used_notifications_desc(5, true).unwrap(); + + let evt = ring.read_driver_event(); + assert_eq!(evt.flags(), EventFlags::DESC); + assert_eq!(evt.desc_event_off(), 5); + assert!(evt.desc_event_wrap()); + } + + #[test] + fn producer_enable_used_notifications_for_next_programs_used_cursor() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + // initial used cursor: head=0, wrap=true + producer.enable_used_notifications_for_next().unwrap(); + + let evt = ring.read_driver_event(); + assert_eq!(evt.flags(), EventFlags::DESC); + assert_eq!(evt.desc_event_off(), 0); + assert!(evt.desc_event_wrap()); + } + + #[test] + fn consumer_disable_avail_notifications_writes_device_disable() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + + assert_eq!(ring.read_device_event().flags(), EventFlags::ENABLE); + consumer.disable_avail_notifications().unwrap(); + assert_eq!(ring.read_device_event().flags(), EventFlags::DISABLE); + } + + #[test] + fn consumer_enable_avail_notifications_writes_device_enable() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + + consumer.disable_avail_notifications().unwrap(); + assert_eq!(ring.read_device_event().flags(), EventFlags::DISABLE); + + consumer.enable_avail_notifications().unwrap(); + assert_eq!(ring.read_device_event().flags(), EventFlags::ENABLE); + } + + #[test] + fn consumer_enable_avail_notifications_desc_sets_off_wrap_and_flags() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + + consumer.enable_avail_notifications_desc(7, false).unwrap(); + + let evt = ring.read_device_event(); + assert_eq!(evt.flags(), EventFlags::DESC); + assert_eq!(evt.desc_event_off(), 7); + assert!(!evt.desc_event_wrap()); + } + + #[test] + fn consumer_enable_avail_notifications_for_next_programs_avail_cursor() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + + // initial avail cursor: head=0, wrap=true + consumer.enable_avail_notifications_for_next().unwrap(); + + let evt = ring.read_device_event(); + assert_eq!(evt.flags(), EventFlags::DESC); + assert_eq!(evt.desc_event_off(), 0); + assert!(evt.desc_event_wrap()); + } + + #[test] + fn producer_does_not_write_device_event_when_toggling_used_notifications() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let dev_before = ring.read_device_event(); + producer.disable_used_notifications().unwrap(); + let dev_after = ring.read_device_event(); + + assert_eq!(dev_after, dev_before); + } + + #[test] + fn consumer_does_not_write_driver_event_when_toggling_avail_notifications() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + + let drv_before = ring.read_driver_event(); + consumer.disable_avail_notifications().unwrap(); + let drv_after = ring.read_driver_event(); + + assert_eq!(drv_after, drv_before); + } + + #[test] + fn should_notify_flags_enable_disable() { + let ring_len = 8; + + let old = RingCursor { + head: 0, + size: ring_len, + wrap: true, + }; + let new = RingCursor { + head: 1, + size: ring_len, + wrap: true, + }; + + // DISABLE -> never notify + let evt = EventSuppression::new(0, EventFlags::DISABLE); + assert!(!should_notify(evt, ring_len, old, new)); + + // ENABLE -> always notify + let evt = EventSuppression::new(0, EventFlags::ENABLE); + assert!(should_notify(evt, ring_len, old, new)); + } + + #[test] + fn should_notify_desc_no_crossing() { + let ring_len = 8; + + let old = RingCursor { + head: 2, + size: ring_len, + wrap: true, + }; + let new = RingCursor { + head: 3, + size: ring_len, + wrap: true, + }; + + // event at 6, we did not cross it + let mut evt = EventSuppression::zeroed(); + evt.set_desc_event(6, true); + evt.set_flags(EventFlags::DESC); + + assert!(!should_notify(evt, ring_len, old, new)); + } + + #[test] + fn should_notify_desc_wrap_mismatch_adjusts_event_idx() { + let ring_len = 8; + + let old = RingCursor { + head: 7, + size: ring_len, + wrap: true, + }; + let new = RingCursor { + head: 1, + size: ring_len, + wrap: false, + }; + + let mut evt = EventSuppression::zeroed(); + evt.set_desc_event(7, true); + evt.set_flags(EventFlags::DESC); + + assert!(should_notify(evt, ring_len, old, new)); + } + + #[test] + fn ring_need_event_basic_cases() { + // If event_idx == new-1, should be true + assert!(ring_need_event(4, 5, 2)); + // If no progress, should be false + assert!(!ring_need_event(4, 5, 5)); + + // Wrapping arithmetic sanity: old near u16::MAX + let old = 0xFFFE; + let new = 1; + // event at 0xFFFF is considered "just before wrap" + assert!(ring_need_event(0xFFFF, new, old)); + } + + // Bad device/driver tests + #[test] + fn bad_device_marks_tail_used() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 4) + .readable(0x2000, 4) + .build() + .unwrap(); + let id = producer.submit_available(&chain).unwrap(); + + // Bad device: mark index 1 (tail) used + let mut tail = ring.read_desc(1); + tail.mark_used(producer.used_cursor.wrap()); + ring.write_desc(1, tail); + + // Driver must not consume it + assert!(matches!(producer.poll_used(), Err(RingError::WouldBlock))); + + // Now mark head properly, driver must consume + let mut head = ring.read_desc(0); + head.mark_used(producer.used_cursor.wrap()); + ring.write_desc(0, head); + + let used = producer.poll_used().unwrap(); + assert_eq!(used.id, id); + } + + #[test] + fn bad_device_wrong_used_bits() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + let id = producer.submit_one(0x1000, 8, true).unwrap(); + + // Malformed: set AVAIL but clear USED (should be equal for used) + let mut d = ring.read_desc(0); + // Force flags to look like "available" despite intent + d.mark_avail(producer.used_cursor.wrap()); + d.len = 8; + ring.write_desc(0, d); + + assert!(matches!(producer.poll_used(), Err(RingError::WouldBlock))); + + let mut d2 = ring.read_desc(0); + d2.mark_used(producer.used_cursor.wrap()); + ring.write_desc(0, d2); + + let u = producer.poll_used().unwrap(); + assert_eq!(u.id, id); + } + + #[test] + fn bad_driver_next_never_clears() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + let mut producer = make_producer(&ring); + + // Allocate an ID and pretend one huge chain + let id = producer.id_free.pop().unwrap(); + producer.id_num[id as usize] = 8; + + let mut pos = producer.avail_cursor; + let wrap_start = pos.wrap(); + + // Write every descriptor with NEXT set and same id + for _ in 0..8 { + let idx = pos.head(); + let mut flags = DescFlags::empty(); + flags.set(DescFlags::NEXT, true); // incorrect: last should NOT have NEXT + let mut desc = Descriptor::new(0x1000 + idx as u64 * 0x10, 4, id, flags); + desc.mark_avail(pos.wrap()); + ring.write_desc(idx, desc); + pos.advance(); + } + + // Publish head last (simulate driver behavior) + let head_idx = producer.avail_cursor.head(); + let mut head_flags = DescFlags::empty(); + head_flags.set(DescFlags::NEXT, true); + let mut head_desc = Descriptor::new(0x42, 4, id, head_flags); + head_desc.mark_avail(wrap_start); + ring.write_desc(head_idx, head_desc); + + // Consumer should detect invalid chain via step guard + assert!(matches!( + consumer.poll_available(), + Err(RingError::BadChain) + )); + } + + #[test] + fn bad_driver_interleaved_readables_and_writables() { + let ring = make_ring(8); + let mut consumer = make_consumer(&ring); + let mut producer = make_producer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 4) + .readable(0x2000, 4) + .writable(0x2000, 4) + .build() + .unwrap(); + + let _id = producer.submit_available(&chain).unwrap(); + + // now change first descriptor to writable (bad driver) + let mut first = ring.read_desc(0); + first.flags |= DescFlags::WRITE.bits(); + ring.write_desc(0, first); + + assert!(matches!( + consumer.poll_available(), + Err(RingError::BadChain) + )); + } + + #[test] + fn bad_device_marks_multiple_used_in_chain() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let chain = BufferChainBuilder::new() + .readable(0x1000, 4) + .readable(0x2000, 4) + .build() + .unwrap(); + let id = producer.submit_available(&chain).unwrap(); + + // Bad device: mark head and tail used + let mut head = ring.read_desc(0); + head.mark_used(producer.used_cursor.wrap()); + ring.write_desc(0, head); + + let mut tail = ring.read_desc(1); + tail.mark_used(producer.used_cursor.wrap()); + ring.write_desc(1, tail); + + // Driver consumes once + let u = producer.poll_used().unwrap(); + assert_eq!(u.id, id); + + // Next poll should block; no duplicate consumption + assert!(matches!(producer.poll_used(), Err(RingError::WouldBlock))); + } + + #[test] + fn bad_device_writes_used_at_wrong_slot() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let _id = producer.submit_one(0x1000, 4, true).unwrap(); + + // Wrong slot: mark index 3 used while next_used is 0 + let mut d = ring.read_desc(3); + d.mark_used(producer.used_cursor.wrap()); + ring.write_desc(3, d); + + // Driver should still block (polls only slot 0) + assert!(matches!(producer.poll_used(), Err(RingError::WouldBlock))); + + // Now mark slot 0 correctly, driver can consume + let mut d0 = ring.read_desc(0); + d0.mark_used(producer.used_cursor.wrap()); + ring.write_desc(0, d0); + let _u = producer.poll_used().unwrap(); + } + + #[test] + fn bad_driver_reuses_id_while_outstanding() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + // Submit first buffer: allocate ID + let id = producer.submit_one(0x1000, 4, false).unwrap(); + assert_eq!(producer.id_num[id as usize], 1); + + // push the same ID back into free list while it's still outstanding. + producer.id_free.push(id); + + // Next submit should fail because ID is still outstanding. + let res = producer.submit_one(0x2000, 4, false); + assert!(matches!(res, Err(RingError::InvalidState))); + } + + #[test] + fn test_avail_cursor_accessor() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + // Initial cursor + let cursor = producer.avail_cursor(); + assert_eq!(cursor.head(), 0); + assert!(cursor.wrap()); + + // After submit + producer.submit_one(0x1000, 512, false).unwrap(); + let cursor = producer.avail_cursor(); + assert_eq!(cursor.head(), 1); + assert!(cursor.wrap()); + } + + #[test] + fn test_should_notify_since() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + let before = producer.avail_cursor(); + producer.submit_one(0x1000, 512, false).unwrap(); + + // Default is ENABLE mode, so should notify + let should_notify = producer.should_notify_since(before).unwrap(); + assert!(should_notify); + } + + #[test] + fn test_batch_notification_single_check() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + let before = producer.avail_cursor(); + + // Submit multiple descriptors + producer.submit_one(0x1000, 512, false).unwrap(); + producer.submit_one(0x2000, 512, false).unwrap(); + producer.submit_one(0x3000, 512, false).unwrap(); + + // Single notification check for the entire batch + let should_notify = producer.should_notify_since(before).unwrap(); + assert!(should_notify); + + // Consumer sees all 3 descriptors + for _ in 0..3 { + let (_, _) = consumer.poll_available().unwrap(); + } + } +} + +#[cfg(test)] +mod fuzz { + use quickcheck::{Arbitrary, Gen, QuickCheck}; + + use super::tests::{OwnedRing, make_consumer, make_producer}; + use super::*; + + const MAX_RING: usize = 64; + const MAX_OPS: usize = 128; + const MAX_CHAIN_LEN: usize = 8; + + #[allow(clippy::large_enum_variant)] + #[derive(Clone, Debug)] + enum Op { + /// submit one chain + Submit(BufferChain), + /// poll up to N chains + PollAvail(u8), + /// driver reclaims up to N completions + PollUsed(u8), + /// complete one previously polled chain + CompleteOne, + } + + impl Arbitrary for Op { + fn arbitrary(g: &mut Gen) -> Self { + let choice = u8::arbitrary(g) % 4; + match choice { + 0 => Op::Submit(BufferChain::arbitrary(g)), + 1 => Op::PollAvail(u8::arbitrary(g) % 8 + 1), + 2 => Op::PollUsed(u8::arbitrary(g) % 8 + 1), + 3 => Op::CompleteOne, + _ => unreachable!(), + } + } + } + + #[derive(Clone, Debug)] + struct Scenario { + table_size: usize, + ops: Vec, + } + + impl Arbitrary for Scenario { + fn arbitrary(g: &mut Gen) -> Self { + let table_size = usize::arbitrary(g) % MAX_RING + 1; + let num_ops = usize::arbitrary(g) % MAX_OPS + 1; + + let ops = (0..num_ops).map(|_| Op::arbitrary(g)).collect(); + Scenario { table_size, ops } + } + } + + impl Arbitrary for BufferElement { + fn arbitrary(g: &mut Gen) -> Self { + let addr = u64::arbitrary(g); + let len = u32::arbitrary(g); + let writable = bool::arbitrary(g); + + BufferElement { + addr, + len, + writable, + } + } + } + + impl Arbitrary for BufferChain { + fn arbitrary(g: &mut Gen) -> Self { + let chain_len = usize::arbitrary(g) % MAX_CHAIN_LEN + 1; + + let mut elems = vec![BufferElement::zeroed(); chain_len]; + let mut readables = 0; + let mut writables = 0; + + for _ in 0..chain_len { + let elem = BufferElement::arbitrary(g); + if elem.writable { + elems[chain_len - 1 - writables] = elem; + writables += 1; + } else { + elems[readables] = elem; + readables += 1; + } + } + + BufferChain { + elems: elems.into(), + split: readables, + } + } + } + + fn run_scenario(s: Scenario) -> bool { + let ring = OwnedRing::new(s.table_size); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Order logs + let mut dev_order: Vec = Vec::new(); + let mut drv_order: Vec = Vec::new(); + + // Device-tracked polled-but-not-completed IDs + let mut dev_ready: Vec<(u16, u32)> = Vec::new(); + + for op in &s.ops { + match op { + Op::Submit(chain) => { + // Submit only if space; otherwise skip + let _ = producer.submit_available(chain); + } + Op::PollAvail(n) => { + for _ in 0..*n { + if let Ok((id, chain)) = consumer.poll_available() { + dev_ready.push((id, chain.len() as u32)); + } else { + break; + } + } + } + Op::PollUsed(n) => { + for _ in 0..*n { + match producer.poll_used() { + Ok(u) => { + drv_order.push(u.id); + if producer.id_num[u.id as usize] != 0 { + return false; + } + if !producer.id_free.contains(&u.id) { + return false; + } + } + Err(RingError::WouldBlock) => break, + Err(_) => return false, + } + } + } + Op::CompleteOne => { + if let Some((id, len)) = dev_ready.pop() { + if consumer.submit_used(id, len).is_err() { + return false; + } + + dev_order.push(id); + } + } + } + + // assert invariants after each op + let outstanding: u16 = producer.id_num.iter().copied().sum(); + if outstanding as usize + producer.num_free != ring.len() { + return false; + } + + for id in producer.id_free.iter() { + if producer.id_num[*id as usize] != 0 { + return false; + } + } + } + + // Drain remaining completions and reclaims + while let Some((id, len)) = dev_ready.pop() { + if consumer.submit_used(id, len).is_err() { + return false; + } + } + + loop { + match producer.poll_used() { + Ok(u) => drv_order.push(u.id), + Err(RingError::WouldBlock) => break, + Err(_) => return false, + } + } + + true + } + + #[test] + fn prop_interleaved_with_order_verification() { + #[cfg(miri)] + let tests = 1; + #[cfg(not(miri))] + let tests = 100; + + QuickCheck::new() + .tests(tests) + .quickcheck(run_scenario as fn(Scenario) -> bool); + } +} diff --git a/src/hyperlight_guest/src/error.rs b/src/hyperlight_guest/src/error.rs index 463ba6a69..f5e3cbd83 100644 --- a/src/hyperlight_guest/src/error.rs +++ b/src/hyperlight_guest/src/error.rs @@ -17,9 +17,10 @@ limitations under the License. use alloc::format; use alloc::string::{String, ToString as _}; +use anyhow; use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::func::Error as FuncError; -use {anyhow, serde_json}; +use serde_json; pub type Result = core::result::Result; @@ -171,10 +172,10 @@ impl GuestErrorContext for core::result::Result { #[macro_export] macro_rules! bail { ($ec:expr => $($msg:tt)*) => { - return ::core::result::Result::Err($crate::error::HyperlightGuestError::new($ec, ::alloc::format!($($msg)*))); + return ::core::result::Result::Err($crate::error::HyperlightGuestError::new($ec, ::alloc::format!($($msg)*))) }; ($($msg:tt)*) => { - $crate::bail!($crate::error::ErrorCode::GuestError => $($msg)*); + $crate::bail!($crate::error::ErrorCode::GuestError => $($msg)*) }; } From aa4c41713cad8fa67f004819f04381f8c176f89d Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 25 Mar 2026 13:51:23 +0100 Subject: [PATCH 02/26] feat(virtq): add virtqueue ring plumbing in scratch region Place G2H and H2G packed virtqueue descriptor rings at deterministic offsets in the scratch region. Signed-off-by: Tomasz Andrzejak --- Cargo.lock | 1 + .../src/arch/aarch64/layout.rs | 7 +- .../src/arch/amd64/layout.rs | 15 +++- src/hyperlight_common/src/arch/i686/layout.rs | 7 +- src/hyperlight_common/src/layout.rs | 44 ++++++++++- src/hyperlight_common/src/outb.rs | 3 + src/hyperlight_common/src/virtq/desc.rs | 4 +- src/hyperlight_common/src/virtq/mod.rs | 1 + src/hyperlight_common/src/virtq/msg.rs | 75 +++++++++++++++++++ src/hyperlight_guest_bin/src/lib.rs | 4 + src/hyperlight_guest_bin/src/virtq_init.rs | 50 +++++++++++++ src/hyperlight_host/Cargo.toml | 1 + src/hyperlight_host/src/mem/layout.rs | 56 +++++++++++++- src/hyperlight_host/src/mem/mgr.rs | 45 +++++++++++ src/hyperlight_host/src/mem/shared_mem.rs | 60 ++++----------- src/hyperlight_host/src/sandbox/config.rs | 22 ++++++ .../src/sandbox/initialized_multi_use.rs | 2 + src/hyperlight_host/src/sandbox/outb.rs | 4 + src/tests/rust_guests/dummyguest/Cargo.lock | 50 +++++++++++++ src/tests/rust_guests/simpleguest/Cargo.lock | 50 +++++++++++++ src/tests/rust_guests/witguest/Cargo.lock | 50 +++++++++++++ 21 files changed, 494 insertions(+), 57 deletions(-) create mode 100644 src/hyperlight_common/src/virtq/msg.rs create mode 100644 src/hyperlight_guest_bin/src/virtq_init.rs diff --git a/Cargo.lock b/Cargo.lock index 2efdd92f9..cd013fbb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1554,6 +1554,7 @@ dependencies = [ "bitflags 2.11.0", "blake3", "built", + "bytemuck", "cfg-if", "cfg_aliases", "chrono", diff --git a/src/hyperlight_common/src/arch/aarch64/layout.rs b/src/hyperlight_common/src/arch/aarch64/layout.rs index 20f17026c..25bd99a1e 100644 --- a/src/hyperlight_common/src/arch/aarch64/layout.rs +++ b/src/hyperlight_common/src/arch/aarch64/layout.rs @@ -20,6 +20,11 @@ pub const SNAPSHOT_PT_GVA_MIN: usize = 0xffff_8000_0000_0000; pub const SNAPSHOT_PT_GVA_MAX: usize = 0xffff_80ff_ffff_ffff; pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; -pub fn min_scratch_size(_input_data_size: usize, _output_data_size: usize) -> usize { +pub fn min_scratch_size( + _input_data_size: usize, + _output_data_size: usize, + _g2h_num_descs: usize, + _h2g_num_descs: usize, +) -> usize { unimplemented!("min_scratch_size") } diff --git a/src/hyperlight_common/src/arch/amd64/layout.rs b/src/hyperlight_common/src/arch/amd64/layout.rs index 14a9cd62a..4731f21b2 100644 --- a/src/hyperlight_common/src/arch/amd64/layout.rs +++ b/src/hyperlight_common/src/arch/amd64/layout.rs @@ -37,8 +37,17 @@ pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; /// - A page for the smallest possible non-exception stack /// - (up to) 3 pages for mapping that /// - Two pages for the exception stack and metadata -/// - A page-aligned amount of memory for I/O buffers (for now) -pub fn min_scratch_size(input_data_size: usize, output_data_size: usize) -> usize { - (input_data_size + output_data_size).next_multiple_of(crate::vmem::PAGE_SIZE) +/// - A page-aligned amount of memory for I/O buffers and virtqueue rings +pub fn min_scratch_size( + input_data_size: usize, + output_data_size: usize, + g2h_num_descs: usize, + h2g_num_descs: usize, +) -> usize { + let g2h_ring_size = crate::virtq::Layout::query_size(g2h_num_descs); + let h2g_ring_size = crate::virtq::Layout::query_size(h2g_num_descs); + + (input_data_size + output_data_size + g2h_ring_size + h2g_ring_size) + .next_multiple_of(crate::vmem::PAGE_SIZE) + 12 * crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/arch/i686/layout.rs b/src/hyperlight_common/src/arch/i686/layout.rs index f3601c643..08c9ec594 100644 --- a/src/hyperlight_common/src/arch/i686/layout.rs +++ b/src/hyperlight_common/src/arch/i686/layout.rs @@ -20,6 +20,11 @@ limitations under the License. pub const MAX_GVA: usize = 0xffff_ffff; pub const MAX_GPA: usize = 0xffff_ffff; -pub fn min_scratch_size(_input_data_size: usize, _output_data_size: usize) -> usize { +pub fn min_scratch_size( + _input_data_size: usize, + _output_data_size: usize, + _g2h_num_descs: usize, + _h2g_num_descs: usize, +) -> usize { crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 64b79d982..a043b794d 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -33,11 +33,26 @@ pub use arch::{MAX_GPA, MAX_GVA}; ))] pub use arch::{SNAPSHOT_PT_GVA_MAX, SNAPSHOT_PT_GVA_MIN}; -// offsets down from the top of scratch memory for various things pub const SCRATCH_TOP_SIZE_OFFSET: u64 = 0x08; pub const SCRATCH_TOP_ALLOCATOR_OFFSET: u64 = 0x10; pub const SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET: u64 = 0x18; -pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x20; +pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x20; +pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x28; +pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x30; +pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x32; +pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; + +// fields must not overlap, and exception stack address must be 16-byte aligned. +const _: () = { + assert!(SCRATCH_TOP_SIZE_OFFSET + 8 <= SCRATCH_TOP_ALLOCATOR_OFFSET); + assert!(SCRATCH_TOP_ALLOCATOR_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET); + assert!(SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET + 8 <= SCRATCH_TOP_G2H_RING_GVA_OFFSET); + assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_H2G_RING_GVA_OFFSET); + assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); + assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); + assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); +}; /// Offset from the top of scratch memory for a shared host-guest u64 counter. /// @@ -55,5 +70,30 @@ pub fn scratch_base_gva(size: usize) -> u64 { (MAX_GVA - size + 1) as u64 } +/// Compute the byte offset from the scratch base to the G2H ring. +/// +/// TODO(ring): Remove input/output +pub const fn g2h_ring_scratch_offset(input_data_size: usize, output_data_size: usize) -> usize { + let io_off = input_data_size + output_data_size; + let align = crate::virtq::Descriptor::ALIGN; + + (io_off + align - 1) & !(align - 1) +} + +/// Compute the byte offset from the scratch base to the H2G ring. +/// +/// TODO(ring): Remove input/output +pub const fn h2g_ring_scratch_offset( + input_data_size: usize, + output_data_size: usize, + g2h_num_descs: usize, +) -> usize { + let g2h_offset = g2h_ring_scratch_offset(input_data_size, output_data_size); + let g2h_size = crate::virtq::Layout::query_size(g2h_num_descs); + let align = crate::virtq::Descriptor::ALIGN; + + (g2h_offset + g2h_size + align - 1) & !(align - 1) +} + /// Compute the minimum scratch region size needed for a sandbox. pub use arch::min_scratch_size; diff --git a/src/hyperlight_common/src/outb.rs b/src/hyperlight_common/src/outb.rs index 3bfb99848..0f9c25e00 100644 --- a/src/hyperlight_common/src/outb.rs +++ b/src/hyperlight_common/src/outb.rs @@ -105,6 +105,8 @@ pub enum OutBAction { TraceMemoryAlloc = 105, #[cfg(feature = "mem_profile")] TraceMemoryFree = 106, + /// Notification that entries are available on a virtqueue. + VirtqNotify = 109, } /// IO-port actions intercepted at the hypervisor level (in `run_vcpu`) @@ -137,6 +139,7 @@ impl TryFrom for OutBAction { 105 => Ok(OutBAction::TraceMemoryAlloc), #[cfg(feature = "mem_profile")] 106 => Ok(OutBAction::TraceMemoryFree), + 109 => Ok(OutBAction::VirtqNotify), _ => Err(anyhow::anyhow!("Invalid OutBAction value: {}", val)), } } diff --git a/src/hyperlight_common/src/virtq/desc.rs b/src/hyperlight_common/src/virtq/desc.rs index 57e5efb12..64bde4f7d 100644 --- a/src/hyperlight_common/src/virtq/desc.rs +++ b/src/hyperlight_common/src/virtq/desc.rs @@ -59,14 +59,16 @@ pub struct Descriptor { } const _: () = assert!(core::mem::size_of::() == 16); +const _: () = assert!(Descriptor::ALIGN == 16); const _: () = assert!(Descriptor::ADDR_OFFSET == 0); const _: () = assert!(Descriptor::LEN_OFFSET == 8); const _: () = assert!(Descriptor::ID_OFFSET == 12); const _: () = assert!(Descriptor::FLAGS_OFFSET == 14); impl Descriptor { + // VIRTIO spec requires 16-byte alignment for descriptors + pub const ALIGN: usize = 16; pub const SIZE: usize = core::mem::size_of::(); - pub const ALIGN: usize = core::mem::align_of::(); pub const ADDR_OFFSET: usize = core::mem::offset_of!(Self, addr); pub const LEN_OFFSET: usize = core::mem::offset_of!(Self, len); diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 490f30ac1..b52ef2805 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -154,6 +154,7 @@ mod access; mod consumer; mod desc; mod event; +pub mod msg; mod pool; mod producer; mod ring; diff --git a/src/hyperlight_common/src/virtq/msg.rs b/src/hyperlight_common/src/virtq/msg.rs new file mode 100644 index 000000000..9c7f69947 --- /dev/null +++ b/src/hyperlight_common/src/virtq/msg.rs @@ -0,0 +1,75 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Wire format header for all virtqueue messages. +//! +//! Every payload on both the G2H and H2G queues starts with this +//! fixed 8-byte header, enabling message type discrimination and +//! request/response correlation. + +/// Message types for the virtqueue wire protocol. +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MsgKind { + /// A function call request (FunctionCall payload follows). + Request = 0x01, + /// A function call response (FunctionCallResult payload follows). + Response = 0x02, + /// A stream data chunk. + StreamChunk = 0x03, + /// End-of-stream marker. + StreamEnd = 0x04, + /// Cancel a pending request. + Cancel = 0x05, +} + +/// Wire header for all virtqueue messages +#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] +#[repr(C)] +pub struct VirtqMsgHeader { + /// Discriminates the message type. + pub kind: u8, + /// Per-type flags TODO(ring): add flags type. + pub flags: u8, + /// Caller-assigned correlation ID. Responses echo the request's ID. + pub req_id: u16, + /// Byte length of the payload following this header. + pub payload_len: u32, +} + +impl VirtqMsgHeader { + pub const SIZE: usize = core::mem::size_of::(); + + /// Create a new message header. + pub const fn new(kind: MsgKind, req_id: u16, payload_len: u32) -> Self { + Self { + kind: kind as u8, + flags: 0, + req_id, + payload_len, + } + } + + /// Create a new header with flags. + pub const fn with_flags(kind: MsgKind, flags: u8, req_id: u16, payload_len: u32) -> Self { + Self { + kind: kind as u8, + flags, + req_id, + payload_len, + } + } +} diff --git a/src/hyperlight_guest_bin/src/lib.rs b/src/hyperlight_guest_bin/src/lib.rs index 0d5672de0..0ce3fdd51 100644 --- a/src/hyperlight_guest_bin/src/lib.rs +++ b/src/hyperlight_guest_bin/src/lib.rs @@ -52,6 +52,7 @@ pub mod host_comm; pub mod memory; #[cfg(target_arch = "x86_64")] pub mod paging; +mod virtq_init; // Globals #[cfg(all(feature = "mem_profile", target_arch = "x86_64"))] @@ -235,6 +236,9 @@ pub(crate) extern "C" fn generic_init( OS_PAGE_SIZE = ops as u32; } + // Initialize virtqueues + virtq_init::init_virtqueues(); + // set up the logger let guest_log_level_filter = GuestLogFilter::try_from(max_log_level).expect("Invalid log level"); diff --git a/src/hyperlight_guest_bin/src/virtq_init.rs b/src/hyperlight_guest_bin/src/virtq_init.rs new file mode 100644 index 000000000..1f24f5d9b --- /dev/null +++ b/src/hyperlight_guest_bin/src/virtq_init.rs @@ -0,0 +1,50 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue initialization. +//! +//! The host places virtqueue rings at deterministic offsets in the +//! scratch region and writes ring GVAs and queue depths to scratch-top +//! metadata. + +use hyperlight_common::layout::{ + self, SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, + SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, +}; +use hyperlight_common::virtq::Layout as VirtqLayout; + +/// Read a value from a scratch-top metadata slot. +unsafe fn read_scratch_top(offset: u64) -> T { + let addr = (layout::MAX_GVA as u64 - offset + 1) as *const T; + unsafe { core::ptr::read_volatile(addr) } +} + +/// Initialize virtqueue ring memory in the scratch region. +pub(crate) fn init_virtqueues() { + let g2h_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; + let g2h_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; + let h2g_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; + let h2g_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; + + assert!(g2h_depth > 0 && h2g_depth > 0); + assert!(g2h_gva != 0 && h2g_gva != 0); + + let size = VirtqLayout::query_size(g2h_depth as usize); + unsafe { core::ptr::write_bytes(g2h_gva as *mut u8, 0, size) }; + + let size = VirtqLayout::query_size(h2g_depth as usize); + unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, size) }; +} diff --git a/src/hyperlight_host/Cargo.toml b/src/hyperlight_host/Cargo.toml index 6fd1ccb1f..6f38c6b32 100644 --- a/src/hyperlight_host/Cargo.toml +++ b/src/hyperlight_host/Cargo.toml @@ -21,6 +21,7 @@ bench = false # see https://bheisler.github.io/criterion.rs/book/faq.html#cargo- workspace = true [dependencies] +bytemuck = { version = "1.25", features = ["derive"] } gdbstub = { version = "0.7.10", optional = true } gdbstub_arch = { version = "0.3.3", optional = true } goblin = { version = "0.10", default-features = false, features = ["std", "elf32", "elf64", "endian_fd"] } diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index b55189969..cee44b94b 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -341,6 +341,8 @@ impl SandboxMemoryLayout { let min_scratch_size = hyperlight_common::layout::min_scratch_size( cfg.get_input_data_size(), cfg.get_output_data_size(), + cfg.get_g2h_queue_depth(), + cfg.get_h2g_queue_depth(), ); if scratch_size < min_scratch_size { return Err(MemoryRequestTooSmall(scratch_size, min_scratch_size)); @@ -484,13 +486,55 @@ impl SandboxMemoryLayout { 0 } + /// Get the offset into the scratch region of the G2H ring. + fn get_g2h_ring_scratch_offset(&self) -> usize { + hyperlight_common::layout::g2h_ring_scratch_offset( + self.sandbox_memory_config.get_input_data_size(), + self.sandbox_memory_config.get_output_data_size(), + ) + } + + /// Get the size of the G2H ring in bytes. + fn get_g2h_ring_size(&self) -> usize { + hyperlight_common::virtq::Layout::query_size( + self.sandbox_memory_config.get_g2h_queue_depth(), + ) + } + + /// Get the offset into the scratch region of the H2G ring. + fn get_h2g_ring_scratch_offset(&self) -> usize { + hyperlight_common::layout::h2g_ring_scratch_offset( + self.sandbox_memory_config.get_input_data_size(), + self.sandbox_memory_config.get_output_data_size(), + self.sandbox_memory_config.get_g2h_queue_depth(), + ) + } + + /// Get the size of the H2G ring in bytes. + fn get_h2g_ring_size(&self) -> usize { + hyperlight_common::virtq::Layout::query_size( + self.sandbox_memory_config.get_h2g_queue_depth(), + ) + } + + /// Get the GVA of the G2H ring in guest address space. + pub(crate) fn get_g2h_ring_gva(&self) -> u64 { + hyperlight_common::layout::scratch_base_gva(self.scratch_size) + + self.get_g2h_ring_scratch_offset() as u64 + } + + /// Get the GVA of the H2G ring in guest address space. + pub(crate) fn get_h2g_ring_gva(&self) -> u64 { + hyperlight_common::layout::scratch_base_gva(self.scratch_size) + + self.get_h2g_ring_scratch_offset() as u64 + } + /// Get the offset from the beginning of the scratch region to the /// location where page tables will be eagerly copied on restore #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn get_pt_base_scratch_offset(&self) -> usize { - (self.sandbox_memory_config.get_input_data_size() - + self.sandbox_memory_config.get_output_data_size()) - .next_multiple_of(hyperlight_common::vmem::PAGE_SIZE) + let after_rings = self.get_h2g_ring_scratch_offset() + self.get_h2g_ring_size(); + after_rings.next_multiple_of(hyperlight_common::vmem::PAGE_SIZE) } /// Get the base GPA to which the page tables will be eagerly @@ -595,6 +639,8 @@ impl SandboxMemoryLayout { let min_fixed_scratch = hyperlight_common::layout::min_scratch_size( self.sandbox_memory_config.get_input_data_size(), self.sandbox_memory_config.get_output_data_size(), + self.sandbox_memory_config.get_g2h_queue_depth(), + self.sandbox_memory_config.get_h2g_queue_depth(), ); let min_scratch = min_fixed_scratch + size; if self.scratch_size < min_scratch { @@ -817,6 +863,10 @@ impl SandboxMemoryLayout { // initialised here, because they are in the scratch // region---they are instead set in // [`SandboxMemoryManager::update_scratch_bookkeeping`]. + // + // Virtqueue ring layouts are also communicated via scratch-top + // metadata (queue depths), not the PEB. Both host and guest + // compute ring addresses from shared offset functions. Ok(()) } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 98c70734b..3f19a8ade 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -15,6 +15,7 @@ limitations under the License. */ #[cfg(feature = "nanvix-unstable")] use std::mem::offset_of; +use std::num::NonZeroU16; use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{ @@ -22,6 +23,7 @@ use hyperlight_common::flatbuffer_wrappers::function_call::{ }; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; +use hyperlight_common::virtq::Layout as VirtqLayout; use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE, PageTableEntry, PhysAddr}; #[cfg(all(feature = "crashdump", not(feature = "nanvix-unstable")))] use hyperlight_common::vmem::{BasicMapping, MappingKind}; @@ -554,6 +556,25 @@ impl SandboxMemoryManager { SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, )?; + // Write virtqueue metadata to scratch-top so the guest can + // discover ring locations without reading the PEB. + self.update_scratch_bookkeeping_item( + SCRATCH_TOP_G2H_RING_GVA_OFFSET, + self.layout.get_g2h_ring_gva(), + )?; + self.update_scratch_bookkeeping_item( + SCRATCH_TOP_H2G_RING_GVA_OFFSET, + self.layout.get_h2g_ring_gva(), + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET as usize, + self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16, + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET as usize, + self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16, + )?; + // Copy the page tables into the scratch region let snapshot_pt_end = self.shared_mem.mem_size(); let snapshot_pt_size = self.layout.get_pt_size(); @@ -793,6 +814,30 @@ impl SandboxMemoryManager { }) })?? } + + /// Compute the G2H virtqueue Layout from scratch region addresses. + pub(crate) fn g2h_virtq_layout(&self) -> Result { + let base = self.layout.get_g2h_ring_gva(); + let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth(); + + let nz = NonZeroU16::new(depth as u16) + .ok_or_else(|| new_error!("G2H queue depth is zero"))?; + + unsafe { VirtqLayout::from_base(base, nz) } + .map_err(|e| new_error!("Invalid G2H virtq layout: {:?}", e)) + } + + /// Compute the H2G virtqueue Layout from scratch region addresses. + pub(crate) fn h2g_virtq_layout(&self) -> Result { + let base = self.layout.get_h2g_ring_gva(); + let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); + + let nz = NonZeroU16::new(depth as u16) + .ok_or_else(|| new_error!("H2G queue depth is zero"))?; + + unsafe { VirtqLayout::from_base(base, nz) } + .map_err(|e| new_error!("Invalid H2G virtq layout: {:?}", e)) + } } #[cfg(test)] diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index b978b3475..6f10bcbf3 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -876,57 +876,25 @@ impl SharedMemory for GuestSharedMemory { } } -/// An unsafe marker trait for types for which all bit patterns are valid. -/// This is required in order for it to be safe to read a value of a particular -/// type out of the sandbox from the HostSharedMemory. -/// -/// # Safety -/// This must only be implemented for types for which all bit patterns -/// are valid. It requires that any (non-undef/poison) value of the -/// correct size can be transmuted to the type. -pub unsafe trait AllValid {} -unsafe impl AllValid for u8 {} -unsafe impl AllValid for u16 {} -unsafe impl AllValid for u32 {} -unsafe impl AllValid for u64 {} -unsafe impl AllValid for i8 {} -unsafe impl AllValid for i16 {} -unsafe impl AllValid for i32 {} -unsafe impl AllValid for i64 {} -unsafe impl AllValid for [u8; 16] {} - impl HostSharedMemory { - /// Read a value of type T, whose representation is the same - /// between the sandbox and the host, and which has no invalid bit - /// patterns - pub fn read(&self, offset: usize) -> Result { + /// Read a value of type T from the sandbox at the given offset. + /// + /// T must implement [`bytemuck::Pod`] which guarantees all bit + /// patterns are valid and there is no padding. + pub fn read(&self, offset: usize) -> Result { bounds_check!(offset, std::mem::size_of::(), self.mem_size()); - unsafe { - let mut ret: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); - { - let slice: &mut [u8] = core::slice::from_raw_parts_mut( - ret.as_mut_ptr() as *mut u8, - std::mem::size_of::(), - ); - self.copy_to_slice(slice, offset)?; - } - Ok(ret.assume_init()) - } + let mut val = T::zeroed(); + self.copy_to_slice(bytemuck::bytes_of_mut(&mut val), offset)?; + Ok(val) } - /// Write a value of type T, whose representation is the same - /// between the sandbox and the host, and which has no invalid bit - /// patterns - pub fn write(&self, offset: usize, data: T) -> Result<()> { + /// Write a value of type T into the sandbox at the given offset. + /// + /// T must implement [`bytemuck::Pod`] which guarantees all bit + /// patterns are valid and there is no padding. + pub fn write(&self, offset: usize, data: T) -> Result<()> { bounds_check!(offset, std::mem::size_of::(), self.mem_size()); - unsafe { - let slice: &[u8] = core::slice::from_raw_parts( - core::ptr::addr_of!(data) as *const u8, - std::mem::size_of::(), - ); - self.copy_from_slice(slice, offset)?; - } - Ok(()) + self.copy_from_slice(bytemuck::bytes_of(&data), offset) } /// Copy the contents of the slice into the sandbox at the diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index f12387a0b..120aa06cd 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -74,6 +74,12 @@ pub struct SandboxConfiguration { interrupt_vcpu_sigrtmin_offset: u8, /// How much writable memory to offer the guest scratch_size: usize, + /// Number of descriptors for the G2H (guest-to-host) virtqueue. Must be a power of 2. + /// Default: 64 sized to 2x H2G depth for deadlock prevention. + g2h_queue_depth: usize, + /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. + /// Default: 32 + h2g_queue_depth: usize, } impl SandboxConfiguration { @@ -93,6 +99,10 @@ impl SandboxConfiguration { pub const DEFAULT_HEAP_SIZE: u64 = 131072; /// The default size of the scratch region pub const DEFAULT_SCRATCH_SIZE: usize = 0x48000; + /// The default G2H virtqueue depth (number of descriptors, must be power of 2) + pub const DEFAULT_G2H_QUEUE_DEPTH: usize = 64; + /// The default H2G virtqueue depth (number of descriptors, must be power of 2) + pub const DEFAULT_H2G_QUEUE_DEPTH: usize = 32; #[allow(clippy::too_many_arguments)] /// Create a new configuration for a sandbox with the given sizes. @@ -114,6 +124,8 @@ impl SandboxConfiguration { scratch_size, interrupt_retry_delay, interrupt_vcpu_sigrtmin_offset, + g2h_queue_depth: Self::DEFAULT_G2H_QUEUE_DEPTH, + h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -209,6 +221,16 @@ impl SandboxConfiguration { self.scratch_size } + /// Get the G2H virtqueue depth (number of descriptors). + pub fn get_g2h_queue_depth(&self) -> usize { + self.g2h_queue_depth + } + + /// Get the H2G virtqueue depth (number of descriptors). + pub fn get_h2g_queue_depth(&self) -> usize { + self.h2g_queue_depth + } + /// Set the size of the scratch regiong #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_scratch_size(&mut self, scratch_size: usize) { diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 72de96035..642fb2772 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1107,6 +1107,8 @@ mod tests { let min_scratch = hyperlight_common::layout::min_scratch_size( cfg.get_input_data_size(), cfg.get_output_data_size(), + cfg.get_g2h_queue_depth(), + cfg.get_h2g_queue_depth(), ); cfg.set_scratch_size(min_scratch + 0x10000 + 0x10000); diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 9704a1fe3..bb73763a6 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -227,6 +227,10 @@ pub(crate) fn handle_outb( eprint!("{}", ch); Ok(()) } + OutBAction::VirtqNotify => { + // TODO(ring): acknowledge notification but no-op for now. + Ok(()) + } #[cfg(feature = "trace_guest")] OutBAction::TraceBatch => Ok(()), #[cfg(feature = "mem_profile")] diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index e2f99c9a0..736cf6c4b 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -8,6 +8,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "atomic_refcell" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" + [[package]] name = "bitflags" version = "2.11.0" @@ -23,6 +29,32 @@ dependencies = [ "spin", ] +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" version = "1.2.57" @@ -59,6 +91,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -86,8 +124,14 @@ name = "hyperlight-common" version = "0.14.0" dependencies = [ "anyhow", + "atomic_refcell", + "bitflags", + "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", + "smallvec", "spin", "thiserror", "tracing-core", @@ -303,6 +347,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + [[package]] name = "spin" version = "0.10.0" diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index 7c0f52ccc..e12c99d4d 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -8,6 +8,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "atomic_refcell" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" + [[package]] name = "bitflags" version = "2.11.0" @@ -23,6 +29,32 @@ dependencies = [ "spin", ] +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" version = "1.2.57" @@ -51,6 +83,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -78,8 +116,14 @@ name = "hyperlight-common" version = "0.14.0" dependencies = [ "anyhow", + "atomic_refcell", + "bitflags", + "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", + "smallvec", "spin", "thiserror", "tracing-core", @@ -307,6 +351,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + [[package]] name = "spin" version = "0.10.0" diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index fb184958a..39bb97169 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -67,6 +67,12 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "atomic_refcell" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" + [[package]] name = "bitflags" version = "2.11.0" @@ -82,6 +88,32 @@ dependencies = [ "spin", ] +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + [[package]] name = "cc" version = "1.2.57" @@ -145,6 +177,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flatbuffers" version = "25.12.19" @@ -183,8 +221,14 @@ name = "hyperlight-common" version = "0.14.0" dependencies = [ "anyhow", + "atomic_refcell", + "bitflags", + "bytemuck", + "bytes", + "fixedbitset", "flatbuffers", "log", + "smallvec", "spin", "thiserror", "tracing-core", @@ -534,6 +578,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + [[package]] name = "spin" version = "0.10.0" From 02e8a704b6a326cb37341a0d6ecce8fd74a6a305 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 25 Mar 2026 15:05:52 +0100 Subject: [PATCH 03/26] feat(virtq): add MemOps for host and guest Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest/src/lib.rs | 1 + src/hyperlight_guest/src/virtq_mem.rs | 67 ++++++++++++ src/hyperlight_host/src/mem/mgr.rs | 10 +- src/hyperlight_host/src/mem/mod.rs | 2 + src/hyperlight_host/src/mem/virtq_mem.rs | 124 +++++++++++++++++++++++ 5 files changed, 198 insertions(+), 6 deletions(-) create mode 100644 src/hyperlight_guest/src/virtq_mem.rs create mode 100644 src/hyperlight_host/src/mem/virtq_mem.rs diff --git a/src/hyperlight_guest/src/lib.rs b/src/hyperlight_guest/src/lib.rs index 951cdedd2..6ef9efd86 100644 --- a/src/hyperlight_guest/src/lib.rs +++ b/src/hyperlight_guest/src/lib.rs @@ -25,6 +25,7 @@ pub mod error; pub mod exit; pub mod layout; pub mod prim_alloc; +pub mod virtq_mem; pub mod guest_handle { pub mod handle; diff --git a/src/hyperlight_guest/src/virtq_mem.rs b/src/hyperlight_guest/src/virtq_mem.rs new file mode 100644 index 000000000..8309deb79 --- /dev/null +++ b/src/hyperlight_guest/src/virtq_mem.rs @@ -0,0 +1,67 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side [`MemOps`] implementation for virtqueue access. + +use core::convert::Infallible; +use core::sync::atomic::{AtomicU16, Ordering}; +use core::{ptr, slice}; + +use hyperlight_common::virtq::MemOps; + +/// Guest-side memory accessor for virtqueue operations. Treats virtq +/// addresses as guest virtual addresses that map directly to memory. +#[derive(Clone, Copy, Debug)] +pub struct GuestMemOps; + +impl MemOps for GuestMemOps { + type Error = Infallible; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + let src = addr as *const u8; + unsafe { + ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), dst.len()); + } + Ok(dst.len()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result { + let dst = addr as *mut u8; + unsafe { + ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len()); + } + Ok(src.len()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let ptr = addr as *const AtomicU16; + Ok(unsafe { (*ptr).load(Ordering::Acquire) }) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let ptr = addr as *const AtomicU16; + unsafe { (*ptr).store(val, Ordering::Release) }; + Ok(()) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + Ok(unsafe { slice::from_raw_parts(addr as *const u8, len) }) + } + + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + Ok(unsafe { slice::from_raw_parts_mut(addr as *mut u8, len) }) + } +} diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 3f19a8ade..95983520d 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -818,10 +818,9 @@ impl SandboxMemoryManager { /// Compute the G2H virtqueue Layout from scratch region addresses. pub(crate) fn g2h_virtq_layout(&self) -> Result { let base = self.layout.get_g2h_ring_gva(); - let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth(); + let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16; - let nz = NonZeroU16::new(depth as u16) - .ok_or_else(|| new_error!("G2H queue depth is zero"))?; + let nz = NonZeroU16::new(depth).ok_or_else(|| new_error!("G2H queue depth is zero"))?; unsafe { VirtqLayout::from_base(base, nz) } .map_err(|e| new_error!("Invalid G2H virtq layout: {:?}", e)) @@ -830,10 +829,9 @@ impl SandboxMemoryManager { /// Compute the H2G virtqueue Layout from scratch region addresses. pub(crate) fn h2g_virtq_layout(&self) -> Result { let base = self.layout.get_h2g_ring_gva(); - let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); + let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16; - let nz = NonZeroU16::new(depth as u16) - .ok_or_else(|| new_error!("H2G queue depth is zero"))?; + let nz = NonZeroU16::new(depth).ok_or_else(|| new_error!("H2G queue depth is zero"))?; unsafe { VirtqLayout::from_base(base, nz) } .map_err(|e| new_error!("Invalid H2G virtq layout: {:?}", e)) diff --git a/src/hyperlight_host/src/mem/mod.rs b/src/hyperlight_host/src/mem/mod.rs index 64f5db2fe..4882bc75c 100644 --- a/src/hyperlight_host/src/mem/mod.rs +++ b/src/hyperlight_host/src/mem/mod.rs @@ -38,3 +38,5 @@ pub mod shared_mem; /// Utilities for writing shared memory tests #[cfg(all(test, not(miri)))] // uses proptest which isn't miri-compatible pub(crate) mod shared_mem_tests; +/// Host-side [`hyperlight_common::virtq::MemOps`] for virtqueue access. +pub(crate) mod virtq_mem; diff --git a/src/hyperlight_host/src/mem/virtq_mem.rs b/src/hyperlight_host/src/mem/virtq_mem.rs new file mode 100644 index 000000000..f96674c1d --- /dev/null +++ b/src/hyperlight_host/src/mem/virtq_mem.rs @@ -0,0 +1,124 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Host-side [`MemOps`] implementation for virtqueue access. +//! +//! Translates guest virtual addresses used in virtqueue descriptors +//! to offsets into the scratch [`HostSharedMemory`], reusing its +//! volatile access and locking patterns. + +use core::sync::atomic::{AtomicU16, Ordering}; + +use hyperlight_common::virtq::MemOps; + +use super::shared_mem::{HostSharedMemory, SharedMemory}; + +/// Error type for host memory operations. +#[derive(Debug, thiserror::Error)] +pub enum HostMemError { + #[error("address {addr:#x} out of bounds scratch_size={scratch_size}")] + OutOfBounds { addr: u64, scratch_size: usize }, + #[error("shared memory error: {0}")] + SharedMem(String), + #[error("as_slice/as_mut_slice not supported on host")] + DirectSliceNotSupported, +} + +/// Host-side memory accessor for virtqueue operations. +/// +/// Owns a clone of the scratch [`HostSharedMemory`] and translates +/// guest virtual addresses (in the scratch region) to offsets for the +/// existing volatile read/write methods. +#[derive(Clone)] +pub(crate) struct HostMemOps { + /// Cloned handle to the scratch shared memory + scratch: HostSharedMemory, + /// The guest virtual address that corresponds to scratch offset 0. + scratch_base_gva: u64, +} + +impl HostMemOps { + /// Create a new `HostMemOps` backed by shared memory. + pub(crate) fn new(scratch: &HostSharedMemory, scratch_base_gva: u64) -> Self { + Self { + scratch: scratch.clone(), + scratch_base_gva, + } + } + + /// Translate a guest virtual address to a scratch offset. + fn to_offset(&self, addr: u64) -> Result { + addr.checked_sub(self.scratch_base_gva) + .map(|o| o as usize) + .ok_or(HostMemError::OutOfBounds { + addr, + scratch_size: self.scratch.mem_size(), + }) + } + + /// Get a raw pointer into scratch memory at the given guest address. + fn raw_ptr(&self, addr: u64, len: usize) -> Result<*mut u8, HostMemError> { + let offset = self.to_offset(addr)?; + let scratch_size = self.scratch.mem_size(); + + if offset.checked_add(len).is_none_or(|end| end > scratch_size) { + return Err(HostMemError::OutOfBounds { addr, scratch_size }); + } + + Ok(self.scratch.base_ptr().wrapping_add(offset)) + } +} + +impl MemOps for HostMemOps { + type Error = HostMemError; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + let offset = self.to_offset(addr)?; + self.scratch + .copy_to_slice(dst, offset) + .map_err(|e| HostMemError::SharedMem(e.to_string()))?; + Ok(dst.len()) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result { + let offset = self.to_offset(addr)?; + self.scratch + .copy_from_slice(src, offset) + .map_err(|e| HostMemError::SharedMem(e.to_string()))?; + Ok(src.len()) + } + + fn load_acquire(&self, addr: u64) -> Result { + let ptr = self.raw_ptr(addr, core::mem::size_of::())?; + let atomic = unsafe { &*(ptr as *const AtomicU16) }; + Ok(atomic.load(Ordering::Acquire)) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + let ptr = self.raw_ptr(addr, core::mem::size_of::())?; + let atomic = unsafe { &*(ptr as *const AtomicU16) }; + atomic.store(val, Ordering::Release); + Ok(()) + } + + unsafe fn as_slice(&self, _addr: u64, _len: usize) -> Result<&[u8], Self::Error> { + Err(HostMemError::DirectSliceNotSupported) + } + + unsafe fn as_mut_slice(&self, _addr: u64, _len: usize) -> Result<&mut [u8], Self::Error> { + Err(HostMemError::DirectSliceNotSupported) + } +} From 528a5c01f3a8fdf5cf52b0a6b114e12e2109cf47 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 25 Mar 2026 17:11:20 +0100 Subject: [PATCH 04/26] feat(virtq): create G2H producer during guest init Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/layout.rs | 11 ++- src/hyperlight_common/src/virtq/pool.rs | 24 ++++++ src/hyperlight_guest_bin/src/lib.rs | 4 +- src/hyperlight_guest_bin/src/virtq/mod.rs | 68 +++++++++++++++++ src/hyperlight_guest_bin/src/virtq/state.rs | 75 +++++++++++++++++++ src/hyperlight_guest_bin/src/virtq_init.rs | 50 ------------- src/hyperlight_host/src/mem/mgr.rs | 4 + src/hyperlight_host/src/sandbox/config.rs | 26 +++++++ src/hyperlight_host/tests/integration_test.rs | 10 ++- 9 files changed, 216 insertions(+), 56 deletions(-) create mode 100644 src/hyperlight_guest_bin/src/virtq/mod.rs create mode 100644 src/hyperlight_guest_bin/src/virtq/state.rs delete mode 100644 src/hyperlight_guest_bin/src/virtq_init.rs diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index a043b794d..4fb5208f6 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -40,9 +40,9 @@ pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x20; pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x28; pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x30; pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x32; +pub const SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET: u64 = 0x34; pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; -// fields must not overlap, and exception stack address must be 16-byte aligned. const _: () = { assert!(SCRATCH_TOP_SIZE_OFFSET + 8 <= SCRATCH_TOP_ALLOCATOR_OFFSET); assert!(SCRATCH_TOP_ALLOCATOR_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET); @@ -50,7 +50,8 @@ const _: () = { assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_H2G_RING_GVA_OFFSET); assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); - assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET); + assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); }; @@ -70,9 +71,13 @@ pub fn scratch_base_gva(size: usize) -> u64 { (MAX_GVA - size + 1) as u64 } +pub const fn scratch_top_ptr(offset: u64) -> *mut T { + (MAX_GVA as u64 - offset + 1) as *mut T +} + /// Compute the byte offset from the scratch base to the G2H ring. /// -/// TODO(ring): Remove input/output +/// TODO(virtq): Remove input/output pub const fn g2h_ring_scratch_offset(input_data_size: usize, output_data_size: usize) -> usize { let io_off = input_data_size + output_data_size; let align = crate::virtq::Descriptor::ALIGN; diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 0324c08fe..99d8f4109 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -126,6 +126,30 @@ pub trait BufferProvider { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; } +impl BufferProvider for alloc::rc::Rc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } +} + +impl BufferProvider for alloc::sync::Arc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } +} + /// The owner of a mapped buffer, ensuring its lifetime. /// /// Holds a pool allocation and provides direct access to the underlying diff --git a/src/hyperlight_guest_bin/src/lib.rs b/src/hyperlight_guest_bin/src/lib.rs index 0ce3fdd51..695f7d282 100644 --- a/src/hyperlight_guest_bin/src/lib.rs +++ b/src/hyperlight_guest_bin/src/lib.rs @@ -52,7 +52,7 @@ pub mod host_comm; pub mod memory; #[cfg(target_arch = "x86_64")] pub mod paging; -mod virtq_init; +mod virtq; // Globals #[cfg(all(feature = "mem_profile", target_arch = "x86_64"))] @@ -237,7 +237,7 @@ pub(crate) extern "C" fn generic_init( } // Initialize virtqueues - virtq_init::init_virtqueues(); + virtq::init_virtqueues(); // set up the logger let guest_log_level_filter = diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq/mod.rs new file mode 100644 index 000000000..50c0dd6d9 --- /dev/null +++ b/src/hyperlight_guest_bin/src/virtq/mod.rs @@ -0,0 +1,68 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue initialization. +//! +//! Zeroes ring memory and creates VirtqProducer instances by allocating +//! buffer pool pages from the scratch page allocator. + +pub(crate) mod state; + +use hyperlight_common::layout::{ + SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, + SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, + SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, +}; +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::virtq::Layout as VirtqLayout; +use hyperlight_guest::prim_alloc::alloc_phys_pages; + +use crate::paging::phys_to_virt; + +/// Initialize virtqueue producers for G2H and H2G queues. +pub(crate) fn init_virtqueues() { + let g2h_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; + let g2h_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; + let h2g_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; + let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; + let pool_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET) } as u64; + + assert!(g2h_depth > 0 && h2g_depth > 0); + assert!(g2h_gva != 0 && h2g_gva != 0); + assert!(pool_pages > 0); + + // Zero ring memory + let g2h_ring_size = VirtqLayout::query_size(g2h_depth as usize); + unsafe { core::ptr::write_bytes(g2h_gva as *mut u8, 0, g2h_ring_size) }; + + let h2g_ring_size = VirtqLayout::query_size(h2g_depth as usize); + unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, h2g_ring_size) }; + + // Allocate buffer pool from physical pages + let pool_gpa = unsafe { alloc_phys_pages(pool_pages) }; + let pool_ptr = phys_to_virt(pool_gpa).expect("failed to map pool pages"); + let pool_gva = pool_ptr as u64; + let pool_size = pool_pages as usize * PAGE_SIZE_USIZE; + unsafe { core::ptr::write_bytes(pool_ptr, 0, pool_size) }; + + // Create G2H producer + unsafe { + state::init_g2h_producer(g2h_gva, g2h_depth, pool_gva, pool_size); + } + + // TODO(virtq): add other direction's producer + let _ = (h2g_gva, h2g_depth); +} diff --git a/src/hyperlight_guest_bin/src/virtq/state.rs b/src/hyperlight_guest_bin/src/virtq/state.rs new file mode 100644 index 000000000..232726377 --- /dev/null +++ b/src/hyperlight_guest_bin/src/virtq/state.rs @@ -0,0 +1,75 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue state and initialization. +//! +//! Holds the global VirtqProducer instances for G2H and H2G queues. +//! The producers are created during guest init (from `hyperlight_guest_bin`) +//! and used by the guest host-call path in `host_comm`. + +use alloc::rc::Rc; +use core::cell::RefCell; +use core::num::NonZeroU16; + +use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; +use hyperlight_guest::virtq_mem::GuestMemOps; + +/// Wrapper to mark types as Sync for single-threaded guest execution. +struct SyncWrap(T); + +// SAFETY: guest execution is single-threaded. +unsafe impl Sync for SyncWrap {} + +/// Guest-side notifier (no-op). +#[derive(Clone, Copy)] +pub struct GuestNotifier; + +impl Notifier for GuestNotifier { + fn notify(&self, _stats: QueueStats) {} +} + +/// Type alias for the guest-side producer. +pub type GuestProducer = VirtqProducer>; +/// Global G2H producer instance, initialized during guest init. +static G2H_PRODUCER: SyncWrap>> = SyncWrap(RefCell::new(None)); + +/// Borrow the G2H producer mutably. +/// +/// # Panics +/// +/// Panics if the G2H producer has not been initialized or is already +/// borrowed. +pub fn with_g2h_producer(f: impl FnOnce(&mut GuestProducer) -> R) -> R { + let mut guard = G2H_PRODUCER.0.borrow_mut(); + let producer = guard.as_mut().expect("G2H producer not initialized"); + f(producer) +} + +/// Initialize the G2H producer +/// +/// # Safety +/// +/// The ring GVA must point to valid, zeroed ring memory of the +/// appropriate size. The pool GVA must point to valid, zeroed memory. +pub unsafe fn init_g2h_producer(ring_gva: u64, num_descs: u16, pool_gva: u64, pool_size: usize) { + let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); + let pool = BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"); + + let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); + let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, Rc::new(pool)); + + *G2H_PRODUCER.0.borrow_mut() = Some(producer); +} diff --git a/src/hyperlight_guest_bin/src/virtq_init.rs b/src/hyperlight_guest_bin/src/virtq_init.rs deleted file mode 100644 index 1f24f5d9b..000000000 --- a/src/hyperlight_guest_bin/src/virtq_init.rs +++ /dev/null @@ -1,50 +0,0 @@ -/* -Copyright 2026 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -//! Guest-side virtqueue initialization. -//! -//! The host places virtqueue rings at deterministic offsets in the -//! scratch region and writes ring GVAs and queue depths to scratch-top -//! metadata. - -use hyperlight_common::layout::{ - self, SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, -}; -use hyperlight_common::virtq::Layout as VirtqLayout; - -/// Read a value from a scratch-top metadata slot. -unsafe fn read_scratch_top(offset: u64) -> T { - let addr = (layout::MAX_GVA as u64 - offset + 1) as *const T; - unsafe { core::ptr::read_volatile(addr) } -} - -/// Initialize virtqueue ring memory in the scratch region. -pub(crate) fn init_virtqueues() { - let g2h_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; - let g2h_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; - let h2g_gva: u64 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; - let h2g_depth: u16 = unsafe { read_scratch_top(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; - - assert!(g2h_depth > 0 && h2g_depth > 0); - assert!(g2h_gva != 0 && h2g_gva != 0); - - let size = VirtqLayout::query_size(g2h_depth as usize); - unsafe { core::ptr::write_bytes(g2h_gva as *mut u8, 0, size) }; - - let size = VirtqLayout::query_size(h2g_depth as usize); - unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, size) }; -} diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 95983520d..831163d4c 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -574,6 +574,10 @@ impl SandboxMemoryManager { scratch_size - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET as usize, self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16, )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_virtq_pool_pages() as u16, + )?; // Copy the page tables into the scratch region let snapshot_pt_end = self.shared_mem.mem_size(); diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index 120aa06cd..a329e5fd5 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -80,6 +80,9 @@ pub struct SandboxConfiguration { /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. /// Default: 32 h2g_queue_depth: usize, + /// Number of physical pages to allocate for each virtqueue's buffer pool. + /// Default: 8 pages (32KB). + virtq_pool_pages: usize, } impl SandboxConfiguration { @@ -103,6 +106,8 @@ impl SandboxConfiguration { pub const DEFAULT_G2H_QUEUE_DEPTH: usize = 64; /// The default H2G virtqueue depth (number of descriptors, must be power of 2) pub const DEFAULT_H2G_QUEUE_DEPTH: usize = 32; + /// The default number of physical pages per virtqueue buffer pool + pub const DEFAULT_VIRTQ_POOL_PAGES: usize = 8; #[allow(clippy::too_many_arguments)] /// Create a new configuration for a sandbox with the given sizes. @@ -126,6 +131,7 @@ impl SandboxConfiguration { interrupt_vcpu_sigrtmin_offset, g2h_queue_depth: Self::DEFAULT_G2H_QUEUE_DEPTH, h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, + virtq_pool_pages: Self::DEFAULT_VIRTQ_POOL_PAGES, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -231,6 +237,26 @@ impl SandboxConfiguration { self.h2g_queue_depth } + /// Get the number of physical pages per virtqueue buffer pool. + pub fn get_virtq_pool_pages(&self) -> usize { + self.virtq_pool_pages + } + + /// Set the G2H virtqueue depth (number of descriptors, must be power of 2). + pub fn set_g2h_queue_depth(&mut self, depth: usize) { + self.g2h_queue_depth = depth; + } + + /// Set the H2G virtqueue depth (number of descriptors, must be power of 2). + pub fn set_h2g_queue_depth(&mut self, depth: usize) { + self.h2g_queue_depth = depth; + } + + /// Set the number of physical pages per virtqueue buffer pool. + pub fn set_virtq_pool_pages(&mut self, pages: usize) { + self.virtq_pool_pages = pages; + } + /// Set the size of the scratch regiong #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_scratch_size(&mut self, scratch_size: usize) { diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 1f823dc49..3b8e0fcb6 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -545,6 +545,9 @@ fn guest_malloc_abort() { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); + cfg.set_g2h_queue_depth(2); + cfg.set_h2g_queue_depth(2); + cfg.set_virtq_pool_pages(2); with_rust_sandbox_cfg(cfg, |mut sbox2| { let err = sbox2 .call::( @@ -621,6 +624,9 @@ fn guest_panic_no_alloc() { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size); + cfg.set_g2h_queue_depth(2); + cfg.set_h2g_queue_depth(2); + cfg.set_virtq_pool_pages(2); with_rust_sandbox_cfg(cfg, |mut sbox| { let res = sbox .call::( @@ -1680,7 +1686,9 @@ fn exception_handler_installation_and_validation() { /// This validates that the exception handling path does not require heap allocations. #[test] fn fill_heap_and_cause_exception() { - with_rust_sandbox(|mut sandbox| { + let mut cfg = SandboxConfiguration::default(); + cfg.set_virtq_pool_pages(2); + with_rust_sandbox_cfg(cfg, |mut sandbox| { let result = sandbox.call::<()>("FillHeapAndCauseException", ()); // The call should fail with an exception error since there's no handler installed From 3c68454ecb878783a48b9aa7c535f94beddb2144 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 26 Mar 2026 11:02:57 +0100 Subject: [PATCH 05/26] feat(virtq): add reset API Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 50 ++++++++ src/hyperlight_common/src/virtq/pool.rs | 94 ++++++++++++++ src/hyperlight_common/src/virtq/producer.rs | 48 +++++++ src/hyperlight_common/src/virtq/ring.rs | 135 ++++++++++++++++++++ 4 files changed, 327 insertions(+) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index 4c7bbc9ba..9e4e09527 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -441,6 +441,12 @@ impl VirtqConsumer { Ok(Bytes::from(buf)) } + + /// Reset ring and inflight state to initial values. + pub fn reset(&mut self) { + self.inner.reset(); + self.inflight.fill(None); + } } /// Parse a descriptor chain into entry/completion buffer elements. @@ -630,4 +636,48 @@ mod tests { assert_eq!(data.as_ref(), b"abc"); consumer.complete(completion).unwrap(); } + + #[test] + fn test_virtq_consumer_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit and poll (but do not complete) + let se = producer.chain().completion(16).build().unwrap(); + producer.submit(se).unwrap(); + + let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert!(consumer.inflight.iter().any(|s| s.is_some())); + + // Complete first so we do not leak + consumer.complete(completion).unwrap(); + + consumer.reset(); + + assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inner.num_inflight(), 0); + } + + #[test] + fn test_virtq_consumer_reset_clears_inflight() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit two entries and poll both + let se1 = producer.chain().completion(16).build().unwrap(); + producer.submit(se1).unwrap(); + let se2 = producer.chain().completion(16).build().unwrap(); + producer.submit(se2).unwrap(); + + let (_e1, c1) = consumer.poll(1024).unwrap().unwrap(); + let (_e2, c2) = consumer.poll(1024).unwrap().unwrap(); + // Complete both before reset + consumer.complete(c1).unwrap(); + consumer.complete(c2).unwrap(); + + consumer.reset(); + + assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inner.num_inflight(), 0); + } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 99d8f4109..83178998d 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -516,6 +516,12 @@ impl Slab { pub const fn slot_size() -> usize { N } + + /// Reset the slab to initial state which is all slots free. + pub fn reset(&mut self) { + self.used_slots.clear(); + self.last_free_run = None; + } } #[inline] @@ -547,6 +553,13 @@ impl BufferPool { inner: inner.into(), }) } + + /// Reset the pool to initial state + pub fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + inner.lower.reset(); + inner.upper.reset(); + } } #[cfg(all(test, loom))] @@ -1171,6 +1184,87 @@ mod tests { slab.dealloc(a3).unwrap(); slab.dealloc(a4).unwrap(); } + + #[test] + fn test_slab_reset_returns_to_initial_state() { + let mut slab = make_slab::<256>(4096); + let initial_free = slab.free_bytes(); + let initial_cap = slab.capacity(); + + // Allocate some slots + let _a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(512).unwrap(); + assert!(slab.free_bytes() < initial_free); + + slab.reset(); + + assert_eq!(slab.free_bytes(), initial_free); + assert_eq!(slab.capacity(), initial_cap); + assert!(slab.last_free_run.is_none()); + assert_eq!(slab.used_slots.count_ones(..), 0); + + // Should be able to allocate the full capacity again + let a = slab.alloc(initial_cap).unwrap(); + assert_eq!(a.len, initial_cap); + } + + #[test] + fn test_slab_reset_matches_new() { + let base = align_up(0x10000, 256) as u64; + let region = 4096; + + let fresh = Slab::<256>::new(base, region).unwrap(); + + let mut used = Slab::<256>::new(base, region).unwrap(); + let _a = used.alloc(256).unwrap(); + let _b = used.alloc(1024).unwrap(); + used.reset(); + + assert_eq!(used.free_bytes(), fresh.free_bytes()); + assert_eq!(used.capacity(), fresh.capacity()); + assert_eq!( + used.used_slots.count_ones(..), + fresh.used_slots.count_ones(..) + ); + assert!(used.last_free_run.is_none()); + assert!(fresh.last_free_run.is_none()); + } + + #[test] + fn test_buffer_pool_reset_returns_to_initial_state() { + let pool = make_pool::<256, 4096>(0x20000); + + // Allocate from both tiers + let a1 = pool.inner.borrow_mut().alloc(128).unwrap(); + let a2 = pool.inner.borrow_mut().alloc(8192).unwrap(); + assert!(a1.len > 0); + assert!(a2.len > 0); + + pool.reset(); + + let inner = pool.inner.borrow(); + assert_eq!(inner.lower.used_slots.count_ones(..), 0); + assert_eq!(inner.upper.used_slots.count_ones(..), 0); + assert!(inner.lower.last_free_run.is_none()); + assert!(inner.upper.last_free_run.is_none()); + } + + #[test] + fn test_buffer_pool_reset_allows_reallocation() { + let pool = make_pool::<256, 4096>(0x20000); + + // Fill up some allocations + let mut allocs = Vec::new(); + for _ in 0..5 { + allocs.push(pool.inner.borrow_mut().alloc(256).unwrap()); + } + + pool.reset(); + + // Should be able to allocate as if fresh + let a = pool.inner.borrow_mut().alloc(256).unwrap(); + assert!(a.len > 0); + } } #[cfg(test)] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 95db0b7ba..28c5dbf3a 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -329,6 +329,13 @@ where } Ok(()) } + + /// Reset ring and inflight state to initial values. + /// Does not reset the buffer pool; call pool.reset() separately if needed. + pub fn reset(&mut self) { + self.inner.reset(); + self.inflight.fill(None); + } } /// Builder for configuring a descriptor chain's buffer layout. @@ -787,4 +794,45 @@ mod tests { assert_eq!(cqe.token, token); assert_eq!(&cqe.data[..], b"response data"); } + + #[test] + fn test_virtq_producer_reset() { + let ring = make_ring(16); + let (mut producer, mut consumer, _notifier) = make_test_producer(&ring); + + // Submit and complete a round trip + let mut se = producer.chain().entry(32).completion(64).build().unwrap(); + se.write_all(b"hello").unwrap(); + producer.submit(se).unwrap(); + + let (entry, completion) = consumer.poll(1024).unwrap().unwrap(); + assert_eq!(entry.data().as_ref(), b"hello"); + consumer.complete(completion).unwrap(); + let _ = producer.poll().unwrap().unwrap(); + + // Now reset + producer.reset(); + + // All inflight slots should be None + assert!(producer.inflight.iter().all(|s| s.is_none())); + // Ring state should be back to initial + assert_eq!(producer.inner.num_free(), producer.inner.len()); + } + + #[test] + fn test_virtq_producer_reset_clears_inflight() { + let ring = make_ring(16); + let (mut producer, _consumer, _notifier) = make_test_producer(&ring); + + // Submit without completing + let se = producer.chain().completion(64).build().unwrap(); + producer.submit(se).unwrap(); + + assert!(producer.inflight.iter().any(|s| s.is_some())); + + producer.reset(); + + assert!(producer.inflight.iter().all(|s| s.is_none())); + assert_eq!(producer.inner.num_free(), producer.inner.len()); + } } diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index c130cdcad..75508afd1 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -369,6 +369,12 @@ impl RingCursor { pub fn wrap(&self) -> bool { self.wrap } + + /// Reset cursor to initial state. + pub fn reset(&mut self) { + self.head = 0; + self.wrap = true; + } } /// Producer (driver) side of a packed virtqueue. @@ -817,6 +823,18 @@ impl RingProducer { Ok(should_notify(evt, self.len() as u16, old, new)) } + + /// Reset to initial state matching a freshly zeroed ring. + pub fn reset(&mut self) { + let size = self.desc_table.len(); + self.avail_cursor.reset(); + self.used_cursor.reset(); + self.num_free = size; + self.id_free.clear(); + self.id_free.extend(0..size as u16); + self.id_num.iter_mut().for_each(|n| *n = 0); + self.event_flags_shadow = EventFlags::ENABLE; + } } /// Consumer (device) side of a packed virtqueue. @@ -1164,6 +1182,16 @@ impl RingConsumer { Ok(should_notify(evt, self.desc_table.len() as u16, old, new)) } + + /// Reset to initial state matching a freshly zeroed ring. + /// Does not reallocate internal buffers. + pub fn reset(&mut self) { + self.avail_cursor.reset(); + self.used_cursor.reset(); + self.id_num.iter_mut().for_each(|n| *n = 0); + self.num_inflight = 0; + self.event_flags_shadow = EventFlags::ENABLE; + } } /// Common packed-ring notification decision: @@ -2974,6 +3002,113 @@ pub(crate) mod tests { let (_, _) = consumer.poll_available().unwrap(); } } + + #[test] + fn test_ring_cursor_reset() { + let mut cursor = RingCursor::new(16); + cursor.advance_by(5); + assert_eq!(cursor.head(), 5); + + cursor.reset(); + assert_eq!(cursor, RingCursor::new(16)); + assert_eq!(cursor.head(), 0); + assert!(cursor.wrap()); + } + + #[test] + fn test_ring_cursor_reset_after_wrap() { + let mut cursor = RingCursor::new(4); + // Advance past the wrap point + cursor.advance_by(5); + assert_eq!(cursor.head(), 1); + assert!(!cursor.wrap()); + + cursor.reset(); + assert_eq!(cursor.head(), 0); + assert!(cursor.wrap()); + } + + #[test] + fn test_ring_producer_reset_matches_new() { + let ring = make_ring(8); + let fresh = make_producer(&ring); + + let mut used = make_producer(&ring); + // Mutate state + used.submit_one(0x1000, 64, false).unwrap(); + used.submit_one(0x2000, 128, true).unwrap(); + + used.reset(); + + assert_eq!(used.avail_cursor, fresh.avail_cursor); + assert_eq!(used.used_cursor, fresh.used_cursor); + assert_eq!(used.num_free, fresh.num_free); + assert_eq!(used.id_free.len(), fresh.id_free.len()); + assert_eq!(used.id_num.as_slice(), fresh.id_num.as_slice()); + assert_eq!(used.event_flags_shadow, fresh.event_flags_shadow); + } + + #[test] + fn test_ring_producer_reset_id_free_complete() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + + // Submit and consume several descriptors + for i in 0..4u64 { + producer.submit_one(0x1000 + i * 0x100, 64, false).unwrap(); + } + assert_eq!(producer.num_free, 4); + + producer.reset(); + + assert_eq!(producer.num_free, 8); + assert_eq!(producer.id_free.len(), 8); + // All IDs 0..8 should be present + for id in 0..8u16 { + assert!(producer.id_free.contains(&id)); + } + } + + #[test] + fn test_ring_consumer_reset_matches_new() { + let ring = make_ring(8); + let fresh = make_consumer(&ring); + + let mut used = make_consumer(&ring); + + // Submit from producer side so consumer has something to poll + let mut producer = make_producer(&ring); + producer.submit_one(0x1000, 64, false).unwrap(); + + // Consumer polls the available descriptor + let (id, _chain) = used.poll_available().unwrap(); + used.submit_used(id, 64).unwrap(); + + used.reset(); + + assert_eq!(used.avail_cursor, fresh.avail_cursor); + assert_eq!(used.used_cursor, fresh.used_cursor); + assert_eq!(used.id_num.as_slice(), fresh.id_num.as_slice()); + assert_eq!(used.num_inflight, fresh.num_inflight); + assert_eq!(used.event_flags_shadow, fresh.event_flags_shadow); + } + + #[test] + fn test_ring_consumer_reset_clears_inflight() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let mut consumer = make_consumer(&ring); + + // Submit and poll two items (consume but do not complete) + producer.submit_one(0x1000, 64, false).unwrap(); + producer.submit_one(0x2000, 64, false).unwrap(); + let _ = consumer.poll_available().unwrap(); + let _ = consumer.poll_available().unwrap(); + assert_eq!(consumer.num_inflight, 2); + + consumer.reset(); + assert_eq!(consumer.num_inflight, 0); + } } #[cfg(test)] From 95cfb6cd19f4531e842a0257d6a0dbdaf1b2ed03 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 26 Mar 2026 14:43:41 +0100 Subject: [PATCH 06/26] feat(virtq): replace guest-to-host calls with virtqueue Signed-off-by: Tomasz Andrzejak --- Cargo.lock | 1 + src/hyperlight_common/src/layout.rs | 4 +- src/hyperlight_guest/Cargo.toml | 1 + src/hyperlight_guest/src/error.rs | 25 ++- src/hyperlight_guest/src/lib.rs | 2 +- src/hyperlight_guest/src/virtq/context.rs | 159 ++++++++++++++++++ .../src/{virtq_mem.rs => virtq/mem.rs} | 16 +- src/hyperlight_guest/src/virtq/mod.rs | 99 +++++++++++ src/hyperlight_guest_bin/Cargo.toml | 3 +- .../src/guest_function/call.rs | 5 + src/hyperlight_guest_bin/src/host_comm.rs | 13 +- src/hyperlight_guest_bin/src/virtq/mod.rs | 19 +-- src/hyperlight_guest_bin/src/virtq/state.rs | 75 --------- src/hyperlight_host/src/mem/mgr.rs | 67 +++++++- src/hyperlight_host/src/sandbox/outb.rs | 71 +++++++- src/tests/rust_guests/dummyguest/Cargo.lock | 1 + src/tests/rust_guests/simpleguest/Cargo.lock | 1 + src/tests/rust_guests/witguest/Cargo.lock | 1 + 18 files changed, 451 insertions(+), 112 deletions(-) create mode 100644 src/hyperlight_guest/src/virtq/context.rs rename src/hyperlight_guest/src/{virtq_mem.rs => virtq/mem.rs} (79%) create mode 100644 src/hyperlight_guest/src/virtq/mod.rs delete mode 100644 src/hyperlight_guest_bin/src/virtq/state.rs diff --git a/Cargo.lock b/Cargo.lock index cd013fbb0..3286d2fa0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1500,6 +1500,7 @@ name = "hyperlight-guest" version = "0.14.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index 4fb5208f6..dae24c3f8 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -41,6 +41,7 @@ pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x28; pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x30; pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x32; pub const SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET: u64 = 0x34; +pub const SCRATCH_TOP_VIRTQ_GENERATION_OFFSET: u64 = 0x36; pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; const _: () = { @@ -51,7 +52,8 @@ const _: () = { assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET); - assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_GENERATION_OFFSET); + assert!(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); }; diff --git a/src/hyperlight_guest/Cargo.toml b/src/hyperlight_guest/Cargo.toml index 3c985d1dc..7b7914bd0 100644 --- a/src/hyperlight_guest/Cargo.toml +++ b/src/hyperlight_guest/Cargo.toml @@ -15,6 +15,7 @@ Provides only the essential building blocks for interacting with the host enviro anyhow = { version = "1.0.102", default-features = false } serde_json = { version = "1.0", default-features = false, features = ["alloc"] } hyperlight-common = { workspace = true, default-features = false } +bytemuck = { version = "1.24", features = ["derive"] } flatbuffers = { version= "25.12.19", default-features = false } tracing = { version = "0.1.44", default-features = false, features = ["attributes"] } diff --git a/src/hyperlight_guest/src/error.rs b/src/hyperlight_guest/src/error.rs index f5e3cbd83..62ca01bda 100644 --- a/src/hyperlight_guest/src/error.rs +++ b/src/hyperlight_guest/src/error.rs @@ -17,10 +17,11 @@ limitations under the License. use alloc::format; use alloc::string::{String, ToString as _}; -use anyhow; -use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; +pub(crate) use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; +use hyperlight_common::flatbuffer_wrappers::guest_error::GuestError; use hyperlight_common::func::Error as FuncError; -use serde_json; +use hyperlight_common::virtq::VirtqError; +use {anyhow, serde_json}; pub type Result = core::result::Result; @@ -81,6 +82,24 @@ impl From for HyperlightGuestError { } } +impl From for HyperlightGuestError { + fn from(e: VirtqError) -> Self { + Self { + kind: ErrorCode::GuestError, + message: format!("virtq: {e}"), + } + } +} + +impl From for HyperlightGuestError { + fn from(e: GuestError) -> Self { + Self { + kind: e.code, + message: e.message, + } + } +} + /// Extension trait to add context to `Option` and `Result` types in guest code, /// converting them to `Result`. /// diff --git a/src/hyperlight_guest/src/lib.rs b/src/hyperlight_guest/src/lib.rs index 6ef9efd86..8dbd74dc0 100644 --- a/src/hyperlight_guest/src/lib.rs +++ b/src/hyperlight_guest/src/lib.rs @@ -25,7 +25,7 @@ pub mod error; pub mod exit; pub mod layout; pub mod prim_alloc; -pub mod virtq_mem; +pub mod virtq; pub mod guest_handle { pub mod handle; diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs new file mode 100644 index 000000000..e83639fc9 --- /dev/null +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -0,0 +1,159 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest virtqueue context. + +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::num::NonZeroU16; +use core::sync::atomic::AtomicU16; +use core::sync::atomic::Ordering::Relaxed; + +use flatbuffers::FlatBufferBuilder; +use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; +use hyperlight_common::flatbuffer_wrappers::function_types::{ + FunctionCallResult, ParameterValue, ReturnType, ReturnValue, +}; +use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; +use hyperlight_common::outb::OutBAction; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; + +use super::GuestMemOps; +use crate::bail; +use crate::error::Result; + +static REQUEST_ID: AtomicU16 = AtomicU16::new(0); +const MAX_RESPONSE_CAP: usize = 4096; + +/// Guest-side notifier that triggers a VM exit via outb. +#[derive(Clone, Copy)] +pub struct GuestNotifier; + +impl Notifier for GuestNotifier { + fn notify(&self, _stats: QueueStats) { + unsafe { crate::exit::out32(OutBAction::VirtqNotify as u16, 0) }; + } +} + +/// Type alias for the guest-side G2H producer. +pub type G2hProducer = VirtqProducer>; + +/// Virtqueue runtime state for guest-host communication. +pub struct GuestContext { + g2h_pool: Arc, + g2h_producer: G2hProducer, + generation: u16, +} + +impl GuestContext { + /// Create a new context with a G2H queue. + /// + /// # Safety + /// + /// `ring_gva` must point to valid, zeroed ring memory. + /// `pool_gva` must point to valid, zeroed memory. + pub unsafe fn new( + ring_gva: u64, + num_descs: u16, + pool_gva: u64, + pool_size: usize, + generation: u16, + ) -> Self { + let pool = Arc::new( + BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"), + ); + let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); + let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); + let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, pool.clone()); + + Self { + g2h_pool: pool, + g2h_producer: producer, + generation, + } + } + + /// Call a host function via the G2H virtqueue. + pub fn call_host_function>( + &mut self, + function_name: &str, + parameters: Option>, + return_type: ReturnType, + ) -> Result { + let params = parameters.as_deref().unwrap_or_default(); + let estimated_capacity = estimate_flatbuffer_capacity(function_name, params); + + let fc = FunctionCall::new( + function_name.into(), + parameters, + FunctionCallType::Host, + return_type, + ); + + let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); + let payload = fc.encode(&mut builder); + + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let hdr = VirtqMsgHeader::new(MsgKind::Request, reqid, payload.len() as u32); + let hdr_bytes = bytemuck::bytes_of(&hdr); + + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + + let mut entry = self + .g2h_producer + .chain() + .entry(entry_len) + .completion(MAX_RESPONSE_CAP) + .build()?; + + entry.write_all(hdr_bytes)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry)?; + + let Some(completion) = self.g2h_producer.poll()? else { + bail!("G2H: no completion received"); + }; + + let result_bytes = &completion.data; + if result_bytes.len() > MAX_RESPONSE_CAP { + bail!("G2H: response is too large"); + } + + let payload_bytes = &result_bytes[VirtqMsgHeader::SIZE..]; + let Ok(fcr) = FunctionCallResult::try_from(payload_bytes) else { + bail!("G2H: malformed response"); + }; + + let ret = fcr.into_inner()?; + let Ok(ret) = T::try_from(ret) else { + bail!("G2H: host return value type mismatch"); + }; + + Ok(ret) + } + + /// Reset ring and pool state after snapshot restore. + pub(super) fn reset(&mut self, new_generation: u16) { + self.g2h_producer.reset(); + self.g2h_pool.reset(); + self.generation = new_generation; + } + + pub(super) fn generation(&self) -> u16 { + self.generation + } +} diff --git a/src/hyperlight_guest/src/virtq_mem.rs b/src/hyperlight_guest/src/virtq/mem.rs similarity index 79% rename from src/hyperlight_guest/src/virtq_mem.rs rename to src/hyperlight_guest/src/virtq/mem.rs index 8309deb79..590be2ac3 100644 --- a/src/hyperlight_guest/src/virtq_mem.rs +++ b/src/hyperlight_guest/src/virtq/mem.rs @@ -31,29 +31,21 @@ impl MemOps for GuestMemOps { type Error = Infallible; fn read(&self, addr: u64, dst: &mut [u8]) -> Result { - let src = addr as *const u8; - unsafe { - ptr::copy_nonoverlapping(src, dst.as_mut_ptr(), dst.len()); - } + unsafe { ptr::copy_nonoverlapping(addr as *const u8, dst.as_mut_ptr(), dst.len()) }; Ok(dst.len()) } fn write(&self, addr: u64, src: &[u8]) -> Result { - let dst = addr as *mut u8; - unsafe { - ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len()); - } + unsafe { ptr::copy_nonoverlapping(src.as_ptr(), addr as *mut u8, src.len()) }; Ok(src.len()) } fn load_acquire(&self, addr: u64) -> Result { - let ptr = addr as *const AtomicU16; - Ok(unsafe { (*ptr).load(Ordering::Acquire) }) + Ok(unsafe { (*(addr as *const AtomicU16)).load(Ordering::Acquire) }) } fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { - let ptr = addr as *const AtomicU16; - unsafe { (*ptr).store(val, Ordering::Release) }; + unsafe { (*(addr as *const AtomicU16)).store(val, Ordering::Release) }; Ok(()) } diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs new file mode 100644 index 000000000..29f404740 --- /dev/null +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -0,0 +1,99 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Guest-side virtqueue support. +//! +//! Global context is installed once via [`set_global_context`] and +//! accessed via [`with_context`]. + +pub mod context; +pub mod mem; + +use core::cell::UnsafeCell; +use core::sync::atomic::{AtomicU8, Ordering}; + +use context::GuestContext; +use hyperlight_common::layout::{SCRATCH_TOP_VIRTQ_GENERATION_OFFSET, scratch_top_ptr}; +pub use mem::GuestMemOps; + +// Init state machine +const UNINITIALIZED: u8 = 0; +const INITIALIZED: u8 = 1; +static INIT_STATE: AtomicU8 = AtomicU8::new(UNINITIALIZED); + +/// Check if the global context has been initialized. +pub fn is_initialized() -> bool { + INIT_STATE.load(Ordering::Acquire) == INITIALIZED +} + +// Storage: UnsafeCell guarded by atomic init state. +struct SyncWrap(T); +unsafe impl Sync for SyncWrap {} + +static GLOBAL_CONTEXT: SyncWrap>> = SyncWrap(UnsafeCell::new(None)); + +/// Access the global guest context via closure. +/// +/// # Panics +/// +/// Panics if the context has not been initialized. +pub fn with_context(f: impl FnOnce(&mut GuestContext) -> R) -> R { + assert!( + INIT_STATE.load(Ordering::Acquire) == INITIALIZED, + "guest context not initialized" + ); + let ctx = unsafe { &mut *GLOBAL_CONTEXT.0.get() }; + f(ctx.as_mut().unwrap()) +} + +/// Install the global guest context. Called once during guest init. +/// +/// # Panics +/// +/// Panics if called more than once. +pub fn set_global_context(ctx: GuestContext) { + if INIT_STATE + .compare_exchange( + UNINITIALIZED, + INITIALIZED, + Ordering::SeqCst, + Ordering::SeqCst, + ) + .is_err() + { + panic!("guest context already initialized"); + } + unsafe { *GLOBAL_CONTEXT.0.get() = Some(ctx) }; +} + +/// Reset the global context if a snapshot restore was detected. +/// Compares the virtq generation counter in scratch-top metadata. +pub fn reset_global_context() { + if !is_initialized() { + return; + } + let current_gen = read_gen(); + with_context(|ctx| { + if current_gen != ctx.generation() { + ctx.reset(current_gen); + } + }); +} + +/// Read the current virtqueue generation from scratch-top metadata. +fn read_gen() -> u16 { + unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET) } +} diff --git a/src/hyperlight_guest_bin/Cargo.toml b/src/hyperlight_guest_bin/Cargo.toml index ea6ef1dd2..7c12d3765 100644 --- a/src/hyperlight_guest_bin/Cargo.toml +++ b/src/hyperlight_guest_bin/Cargo.toml @@ -14,9 +14,10 @@ and third-party code used by our C-API needed to build a native hyperlight-guest """ [features] -default = ["libc", "printf", "macros"] +default = ["libc", "printf", "macros", "virtq"] libc = [] # compile musl libc printf = [ "libc" ] # compile printf +virtq = [] # use virtqueue for guest-to-host calls trace_guest = ["hyperlight-common/trace_guest", "hyperlight-guest/trace_guest", "hyperlight-guest-tracing/trace"] mem_profile = ["hyperlight-common/mem_profile"] macros = ["dep:hyperlight-guest-macro", "dep:linkme"] diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index ebadda540..37d02457b 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -22,6 +22,7 @@ use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, Functi use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterType}; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; use tracing::instrument; use crate::{GUEST_HANDLE, REGISTERED_GUEST_FUNCTIONS}; @@ -87,6 +88,10 @@ pub(crate) fn internal_dispatch_function() { let handle = unsafe { GUEST_HANDLE }; + // After snapshot restore, the ring memory is zeroed but the + // producer's cursors are stale. Check once per dispatch entry. + virtq::reset_global_context(); + let function_call = handle .try_pop_shared_input_data_into::() .expect("Function call deserialization failed"); diff --git a/src/hyperlight_guest_bin/src/host_comm.rs b/src/hyperlight_guest_bin/src/host_comm.rs index 16ac9af4f..369981deb 100644 --- a/src/hyperlight_guest_bin/src/host_comm.rs +++ b/src/hyperlight_guest_bin/src/host_comm.rs @@ -41,8 +41,17 @@ pub fn call_host_function( where T: TryFrom, { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function::(function_name, parameters, return_type) + #[cfg(feature = "virtq")] + { + hyperlight_guest::virtq::with_context(|ctx| { + ctx.call_host_function(function_name, parameters, return_type) + }) + } + #[cfg(not(feature = "virtq"))] + { + let handle = unsafe { GUEST_HANDLE }; + handle.call_host_function::(function_name, parameters, return_type) + } } pub fn call_host(function_name: impl AsRef, args: impl ParameterTuple) -> Result diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq/mod.rs index 50c0dd6d9..1c1b42300 100644 --- a/src/hyperlight_guest_bin/src/virtq/mod.rs +++ b/src/hyperlight_guest_bin/src/virtq/mod.rs @@ -15,30 +15,27 @@ limitations under the License. */ //! Guest-side virtqueue initialization. -//! -//! Zeroes ring memory and creates VirtqProducer instances by allocating -//! buffer pool pages from the scratch page allocator. - -pub(crate) mod state; use hyperlight_common::layout::{ SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, + SCRATCH_TOP_VIRTQ_GENERATION_OFFSET, SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, }; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::virtq::Layout as VirtqLayout; use hyperlight_guest::prim_alloc::alloc_phys_pages; +use hyperlight_guest::virtq::context::GuestContext; use crate::paging::phys_to_virt; -/// Initialize virtqueue producers for G2H and H2G queues. +/// Initialize virtqueue context. pub(crate) fn init_virtqueues() { let g2h_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_RING_GVA_OFFSET) }; let g2h_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; let h2g_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; let pool_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET) } as u64; + let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET) }; assert!(g2h_depth > 0 && h2g_depth > 0); assert!(g2h_gva != 0 && h2g_gva != 0); @@ -58,11 +55,9 @@ pub(crate) fn init_virtqueues() { let pool_size = pool_pages as usize * PAGE_SIZE_USIZE; unsafe { core::ptr::write_bytes(pool_ptr, 0, pool_size) }; - // Create G2H producer - unsafe { - state::init_g2h_producer(g2h_gva, g2h_depth, pool_gva, pool_size); - } + // Create and install global context + let ctx = unsafe { GuestContext::new(g2h_gva, g2h_depth, pool_gva, pool_size, generation) }; + hyperlight_guest::virtq::set_global_context(ctx); - // TODO(virtq): add other direction's producer let _ = (h2g_gva, h2g_depth); } diff --git a/src/hyperlight_guest_bin/src/virtq/state.rs b/src/hyperlight_guest_bin/src/virtq/state.rs deleted file mode 100644 index 232726377..000000000 --- a/src/hyperlight_guest_bin/src/virtq/state.rs +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2026 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -//! Guest-side virtqueue state and initialization. -//! -//! Holds the global VirtqProducer instances for G2H and H2G queues. -//! The producers are created during guest init (from `hyperlight_guest_bin`) -//! and used by the guest host-call path in `host_comm`. - -use alloc::rc::Rc; -use core::cell::RefCell; -use core::num::NonZeroU16; - -use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; -use hyperlight_guest::virtq_mem::GuestMemOps; - -/// Wrapper to mark types as Sync for single-threaded guest execution. -struct SyncWrap(T); - -// SAFETY: guest execution is single-threaded. -unsafe impl Sync for SyncWrap {} - -/// Guest-side notifier (no-op). -#[derive(Clone, Copy)] -pub struct GuestNotifier; - -impl Notifier for GuestNotifier { - fn notify(&self, _stats: QueueStats) {} -} - -/// Type alias for the guest-side producer. -pub type GuestProducer = VirtqProducer>; -/// Global G2H producer instance, initialized during guest init. -static G2H_PRODUCER: SyncWrap>> = SyncWrap(RefCell::new(None)); - -/// Borrow the G2H producer mutably. -/// -/// # Panics -/// -/// Panics if the G2H producer has not been initialized or is already -/// borrowed. -pub fn with_g2h_producer(f: impl FnOnce(&mut GuestProducer) -> R) -> R { - let mut guard = G2H_PRODUCER.0.borrow_mut(); - let producer = guard.as_mut().expect("G2H producer not initialized"); - f(producer) -} - -/// Initialize the G2H producer -/// -/// # Safety -/// -/// The ring GVA must point to valid, zeroed ring memory of the -/// appropriate size. The pool GVA must point to valid, zeroed memory. -pub unsafe fn init_g2h_producer(ring_gva: u64, num_descs: u16, pool_gva: u64, pool_size: usize) { - let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); - let pool = BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"); - - let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); - let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, Rc::new(pool)); - - *G2H_PRODUCER.0.borrow_mut() = Some(producer); -} diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 831163d4c..f46361dfa 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -21,6 +21,20 @@ use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{ FunctionCall, validate_guest_function_call_buffer, }; + +use super::virtq_mem::HostMemOps; + +/// No-op notifier for host-side consumer. +/// The host resumes the VM to notify the guest, not via the ring. +#[derive(Clone, Copy)] +pub(crate) struct HostNotifier; + +impl hyperlight_common::virtq::Notifier for HostNotifier { + fn notify(&self, _stats: hyperlight_common::virtq::QueueStats) {} +} + +/// Type alias for the host-side G2H virtqueue consumer. +pub(crate) type G2hConsumer = hyperlight_common::virtq::VirtqConsumer; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::virtq::Layout as VirtqLayout; @@ -136,7 +150,6 @@ impl ReadonlySharedMemory { pub(crate) use unused_hack::SnapshotSharedMemory; /// A struct that is responsible for laying out and managing the memory /// for a given `Sandbox`. -#[derive(Clone)] pub(crate) struct SandboxMemoryManager { /// Shared memory for the Sandbox pub(crate) shared_mem: SnapshotSharedMemory, @@ -150,6 +163,22 @@ pub(crate) struct SandboxMemoryManager { pub(crate) mapped_rgns: u64, /// Buffer for accumulating guest abort messages pub(crate) abort_buffer: Vec, + /// G2H virtqueue consumer, created after sandbox init. + pub(crate) g2h_consumer: Option, +} + +impl Clone for SandboxMemoryManager { + fn clone(&self) -> Self { + Self { + shared_mem: self.shared_mem.clone(), + scratch_mem: self.scratch_mem.clone(), + layout: self.layout, + entrypoint: self.entrypoint, + mapped_rgns: self.mapped_rgns, + abort_buffer: self.abort_buffer.clone(), + g2h_consumer: None, // consumer is not cloned; re-init if needed + } + } } pub(crate) struct GuestPageTableBuffer { @@ -259,6 +288,7 @@ where entrypoint, mapped_rgns: 0, abort_buffer: Vec::new(), + g2h_consumer: None, } } @@ -326,6 +356,7 @@ impl SandboxMemoryManager { entrypoint: self.entrypoint, mapped_rgns: self.mapped_rgns, abort_buffer: self.abort_buffer, + g2h_consumer: None, }; let guest_mgr = SandboxMemoryManager { shared_mem: gshm, @@ -334,8 +365,10 @@ impl SandboxMemoryManager { entrypoint: self.entrypoint, mapped_rgns: self.mapped_rgns, abort_buffer: Vec::new(), // Guest doesn't need abort buffer + g2h_consumer: None, }; host_mgr.update_scratch_bookkeeping()?; + host_mgr.init_g2h_consumer()?; Ok((host_mgr, guest_mgr)) } } @@ -526,6 +559,7 @@ impl SandboxMemoryManager { }; self.layout = *snapshot.layout(); self.update_scratch_bookkeeping()?; + self.init_g2h_consumer()?; Ok((gsnapshot, gscratch)) } @@ -578,6 +612,11 @@ impl SandboxMemoryManager { scratch_size - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET as usize, self.layout.sandbox_memory_config.get_virtq_pool_pages() as u16, )?; + // Increment generation so the guest detects stale ring state. + let gen_offset = scratch_size - SCRATCH_TOP_VIRTQ_GENERATION_OFFSET as usize; + let gen_val: u16 = self.scratch_mem.read(gen_offset).unwrap_or(0); + self.scratch_mem + .write::(gen_offset, gen_val.wrapping_add(1))?; // Copy the page tables into the scratch region let snapshot_pt_end = self.shared_mem.mem_size(); @@ -840,6 +879,32 @@ impl SandboxMemoryManager { unsafe { VirtqLayout::from_base(base, nz) } .map_err(|e| new_error!("Invalid H2G virtq layout: {:?}", e)) } + + /// Create a [`HostMemOps`] instance backed by this manager's + /// scratch shared memory. + pub(crate) fn host_mem_ops(&self) -> HostMemOps { + let scratch_base_gva = + hyperlight_common::layout::scratch_base_gva(self.scratch_mem.mem_size()); + HostMemOps::new(&self.scratch_mem, scratch_base_gva) + } + + /// Initialize the G2H virtqueue consumer. + /// Must be called after scratch bookkeeping is written. + pub(crate) fn init_g2h_consumer(&mut self) -> Result<()> { + match &mut self.g2h_consumer { + Some(consumer) => { + consumer.reset(); + } + None => { + let layout = self.g2h_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let consumer = + hyperlight_common::virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + self.g2h_consumer = Some(consumer); + } + } + Ok(()) + } } #[cfg(test)] diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index bb73763a6..aa40bec3d 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -16,6 +16,7 @@ limitations under the License. use std::sync::{Arc, Mutex}; +use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, ParameterValue}; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; @@ -180,6 +181,71 @@ fn outb_abort( Ok(()) } +/// Handle a guest-to-host function call received via the G2H virtqueue. +fn outb_virtq_call( + mem_mgr: &mut SandboxMemoryManager, + host_funcs: &Arc>, +) -> Result<(), HandleOutbError> { + use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; + + let consumer = mem_mgr.g2h_consumer.as_mut().ok_or_else(|| { + HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) + })?; + + let (entry, completion) = consumer + .poll(8192) + .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? + .ok_or_else(|| HandleOutbError::ReadHostFunctionCall("G2H poll: no entry".into()))?; + + // Parse: skip VirtqMsgHeader, deserialize FunctionCall from remainder + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H entry too short".into(), + )); + } + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let call = FunctionCall::try_from(payload) + .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; + + // Dispatch the host function (same as CallFunction path) + let name = call.function_name.clone(); + let args: Vec = call.parameters.unwrap_or(vec![]); + let res = host_funcs + .try_lock() + .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))? + .call_host_function(&name, args) + .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); + + // Serialize response: VirtqMsgHeader + FunctionCallResult + let func_result = FunctionCallResult::new(res); + let mut builder = flatbuffers::FlatBufferBuilder::new(); + let result_payload = func_result.encode(&mut builder); + + let resp_header = VirtqMsgHeader::new(MsgKind::Response, 0, result_payload.len() as u32); + let resp_header_bytes = bytemuck::bytes_of(&resp_header); + + // Write response into the completion buffer + match completion { + hyperlight_common::virtq::SendCompletion::Writable(mut wc) => { + wc.write_all(resp_header_bytes) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + wc.write_all(result_payload) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + consumer + .complete(wc.into()) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + } + hyperlight_common::virtq::SendCompletion::Ack(ack) => { + consumer + .complete(ack.into()) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + } + } + + Ok(()) +} + /// Handles OutB operations from the guest. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn handle_outb( @@ -227,10 +293,7 @@ pub(crate) fn handle_outb( eprint!("{}", ch); Ok(()) } - OutBAction::VirtqNotify => { - // TODO(ring): acknowledge notification but no-op for now. - Ok(()) - } + OutBAction::VirtqNotify => outb_virtq_call(mem_mgr, host_funcs), #[cfg(feature = "trace_guest")] OutBAction::TraceBatch => Ok(()), #[cfg(feature = "mem_profile")] diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index 736cf6c4b..b6aaae23c 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -142,6 +142,7 @@ name = "hyperlight-guest" version = "0.14.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index e12c99d4d..dbc6c01e1 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -134,6 +134,7 @@ name = "hyperlight-guest" version = "0.14.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index 39bb97169..4fb67bbe7 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -266,6 +266,7 @@ name = "hyperlight-guest" version = "0.14.0" dependencies = [ "anyhow", + "bytemuck", "flatbuffers", "hyperlight-common", "hyperlight-guest-tracing", From 833a3fc3849cd472aa22d03b797806f43a1c6282 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 3 Apr 2026 12:48:06 +0200 Subject: [PATCH 07/26] feat(virtq): replace host-to-guest calls with virtq Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/layout.rs | 36 +-- src/hyperlight_common/src/virtq/mod.rs | 12 +- src/hyperlight_common/src/virtq/pool.rs | 32 ++- src/hyperlight_common/src/virtq/producer.rs | 36 ++- .../src/virtq/recycle_pool.rs | 120 ++++++++++ src/hyperlight_guest/src/virtq/context.rs | 158 ++++++++++--- src/hyperlight_guest/src/virtq/mod.rs | 13 +- .../src/guest_function/call.rs | 34 ++- src/hyperlight_guest_bin/src/virtq/mod.rs | 62 +++-- src/hyperlight_host/src/mem/layout.rs | 1 + src/hyperlight_host/src/mem/mgr.rs | 221 +++++++++++++++--- src/hyperlight_host/src/sandbox/config.rs | 80 +++++-- .../src/sandbox/initialized_multi_use.rs | 4 +- src/hyperlight_host/src/sandbox/outb.rs | 53 +++-- .../src/sandbox/uninitialized_evolve.rs | 15 +- src/hyperlight_host/tests/integration_test.rs | 8 +- 16 files changed, 706 insertions(+), 179 deletions(-) create mode 100644 src/hyperlight_common/src/virtq/recycle_pool.rs diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index dae24c3f8..f6c8b1caa 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -36,24 +36,28 @@ pub use arch::{SNAPSHOT_PT_GVA_MAX, SNAPSHOT_PT_GVA_MIN}; pub const SCRATCH_TOP_SIZE_OFFSET: u64 = 0x08; pub const SCRATCH_TOP_ALLOCATOR_OFFSET: u64 = 0x10; pub const SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET: u64 = 0x18; -pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x20; -pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x28; -pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x30; -pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x32; -pub const SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET: u64 = 0x34; -pub const SCRATCH_TOP_VIRTQ_GENERATION_OFFSET: u64 = 0x36; -pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x40; +pub const SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET: u64 = 0x20; +pub const SCRATCH_TOP_G2H_RING_GVA_OFFSET: u64 = 0x28; +pub const SCRATCH_TOP_H2G_RING_GVA_OFFSET: u64 = 0x30; +pub const SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET: u64 = 0x38; +pub const SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET: u64 = 0x3A; +pub const SCRATCH_TOP_G2H_POOL_PAGES_OFFSET: u64 = 0x3C; +pub const SCRATCH_TOP_H2G_POOL_PAGES_OFFSET: u64 = 0x3E; +pub const SCRATCH_TOP_H2G_POOL_GVA_OFFSET: u64 = 0x48; +pub const SCRATCH_TOP_EXN_STACK_OFFSET: u64 = 0x50; const _: () = { - assert!(SCRATCH_TOP_SIZE_OFFSET + 8 <= SCRATCH_TOP_ALLOCATOR_OFFSET); - assert!(SCRATCH_TOP_ALLOCATOR_OFFSET + 8 <= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET); - assert!(SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET + 8 <= SCRATCH_TOP_G2H_RING_GVA_OFFSET); - assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_H2G_RING_GVA_OFFSET); - assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8 <= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET); - assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET); - assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET); - assert!(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET + 2 <= SCRATCH_TOP_VIRTQ_GENERATION_OFFSET); - assert!(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET + 2 <= SCRATCH_TOP_EXN_STACK_OFFSET); + assert!(SCRATCH_TOP_ALLOCATOR_OFFSET >= SCRATCH_TOP_SIZE_OFFSET + 8); + assert!(SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET >= SCRATCH_TOP_ALLOCATOR_OFFSET + 8); + assert!(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET >= SCRATCH_TOP_SNAPSHOT_PT_GPA_BASE_OFFSET + 8); + assert!(SCRATCH_TOP_G2H_RING_GVA_OFFSET >= SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET + 8); + assert!(SCRATCH_TOP_H2G_RING_GVA_OFFSET >= SCRATCH_TOP_G2H_RING_GVA_OFFSET + 8); + assert!(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET >= SCRATCH_TOP_H2G_RING_GVA_OFFSET + 8); + assert!(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET >= SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET + 2); + assert!(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET >= SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET + 2); + assert!(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET >= SCRATCH_TOP_G2H_POOL_PAGES_OFFSET + 2); + assert!(SCRATCH_TOP_H2G_POOL_GVA_OFFSET >= SCRATCH_TOP_H2G_POOL_PAGES_OFFSET + 8); + assert!(SCRATCH_TOP_EXN_STACK_OFFSET >= SCRATCH_TOP_H2G_POOL_GVA_OFFSET + 8); assert!(SCRATCH_TOP_EXN_STACK_OFFSET % 0x10 == 0); }; diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index b52ef2805..5e9fc7e5f 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -157,6 +157,7 @@ mod event; pub mod msg; mod pool; mod producer; +pub mod recycle_pool; mod ring; use core::num::NonZeroU16; @@ -170,7 +171,7 @@ pub use producer::*; pub use ring::*; use thiserror::Error; -/// A trait for notifying about new requests in the virtqueue. +/// A trait for notifying the consumer about virtqueue events. pub trait Notifier { fn notify(&self, stats: QueueStats); } @@ -439,15 +440,14 @@ pub(crate) mod test_utils { } } + type TestProducer = VirtqProducer, TestNotifier, TestPool>; + type TestConsumer = VirtqConsumer, TestNotifier>; + /// Create test infrastructure: a producer, consumer, and notifier backed /// by the supplied [`OwnedRing`]. pub(crate) fn make_test_producer( ring: &OwnedRing, - ) -> ( - VirtqProducer, TestNotifier, TestPool>, - VirtqConsumer, TestNotifier>, - TestNotifier, - ) { + ) -> (TestProducer, TestConsumer, TestNotifier) { let layout = ring.layout(); let mem = ring.mem(); diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 83178998d..cf0915fdf 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -79,7 +79,6 @@ limitations under the License. //! owning slab (`Slab::resize`) but will never move allocations between //! slabs. -#[cfg(all(test, loom))] use alloc::sync::Arc; use core::cmp::Ordering; @@ -124,6 +123,9 @@ pub trait BufferProvider { /// Resize by trying in-place grow; otherwise reserve a new block and free old. fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; + + /// Reset the pool to initial state. + fn reset(&self) {} } impl BufferProvider for alloc::rc::Rc { @@ -136,9 +138,12 @@ impl BufferProvider for alloc::rc::Rc { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { (**self).resize(old_alloc, new_len) } + fn reset(&self) { + (**self).reset() + } } -impl BufferProvider for alloc::sync::Arc { +impl BufferProvider for Arc { fn alloc(&self, len: usize) -> Result { (**self).alloc(len) } @@ -148,6 +153,9 @@ impl BufferProvider for alloc::sync::Arc { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { (**self).resize(old_alloc, new_len) } + fn reset(&self) { + (**self).reset() + } } /// The owner of a mapped buffer, ensuring its lifetime. @@ -540,9 +548,10 @@ struct Inner { } /// Two tier buffer pool with small and large slabs. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BufferPool { - inner: AtomicRefCell>, + // TODO: Use Rc instead, relax Sync + Send bounds + inner: Arc>>, } impl BufferPool { @@ -550,16 +559,9 @@ impl BufferPool { pub fn new(base_addr: u64, region_len: usize) -> Result { let inner = Inner::::new(base_addr, region_len)?; Ok(Self { - inner: inner.into(), + inner: Arc::new(inner.into()), }) } - - /// Reset the pool to initial state - pub fn reset(&self) { - let mut inner = self.inner.borrow_mut(); - inner.lower.reset(); - inner.upper.reset(); - } } #[cfg(all(test, loom))] @@ -672,6 +674,12 @@ impl BufferProvider for BufferPool { fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { self.inner.borrow_mut().resize(old_alloc, new_len) } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + inner.lower.reset(); + inner.upper.reset(); + } } #[cfg(all(test, loom))] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 28c5dbf3a..5e6a7edf1 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -282,6 +282,13 @@ where *slot = Some(inflight); let should_notify = self.inner.should_notify_since(cursor_before)?; + + // TODO(virtq): for now simulate current outb behavior of only + // notifying on bidirectional (request/response) entries. + // Eventually this should be decoupled from the buffer layout + // and driven entirely by event suppression rules. + let should_notify = should_notify && matches!(inflight, Inflight::ReadWrite { .. }); + if should_notify { self.notifier.notify(QueueStats { num_free: self.inner.num_free(), @@ -292,6 +299,17 @@ where Ok(Token(id)) } + /// Signal backpressure to the consumer. + /// + /// Bypasses event suppression. Call this when submit fails with a backpressure error and the consumer needs to drain. + #[inline] + pub fn notify_backpressure(&self) { + self.notifier.notify(QueueStats { + num_free: self.inner.num_free(), + num_inflight: self.inner.num_inflight(), + }); + } + /// Get the current used cursor position. /// /// Useful for setting up descriptor-based event suppression. @@ -330,10 +348,20 @@ where Ok(()) } - /// Reset ring and inflight state to initial values. - /// Does not reset the buffer pool; call pool.reset() separately if needed. + /// Reset ring, inflight, and pool state to initial values. + /// + /// # Safety + /// + /// All [`RecvCompletion`]s (and their backing [`Bytes`]) from + /// previous `poll()` calls must have been dropped before calling + /// this. Outstanding completions hold pool allocations via + /// `BufferOwner`; resetting the pool while they exist would cause + /// double-free on drop. + /// + /// TODO(virtq): properly restore state after snapshot instead of just resetting everything pub fn reset(&mut self) { self.inner.reset(); + self.pool.reset(); self.inflight.fill(None); } } @@ -343,14 +371,14 @@ where /// If dropped without building, no resources are leaked (allocations are /// deferred to [`build`](Self::build)). #[must_use = "call .build() to create a SendEntry"] -pub struct ChainBuilder { +pub struct ChainBuilder { mem: M, pool: P, entry_cap: Option, cqe_cap: Option, } -impl ChainBuilder { +impl ChainBuilder { fn new(mem: M, pool: P) -> Self { Self { mem, diff --git a/src/hyperlight_common/src/virtq/recycle_pool.rs b/src/hyperlight_common/src/virtq/recycle_pool.rs new file mode 100644 index 000000000..4bcf9978a --- /dev/null +++ b/src/hyperlight_common/src/virtq/recycle_pool.rs @@ -0,0 +1,120 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! A simple fixed-size buffer recycler for H2G prefill entries. +//! +//! Unlike [`super::BufferPool`] which uses a bitmap allocator, this +//! holds a fixed set of same-sized buffer addresses in a free list. +//! Alloc and dealloc are O(1). Intended for H2G writable buffers +//! that are pre-allocated once and recycled after each use. + +use alloc::sync::Arc; + +use atomic_refcell::AtomicRefCell; +use smallvec::SmallVec; + +use super::{AllocError, Allocation, BufferProvider}; + +/// A recycling buffer provider with fixed-size slots. +#[derive(Clone)] +pub struct RecyclePool { + inner: Arc>, +} + +struct RecyclePoolInner { + base_addr: u64, + slot_size: usize, + count: usize, + free: SmallVec<[u64; 64]>, +} + +impl RecyclePool { + /// Create a new recycling pool by carving `base..base+region_len` into slots of `slot_size` bytes. + pub fn new(base_addr: u64, region_len: usize, slot_size: usize) -> Result { + if slot_size == 0 { + return Err(AllocError::InvalidArg); + } + + let count = region_len / slot_size; + if count == 0 { + return Err(AllocError::EmptyRegion); + } + + let mut free = SmallVec::with_capacity(count); + for i in 0..count { + free.push(base_addr + (i * slot_size) as u64); + } + + let inner = AtomicRefCell::new(RecyclePoolInner { + base_addr, + slot_size, + count, + free, + }); + + Ok(Self { + inner: inner.into(), + }) + } + + /// Number of free slots. + pub fn num_free(&self) -> usize { + self.inner.borrow().free.len() + } +} + +impl BufferProvider for RecyclePool { + fn alloc(&self, len: usize) -> Result { + let mut inner = self.inner.borrow_mut(); + if len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + + let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; + + Ok(Allocation { + addr, + len: inner.slot_size, + }) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + let mut inner = self.inner.borrow_mut(); + inner.free.push(alloc.addr); + Ok(()) + } + + fn resize(&self, old: Allocation, new_len: usize) -> Result { + let inner = self.inner.borrow(); + if new_len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + Ok(old) + } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + let base = inner.base_addr; + let slot = inner.slot_size; + let count = inner.count; + + inner.free.clear(); + + for i in 0..count { + inner.free.push(base + (i * slot) as u64); + } + } +} diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index e83639fc9..e02c7ef52 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -16,7 +16,6 @@ limitations under the License. //! Guest virtqueue context. -use alloc::sync::Arc; use alloc::vec::Vec; use core::num::NonZeroU16; use core::sync::atomic::AtomicU16; @@ -28,8 +27,10 @@ use hyperlight_common::flatbuffer_wrappers::function_types::{ FunctionCallResult, ParameterValue, ReturnType, ReturnValue, }; use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; +use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::outb::OutBAction; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::recycle_pool::RecyclePool; use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; use super::GuestMemOps; @@ -50,41 +51,58 @@ impl Notifier for GuestNotifier { } /// Type alias for the guest-side G2H producer. -pub type G2hProducer = VirtqProducer>; +pub type G2hProducer = VirtqProducer; + +/// Type alias for the guest-side H2G producer (uses fixed-size RecyclePool slots). +pub type H2gProducer = VirtqProducer; + +/// Configuration for one queue passed to [`GuestContext::new`]. +pub struct QueueConfig { + /// Ring descriptor layout in shared memory. + pub layout: Layout, + /// Base GVA of the buffer pool region. + pub pool_gva: u64, + /// Number of pages in the buffer pool. + pub pool_pages: usize, +} /// Virtqueue runtime state for guest-host communication. pub struct GuestContext { - g2h_pool: Arc, g2h_producer: G2hProducer, + h2g_producer: H2gProducer, generation: u16, } impl GuestContext { - /// Create a new context with a G2H queue. - /// - /// # Safety - /// - /// `ring_gva` must point to valid, zeroed ring memory. - /// `pool_gva` must point to valid, zeroed memory. - pub unsafe fn new( - ring_gva: u64, - num_descs: u16, - pool_gva: u64, - pool_size: usize, - generation: u16, - ) -> Self { - let pool = Arc::new( - BufferPool::new(pool_gva, pool_size).expect("failed to create G2H buffer pool"), - ); - let nz = NonZeroU16::new(num_descs).expect("G2H queue depth must be non-zero"); - let layout = unsafe { Layout::from_base(ring_gva, nz) }.expect("invalid G2H ring layout"); - let producer = VirtqProducer::new(layout, GuestMemOps, GuestNotifier, pool.clone()); + /// Create a new context with G2H and H2G queues. + pub fn new(g2h: QueueConfig, h2g: QueueConfig, generation: u16) -> Self { + let size = g2h.pool_pages * PAGE_SIZE_USIZE; + let g2h_pool = + BufferPool::new(g2h.pool_gva, size).expect("failed to create G2H buffer pool"); + let g2h_producer = + VirtqProducer::new(g2h.layout, GuestMemOps, GuestNotifier, g2h_pool.clone()); - Self { - g2h_pool: pool, - g2h_producer: producer, + // Each H2G prefill entry is a single descriptor with one contiguous buffer: one + // fixed-size buffer per descriptor, large payloads split across multiple independent + // completions. + // + // TODO(virtq): consider smaller slot_size (e.g. pool_size / desc_count) to maximize + // prefilled entries for host-side call batching. + let size = h2g.pool_pages * PAGE_SIZE_USIZE; + let slot = PAGE_SIZE_USIZE; + let h2g_pool = + RecyclePool::new(h2g.pool_gva, size, slot).expect("failed to create H2G recycle pool"); + let h2g_producer = + VirtqProducer::new(h2g.layout, GuestMemOps, GuestNotifier, h2g_pool.clone()); + + let mut ctx = Self { + g2h_producer, + h2g_producer, generation, - } + }; + + ctx.prefill_h2g(); + ctx } /// Call a host function via the G2H virtqueue. @@ -146,10 +164,96 @@ impl GuestContext { Ok(ret) } + /// Pre-fill the H2G queue with completion-only descriptors so the host + /// can write incoming call payloads into them. + fn prefill_h2g(&mut self) { + loop { + let entry = match self + .h2g_producer + .chain() + .completion(PAGE_SIZE_USIZE) + .build() + { + Ok(e) => e, + Err(_) => break, + }; + if self.h2g_producer.submit(entry).is_err() { + break; + } + } + } + + /// Receive a host-to-guest function call from the H2G queue. + pub fn recv_h2g_call(&mut self) -> Result { + let Some(completion) = self.h2g_producer.poll()? else { + bail!("H2G: no pending call"); + }; + + let data = &completion.data; + if data.len() < VirtqMsgHeader::SIZE { + bail!("H2G: completion too short for header"); + } + + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&data[..VirtqMsgHeader::SIZE]); + + if hdr.kind != MsgKind::Request as u8 { + bail!("H2G: unexpected message kind"); + } + + let payload_end = VirtqMsgHeader::SIZE + hdr.payload_len as usize; + if payload_end > data.len() { + bail!("H2G: payload length exceeds completion data"); + } + + let payload = &data[VirtqMsgHeader::SIZE..payload_end]; + let fc = FunctionCall::try_from(payload)?; + Ok(fc) + } + + /// Send the result of a host-to-guest call back to the host via the + /// G2H queue, then refill one H2G descriptor slot. + pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { + // Build a Response message on the G2H queue + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let hdr = VirtqMsgHeader::new(MsgKind::Response, reqid, payload.len() as u32); + let hdr_bytes = bytemuck::bytes_of(&hdr); + + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + let mut entry = self.g2h_producer.chain().entry(entry_len).build()?; + + entry.write_all(hdr_bytes)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry)?; + + // Refill one H2G completion slot + if let Ok(e) = self + .h2g_producer + .chain() + .completion(PAGE_SIZE_USIZE) + .build() + { + let _ = self.h2g_producer.submit(e); + } + + Ok(()) + } + + /// Drain any pending G2H completions (discard them). + /// + /// This is called before checking for H2G calls so that the host + /// can reclaim G2H response buffers. + pub fn drain_g2h_completions(&mut self) { + while let Ok(Some(_)) = self.g2h_producer.poll() {} + } + /// Reset ring and pool state after snapshot restore. pub(super) fn reset(&mut self, new_generation: u16) { + // G2H producer reset also resets the pool via BufferProvider::reset() self.g2h_producer.reset(); - self.g2h_pool.reset(); + // H2G state is NOT reset. The guest's inflight and cursors + // survived via CoW and are already correct. The host's + // restore_h2g_prefill() wrote matching descriptors to the + // zeroed ring memory. Both sides are in sync. self.generation = new_generation; } diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs index 29f404740..9bb4d2348 100644 --- a/src/hyperlight_guest/src/virtq/mod.rs +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -26,7 +26,7 @@ use core::cell::UnsafeCell; use core::sync::atomic::{AtomicU8, Ordering}; use context::GuestContext; -use hyperlight_common::layout::{SCRATCH_TOP_VIRTQ_GENERATION_OFFSET, scratch_top_ptr}; +use hyperlight_common::layout::{SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET, scratch_top_ptr}; pub use mem::GuestMemOps; // Init state machine @@ -81,19 +81,16 @@ pub fn set_global_context(ctx: GuestContext) { /// Reset the global context if a snapshot restore was detected. /// Compares the virtq generation counter in scratch-top metadata. -pub fn reset_global_context() { +pub fn maybe_reset_global_context() { if !is_initialized() { return; } - let current_gen = read_gen(); + + let current_gen = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; + with_context(|ctx| { if current_gen != ctx.generation() { ctx.reset(current_gen); } }); } - -/// Read the current virtqueue generation from scratch-top metadata. -fn read_gen() -> u16 { - unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET) } -} diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index 37d02457b..fb71ef798 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -25,7 +25,7 @@ use hyperlight_guest::error::{HyperlightGuestError, Result}; use hyperlight_guest::virtq; use tracing::instrument; -use crate::{GUEST_HANDLE, REGISTERED_GUEST_FUNCTIONS}; +use crate::REGISTERED_GUEST_FUNCTIONS; #[instrument(skip_all, level = "Info")] pub(crate) fn call_guest_function(function_call: FunctionCall) -> Result> { @@ -86,34 +86,32 @@ pub(crate) fn internal_dispatch_function() { tracing::span!(tracing::Level::INFO, "internal_dispatch_function").entered() }; - let handle = unsafe { GUEST_HANDLE }; - // After snapshot restore, the ring memory is zeroed but the // producer's cursors are stale. Check once per dispatch entry. - virtq::reset_global_context(); + virtq::maybe_reset_global_context(); + virtq::with_context(|ctx| ctx.drain_g2h_completions()); - let function_call = handle - .try_pop_shared_input_data_into::() - .expect("Function call deserialization failed"); + let function_call = virtq::with_context(|ctx| { + ctx.recv_h2g_call() + .expect("H2G: expected a host-to-guest call") + }); let res = call_guest_function(function_call); - match res { - Ok(bytes) => { - handle - .push_shared_output_data(bytes.as_slice()) - .expect("Failed to serialize function call result"); - } + let res_bytes = match res { + Ok(bytes) => bytes, Err(err) => { let guest_error = Err(GuestError::new(err.kind, err.message)); let fcr = FunctionCallResult::new(guest_error); let mut builder = FlatBufferBuilder::new(); - let data = fcr.encode(&mut builder); - handle - .push_shared_output_data(data) - .expect("Failed to serialize function call result"); + fcr.encode(&mut builder).to_vec() } - } + }; + + virtq::with_context(|ctx| { + ctx.send_h2g_result(&res_bytes) + .expect("H2G: failed to send result"); + }); // All this tracing logic shall be done right before the call to `hlt` which is done after this // function returns diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq/mod.rs index 1c1b42300..fd90240df 100644 --- a/src/hyperlight_guest_bin/src/virtq/mod.rs +++ b/src/hyperlight_guest_bin/src/virtq/mod.rs @@ -16,15 +16,18 @@ limitations under the License. //! Guest-side virtqueue initialization. +use core::num::NonZeroU16; + use hyperlight_common::layout::{ - SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_G2H_RING_GVA_OFFSET, - SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, SCRATCH_TOP_H2G_RING_GVA_OFFSET, - SCRATCH_TOP_VIRTQ_GENERATION_OFFSET, SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET, scratch_top_ptr, + SCRATCH_TOP_G2H_POOL_PAGES_OFFSET, SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET, + SCRATCH_TOP_G2H_RING_GVA_OFFSET, SCRATCH_TOP_H2G_POOL_GVA_OFFSET, + SCRATCH_TOP_H2G_POOL_PAGES_OFFSET, SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET, + SCRATCH_TOP_H2G_RING_GVA_OFFSET, SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET, scratch_top_ptr, }; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::virtq::Layout as VirtqLayout; use hyperlight_guest::prim_alloc::alloc_phys_pages; -use hyperlight_guest::virtq::context::GuestContext; +use hyperlight_guest::virtq::context::{GuestContext, QueueConfig}; use crate::paging::phys_to_virt; @@ -34,12 +37,12 @@ pub(crate) fn init_virtqueues() { let g2h_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_QUEUE_DEPTH_OFFSET) }; let h2g_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_RING_GVA_OFFSET) }; let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; - let pool_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET) } as u64; - let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_VIRTQ_GENERATION_OFFSET) }; + let g2h_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET) } as usize; + let h2g_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET) } as usize; + let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; - assert!(g2h_depth > 0 && h2g_depth > 0); + assert!(g2h_depth > 0 && h2g_depth > 0 && g2h_pages > 0 && h2g_pages > 0); assert!(g2h_gva != 0 && h2g_gva != 0); - assert!(pool_pages > 0); // Zero ring memory let g2h_ring_size = VirtqLayout::query_size(g2h_depth as usize); @@ -48,16 +51,41 @@ pub(crate) fn init_virtqueues() { let h2g_ring_size = VirtqLayout::query_size(h2g_depth as usize); unsafe { core::ptr::write_bytes(h2g_gva as *mut u8, 0, h2g_ring_size) }; - // Allocate buffer pool from physical pages - let pool_gpa = unsafe { alloc_phys_pages(pool_pages) }; - let pool_ptr = phys_to_virt(pool_gpa).expect("failed to map pool pages"); - let pool_gva = pool_ptr as u64; - let pool_size = pool_pages as usize * PAGE_SIZE_USIZE; - unsafe { core::ptr::write_bytes(pool_ptr, 0, pool_size) }; + // Build ring layouts + let nz = NonZeroU16::new(g2h_depth).expect("G2H depth zero"); + let g2h_layout = unsafe { VirtqLayout::from_base(g2h_gva, nz) }.expect("invalid layout"); + + let nz = NonZeroU16::new(h2g_depth).expect("H2G depth zero"); + let h2g_layout = unsafe { VirtqLayout::from_base(h2g_gva, nz) }.expect("invalid layout"); + + // Allocate buffer pools + let g2h_pool_gva = alloc_pool(g2h_pages); + let h2g_pool_gva = alloc_pool(h2g_pages); - // Create and install global context - let ctx = unsafe { GuestContext::new(g2h_gva, g2h_depth, pool_gva, pool_size, generation) }; + // Publish H2G pool GVA so the host can prefill after restore + unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_GVA_OFFSET) = h2g_pool_gva }; + + let ctx = GuestContext::new( + QueueConfig { + layout: g2h_layout, + pool_gva: g2h_pool_gva, + pool_pages: g2h_pages, + }, + QueueConfig { + layout: h2g_layout, + pool_gva: h2g_pool_gva, + pool_pages: h2g_pages, + }, + generation, + ); hyperlight_guest::virtq::set_global_context(ctx); +} - let _ = (h2g_gva, h2g_depth); +/// Allocate and zero `n` physical pages, returning the GVA. +fn alloc_pool(n: usize) -> u64 { + let gpa = unsafe { alloc_phys_pages(n as u64) }; + let ptr = phys_to_virt(gpa).expect("failed to map pool pages"); + let size = n as usize * PAGE_SIZE_USIZE; + unsafe { core::ptr::write_bytes(ptr, 0, size) }; + ptr as u64 } diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index cee44b94b..ccf842268 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -495,6 +495,7 @@ impl SandboxMemoryLayout { } /// Get the size of the G2H ring in bytes. + #[allow(dead_code)] fn get_g2h_ring_size(&self) -> usize { hyperlight_common::virtq::Layout::query_size( self.sandbox_memory_config.get_g2h_queue_depth(), diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index f46361dfa..c65bab934 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -21,23 +21,11 @@ use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_call::{ FunctionCall, validate_guest_function_call_buffer, }; - -use super::virtq_mem::HostMemOps; - -/// No-op notifier for host-side consumer. -/// The host resumes the VM to notify the guest, not via the ring. -#[derive(Clone, Copy)] -pub(crate) struct HostNotifier; - -impl hyperlight_common::virtq::Notifier for HostNotifier { - fn notify(&self, _stats: hyperlight_common::virtq::QueueStats) {} -} - -/// Type alias for the host-side G2H virtqueue consumer. -pub(crate) type G2hConsumer = hyperlight_common::virtq::VirtqConsumer; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; -use hyperlight_common::virtq::Layout as VirtqLayout; +use hyperlight_common::mem::PAGE_SIZE_USIZE; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{self, Layout as VirtqLayout}; use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE, PageTableEntry, PhysAddr}; #[cfg(all(feature = "crashdump", not(feature = "nanvix-unstable")))] use hyperlight_common::vmem::{BasicMapping, MappingKind}; @@ -47,6 +35,7 @@ use super::layout::SandboxMemoryLayout; use super::shared_mem::{ ExclusiveSharedMemory, GuestSharedMemory, HostSharedMemory, ReadonlySharedMemory, SharedMemory, }; +use super::virtq_mem::HostMemOps; use crate::hypervisor::regs::CommonSpecialRegisters; use crate::mem::memory_region::MemoryRegion; #[cfg(crashdump)] @@ -54,6 +43,20 @@ use crate::mem::memory_region::{CrashDumpRegion, MemoryRegionFlags, MemoryRegion use crate::sandbox::snapshot::{NextAction, Snapshot}; use crate::{Result, new_error}; +/// Type alias for the host-side G2H virtqueue consumer. +pub(crate) type G2hConsumer = virtq::VirtqConsumer; +/// Type alias for the host-side H2G virtqueue consumer. +pub(crate) type H2gConsumer = virtq::VirtqConsumer; + +/// No-op notifier for host-side consumer. +/// The host resumes the VM to notify the guest, not via the ring. +#[derive(Clone, Copy)] +pub(crate) struct HostNotifier; + +impl virtq::Notifier for HostNotifier { + fn notify(&self, _stats: virtq::QueueStats) {} +} + #[cfg(all(feature = "crashdump", not(feature = "nanvix-unstable")))] fn mapping_kind_to_flags(kind: &MappingKind) -> (MemoryRegionFlags, MemoryRegionType) { match kind { @@ -165,9 +168,15 @@ pub(crate) struct SandboxMemoryManager { pub(crate) abort_buffer: Vec, /// G2H virtqueue consumer, created after sandbox init. pub(crate) g2h_consumer: Option, + /// H2G virtqueue consumer, created after sandbox init. + pub(crate) h2g_consumer: Option, + /// Saved H2G pool GVA for prefilling after snapshot restore. + pub(crate) h2g_pool_gva: Option, + /// Monotonically increasing snapshot generation counter. + snapshot_generation: u16, } -impl Clone for SandboxMemoryManager { +impl Clone for SandboxMemoryManager { fn clone(&self) -> Self { Self { shared_mem: self.shared_mem.clone(), @@ -176,7 +185,10 @@ impl Clone for SandboxMemoryManager { entrypoint: self.entrypoint, mapped_rgns: self.mapped_rgns, abort_buffer: self.abort_buffer.clone(), - g2h_consumer: None, // consumer is not cloned; re-init if needed + g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: self.h2g_pool_gva, + snapshot_generation: self.snapshot_generation, } } } @@ -289,6 +301,9 @@ where mapped_rgns: 0, abort_buffer: Vec::new(), g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, + snapshot_generation: 0, } } @@ -357,6 +372,9 @@ impl SandboxMemoryManager { mapped_rgns: self.mapped_rgns, abort_buffer: self.abort_buffer, g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, + snapshot_generation: 0, }; let guest_mgr = SandboxMemoryManager { shared_mem: gshm, @@ -366,9 +384,13 @@ impl SandboxMemoryManager { mapped_rgns: self.mapped_rgns, abort_buffer: Vec::new(), // Guest doesn't need abort buffer g2h_consumer: None, + h2g_consumer: None, + h2g_pool_gva: None, + snapshot_generation: 0, }; host_mgr.update_scratch_bookkeeping()?; host_mgr.init_g2h_consumer()?; + host_mgr.init_h2g_consumer()?; Ok((host_mgr, guest_mgr)) } } @@ -462,6 +484,7 @@ impl SandboxMemoryManager { /// Writes a guest function call to memory #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + #[allow(dead_code)] pub(crate) fn write_guest_function_call(&mut self, buffer: &[u8]) -> Result<()> { validate_guest_function_call_buffer(buffer).map_err(|e| { new_error!( @@ -480,6 +503,7 @@ impl SandboxMemoryManager { /// Reads a function call result from memory. /// A function call result can be either an error or a successful return value. + #[allow(dead_code)] #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn get_guest_function_call_result(&mut self) -> Result { self.scratch_mem.try_pop_buffer_into::( @@ -560,6 +584,8 @@ impl SandboxMemoryManager { self.layout = *snapshot.layout(); self.update_scratch_bookkeeping()?; self.init_g2h_consumer()?; + self.init_h2g_consumer()?; + self.restore_h2g_prefill()?; Ok((gsnapshot, gscratch)) } @@ -609,14 +635,18 @@ impl SandboxMemoryManager { self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16, )?; self.scratch_mem.write::( - scratch_size - SCRATCH_TOP_VIRTQ_POOL_PAGES_OFFSET as usize, - self.layout.sandbox_memory_config.get_virtq_pool_pages() as u16, + scratch_size - SCRATCH_TOP_G2H_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_g2h_pool_pages() as u16, + )?; + self.scratch_mem.write::( + scratch_size - SCRATCH_TOP_H2G_POOL_PAGES_OFFSET as usize, + self.layout.sandbox_memory_config.get_h2g_pool_pages() as u16, )?; // Increment generation so the guest detects stale ring state. - let gen_offset = scratch_size - SCRATCH_TOP_VIRTQ_GENERATION_OFFSET as usize; - let gen_val: u16 = self.scratch_mem.read(gen_offset).unwrap_or(0); + self.snapshot_generation = self.snapshot_generation.wrapping_add(1); + let gen_offset = scratch_size - SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET as usize; self.scratch_mem - .write::(gen_offset, gen_val.wrapping_add(1))?; + .write::(gen_offset, self.snapshot_generation)?; // Copy the page tables into the scratch region let snapshot_pt_end = self.shared_mem.mem_size(); @@ -859,7 +889,7 @@ impl SandboxMemoryManager { } /// Compute the G2H virtqueue Layout from scratch region addresses. - pub(crate) fn g2h_virtq_layout(&self) -> Result { + pub(crate) fn g2h_virtq_layout(&self) -> Result { let base = self.layout.get_g2h_ring_gva(); let depth = self.layout.sandbox_memory_config.get_g2h_queue_depth() as u16; @@ -870,7 +900,7 @@ impl SandboxMemoryManager { } /// Compute the H2G virtqueue Layout from scratch region addresses. - pub(crate) fn h2g_virtq_layout(&self) -> Result { + pub(crate) fn h2g_virtq_layout(&self) -> Result { let base = self.layout.get_h2g_ring_gva(); let depth = self.layout.sandbox_memory_config.get_h2g_queue_depth() as u16; @@ -898,13 +928,152 @@ impl SandboxMemoryManager { None => { let layout = self.g2h_virtq_layout()?; let mem_ops = self.host_mem_ops(); - let consumer = - hyperlight_common::virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + let consumer = virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); self.g2h_consumer = Some(consumer); } } Ok(()) } + + /// Initialize the H2G virtqueue consumer. + /// + /// Must be called after scratch bookkeeping is written. Avail suppression is set to Disable + /// so guest prefill/refill operations do not trigger VM exits. + pub(crate) fn init_h2g_consumer(&mut self) -> Result<()> { + match &mut self.h2g_consumer { + Some(consumer) => { + consumer.reset(); + consumer + .set_avail_suppression(virtq::SuppressionKind::Disable) + .map_err(|e| new_error!("H2G avail suppression: {:?}", e))?; + } + None => { + let layout = self.h2g_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let mut consumer = virtq::VirtqConsumer::new(layout, mem_ops, HostNotifier); + consumer + .set_avail_suppression(virtq::SuppressionKind::Disable) + .map_err(|e| new_error!("H2G avail suppression: {:?}", e))?; + self.h2g_consumer = Some(consumer); + } + } + Ok(()) + } + + /// Prefill the H2G ring with writable descriptors after snapshot restore. + /// + /// Uses a temporary `RingProducer` to write descriptors into the H2G ring + /// so the host consumer can poll them. The guest's `restore_from_ring` + /// will later reconstruct its inflight state from these descriptors. + pub(crate) fn restore_h2g_prefill(&mut self) -> Result<()> { + let pool_gva = match self.h2g_pool_gva { + Some(gva) => gva, + None => return Ok(()), + }; + + let layout = self.h2g_virtq_layout()?; + let mem_ops = self.host_mem_ops(); + let h2g_depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); + + // Pool size from config + let slot_size = PAGE_SIZE_USIZE; + let pool_size = self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE; + let slot_count = pool_size / slot_size; + + let mut producer = virtq::RingProducer::new(layout, mem_ops); + let prefill_count = core::cmp::min(slot_count, h2g_depth); + + // Write descriptors in reverse order to match the guest's LIFO + // allocation pattern (RecyclePool::alloc pops from the end of + // the free list, so the first prefill gets the highest address). + for i in (0..prefill_count).rev() { + let addr = pool_gva + (i * slot_size) as u64; + producer + .submit_one(addr, slot_size as u32, true) + .map_err(|e| new_error!("H2G prefill submit: {:?}", e))?; + } + + Ok(()) + } + + /// Write a guest function call into the H2G virtqueue. + /// + /// Polls the H2G consumer for a prefilled entry from the guest, + /// writes `VirtqMsgHeader::Request` followed by `buffer` into the + /// writable completion, and completes the entry. + pub(crate) fn write_guest_function_call_virtq(&mut self, buffer: &[u8]) -> Result<()> { + let consumer = self + .h2g_consumer + .as_mut() + .ok_or_else(|| new_error!("H2G consumer not initialized"))?; + + let (entry, completion) = consumer + .poll(8192) + .map_err(|e| new_error!("H2G poll: {:?}", e))? + .ok_or_else(|| new_error!("H2G: no prefilled entry available"))?; + + // Consume the entry data - this should be empty + drop(entry); + + let header = VirtqMsgHeader::new(MsgKind::Request, 0, buffer.len() as u32); + + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(new_error!( + "H2G: expected writable completion, got non-writable (ring corruption)" + )); + }; + + wc.write_all(bytemuck::bytes_of(&header)) + .map_err(|e| new_error!("H2G write header: {:?}", e))?; + wc.write_all(buffer) + .map_err(|e| new_error!("H2G write payload: {:?}", e))?; + + consumer + .complete(wc.into()) + .map_err(|e| new_error!("H2G complete: {:?}", e))?; + + Ok(()) + } + + /// Read the H2G result from G2H after the guest halts. + /// + /// The guest submitted the Response on G2H with + pub(crate) fn read_h2g_result_from_g2h(&mut self) -> Result { + let consumer = self + .g2h_consumer + .as_mut() + .ok_or_else(|| new_error!("G2H consumer not initialized"))?; + + let Some((entry, completion)) = consumer + .poll(8192) + .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))? + else { + return Err(new_error!("G2H: no H2G result entry after halt")); + }; + + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(new_error!("G2H: result entry too short")); + } + + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + if hdr.kind != MsgKind::Response as u8 { + return Err(new_error!( + "G2H: expected Response after halt, got kind={}", + hdr.kind + )); + } + + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let fcr = FunctionCallResult::try_from(payload) + .map_err(|e| new_error!("G2H: malformed FunctionCallResult: {}", e))?; + + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete: {:?}", e))?; + + Ok(fcr) + } } #[cfg(test)] diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index a329e5fd5..b3e5fd6d3 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -80,9 +80,14 @@ pub struct SandboxConfiguration { /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. /// Default: 32 h2g_queue_depth: usize, - /// Number of physical pages to allocate for each virtqueue's buffer pool. + /// Number of physical pages for the G2H (guest-to-host) buffer pool. + /// If not set, derived from `input_data_size` for backward compatibility. /// Default: 8 pages (32KB). - virtq_pool_pages: usize, + g2h_pool_pages: Option, + /// Number of physical pages for the H2G (host-to-guest) buffer pool. + /// If not set, derived from `output_data_size` for backward compatibility. + /// Default: 4 page (16KB). + h2g_pool_pages: Option, } impl SandboxConfiguration { @@ -106,8 +111,10 @@ impl SandboxConfiguration { pub const DEFAULT_G2H_QUEUE_DEPTH: usize = 64; /// The default H2G virtqueue depth (number of descriptors, must be power of 2) pub const DEFAULT_H2G_QUEUE_DEPTH: usize = 32; - /// The default number of physical pages per virtqueue buffer pool - pub const DEFAULT_VIRTQ_POOL_PAGES: usize = 8; + /// The default number of G2H buffer pool pages + pub const DEFAULT_G2H_POOL_PAGES: usize = 8; + /// The default number of H2G buffer pool pages + pub const DEFAULT_H2G_POOL_PAGES: usize = 4; #[allow(clippy::too_many_arguments)] /// Create a new configuration for a sandbox with the given sizes. @@ -131,7 +138,8 @@ impl SandboxConfiguration { interrupt_vcpu_sigrtmin_offset, g2h_queue_depth: Self::DEFAULT_G2H_QUEUE_DEPTH, h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, - virtq_pool_pages: Self::DEFAULT_VIRTQ_POOL_PAGES, + g2h_pool_pages: None, + h2g_pool_pages: None, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -139,15 +147,21 @@ impl SandboxConfiguration { } } - /// Set the size of the memory buffer that is made available for input to the guest - /// the minimum value is MIN_INPUT_SIZE + /// Set the size of the legacy input data buffer (host-to-guest). + /// + /// Deprecated: use [`set_h2g_pool_pages`](Self::set_h2g_pool_pages) instead. + /// When `h2g_pool_pages` is not set, the H2G pool size is derived + /// from this value for backward compatibility. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_input_data_size(&mut self, input_data_size: usize) { self.input_data_size = max(input_data_size, Self::MIN_INPUT_SIZE); } - /// Set the size of the memory buffer that is made available for output from the guest - /// the minimum value is MIN_OUTPUT_SIZE + /// Set the size of the legacy output data buffer (guest-to-host). + /// + /// Deprecated: use [`set_g2h_pool_pages`](Self::set_g2h_pool_pages) instead. + /// When `g2h_pool_pages` is not set, the G2H pool size is derived + /// from this value for backward compatibility. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_output_data_size(&mut self, output_data_size: usize) { self.output_data_size = max(output_data_size, Self::MIN_OUTPUT_SIZE); @@ -228,33 +242,65 @@ impl SandboxConfiguration { } /// Get the G2H virtqueue depth (number of descriptors). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_g2h_queue_depth(&self) -> usize { self.g2h_queue_depth } /// Get the H2G virtqueue depth (number of descriptors). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_h2g_queue_depth(&self) -> usize { self.h2g_queue_depth } - /// Get the number of physical pages per virtqueue buffer pool. - pub fn get_virtq_pool_pages(&self) -> usize { - self.virtq_pool_pages - } - /// Set the G2H virtqueue depth (number of descriptors, must be power of 2). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_g2h_queue_depth(&mut self, depth: usize) { self.g2h_queue_depth = depth; } /// Set the H2G virtqueue depth (number of descriptors, must be power of 2). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_h2g_queue_depth(&mut self, depth: usize) { self.h2g_queue_depth = depth; } - /// Set the number of physical pages per virtqueue buffer pool. - pub fn set_virtq_pool_pages(&mut self, pages: usize) { - self.virtq_pool_pages = pages; + /// Get the number of G2H buffer pool pages. + /// Falls back to deriving from `output_data_size` if not explicitly set + /// (output = guest-to-host direction). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn get_g2h_pool_pages(&self) -> usize { + self.g2h_pool_pages.unwrap_or_else(|| { + let pages = self + .output_data_size + .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); + pages.max(Self::DEFAULT_G2H_POOL_PAGES) + }) + } + + /// Get the number of H2G buffer pool pages. + /// Falls back to deriving from `input_data_size` if not explicitly set + /// (input = host-to-guest direction). + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn get_h2g_pool_pages(&self) -> usize { + self.h2g_pool_pages.unwrap_or_else(|| { + let pages = self + .input_data_size + .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); + pages.max(Self::DEFAULT_H2G_POOL_PAGES) + }) + } + + /// Set the number of G2H buffer pool pages. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_g2h_pool_pages(&mut self, pages: usize) { + self.g2h_pool_pages = Some(pages); + } + + /// Set the number of H2G buffer pool pages. + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + pub fn set_h2g_pool_pages(&mut self, pages: usize) { + self.h2g_pool_pages = Some(pages); } /// Set the size of the scratch regiong diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 642fb2772..270b7460c 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -700,7 +700,7 @@ impl MultiUseSandbox { let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); let buffer = fc.encode(&mut builder); - self.mem_mgr.write_guest_function_call(buffer)?; + self.mem_mgr.write_guest_function_call_virtq(buffer)?; let dispatch_res = self.vm.dispatch_call_from_host( &mut self.mem_mgr, @@ -717,7 +717,7 @@ impl MultiUseSandbox { return Err(error); } - let guest_result = self.mem_mgr.get_guest_function_call_result()?.into_inner(); + let guest_result = self.mem_mgr.read_h2g_result_from_g2h()?.into_inner(); match guest_result { Ok(val) => Ok(val), diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index aa40bec3d..0e11409ad 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -21,6 +21,8 @@ use hyperlight_common::flatbuffer_wrappers::function_types::{FunctionCallResult, use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::outb::{Exception, OutBAction}; +use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::{self}; use log::{Level, Record}; use tracing::{Span, instrument}; use tracing_log::format_trace; @@ -186,29 +188,39 @@ fn outb_virtq_call( mem_mgr: &mut SandboxMemoryManager, host_funcs: &Arc>, ) -> Result<(), HandleOutbError> { - use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; - let consumer = mem_mgr.g2h_consumer.as_mut().ok_or_else(|| { HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) })?; - let (entry, completion) = consumer + let Some((entry, completion)) = consumer .poll(8192) .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? - .ok_or_else(|| HandleOutbError::ReadHostFunctionCall("G2H poll: no entry".into()))?; + else { + // No G2H entry - can happen when guest H2G prefill + // triggers VirtqNotify before suppression is set. + return Ok(()); + }; - // Parse: skip VirtqMsgHeader, deserialize FunctionCall from remainder let entry_data = entry.data(); if entry_data.len() < VirtqMsgHeader::SIZE { return Err(HandleOutbError::ReadHostFunctionCall( "G2H entry too short".into(), )); } + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); let payload = &entry_data[VirtqMsgHeader::SIZE..]; + + // TODO(virtq): Only Requests (host function callbacks) arrive via outb. + if hdr.kind != MsgKind::Request as u8 { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: expected Request via outb, got kind={}", + hdr.kind + ))); + } + let call = FunctionCall::try_from(payload) .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; - // Dispatch the host function (same as CallFunction path) let name = call.function_name.clone(); let args: Vec = call.parameters.unwrap_or(vec![]); let res = host_funcs @@ -226,22 +238,19 @@ fn outb_virtq_call( let resp_header_bytes = bytemuck::bytes_of(&resp_header); // Write response into the completion buffer - match completion { - hyperlight_common::virtq::SendCompletion::Writable(mut wc) => { - wc.write_all(resp_header_bytes) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - wc.write_all(result_payload) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - consumer - .complete(wc.into()) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - } - hyperlight_common::virtq::SendCompletion::Ack(ack) => { - consumer - .complete(ack.into()) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; - } - } + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(HandleOutbError::WriteHostFunctionResponse( + "G2H: expected writable completion, got ack (ring corruption)".into(), + )); + }; + + wc.write_all(resp_header_bytes) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + wc.write_all(result_payload) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; + consumer + .complete(wc.into()) + .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; Ok(()) } diff --git a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs index 428594d37..6e02cfe26 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs @@ -16,6 +16,7 @@ limitations under the License. #[cfg(gdb)] use std::sync::{Arc, Mutex}; +use hyperlight_common::layout::SCRATCH_TOP_H2G_POOL_GVA_OFFSET; use rand::RngExt; use tracing::{Span, instrument}; @@ -26,7 +27,7 @@ use crate::hypervisor::hyperlight_vm::{HyperlightVm, HyperlightVmError}; use crate::mem::exe::LoadInfo; use crate::mem::mgr::SandboxMemoryManager; use crate::mem::ptr::RawPtr; -use crate::mem::shared_mem::GuestSharedMemory; +use crate::mem::shared_mem::{GuestSharedMemory, SharedMemory}; #[cfg(gdb)] use crate::sandbox::config::DebugInfo; #[cfg(feature = "mem_profile")] @@ -131,6 +132,18 @@ pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result(offset) + && gva != 0 + { + hshm.h2g_pool_gva = Some(gva); + } + } + #[cfg(gdb)] let dbg_mem_wrapper = Arc::new(Mutex::new(hshm.clone())); diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 3b8e0fcb6..a2bb2d91a 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -547,7 +547,8 @@ fn guest_malloc_abort() { cfg.set_heap_size(heap_size); cfg.set_g2h_queue_depth(2); cfg.set_h2g_queue_depth(2); - cfg.set_virtq_pool_pages(2); + cfg.set_g2h_pool_pages(3); + cfg.set_h2g_pool_pages(1); with_rust_sandbox_cfg(cfg, |mut sbox2| { let err = sbox2 .call::( @@ -626,7 +627,8 @@ fn guest_panic_no_alloc() { cfg.set_heap_size(heap_size); cfg.set_g2h_queue_depth(2); cfg.set_h2g_queue_depth(2); - cfg.set_virtq_pool_pages(2); + cfg.set_g2h_pool_pages(3); + cfg.set_h2g_pool_pages(1); with_rust_sandbox_cfg(cfg, |mut sbox| { let res = sbox .call::( @@ -1687,7 +1689,7 @@ fn exception_handler_installation_and_validation() { #[test] fn fill_heap_and_cause_exception() { let mut cfg = SandboxConfiguration::default(); - cfg.set_virtq_pool_pages(2); + cfg.set_scratch_size(0x60000); with_rust_sandbox_cfg(cfg, |mut sandbox| { let result = sandbox.call::<()>("FillHeapAndCauseException", ()); From da79c03d95044d99756e36a09cb82812fb3ed9b6 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 3 Apr 2026 14:16:11 +0200 Subject: [PATCH 08/26] feat(virtq): cleanup send + sync bounds Signed-off-by: Tomasz Andrzejak --- Cargo.lock | 7 - src/hyperlight_common/Cargo.toml | 1 - src/hyperlight_common/src/virtq/access.rs | 31 ++ src/hyperlight_common/src/virtq/buffer.rs | 158 +++++++++ src/hyperlight_common/src/virtq/mod.rs | 11 +- src/hyperlight_common/src/virtq/pool.rs | 312 ++++++++---------- src/hyperlight_common/src/virtq/producer.rs | 8 +- .../src/virtq/recycle_pool.rs | 120 ------- src/hyperlight_common/src/virtq/ring.rs | 47 +-- src/hyperlight_guest/src/virtq/context.rs | 5 +- src/tests/rust_guests/dummyguest/Cargo.lock | 7 - src/tests/rust_guests/simpleguest/Cargo.lock | 7 - src/tests/rust_guests/witguest/Cargo.lock | 7 - 13 files changed, 360 insertions(+), 361 deletions(-) create mode 100644 src/hyperlight_common/src/virtq/buffer.rs delete mode 100644 src/hyperlight_common/src/virtq/recycle_pool.rs diff --git a/Cargo.lock b/Cargo.lock index 3286d2fa0..994888e21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -169,12 +169,6 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" -[[package]] -name = "atomic_refcell" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" - [[package]] name = "autocfg" version = "1.5.0" @@ -1439,7 +1433,6 @@ version = "0.14.0" dependencies = [ "anyhow", "arbitrary", - "atomic_refcell", "bitflags 2.11.0", "bytemuck", "bytes", diff --git a/src/hyperlight_common/Cargo.toml b/src/hyperlight_common/Cargo.toml index c6f961b66..99d618cd0 100644 --- a/src/hyperlight_common/Cargo.toml +++ b/src/hyperlight_common/Cargo.toml @@ -17,7 +17,6 @@ workspace = true [dependencies] arbitrary = {version = "1.4.2", optional = true, features = ["derive"]} anyhow = { version = "1.0.102", default-features = false } -atomic_refcell = "0.1.13" bitflags = "2.10.0" bytemuck = { version = "1.24", features = ["derive"] } bytes = { version = "1", default-features = false } diff --git a/src/hyperlight_common/src/virtq/access.rs b/src/hyperlight_common/src/virtq/access.rs index 4daba3178..f569453c4 100644 --- a/src/hyperlight_common/src/virtq/access.rs +++ b/src/hyperlight_common/src/virtq/access.rs @@ -20,6 +20,8 @@ limitations under the License. //! required by the virtqueue implementation. This allows the virtqueue code to //! work with different memory backends e.g. Host vs Guest. +use alloc::sync::Arc; + use bytemuck::Pod; /// Backend-provided memory access for virtqueue. @@ -134,3 +136,32 @@ pub trait MemOps { Ok(()) } } + +impl MemOps for Arc { + type Error = T::Error; + + fn read(&self, addr: u64, dst: &mut [u8]) -> Result { + (**self).read(addr, dst) + } + + fn write(&self, addr: u64, src: &[u8]) -> Result { + (**self).write(addr, src) + } + + fn load_acquire(&self, addr: u64) -> Result { + (**self).load_acquire(addr) + } + + fn store_release(&self, addr: u64, val: u16) -> Result<(), Self::Error> { + (**self).store_release(addr, val) + } + + unsafe fn as_slice(&self, addr: u64, len: usize) -> Result<&[u8], Self::Error> { + unsafe { (**self).as_slice(addr, len) } + } + + #[allow(clippy::mut_from_ref)] + unsafe fn as_mut_slice(&self, addr: u64, len: usize) -> Result<&mut [u8], Self::Error> { + unsafe { (**self).as_mut_slice(addr, len) } + } +} diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs new file mode 100644 index 000000000..238775c6d --- /dev/null +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -0,0 +1,158 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +//! Buffer allocation traits and shared types for virtqueue buffer management. + +use alloc::rc::Rc; +use alloc::sync::Arc; + +use thiserror::Error; + +use super::access::MemOps; + +#[derive(Debug, Error, Copy, Clone)] +pub enum AllocError { + #[error("Invalid region addr {0}")] + InvalidAlign(u64), + #[error("Invalid free addr {0} and size {1}")] + InvalidFree(u64, usize), + #[error("Invalid argument")] + InvalidArg, + #[error("Empty region")] + EmptyRegion, + #[error("Out of memory")] + OutOfMemory, + #[error("Overflow")] + Overflow, +} + +/// Allocation result +#[derive(Debug, Clone, Copy)] +pub struct Allocation { + /// Starting address of the allocation + pub addr: u64, + /// Length of the allocation in bytes rounded up to slab size + pub len: usize, +} + +/// Trait for buffer providers. +pub trait BufferProvider { + /// Allocate at least `len` bytes. + fn alloc(&self, len: usize) -> Result; + + /// Free a previously allocated block. + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError>; + + /// Resize by trying in-place grow; otherwise reserve a new block and free old. + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; + + /// Reset the pool to initial state. + fn reset(&self) {} +} + +impl BufferProvider for Rc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } + fn reset(&self) { + (**self).reset() + } +} + +impl BufferProvider for Arc { + fn alloc(&self, len: usize) -> Result { + (**self).alloc(len) + } + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + (**self).dealloc(alloc) + } + fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { + (**self).resize(old_alloc, new_len) + } + fn reset(&self) { + (**self).reset() + } +} + +/// The owner of a mapped buffer, ensuring its lifetime. +/// +/// Holds a pool allocation and provides direct access to the underlying +/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it +/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for +/// zero-copy `Bytes` backed by shared memory. +/// +/// When dropped, the allocation is returned to the pool. +#[derive(Debug, Clone)] +pub struct BufferOwner { + pub(crate) pool: P, + pub(crate) mem: M, + pub(crate) alloc: Allocation, + pub(crate) written: usize, +} + +impl Drop for BufferOwner { + fn drop(&mut self) { + let _ = self.pool.dealloc(self.alloc); + } +} + +impl AsRef<[u8]> for BufferOwner { + fn as_ref(&self) -> &[u8] { + let len = self.written.min(self.alloc.len); + // Safety: BufferOwner keeps both the pool allocation and the M + // alive, so the memory region is valid. Protocol-level descriptor + // ownership transfer guarantees no concurrent writes. + match unsafe { self.mem.as_slice(self.alloc.addr, len) } { + Ok(slice) => slice, + Err(_) => &[], + } + } +} + +/// A guard that runs a cleanup function when dropped, unless dismissed. +pub struct AllocGuard(Option<(Allocation, F)>); + +impl AllocGuard { + pub fn new(alloc: Allocation, cleanup: F) -> Self { + Self(Some((alloc, cleanup))) + } + + pub fn release(mut self) -> Allocation { + self.0.take().unwrap().0 + } +} + +impl core::ops::Deref for AllocGuard { + type Target = Allocation; + + fn deref(&self) -> &Allocation { + &self.0.as_ref().unwrap().0 + } +} + +impl Drop for AllocGuard { + fn drop(&mut self) { + if let Some((alloc, cleanup)) = self.0.take() { + cleanup(alloc) + } + } +} diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 5e9fc7e5f..5f125b72c 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -151,18 +151,19 @@ limitations under the License. //! ``` mod access; +mod buffer; mod consumer; mod desc; mod event; pub mod msg; mod pool; mod producer; -pub mod recycle_pool; mod ring; use core::num::NonZeroU16; pub use access::*; +pub use buffer::*; pub use consumer::*; pub use desc::*; pub use event::*; @@ -440,8 +441,8 @@ pub(crate) mod test_utils { } } - type TestProducer = VirtqProducer, TestNotifier, TestPool>; - type TestConsumer = VirtqConsumer, TestNotifier>; + type TestProducer = VirtqProducer; + type TestConsumer = VirtqConsumer; /// Create test infrastructure: a producer, consumer, and notifier backed /// by the supplied [`OwnedRing`]. @@ -474,7 +475,7 @@ mod tests { /// Helper: build and submit an entry+completion chain using the chain() builder. fn send_readwrite( - producer: &mut VirtqProducer, TestNotifier, TestPool>, + producer: &mut VirtqProducer, entry_data: &[u8], cqe_cap: usize, ) -> Token { @@ -957,7 +958,7 @@ mod fuzz { } } - impl MemOps for Arc { + impl MemOps for LoomMem { type Error = MemErr; fn read(&self, addr: u64, dst: &mut [u8]) -> Result { diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index cf0915fdf..42325a56c 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -13,50 +13,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -//! Simple bitmap-based allocator for virtio buffer management. +//! Buffer pool implementations for virtqueue buffer management. //! -//! This module provides two layers: +//! This module provides concrete buffer allocators: //! -//! - [`Slab`] - a fixed-size region allocator with a power-of-two slot size `N`, -//! backed by a flat bitmap (`FixedBitSet`). -//! - [`BufferPool`] - a two-tier pool that composes two slabs: one with small -//! slots (e.g. 256 bytes) for control messages / small descriptors, and one -//! with page-sized slots (e.g. 4 KiB) for data buffers. +//! - [`BufferPool`] - a two-tier bitmap pool with small and large slabs, +//! intended for G2H descriptors where allocation sizes vary. +//! - [`RecyclePool`] - a fixed-size free-list recycler for H2G prefill +//! entries where all buffers are the same size. //! -//! # Design and algorithm +//! Both implement [`BufferProvider`] from the [`super::buffer`] module. +//! +//! # BufferPool design //! //! The core allocation strategy is a bitmap allocator that performs a linear //! search over the bitmap, but implemented via `fixedbitset`'s SIMD iteration -//! over zero bits. This is conceptually simpler than tree-based allocators -//! (e.g. linked lists or bitmaps representing a tree as in -//! ), yet for "moderate" region sizes it can -//! be faster in practice: -//! -//! - `FixedBitSet::zeroes()` and related methods use word/SIMD operations to -//! skip over runs of set bits, so the linear search is over words rather than -//! individual bits. -//! - We scan for a contiguous run of free bits corresponding to the required -//! number of slots; no auxiliary tree structure is maintained. -//! -//! The tree-based approach (bitmap encoding a tree and doing a binary search -//! in O(log(n)) time) is a natural next step if larger regions or stricter worst -//! case bounds are required; switching to such a representation should be -//! relatively straightforward since all allocation paths go through a single -//! `find_slots` function. -//! -//! # Locality characteristics +//! over zero bits. //! -//! The allocator tends to preserve spatial locality: -//! -//! - It searches from low indices upward, returning the first run of free -//! slots large enough for the request. Slots are merged if necessary. -//! - Freed runs are cached in `last_free_run` and reused eagerly, which -//! introduces a mild LIFO behavior for recently freed blocks. -//! - As a result, consecutive allocations are likely to end up in nearby slots, -//! which keeps virtqueue descriptors, control buffers, and data buffers -//! clustered in memory and helps cache performance. -//! -//! # Two-tier buffer pool +//! # Two-tier layout //! //! [`BufferPool`] divides the underlying region into two slabs with different //! slot sizes: @@ -66,159 +40,40 @@ limitations under the License. //! small structures. Small allocations first try this tier. //! - The upper tier (`Slab`, default `U = 4096`) uses page sized slots //! and is intended for larger data buffers. -//! -//! The split of the region is currently fixed at a constant fraction -//! (`LOWER_FRACTION`) for the lower slab and the remainder for the upper slab. -//! -//! Allocation policy: -//! -//! - Requests `<= L` bytes are first attempted in the lower slab; on -//! `OutOfMemory` they fall back to the upper slab. -//! - Larger requests go directly to the upper slab. -//! - [`BufferPool::resize`] will try to grow or shrink in place within the -//! owning slab (`Slab::resize`) but will never move allocations between -//! slabs. - -use alloc::sync::Arc; + +use alloc::rc::Rc; +use core::cell::RefCell; use core::cmp::Ordering; +use core::ops::Deref; -use atomic_refcell::AtomicRefCell; use fixedbitset::FixedBitSet; -use thiserror::Error; - -use super::access::MemOps; - -#[derive(Debug, Error, Copy, Clone)] -pub enum AllocError { - #[error("Invalid region addr {0}")] - InvalidAlign(u64), - #[error("Invalid free addr {0} and size {1}")] - InvalidFree(u64, usize), - #[error("Invalid argument")] - InvalidArg, - #[error("Empty region")] - EmptyRegion, - #[error("Out of memory")] - OutOfMemory, - #[error("Overflow")] - Overflow, -} - -/// Allocation result -#[derive(Debug, Clone, Copy)] -pub struct Allocation { - /// Starting address of the allocation - pub addr: u64, - /// Length of the allocation in bytes rounded up to slab size - pub len: usize, -} - -/// Trait for buffer providers. -pub trait BufferProvider { - /// Allocate at least `len` bytes. - fn alloc(&self, len: usize) -> Result; - - /// Free a previously allocated block. - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError>; - - /// Resize by trying in-place grow; otherwise reserve a new block and free old. - fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result; - - /// Reset the pool to initial state. - fn reset(&self) {} -} - -impl BufferProvider for alloc::rc::Rc { - fn alloc(&self, len: usize) -> Result { - (**self).alloc(len) - } - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { - (**self).dealloc(alloc) - } - fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { - (**self).resize(old_alloc, new_len) - } - fn reset(&self) { - (**self).reset() - } -} +use smallvec::SmallVec; -impl BufferProvider for Arc { - fn alloc(&self, len: usize) -> Result { - (**self).alloc(len) - } - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { - (**self).dealloc(alloc) - } - fn resize(&self, old_alloc: Allocation, new_len: usize) -> Result { - (**self).resize(old_alloc, new_len) - } - fn reset(&self) { - (**self).reset() - } -} +use super::buffer::{AllocError, Allocation, BufferProvider}; -/// The owner of a mapped buffer, ensuring its lifetime. +/// Wrapper asserting `Send + Sync` for single-threaded contexts. /// -/// Holds a pool allocation and provides direct access to the underlying -/// shared memory via [`MemOps::as_slice`]. Implements `AsRef<[u8]>` so it -/// can be used with [`Bytes::from_owner`](bytes::Bytes::from_owner) for -/// zero-copy `Bytes` backed by shared memory. +/// # Safety /// -/// When dropped, the allocation is returned to the pool. -#[derive(Debug, Clone)] -pub struct BufferOwner { - pub(crate) pool: P, - pub(crate) mem: M, - pub(crate) alloc: Allocation, - pub(crate) written: usize, -} - -impl Drop for BufferOwner { - fn drop(&mut self) { - let _ = self.pool.dealloc(self.alloc); - } -} - -impl AsRef<[u8]> for BufferOwner { - fn as_ref(&self) -> &[u8] { - let len = self.written.min(self.alloc.len); - // Safety: BufferOwner keeps both the pool allocation and the M - // alive, so the memory region is valid. Protocol-level descriptor - // ownership transfer guarantees no concurrent writes. - match unsafe { self.mem.as_slice(self.alloc.addr, len) } { - Ok(slice) => slice, - Err(_) => &[], - } - } -} - -/// A guard that runs a cleanup function when dropped, unless dismissed. -pub struct AllocGuard(Option<(Allocation, F)>); - -impl AllocGuard { - pub fn new(alloc: Allocation, cleanup: F) -> Self { - Self(Some((alloc, cleanup))) - } - - pub fn release(mut self) -> Allocation { - self.0.take().unwrap().0 - } -} +/// The wrapped value must only be accessed from a single thread. +#[derive(Debug)] +pub(super) struct SyncWrap(pub(super) T); -impl core::ops::Deref for AllocGuard { - type Target = Allocation; +// SAFETY: The wrapped value must only be accessed from a single thread. +unsafe impl Send for SyncWrap {} +// SAFETY: The wrapped value must only be accessed from a single thread. +unsafe impl Sync for SyncWrap {} - fn deref(&self) -> &Allocation { - &self.0.as_ref().unwrap().0 +impl Clone for SyncWrap { + fn clone(&self) -> Self { + Self(self.0.clone()) } } -impl Drop for AllocGuard { - fn drop(&mut self) { - if let Some((alloc, cleanup)) = self.0.take() { - cleanup(alloc) - } +impl Deref for SyncWrap { + type Target = T; + fn deref(&self) -> &T { + &self.0 } } @@ -550,8 +405,7 @@ struct Inner { /// Two tier buffer pool with small and large slabs. #[derive(Debug, Clone)] pub struct BufferPool { - // TODO: Use Rc instead, relax Sync + Send bounds - inner: Arc>>, + inner: SyncWrap>>>, } impl BufferPool { @@ -559,7 +413,7 @@ impl BufferPool { pub fn new(base_addr: u64, region_len: usize) -> Result { let inner = Inner::::new(base_addr, region_len)?; Ok(Self { - inner: Arc::new(inner.into()), + inner: SyncWrap(Rc::new(RefCell::new(inner))), }) } } @@ -700,6 +554,102 @@ impl BufferProvider for BufferPoolSync { } } +struct RecyclePoolInner { + base_addr: u64, + slot_size: usize, + count: usize, + free: SmallVec<[u64; 64]>, +} + +/// A recycling buffer provider with fixed-size slots. +/// +/// Unlike [`BufferPool`] which uses a bitmap allocator, this holds a +/// fixed set of same-sized buffer addresses in a free list. Alloc and +/// dealloc are O(1). Intended for H2G writable buffers that are +/// pre-allocated once and recycled after each use. +#[derive(Clone)] +pub struct RecyclePool { + inner: SyncWrap>>, +} + +impl RecyclePool { + /// Create a new recycling pool by carving `base..base+region_len` into slots of `slot_size` bytes. + pub fn new(base_addr: u64, region_len: usize, slot_size: usize) -> Result { + if slot_size == 0 { + return Err(AllocError::InvalidArg); + } + + let count = region_len / slot_size; + if count == 0 { + return Err(AllocError::EmptyRegion); + } + + let mut free = SmallVec::with_capacity(count); + for i in 0..count { + free.push(base_addr + (i * slot_size) as u64); + } + + let inner = RefCell::new(RecyclePoolInner { + base_addr, + slot_size, + count, + free, + }); + + Ok(Self { + inner: SyncWrap(Rc::new(inner)), + }) + } + + /// Number of free slots. + pub fn num_free(&self) -> usize { + self.inner.borrow().free.len() + } +} + +impl BufferProvider for RecyclePool { + fn alloc(&self, len: usize) -> Result { + let mut inner = self.inner.borrow_mut(); + if len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + + let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; + + Ok(Allocation { + addr, + len: inner.slot_size, + }) + } + + fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { + let mut inner = self.inner.borrow_mut(); + inner.free.push(alloc.addr); + Ok(()) + } + + fn resize(&self, old: Allocation, new_len: usize) -> Result { + let inner = self.inner.borrow(); + if new_len > inner.slot_size { + return Err(AllocError::OutOfMemory); + } + Ok(old) + } + + fn reset(&self) { + let mut inner = self.inner.borrow_mut(); + let base = inner.base_addr; + let slot = inner.slot_size; + let count = inner.count; + + inner.free.clear(); + + for i in 0..count { + inner.free.push(base + (i * slot) as u64); + } + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 5e6a7edf1..276a2ff3b 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -168,8 +168,8 @@ where /// wrote more data than the completion buffer capacity pub fn poll(&mut self) -> Result, VirtqError> where - M: Send + Sync + 'static, - P: Send + Sync + 'static, + M: Send + 'static, + P: Send + 'static, { let used = match self.inner.poll_used() { Ok(u) => u, @@ -234,8 +234,8 @@ where /// ``` pub fn drain(&mut self, mut f: impl FnMut(Token, Bytes)) -> Result<(), VirtqError> where - M: Send + Sync + 'static, - P: Send + Sync + 'static, + M: Send + 'static, + P: Send + 'static, { while let Some(cqe) = self.poll()? { f(cqe.token, cqe.data); diff --git a/src/hyperlight_common/src/virtq/recycle_pool.rs b/src/hyperlight_common/src/virtq/recycle_pool.rs deleted file mode 100644 index 4bcf9978a..000000000 --- a/src/hyperlight_common/src/virtq/recycle_pool.rs +++ /dev/null @@ -1,120 +0,0 @@ -/* -Copyright 2026 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -//! A simple fixed-size buffer recycler for H2G prefill entries. -//! -//! Unlike [`super::BufferPool`] which uses a bitmap allocator, this -//! holds a fixed set of same-sized buffer addresses in a free list. -//! Alloc and dealloc are O(1). Intended for H2G writable buffers -//! that are pre-allocated once and recycled after each use. - -use alloc::sync::Arc; - -use atomic_refcell::AtomicRefCell; -use smallvec::SmallVec; - -use super::{AllocError, Allocation, BufferProvider}; - -/// A recycling buffer provider with fixed-size slots. -#[derive(Clone)] -pub struct RecyclePool { - inner: Arc>, -} - -struct RecyclePoolInner { - base_addr: u64, - slot_size: usize, - count: usize, - free: SmallVec<[u64; 64]>, -} - -impl RecyclePool { - /// Create a new recycling pool by carving `base..base+region_len` into slots of `slot_size` bytes. - pub fn new(base_addr: u64, region_len: usize, slot_size: usize) -> Result { - if slot_size == 0 { - return Err(AllocError::InvalidArg); - } - - let count = region_len / slot_size; - if count == 0 { - return Err(AllocError::EmptyRegion); - } - - let mut free = SmallVec::with_capacity(count); - for i in 0..count { - free.push(base_addr + (i * slot_size) as u64); - } - - let inner = AtomicRefCell::new(RecyclePoolInner { - base_addr, - slot_size, - count, - free, - }); - - Ok(Self { - inner: inner.into(), - }) - } - - /// Number of free slots. - pub fn num_free(&self) -> usize { - self.inner.borrow().free.len() - } -} - -impl BufferProvider for RecyclePool { - fn alloc(&self, len: usize) -> Result { - let mut inner = self.inner.borrow_mut(); - if len > inner.slot_size { - return Err(AllocError::OutOfMemory); - } - - let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; - - Ok(Allocation { - addr, - len: inner.slot_size, - }) - } - - fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { - let mut inner = self.inner.borrow_mut(); - inner.free.push(alloc.addr); - Ok(()) - } - - fn resize(&self, old: Allocation, new_len: usize) -> Result { - let inner = self.inner.borrow(); - if new_len > inner.slot_size { - return Err(AllocError::OutOfMemory); - } - Ok(old) - } - - fn reset(&self) { - let mut inner = self.inner.borrow_mut(); - let base = inner.base_addr; - let slot = inner.slot_size; - let count = inner.count; - - inner.free.clear(); - - for i in 0..count { - inner.free.push(base + (i * slot) as u64); - } - } -} diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 75508afd1..bf5eba0f2 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -1268,46 +1268,53 @@ pub(crate) mod tests { /// Test MemOps implementation that maintains pointer provenance. /// - /// This wraps a Vec and provides memory access using the Vec's - /// base pointer to preserve provenance for Miri. + /// Wraps shared storage behind Arc for cheap cloning. This allows + /// producer and consumer to share the same backing memory without + /// Arc appearing in the type signatures. + #[derive(Clone)] pub struct TestMem { + inner: Arc, + } + + struct TestMemInner { /// The backing storage - UnsafeCell for interior mutability storage: UnsafeCell>, /// Base address (the address we tell the ring about) base_addr: u64, } + // Safety: TestMemInner's UnsafeCell is only accessed from test code + // with no real concurrency in unit tests (loom tests use LoomMem). + unsafe impl Send for TestMemInner {} + unsafe impl Sync for TestMemInner {} + impl TestMem { pub fn new(size: usize) -> Self { let storage = vec![0u8; size]; let base_addr = storage.as_ptr() as u64; Self { - storage: UnsafeCell::new(storage), - base_addr, + inner: Arc::new(TestMemInner { + storage: UnsafeCell::new(storage), + base_addr, + }), } } /// Get a pointer with proper provenance for the given address fn ptr_for_addr(&self, addr: u64) -> *mut u8 { - let storage = unsafe { &mut *self.storage.get() }; + let storage = unsafe { &mut *self.inner.storage.get() }; let base_ptr = storage.as_mut_ptr(); - let offset = (addr - self.base_addr) as usize; + let offset = (addr - self.inner.base_addr) as usize; // Use wrapping_add to maintain provenance from base_ptr base_ptr.wrapping_add(offset) } pub fn base_addr(&self) -> u64 { - self.base_addr + self.inner.base_addr } } - // Safety: TestMem's UnsafeCell is only accessed from test code with no - // real concurrency in unit tests (loom tests use their own LoomMem). - // Required so Arc satisfies Send + Sync for Bytes::from_owner. - unsafe impl Send for TestMem {} - unsafe impl Sync for TestMem {} - - impl MemOps for Arc { + impl MemOps for TestMem { type Error = core::convert::Infallible; fn read(&self, addr: u64, dst: &mut [u8]) -> Result { @@ -1361,7 +1368,7 @@ pub(crate) mod tests { /// Owns the descriptor table and event suppression structures pub struct OwnedRing { - mem: Arc, + mem: TestMem, layout: Layout, } @@ -1379,7 +1386,7 @@ pub(crate) mod tests { // pool size = 0x8000). let padding = Descriptor::ALIGN; let pool_headroom = 0x100 + 0x8000; - let mem = Arc::new(TestMem::new(needed + padding + pool_headroom)); + let mem = TestMem::new(needed + padding + pool_headroom); // Align the base address let aligned_base = align_up(mem.base_addr() as usize, Descriptor::ALIGN) as u64; @@ -1392,7 +1399,7 @@ pub(crate) mod tests { self.layout } - pub fn mem(&self) -> Arc { + pub fn mem(&self) -> TestMem { self.mem.clone() } @@ -1431,15 +1438,15 @@ pub(crate) mod tests { OwnedRing::new(size) } - pub(crate) fn make_producer(ring: &OwnedRing) -> RingProducer> { + pub(crate) fn make_producer(ring: &OwnedRing) -> RingProducer { RingProducer::new(ring.layout(), ring.mem()) } - pub(crate) fn make_consumer(ring: &OwnedRing) -> RingConsumer> { + pub(crate) fn make_consumer(ring: &OwnedRing) -> RingConsumer { RingConsumer::new(ring.layout(), ring.mem()) } - fn assert_invariants(ring: &OwnedRing, prod: &RingProducer>) { + fn assert_invariants(ring: &OwnedRing, prod: &RingProducer) { let outstanding: u16 = prod.id_num.iter().copied().sum(); assert_eq!(outstanding as usize + prod.num_free, ring.len()); diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index e02c7ef52..fd7b00142 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -30,8 +30,9 @@ use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::outb::OutBAction; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; -use hyperlight_common::virtq::recycle_pool::RecyclePool; -use hyperlight_common::virtq::{BufferPool, Layout, Notifier, QueueStats, VirtqProducer}; +use hyperlight_common::virtq::{ + BufferPool, Layout, Notifier, QueueStats, RecyclePool, VirtqProducer, +}; use super::GuestMemOps; use crate::bail; diff --git a/src/tests/rust_guests/dummyguest/Cargo.lock b/src/tests/rust_guests/dummyguest/Cargo.lock index b6aaae23c..9ac9e9bdc 100644 --- a/src/tests/rust_guests/dummyguest/Cargo.lock +++ b/src/tests/rust_guests/dummyguest/Cargo.lock @@ -8,12 +8,6 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" -[[package]] -name = "atomic_refcell" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" - [[package]] name = "bitflags" version = "2.11.0" @@ -124,7 +118,6 @@ name = "hyperlight-common" version = "0.14.0" dependencies = [ "anyhow", - "atomic_refcell", "bitflags", "bytemuck", "bytes", diff --git a/src/tests/rust_guests/simpleguest/Cargo.lock b/src/tests/rust_guests/simpleguest/Cargo.lock index dbc6c01e1..f088a8cb1 100644 --- a/src/tests/rust_guests/simpleguest/Cargo.lock +++ b/src/tests/rust_guests/simpleguest/Cargo.lock @@ -8,12 +8,6 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" -[[package]] -name = "atomic_refcell" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" - [[package]] name = "bitflags" version = "2.11.0" @@ -116,7 +110,6 @@ name = "hyperlight-common" version = "0.14.0" dependencies = [ "anyhow", - "atomic_refcell", "bitflags", "bytemuck", "bytes", diff --git a/src/tests/rust_guests/witguest/Cargo.lock b/src/tests/rust_guests/witguest/Cargo.lock index 4fb67bbe7..9206f2e7d 100644 --- a/src/tests/rust_guests/witguest/Cargo.lock +++ b/src/tests/rust_guests/witguest/Cargo.lock @@ -67,12 +67,6 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" -[[package]] -name = "atomic_refcell" -version = "0.1.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" - [[package]] name = "bitflags" version = "2.11.0" @@ -221,7 +215,6 @@ name = "hyperlight-common" version = "0.14.0" dependencies = [ "anyhow", - "atomic_refcell", "bitflags", "bytemuck", "bytes", From 6498b2f39494d3b7ceb9dbcbfe5921d32a66cde0 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 13:40:51 +0200 Subject: [PATCH 09/26] feat(virtq): send logs over virtq Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/buffer.rs | 4 +- src/hyperlight_common/src/virtq/mod.rs | 184 +++++++++++++++++- src/hyperlight_common/src/virtq/msg.rs | 23 +++ src/hyperlight_common/src/virtq/pool.rs | 12 +- src/hyperlight_common/src/virtq/producer.rs | 40 ++++ .../src/guest_handle/host_comm.rs | 12 +- src/hyperlight_guest/src/virtq/context.rs | 180 ++++++++++++----- src/hyperlight_host/src/mem/mgr.rs | 66 ++++--- src/hyperlight_host/src/sandbox/outb.rs | 136 ++++++++----- src/hyperlight_host/tests/common/mod.rs | 10 + .../tests/sandbox_host_tests.rs | 181 ++++++++++++++++- src/tests/rust_guests/simpleguest/src/main.rs | 7 + 12 files changed, 715 insertions(+), 140 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index 238775c6d..b41708b03 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -33,7 +33,9 @@ pub enum AllocError { InvalidArg, #[error("Empty region")] EmptyRegion, - #[error("Out of memory")] + #[error("No space available")] + NoSpace, + #[error("Requested size exceeds pool capacity")] OutOfMemory, #[error("Overflow")] Overflow, diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 5f125b72c..331e6cc82 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -181,9 +181,13 @@ pub trait Notifier { #[derive(Error, Debug)] pub enum VirtqError { #[error("Ring error: {0}")] - RingError(#[from] RingError), + RingError(RingError), #[error("Allocation error: {0}")] - Alloc(#[from] AllocError), + Alloc(AllocError), + #[error("Ring or pool temporarily full")] + Backpressure, + #[error("Allocation exceeds pool capacity")] + OutOfMemory, #[error("Invalid token")] BadToken, #[error("Invalid chain received")] @@ -202,6 +206,33 @@ pub enum VirtqError { NoReadableBuffer, } +impl VirtqError { + /// Check if this error is transient or unrecoverable. + #[inline(always)] + pub fn is_transient(&self) -> bool { + matches!(self, Self::Backpressure) + } +} + +impl From for VirtqError { + fn from(e: RingError) -> Self { + match e { + RingError::WouldBlock => Self::Backpressure, + other => Self::RingError(other), + } + } +} + +impl From for VirtqError { + fn from(e: AllocError) -> Self { + match e { + AllocError::NoSpace => Self::Backpressure, + AllocError::OutOfMemory => Self::OutOfMemory, + other => Self::Alloc(other), + } + } +} + /// Layout of a packed virtqueue ring in shared memory. /// /// Describes the memory addresses for the descriptor table and event suppression @@ -424,7 +455,7 @@ pub(crate) mod test_utils { let addr = self.next.fetch_add(len as u64, Ordering::Relaxed); let end = addr + len as u64; if end > self.base + self.size as u64 { - return Err(AllocError::OutOfMemory); + return Err(AllocError::NoSpace); } Ok(Allocation { addr, len }) } @@ -794,6 +825,153 @@ mod tests { assert_eq!(&expected_first.1[..], b"resp1"); assert_eq!(&expected_second.1[..], b"resp2"); } + + /// Helper: submit a ReadOnly entry (entry data, no completion). + fn send_readonly( + producer: &mut VirtqProducer, + entry_data: &[u8], + ) -> Token { + let mut se = producer.chain().entry(entry_data.len()).build().unwrap(); + se.write_all(entry_data).unwrap(); + producer.submit(se).unwrap() + } + + #[test] + fn test_reclaim_frees_ring_slots() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Fill the ring with ReadOnly entries + send_readonly(&mut producer, b"a"); + send_readonly(&mut producer, b"b"); + send_readonly(&mut producer, b"c"); + send_readonly(&mut producer, b"d"); + + // Ring is now full - next submit should fail with Backpressure + let mut se = producer.chain().entry(1).build().unwrap(); + se.write_all(b"e").unwrap(); + let res = producer.submit(se); + assert!( + matches!(res, Err(VirtqError::Backpressure)), + "expected Backpressure from full ring" + ); + + // Consumer acks all entries + while let Some((_, completion)) = consumer.poll(1024).unwrap() { + consumer.complete(completion).unwrap(); + } + + // Reclaim should free ring slots without losing data + let count = producer.reclaim().unwrap(); + assert_eq!(count, 4, "expected 4 reclaimed entries"); + + // Ring should have space now + send_readonly(&mut producer, b"e"); + } + + #[test] + fn test_reclaim_buffers_rw_completions() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit a ReadWrite entry + let tok = send_readwrite(&mut producer, b"request", 64); + + // Consumer processes and writes response + let (_, completion) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = completion else { + panic!("expected writable"); + }; + wc.write_all(b"response-data").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Reclaim buffers the completion (doesn't discard it) + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // poll() should return the buffered completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok); + assert_eq!(&cqe.data[..], b"response-data"); + } + + #[test] + fn test_reclaim_then_poll_preserves_order() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit 3 entries: RO, RW, RO + let tok_ro1 = send_readonly(&mut producer, b"log1"); + let tok_rw = send_readwrite(&mut producer, b"call", 64); + let tok_ro2 = send_readonly(&mut producer, b"log2"); + + // Consumer processes all 3 + let (_, c1) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c1).unwrap(); // ack RO + + let (_, c2) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c2 else { + panic!("expected writable"); + }; + wc.write_all(b"result").unwrap(); + consumer.complete(wc.into()).unwrap(); // complete RW + + let (_, c3) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c3).unwrap(); // ack RO + + // Reclaim all 3 + let count = producer.reclaim().unwrap(); + assert_eq!(count, 3); + + // poll() returns them in order + let cqe1 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe1.token, tok_ro1); + assert!(cqe1.data.is_empty()); + + let cqe2 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe2.token, tok_rw); + assert_eq!(&cqe2.data[..], b"result"); + + let cqe3 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe3.token, tok_ro2); + assert!(cqe3.data.is_empty()); + + // No more + assert!(producer.poll().unwrap().is_none()); + } + + #[test] + fn test_reclaim_mixed_with_poll() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit and complete 2 entries + send_readonly(&mut producer, b"x"); + let tok_rw = send_readwrite(&mut producer, b"y", 64); + + let (_, c1) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c1).unwrap(); + + let (_, c2) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c2 else { + panic!("expected writable"); + }; + wc.write_all(b"reply").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // poll() consumes first entry directly from ring + let cqe1 = producer.poll().unwrap().unwrap(); + assert!(cqe1.data.is_empty()); + + // reclaim() buffers second entry + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // poll() returns the buffered one + let cqe2 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe2.token, tok_rw); + assert_eq!(&cqe2.data[..], b"reply"); + } } #[cfg(all(test, loom))] mod fuzz { diff --git a/src/hyperlight_common/src/virtq/msg.rs b/src/hyperlight_common/src/virtq/msg.rs index 9c7f69947..ade59643b 100644 --- a/src/hyperlight_common/src/virtq/msg.rs +++ b/src/hyperlight_common/src/virtq/msg.rs @@ -34,6 +34,24 @@ pub enum MsgKind { StreamEnd = 0x04, /// Cancel a pending request. Cancel = 0x05, + /// A guest log message (GuestLogData payload follows). + Log = 0x06, +} + +impl TryFrom for MsgKind { + type Error = u8; + + fn try_from(value: u8) -> Result { + match value { + 0x01 => Ok(Self::Request), + 0x02 => Ok(Self::Response), + 0x03 => Ok(Self::StreamChunk), + 0x04 => Ok(Self::StreamEnd), + 0x05 => Ok(Self::Cancel), + 0x06 => Ok(Self::Log), + other => Err(other), + } + } } /// Wire header for all virtqueue messages @@ -72,4 +90,9 @@ impl VirtqMsgHeader { payload_len, } } + + /// Parse the kind field into a [`MsgKind`] enum. + pub fn msg_kind(&self) -> Result { + MsgKind::try_from(self.kind) + } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 42325a56c..2e49e27fe 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -169,7 +169,7 @@ impl Slab { return Err(AllocError::OutOfMemory); } - let idx = self.find_slots(need_slots).ok_or(AllocError::OutOfMemory)?; + let idx = self.find_slots(need_slots).ok_or(AllocError::NoSpace)?; self.used_slots.insert_range(idx..idx + need_slots); let addr = self.addr_of(idx).ok_or(AllocError::Overflow)?; @@ -463,7 +463,7 @@ impl Inner { if len <= L { match self.lower.alloc(len) { Ok(alloc) => return Ok(alloc), - Err(AllocError::OutOfMemory) => {} + Err(AllocError::NoSpace) => {} Err(e) => return Err(e), } } @@ -614,7 +614,7 @@ impl BufferProvider for RecyclePool { return Err(AllocError::OutOfMemory); } - let addr = inner.free.pop().ok_or(AllocError::OutOfMemory)?; + let addr = inner.free.pop().ok_or(AllocError::NoSpace)?; Ok(Allocation { addr, @@ -727,7 +727,7 @@ mod tests { // Next allocation should fail let result = slab.alloc(256); - assert!(matches!(result, Err(AllocError::OutOfMemory))); + assert!(matches!(result, Err(AllocError::NoSpace))); // Free one and retry slab.dealloc(a2).unwrap(); @@ -1287,7 +1287,7 @@ mod fuzz { assert!(alloc.len >= *size); allocations.push(alloc); } - Err(AllocError::OutOfMemory) => {} + Err(AllocError::NoSpace | AllocError::OutOfMemory) => {} Err(_) => { return false; } @@ -1318,7 +1318,7 @@ mod fuzz { assert!(new_alloc.len >= *new_size); allocations[idx] = new_alloc; } - Err(AllocError::OutOfMemory) => {} + Err(AllocError::NoSpace | AllocError::OutOfMemory) => {} Err(_) => return false, } } diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index 276a2ff3b..eeb96cc7f 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +use alloc::collections::VecDeque; use alloc::vec; use alloc::vec::Vec; @@ -124,6 +125,7 @@ pub struct VirtqProducer { notifier: N, pool: P, inflight: Vec>, + pending: VecDeque, } impl VirtqProducer @@ -149,11 +151,15 @@ where pool, notifier, inflight, + pending: VecDeque::new(), } } /// Poll for a single completion from the device. /// + /// Returns buffered completions from prior [`reclaim`](Self::reclaim) + /// calls first, then checks the ring for new completions. + /// /// Returns `Ok(Some(completion))` if a completion is available, `Ok(None)` if no /// completions are ready (would block), or an error if the device misbehaved. /// @@ -167,6 +173,39 @@ where /// - [`VirtqError::InvalidState`] - Device returned invalid descriptor ID or /// wrote more data than the completion buffer capacity pub fn poll(&mut self) -> Result, VirtqError> + where + M: Send + 'static, + P: Send + 'static, + { + if let Some(cqe) = self.pending.pop_front() { + return Ok(Some(cqe)); + } + self.poll_ring() + } + + /// Reclaim ring slots and pool entries from completed descriptors. + /// + /// Processes all available used entries from the ring: frees entry + /// buffer allocations immediately, and buffers completion data for + /// later retrieval via [`poll`](Self::poll). + /// + /// Use this to free resources under backpressure without losing + /// completion data. Returns the number of entries reclaimed. + pub fn reclaim(&mut self) -> Result + where + M: Send + 'static, + P: Send + 'static, + { + let mut count = 0; + while let Some(cqe) = self.poll_ring()? { + self.pending.push_back(cqe); + count += 1; + } + Ok(count) + } + + /// Poll one completion directly from the ring (bypassing pending buffer). + fn poll_ring(&mut self) -> Result, VirtqError> where M: Send + 'static, P: Send + 'static, @@ -363,6 +402,7 @@ where self.inner.reset(); self.pool.reset(); self.inflight.fill(None); + self.pending.clear(); } } diff --git a/src/hyperlight_guest/src/guest_handle/host_comm.rs b/src/hyperlight_guest/src/guest_handle/host_comm.rs index c72de8a3f..d440852f6 100644 --- a/src/hyperlight_guest/src/guest_handle/host_comm.rs +++ b/src/hyperlight_guest/src/guest_handle/host_comm.rs @@ -162,7 +162,7 @@ impl GuestHandle { source_file: &str, line: u32, ) { - // Closure to send log message to host + // Closure to send log message to host via G2H virtqueue let _send_to_host = || { let guest_log_data = GuestLogData::new( message.to_string(), @@ -177,12 +177,10 @@ impl GuestHandle { .try_into() .expect("Failed to convert GuestLogData to bytes"); - self.push_shared_output_data(&bytes) - .expect("Unable to push log data to shared output data"); - - unsafe { - out32(OutBAction::Log as u16, 0); - } + crate::virtq::with_context(|ctx| { + ctx.emit_log(&bytes) + .expect("Unable to send log data via virtq"); + }); }; #[cfg(all(feature = "trace_guest", target_arch = "x86_64"))] diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index fd7b00142..f4db4699c 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -17,7 +17,7 @@ limitations under the License. //! Guest virtqueue context. use alloc::vec::Vec; -use core::num::NonZeroU16; +use core::result; use core::sync::atomic::AtomicU16; use core::sync::atomic::Ordering::Relaxed; @@ -31,7 +31,7 @@ use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::outb::OutBAction; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{ - BufferPool, Layout, Notifier, QueueStats, RecyclePool, VirtqProducer, + self, BufferPool, Layout, Notifier, QueueStats, RecyclePool, Token, VirtqProducer, }; use super::GuestMemOps; @@ -132,19 +132,33 @@ impl GuestContext { let entry_len = VirtqMsgHeader::SIZE + payload.len(); - let mut entry = self - .g2h_producer - .chain() - .entry(entry_len) - .completion(MAX_RESPONSE_CAP) - .build()?; + let token = match self.try_send_readwrite(hdr_bytes, payload, entry_len) { + Ok(tok) => tok, + Err(e) if e.is_transient() => { + self.g2h_producer.notify_backpressure(); - entry.write_all(hdr_bytes)?; - entry.write_all(payload)?; - self.g2h_producer.submit(entry)?; + if let Err(err) = self.g2h_producer.reclaim() { + bail!("G2H reclaim: {err}"); + } + + let Ok(tok) = self.try_send_readwrite(hdr_bytes, payload, entry_len) else { + bail!("G2H call retry"); + }; - let Some(completion) = self.g2h_producer.poll()? else { - bail!("G2H: no completion received"); + tok + } + Err(e) => bail!("G2H call: {e}"), + }; + + // Poll completions, skipping earlier entries like log acks + // until we find the completion matching our request token. + let completion = loop { + let Some(cqe) = self.g2h_producer.poll()? else { + bail!("G2H: no completion received"); + }; + if cqe.token == token { + break cqe; + } }; let result_bytes = &completion.data; @@ -165,25 +179,6 @@ impl GuestContext { Ok(ret) } - /// Pre-fill the H2G queue with completion-only descriptors so the host - /// can write incoming call payloads into them. - fn prefill_h2g(&mut self) { - loop { - let entry = match self - .h2g_producer - .chain() - .completion(PAGE_SIZE_USIZE) - .build() - { - Ok(e) => e, - Err(_) => break, - }; - if self.h2g_producer.submit(entry).is_err() { - break; - } - } - } - /// Receive a host-to-guest function call from the H2G queue. pub fn recv_h2g_call(&mut self) -> Result { let Some(completion) = self.h2g_producer.poll()? else { @@ -197,8 +192,8 @@ impl GuestContext { let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&data[..VirtqMsgHeader::SIZE]); - if hdr.kind != MsgKind::Request as u8 { - bail!("H2G: unexpected message kind"); + if hdr.msg_kind() != Ok(MsgKind::Request) { + bail!("H2G: unexpected message kind: 0x{:02x}", hdr.kind); } let payload_end = VirtqMsgHeader::SIZE + hdr.payload_len as usize; @@ -214,32 +209,83 @@ impl GuestContext { /// Send the result of a host-to-guest call back to the host via the /// G2H queue, then refill one H2G descriptor slot. pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { - // Build a Response message on the G2H queue - let reqid = REQUEST_ID.fetch_add(1, Relaxed); - let hdr = VirtqMsgHeader::new(MsgKind::Response, reqid, payload.len() as u32); - let hdr_bytes = bytemuck::bytes_of(&hdr); + self.send_g2h_oneshot(MsgKind::Response, payload)?; - let entry_len = VirtqMsgHeader::SIZE + payload.len(); - let mut entry = self.g2h_producer.chain().entry(entry_len).build()?; - - entry.write_all(hdr_bytes)?; - entry.write_all(payload)?; - self.g2h_producer.submit(entry)?; - - // Refill one H2G completion slot - if let Ok(e) = self + // Best-effort refill of one H2G slot. Backpressure is expected + // (pool/ring may be full), other errors are propagated. + match self .h2g_producer .chain() .completion(PAGE_SIZE_USIZE) .build() { - let _ = self.h2g_producer.submit(e); + Ok(e) => match self.h2g_producer.submit(e) { + Ok(_) => {} + Err(virtq::VirtqError::Backpressure) => {} + Err(e) => bail!("H2G refill submit: {e}"), + }, + Err(virtq::VirtqError::Backpressure) => {} + Err(e) => bail!("H2G refill build: {e}"), } Ok(()) } - /// Drain any pending G2H completions (discard them). + /// Pre-fill the H2G queue with completion-only descriptors so the host + /// can write incoming call payloads into them. + fn prefill_h2g(&mut self) { + loop { + let entry = match self + .h2g_producer + .chain() + .completion(PAGE_SIZE_USIZE) + .build() + { + Ok(e) => e, + Err(virtq::VirtqError::Backpressure) => break, + Err(e) => panic!("H2G prefill build: {e}"), + }; + + match self.h2g_producer.submit(entry) { + Ok(_) => {} + Err(virtq::VirtqError::Backpressure) => break, + Err(e) => panic!("H2G prefill submit: {e}"), + } + } + } + + /// Send a one-way message on the G2H queue ReadOnly and no completion. + /// + /// If the pool or ring is full, triggers backpressure, VM exit so + /// the host can drain, then retries once. + fn send_g2h_oneshot(&mut self, kind: MsgKind, payload: &[u8]) -> Result<()> { + let reqid = REQUEST_ID.fetch_add(1, Relaxed); + let hdr = VirtqMsgHeader::new(kind, reqid, payload.len() as u32); + let hdr_bytes = bytemuck::bytes_of(&hdr); + let entry_len = VirtqMsgHeader::SIZE + payload.len(); + + // First attempt + match self.try_send_readonly(hdr_bytes, payload, entry_len) { + Ok(_) => return Ok(()), + Err(virtq::VirtqError::Backpressure) => { + // VM exit so host drains and completes G2H entries. + self.g2h_producer.notify_backpressure(); + } + Err(e) => bail!("G2H oneshot: {e}"), + } + + // Reclaim ring/pool resources from completed entries. + if let Err(e) = self.g2h_producer.reclaim() { + bail!("G2H oneshot retry: {e}"); + } + // Retry after backpressure + match self.try_send_readonly(hdr_bytes, payload, entry_len) { + Ok(_) => Ok(()), + Err(e) => bail!("G2H oneshot retry: {e}"), + } + } + + /// Drain any pending G2H completions. /// /// This is called before checking for H2G calls so that the host /// can reclaim G2H response buffers. @@ -247,6 +293,11 @@ impl GuestContext { while let Ok(Some(_)) = self.g2h_producer.poll() {} } + /// Send a log message via the G2H queue. Fire-and-forget. + pub fn emit_log(&mut self, log_data: &[u8]) -> Result<()> { + self.send_g2h_oneshot(MsgKind::Log, log_data) + } + /// Reset ring and pool state after snapshot restore. pub(super) fn reset(&mut self, new_generation: u16) { // G2H producer reset also resets the pool via BufferProvider::reset() @@ -261,4 +312,35 @@ impl GuestContext { pub(super) fn generation(&self) -> u16 { self.generation } + + fn try_send( + &mut self, + header: &[u8], + payload: &[u8], + entry_len: usize, + ) -> result::Result { + let mut entry = self.g2h_producer.chain().entry(entry_len).build()?; + + entry.write_all(header)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry) + } + + fn try_send_readwrite( + &mut self, + header: &[u8], + payload: &[u8], + entry_len: usize, + ) -> result::Result { + let mut entry = self + .g2h_producer + .chain() + .entry(entry_len) + .completion(MAX_RESPONSE_CAP) + .build()?; + + entry.write_all(header)?; + entry.write_all(payload)?; + self.g2h_producer.submit(entry) + } } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index c65bab934..f30955e2c 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -513,6 +513,7 @@ impl SandboxMemoryManager { } /// Read guest log data from the `SharedMemory` contained within `self` + #[allow(dead_code)] #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn read_guest_log_data(&mut self) -> Result { self.scratch_mem.try_pop_buffer_into::( @@ -1044,35 +1045,48 @@ impl SandboxMemoryManager { .as_mut() .ok_or_else(|| new_error!("G2H consumer not initialized"))?; - let Some((entry, completion)) = consumer - .poll(8192) - .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))? - else { - return Err(new_error!("G2H: no H2G result entry after halt")); - }; - - let entry_data = entry.data(); - if entry_data.len() < VirtqMsgHeader::SIZE { - return Err(new_error!("G2H: result entry too short")); - } - - let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); - if hdr.kind != MsgKind::Response as u8 { - return Err(new_error!( - "G2H: expected Response after halt, got kind={}", - hdr.kind - )); - } + // Drain the G2H queue, processing Log entries inline, until we + // find the Response that carries the H2G function call result. + loop { + let maybe_next = consumer + .poll(8192) + .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))?; - let payload = &entry_data[VirtqMsgHeader::SIZE..]; - let fcr = FunctionCallResult::try_from(payload) - .map_err(|e| new_error!("G2H: malformed FunctionCallResult: {}", e))?; + let Some((entry, completion)) = maybe_next else { + return Err(new_error!("G2H: no H2G result entry after halt")); + }; - consumer - .complete(completion) - .map_err(|e| new_error!("G2H complete: {:?}", e))?; + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(new_error!("G2H: result entry too short")); + } - Ok(fcr) + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + + match hdr.msg_kind() { + Ok(MsgKind::Response) => { + let fcr = FunctionCallResult::try_from(payload) + .map_err(|e| new_error!("G2H: malformed FunctionCallResult: {}", e))?; + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete: {:?}", e))?; + return Ok(fcr); + } + Ok(MsgKind::Log) => { + crate::sandbox::outb::emit_guest_log_from_payload(payload); + consumer + .complete(completion) + .map_err(|e| new_error!("G2H complete log: {:?}", e))?; + } + Ok(other) => { + return Err(new_error!("G2H: expected Response or Log, got {:?}", other)); + } + Err(unknown) => { + return Err(new_error!("G2H: unknown message kind: 0x{:02x}", unknown)); + } + } + } } } diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 0e11409ad..b5a20d31c 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -64,46 +64,67 @@ pub enum HandleOutbError { MemProfile(String), } +#[allow(dead_code)] #[instrument(err(Debug), skip_all, parent = Span::current(), level="Trace")] pub(super) fn outb_log( mgr: &mut SandboxMemoryManager, ) -> Result<(), HandleOutbError> { - // This code will create either a logging record or a tracing record for the GuestLogData depending on if the host has set up a tracing subscriber. - // In theory as we have enabled the log feature in the Cargo.toml for tracing this should happen - // automatically (based on if there is tracing subscriber present) but only works if the event created using macros. (see https://github.com/tokio-rs/tracing/blob/master/tracing/src/macros.rs#L2421 ) - // The reason that we don't want to use the tracing macros is that we want to be able to explicitly - // set the file and line number for the log record which is not possible with macros. - // This is because the file and line number come from the guest not the call site. - let log_data: GuestLogData = mgr .read_guest_log_data() .map_err(|e| HandleOutbError::ReadLogData(e.to_string()))?; - let record_level: Level = (&log_data.level).into(); + emit_guest_log(&log_data); + Ok(()) +} - // Work out if we need to log or trace - // this API is marked as follows but it is the easiest way to work out if we should trace or log +/// Emit a guest log record from a virtqueue payload. +/// +/// Deserializes [`GuestLogData`] from the raw bytes and emits either +/// a tracing event or a log record, matching the original `outb_log` +/// behavior. +pub(crate) fn emit_guest_log_from_payload(payload: &[u8]) { + let Ok(log_data) = GuestLogData::try_from(payload) else { + return; + }; + emit_guest_log(&log_data); +} - // Private API for internal use by tracing's macros. - // - // This function is *not* considered part of `tracing`'s public API, and has no - // stability guarantees. If you use it, and it breaks or disappears entirely, - // don't say we didn't warn you. +fn emit_guest_log(log_data: &GuestLogData) { + // This code will create either a logging record or a tracing record + // for the GuestLogData depending on if the host has set up a tracing + // subscriber. + // In theory as we have enabled the log feature in the Cargo.toml for + // tracing this should happen automatically (based on if there is a + // tracing subscriber present) but only works if the event is created + // using macros. + // (see https://github.com/tokio-rs/tracing/blob/master/tracing/src/macros.rs#L2421) + // The reason that we don't want to use the tracing macros is that we + // want to be able to explicitly set the file and line number for the + // log record which is not possible with macros. + // This is because the file and line number come from the guest not + // the call site. + let record_level: Level = (&log_data.level).into(); + + // Work out if we need to log or trace. + // This API is marked as internal but it is the easiest way to work + // out if we should trace or log. let should_trace = tracing_core::dispatcher::has_been_set(); let source_file = Some(log_data.source_file.as_str()); let line = Some(log_data.line); let source = Some(log_data.source.as_str()); - // See https://github.com/rust-lang/rust/issues/42253 for the reason this has to be done this way + // See https://github.com/rust-lang/rust/issues/42253 for the reason + // this has to be done this way. if should_trace { - // Create a tracing event for the GuestLogData - // Ideally we would create tracing metadata based on the Guest Log Data - // but tracing derives the metadata at compile time + // Create a tracing event for the GuestLogData. + // Ideally we would create tracing metadata based on the Guest + // Log Data but tracing derives the metadata at compile time. // see https://github.com/tokio-rs/tracing/issues/2419 - // so we leave it up to the subscriber to figure out that there are logging fields present with this data - format_trace( + // So we leave it up to the subscriber to figure out that there + // are logging fields present with this data. + let _ = format_trace( &Record::builder() .args(format_args!("{}", log_data.message)) .level(record_level) @@ -112,8 +133,7 @@ pub(super) fn outb_log( .line(line) .module_path(source) .build(), - ) - .map_err(|e| HandleOutbError::TraceFormat(e.to_string()))?; + ); } else { // Create a log record for the GuestLogData log::logger().log( @@ -127,8 +147,6 @@ pub(super) fn outb_log( .build(), ); } - - Ok(()) } const ABORT_TERMINATOR: u8 = 0xFF; @@ -184,6 +202,8 @@ fn outb_abort( } /// Handle a guest-to-host function call received via the G2H virtqueue. +/// +/// Log entries that arrive before the Request are processed inline. fn outb_virtq_call( mem_mgr: &mut SandboxMemoryManager, host_funcs: &Arc>, @@ -192,32 +212,49 @@ fn outb_virtq_call( HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) })?; - let Some((entry, completion)) = consumer - .poll(8192) - .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? - else { - // No G2H entry - can happen when guest H2G prefill - // triggers VirtqNotify before suppression is set. - return Ok(()); + // Drain entries, processing Log messages, until we find a Request. + let (entry, completion) = loop { + let Some((entry, completion)) = consumer + .poll(8192) + .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? + else { + // No G2H entry - backpressure-only notify or prefill notify. + return Ok(()); + }; + + let entry_data = entry.data(); + if entry_data.len() < VirtqMsgHeader::SIZE { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H entry too short".into(), + )); + } + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + + match hdr.msg_kind() { + Ok(MsgKind::Log) => { + let payload = &entry_data[VirtqMsgHeader::SIZE..]; + emit_guest_log_from_payload(payload); + let _ = consumer.complete(completion); + continue; + } + Ok(MsgKind::Request) => break (entry, completion), + Ok(other) => { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: expected Request via outb, got {:?}", + other + ))); + } + Err(unknown) => { + return Err(HandleOutbError::ReadHostFunctionCall(format!( + "G2H: unknown message kind: 0x{unknown:02x}" + ))); + } + } }; let entry_data = entry.data(); - if entry_data.len() < VirtqMsgHeader::SIZE { - return Err(HandleOutbError::ReadHostFunctionCall( - "G2H entry too short".into(), - )); - } - let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); let payload = &entry_data[VirtqMsgHeader::SIZE..]; - // TODO(virtq): Only Requests (host function callbacks) arrive via outb. - if hdr.kind != MsgKind::Request as u8 { - return Err(HandleOutbError::ReadHostFunctionCall(format!( - "G2H: expected Request via outb, got kind={}", - hdr.kind - ))); - } - let call = FunctionCall::try_from(payload) .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; @@ -269,7 +306,12 @@ pub(crate) fn handle_outb( .try_into() .map_err(|e: anyhow::Error| HandleOutbError::InvalidPort(e.to_string()))? { - OutBAction::Log => outb_log(mem_mgr), + OutBAction::Log => { + // Legacy path - logs now arrive via G2H virtqueue + // and are processed inline by outb_virtq_call / + // read_h2g_result_from_g2h. + Ok(()) + } OutBAction::CallFunction => { let call = mem_mgr .get_host_function_call() diff --git a/src/hyperlight_host/tests/common/mod.rs b/src/hyperlight_host/tests/common/mod.rs index d58e60aa6..8b2f6de9f 100644 --- a/src/hyperlight_host/tests/common/mod.rs +++ b/src/hyperlight_host/tests/common/mod.rs @@ -80,6 +80,16 @@ where f(sandbox); } +/// Runs a test with a Rust guest UninitializedSandbox using custom configuration. +pub fn with_rust_uninit_sandbox_cfg(cfg: SandboxConfiguration, f: F) +where + F: FnOnce(UninitializedSandbox), +{ + let sandbox = + UninitializedSandbox::new(GuestBinary::FilePath(rust_guest_path()), Some(cfg)).unwrap(); + f(sandbox); +} + // ============================================================================= // C guest helpers // ============================================================================= diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 9d4626aa7..b4c2a3a6a 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -27,7 +27,8 @@ use hyperlight_testing::simple_guest_as_string; pub mod common; // pub to disable dead_code warning use crate::common::{ with_all_sandboxes, with_all_sandboxes_cfg, with_all_sandboxes_with_writer, - with_all_uninit_sandboxes, + with_all_uninit_sandboxes, with_rust_sandbox_cfg, with_rust_uninit_sandbox, + with_rust_uninit_sandbox_cfg, }; #[test] @@ -376,3 +377,181 @@ fn host_function_error() { } }); } + +#[test] +fn virtq_log_delivery() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox + .call::<()>("LogMessage", ("virtq log test message".to_string(), 3_i32)) + .unwrap(); + + // Verify the guest log arrived via virtqueue + let count = LOGGER.num_log_calls(); + assert!(count > 0, "expected at least one guest log, got 0"); + + let mut found = false; + for i in 0..count { + if let Some(call) = LOGGER.get_log_call(i) + && call.target == "hyperlight_guest" + && call.args.contains("virtq log test") + { + found = true; + break; + } + } + assert!(found, "expected 'virtq log test' message from guest"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +fn virtq_log_with_callback() { + // Verify that log messages interleaved with host callbacks work + with_all_uninit_sandboxes(|mut sandbox| { + let (tx, _rx) = channel(); + sandbox + .register("HostMethod1", move |msg: String| { + let len = msg.len(); + tx.send(msg).unwrap(); + len as i32 + }) + .unwrap(); + let mut sandbox = sandbox.evolve().unwrap(); + + // Echo triggers guest-side logging infrastructure, then returns. + // This validates that log ReadOnly entries interleaved with + // function call ReadWrite entries don't corrupt the G2H queue. + let res: String = sandbox.call("Echo", "test".to_string()).unwrap(); + assert_eq!(res, "test"); + }); +} + +#[test] +fn virtq_log_backpressure() { + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + // 50 logs with a 2-page pool should trigger backpressure + sandbox.call::<()>("LogMessageN", 50_i32).unwrap(); + + // Verify sandbox is still functional after backpressure + let res: i32 = sandbox + .call("ThisIsNotARealFunctionButTheNameIsImportant", ()) + .unwrap(); + assert_eq!(res, 99); + + // Verify all 50 log entries were delivered + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 50, "expected 50 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +fn virtq_log_backpressure_repeated() { + // Multiple calls that each trigger backpressure, verifying the + // pool recovers correctly each time. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + + with_rust_sandbox_cfg(cfg, |mut sandbox| { + for _ in 0..5 { + sandbox.call::<()>("LogMessageN", 30_i32).unwrap(); + } + }); +} + +#[test] +fn virtq_backpressure_small_ring() { + // Small descriptor table forces ring-level backpressure. + use hyperlight_testing::simplelogger::{LOGGER, SimpleLogger}; + + SimpleLogger::initialize_test_logger(); + LOGGER.clear_log_calls(); + + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_queue_depth(4); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); + + let guest_count = (0..LOGGER.num_log_calls()) + .filter_map(|i| LOGGER.get_log_call(i)) + .filter(|c| c.target == "hyperlight_guest" && c.args.contains("log entry")) + .count(); + assert_eq!(guest_count, 20, "expected 20 guest logs, got {guest_count}"); + LOGGER.clear_log_calls(); + }); +} + +#[test] +fn virtq_backpressure_log_then_callback() { + // Logs fill the G2H ring, then a host callback needs ring space. + // call_host_function handles backpressure by notify + reclaim + retry. + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_queue_depth(4); + cfg.set_g2h_pool_pages(2); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + sbox.register_print(|msg: String| msg.len() as i32).unwrap(); + let mut sandbox = sbox.evolve().unwrap(); + + // PrintOutput logs and calls HostPrint callback. + // With depth=4 the logs may fill the ring, requiring + // call_host_function to handle backpressure before + // submitting the callback entry. + let res: i32 = sandbox.call("PrintOutput", "bp-test".to_string()).unwrap(); + assert_eq!(res, 7); + }); +} + +#[test] +fn virtq_backpressure_no_data_loss() { + // After backpressure recovery, verify multiple function calls + // return correct results (completion data wasn't lost by reclaim). + let mut cfg = SandboxConfiguration::default(); + cfg.set_g2h_pool_pages(2); + cfg.set_g2h_queue_depth(4); + + with_rust_uninit_sandbox_cfg(cfg, |mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + // Trigger backpressure with logs + sandbox.call::<()>("LogMessageN", 20_i32).unwrap(); + + // Now verify multiple function calls with return values + let res: String = sandbox.call("Echo", "first".to_string()).unwrap(); + assert_eq!(res, "first"); + + let res: String = sandbox.call("Echo", "second".to_string()).unwrap(); + assert_eq!(res, "second"); + + let res: f64 = sandbox.call("EchoDouble", 1.234_f64).unwrap(); + assert!((res - 1.234).abs() < f64::EPSILON); + }); +} diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index ac876e65f..3e7d89ee6 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -479,6 +479,13 @@ fn log_message(message: String, level: i32) { } } +#[guest_function("LogMessageN")] +fn log_message_n(count: i32) { + for i in 0..count { + log::info!("log entry {}", i); + } +} + #[guest_function("TriggerException")] fn trigger_exception() { // trigger an undefined instruction exception From 7d90a5a3b80794b53771df90aee63c0deccfec2c Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 13:55:08 +0200 Subject: [PATCH 10/26] feat(virtq): use virtq for capi ret error Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest/src/virtq/context.rs | 2 +- src/hyperlight_guest_capi/src/error.rs | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index f4db4699c..20c71ff9d 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -313,7 +313,7 @@ impl GuestContext { self.generation } - fn try_send( + fn try_send_readonly( &mut self, header: &[u8], payload: &[u8], diff --git a/src/hyperlight_guest_capi/src/error.rs b/src/hyperlight_guest_capi/src/error.rs index 03217600e..720911157 100644 --- a/src/hyperlight_guest_capi/src/error.rs +++ b/src/hyperlight_guest_capi/src/error.rs @@ -19,7 +19,6 @@ use core::ffi::{CStr, c_char}; use flatbuffers::FlatBufferBuilder; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::flatbuffer_wrappers::guest_error::{ErrorCode, GuestError}; -use hyperlight_guest_bin::GUEST_HANDLE; use crate::alloc::borrow::ToOwned; @@ -35,12 +34,11 @@ pub extern "C" fn hl_set_error(err: ErrorCode, message: *const c_char) { let fcr = FunctionCallResult::new(guest_error); let mut builder = FlatBufferBuilder::new(); let data = fcr.encode(&mut builder); - unsafe { - #[allow(static_mut_refs)] // we are single threaded - GUEST_HANDLE - .push_shared_output_data(data) - .expect("Failed to set error") - } + + hyperlight_guest::virtq::with_context(|ctx| { + ctx.send_h2g_result(data) + .expect("Failed to send error via virtq"); + }); } #[unsafe(no_mangle)] From 30ed3b5df03cc383da97139ec685c8235115dfce Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 16:26:06 +0200 Subject: [PATCH 11/26] feat(virtq): remove unused stack based io path Signed-off-by: Tomasz Andrzejak --- .../src/guest_handle/host_comm.rs | 101 ------ src/hyperlight_guest/src/guest_handle/io.rs | 150 -------- src/hyperlight_guest/src/lib.rs | 1 - src/hyperlight_guest/src/virtq/context.rs | 25 ++ src/hyperlight_guest_bin/src/host_comm.rs | 49 +-- src/hyperlight_guest_capi/src/dispatch.rs | 24 +- src/hyperlight_guest_capi/src/flatbuffer.rs | 22 +- src/hyperlight_host/src/mem/mgr.rs | 93 +---- src/hyperlight_host/src/mem/shared_mem.rs | 326 ------------------ .../src/sandbox/initialized_multi_use.rs | 2 - src/hyperlight_host/src/sandbox/outb.rs | 298 +--------------- src/hyperlight_host/src/testing/log_values.rs | 62 ---- src/hyperlight_host/src/testing/mod.rs | 1 - .../tests/sandbox_host_tests.rs | 102 ++++++ src/tests/rust_guests/simpleguest/src/main.rs | 46 +-- 15 files changed, 188 insertions(+), 1114 deletions(-) delete mode 100644 src/hyperlight_guest/src/guest_handle/io.rs delete mode 100644 src/hyperlight_host/src/testing/log_values.rs diff --git a/src/hyperlight_guest/src/guest_handle/host_comm.rs b/src/hyperlight_guest/src/guest_handle/host_comm.rs index d440852f6..10b8e9a7a 100644 --- a/src/hyperlight_guest/src/guest_handle/host_comm.rs +++ b/src/hyperlight_guest/src/guest_handle/host_comm.rs @@ -18,21 +18,13 @@ use alloc::format; use alloc::string::ToString; use alloc::vec::Vec; -use flatbuffers::FlatBufferBuilder; -use hyperlight_common::flatbuffer_wrappers::function_call::{FunctionCall, FunctionCallType}; -use hyperlight_common::flatbuffer_wrappers::function_types::{ - FunctionCallResult, ParameterValue, ReturnType, ReturnValue, -}; use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::flatbuffer_wrappers::guest_log_level::LogLevel; -use hyperlight_common::flatbuffer_wrappers::util::estimate_flatbuffer_capacity; -use hyperlight_common::outb::OutBAction; use tracing::instrument; use super::handle::GuestHandle; use crate::error::{HyperlightGuestError, Result}; -use crate::exit::out32; impl GuestHandle { /// Get user memory region as bytes. @@ -59,99 +51,6 @@ impl GuestHandle { } } - /// Get a return value from a host function call. - /// This usually requires a host function to be called first using - /// `call_host_function_internal`. - /// - /// When calling `call_host_function`, this function is called - /// internally to get the return value. - #[instrument(skip_all, level = "Trace")] - pub fn get_host_return_value>(&self) -> Result { - let inner = self - .try_pop_shared_input_data_into::() - .expect("Unable to deserialize a return value from host") - .into_inner(); - - match inner { - Ok(ret) => T::try_from(ret).map_err(|_| { - let expected = core::any::type_name::(); - HyperlightGuestError::new( - ErrorCode::UnsupportedParameterType, - format!("Host return value could not be converted to expected {expected}",), - ) - }), - Err(e) => Err(HyperlightGuestError { - kind: e.code, - message: e.message, - }), - } - } - - pub fn get_host_return_raw(&self) -> Result { - let inner = self - .try_pop_shared_input_data_into::() - .expect("Unable to deserialize a return value from host") - .into_inner(); - - match inner { - Ok(ret) => Ok(ret), - Err(e) => Err(HyperlightGuestError { - kind: e.code, - message: e.message, - }), - } - } - - /// Call a host function without reading its return value from shared mem. - /// This is used by both the Rust and C APIs to reduce code duplication. - /// - /// Note: The function return value must be obtained by calling - /// `get_host_return_value`. - #[instrument(skip_all, level = "Trace")] - pub fn call_host_function_without_returning_result( - &self, - function_name: &str, - parameters: Option>, - return_type: ReturnType, - ) -> Result<()> { - let estimated_capacity = - estimate_flatbuffer_capacity(function_name, parameters.as_deref().unwrap_or(&[])); - - let host_function_call = FunctionCall::new( - function_name.to_string(), - parameters, - FunctionCallType::Host, - return_type, - ); - - let mut builder = FlatBufferBuilder::with_capacity(estimated_capacity); - - let host_function_call_buffer = host_function_call.encode(&mut builder); - self.push_shared_output_data(host_function_call_buffer)?; - - unsafe { - out32(OutBAction::CallFunction as u16, 0); - } - - Ok(()) - } - - /// Call a host function with the given parameters and return type. - /// This function serializes the function call and its parameters, - /// sends it to the host, and then retrieves the return value. - /// - /// The return value is deserialized into the specified type `T`. - #[instrument(skip_all, level = "Info")] - pub fn call_host_function>( - &self, - function_name: &str, - parameters: Option>, - return_type: ReturnType, - ) -> Result { - self.call_host_function_without_returning_result(function_name, parameters, return_type)?; - self.get_host_return_value::() - } - /// Log a message with the specified log level, source, caller, source file, and line number. pub fn log_message( &self, diff --git a/src/hyperlight_guest/src/guest_handle/io.rs b/src/hyperlight_guest/src/guest_handle/io.rs deleted file mode 100644 index 46c1d68f6..000000000 --- a/src/hyperlight_guest/src/guest_handle/io.rs +++ /dev/null @@ -1,150 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use alloc::format; -use alloc::string::ToString; -use core::any::type_name; -use core::slice::from_raw_parts_mut; - -use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; -use tracing::instrument; - -use super::handle::GuestHandle; -use crate::error::{HyperlightGuestError, Result}; - -impl GuestHandle { - /// Pops the top element from the shared input data buffer and returns it as a T - #[instrument(skip_all, level = "Trace")] - pub fn try_pop_shared_input_data_into(&self) -> Result - where - T: for<'a> TryFrom<&'a [u8]>, - { - let peb_ptr = self.peb().unwrap(); - let input_stack_size = unsafe { (*peb_ptr).input_stack.size as usize }; - let input_stack_ptr = unsafe { (*peb_ptr).input_stack.ptr as *mut u8 }; - - let idb = unsafe { from_raw_parts_mut(input_stack_ptr, input_stack_size) }; - - if idb.is_empty() { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - "Got a 0-size buffer in pop_shared_input_data_into".to_string(), - )); - } - - // get relative offset to next free address - let stack_ptr_rel: u64 = - u64::from_le_bytes(idb[..8].try_into().expect("Shared input buffer too small")); - - if stack_ptr_rel as usize > input_stack_size || stack_ptr_rel < 16 { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Invalid stack pointer: {} in pop_shared_input_data_into", - stack_ptr_rel - ), - )); - } - - // go back 8 bytes and read. This is the offset to the element on top of stack - let last_element_offset_rel = u64::from_le_bytes( - idb[stack_ptr_rel as usize - 8..stack_ptr_rel as usize] - .try_into() - .expect("Invalid stack pointer in pop_shared_input_data_into"), - ); - - let buffer = &idb[last_element_offset_rel as usize..]; - - // convert the buffer to T - let type_t = match T::try_from(buffer) { - Ok(t) => Ok(t), - Err(_e) => { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!("Unable to convert buffer to {}", type_name::()), - )); - } - }; - - // update the stack pointer to point to the element we just popped of since that is now free - idb[..8].copy_from_slice(&last_element_offset_rel.to_le_bytes()); - - // zero out popped off buffer - idb[last_element_offset_rel as usize..stack_ptr_rel as usize].fill(0); - - type_t - } - - /// Pushes the given data onto the shared output data buffer. - pub fn push_shared_output_data(&self, data: &[u8]) -> Result<()> { - let peb_ptr = self.peb().unwrap(); - let output_stack_size = unsafe { (*peb_ptr).output_stack.size as usize }; - let output_stack_ptr = unsafe { (*peb_ptr).output_stack.ptr as *mut u8 }; - - let odb = unsafe { from_raw_parts_mut(output_stack_ptr, output_stack_size) }; - - if odb.is_empty() { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - "Got a 0-size buffer in push_shared_output_data".to_string(), - )); - } - - // get offset to next free address on the stack - let stack_ptr_rel: u64 = - u64::from_le_bytes(odb[..8].try_into().expect("Shared output buffer too small")); - - // check if the stack pointer is within the bounds of the buffer. - // It can be equal to the size, but never greater - // It can never be less than 8. An empty buffer's stack pointer is 8 - if stack_ptr_rel as usize > output_stack_size || stack_ptr_rel < 8 { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Invalid stack pointer: {} in push_shared_output_data", - stack_ptr_rel - ), - )); - } - - // check if there is enough space in the buffer - let size_required = data.len() + 8; // the data plus the pointer pointing to the data - let size_available = output_stack_size - stack_ptr_rel as usize; - if size_required > size_available { - return Err(HyperlightGuestError::new( - ErrorCode::GuestError, - format!( - "Not enough space in shared output buffer. Required: {}, Available: {}", - size_required, size_available - ), - )); - } - - // write the actual data - odb[stack_ptr_rel as usize..stack_ptr_rel as usize + data.len()].copy_from_slice(data); - - // write the offset to the newly written data, to the top of the stack - let bytes: [u8; 8] = stack_ptr_rel.to_le_bytes(); - odb[stack_ptr_rel as usize + data.len()..stack_ptr_rel as usize + data.len() + 8] - .copy_from_slice(&bytes); - - // update stack pointer to point to next free address - let new_stack_ptr_rel: u64 = (stack_ptr_rel as usize + data.len() + 8) as u64; - odb[0..8].copy_from_slice(&(new_stack_ptr_rel).to_le_bytes()); - - Ok(()) - } -} diff --git a/src/hyperlight_guest/src/lib.rs b/src/hyperlight_guest/src/lib.rs index 8dbd74dc0..6cf30a023 100644 --- a/src/hyperlight_guest/src/lib.rs +++ b/src/hyperlight_guest/src/lib.rs @@ -30,5 +30,4 @@ pub mod virtq; pub mod guest_handle { pub mod handle; pub mod host_comm; - pub mod io; } diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 20c71ff9d..ac0357351 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -72,6 +72,7 @@ pub struct GuestContext { g2h_producer: G2hProducer, h2g_producer: H2gProducer, generation: u16, + last_host_return: Option, } impl GuestContext { @@ -100,6 +101,7 @@ impl GuestContext { g2h_producer, h2g_producer, generation, + last_host_return: None, }; ctx.prefill_h2g(); @@ -343,4 +345,27 @@ impl GuestContext { entry.write_all(payload)?; self.g2h_producer.submit(entry) } + + /// Stash a host function return value for later retrieval. + /// + /// Used by the C API's two-step calling convention where + /// `hl_call_host_function` and `hl_get_host_return_value_as_*` + /// are separate calls. + pub fn stash_host_return(&mut self, value: ReturnValue) { + self.last_host_return = Some(value); + } + + /// Take the stashed host return value. + /// + /// Panics if no value was stashed or if the type conversion fails. + pub fn take_host_return>(&mut self) -> T { + let rv = self + .last_host_return + .take() + .expect("No host return value available"); + match T::try_from(rv) { + Ok(v) => v, + Err(_) => panic!("Host return value type mismatch"), + } + } } diff --git a/src/hyperlight_guest_bin/src/host_comm.rs b/src/hyperlight_guest_bin/src/host_comm.rs index 369981deb..1fe7f9994 100644 --- a/src/hyperlight_guest_bin/src/host_comm.rs +++ b/src/hyperlight_guest_bin/src/host_comm.rs @@ -27,6 +27,7 @@ use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; use hyperlight_common::func::{ParameterTuple, SupportedReturnType}; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; const BUFFER_SIZE: usize = 1000; static mut MESSAGE_BUFFER: Vec = Vec::new(); @@ -41,17 +42,7 @@ pub fn call_host_function( where T: TryFrom, { - #[cfg(feature = "virtq")] - { - hyperlight_guest::virtq::with_context(|ctx| { - ctx.call_host_function(function_name, parameters, return_type) - }) - } - #[cfg(not(feature = "virtq"))] - { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function::(function_name, parameters, return_type) - } + virtq::with_context(|ctx| ctx.call_host_function(function_name, parameters, return_type)) } pub fn call_host(function_name: impl AsRef, args: impl ParameterTuple) -> Result @@ -61,25 +52,6 @@ where call_host_function::(function_name.as_ref(), Some(args.into_value()), T::TYPE) } -pub fn call_host_function_without_returning_result( - function_name: &str, - parameters: Option>, - return_type: ReturnType, -) -> Result<()> { - let handle = unsafe { GUEST_HANDLE }; - handle.call_host_function_without_returning_result(function_name, parameters, return_type) -} - -pub fn get_host_return_value_raw() -> Result { - let handle = unsafe { GUEST_HANDLE }; - handle.get_host_return_raw() -} - -pub fn get_host_return_value>() -> Result { - let handle = unsafe { GUEST_HANDLE }; - handle.get_host_return_value::() -} - pub fn read_n_bytes_from_user_memory(num: u64) -> Result> { let handle = unsafe { GUEST_HANDLE }; handle.read_n_bytes_from_user_memory(num) @@ -90,9 +62,8 @@ pub fn read_n_bytes_from_user_memory(num: u64) -> Result> { /// This function requires memory to be setup to be used. In particular, the /// existence of the input and output memory regions. pub fn print_output_with_host_print(function_call: FunctionCall) -> Result> { - let handle = unsafe { GUEST_HANDLE }; if let ParameterValue::String(message) = function_call.parameters.unwrap().remove(0) { - let res = handle.call_host_function::( + let res = call_host_function::( "HostPrint", Some(Vec::from(&[ParameterValue::String(message)])), ReturnType::Int, @@ -114,7 +85,6 @@ pub fn print_output_with_host_print(function_call: FunctionCall) -> Result( - "HostPrint", - Some(Vec::from(&[ParameterValue::String(str)])), - ReturnType::Int, - ) - .expect("Failed to call HostPrint"); + let _ = call_host_function::( + "HostPrint", + Some(Vec::from(&[ParameterValue::String(str)])), + ReturnType::Int, + ) + .expect("Failed to call HostPrint"); // Clear the buffer after sending message_buffer.clear(); diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index e0a8bc34c..245eaf700 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -20,12 +20,14 @@ use alloc::vec::Vec; use core::ffi::{CStr, c_char}; use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall; -use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterType, ReturnType}; +use hyperlight_common::flatbuffer_wrappers::function_types::{ + ParameterType, ReturnType, ReturnValue, +}; use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode; use hyperlight_guest::error::{HyperlightGuestError, Result}; +use hyperlight_guest::virtq; use hyperlight_guest_bin::guest_function::definition::GuestFunctionDefinition; use hyperlight_guest_bin::guest_function::register::GuestFunctionRegister; -use hyperlight_guest_bin::host_comm::call_host_function_without_returning_result; use crate::types::{FfiFunctionCall, FfiVec}; static mut REGISTERED_C_GUEST_FUNCTIONS: GuestFunctionRegister = @@ -98,15 +100,23 @@ pub extern "C" fn hl_register_function_definition( unsafe { (&mut *(&raw mut REGISTERED_C_GUEST_FUNCTIONS)).register(func_def) }; } -/// The caller is responsible for freeing the memory associated with given `FfiFunctionCall`. +/// Call a host function. The return value can be retrieved with +/// `hl_get_host_return_value_as_*` immediately after. #[unsafe(no_mangle)] pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let parameters = unsafe { function_call.copy_parameters() }; let func_name = unsafe { function_call.copy_function_name() }; let return_type = unsafe { function_call.copy_return_type() }; - // Use the non-generic internal implementation - // The C API will then call specific getter functions to fetch the properly typed return value - let _ = call_host_function_without_returning_result(&func_name, Some(parameters), return_type) - .expect("Failed to call host function"); + virtq::with_context(|ctx| { + let result: ReturnValue = ctx + .call_host_function(&func_name, Some(parameters), return_type) + .expect("Failed to call host function"); + ctx.stash_host_return(result); + }); +} + +/// Retrieve the return value stashed by the last `hl_call_host_function`. +pub(crate) fn take_last_host_return>() -> T { + virtq::with_context(|ctx| ctx.take_host_return::()) } diff --git a/src/hyperlight_guest_capi/src/flatbuffer.rs b/src/hyperlight_guest_capi/src/flatbuffer.rs index ff12400d6..043431e4c 100644 --- a/src/hyperlight_guest_capi/src/flatbuffer.rs +++ b/src/hyperlight_guest_capi/src/flatbuffer.rs @@ -21,8 +21,8 @@ use alloc::vec::Vec; use core::ffi::{CStr, c_char}; use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result; -use hyperlight_guest_bin::host_comm::get_host_return_value; +use crate::dispatch::take_last_host_return; use crate::types::FfiVec; // The reason for the capitalized type in the function names below @@ -106,44 +106,43 @@ pub extern "C" fn hl_flatbuffer_result_from_Bool(value: bool) -> Box { #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Int() -> i32 { - get_host_return_value().expect("Unable to get host return value as int") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_UInt() -> u32 { - get_host_return_value().expect("Unable to get host return value as uint") + take_last_host_return() } // the same for long, ulong #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Long() -> i64 { - get_host_return_value().expect("Unable to get host return value as long") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_ULong() -> u64 { - get_host_return_value().expect("Unable to get host return value as ulong") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Bool() -> bool { - get_host_return_value().expect("Unable to get host return value as bool") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Float() -> f32 { - get_host_return_value().expect("Unable to get host return value as f32") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_Double() -> f64 { - get_host_return_value().expect("Unable to get host return value as f64") + take_last_host_return() } #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_String() -> *const c_char { - let string_value: String = - get_host_return_value().expect("Unable to get host return value as string"); + let string_value: String = take_last_host_return(); let c_string = CString::new(string_value).expect("Failed to create CString"); c_string.into_raw() @@ -151,8 +150,7 @@ pub extern "C" fn hl_get_host_return_value_as_String() -> *const c_char { #[unsafe(no_mangle)] pub extern "C" fn hl_get_host_return_value_as_VecBytes() -> Box { - let vec_value: Vec = - get_host_return_value().expect("Unable to get host return value as vec bytes"); + let vec_value: Vec = take_last_host_return(); Box::new(unsafe { FfiVec::from_vec(vec_value) }) } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index f30955e2c..5391638de 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -17,12 +17,7 @@ limitations under the License. use std::mem::offset_of; use std::num::NonZeroU16; -use flatbuffers::FlatBufferBuilder; -use hyperlight_common::flatbuffer_wrappers::function_call::{ - FunctionCall, validate_guest_function_call_buffer, -}; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; -use hyperlight_common::flatbuffer_wrappers::guest_log_data::GuestLogData; use hyperlight_common::mem::PAGE_SIZE_USIZE; use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{self, Layout as VirtqLayout}; @@ -457,92 +452,6 @@ impl SandboxMemoryManager { Ok(()) } - /// Reads a host function call from memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_host_function_call(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - /// Writes a host function call result to memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn write_response_from_host_function_call( - &mut self, - res: &FunctionCallResult, - ) -> Result<()> { - let mut builder = FlatBufferBuilder::new(); - let data = res.encode(&mut builder); - - self.scratch_mem.push_buffer( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - data, - ) - } - - /// Writes a guest function call to memory - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - #[allow(dead_code)] - pub(crate) fn write_guest_function_call(&mut self, buffer: &[u8]) -> Result<()> { - validate_guest_function_call_buffer(buffer).map_err(|e| { - new_error!( - "Guest function call buffer validation failed: {}", - e.to_string() - ) - })?; - - self.scratch_mem.push_buffer( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - buffer, - )?; - Ok(()) - } - - /// Reads a function call result from memory. - /// A function call result can be either an error or a successful return value. - #[allow(dead_code)] - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_guest_function_call_result(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - /// Read guest log data from the `SharedMemory` contained within `self` - #[allow(dead_code)] - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn read_guest_log_data(&mut self) -> Result { - self.scratch_mem.try_pop_buffer_into::( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) - } - - pub(crate) fn clear_io_buffers(&mut self) { - // Clear the output data buffer - loop { - let Ok(_) = self.scratch_mem.try_pop_buffer_into::>( - self.layout.get_output_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_output_data_size(), - ) else { - break; - }; - } - // Clear the input data buffer - loop { - let Ok(_) = self.scratch_mem.try_pop_buffer_into::>( - self.layout.get_input_data_buffer_scratch_host_offset(), - self.layout.sandbox_memory_config.get_input_data_size(), - ) else { - break; - }; - } - } - /// This function restores a memory snapshot from a given snapshot. pub(crate) fn restore_snapshot( &mut self, @@ -1074,7 +983,7 @@ impl SandboxMemoryManager { return Ok(fcr); } Ok(MsgKind::Log) => { - crate::sandbox::outb::emit_guest_log_from_payload(payload); + crate::sandbox::outb::emit_guest_log(payload); consumer .complete(completion) .map_err(|e| new_error!("G2H complete log: {:?}", e))?; diff --git a/src/hyperlight_host/src/mem/shared_mem.rs b/src/hyperlight_host/src/mem/shared_mem.rs index 6f10bcbf3..4efdf6eb3 100644 --- a/src/hyperlight_host/src/mem/shared_mem.rs +++ b/src/hyperlight_host/src/mem/shared_mem.rs @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::any::type_name; use std::ffi::c_void; use std::io::Error; use std::mem::{align_of, size_of}; @@ -1046,145 +1045,6 @@ impl HostSharedMemory { drop(guard); Ok(()) } - - /// Pushes the given data onto shared memory to the buffer at the given offset. - /// NOTE! buffer_start_offset must point to the beginning of the buffer - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - pub fn push_buffer( - &mut self, - buffer_start_offset: usize, - buffer_size: usize, - data: &[u8], - ) -> Result<()> { - let stack_pointer_rel = self.read::(buffer_start_offset)? as usize; - let buffer_size_u64: u64 = buffer_size.try_into()?; - - if stack_pointer_rel > buffer_size || stack_pointer_rel < 8 { - return Err(new_error!( - "Unable to push data to buffer: Stack pointer is out of bounds. Stack pointer: {}, Buffer size: {}", - stack_pointer_rel, - buffer_size_u64 - )); - } - - let size_required = data.len() + 8; - let size_available = buffer_size - stack_pointer_rel; - - if size_required > size_available { - return Err(new_error!( - "Not enough space in buffer to push data. Required: {}, Available: {}", - size_required, - size_available - )); - } - - // get absolute - let stack_pointer_abs = stack_pointer_rel + buffer_start_offset; - - // write the actual data to the top of stack - self.copy_from_slice(data, stack_pointer_abs)?; - - // write the offset to the newly written data, to the top of stack. - // this is used when popping the stack, to know how far back to jump - self.write::(stack_pointer_abs + data.len(), stack_pointer_rel as u64)?; - - // update stack pointer to point to the next free address - self.write::( - buffer_start_offset, - (stack_pointer_rel + data.len() + 8) as u64, - )?; - Ok(()) - } - - /// Pops the given given buffer into a `T` and returns it. - /// NOTE! the data must be a size-prefixed flatbuffer, and - /// buffer_start_offset must point to the beginning of the buffer - pub fn try_pop_buffer_into( - &mut self, - buffer_start_offset: usize, - buffer_size: usize, - ) -> Result - where - T: for<'b> TryFrom<&'b [u8]>, - { - // get the stackpointer - let stack_pointer_rel = self.read::(buffer_start_offset)? as usize; - - if stack_pointer_rel > buffer_size || stack_pointer_rel < 16 { - return Err(new_error!( - "Unable to pop data from buffer: Stack pointer is out of bounds. Stack pointer: {}, Buffer size: {}", - stack_pointer_rel, - buffer_size - )); - } - - // make it absolute - let last_element_offset_abs = stack_pointer_rel + buffer_start_offset; - - // go back 8 bytes to get offset to element on top of stack - let last_element_offset_rel: usize = - self.read::(last_element_offset_abs - 8)? as usize; - - // Validate element offset (guest-writable): must be in [8, stack_pointer_rel - 16] - // to leave room for the 8-byte back-pointer plus at least 8 bytes of element data - // (the minimum for a size-prefixed flatbuffer: 4-byte prefix + 4-byte root offset). - if last_element_offset_rel > stack_pointer_rel.saturating_sub(16) - || last_element_offset_rel < 8 - { - return Err(new_error!( - "Corrupt buffer back-pointer: element offset {} is outside valid range [8, {}].", - last_element_offset_rel, - stack_pointer_rel.saturating_sub(16), - )); - } - - // make it absolute - let last_element_offset_abs = last_element_offset_rel + buffer_start_offset; - - // Max bytes the element can span (excluding the 8-byte back-pointer). - let max_element_size = stack_pointer_rel - last_element_offset_rel - 8; - - // Get the size of the flatbuffer buffer from memory - let fb_buffer_size = { - let raw_prefix = self.read::(last_element_offset_abs)?; - // flatbuffer byte arrays are prefixed by 4 bytes indicating - // the remaining size; add 4 for the prefix itself. - let total = raw_prefix.checked_add(4).ok_or_else(|| { - new_error!( - "Corrupt buffer size prefix: value {} overflows when adding 4-byte header.", - raw_prefix - ) - })?; - usize::try_from(total) - }?; - - if fb_buffer_size > max_element_size { - return Err(new_error!( - "Corrupt buffer size prefix: flatbuffer claims {} bytes but the element slot is only {} bytes.", - fb_buffer_size, - max_element_size - )); - } - - let mut result_buffer = vec![0; fb_buffer_size]; - - self.copy_to_slice(&mut result_buffer, last_element_offset_abs)?; - let to_return = T::try_from(result_buffer.as_slice()).map_err(|_e| { - new_error!( - "pop_buffer_into: failed to convert buffer to {}", - type_name::() - ) - })?; - - // update the stack pointer to point to the element we just popped off since that is now free - self.write::(buffer_start_offset, last_element_offset_rel as u64)?; - - // zero out the memory we just popped off - let num_bytes_to_zero = stack_pointer_rel - last_element_offset_rel; - self.fill(0, last_element_offset_abs, num_bytes_to_zero)?; - - Ok(to_return) - } } impl SharedMemory for HostSharedMemory { @@ -1692,192 +1552,6 @@ mod tests { } } - /// Bounds checking for `try_pop_buffer_into` against corrupt guest data. - mod try_pop_buffer_bounds { - use super::*; - - #[derive(Debug, PartialEq)] - struct RawBytes(Vec); - - impl TryFrom<&[u8]> for RawBytes { - type Error = String; - fn try_from(value: &[u8]) -> std::result::Result { - Ok(RawBytes(value.to_vec())) - } - } - - /// Create a buffer with stack pointer initialized to 8 (empty). - fn make_buffer(mem_size: usize) -> super::super::HostSharedMemory { - let eshm = ExclusiveSharedMemory::new(mem_size).unwrap(); - let (hshm, _) = eshm.build(); - hshm.write::(0, 8u64).unwrap(); - hshm - } - - #[test] - fn normal_push_pop_roundtrip() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - // Size-prefixed flatbuffer-like payload: [size: u32 LE][payload] - let payload = b"hello"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - - hshm.push_buffer(0, mem_size, &data).unwrap(); - let result: RawBytes = hshm.try_pop_buffer_into(0, mem_size).unwrap(); - assert_eq!(result.0, data); - } - - #[test] - fn malicious_flatbuffer_size_prefix() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"small"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt size prefix at element start (offset 8) to near u32::MAX. - hshm.write::(8, 0xFFFF_FFFBu32).unwrap(); // +4 = 0xFFFF_FFFF - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 4294967295 bytes but the element slot is only 9 bytes"), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_element_offset_too_small() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt back-pointer (offset 16) to 0 (before valid range). - hshm.write::(16, 0u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 0 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_element_offset_past_stack_pointer() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt back-pointer (offset 16) to 9999 (past stack pointer 24). - hshm.write::(16, 9999u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 9999 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - #[test] - fn malicious_flatbuffer_size_off_by_one() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"abcd"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Corrupt size prefix: claim 5 bytes (total 9), exceeding the 8-byte slot. - hshm.write::(8, 5u32).unwrap(); // fb_buffer_size = 5 + 4 = 9 - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 9 bytes but the element slot is only 8 bytes"), - "Unexpected error message: {}", - err_msg - ); - } - - /// Back-pointer just below stack_pointer causes underflow in - /// `stack_pointer_rel - last_element_offset_rel - 8`. - #[test] - fn back_pointer_near_stack_pointer_underflow() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // stack_pointer_rel = 24. Set back-pointer to 23 (> 24 - 16 = 8, so rejected). - hshm.write::(16, 23u64).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 23 is outside valid range [8, 8]" - ), - "Unexpected error message: {}", - err_msg - ); - } - - /// Size prefix of 0xFFFF_FFFD causes u32 overflow: 0xFFFF_FFFD + 4 wraps. - #[test] - fn size_prefix_u32_overflow() { - let mem_size = 4096; - let mut hshm = make_buffer(mem_size); - - let payload = b"test"; - let mut data = Vec::new(); - data.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - data.extend_from_slice(payload); - hshm.push_buffer(0, mem_size, &data).unwrap(); - - // Write 0xFFFF_FFFD as size prefix: checked_add(4) returns None. - hshm.write::(8, 0xFFFF_FFFDu32).unwrap(); - - let result: Result = hshm.try_pop_buffer_into(0, mem_size); - let err_msg = format!("{}", result.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: value 4294967293 overflows when adding 4-byte header"), - "Unexpected error message: {}", - err_msg - ); - } - } - #[cfg(target_os = "linux")] mod guard_page_crash_test { use crate::mem::shared_mem::{ExclusiveSharedMemory, SharedMemory}; diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 270b7460c..687959bec 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -745,8 +745,6 @@ impl MultiUseSandbox { // - any serialized host function call are zeroed out by us (the host) during deserialization, see `get_host_function_call` // - any serialized host function result is zeroed out by the guest during deserialization, see `get_host_return_value` if let Err(e) = &res { - self.mem_mgr.clear_io_buffers(); - // Determine if we should poison the sandbox. self.poisoned |= e.is_poison_error(); } diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index b5a20d31c..3fa571db2 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -64,32 +64,15 @@ pub enum HandleOutbError { MemProfile(String), } -#[allow(dead_code)] -#[instrument(err(Debug), skip_all, parent = Span::current(), level="Trace")] -pub(super) fn outb_log( - mgr: &mut SandboxMemoryManager, -) -> Result<(), HandleOutbError> { - let log_data: GuestLogData = mgr - .read_guest_log_data() - .map_err(|e| HandleOutbError::ReadLogData(e.to_string()))?; - - emit_guest_log(&log_data); - Ok(()) -} - /// Emit a guest log record from a virtqueue payload. /// /// Deserializes [`GuestLogData`] from the raw bytes and emits either -/// a tracing event or a log record, matching the original `outb_log` -/// behavior. -pub(crate) fn emit_guest_log_from_payload(payload: &[u8]) { +/// a tracing event or a log record. +pub(crate) fn emit_guest_log(payload: &[u8]) { let Ok(log_data) = GuestLogData::try_from(payload) else { return; }; - emit_guest_log(&log_data); -} -fn emit_guest_log(log_data: &GuestLogData) { // This code will create either a logging record or a tracing record // for the GuestLogData depending on if the host has set up a tracing // subscriber. @@ -233,7 +216,7 @@ fn outb_virtq_call( match hdr.msg_kind() { Ok(MsgKind::Log) => { let payload = &entry_data[VirtqMsgHeader::SIZE..]; - emit_guest_log_from_payload(payload); + emit_guest_log(payload); let _ = consumer.complete(completion); continue; } @@ -306,30 +289,9 @@ pub(crate) fn handle_outb( .try_into() .map_err(|e: anyhow::Error| HandleOutbError::InvalidPort(e.to_string()))? { - OutBAction::Log => { - // Legacy path - logs now arrive via G2H virtqueue - // and are processed inline by outb_virtq_call / - // read_h2g_result_from_g2h. - Ok(()) - } - OutBAction::CallFunction => { - let call = mem_mgr - .get_host_function_call() - .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; - let name = call.function_name.clone(); - let args: Vec = call.parameters.unwrap_or(vec![]); - let res = host_funcs - .try_lock() - .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))? - .call_host_function(&name, args) - .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); - - let func_result = FunctionCallResult::new(res); - - mem_mgr - .write_response_from_host_function_call(&func_result) - .map_err(|e| HandleOutbError::WriteHostFunctionResponse(e.to_string()))?; - + OutBAction::Log | OutBAction::CallFunction => { + // Legacy paths removed - these actions should no longer be + // emitted by the guest. Ignore gracefully. Ok(()) } OutBAction::Abort => outb_abort(mem_mgr, data), @@ -353,251 +315,3 @@ pub(crate) fn handle_outb( OutBAction::TraceMemoryFree => trace_info.handle_trace_mem_free(regs, mem_mgr), } } -#[cfg(test)] -mod tests { - use hyperlight_common::flatbuffer_wrappers::guest_log_level::LogLevel; - use hyperlight_testing::logger::{LOGGER, Logger}; - use hyperlight_testing::simple_guest_as_string; - use log::Level; - use tracing_core::callsite::rebuild_interest_cache; - - use super::outb_log; - use crate::GuestBinary; - use crate::mem::mgr::SandboxMemoryManager; - use crate::sandbox::SandboxConfiguration; - use crate::sandbox::outb::GuestLogData; - use crate::testing::log_values::test_value_as_str; - - fn new_guest_log_data(level: LogLevel) -> GuestLogData { - GuestLogData::new( - "test log".to_string(), - "test source".to_string(), - level, - "test caller".to_string(), - "test source file".to_string(), - 123, - ) - } - - #[test] - #[ignore] - fn test_log_outb_log() { - Logger::initialize_test_logger(); - LOGGER.set_max_level(log::LevelFilter::Off); - - let sandbox_cfg = SandboxConfiguration::default(); - - let new_mgr = || { - let bin = GuestBinary::FilePath(simple_guest_as_string().unwrap()); - let snapshot = crate::sandbox::snapshot::Snapshot::from_env(bin, sandbox_cfg).unwrap(); - let mgr = SandboxMemoryManager::from_snapshot(&snapshot).unwrap(); - let (hmgr, _) = mgr.build().unwrap(); - hmgr - }; - { - // We set a logger but there is no guest log data - // in memory, so expect a log operation to fail - let mut mgr = new_mgr(); - assert!(outb_log(&mut mgr).is_err()); - } - { - // Write a log message so outb_log will succeed. - // Since the logger level is set off, expect logs to be no-ops - let mut mgr = new_mgr(); - let log_msg = new_guest_log_data(LogLevel::Information); - - let guest_log_data_buffer: Vec = log_msg.try_into().unwrap(); - let offset = mgr.layout.get_output_data_buffer_scratch_host_offset(); - mgr.scratch_mem - .push_buffer( - offset, - sandbox_cfg.get_output_data_size(), - &guest_log_data_buffer, - ) - .unwrap(); - - let res = outb_log(&mut mgr); - assert!(res.is_ok()); - assert_eq!(0, LOGGER.num_log_calls()); - LOGGER.clear_log_calls(); - } - { - // now, test logging - LOGGER.set_max_level(log::LevelFilter::Trace); - let mut mgr = new_mgr(); - LOGGER.clear_log_calls(); - - // set up the logger and set the log level to the maximum - // possible (Trace) to ensure we're able to test all - // the possible branches of the match in outb_log - - let levels = vec![ - LogLevel::Trace, - LogLevel::Debug, - LogLevel::Information, - LogLevel::Warning, - LogLevel::Error, - LogLevel::Critical, - LogLevel::None, - ]; - for level in levels { - let layout = mgr.layout; - let log_data = new_guest_log_data(level); - - let guest_log_data_buffer: Vec = log_data.clone().try_into().unwrap(); - mgr.scratch_mem - .push_buffer( - layout.get_output_data_buffer_scratch_host_offset(), - sandbox_cfg.get_output_data_size(), - guest_log_data_buffer.as_slice(), - ) - .unwrap(); - - outb_log(&mut mgr).unwrap(); - - LOGGER.test_log_records(|log_calls| { - let expected_level: Level = (&level).into(); - - assert!( - log_calls - .iter() - .filter(|log_call| { - log_call.level == expected_level - && log_call.line == Some(log_data.line) - && log_call.args == log_data.message - && log_call.module_path == Some(log_data.source.clone()) - && log_call.file == Some(log_data.source_file.clone()) - }) - .count() - == 1, - "log call did not occur for level {:?}", - level.clone() - ); - }); - } - } - } - - // Tests that outb_log emits traces when a trace subscriber is set - // this test is ignored because it is incompatible with other tests , specifically those which require a logger for tracing - // marking this test as ignored means that running `cargo test` will not run this test but will allow a developer who runs that command - // from their workstation to be successful without needed to know about test interdependencies - // this test will be run explicitly as a part of the CI pipeline - #[ignore] - #[test] - fn test_trace_outb_log() { - Logger::initialize_log_tracer(); - rebuild_interest_cache(); - let subscriber = - hyperlight_testing::tracing_subscriber::TracingSubscriber::new(tracing::Level::TRACE); - let sandbox_cfg = SandboxConfiguration::default(); - tracing::subscriber::with_default(subscriber.clone(), || { - let new_mgr = || { - let bin = GuestBinary::FilePath(simple_guest_as_string().unwrap()); - let snapshot = - crate::sandbox::snapshot::Snapshot::from_env(bin, sandbox_cfg).unwrap(); - let mgr = SandboxMemoryManager::from_snapshot(&snapshot).unwrap(); - let (hmgr, _) = mgr.build().unwrap(); - hmgr - }; - - // as a span does not exist one will be automatically created - // after that there will be an event for each log message - // we are interested only in the events for the log messages that we created - - let levels = vec![ - LogLevel::Trace, - LogLevel::Debug, - LogLevel::Information, - LogLevel::Warning, - LogLevel::Error, - LogLevel::Critical, - LogLevel::None, - ]; - for level in levels { - let mut mgr = new_mgr(); - let layout = mgr.layout; - let log_data: GuestLogData = new_guest_log_data(level); - subscriber.clear(); - - let guest_log_data_buffer: Vec = log_data.try_into().unwrap(); - mgr.scratch_mem - .push_buffer( - layout.get_output_data_buffer_scratch_host_offset(), - sandbox_cfg.get_output_data_size(), - guest_log_data_buffer.as_slice(), - ) - .unwrap(); - subscriber.clear(); - outb_log(&mut mgr).unwrap(); - - subscriber.test_trace_records(|spans, events| { - let expected_level = match level { - LogLevel::Trace => "TRACE", - LogLevel::Debug => "DEBUG", - LogLevel::Information => "INFO", - LogLevel::Warning => "WARN", - LogLevel::Error => "ERROR", - LogLevel::Critical => "ERROR", - LogLevel::None => "TRACE", - }; - - // We cannot get the parent span using the `current_span()` method as by the time we get to this point that span has been exited so there is no current span - // We need to make sure that the span that we created is in the spans map instead - // We expect to have created 21 spans at this point. We are only interested in the first one that was created when calling outb_log. - - assert!( - spans.len() == 21, - "expected 21 spans, found {}", - spans.len() - ); - - let span_value = spans - .get(&1) - .unwrap() - .as_object() - .unwrap() - .get("span") - .unwrap() - .get("attributes") - .unwrap() - .as_object() - .unwrap() - .get("metadata") - .unwrap() - .as_object() - .unwrap(); - - //test_value_as_str(span_value, "level", "INFO"); - test_value_as_str(span_value, "module_path", "hyperlight_host::sandbox::outb"); - let expected_file = if cfg!(windows) { - "src\\hyperlight_host\\src\\sandbox\\outb.rs" - } else { - "src/hyperlight_host/src/sandbox/outb.rs" - }; - test_value_as_str(span_value, "file", expected_file); - test_value_as_str(span_value, "target", "hyperlight_host::sandbox::outb"); - - let mut count_matching_events = 0; - - for json_value in events { - let event_values = json_value.as_object().unwrap().get("event").unwrap(); - let metadata_values_map = - event_values.get("metadata").unwrap().as_object().unwrap(); - let event_values_map = event_values.as_object().unwrap(); - test_value_as_str(metadata_values_map, "level", expected_level); - test_value_as_str(event_values_map, "log.file", "test source file"); - test_value_as_str(event_values_map, "log.module_path", "test source"); - test_value_as_str(event_values_map, "log.target", "hyperlight_guest"); - count_matching_events += 1; - } - assert!( - count_matching_events == 1, - "trace log call did not occur for level {:?}", - level.clone() - ); - }); - } - }); - } -} diff --git a/src/hyperlight_host/src/testing/log_values.rs b/src/hyperlight_host/src/testing/log_values.rs deleted file mode 100644 index 47f40ae0a..000000000 --- a/src/hyperlight_host/src/testing/log_values.rs +++ /dev/null @@ -1,62 +0,0 @@ -/* -Copyright 2025 The Hyperlight Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -use serde_json::{Map, Value}; - -use crate::{Result, new_error}; - -/// Call `check_value_as_str` and panic if it returned an `Err`. Otherwise, -/// do nothing. -#[track_caller] -pub(crate) fn test_value_as_str(values: &Map, key: &str, expected_value: &str) { - if let Err(e) = check_value_as_str(values, key, expected_value) { - panic!("{e:?}"); - } -} - -/// Check to see if the value in `values` for key `key` matches -/// `expected_value`. If so, return `Ok(())`. Otherwise, return an `Err` -/// indicating the mismatch. -pub(crate) fn check_value_as_str( - values: &Map, - key: &str, - expected_value: &str, -) -> Result<()> { - let value = try_to_string(values, key)?; - if expected_value != value { - return Err(new_error!( - "expected value {} != value {}", - expected_value, - value - )); - } - Ok(()) -} - -/// Fetch the value in `values` with key `key` and, if it existed, convert -/// it to a string. If all those steps succeeded, return an `Ok` with the -/// string value inside. Otherwise, return an `Err`. -fn try_to_string<'a>(values: &'a Map, key: &'a str) -> Result<&'a str> { - if let Some(value) = values.get(key) { - if let Some(value_str) = value.as_str() { - Ok(value_str) - } else { - Err(new_error!("value with key {} was not a string", key)) - } - } else { - Err(new_error!("value for key {} was not found", key)) - } -} diff --git a/src/hyperlight_host/src/testing/mod.rs b/src/hyperlight_host/src/testing/mod.rs index 26776b405..503fda1ee 100644 --- a/src/hyperlight_host/src/testing/mod.rs +++ b/src/hyperlight_host/src/testing/mod.rs @@ -13,4 +13,3 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -pub(crate) mod log_values; diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index b4c2a3a6a..c2b0f5902 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -555,3 +555,105 @@ fn virtq_backpressure_no_data_loss() { assert!((res - 1.234).abs() < f64::EPSILON); }); } + +#[test] +fn virtq_log_tracing_delivery() { + // Verify guest logs are emitted as tracing events when a tracing + // subscriber is active, matching the behavior of the old outb_log. + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::INFO); + let mut sandbox = sbox.evolve().unwrap(); + + subscriber.clear(); + + sandbox + .call::<()>("LogMessage", ("tracing delivery test".to_string(), 3_i32)) + .unwrap(); + + // Guest log goes through format_trace which creates tracing + // events with log.target = "hyperlight_guest" as a field. + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events after guest log call, got none" + ); + }); + }); +} + +#[test] +fn virtq_log_tracing_levels() { + // Verify each guest log level produces tracing events. + use hyperlight_testing::tracing_subscriber::TracingSubscriber; + + let subscriber = TracingSubscriber::new(tracing::Level::TRACE); + + tracing::subscriber::with_default(subscriber.clone(), || { + with_rust_uninit_sandbox(|mut sbox| { + sbox.set_max_guest_log_level(tracing_core::LevelFilter::TRACE); + let mut sandbox = sbox.evolve().unwrap(); + + // Test each level: 1=Trace, 2=Debug, 3=Info, 4=Warn, 5=Error + for level in [1_i32, 2, 3, 4, 5] { + subscriber.clear(); + let msg = format!("level-test-{}", level); + sandbox.call::<()>("LogMessage", (msg, level)).unwrap(); + + let events = subscriber.get_events(); + assert!( + !events.is_empty(), + "expected tracing events for guest log level {}", + level + ); + } + }); + }); +} + +#[test] +fn virtq_invalid_guest_function_returns_error() { + // Calling a non-existent guest function should return a proper + // GuestError, not corrupt data or a hang. This validates that + // the virtq error path (MsgKind::Response with GuestError payload) + // works end-to-end. + with_rust_sandbox_cfg(SandboxConfiguration::default(), |mut sandbox| { + let res = sandbox.call::<()>("ThisFunctionDoesNotExist", ()); + assert!(res.is_err(), "expected error for non-existent function"); + let err = res.unwrap_err(); + assert!( + matches!( + err, + HyperlightError::GuestError( + hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode::GuestFunctionNotFound, + _ + ) + ), + "expected GuestFunctionNotFound, got {:?}", + err + ); + }); +} + +#[test] +fn virtq_large_payload_roundtrip() { + // Verify that larger payloads survive the virtq roundtrip without corruption. + with_rust_sandbox_cfg(SandboxConfiguration::default(), |mut sandbox| { + // 1KB string + let large_msg: String = "X".repeat(1024); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + + // 1KB byte array + let large_bytes = vec![0xABu8; 1024]; + let res: Vec = sandbox + .call("SetByteArrayToZero", large_bytes.clone()) + .unwrap(); + assert_eq!(res.len(), 1024); + assert!(res.iter().all(|&b| b == 0)); + }); +} diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index 3e7d89ee6..e122c8727 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -49,8 +49,7 @@ use hyperlight_guest_bin::exception::arch::{Context, ExceptionInfo}; use hyperlight_guest_bin::guest_function::definition::{GuestFunc, GuestFunctionDefinition}; use hyperlight_guest_bin::guest_function::register::register_function; use hyperlight_guest_bin::host_comm::{ - call_host_function, call_host_function_without_returning_result, get_host_return_value_raw, - print_output_with_host_print, read_n_bytes_from_user_memory, + call_host_function, print_output_with_host_print, read_n_bytes_from_user_memory, }; use hyperlight_guest_bin::memory::malloc; use hyperlight_guest_bin::{GUEST_HANDLE, guest_function, guest_logger, host_function}; @@ -1045,32 +1044,23 @@ fn fuzz_host_function(func: FunctionCall) -> Result> { } }; - // Because we do not know at compile time the actual return type of the host function to be called - // we cannot use the `call_host_function` generic function. - // We need to use the `call_host_function_without_returning_result` function that does not retrieve the return - // value - call_host_function_without_returning_result( - &host_func_name, - Some(params), - func.expected_return_type, - ) - .expect("failed to call host function"); - - let host_return = get_host_return_value_raw(); - match host_return { - Ok(return_value) => match return_value { - ReturnValue::Int(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::UInt(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Long(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::ULong(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Float(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::Double(i) => Ok(get_flatbuffer_result(i)), - ReturnValue::String(str) => Ok(get_flatbuffer_result(str.as_str())), - ReturnValue::Bool(bool) => Ok(get_flatbuffer_result(bool)), - ReturnValue::Void(()) => Ok(get_flatbuffer_result(())), - ReturnValue::VecBytes(byte) => Ok(get_flatbuffer_result(byte.as_slice())), - }, - Err(e) => Err(e), + // Call the host function with dynamic return type. Since we don't + // know T at compile time, use ReturnValue as the return type and + // match on the result. + let return_value: ReturnValue = + call_host_function(&host_func_name, Some(params), func.expected_return_type)?; + + match return_value { + ReturnValue::Int(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::UInt(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Long(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::ULong(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Float(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::Double(i) => Ok(get_flatbuffer_result(i)), + ReturnValue::String(str) => Ok(get_flatbuffer_result(str.as_str())), + ReturnValue::Bool(bool) => Ok(get_flatbuffer_result(bool)), + ReturnValue::Void(()) => Ok(get_flatbuffer_result(())), + ReturnValue::VecBytes(byte) => Ok(get_flatbuffer_result(byte.as_slice())), } } From 9f2fc788f8c5e1a5669be9a4dad77824384bb3fd Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Tue, 7 Apr 2026 19:13:58 +0200 Subject: [PATCH 12/26] feat(virtq): remove input output regions from ABI Signed-off-by: Tomasz Andrzejak --- fuzz/fuzz_targets/host_call.rs | 6 +- .../src/arch/aarch64/layout.rs | 7 +- .../src/arch/amd64/layout.rs | 12 +- src/hyperlight_common/src/arch/i686/layout.rs | 7 +- src/hyperlight_common/src/layout.rs | 21 +-- src/hyperlight_common/src/mem.rs | 2 - src/hyperlight_guest_capi/src/dispatch.rs | 20 +- src/hyperlight_host/benches/benchmarks.rs | 6 +- .../src/hypervisor/hyperlight_vm/x86_64.rs | 11 +- src/hyperlight_host/src/mem/layout.rs | 128 +------------ src/hyperlight_host/src/mem/mgr.rs | 11 -- src/hyperlight_host/src/sandbox/config.rs | 173 ++++++------------ .../src/sandbox/initialized_multi_use.rs | 8 +- .../src/sandbox/uninitialized.rs | 12 +- src/hyperlight_host/tests/integration_test.rs | 36 ---- .../tests/sandbox_host_tests.rs | 5 +- src/tests/rust_guests/simpleguest/src/main.rs | 49 +---- 17 files changed, 105 insertions(+), 409 deletions(-) diff --git a/fuzz/fuzz_targets/host_call.rs b/fuzz/fuzz_targets/host_call.rs index b0d37cf1a..0559b2bd6 100644 --- a/fuzz/fuzz_targets/host_call.rs +++ b/fuzz/fuzz_targets/host_call.rs @@ -33,9 +33,9 @@ static SANDBOX: OnceLock> = OnceLock::new(); fuzz_target!( init: { let mut cfg = SandboxConfiguration::default(); - cfg.set_output_data_size(64 * 1024); // 64 KB output buffer - cfg.set_input_data_size(64 * 1024); // 64 KB input buffer - cfg.set_scratch_size(512 * 1024); // large scratch region to contain those buffers, any data copies, etc. + cfg.set_g2h_pool_pages(16); // 64 KB / 4096 = 16 pages + cfg.set_h2g_pool_pages(16); // 64 KB / 4096 = 16 pages + cfg.set_scratch_size(512 * 1024); // large scratch region let u_sbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_for_fuzzing_as_string().expect("Guest Binary Missing")), Some(cfg) diff --git a/src/hyperlight_common/src/arch/aarch64/layout.rs b/src/hyperlight_common/src/arch/aarch64/layout.rs index 25bd99a1e..9f9c504a6 100644 --- a/src/hyperlight_common/src/arch/aarch64/layout.rs +++ b/src/hyperlight_common/src/arch/aarch64/layout.rs @@ -20,11 +20,6 @@ pub const SNAPSHOT_PT_GVA_MIN: usize = 0xffff_8000_0000_0000; pub const SNAPSHOT_PT_GVA_MAX: usize = 0xffff_80ff_ffff_ffff; pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; -pub fn min_scratch_size( - _input_data_size: usize, - _output_data_size: usize, - _g2h_num_descs: usize, - _h2g_num_descs: usize, -) -> usize { +pub fn min_scratch_size(_g2h_num_descs: usize, _h2g_num_descs: usize) -> usize { unimplemented!("min_scratch_size") } diff --git a/src/hyperlight_common/src/arch/amd64/layout.rs b/src/hyperlight_common/src/arch/amd64/layout.rs index 4731f21b2..12644de6c 100644 --- a/src/hyperlight_common/src/arch/amd64/layout.rs +++ b/src/hyperlight_common/src/arch/amd64/layout.rs @@ -37,17 +37,11 @@ pub const MAX_GPA: usize = 0x0000_000f_ffff_ffff; /// - A page for the smallest possible non-exception stack /// - (up to) 3 pages for mapping that /// - Two pages for the exception stack and metadata -/// - A page-aligned amount of memory for I/O buffers and virtqueue rings -pub fn min_scratch_size( - input_data_size: usize, - output_data_size: usize, - g2h_num_descs: usize, - h2g_num_descs: usize, -) -> usize { +/// - A page-aligned amount of memory for virtqueue rings +pub fn min_scratch_size(g2h_num_descs: usize, h2g_num_descs: usize) -> usize { let g2h_ring_size = crate::virtq::Layout::query_size(g2h_num_descs); let h2g_ring_size = crate::virtq::Layout::query_size(h2g_num_descs); - (input_data_size + output_data_size + g2h_ring_size + h2g_ring_size) - .next_multiple_of(crate::vmem::PAGE_SIZE) + (g2h_ring_size + h2g_ring_size).next_multiple_of(crate::vmem::PAGE_SIZE) + 12 * crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/arch/i686/layout.rs b/src/hyperlight_common/src/arch/i686/layout.rs index 08c9ec594..0d47af909 100644 --- a/src/hyperlight_common/src/arch/i686/layout.rs +++ b/src/hyperlight_common/src/arch/i686/layout.rs @@ -20,11 +20,6 @@ limitations under the License. pub const MAX_GVA: usize = 0xffff_ffff; pub const MAX_GPA: usize = 0xffff_ffff; -pub fn min_scratch_size( - _input_data_size: usize, - _output_data_size: usize, - _g2h_num_descs: usize, - _h2g_num_descs: usize, -) -> usize { +pub fn min_scratch_size(_g2h_num_descs: usize, _h2g_num_descs: usize) -> usize { crate::vmem::PAGE_SIZE } diff --git a/src/hyperlight_common/src/layout.rs b/src/hyperlight_common/src/layout.rs index f6c8b1caa..63174b530 100644 --- a/src/hyperlight_common/src/layout.rs +++ b/src/hyperlight_common/src/layout.rs @@ -83,27 +83,20 @@ pub const fn scratch_top_ptr(offset: u64) -> *mut T { /// Compute the byte offset from the scratch base to the G2H ring. /// -/// TODO(virtq): Remove input/output -pub const fn g2h_ring_scratch_offset(input_data_size: usize, output_data_size: usize) -> usize { - let io_off = input_data_size + output_data_size; - let align = crate::virtq::Descriptor::ALIGN; - - (io_off + align - 1) & !(align - 1) +/// The G2H ring starts at offset 0, aligned to descriptor alignment. +pub const fn g2h_ring_scratch_offset() -> usize { + 0 } /// Compute the byte offset from the scratch base to the H2G ring. /// -/// TODO(ring): Remove input/output -pub const fn h2g_ring_scratch_offset( - input_data_size: usize, - output_data_size: usize, - g2h_num_descs: usize, -) -> usize { - let g2h_offset = g2h_ring_scratch_offset(input_data_size, output_data_size); +/// The H2G ring follows immediately after the G2H ring, aligned to +/// descriptor alignment. +pub const fn h2g_ring_scratch_offset(g2h_num_descs: usize) -> usize { let g2h_size = crate::virtq::Layout::query_size(g2h_num_descs); let align = crate::virtq::Descriptor::ALIGN; - (g2h_offset + g2h_size + align - 1) & !(align - 1) + (g2h_size + align - 1) & !(align - 1) } /// Compute the minimum scratch region size needed for a sandbox. diff --git a/src/hyperlight_common/src/mem.rs b/src/hyperlight_common/src/mem.rs index fb850acc8..1cdd65cef 100644 --- a/src/hyperlight_common/src/mem.rs +++ b/src/hyperlight_common/src/mem.rs @@ -68,8 +68,6 @@ impl Default for FileMappingInfo { #[derive(Debug, Clone, Copy)] #[repr(C)] pub struct HyperlightPEB { - pub input_stack: GuestMemoryRegion, - pub output_stack: GuestMemoryRegion, pub init_data: GuestMemoryRegion, pub guest_heap: GuestMemoryRegion, /// File mappings array descriptor. diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index 245eaf700..4fd61df44 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -109,10 +109,22 @@ pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let return_type = unsafe { function_call.copy_return_type() }; virtq::with_context(|ctx| { - let result: ReturnValue = ctx - .call_host_function(&func_name, Some(parameters), return_type) - .expect("Failed to call host function"); - ctx.stash_host_return(result); + match ctx.call_host_function::(&func_name, Some(parameters), return_type) { + Ok(result) => ctx.stash_host_return(result), + Err(e) => { + // Host function returned an error. Abort with the error + // message so the host can capture it via the abort buffer. + let msg = alloc::ffi::CString::new(e.message) + .unwrap_or_else(|_| alloc::ffi::CString::new("host error").unwrap()); + + unsafe { + hyperlight_guest::exit::abort_with_code_and_message( + &[e.kind as u8], + msg.as_ptr(), + ); + } + } + } }); } diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index c7cbc2631..f8b6990a0 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -385,9 +385,9 @@ fn guest_call_benchmark_large_param(c: &mut Criterion) { let large_string = String::from_utf8(large_vec.clone()).unwrap(); let mut config = SandboxConfiguration::default(); - config.set_input_data_size(2 * SIZE + (1024 * 1024)); // 2 * SIZE + 1 MB, to allow 1MB for the rest of the serialized function call + config.set_h2g_pool_pages((2 * SIZE + (1024 * 1024)) / 4096); // pool pages for the large input config.set_heap_size(SIZE as u64 * 15); - config.set_scratch_size(6 * SIZE + 4 * (1024 * 1024)); // Big enough for the IO data regions and enough of the heap to be used + config.set_scratch_size(6 * SIZE + 4 * (1024 * 1024)); // Big enough for any data copies, etc. let sandbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), @@ -470,7 +470,7 @@ fn sample_workloads_benchmark(c: &mut Criterion) { fn bench_24k_in_8k_out(b: &mut criterion::Bencher, guest_path: String) { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(25 * 1024); + cfg.set_h2g_pool_pages(7); // 25 * 1024 / 4096 ~= 7 pages let mut sandbox = UninitializedSandbox::new(GuestBinary::FilePath(guest_path), Some(cfg)) .unwrap() diff --git a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs index 698ab49e5..2b033f23b 100644 --- a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs +++ b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs @@ -2125,17 +2125,18 @@ mod tests { } /// Creates VM with guest code that: dirtys FPU (if flag==0), does FXSAVE to buffer, sets flag=1. - /// Uses output_data region for FXSAVE buffer (like regular guest output), scratch for stack. + /// Uses scratch region after rings for FXSAVE buffer. fn hyperlight_vm_with_mem_mgr_fxsave() -> FxsaveTestContext { use iced_x86::code_asm::*; // Compute fixed addresses for FXSAVE buffer and flag. - // These are in the output_data region which starts at a known offset. - // We use a default SandboxConfiguration to get the same layout as create_test_vm_context. + // We use the page-table area in scratch after rings as a + // convenient 512-byte aligned buffer for FXSAVE. let config: SandboxConfiguration = Default::default(); let layout = SandboxMemoryLayout::new(config, 512, 4096, None).unwrap(); - let fxsave_offset = layout.get_output_data_buffer_scratch_host_offset(); - let fxsave_gva = layout.get_output_data_buffer_gva(); + let fxsave_offset = layout.get_pt_base_scratch_offset(); + let fxsave_gva = hyperlight_common::layout::scratch_base_gva(config.get_scratch_size()) + + fxsave_offset as u64; let flag_gva = fxsave_gva + 512; let mut a = CodeAssembler::new(64).unwrap(); diff --git a/src/hyperlight_host/src/mem/layout.rs b/src/hyperlight_host/src/mem/layout.rs index ccf842268..aa11defb4 100644 --- a/src/hyperlight_host/src/mem/layout.rs +++ b/src/hyperlight_host/src/mem/layout.rs @@ -47,17 +47,12 @@ limitations under the License. //! //! There is also a scratch region at the top of physical memory, //! which is mostly laid out as a large undifferentiated blob of -//! memory, although at present the snapshot process specially -//! privileges the statically allocated input and output data regions: +//! memory: //! //! +-------------------------------------------+ (top of physical memory) //! | Exception Stack, Metadata | //! +-------------------------------------------+ (1 page below) //! | Scratch Memory | -//! +-------------------------------------------+ -//! | Output Data | -//! +-------------------------------------------+ -//! | Input Data | //! +-------------------------------------------+ (scratch size) use std::fmt::Debug; @@ -223,8 +218,6 @@ pub(crate) struct SandboxMemoryLayout { /// The following fields are offsets to the actual PEB struct fields. /// They are used when writing the PEB struct itself peb_offset: usize, - peb_input_data_offset: usize, - peb_output_data_offset: usize, peb_init_data_offset: usize, peb_heap_data_offset: usize, #[cfg(feature = "nanvix-unstable")] @@ -265,14 +258,6 @@ impl Debug for SandboxMemoryLayout { .field("PEB Address", &format_args!("{:#x}", self.peb_address)) .field("PEB Offset", &format_args!("{:#x}", self.peb_offset)) .field("Code Size", &format_args!("{:#x}", self.code_size)) - .field( - "Input Data Offset", - &format_args!("{:#x}", self.peb_input_data_offset), - ) - .field( - "Output Data Offset", - &format_args!("{:#x}", self.peb_output_data_offset), - ) .field( "Init Data Offset", &format_args!("{:#x}", self.peb_init_data_offset), @@ -321,9 +306,6 @@ impl SandboxMemoryLayout { #[cfg(feature = "nanvix-unstable")] pub(crate) const BASE_ADDRESS: usize = 0x0; - // the offset into a sandbox's input/output buffer where the stack starts - pub(crate) const STACK_POINTER_SIZE_BYTES: u64 = 8; - /// Create a new `SandboxMemoryLayout` with the given /// `SandboxConfiguration`, code size and stack/heap size. #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] @@ -339,8 +321,6 @@ impl SandboxMemoryLayout { return Err(MemoryRequestTooBig(scratch_size, Self::MAX_MEMORY_SIZE)); } let min_scratch_size = hyperlight_common::layout::min_scratch_size( - cfg.get_input_data_size(), - cfg.get_output_data_size(), cfg.get_g2h_queue_depth(), cfg.get_h2g_queue_depth(), ); @@ -351,8 +331,6 @@ impl SandboxMemoryLayout { let guest_code_offset = 0; // The following offsets are to the fields of the PEB struct itself! let peb_offset = code_size.next_multiple_of(PAGE_SIZE_USIZE); - let peb_input_data_offset = peb_offset + offset_of!(HyperlightPEB, input_stack); - let peb_output_data_offset = peb_offset + offset_of!(HyperlightPEB, output_stack); let peb_init_data_offset = peb_offset + offset_of!(HyperlightPEB, init_data); let peb_heap_data_offset = peb_offset + offset_of!(HyperlightPEB, guest_heap); #[cfg(feature = "nanvix-unstable")] @@ -387,8 +365,6 @@ impl SandboxMemoryLayout { let mut ret = Self { peb_offset, heap_size, - peb_input_data_offset, - peb_output_data_offset, peb_init_data_offset, peb_heap_data_offset, #[cfg(feature = "nanvix-unstable")] @@ -409,13 +385,6 @@ impl SandboxMemoryLayout { Ok(ret) } - /// Get the offset in guest memory to the output data size - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn get_output_data_size_offset(&self) -> usize { - // The size field is the first field in the `OutputData` struct - self.peb_output_data_offset - } - /// Get the offset in guest memory to the init data size #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(super) fn get_init_data_size_offset(&self) -> usize { @@ -428,14 +397,6 @@ impl SandboxMemoryLayout { self.scratch_size } - /// Get the offset in guest memory to the output data pointer. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_output_data_pointer_offset(&self) -> usize { - // This field is immediately after the output data size field, - // which is a `u64`. - self.get_output_data_size_offset() + size_of::() - } - /// Get the offset in guest memory to the init data pointer. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(super) fn get_init_data_pointer_offset(&self) -> usize { @@ -444,54 +405,9 @@ impl SandboxMemoryLayout { self.get_init_data_size_offset() + size_of::() } - /// Get the guest virtual address of the start of output data. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_buffer_gva(&self) -> u64 { - hyperlight_common::layout::scratch_base_gva(self.scratch_size) - + self.sandbox_memory_config.get_input_data_size() as u64 - } - - /// Get the offset into the host scratch buffer of the start of - /// the output data. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_buffer_scratch_host_offset(&self) -> usize { - self.sandbox_memory_config.get_input_data_size() - } - - /// Get the offset in guest memory to the input data size. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(super) fn get_input_data_size_offset(&self) -> usize { - // The input data size is the first field in the input stack's `GuestMemoryRegion` struct - self.peb_input_data_offset - } - - /// Get the offset in guest memory to the input data pointer. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_pointer_offset(&self) -> usize { - // The input data pointer is immediately after the input - // data size field in the input data `GuestMemoryRegion` struct which is a `u64`. - self.get_input_data_size_offset() + size_of::() - } - - /// Get the guest virtual address of the start of input data - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_input_data_buffer_gva(&self) -> u64 { - hyperlight_common::layout::scratch_base_gva(self.scratch_size) - } - - /// Get the offset into the host scratch buffer of the start of - /// the input data - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_input_data_buffer_scratch_host_offset(&self) -> usize { - 0 - } - /// Get the offset into the scratch region of the G2H ring. fn get_g2h_ring_scratch_offset(&self) -> usize { - hyperlight_common::layout::g2h_ring_scratch_offset( - self.sandbox_memory_config.get_input_data_size(), - self.sandbox_memory_config.get_output_data_size(), - ) + hyperlight_common::layout::g2h_ring_scratch_offset() } /// Get the size of the G2H ring in bytes. @@ -505,8 +421,6 @@ impl SandboxMemoryLayout { /// Get the offset into the scratch region of the H2G ring. fn get_h2g_ring_scratch_offset(&self) -> usize { hyperlight_common::layout::h2g_ring_scratch_offset( - self.sandbox_memory_config.get_input_data_size(), - self.sandbox_memory_config.get_output_data_size(), self.sandbox_memory_config.get_g2h_queue_depth(), ) } @@ -638,8 +552,6 @@ impl SandboxMemoryLayout { #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn set_pt_size(&mut self, size: usize) -> Result<()> { let min_fixed_scratch = hyperlight_common::layout::min_scratch_size( - self.sandbox_memory_config.get_input_data_size(), - self.sandbox_memory_config.get_output_data_size(), self.sandbox_memory_config.get_g2h_queue_depth(), self.sandbox_memory_config.get_h2g_queue_depth(), ); @@ -801,34 +713,6 @@ impl SandboxMemoryLayout { // Start of setting up the PEB. The following are in the order of the PEB fields - // Set up input buffer pointer - write_u64( - mem, - self.get_input_data_size_offset(), - self.sandbox_memory_config - .get_input_data_size() - .try_into()?, - )?; - write_u64( - mem, - self.get_input_data_pointer_offset(), - self.get_input_data_buffer_gva(), - )?; - - // Set up output buffer pointer - write_u64( - mem, - self.get_output_data_size_offset(), - self.sandbox_memory_config - .get_output_data_size() - .try_into()?, - )?; - write_u64( - mem, - self.get_output_data_pointer_offset(), - self.get_output_data_buffer_gva(), - )?; - // Set up init data pointer write_u64( mem, @@ -860,12 +744,7 @@ impl SandboxMemoryLayout { // End of setting up the PEB - // The input and output data regions do not have their layout - // initialised here, because they are in the scratch - // region---they are instead set in - // [`SandboxMemoryManager::update_scratch_bookkeeping`]. - // - // Virtqueue ring layouts are also communicated via scratch-top + // Virtqueue ring layouts are communicated via scratch-top // metadata (queue depths), not the PEB. Both host and guest // compute ring addresses from shared offset functions. @@ -945,7 +824,6 @@ mod tests { let mut cfg = SandboxConfiguration::default(); // scratch_size exceeds 16 GiB limit cfg.set_scratch_size(17 * 1024 * 1024 * 1024); - cfg.set_input_data_size(16 * 1024 * 1024 * 1024); let layout = SandboxMemoryLayout::new(cfg, 4096, 4096, None); assert!(matches!(layout.unwrap_err(), MemoryRequestTooBig(..))); } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 5391638de..875ca2ab4 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -515,17 +515,6 @@ impl SandboxMemoryManager { self.layout.get_first_free_scratch_gpa(), )?; - // Initialise the guest input and output data buffers in - // scratch memory. TODO: remove the need for this. - self.scratch_mem.write::( - self.layout.get_input_data_buffer_scratch_host_offset(), - SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, - )?; - self.scratch_mem.write::( - self.layout.get_output_data_buffer_scratch_host_offset(), - SandboxMemoryLayout::STACK_POINTER_SIZE_BYTES, - )?; - // Write virtqueue metadata to scratch-top so the guest can // discover ring locations without reading the PEB. self.update_scratch_bookkeeping_item( diff --git a/src/hyperlight_host/src/sandbox/config.rs b/src/hyperlight_host/src/sandbox/config.rs index b3e5fd6d3..da068b2cd 100644 --- a/src/hyperlight_host/src/sandbox/config.rs +++ b/src/hyperlight_host/src/sandbox/config.rs @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::cmp::max; use std::time::Duration; +use hyperlight_common::mem::PAGE_SIZE_USIZE; #[cfg(target_os = "linux")] use libc::c_int; use tracing::{Span, instrument}; @@ -44,12 +44,6 @@ pub struct SandboxConfiguration { /// Guest gdb debug port #[cfg(gdb)] guest_debug_info: Option, - /// The size of the memory buffer that is made available for input to the - /// Guest Binary - input_data_size: usize, - /// The size of the memory buffer that is made available for input to the - /// Guest Binary - output_data_size: usize, /// The heap size to use in the guest sandbox. If set to 0, the heap /// size will be determined from the PE file header /// @@ -74,31 +68,29 @@ pub struct SandboxConfiguration { interrupt_vcpu_sigrtmin_offset: u8, /// How much writable memory to offer the guest scratch_size: usize, - /// Number of descriptors for the G2H (guest-to-host) virtqueue. Must be a power of 2. + /// Number of descriptors for the guest-to-host virtqueue. Must be a power of 2. /// Default: 64 sized to 2x H2G depth for deadlock prevention. g2h_queue_depth: usize, /// Number of descriptors for the host-to-guest virtqueue. Must be a power of 2. /// Default: 32 h2g_queue_depth: usize, /// Number of physical pages for the G2H (guest-to-host) buffer pool. - /// If not set, derived from `input_data_size` for backward compatibility. - /// Default: 8 pages (32KB). + /// When None, falls back to deprecated `output_data_size` or default. g2h_pool_pages: Option, /// Number of physical pages for the H2G (host-to-guest) buffer pool. - /// If not set, derived from `output_data_size` for backward compatibility. - /// Default: 4 page (16KB). + /// When None, falls back to deprecated `input_data_size` or default. h2g_pool_pages: Option, + /// Deprecated: use `g2h_pool_pages` instead. + /// When set (non-zero), translates to `g2h_pool_pages` if pool pages + /// are not explicitly configured. + output_data_size: usize, + /// Deprecated: use `h2g_pool_pages` instead. + /// When set (non-zero), translates to `h2g_pool_pages` if pool pages + /// are not explicitly configured. + input_data_size: usize, } impl SandboxConfiguration { - /// The default size of input data - pub const DEFAULT_INPUT_SIZE: usize = 0x4000; - /// The minimum size of input data - pub const MIN_INPUT_SIZE: usize = 0x2000; - /// The default size of output data - pub const DEFAULT_OUTPUT_SIZE: usize = 0x4000; - /// The minimum size of output data - pub const MIN_OUTPUT_SIZE: usize = 0x2000; /// The default interrupt retry delay pub const DEFAULT_INTERRUPT_RETRY_DELAY: Duration = Duration::from_micros(500); /// The default signal offset from `SIGRTMIN` used to determine the signal number for interrupting @@ -120,8 +112,6 @@ impl SandboxConfiguration { /// Create a new configuration for a sandbox with the given sizes. #[instrument(skip_all, parent = Span::current(), level= "Trace")] fn new( - input_data_size: usize, - output_data_size: usize, heap_size_override: Option, scratch_size: usize, interrupt_retry_delay: Duration, @@ -130,8 +120,6 @@ impl SandboxConfiguration { #[cfg(crashdump)] guest_core_dump: bool, ) -> Self { Self { - input_data_size: max(input_data_size, Self::MIN_INPUT_SIZE), - output_data_size: max(output_data_size, Self::MIN_OUTPUT_SIZE), heap_size_override: heap_size_override.unwrap_or(0), scratch_size, interrupt_retry_delay, @@ -140,6 +128,8 @@ impl SandboxConfiguration { h2g_queue_depth: Self::DEFAULT_H2G_QUEUE_DEPTH, g2h_pool_pages: None, h2g_pool_pages: None, + output_data_size: 0, + input_data_size: 0, #[cfg(gdb)] guest_debug_info, #[cfg(crashdump)] @@ -147,26 +137,6 @@ impl SandboxConfiguration { } } - /// Set the size of the legacy input data buffer (host-to-guest). - /// - /// Deprecated: use [`set_h2g_pool_pages`](Self::set_h2g_pool_pages) instead. - /// When `h2g_pool_pages` is not set, the H2G pool size is derived - /// from this value for backward compatibility. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub fn set_input_data_size(&mut self, input_data_size: usize) { - self.input_data_size = max(input_data_size, Self::MIN_INPUT_SIZE); - } - - /// Set the size of the legacy output data buffer (guest-to-host). - /// - /// Deprecated: use [`set_g2h_pool_pages`](Self::set_g2h_pool_pages) instead. - /// When `g2h_pool_pages` is not set, the G2H pool size is derived - /// from this value for backward compatibility. - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub fn set_output_data_size(&mut self, output_data_size: usize) { - self.output_data_size = max(output_data_size, Self::MIN_OUTPUT_SIZE); - } - /// Set the heap size to use in the guest sandbox. If set to 0, the heap size will be determined from the PE file header #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_heap_size(&mut self, heap_size: u64) { @@ -226,16 +196,6 @@ impl SandboxConfiguration { self.guest_debug_info = Some(debug_info); } - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_input_data_size(&self) -> usize { - self.input_data_size - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - pub(crate) fn get_output_data_size(&self) -> usize { - self.output_data_size - } - #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub(crate) fn get_scratch_size(&self) -> usize { self.scratch_size @@ -266,28 +226,36 @@ impl SandboxConfiguration { } /// Get the number of G2H buffer pool pages. - /// Falls back to deriving from `output_data_size` if not explicitly set - /// (output = guest-to-host direction). + /// + /// Priority: explicit `g2h_pool_pages` > derived from deprecated + /// `output_data_size` > default. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_g2h_pool_pages(&self) -> usize { self.g2h_pool_pages.unwrap_or_else(|| { - let pages = self - .output_data_size - .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); - pages.max(Self::DEFAULT_G2H_POOL_PAGES) + if self.output_data_size > 0 { + self.output_data_size + .div_ceil(PAGE_SIZE_USIZE) + .max(Self::DEFAULT_G2H_POOL_PAGES) + } else { + Self::DEFAULT_G2H_POOL_PAGES + } }) } /// Get the number of H2G buffer pool pages. - /// Falls back to deriving from `input_data_size` if not explicitly set - /// (input = host-to-guest direction). + /// + /// Priority: explicit `h2g_pool_pages` > derived from deprecated + /// `input_data_size` > default. #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn get_h2g_pool_pages(&self) -> usize { self.h2g_pool_pages.unwrap_or_else(|| { - let pages = self - .input_data_size - .div_ceil(hyperlight_common::mem::PAGE_SIZE_USIZE); - pages.max(Self::DEFAULT_H2G_POOL_PAGES) + if self.input_data_size > 0 { + self.input_data_size + .div_ceil(PAGE_SIZE_USIZE) + .max(Self::DEFAULT_H2G_POOL_PAGES) + } else { + Self::DEFAULT_H2G_POOL_PAGES + } }) } @@ -303,6 +271,24 @@ impl SandboxConfiguration { self.h2g_pool_pages = Some(pages); } + /// Deprecated: use [`set_g2h_pool_pages`](Self::set_g2h_pool_pages). + /// + /// Sets the output data size. If `g2h_pool_pages` is not explicitly + /// set, this value is translated to pool pages. + #[deprecated(note = "use set_g2h_pool_pages instead")] + pub fn set_output_data_size(&mut self, size: usize) { + self.output_data_size = size; + } + + /// Deprecated: use [`set_h2g_pool_pages`](Self::set_h2g_pool_pages). + /// + /// Sets the input data size. If `h2g_pool_pages` is not explicitly + /// set, this value is translated to pool pages. + #[deprecated(note = "use set_h2g_pool_pages instead")] + pub fn set_input_data_size(&mut self, size: usize) { + self.input_data_size = size; + } + /// Set the size of the scratch regiong #[instrument(skip_all, parent = Span::current(), level= "Trace")] pub fn set_scratch_size(&mut self, scratch_size: usize) { @@ -339,8 +325,6 @@ impl Default for SandboxConfiguration { #[instrument(skip_all, parent = Span::current(), level= "Trace")] fn default() -> Self { Self::new( - Self::DEFAULT_INPUT_SIZE, - Self::DEFAULT_OUTPUT_SIZE, None, Self::DEFAULT_SCRATCH_SIZE, Self::DEFAULT_INTERRUPT_RETRY_DELAY, @@ -360,12 +344,8 @@ mod tests { #[test] fn overrides() { const HEAP_SIZE_OVERRIDE: u64 = 0x50000; - const INPUT_DATA_SIZE_OVERRIDE: usize = 0x4000; - const OUTPUT_DATA_SIZE_OVERRIDE: usize = 0x4001; const SCRATCH_SIZE_OVERRIDE: usize = 0x60000; - let mut cfg = SandboxConfiguration::new( - INPUT_DATA_SIZE_OVERRIDE, - OUTPUT_DATA_SIZE_OVERRIDE, + let cfg = SandboxConfiguration::new( Some(HEAP_SIZE_OVERRIDE), SCRATCH_SIZE_OVERRIDE, SandboxConfiguration::DEFAULT_INTERRUPT_RETRY_DELAY, @@ -380,38 +360,6 @@ mod tests { let scratch_size = cfg.get_scratch_size(); assert_eq!(HEAP_SIZE_OVERRIDE, heap_size); assert_eq!(SCRATCH_SIZE_OVERRIDE, scratch_size); - - cfg.heap_size_override = 2048; - cfg.scratch_size = 0x40000; - assert_eq!(2048, cfg.heap_size_override); - assert_eq!(0x40000, cfg.scratch_size); - assert_eq!(INPUT_DATA_SIZE_OVERRIDE, cfg.input_data_size); - assert_eq!(OUTPUT_DATA_SIZE_OVERRIDE, cfg.output_data_size); - } - - #[test] - fn min_sizes() { - let mut cfg = SandboxConfiguration::new( - SandboxConfiguration::MIN_INPUT_SIZE - 1, - SandboxConfiguration::MIN_OUTPUT_SIZE - 1, - None, - SandboxConfiguration::DEFAULT_SCRATCH_SIZE, - SandboxConfiguration::DEFAULT_INTERRUPT_RETRY_DELAY, - SandboxConfiguration::INTERRUPT_VCPU_SIGRTMIN_OFFSET, - #[cfg(gdb)] - None, - #[cfg(crashdump)] - true, - ); - assert_eq!(SandboxConfiguration::MIN_INPUT_SIZE, cfg.input_data_size); - assert_eq!(SandboxConfiguration::MIN_OUTPUT_SIZE, cfg.output_data_size); - assert_eq!(0, cfg.heap_size_override); - - cfg.set_input_data_size(SandboxConfiguration::MIN_INPUT_SIZE - 1); - cfg.set_output_data_size(SandboxConfiguration::MIN_OUTPUT_SIZE - 1); - - assert_eq!(SandboxConfiguration::MIN_INPUT_SIZE, cfg.input_data_size); - assert_eq!(SandboxConfiguration::MIN_OUTPUT_SIZE, cfg.output_data_size); } mod proptests { @@ -422,21 +370,6 @@ mod tests { use crate::sandbox::config::DebugInfo; proptest! { - #[test] - fn input_data_size(size in SandboxConfiguration::MIN_INPUT_SIZE..=SandboxConfiguration::MIN_INPUT_SIZE * 10) { - let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(size); - prop_assert_eq!(size, cfg.get_input_data_size()); - } - - #[test] - fn output_data_size(size in SandboxConfiguration::MIN_OUTPUT_SIZE..=SandboxConfiguration::MIN_OUTPUT_SIZE * 10) { - let mut cfg = SandboxConfiguration::default(); - cfg.set_output_data_size(size); - prop_assert_eq!(size, cfg.get_output_data_size()); - } - - #[test] fn heap_size_override(size in 0x1000..=0x10000u64) { let mut cfg = SandboxConfiguration::default(); diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 687959bec..ec1aba0de 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1043,12 +1043,10 @@ mod tests { .unwrap(); } - /// Make sure input/output buffers are properly reset after guest call (with host call) + /// Make sure pool buffers are properly reset after guest call (with host call) #[test] fn io_buffer_reset() { - let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(4096); - cfg.set_output_data_size(4096); + let cfg = SandboxConfiguration::default(); let path = simple_guest_as_string().unwrap(); let mut sandbox = UninitializedSandbox::new(GuestBinary::FilePath(path), Some(cfg)).unwrap(); @@ -1103,8 +1101,6 @@ mod tests { // total, and then add some more for the eagerly-copied page // tables on amd64 let min_scratch = hyperlight_common::layout::min_scratch_size( - cfg.get_input_data_size(), - cfg.get_output_data_size(), cfg.get_g2h_queue_depth(), cfg.get_h2g_queue_depth(), ); diff --git a/src/hyperlight_host/src/sandbox/uninitialized.rs b/src/hyperlight_host/src/sandbox/uninitialized.rs index e737d08da..d935083b2 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized.rs @@ -636,8 +636,6 @@ mod tests { // Non default memory configuration let cfg = { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(0x1000); - cfg.set_output_data_size(0x1000); cfg.set_heap_size(0x1000); Some(cfg) }; @@ -1390,11 +1388,11 @@ mod tests { let _evolved: MultiUseSandbox = sandbox.evolve().expect("Failed to evolve sandbox"); } - // Test 4: Create snapshot with custom input/output buffer sizes + // Test 4: Create snapshot with custom pool page sizes { let mut cfg = SandboxConfiguration::default(); - cfg.set_input_data_size(64 * 1024); // 64KB input - cfg.set_output_data_size(64 * 1024); // 64KB output + cfg.set_h2g_pool_pages(16); // 16 pages + cfg.set_g2h_pool_pages(16); // 16 pages let env = GuestEnvironment::new(GuestBinary::FilePath(binary_path.clone()), None); @@ -1418,9 +1416,7 @@ mod tests { { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(32 * 1024 * 1024); // 32MB heap - cfg.set_scratch_size(256 * 1024 * 2); // 512KB scratch (256KB will be input/output) - cfg.set_input_data_size(128 * 1024); // 128KB input - cfg.set_output_data_size(128 * 1024); // 128KB output + cfg.set_scratch_size(256 * 1024 * 2); // 512KB scratch let env = GuestEnvironment::new(GuestBinary::FilePath(binary_path.clone()), None); diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index a2bb2d91a..243b8d3b8 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -583,42 +583,6 @@ fn guest_outb_with_invalid_port_poisons_sandbox() { }); } -#[test] -fn corrupt_output_size_prefix_rejected() { - with_rust_sandbox(|mut sbox| { - let res = sbox.call::("CorruptOutputSizePrefix", ()); - assert!( - res.is_err(), - "Expected error when guest corrupts size prefix, got: {:?}", - res, - ); - let err_msg = format!("{:?}", res.unwrap_err()); - assert!( - err_msg.contains("Corrupt buffer size prefix: flatbuffer claims 4294967295 bytes but the element slot is only 8 bytes"), - "Unexpected error message: {err_msg}" - ); - }); -} - -#[test] -fn corrupt_output_back_pointer_rejected() { - with_rust_sandbox(|mut sbox| { - let res = sbox.call::("CorruptOutputBackPointer", ()); - assert!( - res.is_err(), - "Expected error when guest corrupts back-pointer, got: {:?}", - res, - ); - let err_msg = format!("{:?}", res.unwrap_err()); - assert!( - err_msg.contains( - "Corrupt buffer back-pointer: element offset 57005 is outside valid range [8, 8]" - ), - "Unexpected error message: {err_msg}" - ); - }); -} - #[test] fn guest_panic_no_alloc() { let heap_size = 0x4000; diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index c2b0f5902..de4c41ee5 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -214,9 +214,7 @@ fn incorrect_parameter_num() { #[test] fn small_scratch_sandbox() { let mut cfg = SandboxConfiguration::default(); - cfg.set_scratch_size(0x48000); - cfg.set_input_data_size(0x24000); - cfg.set_output_data_size(0x24000); + cfg.set_scratch_size(0x1000); let a = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), Some(cfg), @@ -346,6 +344,7 @@ fn callback_test_parallel() { } #[test] +#[ignore] // TODO(virtq): C guest host-function error path needs fixing. fn host_function_error() { with_all_uninit_sandboxes(|mut sandbox| { // create host function diff --git a/src/tests/rust_guests/simpleguest/src/main.rs b/src/tests/rust_guests/simpleguest/src/main.rs index e122c8727..64b1f0f43 100644 --- a/src/tests/rust_guests/simpleguest/src/main.rs +++ b/src/tests/rust_guests/simpleguest/src/main.rs @@ -52,7 +52,7 @@ use hyperlight_guest_bin::host_comm::{ call_host_function, print_output_with_host_print, read_n_bytes_from_user_memory, }; use hyperlight_guest_bin::memory::malloc; -use hyperlight_guest_bin::{GUEST_HANDLE, guest_function, guest_logger, host_function}; +use hyperlight_guest_bin::{guest_function, guest_logger, host_function}; use log::{LevelFilter, error}; use tracing::{Span, instrument}; @@ -981,53 +981,6 @@ fn fuzz_guest_trace(max_depth: u32, msg: String) -> u32 { fuzz_traced_function(0, max_depth, &msg) } -#[guest_function("CorruptOutputSizePrefix")] -fn corrupt_output_size_prefix() -> i32 { - unsafe { - let peb_ptr = core::ptr::addr_of!(GUEST_HANDLE).read().peb().unwrap(); - let output_stack_ptr = (*peb_ptr).output_stack.ptr as *mut u8; - - // Write a fake stack entry with a ~4 GB size prefix (0xFFFF_FFFB + 4). - let buf = core::slice::from_raw_parts_mut(output_stack_ptr, 24); - buf[0..8].copy_from_slice(&24_u64.to_le_bytes()); - buf[8..12].copy_from_slice(&0xFFFF_FFFBu32.to_le_bytes()); - buf[12..16].copy_from_slice(&[0u8; 4]); - buf[16..24].copy_from_slice(&8_u64.to_le_bytes()); - - core::arch::asm!( - "out dx, eax", - "cli", - "hlt", - in("dx") hyperlight_common::outb::VmAction::Halt as u16, - in("eax") 0u32, - options(noreturn), - ); - } -} - -#[guest_function("CorruptOutputBackPointer")] -fn corrupt_output_back_pointer() -> i32 { - unsafe { - let peb_ptr = core::ptr::addr_of!(GUEST_HANDLE).read().peb().unwrap(); - let output_stack_ptr = (*peb_ptr).output_stack.ptr as *mut u8; - - // Write a fake stack entry with back-pointer 0xDEAD (past stack pointer 24). - let buf = core::slice::from_raw_parts_mut(output_stack_ptr, 24); - buf[0..8].copy_from_slice(&24_u64.to_le_bytes()); - buf[8..16].copy_from_slice(&[0u8; 8]); - buf[16..24].copy_from_slice(&0xDEAD_u64.to_le_bytes()); - - core::arch::asm!( - "out dx, eax", - "cli", - "hlt", - in("dx") hyperlight_common::outb::VmAction::Halt as u16, - in("eax") 0u32, - options(noreturn), - ); - } -} - // Interprets the given guest function call as a host function call and dispatches it to the host. fn fuzz_host_function(func: FunctionCall) -> Result> { let mut params = func.parameters.unwrap(); From 484289888e74749ba322110a2cae3317313e1362 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 8 Apr 2026 10:26:37 +0200 Subject: [PATCH 13/26] feat(virtq): fix host function error test Signed-off-by: Tomasz Andrzejak --- src/hyperlight_guest/src/virtq/context.rs | 22 +++++++++++-------- src/hyperlight_guest_capi/src/dispatch.rs | 19 +++------------- .../tests/sandbox_host_tests.rs | 1 - 3 files changed, 16 insertions(+), 26 deletions(-) diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index ac0357351..8dc114406 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -72,7 +72,7 @@ pub struct GuestContext { g2h_producer: G2hProducer, h2g_producer: H2gProducer, generation: u16, - last_host_return: Option, + last_host_result: Option>, } impl GuestContext { @@ -101,7 +101,7 @@ impl GuestContext { g2h_producer, h2g_producer, generation, - last_host_return: None, + last_host_result: None, }; ctx.prefill_h2g(); @@ -309,6 +309,7 @@ impl GuestContext { // restore_h2g_prefill() wrote matching descriptors to the // zeroed ring memory. Both sides are in sync. self.generation = new_generation; + self.last_host_result = None; } pub(super) fn generation(&self) -> u16 { @@ -346,24 +347,27 @@ impl GuestContext { self.g2h_producer.submit(entry) } - /// Stash a host function return value for later retrieval. + /// Stash a host function result for later retrieval. /// /// Used by the C API's two-step calling convention where /// `hl_call_host_function` and `hl_get_host_return_value_as_*` /// are separate calls. - pub fn stash_host_return(&mut self, value: ReturnValue) { - self.last_host_return = Some(value); + pub fn stash_host_result(&mut self, result: Result) { + self.last_host_result = Some(result); } /// Take the stashed host return value. /// /// Panics if no value was stashed or if the type conversion fails. + /// If the stashed result was an error, panics with the error message. pub fn take_host_return>(&mut self) -> T { - let rv = self - .last_host_return + let val = self + .last_host_result .take() - .expect("No host return value available"); - match T::try_from(rv) { + .expect("No host return value available") + .expect("Host function returned an error"); + + match T::try_from(val) { Ok(v) => v, Err(_) => panic!("Host return value type mismatch"), } diff --git a/src/hyperlight_guest_capi/src/dispatch.rs b/src/hyperlight_guest_capi/src/dispatch.rs index 4fd61df44..86ee0fcbe 100644 --- a/src/hyperlight_guest_capi/src/dispatch.rs +++ b/src/hyperlight_guest_capi/src/dispatch.rs @@ -109,22 +109,9 @@ pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) { let return_type = unsafe { function_call.copy_return_type() }; virtq::with_context(|ctx| { - match ctx.call_host_function::(&func_name, Some(parameters), return_type) { - Ok(result) => ctx.stash_host_return(result), - Err(e) => { - // Host function returned an error. Abort with the error - // message so the host can capture it via the abort buffer. - let msg = alloc::ffi::CString::new(e.message) - .unwrap_or_else(|_| alloc::ffi::CString::new("host error").unwrap()); - - unsafe { - hyperlight_guest::exit::abort_with_code_and_message( - &[e.kind as u8], - msg.as_ptr(), - ); - } - } - } + let result = + ctx.call_host_function::(&func_name, Some(parameters), return_type); + ctx.stash_host_result(result); }); } diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index de4c41ee5..795308146 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -344,7 +344,6 @@ fn callback_test_parallel() { } #[test] -#[ignore] // TODO(virtq): C guest host-function error path needs fixing. fn host_function_error() { with_all_uninit_sandboxes(|mut sandbox| { // create host function From b605dc41d548704ad57c1de4141e1d2e6c23b27f Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Wed, 8 Apr 2026 10:38:53 +0200 Subject: [PATCH 14/26] feat(virtq): micro optimize consumer state Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 43 ++++++++------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index 9e4e09527..6fed8a685 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -18,16 +18,10 @@ use alloc::vec; use alloc::vec::Vec; use bytes::Bytes; +use fixedbitset::FixedBitSet; use super::*; -/// In-flight entry tracking. -/// -/// Stored per descriptor ID while the entry is being processed. -/// Tracks that a descriptor slot is occupied. -#[derive(Debug, Clone, Copy)] -pub(crate) struct Inflight; - /// Data received from the producer, safely copied out of shared memory. /// /// Created by [`VirtqConsumer::poll`]. The entry data is eagerly copied @@ -261,7 +255,7 @@ impl AckCompletion { pub struct VirtqConsumer { inner: RingConsumer, notifier: N, - inflight: Vec>, + inflight: FixedBitSet, } impl VirtqConsumer { @@ -274,7 +268,7 @@ impl VirtqConsumer { /// * `notifier` - Callback for notifying the driver (producer) about completions pub fn new(layout: Layout, mem: M, notifier: N) -> Self { let inner = RingConsumer::new(layout, mem); - let inflight = vec![None; inner.len()]; + let inflight = FixedBitSet::with_capacity(inner.len()); Self { inner, @@ -320,16 +314,16 @@ impl VirtqConsumer { } // Reserve the inflight slot - let slot = self - .inflight - .get_mut(id as usize) - .ok_or(VirtqError::InvalidState)?; + let id_idx = id as usize; + if id_idx >= self.inflight.len() { + return Err(VirtqError::InvalidState); + } - if slot.is_some() { + if self.inflight.contains(id_idx) { return Err(VirtqError::InvalidState); } - *slot = Some(Inflight); + self.inflight.insert(id_idx); let token = Token(id); // Copy entry data from shared memory @@ -363,16 +357,13 @@ impl VirtqConsumer { let id = completion.id(); let written = completion.written() as u32; - let slot = self - .inflight - .get_mut(id as usize) - .ok_or(VirtqError::InvalidState)?; - - if slot.is_none() { + let id_idx = id as usize; + let slot_set = id_idx < self.inflight.len() && self.inflight.contains(id_idx); + if !slot_set { return Err(VirtqError::InvalidState); } - *slot = None; + self.inflight.set(id_idx, false); if self.inner.submit_used_with_notify(id, written)? { self.notifier.notify(QueueStats { @@ -445,7 +436,7 @@ impl VirtqConsumer { /// Reset ring and inflight state to initial values. pub fn reset(&mut self) { self.inner.reset(); - self.inflight.fill(None); + self.inflight.clear(); } } @@ -647,14 +638,14 @@ mod tests { producer.submit(se).unwrap(); let (_entry, completion) = consumer.poll(1024).unwrap().unwrap(); - assert!(consumer.inflight.iter().any(|s| s.is_some())); + assert!(consumer.inflight.count_ones(..) > 0); // Complete first so we do not leak consumer.complete(completion).unwrap(); consumer.reset(); - assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inflight.count_ones(..), 0); assert_eq!(consumer.inner.num_inflight(), 0); } @@ -677,7 +668,7 @@ mod tests { consumer.reset(); - assert!(consumer.inflight.iter().all(|s| s.is_none())); + assert_eq!(consumer.inflight.count_ones(..), 0); assert_eq!(consumer.inner.num_inflight(), 0); } } From a1b412ab38f5c82f5482c98574899fab02029ccf Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 9 Apr 2026 13:33:09 +0200 Subject: [PATCH 15/26] feat(virtq): harden virtq snapshoting Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 1 - src/hyperlight_common/src/virtq/pool.rs | 127 ++++++++ src/hyperlight_common/src/virtq/producer.rs | 305 +++++++++++++++++- src/hyperlight_common/src/virtq/ring.rs | 151 +++++++++ src/hyperlight_guest/src/virtq/context.rs | 124 +++---- src/hyperlight_guest/src/virtq/mod.rs | 27 +- .../src/guest_function/call.rs | 3 +- .../src/{virtq/mod.rs => virtq.rs} | 38 ++- src/hyperlight_host/src/mem/mgr.rs | 144 +++++++-- .../src/sandbox/initialized_multi_use.rs | 214 ++++++++++++ .../src/sandbox/uninitialized_evolve.rs | 15 +- 11 files changed, 1021 insertions(+), 128 deletions(-) rename src/hyperlight_guest_bin/src/{virtq/mod.rs => virtq.rs} (71%) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index 6fed8a685..b29f7694a 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -15,7 +15,6 @@ limitations under the License. */ use alloc::vec; -use alloc::vec::Vec; use bytes::Bytes; use fixedbitset::FixedBitSet; diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index 2e49e27fe..bbae4ff41 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -601,10 +601,61 @@ impl RecyclePool { }) } + /// Rebuild pool state so that every address in `allocated` is removed + /// from the free list, matching externally known inflight state. + pub fn restore_allocated(&self, allocated: &[u64]) -> Result<(), AllocError> { + self.reset(); + + if allocated.is_empty() { + return Ok(()); + } + + let mut inner = self.inner.borrow_mut(); + + for &addr in allocated { + let pos = inner + .free + .iter() + .position(|&a| a == addr) + .ok_or(AllocError::InvalidFree(addr, inner.slot_size))?; + + inner.free.swap_remove(pos); + } + + Ok(()) + } + + /// Compute the address of slot `index`. + /// + /// Returns `None` if `index >= count`. + pub fn slot_addr(&self, index: usize) -> Option { + let inner = self.inner.borrow(); + if index < inner.count { + Some(inner.base_addr + (index * inner.slot_size) as u64) + } else { + None + } + } + /// Number of free slots. pub fn num_free(&self) -> usize { self.inner.borrow().free.len() } + + /// Base address of the pool region. + pub fn base_addr(&self) -> u64 { + self.inner.borrow().base_addr + } + + /// Slot size in bytes. + pub fn slot_size(&self) -> usize { + self.inner.borrow().slot_size + } + + /// Number of slots in the pool. + pub fn count(&self) -> usize { + self.inner.borrow().count + } } impl BufferProvider for RecyclePool { @@ -664,6 +715,11 @@ mod tests { BufferPool::::new(base, size).unwrap() } + fn make_recycle_pool(slot_count: usize, slot_size: usize) -> RecyclePool { + let base = 0x80000u64; + RecyclePool::new(base, slot_count * slot_size, slot_size).unwrap() + } + #[test] fn test_slab_new_success() { let slab = Slab::<256>::new(0x10000, 1024).unwrap(); @@ -1223,6 +1279,77 @@ mod tests { let a = pool.inner.borrow_mut().alloc(256).unwrap(); assert!(a.len > 0); } + + #[test] + fn test_recycle_pool_restore_allocated_removes_from_free_list() { + let pool = make_recycle_pool(4, 4096); + assert_eq!(pool.num_free(), 4); + + let addrs = [0x80000, 0x81000]; // slots 0 and 1 + pool.restore_allocated(&addrs).unwrap(); + assert_eq!(pool.num_free(), 2); + + // Allocating should only return the two remaining slots + let a1 = pool.alloc(4096).unwrap(); + let a2 = pool.alloc(4096).unwrap(); + assert!(pool.alloc(4096).is_err()); + + // The allocated addresses should be the non-restored ones + let mut got = [a1.addr, a2.addr]; + got.sort(); + assert_eq!(got, [0x82000, 0x83000]); + } + + #[test] + fn test_recycle_pool_restore_allocated_invalid_addr_returns_error() { + let pool = make_recycle_pool(4, 4096); + let result = pool.restore_allocated(&[0xDEAD]); + assert!(result.is_err()); + } + + #[test] + fn test_recycle_pool_restore_allocated_then_dealloc_roundtrip() { + let pool = make_recycle_pool(4, 4096); + let addr = 0x81000u64; + + pool.restore_allocated(&[addr]).unwrap(); + assert_eq!(pool.num_free(), 3); + + // Dealloc the restored address + pool.dealloc(Allocation { addr, len: 4096 }).unwrap(); + assert_eq!(pool.num_free(), 4); + } + + #[test] + fn test_recycle_pool_restore_allocated_all_slots() { + let pool = make_recycle_pool(4, 4096); + let addrs: Vec = (0..4).map(|i| 0x80000 + i * 4096).collect(); + + pool.restore_allocated(&addrs).unwrap(); + assert_eq!(pool.num_free(), 0); + assert!(pool.alloc(4096).is_err()); + } + + #[test] + fn test_recycle_pool_restore_allocated_empty_list_is_noop() { + let pool = make_recycle_pool(4, 4096); + pool.restore_allocated(&[]).unwrap(); + assert_eq!(pool.num_free(), 4); + } + + #[test] + fn test_recycle_pool_restore_allocated_resets_first() { + let pool = make_recycle_pool(4, 4096); + + // Allocate some slots + let _ = pool.alloc(4096).unwrap(); + let _ = pool.alloc(4096).unwrap(); + assert_eq!(pool.num_free(), 2); + + // restore_allocated resets then removes - so 4 - 1 = 3 + pool.restore_allocated(&[0x80000]).unwrap(); + assert_eq!(pool.num_free(), 3); + } } #[cfg(test)] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index eeb96cc7f..beeb4519b 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -19,6 +19,7 @@ use alloc::vec; use alloc::vec::Vec; use bytes::Bytes; +use smallvec::SmallVec; use super::*; @@ -391,18 +392,116 @@ where /// /// # Safety /// - /// All [`RecvCompletion`]s (and their backing [`Bytes`]) from - /// previous `poll()` calls must have been dropped before calling - /// this. Outstanding completions hold pool allocations via - /// `BufferOwner`; resetting the pool while they exist would cause - /// double-free on drop. + /// All [`RecvCompletion`]s (and their backing [`Bytes`]) from previous `poll()` + /// calls must have been dropped before calling this. Outstanding completions + /// hold pool allocations via `BufferOwner`; resetting the pool while they exist + /// would cause double-free on drop. /// - /// TODO(virtq): properly restore state after snapshot instead of just resetting everything + /// TODO(virtq): find a way to allow guest to keep completions across resets. pub fn reset(&mut self) { - self.inner.reset(); self.pool.reset(); + self.inner.reset(); + self.pending.clear(); self.inflight.fill(None); + } + + /// Replace the pool and reset ring, inflight, and pending state. + /// + /// Use this when restoring from a snapshot where the pool has been + /// relocated or recreated. + /// + /// # Safety + /// + /// Same as [`reset`](Self::reset) - all outstanding completions + /// must have been dropped. + pub fn reset_with_pool(&mut self, pool: P) { + self.pool = pool; + self.inner.reset(); self.pending.clear(); + self.inflight.fill(None); + } +} + +/// Snapshot restore support for producers backed by [`RecyclePool`]. +impl VirtqProducer +where + M: MemOps + Clone, + N: Notifier, +{ + /// Replace the pool and reconstruct producer state from a prefilled ring. + /// + /// The host prefills the H2G ring with `min(ring_size, pool_count)` + /// descriptors during restore (`restore_h2g_prefill`), writing + /// descriptors in forward order: position i gets + /// `addr = pool_base + i * slot_size`. + /// + /// Any descriptors already consumed by the host marked used + /// will be discovered naturally by `poll_used()` after restore. + pub fn restore_from_ring(&mut self, pool: RecyclePool) -> Result<(), VirtqError> { + self.reset_with_pool(pool); + + let ring_size = self.inner.len(); + let pool_count = self.pool.count(); + let prefill_count = core::cmp::min(ring_size, pool_count); + let slot_size = self.pool.slot_size(); + + let mut ids = SmallVec::<[u16; 64]>::new(); + + // Scan descriptors to discover in-flight IDs and set up inflight table + for pos in 0..prefill_count as u16 { + let desc_base = self + .inner + .desc_table() + .desc_addr(pos) + .ok_or(VirtqError::RingError(RingError::InvalidState))?; + + let id = self + .inner + .mem() + .read_val::(desc_base + Descriptor::ID_OFFSET as u64) + .map_err(|_| VirtqError::RingError(RingError::MemError))?; + + if (id as usize) >= ring_size { + return Err(VirtqError::InvalidState); + } + + if self.inflight[id as usize].is_some() { + return Err(VirtqError::InvalidState); + } + + let addr = self + .pool + .slot_addr(pos as usize) + .ok_or(VirtqError::InvalidState)?; + + self.inflight[id as usize] = Some(Inflight::WriteOnly { + completion: Allocation { + addr, + len: slot_size, + }, + }); + + ids.push(id); + } + + self.inner.reset_prefilled(&ids); + + let addrs: SmallVec<[u64; 64]> = (0..prefill_count) + .map(|i| self.pool.slot_addr(i).expect("prefill_count <= pool count")) + .collect(); + + self.pool + .restore_allocated(&addrs) + .map_err(|_| VirtqError::InvalidState)?; + + debug_assert!( + self.inflight.iter().filter(|s| s.is_some()).count() == prefill_count, + "restore_from_ring: expected {} inflight entries, found {}", + prefill_count, + self.inflight.iter().filter(|s| s.is_some()).count() + ); + + Ok(()) } } @@ -641,9 +740,28 @@ impl Drop for SendEntry { #[cfg(test)] mod tests { use super::*; - use crate::virtq::ring::tests::make_ring; + use crate::virtq::ring::tests::{OwnedRing, TestMem, make_consumer, make_producer, make_ring}; use crate::virtq::test_utils::*; + type RecycleProducer = VirtqProducer; + + const SLOT_SIZE: usize = 4096; + + fn make_recycle_producer(ring: &OwnedRing, slot_count: usize) -> RecycleProducer { + let layout = ring.layout(); + let mem = ring.mem(); + let pool = make_pool(ring, slot_count); + let notifier = TestNotifier::new(); + + VirtqProducer::new(layout, mem, notifier, pool) + } + + fn make_pool(ring: &OwnedRing, slot_count: usize) -> RecyclePool { + let mem = ring.mem(); + let pool_base = mem.base_addr() + Layout::query_size(ring.len()) as u64 + 0x100; + RecyclePool::new(pool_base, slot_count * SLOT_SIZE, SLOT_SIZE).unwrap() + } + #[test] fn test_chain_readwrite_build() { let ring = make_ring(16); @@ -903,4 +1021,175 @@ mod tests { assert!(producer.inflight.iter().all(|s| s.is_none())); assert_eq!(producer.inner.num_free(), producer.inner.len()); } + + #[test] + fn test_restore_from_ring_requires_full_prefill() { + let ring = make_ring(8); + let mut producer = make_recycle_producer(&ring, 8); + + // Ring has no prefilled descriptors - restore should fail + // because IDs read from zeroed memory will all be 0 (duplicate) + assert!(producer.restore_from_ring(make_pool(&ring, 8)).is_err()); + } + + #[test] + fn test_restore_from_ring_partial_prefill_fails() { + let ring = make_ring(8); + let producer = make_recycle_producer(&ring, 8); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill: write only one descriptor + let mut writer = make_producer(&ring); + writer + .submit_one(pool_base, SLOT_SIZE as u32, true) + .unwrap(); + + // Restore should fail because only 1 of 8 positions has a + // valid unique ID - remaining positions have id=0 (duplicate) + let mut restored = make_recycle_producer(&ring, 8); + assert!(restored.restore_from_ring(make_pool(&ring, 8)).is_err()); + } + + #[test] + fn test_restore_from_ring_full_prefill() { + let depth = 8usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill: write all descriptors + let mut writer = make_producer(&ring); + for i in 0..depth { + let addr = pool_base + (i * SLOT_SIZE) as u64; + writer.submit_one(addr, SLOT_SIZE as u32, true).unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + + // All inflight slots should be populated + let inflight_count = restored.inflight.iter().filter(|s| s.is_some()).count(); + assert_eq!(inflight_count, depth); + + // Pool should be fully allocated + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_forward_order() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Forward order prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + } + + #[test] + fn test_restore_from_ring_reverse_order() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Reverse order prefill (current host behavior) + let mut writer = make_producer(&ring); + for i in (0..depth).rev() { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + } + + #[test] + fn test_restore_from_ring_pool_state_correct() { + let depth = 8usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Full prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + // All slots are allocated after full-prefill restore + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_idempotent() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + assert_eq!(restored.pool.num_free(), 0); + } + + #[test] + fn test_restore_from_ring_then_poll_used() { + let depth = 4usize; + let ring = make_ring(depth); + let producer = make_recycle_producer(&ring, depth); + let pool_base = producer.pool.base_addr(); + + // Simulate host prefill + let mut writer = make_producer(&ring); + for i in 0..depth { + writer + .submit_one(pool_base + (i * SLOT_SIZE) as u64, SLOT_SIZE as u32, true) + .unwrap(); + } + + // Restore producer and use ring-level consumer to complete one entry + let mut restored = make_recycle_producer(&ring, depth); + restored.restore_from_ring(make_pool(&ring, depth)).unwrap(); + + // Ring-level consumer reads available descriptors + let mut consumer = make_consumer(&ring); + let (id, chain) = consumer.poll_available().unwrap(); + let writable = chain.writables(); + assert_eq!(writable.len(), 1); + + // Write some data into the writable buffer + let payload = b"test payload"; + consumer.mem().write(writable[0].addr, payload).unwrap(); + consumer.submit_used(id, payload.len() as u32).unwrap(); + + // Producer polls for the completion + let cqe = restored.poll().unwrap().unwrap(); + assert_eq!(&cqe.data[..payload.len()], payload); + + // Pool slot should be returned after data is dropped + drop(cqe); + assert_eq!(restored.pool.num_free(), 1); + } } diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index bf5eba0f2..9c8d5a30c 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -678,31 +678,42 @@ impl RingProducer { } /// Get number of free descriptors in the ring. + #[inline] pub fn num_free(&self) -> usize { self.num_free } /// Get number of inflight (submitted but not yet used) descriptors. + #[inline] pub fn num_inflight(&self) -> usize { self.desc_table.len() - self.num_free } /// Check if the ring is full (no free descriptors). + #[inline] pub fn is_full(&self) -> bool { self.num_free == 0 } /// Get descriptor table length + #[inline] #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { self.desc_table.len() } /// Get memory accessor reference + #[inline] pub fn mem(&self) -> &M { &self.mem } + /// Get descriptor table reference + #[inline] + pub fn desc_table(&self) -> &DescTable { + &self.desc_table + } + /// Get a snapshot of the current available cursor position. /// /// Used for batch operations to track the cursor before submitting @@ -835,6 +846,31 @@ impl RingProducer { self.id_num.iter_mut().for_each(|n| *n = 0); self.event_flags_shadow = EventFlags::ENABLE; } + + /// Reset the ring to the "N slots submitted, none completed" state. + /// + /// `ids` contains the descriptor IDs that are in-flight. + /// Sets cursors, counters, and `id_num` accordingly. The chain lengths are all set to 1. + pub fn reset_prefilled(&mut self, ids: &[u16]) { + let size = self.desc_table.len(); + let count = ids.len(); + debug_assert!(count <= size); + + let wrapped = count >= size; + self.avail_cursor.head = if wrapped { 0 } else { count as u16 }; + self.avail_cursor.wrap = !wrapped; + + self.used_cursor.head = 0; + self.used_cursor.wrap = true; + + self.id_num.iter_mut().for_each(|n| *n = 0); + for &id in ids { + self.id_num[id as usize] = 1; + } + + self.num_free = size - count; + self.id_free.clear(); + } } /// Consumer (device) side of a packed virtqueue. @@ -3116,6 +3152,121 @@ pub(crate) mod tests { consumer.reset(); assert_eq!(consumer.num_inflight, 0); } + + #[test] + fn test_reset_prefilled_sets_cursors() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let ids: Vec = (0..8).collect(); + producer.reset_prefilled(&ids); + + // avail wrapped once (all 8 slots submitted) + assert_eq!(producer.avail_cursor.head(), 0); + assert!(!producer.avail_cursor.wrap()); + // used cursor at initial position + assert_eq!(producer.used_cursor.head(), 0); + assert!(producer.used_cursor.wrap()); + } + + #[test] + fn test_reset_prefilled_all_ids_inflight() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + let ids: Vec = (0..8).collect(); + producer.reset_prefilled(&ids); + + assert_eq!(producer.num_free, 0); + assert!(producer.id_free.is_empty()); + assert!(producer.id_num.iter().all(|&n| n == 1)); + } + + #[test] + fn test_reset_prefilled_partial() { + let ring = make_ring(8); + let mut producer = make_producer(&ring); + producer.reset_prefilled(&[5, 6, 7, 3]); + + // avail cursor at position 4, no wrap + assert_eq!(producer.avail_cursor.head(), 4); + assert!(producer.avail_cursor.wrap()); + // used cursor at initial position + assert_eq!(producer.used_cursor.head(), 0); + assert!(producer.used_cursor.wrap()); + + assert_eq!(producer.num_free, 4); + assert!(producer.id_free.is_empty()); + // Only the specified IDs are in-flight + for &id in &[5, 6, 7, 3] { + assert_eq!(producer.id_num[id as usize], 1); + } + for &id in &[0, 1, 2, 4] { + assert_eq!(producer.id_num[id as usize], 0); + } + } + + #[test] + fn test_reset_prefilled_then_poll_used() { + let ring = make_ring(4); + let mut producer = make_producer(&ring); + + // Simulate host prefill: LIFO assigns IDs 3, 2, 1, 0 + for i in 0..4u64 { + producer.submit_one(0x1000 + i * 4096, 4096, true).unwrap(); + } + + // Consumer marks one as used + let mut consumer = make_consumer(&ring); + let (id, _chain) = consumer.poll_available().unwrap(); + consumer.submit_used(id, 64).unwrap(); + + // Fresh producer restores via reset_prefilled with all IDs + let mut restored = make_producer(&ring); + restored.reset_prefilled(&[0, 1, 2, 3]); + + // poll_used should discover the consumed descriptor + let used = restored.poll_used().unwrap(); + assert_eq!(used.id, id); + } + + #[test] + fn test_desc_table_read_after_submit() { + let ring = make_ring(8); + let mut writer = make_producer(&ring); + writer.submit_one(0x1000, 4096, true).unwrap(); + + let reader = make_producer(&ring); + let addr = reader.desc_table().desc_addr(0).unwrap(); + let desc = Descriptor::read_acquire(reader.mem(), addr).unwrap(); + assert_eq!(desc.addr, 0x1000); + assert_eq!(desc.len, 4096); + assert!(desc.is_writeable()); + assert!(desc.is_avail(true)); + assert!(!desc.is_used(true)); + } + + #[test] + fn test_desc_table_out_of_bounds() { + let ring = make_ring(8); + let reader = make_producer(&ring); + assert!(reader.desc_table().desc_addr(8).is_none()); + } + + #[test] + fn test_desc_table_read_used_descriptor() { + let ring = make_ring(8); + let mut writer = make_producer(&ring); + writer.submit_one(0x1000, 4096, true).unwrap(); + + let mut consumer = make_consumer(&ring); + let (id, _chain) = consumer.poll_available().unwrap(); + consumer.submit_used(id, 128).unwrap(); + + let reader = make_producer(&ring); + let addr = reader.desc_table().desc_addr(0).unwrap(); + let desc = Descriptor::read_acquire(reader.mem(), addr).unwrap(); + assert!(desc.is_used(true)); + assert!(!desc.is_avail(true)); + } } #[cfg(test)] diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 8dc114406..2c7820a91 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -71,13 +71,13 @@ pub struct QueueConfig { pub struct GuestContext { g2h_producer: G2hProducer, h2g_producer: H2gProducer, - generation: u16, + generation: u32, last_host_result: Option>, } impl GuestContext { /// Create a new context with G2H and H2G queues. - pub fn new(g2h: QueueConfig, h2g: QueueConfig, generation: u16) -> Self { + pub fn new(g2h: QueueConfig, h2g: QueueConfig, generation: u32) -> Self { let size = g2h.pool_pages * PAGE_SIZE_USIZE; let g2h_pool = BufferPool::new(g2h.pool_gva, size).expect("failed to create G2H buffer pool"); @@ -233,6 +233,71 @@ impl GuestContext { Ok(()) } + /// Restore the H2G producer after snapshot restore. + /// + /// Creates a new [`RecyclePool`] at `pool_gva` and calls + /// [`restore_from_ring`] to reconstruct inflight state + /// from the host's prefilled descriptors. + pub fn restore_h2g(&mut self, pool_gva: u64, pool_size: usize) { + let pool = RecyclePool::new(pool_gva, pool_size, PAGE_SIZE_USIZE) + .expect("H2G RecyclePool creation failed"); + + self.h2g_producer + .restore_from_ring(pool) + .expect("H2G restore_from_ring failed"); + } + + /// Reset the G2H producer with a fresh pool. + /// + /// Creates a new [`BufferPool`] at `pool_gva` and resets the + /// producer to its initial state. + pub fn reset_g2h(&mut self, pool_gva: u64, pool_size: usize) { + let pool = BufferPool::new(pool_gva, pool_size).expect("G2H BufferPool creation failed"); + self.g2h_producer.reset_with_pool(pool); + self.last_host_result = None; + } + + /// Send a log message via the G2H queue. Fire-and-forget. + pub fn emit_log(&mut self, log_data: &[u8]) -> Result<()> { + self.send_g2h_oneshot(MsgKind::Log, log_data) + } + + /// Get the current generation counter. + pub fn generation(&self) -> u32 { + self.generation + } + + /// Set the generation counter after snapshot restore. + pub fn set_generation(&mut self, generation: u32) { + self.generation = generation; + } + + /// Stash a host function result for later retrieval. + /// + /// Used by the C API's two-step calling convention where + /// `hl_call_host_function` and `hl_get_host_return_value_as_*` + /// are separate calls. + pub fn stash_host_result(&mut self, result: Result) { + self.last_host_result = Some(result); + } + + /// Take the stashed host return value. + /// + /// Panics if no value was stashed or if the type conversion fails. + /// If the stashed result was an error, panics with the error message. + pub fn take_host_return>(&mut self) -> T { + let val = self + .last_host_result + .take() + .expect("No host return value available") + .expect("Host function returned an error"); + + match T::try_from(val) { + Ok(v) => v, + Err(_) => panic!("Host return value type mismatch"), + } + } + /// Pre-fill the H2G queue with completion-only descriptors so the host /// can write incoming call payloads into them. fn prefill_h2g(&mut self) { @@ -287,35 +352,6 @@ impl GuestContext { } } - /// Drain any pending G2H completions. - /// - /// This is called before checking for H2G calls so that the host - /// can reclaim G2H response buffers. - pub fn drain_g2h_completions(&mut self) { - while let Ok(Some(_)) = self.g2h_producer.poll() {} - } - - /// Send a log message via the G2H queue. Fire-and-forget. - pub fn emit_log(&mut self, log_data: &[u8]) -> Result<()> { - self.send_g2h_oneshot(MsgKind::Log, log_data) - } - - /// Reset ring and pool state after snapshot restore. - pub(super) fn reset(&mut self, new_generation: u16) { - // G2H producer reset also resets the pool via BufferProvider::reset() - self.g2h_producer.reset(); - // H2G state is NOT reset. The guest's inflight and cursors - // survived via CoW and are already correct. The host's - // restore_h2g_prefill() wrote matching descriptors to the - // zeroed ring memory. Both sides are in sync. - self.generation = new_generation; - self.last_host_result = None; - } - - pub(super) fn generation(&self) -> u16 { - self.generation - } - fn try_send_readonly( &mut self, header: &[u8], @@ -346,30 +382,4 @@ impl GuestContext { entry.write_all(payload)?; self.g2h_producer.submit(entry) } - - /// Stash a host function result for later retrieval. - /// - /// Used by the C API's two-step calling convention where - /// `hl_call_host_function` and `hl_get_host_return_value_as_*` - /// are separate calls. - pub fn stash_host_result(&mut self, result: Result) { - self.last_host_result = Some(result); - } - - /// Take the stashed host return value. - /// - /// Panics if no value was stashed or if the type conversion fails. - /// If the stashed result was an error, panics with the error message. - pub fn take_host_return>(&mut self) -> T { - let val = self - .last_host_result - .take() - .expect("No host return value available") - .expect("Host function returned an error"); - - match T::try_from(val) { - Ok(v) => v, - Err(_) => panic!("Host return value type mismatch"), - } - } } diff --git a/src/hyperlight_guest/src/virtq/mod.rs b/src/hyperlight_guest/src/virtq/mod.rs index 9bb4d2348..d86ae40cc 100644 --- a/src/hyperlight_guest/src/virtq/mod.rs +++ b/src/hyperlight_guest/src/virtq/mod.rs @@ -26,7 +26,6 @@ use core::cell::UnsafeCell; use core::sync::atomic::{AtomicU8, Ordering}; use context::GuestContext; -use hyperlight_common::layout::{SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET, scratch_top_ptr}; pub use mem::GuestMemOps; // Init state machine @@ -34,17 +33,17 @@ const UNINITIALIZED: u8 = 0; const INITIALIZED: u8 = 1; static INIT_STATE: AtomicU8 = AtomicU8::new(UNINITIALIZED); -/// Check if the global context has been initialized. -pub fn is_initialized() -> bool { - INIT_STATE.load(Ordering::Acquire) == INITIALIZED -} - // Storage: UnsafeCell guarded by atomic init state. struct SyncWrap(T); unsafe impl Sync for SyncWrap {} static GLOBAL_CONTEXT: SyncWrap>> = SyncWrap(UnsafeCell::new(None)); +/// Check if the global context has been initialized. +pub fn is_initialized() -> bool { + INIT_STATE.load(Ordering::Acquire) == INITIALIZED +} + /// Access the global guest context via closure. /// /// # Panics @@ -78,19 +77,3 @@ pub fn set_global_context(ctx: GuestContext) { } unsafe { *GLOBAL_CONTEXT.0.get() = Some(ctx) }; } - -/// Reset the global context if a snapshot restore was detected. -/// Compares the virtq generation counter in scratch-top metadata. -pub fn maybe_reset_global_context() { - if !is_initialized() { - return; - } - - let current_gen = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; - - with_context(|ctx| { - if current_gen != ctx.generation() { - ctx.reset(current_gen); - } - }); -} diff --git a/src/hyperlight_guest_bin/src/guest_function/call.rs b/src/hyperlight_guest_bin/src/guest_function/call.rs index fb71ef798..ad7797c7d 100644 --- a/src/hyperlight_guest_bin/src/guest_function/call.rs +++ b/src/hyperlight_guest_bin/src/guest_function/call.rs @@ -88,8 +88,7 @@ pub(crate) fn internal_dispatch_function() { // After snapshot restore, the ring memory is zeroed but the // producer's cursors are stale. Check once per dispatch entry. - virtq::maybe_reset_global_context(); - virtq::with_context(|ctx| ctx.drain_g2h_completions()); + crate::virtq::maybe_reset_virtqueues(); let function_call = virtq::with_context(|ctx| { ctx.recv_h2g_call() diff --git a/src/hyperlight_guest_bin/src/virtq/mod.rs b/src/hyperlight_guest_bin/src/virtq.rs similarity index 71% rename from src/hyperlight_guest_bin/src/virtq/mod.rs rename to src/hyperlight_guest_bin/src/virtq.rs index fd90240df..c3cb85b45 100644 --- a/src/hyperlight_guest_bin/src/virtq/mod.rs +++ b/src/hyperlight_guest_bin/src/virtq.rs @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -//! Guest-side virtqueue initialization. +//! Guest-side virtqueue initialization and reset. use core::num::NonZeroU16; @@ -39,7 +39,7 @@ pub(crate) fn init_virtqueues() { let h2g_depth = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_QUEUE_DEPTH_OFFSET) }; let g2h_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET) } as usize; let h2g_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET) } as usize; - let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; + let generation = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; assert!(g2h_depth > 0 && h2g_depth > 0 && g2h_pages > 0 && h2g_pages > 0); assert!(g2h_gva != 0 && h2g_gva != 0); @@ -81,11 +81,43 @@ pub(crate) fn init_virtqueues() { hyperlight_guest::virtq::set_global_context(ctx); } +/// Reset virtqueue state if a snapshot restore was detected. +/// +/// Compares the generation counter in scratch-top metadata against +/// the context's cached value. On mismatch, restores H2G from the +/// host-prefilled ring and allocates a fresh G2H pool. +pub(crate) fn maybe_reset_virtqueues() { + if !hyperlight_guest::virtq::is_initialized() { + return; + } + + let curr_gen = unsafe { *scratch_top_ptr::(SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET) }; + + hyperlight_guest::virtq::with_context(|ctx| { + if curr_gen == ctx.generation() { + return; + } + + // Read host-assigned H2G pool location from scratch-top + let h2g_pool_gva = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_GVA_OFFSET) }; + let h2g_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_H2G_POOL_PAGES_OFFSET) }; + let g2h_pages = unsafe { *scratch_top_ptr::(SCRATCH_TOP_G2H_POOL_PAGES_OFFSET) }; + + let h2g_pages = h2g_pages as usize; + let g2h_pages = g2h_pages as usize; + let g2h_pool_gva = alloc_pool(g2h_pages); + + ctx.restore_h2g(h2g_pool_gva, h2g_pages * PAGE_SIZE_USIZE); + ctx.reset_g2h(g2h_pool_gva, g2h_pages * PAGE_SIZE_USIZE); + ctx.set_generation(curr_gen); + }); +} + /// Allocate and zero `n` physical pages, returning the GVA. fn alloc_pool(n: usize) -> u64 { let gpa = unsafe { alloc_phys_pages(n as u64) }; let ptr = phys_to_virt(gpa).expect("failed to map pool pages"); - let size = n as usize * PAGE_SIZE_USIZE; + let size = n * PAGE_SIZE_USIZE; unsafe { core::ptr::write_bytes(ptr, 0, size) }; ptr as u64 } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 875ca2ab4..8fdb3457a 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -165,10 +165,8 @@ pub(crate) struct SandboxMemoryManager { pub(crate) g2h_consumer: Option, /// H2G virtqueue consumer, created after sandbox init. pub(crate) h2g_consumer: Option, - /// Saved H2G pool GVA for prefilling after snapshot restore. - pub(crate) h2g_pool_gva: Option, /// Monotonically increasing snapshot generation counter. - snapshot_generation: u16, + snapshot_generation: u32, } impl Clone for SandboxMemoryManager { @@ -182,7 +180,6 @@ impl Clone for SandboxMemoryManager { abort_buffer: self.abort_buffer.clone(), g2h_consumer: None, h2g_consumer: None, - h2g_pool_gva: self.h2g_pool_gva, snapshot_generation: self.snapshot_generation, } } @@ -297,7 +294,6 @@ where abort_buffer: Vec::new(), g2h_consumer: None, h2g_consumer: None, - h2g_pool_gva: None, snapshot_generation: 0, } } @@ -368,7 +364,6 @@ impl SandboxMemoryManager { abort_buffer: self.abort_buffer, g2h_consumer: None, h2g_consumer: None, - h2g_pool_gva: None, snapshot_generation: 0, }; let guest_mgr = SandboxMemoryManager { @@ -380,7 +375,6 @@ impl SandboxMemoryManager { abort_buffer: Vec::new(), // Guest doesn't need abort buffer g2h_consumer: None, h2g_consumer: None, - h2g_pool_gva: None, snapshot_generation: 0, }; host_mgr.update_scratch_bookkeeping()?; @@ -493,9 +487,15 @@ impl SandboxMemoryManager { }; self.layout = *snapshot.layout(); self.update_scratch_bookkeeping()?; + + // Place the H2G pool at first_free so the bump allocator starts right after it. + // Guest reads this GVA from scratch-top during reset(). + let h2g_pool_gva = self.place_h2g_pool_at_first_free()?; + self.init_g2h_consumer()?; self.init_h2g_consumer()?; - self.restore_h2g_prefill()?; + self.restore_h2g_prefill(h2g_pool_gva)?; + Ok((gsnapshot, gscratch)) } @@ -544,8 +544,7 @@ impl SandboxMemoryManager { // Increment generation so the guest detects stale ring state. self.snapshot_generation = self.snapshot_generation.wrapping_add(1); let gen_offset = scratch_size - SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET as usize; - self.scratch_mem - .write::(gen_offset, self.snapshot_generation)?; + self.scratch_mem.write::(gen_offset, self.snapshot_generation)?; // Copy the page tables into the scratch region let snapshot_pt_end = self.shared_mem.mem_size(); @@ -859,22 +858,39 @@ impl SandboxMemoryManager { Ok(()) } + /// Place the H2G pool at `first_free` during snapshot restore. + /// + /// Writes the pool GVA to scratch-top and advances the bump + /// allocator past the pool so COW page-fault resolution cannot + /// alias pool memory. Returns the computed pool GVA for use by + /// [`restore_h2g_prefill`]. + fn place_h2g_pool_at_first_free(&mut self) -> Result { + use hyperlight_common::layout::*; + + let scratch_size = self.scratch_mem.mem_size(); + let first_free = self.layout.get_first_free_scratch_gpa(); + let base_gpa = scratch_base_gpa(scratch_size); + let base_gva = scratch_base_gva(scratch_size); + let h2g_pool_gva = base_gva + (first_free - base_gpa); + let h2g_pages = self.layout.sandbox_memory_config.get_h2g_pool_pages() as u64; + + self.update_scratch_bookkeeping_item(SCRATCH_TOP_H2G_POOL_GVA_OFFSET, h2g_pool_gva)?; + let allocator = first_free + h2g_pages * PAGE_SIZE_USIZE as u64; + self.update_scratch_bookkeeping_item(SCRATCH_TOP_ALLOCATOR_OFFSET, allocator)?; + + Ok(h2g_pool_gva) + } + /// Prefill the H2G ring with writable descriptors after snapshot restore. /// /// Uses a temporary `RingProducer` to write descriptors into the H2G ring /// so the host consumer can poll them. The guest's `restore_from_ring` /// will later reconstruct its inflight state from these descriptors. - pub(crate) fn restore_h2g_prefill(&mut self) -> Result<()> { - let pool_gva = match self.h2g_pool_gva { - Some(gva) => gva, - None => return Ok(()), - }; - + fn restore_h2g_prefill(&mut self, pool_gva: u64) -> Result<()> { let layout = self.h2g_virtq_layout()?; let mem_ops = self.host_mem_ops(); let h2g_depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); - // Pool size from config let slot_size = PAGE_SIZE_USIZE; let pool_size = self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE; let slot_count = pool_size / slot_size; @@ -882,10 +898,11 @@ impl SandboxMemoryManager { let mut producer = virtq::RingProducer::new(layout, mem_ops); let prefill_count = core::cmp::min(slot_count, h2g_depth); - // Write descriptors in reverse order to match the guest's LIFO - // allocation pattern (RecyclePool::alloc pops from the end of - // the free list, so the first prefill gets the highest address). - for i in (0..prefill_count).rev() { + // Write descriptors in forward order. The guest calls + // restore_from_ring which reconstructs used-descriptor addresses + // as base + position * slot_size, so the iteration order must + // match this formula. + for i in 0..prefill_count { let addr = pool_gva + (i * slot_size) as u64; producer .submit_one(addr, slot_size as u32, true) @@ -1096,4 +1113,89 @@ mod tests { verify_page_tables(name, config); } } + + /// Verify that the H2G pool placed at `first_free` during restore + /// does not overlap with the bump allocator range or the + /// scratch-top metadata region. + /// + /// This guards against the COW-pool GPA overlap bug: if the bump + /// allocator could return GPAs inside the pool region, a COW + /// page-fault would overwrite pool buffer data with stale shared + /// memory content, corrupting virtqueue communication. + fn verify_pool_allocator_no_collision(name: &str, config: SandboxConfiguration) { + let path = simple_guest_as_string().expect("failed to get simple guest path"); + let snapshot = Snapshot::from_env(GuestBinary::FilePath(path), config) + .unwrap_or_else(|e| panic!("{name}: failed to create snapshot: {e}")); + + let layout = snapshot.layout(); + let scratch_size = layout.get_scratch_size(); + let first_free = layout.get_first_free_scratch_gpa(); + let h2g_pages = layout.sandbox_memory_config.get_h2g_pool_pages(); + let scratch_base = hyperlight_common::layout::scratch_base_gpa(scratch_size); + + let pool_start = first_free; + let pool_end = first_free + (h2g_pages * PAGE_TABLE_SIZE) as u64; + let allocator_start = pool_end; + + // The metadata region lives at the very top of scratch. + // SCRATCH_TOP_EXN_STACK_OFFSET (0x50) is the highest offset. + // Two pages are reserved at the top for exception stack and metadata. + let scratch_end = scratch_base + scratch_size as u64; + let metadata_start = scratch_end - 2 * PAGE_TABLE_SIZE as u64; + + assert!( + pool_start >= scratch_base, + "{name}: pool starts before scratch (pool=0x{pool_start:x}, scratch=0x{scratch_base:x})" + ); + + assert!( + pool_end <= metadata_start, + "{name}: pool overlaps metadata (pool_end=0x{pool_end:x}, metadata=0x{metadata_start:x})" + ); + + assert_eq!( + allocator_start, pool_end, + "{name}: allocator should start immediately after pool" + ); + + assert!( + allocator_start < metadata_start, + "{name}: no room for COW allocations (allocator=0x{allocator_start:x}, metadata=0x{metadata_start:x})" + ); + } + + #[test] + fn test_pool_allocator_no_collision() { + let test_cases: Vec<(&str, SandboxConfiguration)> = vec![ + ("default", SandboxConfiguration::default()), + ("large pools", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(16); + cfg.set_g2h_pool_pages(16); + cfg + }), + ("minimal scratch", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_scratch_size(0x20000); + cfg + }), + ("large scratch", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_scratch_size(0x100000); + cfg + }), + ("large heap + large pools", { + let mut cfg = SandboxConfiguration::default(); + cfg.set_heap_size(LARGE_HEAP_SIZE); + cfg.set_scratch_size(0x100000); + cfg.set_h2g_pool_pages(32); + cfg.set_g2h_pool_pages(32); + cfg + }), + ]; + + for (name, config) in test_cases { + verify_pool_allocator_no_collision(name, config); + } + } } diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index ec1aba0de..cb5566eda 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1157,6 +1157,220 @@ mod tests { assert_eq!(res, 0); } + /// Many snapshot restore cycles with state-modifying guest calls. + #[test] + fn restore_stress_no_pool_corruption() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for _ in 0..50 { + sbox.restore(snapshot.clone()).unwrap(); + let _ = sbox.call::("AddToStatic", 1i32).unwrap(); + + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 1); + + let res: i32 = sbox.call("AddToStatic", 2i32).unwrap(); + assert_eq!(res, 3); + } + } + + /// Stress test: snapshot/restore with G2H queue pressure. + #[test] + fn restore_stress_with_host_calls() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for i in 0..50 { + sbox.restore(snapshot.clone()).unwrap(); + + // Fire-and-forget log oneshots - multiple G2H entries queued + // without waiting for responses + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + + // G2H round-trip with returned data after logs filled the queue + let echo: String = sbox.call("Echo", "ping".to_string()).unwrap(); + assert_eq!(echo, "ping"); + + // Multiple calls without restore to exercise queue reuse + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 1); + + let echo2: String = sbox.call("Echo", format!("echo {i}")).unwrap(); + assert_eq!(echo2, format!("echo {i}")); + } + } + + /// Back-to-back restores without any guest call in between. + /// The generation bumps twice but the guest only sees the latest value. + #[test] + fn restore_back_to_back() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let _ = sbox.call::("AddToStatic", 42i32).unwrap(); + let snapshot = sbox.snapshot().unwrap(); + + // Two restores in a row, no guest calls between them + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + + // Guest should see the snapshot state (static = 42) + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 42); + + // Another round: three restores, then call + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + sbox.restore(snapshot.clone()).unwrap(); + + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 43); + } + + /// Restore after flooding the G2H queue with log oneshots. + #[test] + fn restore_after_g2h_pressure() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for _ in 0..20 { + // Flood G2H with many log oneshots to pressure the queue/pool + sbox.call::<()>("LogMessageN", 30_i32).unwrap(); + + // Restore after heavy G2H usage + sbox.restore(snapshot.clone()).unwrap(); + + // Verify queue works cleanly after restore + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + let echo: String = sbox.call("Echo", "ok".to_string()).unwrap(); + assert_eq!(echo, "ok"); + } + } + + /// Many calls cycling through all descriptor IDs, then restore. + /// Ensures restore handles post-wraparound ring state. + #[test] + fn restore_after_id_wraparound() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + for i in 0..200 { + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, i + 1); + } + + // Restore after IDs have wrapped around many times + sbox.restore(snapshot.clone()).unwrap(); + + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + // Do another round of wraparound + restore + for _ in 0..200 { + let _ = sbox.call::("AddToStatic", 1i32).unwrap(); + } + sbox.restore(snapshot.clone()).unwrap(); + + let echo: String = sbox.call("Echo", "after wraparound".to_string()).unwrap(); + assert_eq!(echo, "after wraparound"); + } + + /// Restore after a guest exception recovers the sandbox. + /// The virtqueue must be fully functional after restore despite + /// the guest having been in a broken state. + #[test] + fn restore_after_guest_error() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + let snapshot = sbox.snapshot().unwrap(); + + // Normal call first + let res: i32 = sbox.call("AddToStatic", 5i32).unwrap(); + assert_eq!(res, 5); + + // Trigger an exception - guest is now in a broken state + let err = sbox.call::<()>("TriggerException", ()); + assert!(err.is_err()); + + // Restore should recover fully + sbox.restore(snapshot.clone()).unwrap(); + + // Verify everything works after recovery + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + let echo: String = sbox.call("Echo", "recovered".to_string()).unwrap(); + assert_eq!(echo, "recovered"); + + sbox.call::<()>("LogMessageN", 5_i32).unwrap(); + let res: i32 = sbox.call("AddToStatic", 1i32).unwrap(); + assert_eq!(res, 1); + } + + /// Snapshot immediately after evolve, restore before any calls. + /// Baseline test: the virtqueue has never been used. + #[test] + fn restore_fresh_snapshot() { + let mut sbox: MultiUseSandbox = { + let path = simple_guest_as_string().unwrap(); + let u_sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), None).unwrap(); + u_sbox.evolve() + } + .unwrap(); + + // Snapshot immediately - no guest calls yet + let snapshot = sbox.snapshot().unwrap(); + + sbox.restore(snapshot.clone()).unwrap(); + + // First-ever guest call after restore + let res: i32 = sbox.call("GetStatic", ()).unwrap(); + assert_eq!(res, 0); + + let echo: String = sbox.call("Echo", "first".to_string()).unwrap(); + assert_eq!(echo, "first"); + + // Restore again and verify + sbox.restore(snapshot.clone()).unwrap(); + sbox.call::<()>("LogMessageN", 10_i32).unwrap(); + let res: i32 = sbox.call("AddToStatic", 7i32).unwrap(); + assert_eq!(res, 7); + } + #[test] fn test_trigger_exception_on_guest() { let usbox = UninitializedSandbox::new( diff --git a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs index 6e02cfe26..428594d37 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized_evolve.rs @@ -16,7 +16,6 @@ limitations under the License. #[cfg(gdb)] use std::sync::{Arc, Mutex}; -use hyperlight_common::layout::SCRATCH_TOP_H2G_POOL_GVA_OFFSET; use rand::RngExt; use tracing::{Span, instrument}; @@ -27,7 +26,7 @@ use crate::hypervisor::hyperlight_vm::{HyperlightVm, HyperlightVmError}; use crate::mem::exe::LoadInfo; use crate::mem::mgr::SandboxMemoryManager; use crate::mem::ptr::RawPtr; -use crate::mem::shared_mem::{GuestSharedMemory, SharedMemory}; +use crate::mem::shared_mem::GuestSharedMemory; #[cfg(gdb)] use crate::sandbox::config::DebugInfo; #[cfg(feature = "mem_profile")] @@ -132,18 +131,6 @@ pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result(offset) - && gva != 0 - { - hshm.h2g_pool_gva = Some(gva); - } - } - #[cfg(gdb)] let dbg_mem_wrapper = Arc::new(Mutex::new(hshm.clone())); From 9633c32f37cce8e9a5138dee6c71aa39d985b650 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 9 Apr 2026 15:53:20 +0200 Subject: [PATCH 16/26] feat(virtq): add support for multi-descriptor payloads Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 5 +- src/hyperlight_common/src/virtq/mod.rs | 56 ++++++++- src/hyperlight_common/src/virtq/msg.rs | 32 ++++- src/hyperlight_common/src/virtq/pool.rs | 45 ++++++- src/hyperlight_common/src/virtq/producer.rs | 45 ++++--- src/hyperlight_guest/src/virtq/context.rs | 118 +++++++++++------- src/hyperlight_host/src/mem/mgr.rs | 92 ++++++++++---- src/hyperlight_host/src/sandbox/outb.rs | 4 +- .../tests/sandbox_host_tests.rs | 76 +++++++++++ 9 files changed, 369 insertions(+), 104 deletions(-) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index b29f7694a..fb11c778e 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -255,6 +255,7 @@ pub struct VirtqConsumer { inner: RingConsumer, notifier: N, inflight: FixedBitSet, + next_token: u32, } impl VirtqConsumer { @@ -273,6 +274,7 @@ impl VirtqConsumer { inner, notifier, inflight, + next_token: 0, } } @@ -323,7 +325,8 @@ impl VirtqConsumer { } self.inflight.insert(id_idx); - let token = Token(id); + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); // Copy entry data from shared memory let data = entry_elem diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 331e6cc82..f8039b38c 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -335,11 +335,11 @@ pub enum SuppressionKind { /// A token representing a sent entry in the virtqueue. /// -/// Tokens uniquely identify in-flight requests and are used to correlate -/// requests with their responses. The token value corresponds to the -/// descriptor ID in the underlying ring. +/// Tokens uniquely identify in-flight requests and are used to correlate requests with their responses. +/// The first element is a monotonically increasing generation counter. The second element is the +/// underlying descriptor ID #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub struct Token(pub u16); +pub struct Token(pub u32, pub u16); impl From for Allocation { fn from(value: BufferElement) -> Self { @@ -972,6 +972,54 @@ mod tests { assert_eq!(cqe2.token, tok_rw); assert_eq!(&cqe2.data[..], b"reply"); } + + /// Regression test: reclaim + submit must not cause token collisions. + /// + /// Before the monotonic generation counter, Token wrapped the descriptor + /// ID which gets recycled. This caused stale pending completions to + /// match newly submitted entries with the same recycled descriptor ID. + #[test] + fn test_reclaim_submit_no_token_collision() { + let ring = make_ring(8); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + // Submit and complete a ReadOnly entry + let tok_old = send_readonly(&mut producer, b"log"); + + let (_, c) = consumer.poll(1024).unwrap().unwrap(); + consumer.complete(c).unwrap(); + + // Reclaim pushes the completion to pending (token = tok_old) + let count = producer.reclaim().unwrap(); + assert_eq!(count, 1); + + // Submit a new ReadWrite entry - may reuse the same descriptor ID + let tok_new = send_readwrite(&mut producer, b"call", 64); + + // Tokens must differ even if the descriptor ID was recycled + assert_ne!( + tok_old, tok_new, + "tokens must be unique across reclaim/submit cycles" + ); + + // Complete the ReadWrite entry + let (_, c) = consumer.poll(1024).unwrap().unwrap(); + let SendCompletion::Writable(mut wc) = c else { + panic!("expected writable"); + }; + wc.write_all(b"result").unwrap(); + consumer.complete(wc.into()).unwrap(); + + // Poll should return the stale ReadOnly completion first (wrong token) + let cqe1 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe1.token, tok_old); + assert!(cqe1.data.is_empty()); + + // Then the new ReadWrite completion (matching token) + let cqe2 = producer.poll().unwrap().unwrap(); + assert_eq!(cqe2.token, tok_new); + assert_eq!(&cqe2.data[..], b"result"); + } } #[cfg(all(test, loom))] mod fuzz { diff --git a/src/hyperlight_common/src/virtq/msg.rs b/src/hyperlight_common/src/virtq/msg.rs index ade59643b..090c2eb5b 100644 --- a/src/hyperlight_common/src/virtq/msg.rs +++ b/src/hyperlight_common/src/virtq/msg.rs @@ -20,6 +20,8 @@ limitations under the License. //! fixed 8-byte header, enabling message type discrimination and //! request/response correlation. +use bitflags::bitflags; + /// Message types for the virtqueue wire protocol. #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -54,24 +56,33 @@ impl TryFrom for MsgKind { } } +bitflags! { + #[repr(transparent)] + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct MsgFlags: u8 { + /// More descriptors follow for this message. + const MORE = 1 << 0; + } +} + /// Wire header for all virtqueue messages #[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)] #[repr(C)] pub struct VirtqMsgHeader { /// Discriminates the message type. pub kind: u8, - /// Per-type flags TODO(ring): add flags type. + /// Per-message flags (see [`MsgFlags`]). pub flags: u8, /// Caller-assigned correlation ID. Responses echo the request's ID. pub req_id: u16, - /// Byte length of the payload following this header. + /// Byte length of the payload following this header in this descriptor. pub payload_len: u32, } impl VirtqMsgHeader { pub const SIZE: usize = core::mem::size_of::(); - /// Create a new message header. + /// Create a new message header with no flags set. pub const fn new(kind: MsgKind, req_id: u16, payload_len: u32) -> Self { Self { kind: kind as u8, @@ -82,10 +93,10 @@ impl VirtqMsgHeader { } /// Create a new header with flags. - pub const fn with_flags(kind: MsgKind, flags: u8, req_id: u16, payload_len: u32) -> Self { + pub const fn with_flags(kind: MsgKind, flags: MsgFlags, req_id: u16, payload_len: u32) -> Self { Self { kind: kind as u8, - flags, + flags: flags.bits(), req_id, payload_len, } @@ -95,4 +106,15 @@ impl VirtqMsgHeader { pub fn msg_kind(&self) -> Result { MsgKind::try_from(self.kind) } + + /// Interpret the raw flags field as [`MsgFlags`]. + pub fn msg_flags(&self) -> MsgFlags { + MsgFlags::from_bits_truncate(self.flags) + } + + /// Returns true if [`MsgFlags::MORE`] is set, indicating more + /// descriptors follow for this message. + pub const fn has_more(&self) -> bool { + self.flags & MsgFlags::MORE.bits() != 0 + } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index bbae4ff41..d60d432bf 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -150,10 +150,10 @@ impl Slab { } // Fallback to full search + let total = self.used_slots.len(); self.used_slots.zeroes().find(|&next_free| { - self.used_slots - .count_zeroes(next_free..next_free + slots_num) - == slots_num + let end = next_free + slots_num; + end <= total && self.used_slots.count_zeroes(next_free..end) == slots_num }) } @@ -416,6 +416,11 @@ impl BufferPool { inner: SyncWrap(Rc::new(RefCell::new(inner))), }) } + + /// Upper slab slot size in bytes. + pub const fn upper_slot_size() -> usize { + U + } } #[cfg(all(test, loom))] @@ -821,6 +826,40 @@ mod tests { assert!(matches!(result, Err(AllocError::InvalidFree(_, _)))); } + #[test] + fn test_slab_multi_slot_alloc_near_end() { + let mut slab = make_slab::<256>(1792); // 7 slots + let a0 = slab.alloc(256).unwrap(); + let a1 = slab.alloc(256).unwrap(); + let _a2 = slab.alloc(256).unwrap(); + let _a3 = slab.alloc(256).unwrap(); + let _a4 = slab.alloc(256).unwrap(); + let _a5 = slab.alloc(256).unwrap(); + let _a6 = slab.alloc(256).unwrap(); + + slab.dealloc(a0).unwrap(); + slab.dealloc(a1).unwrap(); + + // 2-slot run fits at indices 0..2 but the search visits index 6 + // (a free zero) first if slots 0-1 are not found before it. + // Actually slots 0-1 are free, so it should find them. + let run = slab.alloc(300).unwrap(); // needs 2 slots + assert_eq!(run.len, 512); + } + + #[test] + fn test_slab_multi_slot_alloc_no_room_at_end() { + // Only the last slot is free but a 2-slot run is requested. + // find_slots must not panic when checking beyond the bitset. + let mut slab = make_slab::<256>(1792); // 7 slots + let allocs: Vec<_> = (0..7).map(|_| slab.alloc(256).unwrap()).collect(); + // Free only the last slot (index 6) + slab.dealloc(allocs[6]).unwrap(); + + let result = slab.alloc(300); // needs 2 slots, only 1 free + assert!(matches!(result, Err(AllocError::NoSpace))); + } + #[test] fn test_slab_free_invalid_address() { let mut slab = make_slab::<256>(1024); diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index beeb4519b..b66bc78b4 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -125,7 +125,8 @@ pub struct VirtqProducer { inner: RingProducer, notifier: N, pool: P, - inflight: Vec>, + next_token: u32, + inflight: Vec>, pending: VecDeque, } @@ -152,6 +153,7 @@ where pool, notifier, inflight, + next_token: 0, pending: VecDeque::new(), } } @@ -218,13 +220,20 @@ where }; let id = used.id as usize; - let inf = self + let (token, inf) = self .inflight .get_mut(id) .ok_or(VirtqError::InvalidState)? .take() .ok_or(VirtqError::InvalidState)?; + // the token's descriptor ID must match the ring's + debug_assert_eq!( + token.1, used.id, + "ring returned desc_id={} but inflight slot {} has token with desc_id={}", + used.id, id, token.1, + ); + let written = used.len as usize; // Free entry buffers (request data no longer needed) @@ -250,10 +259,7 @@ where None => Bytes::new(), }; - Ok(Some(RecvCompletion { - token: Token(used.id), - data, - })) + Ok(Some(RecvCompletion { token, data })) } /// Drain all available completions, calling the provided closure for each. @@ -310,6 +316,9 @@ where let chain = inflight.try_into_chain(written)?; let id = self.inner.submit_available(&chain)?; + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); + let slot = self .inflight .get_mut(id as usize) @@ -319,7 +328,7 @@ where return Err(VirtqError::InvalidState); } - *slot = Some(inflight); + *slot = Some((token, inflight)); let should_notify = self.inner.should_notify_since(cursor_before)?; @@ -336,7 +345,7 @@ where }); } - Ok(Token(id)) + Ok(token) } /// Signal backpressure to the consumer. @@ -474,12 +483,18 @@ where .slot_addr(pos as usize) .ok_or(VirtqError::InvalidState)?; - self.inflight[id as usize] = Some(Inflight::WriteOnly { - completion: Allocation { - addr, - len: slot_size, + let token = Token(self.next_token, id); + self.next_token = self.next_token.wrapping_add(1); + + self.inflight[id as usize] = Some(( + token, + Inflight::WriteOnly { + completion: Allocation { + addr, + len: slot_size, + }, }, - }); + )); ids.push(id); } @@ -869,7 +884,7 @@ mod tests { // Ring should still be fully usable let se = producer.chain().entry(64).completion(128).build().unwrap(); let tok = producer.submit(se).unwrap(); - assert!(tok.0 < 16); + assert!(tok.1 < 16); } #[test] @@ -885,7 +900,7 @@ mod tests { // Ring should still be fully usable let se = producer.chain().entry(64).completion(128).build().unwrap(); let tok = producer.submit(se).unwrap(); - assert!(tok.0 < 16); + assert!(tok.1 < 16); } #[test] diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 2c7820a91..fae3c87be 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -39,7 +39,6 @@ use crate::bail; use crate::error::Result; static REQUEST_ID: AtomicU16 = AtomicU16::new(0); -const MAX_RESPONSE_CAP: usize = 4096; /// Guest-side notifier that triggers a VM exit via outb. #[derive(Clone, Copy)] @@ -69,9 +68,18 @@ pub struct QueueConfig { /// Virtqueue runtime state for guest-host communication. pub struct GuestContext { + /// guest-to-host driver g2h_producer: G2hProducer, + /// host-to-guest driver h2g_producer: H2gProducer, + /// Max writable bytes the host can write into a G2H completion. + /// Derived from the G2H pool upper slab slot size. + g2h_response_cap: usize, + /// H2G slot size in bytes (each prefilled writable descriptor). + h2g_slot_size: usize, + /// snapshot generation counter generation: u32, + /// used by cabi last_host_result: Option>, } @@ -81,30 +89,27 @@ impl GuestContext { let size = g2h.pool_pages * PAGE_SIZE_USIZE; let g2h_pool = BufferPool::new(g2h.pool_gva, size).expect("failed to create G2H buffer pool"); + let g2h_response_cap = BufferPool::<256, 4096>::upper_slot_size(); let g2h_producer = VirtqProducer::new(g2h.layout, GuestMemOps, GuestNotifier, g2h_pool.clone()); - // Each H2G prefill entry is a single descriptor with one contiguous buffer: one - // fixed-size buffer per descriptor, large payloads split across multiple independent - // completions. - // - // TODO(virtq): consider smaller slot_size (e.g. pool_size / desc_count) to maximize - // prefilled entries for host-side call batching. let size = h2g.pool_pages * PAGE_SIZE_USIZE; - let slot = PAGE_SIZE_USIZE; - let h2g_pool = - RecyclePool::new(h2g.pool_gva, size, slot).expect("failed to create H2G recycle pool"); + let h2g_slot_size = PAGE_SIZE_USIZE; + let h2g_pool = RecyclePool::new(h2g.pool_gva, size, h2g_slot_size) + .expect("failed to create H2G recycle pool"); let h2g_producer = VirtqProducer::new(h2g.layout, GuestMemOps, GuestNotifier, h2g_pool.clone()); let mut ctx = Self { g2h_producer, h2g_producer, + g2h_response_cap, + h2g_slot_size, generation, last_host_result: None, }; - ctx.prefill_h2g(); + ctx.prefill_h2g().expect("H2G initial prefill failed"); ctx } @@ -164,8 +169,8 @@ impl GuestContext { }; let result_bytes = &completion.data; - if result_bytes.len() > MAX_RESPONSE_CAP { - bail!("G2H: response is too large"); + if result_bytes.len() < VirtqMsgHeader::SIZE { + bail!("G2H: response too short for header"); } let payload_bytes = &result_bytes[VirtqMsgHeader::SIZE..]; @@ -182,12 +187,16 @@ impl GuestContext { } /// Receive a host-to-guest function call from the H2G queue. + /// + /// Each descriptor carries a [`VirtqMsgHeader`] with `payload_len` for + /// that chunk. If [`MsgFlags::MORE`](hyperlight_common::virtq::msg::MsgFlags::MORE) + /// is set, more descriptors follow. pub fn recv_h2g_call(&mut self) -> Result { - let Some(completion) = self.h2g_producer.poll()? else { + let Some(first) = self.h2g_producer.poll()? else { bail!("H2G: no pending call"); }; - let data = &completion.data; + let data = &first.data; if data.len() < VirtqMsgHeader::SIZE { bail!("H2G: completion too short for header"); } @@ -198,39 +207,52 @@ impl GuestContext { bail!("H2G: unexpected message kind: 0x{:02x}", hdr.kind); } - let payload_end = VirtqMsgHeader::SIZE + hdr.payload_len as usize; - if payload_end > data.len() { - bail!("H2G: payload length exceeds completion data"); + let chunk_len = hdr.payload_len as usize; + + if !hdr.has_more() { + // Single-descriptor fast path + let payload = &data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]; + let fc = FunctionCall::try_from(payload)?; + return Ok(fc); } - let payload = &data[VirtqMsgHeader::SIZE..payload_end]; - let fc = FunctionCall::try_from(payload)?; + // Multi-descriptor: accumulate payload until MsgFlags::MORE is cleared + let mut assembled = Vec::with_capacity(chunk_len * 2); + assembled.extend_from_slice(&data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]); + + loop { + let Some(next) = self.h2g_producer.poll()? else { + bail!("H2G: expected continuation descriptor, none available"); + }; + + let next_data = &next.data; + if next_data.len() < VirtqMsgHeader::SIZE { + bail!("H2G: continuation too short for header"); + } + + let next_hdr: &VirtqMsgHeader = + bytemuck::from_bytes(&next_data[..VirtqMsgHeader::SIZE]); + + let next_chunk = next_hdr.payload_len as usize; + + assembled.extend_from_slice( + &next_data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + next_chunk], + ); + + if !next_hdr.has_more() { + break; + } + } + + let fc = FunctionCall::try_from(assembled.as_slice())?; Ok(fc) } /// Send the result of a host-to-guest call back to the host via the - /// G2H queue, then refill one H2G descriptor slot. + /// G2H queue, then refill H2G descriptor slots until the ring is full. pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { self.send_g2h_oneshot(MsgKind::Response, payload)?; - - // Best-effort refill of one H2G slot. Backpressure is expected - // (pool/ring may be full), other errors are propagated. - match self - .h2g_producer - .chain() - .completion(PAGE_SIZE_USIZE) - .build() - { - Ok(e) => match self.h2g_producer.submit(e) { - Ok(_) => {} - Err(virtq::VirtqError::Backpressure) => {} - Err(e) => bail!("H2G refill submit: {e}"), - }, - Err(virtq::VirtqError::Backpressure) => {} - Err(e) => bail!("H2G refill build: {e}"), - } - - Ok(()) + self.prefill_h2g() } /// Restore the H2G producer after snapshot restore. @@ -239,7 +261,7 @@ impl GuestContext { /// [`restore_from_ring`] to reconstruct inflight state /// from the host's prefilled descriptors. pub fn restore_h2g(&mut self, pool_gva: u64, pool_size: usize) { - let pool = RecyclePool::new(pool_gva, pool_size, PAGE_SIZE_USIZE) + let pool = RecyclePool::new(pool_gva, pool_size, self.h2g_slot_size) .expect("H2G RecyclePool creation failed"); self.h2g_producer @@ -300,23 +322,23 @@ impl GuestContext { /// Pre-fill the H2G queue with completion-only descriptors so the host /// can write incoming call payloads into them. - fn prefill_h2g(&mut self) { + fn prefill_h2g(&mut self) -> Result<()> { loop { let entry = match self .h2g_producer .chain() - .completion(PAGE_SIZE_USIZE) + .completion(self.h2g_slot_size) .build() { Ok(e) => e, - Err(virtq::VirtqError::Backpressure) => break, - Err(e) => panic!("H2G prefill build: {e}"), + Err(e) if e.is_transient() => return Ok(()), + Err(e) => bail!("H2G prefill build: {e}"), }; match self.h2g_producer.submit(entry) { Ok(_) => {} - Err(virtq::VirtqError::Backpressure) => break, - Err(e) => panic!("H2G prefill submit: {e}"), + Err(e) if e.is_transient() => return Ok(()), + Err(e) => bail!("H2G prefill submit: {e}"), } } } @@ -375,7 +397,7 @@ impl GuestContext { .g2h_producer .chain() .entry(entry_len) - .completion(MAX_RESPONSE_CAP) + .completion(self.g2h_response_cap) .build()?; entry.write_all(header)?; diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 8fdb3457a..2173b66ba 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -19,7 +19,7 @@ use std::num::NonZeroU16; use hyperlight_common::flatbuffer_wrappers::function_types::FunctionCallResult; use hyperlight_common::mem::PAGE_SIZE_USIZE; -use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; +use hyperlight_common::virtq::msg::{MsgFlags, MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{self, Layout as VirtqLayout}; use hyperlight_common::vmem::{self, PAGE_TABLE_SIZE, PageTableEntry, PhysAddr}; #[cfg(all(feature = "crashdump", not(feature = "nanvix-unstable")))] @@ -543,8 +543,9 @@ impl SandboxMemoryManager { )?; // Increment generation so the guest detects stale ring state. self.snapshot_generation = self.snapshot_generation.wrapping_add(1); - let gen_offset = scratch_size - SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET as usize; - self.scratch_mem.write::(gen_offset, self.snapshot_generation)?; + let gen_off = scratch_size - SCRATCH_TOP_SNAPSHOT_GENERATION_OFFSET as usize; + self.scratch_mem + .write::(gen_off, self.snapshot_generation)?; // Copy the page tables into the scratch region let snapshot_pt_end = self.shared_mem.mem_size(); @@ -816,6 +817,20 @@ impl SandboxMemoryManager { HostMemOps::new(&self.scratch_mem, scratch_base_gva) } + /// Total G2H buffer pool size in bytes. + pub(crate) fn g2h_pool_size(&self) -> usize { + self.layout.sandbox_memory_config.get_g2h_pool_pages() * PAGE_SIZE_USIZE + } + + pub(crate) fn h2g_pool_size(&self) -> usize { + self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE + } + + /// H2G slot size in bytes. Each prefilled writable descriptor has this capacity. + pub(crate) fn h2g_slot_size(&self) -> usize { + PAGE_SIZE_USIZE + } + /// Initialize the G2H virtqueue consumer. /// Must be called after scratch bookkeeping is written. pub(crate) fn init_g2h_consumer(&mut self) -> Result<()> { @@ -891,7 +906,7 @@ impl SandboxMemoryManager { let mem_ops = self.host_mem_ops(); let h2g_depth = self.layout.sandbox_memory_config.get_h2g_queue_depth(); - let slot_size = PAGE_SIZE_USIZE; + let slot_size = self.h2g_slot_size(); let pool_size = self.layout.sandbox_memory_config.get_h2g_pool_pages() * PAGE_SIZE_USIZE; let slot_count = pool_size / slot_size; @@ -914,39 +929,60 @@ impl SandboxMemoryManager { /// Write a guest function call into the H2G virtqueue. /// - /// Polls the H2G consumer for a prefilled entry from the guest, - /// writes `VirtqMsgHeader::Request` followed by `buffer` into the - /// writable completion, and completes the entry. + /// Large payloads that exceed a single slot are split across multiple descriptors. pub(crate) fn write_guest_function_call_virtq(&mut self, buffer: &[u8]) -> Result<()> { + let h2g_pool_size = self.h2g_pool_size(); + let consumer = self .h2g_consumer .as_mut() .ok_or_else(|| new_error!("H2G consumer not initialized"))?; - let (entry, completion) = consumer - .poll(8192) - .map_err(|e| new_error!("H2G poll: {:?}", e))? - .ok_or_else(|| new_error!("H2G: no prefilled entry available"))?; + let mut offset = 0usize; - // Consume the entry data - this should be empty - drop(entry); + loop { + let remaining = buffer.len() - offset; - let header = VirtqMsgHeader::new(MsgKind::Request, 0, buffer.len() as u32); + let (entry, completion) = consumer + .poll(h2g_pool_size) + .map_err(|e| new_error!("H2G poll: {:?}", e))? + .ok_or_else(|| new_error!("H2G: no prefilled descriptor available"))?; - let virtq::SendCompletion::Writable(mut wc) = completion else { - return Err(new_error!( - "H2G: expected writable completion, got non-writable (ring corruption)" - )); - }; + drop(entry); - wc.write_all(bytemuck::bytes_of(&header)) - .map_err(|e| new_error!("H2G write header: {:?}", e))?; - wc.write_all(buffer) - .map_err(|e| new_error!("H2G write payload: {:?}", e))?; + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(new_error!( + "H2G: expected writable completion (ring corruption)" + )); + }; - consumer - .complete(wc.into()) - .map_err(|e| new_error!("H2G complete: {:?}", e))?; + let data_cap = wc.capacity() - VirtqMsgHeader::SIZE; + let chunk_len = remaining.min(data_cap); + let has_more = offset + chunk_len < buffer.len(); + + let flags = if has_more { + MsgFlags::MORE + } else { + MsgFlags::empty() + }; + + let hdr = VirtqMsgHeader::with_flags(MsgKind::Request, flags, 0, chunk_len as u32); + + wc.write_all(bytemuck::bytes_of(&hdr)) + .map_err(|e| new_error!("H2G write header: {:?}", e))?; + wc.write_all(&buffer[offset..offset + chunk_len]) + .map_err(|e| new_error!("H2G write payload: {:?}", e))?; + + consumer + .complete(wc.into()) + .map_err(|e| new_error!("H2G complete: {:?}", e))?; + + offset += chunk_len; + + if !has_more { + break; + } + } Ok(()) } @@ -955,6 +991,8 @@ impl SandboxMemoryManager { /// /// The guest submitted the Response on G2H with pub(crate) fn read_h2g_result_from_g2h(&mut self) -> Result { + let g2h_pool_size = self.g2h_pool_size(); + let consumer = self .g2h_consumer .as_mut() @@ -964,7 +1002,7 @@ impl SandboxMemoryManager { // find the Response that carries the H2G function call result. loop { let maybe_next = consumer - .poll(8192) + .poll(g2h_pool_size) .map_err(|e| new_error!("G2H poll for H2G result: {:?}", e))?; let Some((entry, completion)) = maybe_next else { diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 3fa571db2..626236ba4 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -191,6 +191,8 @@ fn outb_virtq_call( mem_mgr: &mut SandboxMemoryManager, host_funcs: &Arc>, ) -> Result<(), HandleOutbError> { + let g2h_pool_size = mem_mgr.g2h_pool_size(); + let consumer = mem_mgr.g2h_consumer.as_mut().ok_or_else(|| { HandleOutbError::ReadHostFunctionCall("G2H consumer not initialized".into()) })?; @@ -198,7 +200,7 @@ fn outb_virtq_call( // Drain entries, processing Log messages, until we find a Request. let (entry, completion) = loop { let Some((entry, completion)) = consumer - .poll(8192) + .poll(g2h_pool_size) .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? else { // No G2H entry - backpressure-only notify or prefill notify. diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 795308146..d195d3590 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -655,3 +655,79 @@ fn virtq_large_payload_roundtrip() { assert!(res.iter().all(|&b| b == 0)); }); } + +#[test] +fn virtq_multi_descriptor_h2g_two_slots() { + // Payload exceeds a single H2G slot (4096 - header), requiring 2 descriptors. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_msg: String = "A".repeat(4200); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_max_slots() { + // Payload spanning all available H2G pool slots. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_msg: String = "B".repeat(8200); + let res: String = sandbox.call("Echo", large_msg.clone()).unwrap(); + assert_eq!(res, large_msg); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_byte_array() { + // Multi-descriptor with byte array arguments to test binary payloads. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(8); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + let large_bytes: Vec = (0..5000).map(|i| (i % 256) as u8).collect(); + let res: Vec = sandbox + .call("SetByteArrayToZero", large_bytes.clone()) + .unwrap(); + assert_eq!(res.len(), 5000); + assert!(res.iter().all(|&b| b == 0)); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_boundary() { + // Payload exactly at single-slot capacity boundary. + // Header is 8 bytes, so a single slot fits exactly 4088 bytes of payload. + // The FlatBuffer encoding adds overhead, so we test near the boundary + // to verify no off-by-one errors. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(4); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + // This should fit in one descriptor (small overhead) + let msg_under: String = "C".repeat(3900); + let res: String = sandbox.call("Echo", msg_under.clone()).unwrap(); + assert_eq!(res, msg_under); + + // This should just barely spill into a second descriptor + let msg_over: String = "D".repeat(4100); + let res: String = sandbox.call("Echo", msg_over.clone()).unwrap(); + assert_eq!(res, msg_over); + }); +} + +#[test] +fn virtq_multi_descriptor_h2g_repeated_calls() { + // Multiple large calls in sequence to verify H2G refill works correctly + // after multi-descriptor consumption. + let mut cfg = SandboxConfiguration::default(); + cfg.set_h2g_pool_pages(8); + with_rust_sandbox_cfg(cfg, |mut sandbox| { + for i in 0..5 { + let ch = char::from(b'A' + i as u8); + let msg: String = std::iter::repeat_n(ch, 4500).collect(); + let res: String = sandbox.call("Echo", msg.clone()).unwrap(); + assert_eq!(res, msg, "mismatch on call {i}"); + } + }); +} From 69b43439d5d4527f881ad297c2b435e0706cf604 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 9 Apr 2026 15:53:56 +0200 Subject: [PATCH 17/26] feat(virtq): ensure descriptor align in tests Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/desc.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/hyperlight_common/src/virtq/desc.rs b/src/hyperlight_common/src/virtq/desc.rs index 64bde4f7d..01967172a 100644 --- a/src/hyperlight_common/src/virtq/desc.rs +++ b/src/hyperlight_common/src/virtq/desc.rs @@ -319,9 +319,12 @@ mod tests { #[test] fn desc_table_get_out_of_bounds() { - let mut vec = vec![Descriptor::zeroed(); 4]; - let ptr = vec.as_mut_ptr(); - let table = unsafe { DescTable::from_raw_parts(ptr.addr() as u64, 4) }; + // Allocate with extra space to guarantee 16-byte alignment + // (Descriptor requires ALIGN=16 but repr(C) only gives 8). + let mut buf = vec![0u8; 4 * Descriptor::SIZE + Descriptor::ALIGN]; + let base = buf.as_mut_ptr() as usize; + let aligned = (base + Descriptor::ALIGN - 1) & !(Descriptor::ALIGN - 1); + let table = unsafe { DescTable::from_raw_parts(aligned as u64, 4) }; assert!(table.desc_addr(3).is_some()); assert!(table.desc_addr(4).is_none()); } From 8570ea7a73ffc78c965afee49969b957cd6fbbc4 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Thu, 9 Apr 2026 21:32:58 +0200 Subject: [PATCH 18/26] feat(virtq): do not swallow errors Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/consumer.rs | 12 ++++-- src/hyperlight_common/src/virtq/ring.rs | 3 +- src/hyperlight_host/src/mem/mgr.rs | 7 ++- src/hyperlight_host/src/sandbox/outb.rs | 48 ++++++++++++++------- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/src/hyperlight_common/src/virtq/consumer.rs b/src/hyperlight_common/src/virtq/consumer.rs index fb11c778e..d3da1020c 100644 --- a/src/hyperlight_common/src/virtq/consumer.rs +++ b/src/hyperlight_common/src/virtq/consumer.rs @@ -329,10 +329,14 @@ impl VirtqConsumer { self.next_token = self.next_token.wrapping_add(1); // Copy entry data from shared memory - let data = entry_elem - .map(|elem| self.read_element(&elem)) - .transpose()? - .unwrap_or_default(); + let data = match entry_elem.map(|elem| self.read_element(&elem)).transpose() { + Ok(d) => d.unwrap_or_default(), + Err(e) => { + // Read failed - clear inflight before propagating + self.inflight.set(id_idx, false); + return Err(e); + } + }; let entry = RecvEntry { token, data }; diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 9c8d5a30c..1c4ead918 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -1251,7 +1251,8 @@ fn should_notify(evt: EventSuppression, ring_len: u16, old: RingCursor, new: Rin ring_need_event(off, new.head(), old.head()) } - _ => unreachable!(), + // treat as disabled if invalid + _ => false, } } diff --git a/src/hyperlight_host/src/mem/mgr.rs b/src/hyperlight_host/src/mem/mgr.rs index 2173b66ba..4e0c0c6e5 100644 --- a/src/hyperlight_host/src/mem/mgr.rs +++ b/src/hyperlight_host/src/mem/mgr.rs @@ -1014,8 +1014,11 @@ impl SandboxMemoryManager { return Err(new_error!("G2H: result entry too short")); } - let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); - let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let hdr_size = VirtqMsgHeader::SIZE; + let hdr: &VirtqMsgHeader = bytemuck::from_bytes(&entry_data[..hdr_size]); + let available = entry_data.len() - hdr_size; + let payload_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[hdr_size..hdr_size + payload_len]; match hdr.msg_kind() { Ok(MsgKind::Response) => { diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 626236ba4..85f3f1aeb 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -199,27 +199,40 @@ fn outb_virtq_call( // Drain entries, processing Log messages, until we find a Request. let (entry, completion) = loop { - let Some((entry, completion)) = consumer - .poll(g2h_pool_size) - .map_err(|e| HandleOutbError::ReadHostFunctionCall(format!("G2H poll: {e}")))? - else { + let Ok(maybe_next) = consumer.poll(g2h_pool_size) else { + return Err(HandleOutbError::ReadHostFunctionCall( + "G2H poll failed".into(), + )); + }; + + let Some((entry, completion)) = maybe_next else { // No G2H entry - backpressure-only notify or prefill notify. return Ok(()); }; + let hdr_size = VirtqMsgHeader::SIZE; let entry_data = entry.data(); - if entry_data.len() < VirtqMsgHeader::SIZE { + + if entry_data.len() < hdr_size { return Err(HandleOutbError::ReadHostFunctionCall( "G2H entry too short".into(), )); } - let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..hdr_size]); match hdr.msg_kind() { Ok(MsgKind::Log) => { - let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let available = entry_data.len() - hdr_size; + let log_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[hdr_size..hdr_size + log_len]; + emit_guest_log(payload); - let _ = consumer.complete(completion); + + consumer.complete(completion).map_err(|e| { + HandleOutbError::ReadHostFunctionCall(format!("G2H complete log: {e}")) + })?; + continue; } Ok(MsgKind::Request) => break (entry, completion), @@ -237,8 +250,18 @@ fn outb_virtq_call( } }; + // Validate completion buffer before calling the host function + let virtq::SendCompletion::Writable(mut wc) = completion else { + return Err(HandleOutbError::WriteHostFunctionResponse( + "G2H: expected writable completion, got ack (ring corruption)".into(), + )); + }; + let entry_data = entry.data(); - let payload = &entry_data[VirtqMsgHeader::SIZE..]; + let hdr: VirtqMsgHeader = *bytemuck::from_bytes(&entry_data[..VirtqMsgHeader::SIZE]); + let available = entry_data.len() - VirtqMsgHeader::SIZE; + let payload_len = (hdr.payload_len as usize).min(available); + let payload = &entry_data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + payload_len]; let call = FunctionCall::try_from(payload) .map_err(|e| HandleOutbError::ReadHostFunctionCall(e.to_string()))?; @@ -259,13 +282,6 @@ fn outb_virtq_call( let resp_header = VirtqMsgHeader::new(MsgKind::Response, 0, result_payload.len() as u32); let resp_header_bytes = bytemuck::bytes_of(&resp_header); - // Write response into the completion buffer - let virtq::SendCompletion::Writable(mut wc) = completion else { - return Err(HandleOutbError::WriteHostFunctionResponse( - "G2H: expected writable completion, got ack (ring corruption)".into(), - )); - }; - wc.write_all(resp_header_bytes) .map_err(|e| HandleOutbError::WriteHostFunctionResponse(format!("{e}")))?; wc.write_all(result_payload) From e2069d099c10baabeb4d7c8672c9870751d730ef Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 11:38:35 +0200 Subject: [PATCH 19/26] fix(virtq): adjust sizes for benchmarks Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/mod.rs | 76 +++++++++++-------- src/hyperlight_common/src/virtq/producer.rs | 25 +++++- src/hyperlight_common/src/virtq/ring.rs | 51 ++++++------- src/hyperlight_host/benches/benchmarks.rs | 15 ++-- .../src/sandbox/initialized_multi_use.rs | 10 +-- 5 files changed, 105 insertions(+), 72 deletions(-) diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index f8039b38c..3d36bf0a8 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -896,14 +896,14 @@ mod tests { } #[test] - fn test_reclaim_then_poll_preserves_order() { + fn test_reclaim_discards_readonly_completions() { let ring = make_ring(8); let (mut producer, mut consumer, _) = make_test_producer(&ring); // Submit 3 entries: RO, RW, RO - let tok_ro1 = send_readonly(&mut producer, b"log1"); + let _tok_ro1 = send_readonly(&mut producer, b"log1"); let tok_rw = send_readwrite(&mut producer, b"call", 64); - let tok_ro2 = send_readonly(&mut producer, b"log2"); + let _tok_ro2 = send_readonly(&mut producer, b"log2"); // Consumer processes all 3 let (_, c1) = consumer.poll(1024).unwrap().unwrap(); @@ -919,24 +919,16 @@ mod tests { let (_, c3) = consumer.poll(1024).unwrap().unwrap(); consumer.complete(c3).unwrap(); // ack RO - // Reclaim all 3 + // Reclaim all 3 - RO completions are discarded, only RW is buffered let count = producer.reclaim().unwrap(); assert_eq!(count, 3); - // poll() returns them in order - let cqe1 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe1.token, tok_ro1); - assert!(cqe1.data.is_empty()); - - let cqe2 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe2.token, tok_rw); - assert_eq!(&cqe2.data[..], b"result"); - - let cqe3 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe3.token, tok_ro2); - assert!(cqe3.data.is_empty()); + // poll() returns only the RW completion + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok_rw); + assert_eq!(&cqe.data[..], b"result"); - // No more + // No more - RO completions were discarded assert!(producer.poll().unwrap().is_none()); } @@ -973,11 +965,7 @@ mod tests { assert_eq!(&cqe2.data[..], b"reply"); } - /// Regression test: reclaim + submit must not cause token collisions. - /// - /// Before the monotonic generation counter, Token wrapped the descriptor - /// ID which gets recycled. This caused stale pending completions to - /// match newly submitted entries with the same recycled descriptor ID. + /// reclaim + submit must not cause token collisions. #[test] fn test_reclaim_submit_no_token_collision() { let ring = make_ring(8); @@ -989,7 +977,6 @@ mod tests { let (_, c) = consumer.poll(1024).unwrap().unwrap(); consumer.complete(c).unwrap(); - // Reclaim pushes the completion to pending (token = tok_old) let count = producer.reclaim().unwrap(); assert_eq!(count, 1); @@ -1010,15 +997,42 @@ mod tests { wc.write_all(b"result").unwrap(); consumer.complete(wc.into()).unwrap(); - // Poll should return the stale ReadOnly completion first (wrong token) - let cqe1 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe1.token, tok_old); - assert!(cqe1.data.is_empty()); + // Poll returns only the RW completion (RO was discarded by reclaim) + let cqe = producer.poll().unwrap().unwrap(); + assert_eq!(cqe.token, tok_new); + assert_eq!(&cqe.data[..], b"result"); - // Then the new ReadWrite completion (matching token) - let cqe2 = producer.poll().unwrap().unwrap(); - assert_eq!(cqe2.token, tok_new); - assert_eq!(&cqe2.data[..], b"result"); + // No stale RO completion in the queue + assert!(producer.poll().unwrap().is_none()); + } + + /// Verify that repeated oneshot submit/reclaim cycles do not accumulate pending completions. + #[test] + fn test_reclaim_readonly_does_not_leak_pending() { + let ring = make_ring(4); + let (mut producer, mut consumer, _) = make_test_producer(&ring); + + for _ in 0..10 { + // Fill the ring + for _ in 0..4 { + send_readonly(&mut producer, b"msg"); + } + + // Consumer acks all + while let Some((_, completion)) = consumer.poll(1024).unwrap() { + consumer.complete(completion).unwrap(); + } + + // Reclaim frees ring slots; empty completions are discarded + let count = producer.reclaim().unwrap(); + assert_eq!(count, 4); + + // No completions should be buffered in pending + assert!( + producer.poll().unwrap().is_none(), + "pending should be empty after reclaiming RO entries" + ); + } } } #[cfg(all(test, loom))] diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index b66bc78b4..ed1844d44 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -34,6 +34,10 @@ pub struct RecvCompletion { pub token: Token, /// Completion data from the device. pub data: Bytes, + /// Whether this entry is oneshot so there is no writable completion buffer. + /// Oneshot entries are fire-and-forget: the producer does not + /// expect any response data from the device. + pub oneshot: bool, } /// Allocation tracking for an in-flight descriptor chain. @@ -146,7 +150,8 @@ where /// * `pool` - Buffer allocator for entry/completion data pub fn new(layout: Layout, mem: M, notifier: N, pool: P) -> Self { let inner = RingProducer::new(layout, mem); - let inflight = vec![None; inner.len()]; + let ring_len = inner.len(); + let inflight = vec![None; ring_len]; Self { inner, @@ -154,7 +159,7 @@ where notifier, inflight, next_token: 0, - pending: VecDeque::new(), + pending: VecDeque::with_capacity(ring_len), } } @@ -192,6 +197,9 @@ where /// buffer allocations immediately, and buffers completion data for /// later retrieval via [`poll`](Self::poll). /// + /// Completions with empty data from read-only/oneshot entries are + /// discarded immediately. + /// /// Use this to free resources under backpressure without losing /// completion data. Returns the number of entries reclaimed. pub fn reclaim(&mut self) -> Result @@ -201,7 +209,11 @@ where { let mut count = 0; while let Some(cqe) = self.poll_ring()? { - self.pending.push_back(cqe); + if !cqe.oneshot { + debug_assert!(self.pending.len() < self.inflight.len()); + debug_assert!(!cqe.data.is_empty()); + self.pending.push_back(cqe); + } count += 1; } Ok(count) @@ -242,6 +254,7 @@ where } // Read completion data + let has_completion = inf.completion().is_some(); let data = match inf.completion() { Some(buf) => { if written > buf.len { @@ -259,7 +272,11 @@ where None => Bytes::new(), }; - Ok(Some(RecvCompletion { token, data })) + Ok(Some(RecvCompletion { + token, + data, + oneshot: !has_completion, + })) } /// Drain all available completions, calling the provided closure for each. diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 1c4ead918..9b463ed3d 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -350,11 +350,14 @@ impl RingCursor { } } - /// Advance by n positions + /// Advance by n positions using modular arithmetic. #[inline] pub(crate) fn advance_by(&mut self, n: u16) { - for _ in 0..n { - self.advance(); + let new = self.head + n; + let wraps = new / self.size; + self.head = new % self.size; + if wraps % 2 != 0 { + self.wrap = !self.wrap; } } @@ -371,6 +374,7 @@ impl RingCursor { } /// Reset cursor to initial state. + #[inline] pub fn reset(&mut self) { self.head = 0; self.wrap = true; @@ -962,7 +966,7 @@ impl RingConsumer { return Err(RingError::WouldBlock); } - // Build chain (head + tails). + // Build chain (head + tails), tracking readable/writable split inline. let mut elements = SmallVec::<[BufferElement; 16]>::new(); let mut pos = self.avail_cursor; let mut chain_len: u16 = 1; @@ -972,7 +976,10 @@ impl RingConsumer { let max_steps = self.desc_table.len(); - elements.push(BufferElement::from(&head_desc)); + let head_elem = BufferElement::from(&head_desc); + let mut seen_writable = head_elem.writable; + let mut writables: usize = if seen_writable { 1 } else { 0 }; + elements.push(head_elem); pos.advance(); while has_next && steps < max_steps { @@ -982,8 +989,17 @@ impl RingConsumer { .ok_or(RingError::InvalidState)?; // tail reads does not need ordering because head has been already validated - let desc = self.mem.read_val(addr).map_err(|_| RingError::MemError)?; - elements.push(BufferElement::from(&desc)); + let desc: Descriptor = self.mem.read_val(addr).map_err(|_| RingError::MemError)?; + let elem = BufferElement::from(&desc); + + if elem.writable { + seen_writable = true; + writables += 1; + } else if seen_writable { + return Err(RingError::BadChain); + } + + elements.push(elem); chain_len += 1; steps += 1; @@ -997,8 +1013,7 @@ impl RingConsumer { return Err(RingError::BadChain); } - // Verify that readable/writable split is correct - let readables = chain_readable_count(&elements)?; + let readables = elements.len() - writables; // Since driver wrote the same id everywhere, head_desc.id is valid. let id = head_desc.id; @@ -1261,24 +1276,6 @@ pub fn ring_need_event(event_idx: u16, new: u16, old: u16) -> bool { new.wrapping_sub(event_idx).wrapping_sub(1) < new.wrapping_sub(old) } -#[inline] -/// Check that a buffer chain is well-formed: all readable buffers first, -/// then writable and return the count of readable buffers. -fn chain_readable_count(elems: &[BufferElement]) -> Result { - let mut seen_writable = false; - let mut writables = 0; - - for e in elems { - if e.writable { - seen_writable = true; - writables += 1; - } else if seen_writable { - return Err(RingError::BadChain); - } - } - - Ok(elems.len() - writables) -} impl From<&Descriptor> for BufferElement { fn from(desc: &Descriptor) -> Self { diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index f8b6990a0..3ac391434 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -62,13 +62,13 @@ impl SandboxSize { Self::Medium => { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(MEDIUM_HEAP_SIZE); - cfg.set_scratch_size(0x50000); + cfg.set_scratch_size(0x80000); Some(cfg) } Self::Large => { let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(LARGE_HEAP_SIZE); - cfg.set_scratch_size(0x100000); + cfg.set_scratch_size(0x200000); Some(cfg) } } @@ -384,10 +384,15 @@ fn guest_call_benchmark_large_param(c: &mut Criterion) { let large_vec = vec![0u8; SIZE]; let large_string = String::from_utf8(large_vec.clone()).unwrap(); + let h2g_pool_pages = (2 * SIZE + (1024 * 1024)) / 4096; + let heap_size = SIZE as u64 * 15; + let mut config = SandboxConfiguration::default(); - config.set_h2g_pool_pages((2 * SIZE + (1024 * 1024)) / 4096); // pool pages for the large input - config.set_heap_size(SIZE as u64 * 15); - config.set_scratch_size(6 * SIZE + 4 * (1024 * 1024)); // Big enough for any data copies, etc. + config.set_h2g_pool_pages(h2g_pool_pages); + config.set_h2g_queue_depth(h2g_pool_pages.next_power_of_two()); + config.set_heap_size(heap_size); + // Scratch backs all guest physical pages (heap, page tables, pools). + config.set_scratch_size(heap_size as usize + 4 * 1024 * 1024); let sandbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().unwrap()), diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index cb5566eda..1a66f18d6 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1090,21 +1090,21 @@ mod tests { assert_eq!(res, 0); } - // Tests to ensure that many (1000) function calls can be made in a call context with a small stack (24K) and heap(20K). + // Tests to ensure that many (1000) function calls can be made in a call context with a small stack (24K) and heap(32K). // This test effectively ensures that the stack is being properly reset after each call and we are not leaking memory in the Guest. #[test] fn test_with_small_stack_and_heap() { let mut cfg = SandboxConfiguration::default(); - cfg.set_heap_size(20 * 1024); + cfg.set_heap_size(32 * 1024); // min_scratch_size already includes 1 page (4k on most // platforms) of guest stack, so add 20k more to get 24k // total, and then add some more for the eagerly-copied page - // tables on amd64 + // tables on amd64 and virtq pool pages. let min_scratch = hyperlight_common::layout::min_scratch_size( cfg.get_g2h_queue_depth(), cfg.get_h2g_queue_depth(), ); - cfg.set_scratch_size(min_scratch + 0x10000 + 0x10000); + cfg.set_scratch_size(min_scratch + 0x10000 + 0x18000); let mut sbox1: MultiUseSandbox = { let path = simple_guest_as_string().unwrap(); @@ -1718,7 +1718,7 @@ mod tests { for (name, heap_size) in test_cases { let mut cfg = SandboxConfiguration::default(); - cfg.set_heap_size(heap_size); + cfg.set_heap_size(128 * 1024); cfg.set_scratch_size(0x100000); let path = simple_guest_as_string().unwrap(); From 26a9c8e2bc9b8494107ada66f87a67abc8d40550 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 13:55:42 +0200 Subject: [PATCH 20/26] fix(virtq): make clippy happy Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/buffer.rs | 11 +++++++++-- src/hyperlight_common/src/virtq/mod.rs | 1 + src/hyperlight_common/src/virtq/producer.rs | 4 ++-- src/hyperlight_common/src/virtq/ring.rs | 1 - .../src/sandbox/initialized_multi_use.rs | 4 ++-- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index b41708b03..237eedcba 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -139,7 +139,10 @@ impl AllocGuard { } pub fn release(mut self) -> Allocation { - self.0.take().unwrap().0 + // Safety: AllocGuard is always constructed with Some, and release is only called once + self.0.take().map(|(alloc, _)| alloc).unwrap_or_else(|| { + unreachable!("AllocGuard::release called on dismissed guard") + }) } } @@ -147,7 +150,11 @@ impl core::ops::Deref for AllocGuard { type Target = Allocation; fn deref(&self) -> &Allocation { - &self.0.as_ref().unwrap().0 + // Safety: AllocGuard is always constructed with Some, and the inner value is only + // taken by release() or Drop. + &self.0.as_ref().unwrap_or_else(|| { + unreachable!("AllocGuard::deref called on dismissed guard") + }).0 } } diff --git a/src/hyperlight_common/src/virtq/mod.rs b/src/hyperlight_common/src/virtq/mod.rs index 3d36bf0a8..0729d473b 100644 --- a/src/hyperlight_common/src/virtq/mod.rs +++ b/src/hyperlight_common/src/virtq/mod.rs @@ -351,6 +351,7 @@ impl From for Allocation { } const _: () = { + #[allow(clippy::unwrap_used)] const fn verify_layout(num_descs: usize) { let base = 0x1000u64; diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index ed1844d44..c7fd49efc 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -519,8 +519,8 @@ where self.inner.reset_prefilled(&ids); let addrs: SmallVec<[u64; 64]> = (0..prefill_count) - .map(|i| self.pool.slot_addr(i).expect("prefill_count <= pool count")) - .collect(); + .map(|i| self.pool.slot_addr(i).ok_or(VirtqError::InvalidState)) + .collect::>()?; self.pool .restore_allocated(&addrs) diff --git a/src/hyperlight_common/src/virtq/ring.rs b/src/hyperlight_common/src/virtq/ring.rs index 9b463ed3d..2464a1a1d 100644 --- a/src/hyperlight_common/src/virtq/ring.rs +++ b/src/hyperlight_common/src/virtq/ring.rs @@ -1276,7 +1276,6 @@ pub fn ring_need_event(event_idx: u16, new: u16, old: u16) -> bool { new.wrapping_sub(event_idx).wrapping_sub(1) < new.wrapping_sub(old) } - impl From<&Descriptor> for BufferElement { fn from(desc: &Descriptor) -> Self { BufferElement { diff --git a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs index 1a66f18d6..0d9d6d737 100644 --- a/src/hyperlight_host/src/sandbox/initialized_multi_use.rs +++ b/src/hyperlight_host/src/sandbox/initialized_multi_use.rs @@ -1718,8 +1718,8 @@ mod tests { for (name, heap_size) in test_cases { let mut cfg = SandboxConfiguration::default(); - cfg.set_heap_size(128 * 1024); - cfg.set_scratch_size(0x100000); + cfg.set_heap_size(heap_size); + cfg.set_scratch_size(heap_size as usize + 0x100000); let path = simple_guest_as_string().unwrap(); let sbox = UninitializedSandbox::new(GuestBinary::FilePath(path), Some(cfg)) From 1a527ad0ddc28dbd62ed3aec8f7f92e6d057fffc Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 14:49:06 +0200 Subject: [PATCH 21/26] feat(virtq): add recycle pool tests Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/buffer.rs | 17 +-- src/hyperlight_common/src/virtq/pool.rs | 141 ++++++++++++++++++++++ 2 files changed, 151 insertions(+), 7 deletions(-) diff --git a/src/hyperlight_common/src/virtq/buffer.rs b/src/hyperlight_common/src/virtq/buffer.rs index 237eedcba..7b637e38b 100644 --- a/src/hyperlight_common/src/virtq/buffer.rs +++ b/src/hyperlight_common/src/virtq/buffer.rs @@ -103,7 +103,7 @@ impl BufferProvider for Arc { /// zero-copy `Bytes` backed by shared memory. /// /// When dropped, the allocation is returned to the pool. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct BufferOwner { pub(crate) pool: P, pub(crate) mem: M, @@ -140,9 +140,10 @@ impl AllocGuard { pub fn release(mut self) -> Allocation { // Safety: AllocGuard is always constructed with Some, and release is only called once - self.0.take().map(|(alloc, _)| alloc).unwrap_or_else(|| { - unreachable!("AllocGuard::release called on dismissed guard") - }) + self.0 + .take() + .map(|(alloc, _)| alloc) + .unwrap_or_else(|| unreachable!("AllocGuard::release called on dismissed guard")) } } @@ -152,9 +153,11 @@ impl core::ops::Deref for AllocGuard { fn deref(&self) -> &Allocation { // Safety: AllocGuard is always constructed with Some, and the inner value is only // taken by release() or Drop. - &self.0.as_ref().unwrap_or_else(|| { - unreachable!("AllocGuard::deref called on dismissed guard") - }).0 + &self + .0 + .as_ref() + .unwrap_or_else(|| unreachable!("AllocGuard::deref called on dismissed guard")) + .0 } } diff --git a/src/hyperlight_common/src/virtq/pool.rs b/src/hyperlight_common/src/virtq/pool.rs index d60d432bf..92db1a38e 100644 --- a/src/hyperlight_common/src/virtq/pool.rs +++ b/src/hyperlight_common/src/virtq/pool.rs @@ -680,6 +680,20 @@ impl BufferProvider for RecyclePool { fn dealloc(&self, alloc: Allocation) -> Result<(), AllocError> { let mut inner = self.inner.borrow_mut(); + let end = inner.base_addr + (inner.count * inner.slot_size) as u64; + + if alloc.addr < inner.base_addr || alloc.addr >= end { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + + if (alloc.addr - inner.base_addr) % inner.slot_size as u64 != 0 { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + + if inner.free.contains(&alloc.addr) { + return Err(AllocError::InvalidFree(alloc.addr, alloc.len)); + } + inner.free.push(alloc.addr); Ok(()) } @@ -1389,6 +1403,133 @@ mod tests { pool.restore_allocated(&[0x80000]).unwrap(); assert_eq!(pool.num_free(), 3); } + + #[test] + fn test_recycle_pool_dealloc_out_of_range() { + let pool = make_recycle_pool(4, 4096); + let _ = pool.alloc(4096).unwrap(); + + let bogus = Allocation { + addr: 0xDEAD, + len: 4096, + }; + assert!(matches!( + pool.dealloc(bogus), + Err(AllocError::InvalidFree(0xDEAD, 4096)) + )); + } + + #[test] + fn test_recycle_pool_dealloc_misaligned() { + let pool = make_recycle_pool(4, 4096); + let _ = pool.alloc(4096).unwrap(); + + let misaligned = Allocation { + addr: 0x80001, + len: 4096, + }; + assert!(matches!( + pool.dealloc(misaligned), + Err(AllocError::InvalidFree(0x80001, 4096)) + )); + } + + #[test] + fn test_recycle_pool_dealloc_double_free() { + let pool = make_recycle_pool(4, 4096); + let a = pool.alloc(4096).unwrap(); + pool.dealloc(a).unwrap(); + + // Second dealloc should fail - address is already in the free list + assert!(matches!( + pool.dealloc(a), + Err(AllocError::InvalidFree(_, _)) + )); + } + + #[test] + fn test_recycle_pool_random_order_dealloc() { + let pool = make_recycle_pool(8, 4096); + + let mut allocs: Vec = (0..8).map(|_| pool.alloc(4096).unwrap()).collect(); + assert_eq!(pool.num_free(), 0); + + // Dealloc in reverse order + allocs.reverse(); + for a in &allocs { + pool.dealloc(*a).unwrap(); + } + assert_eq!(pool.num_free(), 8); + + // All slots should be re-allocatable + let reallocs: Vec = (0..8).map(|_| pool.alloc(4096).unwrap()).collect(); + assert_eq!(pool.num_free(), 0); + + // Verify all addresses are distinct + let mut addrs: Vec = reallocs.iter().map(|a| a.addr).collect(); + addrs.sort(); + addrs.dedup(); + assert_eq!(addrs.len(), 8); + } + + #[test] + fn test_recycle_pool_interleaved_alloc_dealloc_order() { + let pool = make_recycle_pool(4, 4096); + + let a0 = pool.alloc(4096).unwrap(); + let a1 = pool.alloc(4096).unwrap(); + let a2 = pool.alloc(4096).unwrap(); + let a3 = pool.alloc(4096).unwrap(); + assert_eq!(pool.num_free(), 0); + + // Free middle slots first (out of allocation order) + pool.dealloc(a2).unwrap(); + pool.dealloc(a0).unwrap(); + assert_eq!(pool.num_free(), 2); + + // Re-alloc gets the out-of-order slots back (LIFO) + let b0 = pool.alloc(4096).unwrap(); + assert_eq!(b0.addr, a0.addr); + let b1 = pool.alloc(4096).unwrap(); + assert_eq!(b1.addr, a2.addr); + + // Free everything in yet another order + pool.dealloc(a1).unwrap(); + pool.dealloc(b0).unwrap(); + pool.dealloc(b1).unwrap(); + pool.dealloc(a3).unwrap(); + assert_eq!(pool.num_free(), 4); + + // All 4 original addresses should be available + let mut final_addrs: Vec = (0..4).map(|_| pool.alloc(4096).unwrap().addr).collect(); + final_addrs.sort(); + let expected: Vec = (0..4).map(|i| 0x80000 + i * 4096).collect(); + assert_eq!(final_addrs, expected); + } + + #[test] + fn test_recycle_pool_dealloc_order_independent_of_alloc_order() { + let pool = make_recycle_pool(6, 256); + + // Allocate all + let allocs: Vec = (0..6).map(|_| pool.alloc(256).unwrap()).collect(); + + // Dealloc in scattered order: 4, 1, 5, 0, 3, 2 + let order = [4, 1, 5, 0, 3, 2]; + for &i in &order { + pool.dealloc(allocs[i]).unwrap(); + } + assert_eq!(pool.num_free(), 6); + + // Re-allocate all and verify we get back the full set + let mut realloc_addrs: Vec = (0..6).map(|_| pool.alloc(256).unwrap().addr).collect(); + realloc_addrs.sort(); + + let mut orig_addrs: Vec = allocs.iter().map(|a| a.addr).collect(); + orig_addrs.sort(); + + assert_eq!(realloc_addrs, orig_addrs); + } } #[cfg(test)] From 1c23e1a7672a3aa4753ac6d21d0d033327f02c0b Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 15:27:20 +0200 Subject: [PATCH 22/26] feat(virtq): implement G2H reply backlog guard Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/src/virtq/producer.rs | 6 +++ src/hyperlight_guest/src/virtq/context.rs | 48 ++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/src/hyperlight_common/src/virtq/producer.rs b/src/hyperlight_common/src/virtq/producer.rs index c7fd49efc..0cc634e73 100644 --- a/src/hyperlight_common/src/virtq/producer.rs +++ b/src/hyperlight_common/src/virtq/producer.rs @@ -384,6 +384,12 @@ where self.inner.used_cursor() } + /// Number of free (unsubmitted) descriptors in the ring. + #[inline] + pub fn num_free(&self) -> usize { + self.inner.num_free() + } + /// Configure event suppression for used buffer notifications. /// /// This controls when the device (consumer) signals us about completed buffers: diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index fae3c87be..9eb940788 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -79,6 +79,8 @@ pub struct GuestContext { h2g_slot_size: usize, /// snapshot generation counter generation: u32, + /// Number of H2G requests received that still need a G2H response. + pending_replies: u32, /// used by cabi last_host_result: Option>, } @@ -106,6 +108,7 @@ impl GuestContext { g2h_response_cap, h2g_slot_size, generation, + pending_replies: 0, last_host_result: None, }; @@ -114,6 +117,9 @@ impl GuestContext { } /// Call a host function via the G2H virtqueue. + /// + /// The reply guard is checked before submitting the readwrite chain + /// to ensure G2H capacity is reserved for pending responses. pub fn call_host_function>( &mut self, function_name: &str, @@ -139,6 +145,9 @@ impl GuestContext { let entry_len = VirtqMsgHeader::SIZE + payload.len(); + // Reply guard: readwrite chains use 2 descriptors, leave room for pending replies. + self.ensure_reply_capacity(2)?; + let token = match self.try_send_readwrite(hdr_bytes, payload, entry_len) { Ok(tok) => tok, Err(e) if e.is_transient() => { @@ -191,6 +200,9 @@ impl GuestContext { /// Each descriptor carries a [`VirtqMsgHeader`] with `payload_len` for /// that chunk. If [`MsgFlags::MORE`](hyperlight_common::virtq::msg::MsgFlags::MORE) /// is set, more descriptors follow. + /// + /// Increments the reply guard counter so that subsequent G2H sends + /// reserve capacity for the response. pub fn recv_h2g_call(&mut self) -> Result { let Some(first) = self.h2g_producer.poll()? else { bail!("H2G: no pending call"); @@ -209,6 +221,9 @@ impl GuestContext { let chunk_len = hdr.payload_len as usize; + // Track that we owe a response on G2H. + self.pending_replies = self.pending_replies.saturating_add(1); + if !hdr.has_more() { // Single-descriptor fast path let payload = &data[VirtqMsgHeader::SIZE..VirtqMsgHeader::SIZE + chunk_len]; @@ -250,8 +265,11 @@ impl GuestContext { /// Send the result of a host-to-guest call back to the host via the /// G2H queue, then refill H2G descriptor slots until the ring is full. + /// + /// Decrements the reply guard counter after a successful send. pub fn send_h2g_result(&mut self, payload: &[u8]) -> Result<()> { self.send_g2h_oneshot(MsgKind::Response, payload)?; + self.pending_replies = self.pending_replies.saturating_sub(1); self.prefill_h2g() } @@ -343,16 +361,42 @@ impl GuestContext { } } + /// Ensure the G2H ring has enough free descriptors to accommodate + /// both the requested send (`need_descs`) and all pending replies. + fn ensure_reply_capacity(&mut self, need_descs: usize) -> Result<()> { + let reserved = self.pending_replies as usize; + loop { + let free = self.g2h_producer.num_free(); + if free >= need_descs + reserved { + return Ok(()); + } + + self.g2h_producer.notify_backpressure(); + let reclaimed = self.g2h_producer.reclaim()?; + if reclaimed == 0 { + // No progress - host hasn't completed any entries yet. + // Fall through and let the send path handle backpressure + // via its own retry logic. + return Ok(()); + } + } + } + /// Send a one-way message on the G2H queue ReadOnly and no completion. /// - /// If the pool or ring is full, triggers backpressure, VM exit so - /// the host can drain, then retries once. + /// For non-response sends, the reply guard is checked first to + /// ensure enough G2H capacity is reserved for pending replies. fn send_g2h_oneshot(&mut self, kind: MsgKind, payload: &[u8]) -> Result<()> { let reqid = REQUEST_ID.fetch_add(1, Relaxed); let hdr = VirtqMsgHeader::new(kind, reqid, payload.len() as u32); let hdr_bytes = bytemuck::bytes_of(&hdr); let entry_len = VirtqMsgHeader::SIZE + payload.len(); + // Reply guard: non-response sends must leave room for pending replies. + if kind != MsgKind::Response { + self.ensure_reply_capacity(1)?; + } + // First attempt match self.try_send_readonly(hdr_bytes, payload, entry_len) { Ok(_) => return Ok(()), From daf681f7ff626c71dac1f2613325d449e1d21dd0 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 15:43:46 +0200 Subject: [PATCH 23/26] fix(virtq): add copyright header to benches Signed-off-by: Tomasz Andrzejak --- src/hyperlight_common/benches/buffer_pool.rs | 16 ++++++++++++++++ typos.toml | 2 ++ 2 files changed, 18 insertions(+) diff --git a/src/hyperlight_common/benches/buffer_pool.rs b/src/hyperlight_common/benches/buffer_pool.rs index 614f160b0..80d4f9daa 100644 --- a/src/hyperlight_common/benches/buffer_pool.rs +++ b/src/hyperlight_common/benches/buffer_pool.rs @@ -1,3 +1,19 @@ +/* +Copyright 2026 The Hyperlight Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + use std::hint::black_box; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; diff --git a/typos.toml b/typos.toml index c04a7f2c8..a2607e2db 100644 --- a/typos.toml +++ b/typos.toml @@ -13,3 +13,5 @@ fpr="fpr" consts="consts" # ist is an acronym for Interrupt Stack Table, not a missspelling of its ist="ist" +# writables as number of writable buffers +writables="writables" From 78119579a8cd53e36d46a8b8c7b62e97080fb0eb Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 17:47:28 +0200 Subject: [PATCH 24/26] fix(virtq): we gonna need a bigger boat Move FXSAVE buffer to the middle of scratch to avoid overwriting live page tables that are copied to the beginning of scratch when update_scratch_bookkeeping is called Signed-off-by: Tomasz Andrzejak --- fuzz/fuzz_targets/guest_trace.rs | 4 ++-- .../src/hypervisor/hyperlight_vm/x86_64.rs | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/fuzz/fuzz_targets/guest_trace.rs b/fuzz/fuzz_targets/guest_trace.rs index 3dfb61c95..43373300e 100644 --- a/fuzz/fuzz_targets/guest_trace.rs +++ b/fuzz/fuzz_targets/guest_trace.rs @@ -69,8 +69,8 @@ impl<'a> Arbitrary<'a> for FuzzInput { fuzz_target!( init: { let mut cfg = SandboxConfiguration::default(); - // In local tests, 256 KiB seemed sufficient for deep recursion - cfg.set_scratch_size(256 * 1024); + // In local tests, 512 KiB seemed sufficient for deep recursion + cfg.set_scratch_size(512 * 1024); let path = simple_guest_for_fuzzing_as_string().expect("Guest Binary Missing"); let u_sbox = UninitializedSandbox::new( GuestBinary::FilePath(path), diff --git a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs index 2b033f23b..494377a69 100644 --- a/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs +++ b/src/hyperlight_host/src/hypervisor/hyperlight_vm/x86_64.rs @@ -2125,18 +2125,20 @@ mod tests { } /// Creates VM with guest code that: dirtys FPU (if flag==0), does FXSAVE to buffer, sets flag=1. - /// Uses scratch region after rings for FXSAVE buffer. + /// Uses a scratch region area for the FXSAVE buffer. fn hyperlight_vm_with_mem_mgr_fxsave() -> FxsaveTestContext { use iced_x86::code_asm::*; // Compute fixed addresses for FXSAVE buffer and flag. - // We use the page-table area in scratch after rings as a - // convenient 512-byte aligned buffer for FXSAVE. + // Place the buffer at halfway through scratch: well past + // the rings and page tables at the start, and well below + // the stack and scratch-top metadata at the end. let config: SandboxConfiguration = Default::default(); let layout = SandboxMemoryLayout::new(config, 512, 4096, None).unwrap(); - let fxsave_offset = layout.get_pt_base_scratch_offset(); - let fxsave_gva = hyperlight_common::layout::scratch_base_gva(config.get_scratch_size()) - + fxsave_offset as u64; + let scratch_size = config.get_scratch_size(); + let fxsave_offset = (scratch_size / 2) & !0xFFF; // page-aligned + let fxsave_gva = + hyperlight_common::layout::scratch_base_gva(scratch_size) + fxsave_offset as u64; let flag_gva = fxsave_gva + 512; let mut a = CodeAssembler::new(64).unwrap(); From 6f56edf73f6615b7e85602e1525bac4c348aa0c1 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 19:23:15 +0200 Subject: [PATCH 25/26] fix(virtq): add instrumentation to virtq host calls Signed-off-by: Tomasz Andrzejak --- fuzz/fuzz_targets/host_call.rs | 2 ++ src/hyperlight_guest/src/virtq/context.rs | 2 ++ src/hyperlight_host/tests/integration_test.rs | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fuzz/fuzz_targets/host_call.rs b/fuzz/fuzz_targets/host_call.rs index 0559b2bd6..c65b10047 100644 --- a/fuzz/fuzz_targets/host_call.rs +++ b/fuzz/fuzz_targets/host_call.rs @@ -61,6 +61,8 @@ fuzz_target!( HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The number of arguments to the function is wrong") => {} HyperlightError::ParameterValueConversionFailure(_, _) => {}, HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("Failed To Convert Parameter Value") => {} + HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The parameter value type is unexpected") => {} + HyperlightError::GuestError(ErrorCode::HostFunctionError, msg) if msg.contains("The return value type is unexpected") => {} // any other error should be reported _ => panic!("Guest Aborted with Unexpected Error: {:?}", e), diff --git a/src/hyperlight_guest/src/virtq/context.rs b/src/hyperlight_guest/src/virtq/context.rs index 9eb940788..46e4b0cdb 100644 --- a/src/hyperlight_guest/src/virtq/context.rs +++ b/src/hyperlight_guest/src/virtq/context.rs @@ -33,6 +33,7 @@ use hyperlight_common::virtq::msg::{MsgKind, VirtqMsgHeader}; use hyperlight_common::virtq::{ self, BufferPool, Layout, Notifier, QueueStats, RecyclePool, Token, VirtqProducer, }; +use tracing::instrument; use super::GuestMemOps; use crate::bail; @@ -120,6 +121,7 @@ impl GuestContext { /// /// The reply guard is checked before submitting the readwrite chain /// to ensure G2H capacity is reserved for pending responses. + #[instrument(skip_all, level = "Info")] pub fn call_host_function>( &mut self, function_name: &str, diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index 243b8d3b8..e4abf68c1 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -716,10 +716,10 @@ fn log_message() { // follows: // - logs from trace level tracing spans created as logs because of the tracing `log` feature // - 4 from evolve call (generic_init + hyperlight_main) - // - 8 from guest call + // - 4 from guest call (call_host_function + read_n_bytes_from_user_memory) // and are multiplied because we make 6 calls to `log_test_messages` // NOTE: These numbers need to be updated if log messages or spans are added/removed - let num_fixed_trace_log = 12 * 6; + let num_fixed_trace_log = 8 * 6; // Calculate fixed info logs // - 4 logs per iteration from infrastructure at Info level (internal_dispatch_function) From 1f9ed5090647f389fcb4d9c89bc6c23ebed62537 Mon Sep 17 00:00:00 2001 From: Tomasz Andrzejak Date: Fri, 10 Apr 2026 19:55:19 +0200 Subject: [PATCH 26/26] fix(virtq): truncate error message so it fits completion Signed-off-by: Tomasz Andrzejak --- src/hyperlight_host/src/sandbox/outb.rs | 15 +++++++++++++-- src/hyperlight_host/tests/integration_test.rs | 4 ++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/src/hyperlight_host/src/sandbox/outb.rs b/src/hyperlight_host/src/sandbox/outb.rs index 85f3f1aeb..0b6bfb609 100644 --- a/src/hyperlight_host/src/sandbox/outb.rs +++ b/src/hyperlight_host/src/sandbox/outb.rs @@ -268,12 +268,23 @@ fn outb_virtq_call( let name = call.function_name.clone(); let args: Vec = call.parameters.unwrap_or(vec![]); - let res = host_funcs + + let registry = host_funcs .try_lock() - .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))? + .map_err(|e| HandleOutbError::LockFailed(file!(), line!(), e.to_string()))?; + + let mut res = registry .call_host_function(&name, args) .map_err(|e| GuestError::new(ErrorCode::HostFunctionError, e.to_string())); + // Truncate oversized error messages so the serialized response + // fits in the completion buffer the guest pre-allocated. + if let Err(err) = &mut res + && err.message.len() > wc.capacity() + { + err.message.truncate(wc.capacity()); + } + // Serialize response: VirtqMsgHeader + FunctionCallResult let func_result = FunctionCallResult::new(res); let mut builder = flatbuffers::FlatBufferBuilder::new(); diff --git a/src/hyperlight_host/tests/integration_test.rs b/src/hyperlight_host/tests/integration_test.rs index e4abf68c1..4c3b6bd8d 100644 --- a/src/hyperlight_host/tests/integration_test.rs +++ b/src/hyperlight_host/tests/integration_test.rs @@ -536,7 +536,7 @@ fn guest_malloc_abort() { }); // allocate a vector (on heap) that is bigger than the heap - let heap_size = 0x4000; + let heap_size = 0x8000; let size_to_allocate = 0x10000; assert!( size_to_allocate > heap_size, @@ -585,7 +585,7 @@ fn guest_outb_with_invalid_port_poisons_sandbox() { #[test] fn guest_panic_no_alloc() { - let heap_size = 0x4000; + let heap_size = 0x8000; let mut cfg = SandboxConfiguration::default(); cfg.set_heap_size(heap_size);