From bfd8c6e2551c0042b631e711814a15b036e33031 Mon Sep 17 00:00:00 2001 From: bkmashiro <53376445+bkmashiro@users.noreply.github.com> Date: Fri, 19 Jun 2026 13:17:40 +0100 Subject: [PATCH 1/7] feat: add opt-in generic wasm backend --- README.md | 79 ++- cmd/root.go | 5 +- examples/demo-stateful/README.md | 25 + examples/demo-stateful/main.go | 95 ++++ go.mod | 1 + go.sum | 2 + internal/execution/dispatcher.go | 33 +- internal/execution/supervisor/config.go | 5 +- internal/execution/supervisor/models.go | 3 + internal/execution/wasm/adapter.go | 171 +++++++ internal/execution/wasm/config.go | 82 ++++ internal/execution/wasm/dispatcher.go | 376 ++++++++++++++ internal/execution/wasm/dispatcher_test.go | 515 ++++++++++++++++++++ internal/execution/wasm/pool.go | 62 +++ internal/execution/wasm/snapshot.go | 96 ++++ internal/execution/wasm/snapshot_test.go | 258 ++++++++++ internal/execution/wasm/supervisor.go | 221 +++++++++ internal/execution/wasm/testdata/echo.wasm | Bin 0 -> 241 bytes internal/execution/wasm/testdata/echo.wat | 66 +++ internal/execution/wasm/testhelpers_test.go | 46 ++ scripts/demo-wasm.sh | 119 +++++ 21 files changed, 2252 insertions(+), 8 deletions(-) create mode 100644 examples/demo-stateful/README.md create mode 100644 examples/demo-stateful/main.go create mode 100644 internal/execution/wasm/adapter.go create mode 100644 internal/execution/wasm/config.go create mode 100644 internal/execution/wasm/dispatcher.go create mode 100644 internal/execution/wasm/dispatcher_test.go create mode 100644 internal/execution/wasm/pool.go create mode 100644 internal/execution/wasm/snapshot.go create mode 100644 internal/execution/wasm/snapshot_test.go create mode 100644 internal/execution/wasm/supervisor.go create mode 100644 internal/execution/wasm/testdata/echo.wasm create mode 100644 internal/execution/wasm/testdata/echo.wat create mode 100644 internal/execution/wasm/testhelpers_test.go create mode 100755 scripts/demo-wasm.sh diff --git a/README.md b/README.md index 2cd6fda..71bea78 100644 --- a/README.md +++ b/README.md @@ -59,10 +59,10 @@ GLOBAL OPTIONS: function --arg value, -a value [ --arg value, -a value ] additional arguments for to the worker process. [$FUNCTION_ARGS] - --command value, -c value the command to invoke to start the worker process. [$FUNCTION_COMMAND] + --command value, -c value the command to invoke to start the worker process, or the WASM module path when --interface=wasm. [$FUNCTION_COMMAND] --cwd value, -d value the working directory for the worker process. [$FUNCTION_WORKING_DIR] --env value, -e value [ --env value, -e value ] additional environment variables for the worker process. [$FUNCTION_ENV] - --interface value, -i value the interface to use for worker process communication. Options: rpc, file. (default: "rpc") [$FUNCTION_INTERFACE] + --interface value, -i value the interface to use for worker communication. Options: rpc, file, wasm. (default: "rpc") [$FUNCTION_INTERFACE] --max-workers value, -n value the maximum number of worker processes to run concurrently. (default: number of CPU cores) [$FUNCTION_MAX_PROCS] rpc @@ -245,6 +245,81 @@ For example, a Wolfram Language evaluation function in `evaluation.wl` would be wolframscript -file evaluation.wl /tmp/shimmy/abc/request-data-123 /tmp/shimmy/abc/response-data-456 ``` +#### WebAssembly (`--interface wasm`, opt-in) + +The WASM interface executes a pre-built WASI module in-process using wazero. This +is an execution backend only: Shimmy still owns the public HTTP/API contract, +request validation, command routing, cases, and response handling. + +Shimmy does not compile evaluator source code at request time and does not infer +a source language from dependency files. Language-specific work belongs in build +or deployment recipes that produce an `eval.wasm` artifact. + +A generic WASM evaluator module must export: + +| Export | Purpose | +|--------|---------| +| `memory` | Guest linear memory. | +| `alloc(size: i32) -> i32` | Reserves memory where Shimmy writes the request JSON. | +| `evaluate(ptr: i32, len: i32) -> i32` | Executes one command and returns a response pointer. | + +Shimmy writes this internal adapter envelope into guest memory: + +```json +{ + "method": "eval", + "params": { + "response": "...", + "answer": "...", + "params": {} + } +} +``` + +The response pointer returned by `evaluate` must point at: + +```text +[p:p+4] little-endian uint32 JSON length +[p+4:p+4+len] JSON object bytes +``` + +Run a pre-built WASI module with: + +```shell +FUNCTION_INTERFACE=wasm \ +FUNCTION_WASM_MODULE=/path/to/eval.wasm \ +FUNCTION_MAX_PROCS=1 \ +shimmy serve +``` + +`FUNCTION_COMMAND=/path/to/eval.wasm` is also accepted for compatibility, but +`FUNCTION_WASM_MODULE` is clearer for new deployments. + +Example build recipes: + +```shell +# Go +GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm . + +# Rust +cargo build --target wasm32-wasip1 --release + +# C/C++ +/opt/wasi-sdk/bin/clang --target=wasm32-wasip1 ... -o eval.wasm +/opt/wasi-sdk/bin/clang++ --target=wasm32-wasip1 ... -o eval.wasm +``` + +The backend keeps a warm module instance pool and restores a full linear-memory +snapshot after each request. This gives warm reuse without leaking guest mutable +state between requests. Dirty-page restore, Python runtimes, Pyodide, and package +bundling are intentionally out of scope for this generic backend. + +Try the state-isolation example: + +```shell +scripts/demo-wasm.sh +``` + ### Sandboxed Execution (Linux only, experimental) Shimmy can wrap each worker process in an [nsjail](https://github.com/google/nsjail) sandbox to safely execute arbitrary, untrusted code. The sandbox provides: diff --git a/cmd/root.go b/cmd/root.go index 275c31e..f8b9f0b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -47,7 +47,7 @@ functions on arbitrary, serverless platforms.` &cli.StringFlag{ Name: "interface", Aliases: []string{"i"}, - Usage: "the interface to use for worker process communication. Options: rpc, file.", + Usage: "the interface to use for worker communication. Options: rpc, file, wasm.", Value: "rpc", Category: "function", EnvVars: []string{"FUNCTION_INTERFACE"}, @@ -55,10 +55,9 @@ functions on arbitrary, serverless platforms.` &cli.StringFlag{ Name: "command", Aliases: []string{"c"}, - Usage: "the command to invoke to start the worker process.", + Usage: "the command to invoke to start the worker process, or the WASM module path when --interface=wasm.", Category: "function", EnvVars: []string{"FUNCTION_COMMAND"}, - Required: true, }, &cli.StringFlag{ Name: "cwd", diff --git a/examples/demo-stateful/README.md b/examples/demo-stateful/README.md new file mode 100644 index 0000000..ffe4c02 --- /dev/null +++ b/examples/demo-stateful/README.md @@ -0,0 +1,25 @@ +# Demo: stateful WASM evaluator + +This is a deliberately tiny Shimmy-WASM evaluator for live demos. + +It mutates a module-global `invocationCount` on every `eval` call. In a normal +warm worker, that state would leak across requests (`1`, then `2`, then `3`). +Shimmy-WASM snapshots the module memory after startup and restores it after each +request, so every request observes `guest_invocation_count: 1`. + +Use it via the one-command demo runner: + +```bash +scripts/demo-wasm.sh +``` + +What the demo shows: + +1. Build Shimmy. +2. Compile this evaluator to `wasm32-wasip1`. +3. Start `shimmy serve` with `FUNCTION_INTERFACE=wasm`. +4. Send two HTTP grading requests. +5. Assert both responses report `guest_invocation_count == 1`. + +That is the visible end-to-end proof: HTTP request → Shimmy → wazero WASM guest +→ response validation → HTTP response, with per-request state reset. diff --git a/examples/demo-stateful/main.go b/examples/demo-stateful/main.go new file mode 100644 index 0000000..80125a4 --- /dev/null +++ b/examples/demo-stateful/main.go @@ -0,0 +1,95 @@ +//go:build wasip1 + +// demo-stateful is a tiny Shimmy-WASM evaluation function for live demos. +// It intentionally mutates module-global state on every request. The host +// should snapshot/restore WASM memory after each call, so the counter reported +// to the next request should still be 1 rather than leaking as 2, 3, ... +package main + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "unsafe" +) + +var reqBuf [256 * 1024]byte +var respBuf [256 * 1024]byte + +// This is deliberately mutable guest state. A non-isolated warm worker would +// leak it across requests; Shimmy-WASM restores the memory snapshot instead. +var invocationCount uint32 +var lastResponse [64]byte + +//go:wasmexport alloc +func alloc(size int32) int32 { + _ = size + return int32(uintptr(unsafe.Pointer(&reqBuf[0]))) +} + +//go:wasmexport evaluate +func evaluate(reqPtr int32, reqLen int32) int32 { + _ = reqPtr + + var req struct { + Method string `json:"method"` + Params struct { + Response string `json:"response"` + Answer string `json:"answer"` + Params map[string]any `json:"params"` + } `json:"params"` + } + + if err := json.Unmarshal(reqBuf[:reqLen], &req); err != nil { + writeResp(map[string]any{"error": map[string]any{"message": err.Error()}}) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) + } + + invocationCount++ + copy(lastResponse[:], req.Params.Response) + + switch req.Method { + case "eval": + correct := req.Params.Response == req.Params.Answer + feedback := "Correct — and the guest counter is still 1, so snapshot/restore worked." + if !correct { + feedback = fmt.Sprintf("Incorrect: got %q, expected %q. Guest counter is still %d.", req.Params.Response, req.Params.Answer, invocationCount) + } + writeResp(map[string]any{ + "command": "eval", + "result": map[string]any{ + "is_correct": correct, + "feedback": feedback, + "guest_invocation_count": invocationCount, + "snapshot_isolation_ok": invocationCount == 1, + }, + }) + case "preview": + writeResp(map[string]any{ + "command": "preview", + "result": map[string]any{ + "preview": map[string]any{"type": "text", "content": req.Params.Response}, + }, + }) + case "healthcheck": + writeResp(map[string]any{ + "command": "healthcheck", + "result": map[string]any{"status": "ok"}, + }) + default: + writeResp(map[string]any{"error": map[string]any{"message": "unknown method: " + req.Method}}) + } + + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) +} + +func writeResp(v map[string]any) { + data, err := json.Marshal(v) + if err != nil { + data = []byte(`{"error":{"message":"marshal failed"}}`) + } + binary.LittleEndian.PutUint32(respBuf[:4], uint32(len(data))) + copy(respBuf[4:], data) +} + +func main() {} diff --git a/go.mod b/go.mod index 10caf84..cee6645 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/perimeterx/marshmallow v1.1.5 // indirect github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect + github.com/tetratelabs/wazero v1.9.0 github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/woodsbury/decimal128 v1.3.0 // indirect diff --git a/go.sum b/go.sum index 014f78c..8aeb9f0 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/supranational/blst v0.3.11 h1:LyU6FolezeWAhvQk0k6O/d49jqgO52MSDDfYgbeoEm4= github.com/supranational/blst v0.3.11/go.mod h1:jZJtfjgudtNl4en1tzwPIV3KjUnQUvG3/j+w+fVonLw= +github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= +github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/internal/execution/dispatcher.go b/internal/execution/dispatcher.go index 300ca3f..f8aec05 100644 --- a/internal/execution/dispatcher.go +++ b/internal/execution/dispatcher.go @@ -2,11 +2,16 @@ package execution import ( "context" + "fmt" + "os" + "sort" + "strings" "go.uber.org/zap" "github.com/lambda-feedback/shimmy/internal/execution/dispatcher" "github.com/lambda-feedback/shimmy/internal/execution/supervisor" + "github.com/lambda-feedback/shimmy/internal/execution/wasm" ) type Dispatcher dispatcher.Dispatcher @@ -32,7 +37,8 @@ type Params struct { } func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { - if params.Config.Supervisor.IO.Interface == supervisor.RpcIO { + switch params.Config.Supervisor.IO.Interface { + case supervisor.RpcIO: return dispatcher.NewDedicatedDispatcher( dispatcher.DedicatedDispatcherParams{ Config: dispatcher.DedicatedDispatcherConfig{ @@ -42,7 +48,30 @@ func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { Log: params.Log, }, ) - } else { + + case supervisor.WasmIO: + wasmProfile := strings.ToLower(strings.TrimSpace(os.Getenv("FUNCTION_WASM_PROFILE"))) + if wasmProfile == "" { + wasmProfile = "generic" + } + if wasmProfile != "generic" { + validProfiles := []string{"generic"} + sort.Strings(validProfiles) + return nil, fmt.Errorf("unsupported FUNCTION_WASM_PROFILE %q; supported values: %s", wasmProfile, strings.Join(validProfiles, ", ")) + } + + cfg := wasm.Config{ + ModulePath: params.Config.Supervisor.StartParams.Cmd, + MaxInstances: params.Config.MaxWorkers, + Timeout: params.Config.Supervisor.SendParams.Timeout, + } + d := wasm.NewDispatcher(cfg, params.Log) + if err := d.Start(params.Context); err != nil { + return nil, err + } + return d, nil + + default: return dispatcher.NewPooledDispatcher( dispatcher.PooledDispatcherParams{ Config: dispatcher.PooledDispatcherConfig{ diff --git a/internal/execution/supervisor/config.go b/internal/execution/supervisor/config.go index 520e367..9d7303d 100644 --- a/internal/execution/supervisor/config.go +++ b/internal/execution/supervisor/config.go @@ -24,7 +24,7 @@ type SendConfig struct { // IOInterface describes the interface used to communicate with the worker. type IOConfig struct { // Interface describes the communication between the supervisor - // and the worker. It can be either "rpc" or "file". + // and the worker. It can be "rpc", "file", or "wasm". // // If "rpc", the supervisor will communicate with the worker over // a specified transport. The worker is expected to handle incoming @@ -35,6 +35,9 @@ type IOConfig struct { // containing the message payload and response are passed as args // to the worker process. // + // If "wasm", Shimmy loads a pre-built WASI module from FUNCTION_COMMAND + // or FUNCTION_WASM_MODULE and calls its internal alloc/evaluate adapter ABI. + // // Default is "rpc". Interface IOInterface `conf:"interface"` diff --git a/internal/execution/supervisor/models.go b/internal/execution/supervisor/models.go index e7776db..8f98bcb 100644 --- a/internal/execution/supervisor/models.go +++ b/internal/execution/supervisor/models.go @@ -16,6 +16,9 @@ const ( // FileIO describes communication w/ processes over files FileIO IOInterface = "file" + + // WasmIO describes in-process execution of a pre-built WASI module. + WasmIO IOInterface = "wasm" ) // IOTransport describes the transport mechanism used to communicate with diff --git a/internal/execution/wasm/adapter.go b/internal/execution/wasm/adapter.go new file mode 100644 index 0000000..bc11b04 --- /dev/null +++ b/internal/execution/wasm/adapter.go @@ -0,0 +1,171 @@ +// Package wasm implements a WebAssembly execution backend for shimmy using +// wazero. It exposes a [Dispatcher] that manages a pool of pre-compiled WASM +// module instances and dispatches evaluation requests to them. +// +// # Guest ABI +// +// WASM modules loaded by this backend must export two functions: +// +// alloc(size i32) i32 +// Allocate `size` bytes in guest linear memory and return a pointer to +// the start of the allocation. The host will write the JSON-encoded +// request into this region immediately after the call returns. +// +// evaluate(req_ptr i32, req_len i32) i32 +// Process the JSON request at [req_ptr, req_ptr+req_len). Returns a +// pointer P into guest memory where the response is encoded as: +// bytes [P, P+4) — uint32 little-endian response length L +// bytes [P+4, P+4+L) — L bytes of UTF-8 JSON response +// +// The JSON request envelope has the shape: +// +// {"method": "", "params": {…}} +// +// The JSON response is a plain JSON object (map[string]any) that is returned +// verbatim to the caller. +package wasm + +import ( + "context" + "encoding/binary" + "encoding/json" + "fmt" + "time" + + "github.com/tetratelabs/wazero/api" + "go.uber.org/zap" +) + +// requestEnvelope is the JSON structure written into guest memory for each +// evaluation call. +type requestEnvelope struct { + Method string `json:"method"` + Params map[string]any `json:"params"` +} + +// wasmAdapter performs a single evaluate call against a live wazero api.Module. +// It is stateless and safe to call from one goroutine at a time. +type wasmAdapter struct { + mod api.Module + log *zap.Logger + allocFn api.Function // cached exported "alloc" function (M-4 fix) + evalFn api.Function // cached exported "evaluate" function (M-4 fix) +} + +func newWasmAdapter(mod api.Module, log *zap.Logger) *wasmAdapter { + return &wasmAdapter{ + mod: mod, + log: log.Named("adapter_wasm"), + allocFn: mod.ExportedFunction("alloc"), + evalFn: mod.ExportedFunction("evaluate"), + } +} + +// send marshals (method, data) into JSON, writes it into the guest's linear +// memory via alloc, calls evaluate, and reads back the length-prefixed +// response. +func (a *wasmAdapter) send( + ctx context.Context, + method string, + data map[string]any, + timeout time.Duration, +) (map[string]any, error) { + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + // 1. Marshal request envelope. + envelope := requestEnvelope{Method: method, Params: data} + + reqBytes, err := json.Marshal(envelope) + if err != nil { + return nil, fmt.Errorf("wasm: marshal request: %w", err) + } + + reqLen := uint64(len(reqBytes)) + + // 2. Allocate guest memory for the request (cached lookup — M-4 fix). + if a.allocFn == nil { + return nil, fmt.Errorf("wasm: guest module does not export 'alloc'") + } + + allocRes, err := a.allocFn.Call(ctx, reqLen) + if err != nil { + return nil, fmt.Errorf("wasm: alloc(%d): %w", reqLen, err) + } + + reqPtr := allocRes[0] + if reqPtr == 0 { + return nil, fmt.Errorf("wasm: alloc returned NULL (out of memory)") + } + + // 3. Write request bytes into guest memory. + mem := a.mod.Memory() + if mem == nil { + return nil, fmt.Errorf("wasm: guest module has no linear memory") + } + + if !mem.Write(uint32(reqPtr), reqBytes) { + return nil, fmt.Errorf( + "wasm: failed to write %d bytes at ptr=%d (memory size=%d)", + len(reqBytes), reqPtr, mem.Size(), + ) + } + + // 4. Call evaluate (cached lookup — M-4 fix). + if a.evalFn == nil { + return nil, fmt.Errorf("wasm: guest module does not export 'evaluate'") + } + + a.log.Debug("calling evaluate", + zap.String("method", method), + zap.Uint64("req_ptr", reqPtr), + zap.Uint64("req_len", reqLen), + ) + + evalRes, err := a.evalFn.Call(ctx, reqPtr, reqLen) + if err != nil { + return nil, fmt.Errorf("wasm: evaluate: %w", err) + } + + resPtr := uint32(evalRes[0]) + + // 5. Read the 4-byte little-endian length prefix. + lenBytes, ok := mem.Read(resPtr, 4) + if !ok { + return nil, fmt.Errorf("wasm: failed to read response length at ptr=%d", resPtr) + } + + resLen := binary.LittleEndian.Uint32(lenBytes) + + // 6. Read the response JSON body. + // Validate bounds before reading to catch corrupt/malicious response pointers. + if uint64(resPtr)+4+uint64(resLen) > uint64(mem.Size()) { + return nil, fmt.Errorf( + "wasm: response out of bounds: resPtr=%d resLen=%d memSize=%d", + resPtr, resLen, mem.Size(), + ) + } + resBytes, ok := mem.Read(resPtr+4, resLen) + if !ok { + return nil, fmt.Errorf( + "wasm: failed to read %d response bytes at ptr=%d", + resLen, resPtr+4, + ) + } + + a.log.Debug("received response", + zap.Uint32("res_ptr", resPtr), + zap.Uint32("res_len", resLen), + ) + + // 7. Unmarshal response. + var result map[string]any + if err := json.Unmarshal(resBytes, &result); err != nil { + return nil, fmt.Errorf("wasm: unmarshal response: %w", err) + } + + return result, nil +} diff --git a/internal/execution/wasm/config.go b/internal/execution/wasm/config.go new file mode 100644 index 0000000..48e6c0e --- /dev/null +++ b/internal/execution/wasm/config.go @@ -0,0 +1,82 @@ +package wasm + +import ( + "os" + "strconv" + "strings" + "time" +) + +// Config holds the configuration for the opt-in generic WASM execution +// backend. Shimmy consumes an already-built WASI module; source-language +// compilation remains a deployment/build concern outside this package. +type Config struct { + // ModulePath is the path to the .wasm module. It is normally populated from + // FUNCTION_COMMAND for compatibility with the rest of Shimmy, or overridden + // by FUNCTION_WASM_MODULE when set. + ModulePath string + + // MaxInstances is the number of warm module instances in the pool. + MaxInstances int + + // Timeout is the per-request deadline passed to the guest evaluate call. + Timeout time.Duration + + // MaxMemoryPages limits WASM linear memory (1 page = 64 KiB). The default is + // intentionally small and can be raised by FUNCTION_WASM_MAX_MEMORY_PAGES. + MaxMemoryPages uint32 + + // AllowedPaths is a comma-separated allowlist of read-only host directories + // exposed to WASI. Empty means no filesystem access. + AllowedPaths []string + + // AllowedEnv is a comma-separated allowlist of host environment variable + // names exposed to WASI. Empty means no environment variables. + AllowedEnv []string + + // CompileCacheDir enables wazero's on-disk compilation cache when set via + // FUNCTION_WASM_COMPILE_CACHE. + CompileCacheDir string +} + +func (c *Config) applyDefaults() { + if c.Timeout == 0 { + c.Timeout = 30 * time.Second + } + if c.MaxMemoryPages == 0 { + c.MaxMemoryPages = 256 // 16 MiB + } +} + +// applyEnv reads FUNCTION_WASM_* overrides. These settings are intentionally +// limited to generic WASM runtime concerns; Python/reactor/package bundling is +// out of scope for this backend. +func (c *Config) applyEnv() { + if v := os.Getenv("FUNCTION_WASM_MODULE"); v != "" { + c.ModulePath = v + } + if v := os.Getenv("FUNCTION_WASM_MAX_MEMORY_PAGES"); v != "" { + if n, err := strconv.ParseUint(v, 10, 32); err == nil { + c.MaxMemoryPages = uint32(n) + } + } + if v := os.Getenv("FUNCTION_WASM_ALLOWED_PATHS"); v != "" { + c.AllowedPaths = splitNonEmpty(v, ",") + } + if v := os.Getenv("FUNCTION_WASM_ALLOWED_ENV"); v != "" { + c.AllowedEnv = splitNonEmpty(v, ",") + } + if v := os.Getenv("FUNCTION_WASM_COMPILE_CACHE"); v != "" { + c.CompileCacheDir = v + } +} + +func splitNonEmpty(s, sep string) []string { + var out []string + for _, p := range strings.Split(s, sep) { + if t := strings.TrimSpace(p); t != "" { + out = append(out, t) + } + } + return out +} diff --git a/internal/execution/wasm/dispatcher.go b/internal/execution/wasm/dispatcher.go new file mode 100644 index 0000000..85696d0 --- /dev/null +++ b/internal/execution/wasm/dispatcher.go @@ -0,0 +1,376 @@ +package wasm + +import ( + "context" + "fmt" + "os" + "runtime" + "sync" + "time" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "go.uber.org/zap" + + "github.com/lambda-feedback/shimmy/internal/execution/dispatcher" +) + +// ErrDispatcherClosed is returned by Send after the dispatcher has begun (or +// completed) Shutdown. Callers should treat it as a terminal error. +var ErrDispatcherClosed = fmt.Errorf("wasm: dispatcher is shut down") + +// Dispatcher implements [dispatcher.Dispatcher] for the WASM execution +// backend. It compiles the .wasm module once at startup, then maintains a pool +// of pre-initialised [wasmSupervisor] instances (one compiled module, N module +// instances). Requests are dispatched by acquiring a supervisor from the pool, +// calling its Send, and returning it to the pool. +type Dispatcher struct { + cfg Config + rt wazero.Runtime + compiled wazero.CompiledModule + modCfg wazero.ModuleConfig + pool chan *wasmSupervisor + log *zap.Logger + + // mu protects closed and serialises the closed/push transitions so that a + // replacement supervisor cannot land in the pool after Shutdown has begun + // draining it. + mu sync.Mutex + closed bool + // closedCh is closed atomically with closed=true (under mu) by Shutdown. + // Send selects on it to (a) unblock a pool acquire that is racing Shutdown + // and (b) avoid waiting on an empty pool that Shutdown is about to drain. + closedCh chan struct{} + // pending tracks BOTH in-flight Sends (Add in tryBeginSend, Done via Send's + // defer) AND background goroutines spawned during a Send (replacement + // spawns, discard-shutdowns). Shutdown waits on it before draining the + // pool / closing the runtime. + // + // Invariant: every pending.Add is either (a) made under d.mu after + // observing !closed, or (b) made by code that is itself holding a pending + // count (e.g. discardAsync called from inside Send). This keeps Add from + // racing Shutdown's Wait — if closed is already set, branch (a) skips the + // Add and falls back to a synchronous close; in branch (b) Shutdown is + // guaranteed to still be blocked at Wait on the caller's count. + pending sync.WaitGroup +} + +var _ dispatcher.Dispatcher = (*Dispatcher)(nil) + +// NewDispatcher creates a new WASM dispatcher. Compilation and pool +// initialisation happen in Start. +func NewDispatcher(cfg Config, log *zap.Logger) *Dispatcher { + return &Dispatcher{ + cfg: cfg, + log: log.Named("dispatcher_wasm"), + closedCh: make(chan struct{}), + } +} + +// tryBeginSend atomically checks the closed flag and increments pending. It +// returns false if Shutdown has begun (caller must abort with +// ErrDispatcherClosed); on true the caller MUST call pending.Done exactly +// once when finished. Holding a pending count across the entire Send keeps +// Shutdown's Wait blocked while the Send is mid-flight, which is what lets +// discardAsync inside Send safely Add to pending without racing Wait. +func (d *Dispatcher) tryBeginSend() bool { + d.mu.Lock() + defer d.mu.Unlock() + if d.closed { + return false + } + d.pending.Add(1) + return true +} + +// Start reads and compiles the .wasm file, sets up WASI host functions, and +// pre-warms the supervisor pool. +func (d *Dispatcher) Start(ctx context.Context) error { + // Pick up sandbox overrides from FUNCTION_WASM_* env vars (including + // FUNCTION_WASM_MODULE as an alternative to FUNCTION_COMMAND), then apply + // sensible defaults for any fields still at their zero values. + d.cfg.applyEnv() + d.cfg.applyDefaults() + + if d.cfg.ModulePath == "" { + return fmt.Errorf("wasm: ModulePath must be set (FUNCTION_COMMAND or FUNCTION_WASM_MODULE)") + } + + maxInstances := d.cfg.MaxInstances + if maxInstances <= 0 { + maxInstances = runtime.NumCPU() + } + + d.log.Info("starting wasm dispatcher", + zap.String("module", d.cfg.ModulePath), + zap.Int("max_instances", maxInstances), + zap.Uint32("max_memory_pages", d.cfg.MaxMemoryPages), + zap.Duration("timeout", d.cfg.Timeout), + ) + + // Read the .wasm bytes from disk. + wasmBytes, err := os.ReadFile(d.cfg.ModulePath) + if err != nil { + return fmt.Errorf("wasm: read module file %q: %w", d.cfg.ModulePath, err) + } + + // Build the runtime config with memory limit and context-done interruption. + rtCfg := wazero.NewRuntimeConfig(). + // WithCloseOnContextDone causes wazero to interrupt a running WASM module + // when the call context is cancelled or times out, preventing goroutine leaks. + WithCloseOnContextDone(true) + if d.cfg.MaxMemoryPages > 0 { + rtCfg = rtCfg.WithMemoryLimitPages(d.cfg.MaxMemoryPages) + } + + // Wire in on-disk compilation cache when configured. + if d.cfg.CompileCacheDir != "" { + cache, err := wazero.NewCompilationCacheWithDir(d.cfg.CompileCacheDir) + if err != nil { + d.log.Warn("failed to create wazero compilation cache, continuing without cache", + zap.String("dir", d.cfg.CompileCacheDir), + zap.Error(err)) + } else { + rtCfg = rtCfg.WithCompilationCache(cache) + d.log.Info("wazero compilation cache enabled", zap.String("dir", d.cfg.CompileCacheDir)) + } + } + + // Create a single wazero runtime shared by all instances. + rt := wazero.NewRuntimeWithConfig(ctx, rtCfg) + d.rt = rt + + // Instantiate WASI host functions. Most evaluation functions will need at + // least minimal WASI support (e.g. for memory allocation helpers compiled + // from C/Rust/TinyGo). + if _, err := wasi_snapshot_preview1.Instantiate(ctx, rt); err != nil { + _ = rt.Close(ctx) + return fmt.Errorf("wasm: instantiate wasi: %w", err) + } + + // Compile the module once; all instances share the compiled code. + compiled, err := rt.CompileModule(ctx, wasmBytes) + if err != nil { + _ = rt.Close(ctx) + return fmt.Errorf("wasm: compile module: %w", err) + } + d.compiled = compiled + + // Build a locked-down ModuleConfig: no filesystem, no env vars, no + // stdin/stdout/stderr, no args. Only allow nanosleep and wall/mono clocks + // which the Go runtime needs. + modCfg := wazero.NewModuleConfig(). + WithName(""). + WithSysNanosleep(). + WithSysWalltime(). + WithSysNanotime() + + // Filesystem: mount allowed paths read-only; no access by default. + fsCfg := wazero.NewFSConfig() + for _, p := range d.cfg.AllowedPaths { + fsCfg = fsCfg.WithReadOnlyDirMount(p, p) + } + modCfg = modCfg.WithFSConfig(fsCfg) + + // Env vars: expose only explicitly whitelisted variables. + for _, key := range d.cfg.AllowedEnv { + if val, ok := os.LookupEnv(key); ok { + modCfg = modCfg.WithEnv(key, val) + } + } + d.modCfg = modCfg + + // Build the pool. + d.pool = make(chan *wasmSupervisor, maxInstances) + + for i := 0; i < maxInstances; i++ { + sv := newWasmSupervisor(rt, compiled, modCfg, d.cfg.Timeout, d.log) + + if err := sv.Start(ctx); err != nil { + // Clean up already-started supervisors. + _ = drainBufferedPool(ctx, d.pool, d.log) + _ = rt.Close(ctx) + return fmt.Errorf("wasm: start instance %d: %w", i, err) + } + + d.pool <- sv + } + + d.log.Info("wasm dispatcher ready", zap.Int("instances", maxInstances)) + + return nil +} + +// Send acquires a supervisor from the pool, dispatches the request, and +// returns the supervisor to the pool. +func (d *Dispatcher) Send( + ctx context.Context, + method string, + data map[string]any, +) (map[string]any, error) { + if !d.tryBeginSend() { + return nil, ErrDispatcherClosed + } + defer d.pending.Done() + + // Acquire a supervisor, honouring the caller's context AND the shutdown + // signal so we never block forever on a drained pool. + var sv *wasmSupervisor + select { + case sv = <-d.pool: + case <-d.closedCh: + return nil, ErrDispatcherClosed + case <-ctx.Done(): + return nil, fmt.Errorf("wasm: acquire instance: %w", ctx.Err()) + } + + result, err := sv.Send(ctx, method, data) + + // Return the supervisor to the pool only if it is healthy. + // If the snapshot restore failed inside Send, sv.healthy is false and the + // supervisor's state is undefined — discard it and spawn a replacement so + // pool capacity is eventually restored. + if sv.IsHealthy() { + d.returnOrDiscard(sv) + } else { + d.log.Warn("wasm supervisor unhealthy after request — dropping from pool, spawning replacement") + d.discardAsync(sv) + d.spawnReplacementAsync() + } + + if err != nil { + return nil, fmt.Errorf("wasm: send: %w", err) + } + + return result, nil +} + +// returnOrDiscard puts a healthy supervisor back in the pool unless Shutdown +// has begun, in which case the supervisor is closed asynchronously so it does +// not leak past a drained pool. +// +// Must be called from a goroutine that already holds a pending count (i.e. +// from inside Send) so that the Add issued by discardAsync is guaranteed to +// happen before Shutdown's pending.Wait can return. +func (d *Dispatcher) returnOrDiscard(sv *wasmSupervisor) { + d.mu.Lock() + if d.closed { + d.mu.Unlock() + d.discardAsync(sv) + return + } + // Push under the lock so it interleaves correctly with Shutdown's + // closed=true → drainPool sequence: either we push before closed is set + // (drainPool sees the supervisor) or we discard via the branch above. + d.pool <- sv + d.mu.Unlock() +} + +// discardAsync closes a discarded supervisor in the background and tracks it +// via the pending WaitGroup so Shutdown can wait for the close to complete +// before tearing down the runtime. +// +// Must be called from a goroutine that already holds a pending count +// (Send, via tryBeginSend). That invariant keeps Shutdown.pending.Wait +// blocked across this Add, eliminating the Add-after-Wait race. +func (d *Dispatcher) discardAsync(sv *wasmSupervisor) { + d.pending.Add(1) + go func() { + defer d.pending.Done() + _ = sv.Shutdown(context.Background()) + }() +} + +// spawnReplacementAsync kicks off spawnOne in a background goroutine, but only +// if the dispatcher is still open. If Shutdown has begun, no replacement is +// scheduled. Tracked via the pending WaitGroup. +func (d *Dispatcher) spawnReplacementAsync() { + d.mu.Lock() + if d.closed { + d.mu.Unlock() + return + } + d.pending.Add(1) + d.mu.Unlock() + go d.spawnOne() +} + +// spawnOne initialises a fresh wasmSupervisor and adds it to the pool. +// Called in a goroutine when an unhealthy supervisor is discarded so that +// pool capacity is eventually restored. Failures are logged but not fatal. +// +// If Shutdown begins while Start is running, the freshly initialised +// supervisor is closed immediately rather than inserted into the drained pool. +func (d *Dispatcher) spawnOne() { + defer d.pending.Done() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + d.log.Info("wasm: initialising replacement supervisor") + sv := newWasmSupervisor(d.rt, d.compiled, d.modCfg, d.cfg.Timeout, d.log) + if err := sv.Start(ctx); err != nil { + d.log.Error("wasm: replacement supervisor init failed", zap.Error(err)) + return + } + + d.mu.Lock() + if d.closed { + d.mu.Unlock() + d.log.Info("wasm: replacement supervisor born during shutdown — closing immediately") + _ = sv.Shutdown(context.Background()) + return + } + d.pool <- sv + d.mu.Unlock() + d.log.Info("wasm: replacement supervisor ready") +} + +// Shutdown closes all module instances and the wazero runtime. Idempotent. +func (d *Dispatcher) Shutdown(ctx context.Context) error { + d.mu.Lock() + if d.closed { + d.mu.Unlock() + return nil + } + d.closed = true + // Close the channel under mu so the closed=true / close(closedCh) pair is + // atomic with respect to tryBeginSend: any Send that observes !closed has + // also pending.Add'd before Shutdown can reach pending.Wait. + close(d.closedCh) + d.mu.Unlock() + + d.log.Debug("shutting down wasm dispatcher") + + // Wait for in-flight Sends AND any background goroutines (replacement + // spawns / discard shutdowns) to finish so that no late-created supervisor + // lands in the pool after the drain below, no module is mid-Close while we + // close the runtime, and no Send is running against the wazero runtime + // when we tear it down. + d.pending.Wait() + + // Non-blocking drain: after pending.Wait, no spawn or returnOrDiscard + // goroutine will push to the pool, so we just close everything currently + // buffered. (drainPool's blocking-for-cap-items semantics would deadlock + // here when spawnOne took the closed-shortcut and never pushed.) + for { + select { + case sv := <-d.pool: + if err := sv.Shutdown(ctx); err != nil { + d.log.Warn("error shutting down pooled supervisor", zap.Error(err)) + } + default: + goto drained + } + } +drained: + + if d.rt != nil { + if err := d.rt.Close(ctx); err != nil { + return fmt.Errorf("wasm: close runtime: %w", err) + } + d.rt = nil + } + + return nil +} diff --git a/internal/execution/wasm/dispatcher_test.go b/internal/execution/wasm/dispatcher_test.go new file mode 100644 index 0000000..fb71f0a --- /dev/null +++ b/internal/execution/wasm/dispatcher_test.go @@ -0,0 +1,515 @@ +package wasm + +import ( + "context" + "errors" + "os" + "path/filepath" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tetratelabs/wazero" + "go.uber.org/zap" +) + +// echoModulePath returns the absolute path to the pre-compiled echo.wasm test +// fixture. The fixture is a minimal guest module that always returns +// {"ok":true} regardless of the request, which lets us test the host-side Go +// code (alloc call, memory write, evaluate call, length-prefix parsing, JSON +// unmarshal) without implementing a full language runtime in WAT. +func echoModulePath(t *testing.T) string { + t.Helper() + // __file__ is not available in Go, but runtime.Caller gives us the source + // file path so we can derive testdata/ relative to the test file. + _, filename, _, ok := runtime.Caller(0) + require.True(t, ok, "runtime.Caller failed") + return filepath.Join(filepath.Dir(filename), "testdata", "echo.wasm") +} + +// newTestLogger returns a no-op zap logger suitable for unit tests. +func newTestLogger(t *testing.T) *zap.Logger { + t.Helper() + log, err := zap.NewDevelopment() + require.NoError(t, err) + return log +} + +// newEchoDispatcher creates a Dispatcher backed by the echo fixture and starts +// it. The caller is responsible for calling Shutdown. +func newEchoDispatcher(t *testing.T, maxInstances int) *Dispatcher { + t.Helper() + cfg := Config{ + ModulePath: echoModulePath(t), + MaxInstances: maxInstances, + Timeout: 5 * time.Second, + } + d := NewDispatcher(cfg, newTestLogger(t)) + require.NoError(t, d.Start(context.Background()), "dispatcher start") + return d +} + +// TestDispatcher_StartStop verifies that a Dispatcher can be started and shut +// down cleanly without any interaction in between. +func TestDispatcher_StartStop(t *testing.T) { + d := newEchoDispatcher(t, 1) + err := d.Shutdown(context.Background()) + assert.NoError(t, err) +} + +// TestDispatcher_StartStop_MultipleInstances verifies start/stop with the +// default pool size (NumCPU). +func TestDispatcher_StartStop_MultipleInstances(t *testing.T) { + d := newEchoDispatcher(t, runtime.NumCPU()) + err := d.Shutdown(context.Background()) + assert.NoError(t, err) +} + +// TestDispatcher_Send_BasicResponse sends a single request and checks that the +// echo module returns {"ok":true}. +func TestDispatcher_Send_BasicResponse(t *testing.T) { + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + result, err := d.Send(context.Background(), "test", map[string]any{"hello": "world"}) + require.NoError(t, err) + require.NotNil(t, result) + + ok, exists := result["ok"] + assert.True(t, exists, "response should contain 'ok' key") + assert.Equal(t, true, ok, "response 'ok' should be true") +} + +// TestDispatcher_Send_EmptyParams verifies that Send works with nil params. +func TestDispatcher_Send_EmptyParams(t *testing.T) { + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + result, err := d.Send(context.Background(), "noop", nil) + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, true, result["ok"]) +} + +// TestDispatcher_Send_Concurrent sends 10 concurrent requests using a pool of +// 3 instances and verifies that all succeed. +func TestDispatcher_Send_Concurrent(t *testing.T) { + const ( + numWorkers = 10 + numRequests = 20 + poolSize = 3 + ) + + d := newEchoDispatcher(t, poolSize) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + type result struct { + res map[string]any + err error + } + + results := make([]result, numRequests) + var wg sync.WaitGroup + wg.Add(numRequests) + + sem := make(chan struct{}, numWorkers) + for i := range numRequests { + sem <- struct{}{} + go func(i int) { + defer wg.Done() + defer func() { <-sem }() + res, err := d.Send(context.Background(), "eval", map[string]any{"i": i}) + results[i] = result{res, err} + }(i) + } + + wg.Wait() + + for i, r := range results { + require.NoError(t, r.err, "request %d failed", i) + require.NotNil(t, r.res, "request %d returned nil result", i) + assert.Equal(t, true, r.res["ok"], "request %d: unexpected result", i) + } +} + +// TestDispatcher_Send_AfterShutdown checks that Send after Shutdown returns +// ErrDispatcherClosed immediately, rather than blocking on the drained pool +// until the caller's context expires. +func TestDispatcher_Send_AfterShutdown(t *testing.T) { + d := newEchoDispatcher(t, 1) + require.NoError(t, d.Shutdown(context.Background())) + + _, err := d.Send(context.Background(), "test", nil) + assert.ErrorIs(t, err, ErrDispatcherClosed, "Send after Shutdown must return ErrDispatcherClosed") +} + +// TestDispatcher_Shutdown_Idempotent verifies that calling Shutdown twice does +// not return an error or double-close the runtime. +func TestDispatcher_Shutdown_Idempotent(t *testing.T) { + d := newEchoDispatcher(t, 1) + require.NoError(t, d.Shutdown(context.Background())) + require.NoError(t, d.Shutdown(context.Background()), "second Shutdown must be a no-op") +} + +// TestDispatcher_ReplacementDuringShutdown exercises the race where Send has +// just discarded an unhealthy supervisor and scheduled a replacement spawn +// while Shutdown begins. The replacement spawn must NOT insert a supervisor +// into a drained pool, and Shutdown must wait for the spawn goroutine to +// finish before closing the runtime (otherwise the late supervisor would +// reference a torn-down wazero.Runtime). +func TestDispatcher_ReplacementDuringShutdown(t *testing.T) { + d := newEchoDispatcher(t, 1) + + // Consume the only supervisor in the pool to mimic an in-flight Send. + sv := <-d.pool + + // Simulate Send's unhealthy-path bookkeeping: schedule the discard close + // of the bad supervisor and the spawn of a replacement. + d.discardAsync(sv) + d.spawnReplacementAsync() + + // Shutdown races with the spawn. It must wait for pending background work + // (via d.pending.Wait) before draining the pool and closing the runtime. + require.NoError(t, d.Shutdown(context.Background())) + + // After Shutdown the pool must be empty: any replacement that finished + // initialising during the race window was closed by spawnOne's + // closed-guard rather than inserted. + assert.Equal(t, 0, len(d.pool), "drained pool must be empty after Shutdown") + + // Send after Shutdown returns ErrDispatcherClosed promptly. + _, err := d.Send(context.Background(), "test", nil) + assert.ErrorIs(t, err, ErrDispatcherClosed) +} + +// TestDispatcher_Shutdown_WaitsForInFlightSends drives the original race the +// lifecycle patch is meant to fix: many concurrent Sends are issued while +// Shutdown runs partway through. Without the in-flight tracking, Shutdown +// could close the wazero runtime out from under a live Send (use-after-close), +// or returnOrDiscard/discardAsync could call pending.Add after Shutdown's +// pending.Wait already returned. Both surfaces are caught by -race or by an +// outright panic. +// +// Acceptance: every Send either succeeds or returns ErrDispatcherClosed, never +// any other error; Shutdown returns nil; no panic. +func TestDispatcher_Shutdown_WaitsForInFlightSends(t *testing.T) { + d := newEchoDispatcher(t, runtime.NumCPU()) + + const numWorkers = 128 + var ( + wg sync.WaitGroup + successes atomic.Int64 + closedExits atomic.Int64 + unexpected atomic.Int64 + ) + wg.Add(numWorkers) + + start := make(chan struct{}) + for i := 0; i < numWorkers; i++ { + go func() { + defer wg.Done() + <-start + for j := 0; j < 5; j++ { + _, err := d.Send(context.Background(), "eval", map[string]any{"j": j}) + switch { + case err == nil: + successes.Add(1) + case errors.Is(err, ErrDispatcherClosed): + closedExits.Add(1) + return // dispatcher is gone; stop hammering + default: + unexpected.Add(1) + t.Errorf("unexpected error: %v", err) + return + } + } + }() + } + + close(start) + // Give some Sends a chance to begin. + time.Sleep(5 * time.Millisecond) + + require.NoError(t, d.Shutdown(context.Background())) + wg.Wait() + + assert.Zero(t, unexpected.Load(), "no Send may return a non-closed error") + // Post-shutdown Send must return ErrDispatcherClosed promptly (not block). + postCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err := d.Send(postCtx, "eval", nil) + assert.ErrorIs(t, err, ErrDispatcherClosed) + t.Logf("successes=%d closed_exits=%d", successes.Load(), closedExits.Load()) +} + +// TestDispatcher_Shutdown_UnblocksBlockedSend covers the second race called +// out in the patch: a Send that passed tryBeginSend but finds the pool empty +// (all supervisors are in-use or have been drained by a racing Shutdown). +// Without selecting on closedCh, the Send would block on the empty pool until +// the caller's context expired. With the patch it must return +// ErrDispatcherClosed as soon as Shutdown begins. +func TestDispatcher_Shutdown_UnblocksBlockedSend(t *testing.T) { + d := newEchoDispatcher(t, 1) + // Empty the pool so a Send is forced to block on acquire. + sv := <-d.pool + + type sendResult struct { + err error + } + res := make(chan sendResult, 1) + go func() { + _, err := d.Send(context.Background(), "eval", nil) + res <- sendResult{err: err} + }() + + // Let Send reach the empty-pool select. + time.Sleep(50 * time.Millisecond) + + // Put sv back so the dispatcher's drain has something to clean up + // (otherwise Shutdown sees an empty pool, which is also fine). + d.pool <- sv + + require.NoError(t, d.Shutdown(context.Background())) + + select { + case r := <-res: + // Either the Send got the supervisor before Shutdown drained it + // (succeeded), or Shutdown's closedCh fired first. + if r.err != nil { + assert.ErrorIs(t, r.err, ErrDispatcherClosed) + } + case <-time.After(2 * time.Second): + t.Fatal("Send did not return after Shutdown — closedCh select missing") + } +} + +// TestDispatcher_SpawnReplacementAsync_NoopAfterShutdown asserts that calling +// spawnReplacementAsync on a closed dispatcher is a no-op: it must not +// increment pending and must not launch a goroutine that touches the closed +// runtime. +func TestDispatcher_SpawnReplacementAsync_NoopAfterShutdown(t *testing.T) { + d := newEchoDispatcher(t, 1) + require.NoError(t, d.Shutdown(context.Background())) + + // Should return immediately without scheduling work. + d.spawnReplacementAsync() + + // Wait briefly with a deadline — pending.Wait would block forever if the + // no-op guard regressed and a goroutine were leaked with a stale runtime. + done := make(chan struct{}) + go func() { + d.pending.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("pending.Wait did not return — spawn goroutine leaked after Shutdown") + } +} + +// TestDispatcher_MissingModule checks that Start fails when ModulePath does +// not point to a valid file. +func TestDispatcher_MissingModule(t *testing.T) { + cfg := Config{ + ModulePath: "/nonexistent/path/module.wasm", + MaxInstances: 1, + } + d := NewDispatcher(cfg, newTestLogger(t)) + err := d.Start(context.Background()) + assert.Error(t, err, "Start with missing module should fail") +} + +// TestSupervisor_MemoryRestored sends two sequential requests through the same +// supervisor and verifies that both succeed with the same response. This +// exercises the snapshot/restore cycle: after the first evaluate the bump +// allocator's heap_top is advanced, but restoreSnapshot rewinds memory so the +// second call starts from the exact same state. +func TestSupervisor_MemoryRestored(t *testing.T) { + // Use a pool of exactly 1 so both sends use the same supervisor instance. + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + ctx := context.Background() + + r1, err := d.Send(ctx, "first", map[string]any{"seq": 1}) + require.NoError(t, err) + require.NotNil(t, r1) + + r2, err := d.Send(ctx, "second", map[string]any{"seq": 2}) + require.NoError(t, err) + require.NotNil(t, r2) + + // Both responses must be identical {"ok":true}. + assert.Equal(t, r1, r2, "responses must be equal, proving memory was restored between calls") + assert.Equal(t, true, r1["ok"]) + assert.Equal(t, true, r2["ok"]) +} + +// TestSupervisor_MemoryRestored_ManyTimes exercises many sequential calls +// through a single-instance pool to ensure the snapshot/restore cycle is +// stable over repeated invocations. +func TestSupervisor_MemoryRestored_ManyTimes(t *testing.T) { + d := newEchoDispatcher(t, 1) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + ctx := context.Background() + const iters = 50 + + for i := range iters { + res, err := d.Send(ctx, "loop", map[string]any{"i": i}) + require.NoError(t, err, "iteration %d", i) + assert.Equal(t, true, res["ok"], "iteration %d", i) + } +} + +// buildMissingImportModule constructs a valid WASM module that imports a host +// function Shimmy does not provide. Compilation succeeds, but instantiation +// fails inside wasmSupervisor.Start. +func buildMissingImportModule() []byte { + section := func(id byte, payload []byte) []byte { + out := []byte{id} + out = append(out, leb128Encode(uint32(len(payload)))...) + out = append(out, payload...) + return out + } + name := func(s string) []byte { + out := leb128Encode(uint32(len(s))) + out = append(out, []byte(s)...) + return out + } + + module := []byte{0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00} + // Type section: one function type () -> (). + module = append(module, section(1, []byte{0x01, 0x60, 0x00, 0x00})...) + // Import section: one function import env.missing with type index 0. + importPayload := []byte{0x01} + importPayload = append(importPayload, name("env")...) + importPayload = append(importPayload, name("missing")...) + importPayload = append(importPayload, 0x00, 0x00) // kind=func, typeidx=0 + module = append(module, section(2, importPayload)...) + return module +} + +// TestDispatcher_StartFailure_DoesNotBlock verifies that a startup failure while +// initialising the warm instance pool returns an error instead of blocking while +// trying to drain a not-yet-full pool. +func TestDispatcher_StartFailure_DoesNotBlock(t *testing.T) { + modulePath := filepath.Join(t.TempDir(), "missing-import.wasm") + require.NoError(t, os.WriteFile(modulePath, buildMissingImportModule(), 0o644)) + + d := NewDispatcher(Config{ + ModulePath: modulePath, + MaxInstances: 2, + Timeout: 5 * time.Second, + }, newTestLogger(t)) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := d.Start(ctx) + require.Error(t, err) + assert.Contains(t, err.Error(), "start instance") +} + +// TestSupervisor_Start_Idempotent verifies that calling Start twice on the +// same supervisor does not error (the second call is a no-op). +func TestSupervisor_Start_Idempotent(t *testing.T) { + ctx := context.Background() + log := newTestLogger(t) + + wasmBytes := echoWasmBytes(t) + + rt, compiled := compileEchoModule(t, ctx, wasmBytes) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + sv := newWasmSupervisor(rt, compiled, wazero.NewModuleConfig().WithName(""), 5*time.Second, log) + require.NoError(t, sv.Start(ctx)) + require.NoError(t, sv.Start(ctx), "second Start must be a no-op") + require.NoError(t, sv.Shutdown(ctx)) +} + +// TestSupervisor_Send_NotStarted checks that Send before Start returns an +// error. +func TestSupervisor_Send_NotStarted(t *testing.T) { + ctx := context.Background() + log := newTestLogger(t) + + wasmBytes := echoWasmBytes(t) + rt, compiled := compileEchoModule(t, ctx, wasmBytes) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + sv := newWasmSupervisor(rt, compiled, wazero.NewModuleConfig().WithName(""), 5*time.Second, log) + // Do NOT call sv.Start. + + _, err := sv.Send(ctx, "test", nil) + assert.Error(t, err, "Send without Start should return an error") +} + +// TestSupervisor_Send_MemoryGrowDetected is the regression test for +// memory.grow snapshot isolation: if the guest expands linear memory during a +// request, the supervisor must (a) detect the growth, (b) zero the grown tail +// so the next request cannot read leaked guest data, (c) surface +// ErrMemoryGrew, and (d) mark itself unhealthy so the dispatcher discards it +// instead of returning it to the pool. +// +// The echo fixture itself never grows memory, so we simulate a request that +// did by growing the module's memory from host code (between Start and Send) +// and writing a recognisable poison pattern into the new pages. After Send +// runs, restoreSnapshot observes mem.Size() > snapshotSize and must trip the +// defensive path. +func TestSupervisor_Send_MemoryGrowDetected(t *testing.T) { + ctx := context.Background() + log := newTestLogger(t) + + wasmBytes := echoWasmBytes(t) + rt, compiled := compileEchoModule(t, ctx, wasmBytes) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + sv := newWasmSupervisor(rt, compiled, wazero.NewModuleConfig().WithName(""), 5*time.Second, log) + require.NoError(t, sv.Start(ctx)) + t.Cleanup(func() { _ = sv.Shutdown(ctx) }) + + require.True(t, sv.IsHealthy(), "supervisor should be healthy after Start") + + // Capture the snapshot size, then grow memory by 1 page (64 KiB) and + // poison the new pages. This simulates a guest that called memory.grow + // during execution and wrote sensitive data into the new pages. + mem := sv.mod.Memory() + require.NotNil(t, mem) + origSize := mem.Size() + require.Equal(t, origSize, sv.snapshotSize, "snapshotSize must be recorded at Take time") + + prevPages, ok := mem.Grow(1) + require.True(t, ok, "memory.Grow must succeed (echo fixture has no max)") + require.Equal(t, origSize/(64*1024), prevPages) + + grownSize := mem.Size() + require.Greater(t, grownSize, origSize, "memory must have grown") + + poison := make([]byte, grownSize-origSize) + for i := range poison { + poison[i] = 0xAB + } + require.True(t, mem.Write(origSize, poison), "poison tail") + + // Issue a request. The echo guest doesn't itself grow memory, but Send's + // post-call restoreSnapshot will observe the host-injected growth and + // trip the defensive path. + _, err := sv.Send(ctx, "test", map[string]any{"hello": "world"}) + require.Error(t, err, "Send must return the restore error") + assert.ErrorIs(t, err, ErrMemoryGrew, "error must wrap ErrMemoryGrew") + + assert.False(t, sv.IsHealthy(), "supervisor must be marked unhealthy after grow detected") + + // The grown tail must have been zeroed so no leftover guest data remains + // in the (now-unhealthy but still-instantiated) module. + tail, readOK := mem.Read(origSize, grownSize-origSize) + require.True(t, readOK) + expected := make([]byte, grownSize-origSize) + assert.Equal(t, expected, []byte(tail), "tail must be zero-filled, not contain poison bytes") +} diff --git a/internal/execution/wasm/pool.go b/internal/execution/wasm/pool.go new file mode 100644 index 0000000..1237b52 --- /dev/null +++ b/internal/execution/wasm/pool.go @@ -0,0 +1,62 @@ +package wasm + +import ( + "context" + + "go.uber.org/zap" +) + +// poolItem is the interface satisfied by any item that can be shut down when +// draining a pool. +type poolItem interface { + Shutdown(ctx context.Context) error +} + +// drainPool receives up to cap(pool) items from the channel and calls Shutdown +// on each. This helper is only used when the caller knows the pool is full. +func drainPool[T poolItem](ctx context.Context, pool chan T, log *zap.Logger) error { + if pool == nil { + return nil + } + + var firstErr error + for i := 0; i < cap(pool); i++ { + select { + case item := <-pool: + if err := item.Shutdown(ctx); err != nil { + log.Error("error shutting down pool item", zap.Error(err)) + if firstErr == nil { + firstErr = err + } + } + case <-ctx.Done(): + log.Warn("drainPool: context cancelled, some items may not be shut down", + zap.Int("remaining", cap(pool)-i)) + return ctx.Err() + } + } + return firstErr +} + +// drainBufferedPool shuts down only items currently buffered in the channel. It +// is safe for startup-failure paths where the pool may be only partially filled. +func drainBufferedPool[T poolItem](ctx context.Context, pool chan T, log *zap.Logger) error { + if pool == nil { + return nil + } + + var firstErr error + for { + select { + case item := <-pool: + if err := item.Shutdown(ctx); err != nil { + log.Error("error shutting down pool item", zap.Error(err)) + if firstErr == nil { + firstErr = err + } + } + default: + return firstErr + } + } +} diff --git a/internal/execution/wasm/snapshot.go b/internal/execution/wasm/snapshot.go new file mode 100644 index 0000000..b575091 --- /dev/null +++ b/internal/execution/wasm/snapshot.go @@ -0,0 +1,96 @@ +package wasm + +import ( + "fmt" + + "github.com/tetratelabs/wazero/api" +) + +// SnapshotStrategy abstracts how linear-memory snapshots are taken and +// restored. The default implementation (FullMemcpyStrategy) copies the entire +// memory region on every restore. Future strategies may track dirty pages to +// reduce restore cost for large modules. +// +// Contract (I-4 fix — document ordering and concurrency expectations): +// - Take must be called at least once before Restore. +// - Take may be called multiple times; each call overwrites the previous +// snapshot. +// - Calling Restore without a prior Take is a no-op (returns nil) but +// logically meaningless. +// - Implementations are NOT safe for concurrent calls to Take / Restore. +// The caller (wasmSupervisor) must serialise access. +type SnapshotStrategy interface { + // Take captures the current state of the WASM linear memory. + // It is called once after module initialisation. + Take(mem api.Memory) error + + // Restore writes the captured snapshot back into WASM linear memory. + // It is called after every request so the next request sees a clean state. + Restore(mem api.Memory) error + + // Close releases any resources held by the strategy. It is safe to call on a + // zero-value or never-initialised strategy. + Close() error +} + +// --------------------------------------------------------------------------- +// FullMemcpyStrategy +// --------------------------------------------------------------------------- + +// FullMemcpyStrategy is the always-available baseline: it copies the entire +// linear memory into a []byte on Take and writes it all back on Restore. +// Cost is O(total memory size) regardless of how many pages were actually +// written during the request. +type FullMemcpyStrategy struct { + snapshot []byte +} + +// NewFullMemcpyStrategy returns a ready-to-use FullMemcpyStrategy. +func NewFullMemcpyStrategy() *FullMemcpyStrategy { + return &FullMemcpyStrategy{} +} + +// Take implements SnapshotStrategy. +func (f *FullMemcpyStrategy) Take(mem api.Memory) error { + if mem == nil { + f.snapshot = nil + return nil + } + + size := mem.Size() + if size == 0 { + f.snapshot = nil + return nil + } + + buf, ok := mem.Read(0, size) + if !ok { + return fmt.Errorf("snapshot: could not read %d bytes of linear memory", size) + } + + // Make an owned copy — mem.Read may return a slice backed by the wazero + // memory buffer which could be modified by subsequent guest execution. + f.snapshot = make([]byte, len(buf)) + copy(f.snapshot, buf) + + return nil +} + +// Restore implements SnapshotStrategy. +func (f *FullMemcpyStrategy) Restore(mem api.Memory) error { + if f.snapshot == nil || mem == nil { + return nil + } + + if !mem.Write(0, f.snapshot) { + return fmt.Errorf("snapshot: failed to restore %d bytes", len(f.snapshot)) + } + + return nil +} + +// Close implements SnapshotStrategy. FullMemcpyStrategy holds no OS resources. +func (f *FullMemcpyStrategy) Close() error { + f.snapshot = nil + return nil +} diff --git a/internal/execution/wasm/snapshot_test.go b/internal/execution/wasm/snapshot_test.go new file mode 100644 index 0000000..1abe885 --- /dev/null +++ b/internal/execution/wasm/snapshot_test.go @@ -0,0 +1,258 @@ +//go:build !plan9 + +package wasm + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// leb128Encode encodes a uint32 as an unsigned LEB128 byte slice. +func leb128Encode(v uint32) []byte { + var buf []byte + for { + b := byte(v & 0x7f) + v >>= 7 + if v != 0 { + b |= 0x80 + } + buf = append(buf, b) + if v == 0 { + break + } + } + return buf +} + +// buildTestMemoryModule constructs a minimal WASM binary that declares exactly +// `pages` pages (64 KiB each) of linear memory. wazero's Module.Memory() +// returns the first memory regardless of whether it is exported, so no export +// section is needed. +// +// Binary layout (WASM spec §5): +// +// \0asm (magic) + version (1) + memory section +// +// This mirrors buildMinimalMemoryModule from snapshot_bench_test.go but +// accepts *testing.T so it can be used in unit tests. +func buildTestMemoryModule(t *testing.T, pages int) []byte { + t.Helper() + + // Memory section payload: count=1, limits type=0x00 (min only), min=pages + pagesLEB := leb128Encode(uint32(pages)) + memPayload := append([]byte{0x01, 0x00}, pagesLEB...) + + // Section: id=5 (memory), size=len(payload), payload + memSec := append([]byte{0x05}, append(leb128Encode(uint32(len(memPayload))), memPayload...)...) + + // Full module: magic + version + memory section + module := []byte{0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00} + module = append(module, memSec...) + return module +} + +// newTestWazeroMemory instantiates a minimal WASM module with the given number +// of 64 KiB pages and returns its api.Memory. The runtime and module are +// closed via t.Cleanup. +func newTestWazeroMemory(t *testing.T, pages int) api.Memory { + t.Helper() + ctx := context.Background() + + wasmBin := buildTestMemoryModule(t, pages) + + rt := wazero.NewRuntime(ctx) + t.Cleanup(func() { _ = rt.Close(ctx) }) + + compiled, err := rt.CompileModule(ctx, wasmBin) + require.NoError(t, err, "compile minimal module") + t.Cleanup(func() { _ = compiled.Close(ctx) }) + + mod, err := rt.InstantiateModule(ctx, compiled, wazero.NewModuleConfig().WithName("")) + require.NoError(t, err, "instantiate minimal module") + t.Cleanup(func() { _ = mod.Close(ctx) }) + + mem := mod.Memory() + require.NotNil(t, mem, "module must have linear memory") + return mem +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_TakeRestoreRoundtrip +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_TakeRestoreRoundtrip verifies the core contract: +// after Take, mutating the memory and calling Restore brings it back to the +// snapshotted state. +func TestFullMemcpyStrategy_TakeRestoreRoundtrip(t *testing.T) { + mem := newTestWazeroMemory(t, 1) // 1 page = 64 KiB + + // Fill memory with a known pattern. + size := mem.Size() + pattern := make([]byte, size) + for i := range pattern { + pattern[i] = byte(i % 251) + } + require.True(t, mem.Write(0, pattern), "write initial pattern") + + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + // Take snapshot. + require.NoError(t, s.Take(mem)) + + // Overwrite memory with zeros (simulated guest write). + zeros := make([]byte, size) + require.True(t, mem.Write(0, zeros), "overwrite with zeros") + + after, ok := mem.Read(0, size) + require.True(t, ok) + require.Equal(t, zeros, []byte(after), "sanity: memory should be all-zeros now") + + // Restore and verify memory matches original pattern. + require.NoError(t, s.Restore(mem)) + + restored, ok := mem.Read(0, size) + require.True(t, ok) + assert.Equal(t, pattern, []byte(restored), "Restore must return memory to snapshotted state") +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_TakeNilMemory +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_TakeNilMemory checks that Take(nil) is safe and +// results in a nil snapshot (no panic, no error). +func TestFullMemcpyStrategy_TakeNilMemory(t *testing.T) { + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + require.NoError(t, s.Take(nil)) + assert.Nil(t, s.snapshot, "snapshot should be nil after Take(nil)") + + // A subsequent Restore(nil) must also be a no-op. + require.NoError(t, s.Restore(nil)) +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_RestoreBeforeTake +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_RestoreBeforeTake verifies that calling Restore on a +// zero-value / never-initialised strategy is a no-op that does not modify +// memory or return an error. +func TestFullMemcpyStrategy_RestoreBeforeTake(t *testing.T) { + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + mem := newTestWazeroMemory(t, 1) + size := mem.Size() + + // Fill with recognisable data. + data := make([]byte, size) + for i := range data { + data[i] = byte(i % 97) + } + require.True(t, mem.Write(0, data), "write initial data") + + // Snapshot the state so we can compare after Restore. + before, ok := mem.Read(0, size) + require.True(t, ok) + beforeCopy := make([]byte, len(before)) + copy(beforeCopy, before) + + // Restore before any Take — must be a no-op (snapshot is nil). + require.NoError(t, s.Restore(mem)) + + after, ok := mem.Read(0, size) + require.True(t, ok) + assert.Equal(t, beforeCopy, []byte(after), "Restore before Take must leave memory unchanged") +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_EmptyMemory +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_EmptyMemory checks that a zero-size case in snapshot +// logic produces a nil snapshot (size==0 branch). We test this by calling +// Take with nil (which mirrors the zero-size code path in the implementation: +// both nil and zero-size result in snapshot=nil). +func TestFullMemcpyStrategy_EmptyMemory(t *testing.T) { + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + // Take(nil) exercises the "mem == nil" branch which sets snapshot=nil. + require.NoError(t, s.Take(nil)) + assert.Nil(t, s.snapshot, "snapshot must be nil when memory is nil") + + // Restore(nil) must be a no-op. + require.NoError(t, s.Restore(nil)) +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_CloseIdempotent +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_CloseIdempotent verifies that Close can be called +// multiple times without panicking or returning an error. +func TestFullMemcpyStrategy_CloseIdempotent(t *testing.T) { + s := NewFullMemcpyStrategy() + + mem := newTestWazeroMemory(t, 1) + require.NoError(t, s.Take(mem)) + assert.NotNil(t, s.snapshot, "snapshot should be set after Take") + + // First Close should succeed and clear the snapshot. + require.NoError(t, s.Close()) + assert.Nil(t, s.snapshot, "snapshot should be nil after first Close") + + // Second Close must also be safe. + require.NoError(t, s.Close()) +} + +// --------------------------------------------------------------------------- +// TestFullMemcpyStrategy_SnapshotIsOwnedCopy +// --------------------------------------------------------------------------- + +// TestFullMemcpyStrategy_SnapshotIsOwnedCopy confirms that the snapshot is an +// independent copy of the memory buffer, not an alias into wazero's backing +// store. If Take stored a slice backed by the same underlying array, a +// subsequent guest write would silently corrupt the snapshot. +func TestFullMemcpyStrategy_SnapshotIsOwnedCopy(t *testing.T) { + mem := newTestWazeroMemory(t, 1) + size := mem.Size() + + // Write distinct pattern. + pattern := make([]byte, size) + for i := range pattern { + pattern[i] = byte(i % 199) + } + require.True(t, mem.Write(0, pattern), "write pattern") + + s := NewFullMemcpyStrategy() + t.Cleanup(func() { require.NoError(t, s.Close()) }) + + require.NoError(t, s.Take(mem)) + + // Overwrite memory entirely with 0xFF. + corrupt := make([]byte, size) + for i := range corrupt { + corrupt[i] = 0xFF + } + require.True(t, mem.Write(0, corrupt)) + + // Restore: snapshot must be independent of the wazero buffer. + require.NoError(t, s.Restore(mem)) + + restored, ok := mem.Read(0, size) + require.True(t, ok) + assert.Equal(t, pattern, []byte(restored), "snapshot must be independent copy of original data") +} diff --git a/internal/execution/wasm/supervisor.go b/internal/execution/wasm/supervisor.go new file mode 100644 index 0000000..02dbb80 --- /dev/null +++ b/internal/execution/wasm/supervisor.go @@ -0,0 +1,221 @@ +package wasm + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" + "go.uber.org/zap" +) + +// ErrMemoryGrew indicates that the guest expanded linear memory during a +// request beyond the size captured at snapshot time. wazero (and the WASM +// spec) does not allow shrinking linear memory, so the original snapshotted +// state cannot be fully reproduced and the supervisor must be discarded. +var ErrMemoryGrew = errors.New("wasm: linear memory grew beyond snapshotted size") + +// wasmSupervisor manages a single instantiated WASM module. After the module +// is initialised its linear memory is snapshotted; the snapshot is restored +// after every Send so that the next request sees a clean initial state. This +// gives cheap warm-start semantics without re-compiling the module. +type wasmSupervisor struct { + mu sync.Mutex + + runtime wazero.Runtime + compiled wazero.CompiledModule + modCfg wazero.ModuleConfig + + mod api.Module + adapter *wasmAdapter + + // strategy implements snapshot/restore. The generic backend intentionally + // uses the portable full-memory copy strategy; dirty-page optimisation is a + // separate future concern. + strategy SnapshotStrategy + + // healthy is true when the supervisor is in a known-good state and can be + // safely returned to the pool. It is set to false when restoreSnapshot fails, + // indicating the WASM module's memory state is undefined. + healthy bool + + // snapshotSize is the linear-memory size (in bytes) captured at Take time. + // restoreSnapshot compares this against the post-request memory size to + // detect memory.grow during execution — wazero cannot shrink memory, so + // any growth invalidates the snapshot and must mark the supervisor unhealthy. + snapshotSize uint32 + + timeout time.Duration + log *zap.Logger +} + +func newWasmSupervisor( + rt wazero.Runtime, + compiled wazero.CompiledModule, + modCfg wazero.ModuleConfig, + timeout time.Duration, + log *zap.Logger, +) *wasmSupervisor { + return &wasmSupervisor{ + runtime: rt, + compiled: compiled, + modCfg: modCfg, + timeout: timeout, + log: log.Named("supervisor_wasm"), + } +} + +// Start instantiates the compiled module, runs any WASI start function, then +// snapshots linear memory. +func (s *wasmSupervisor) Start(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.mod != nil { + return nil + } + + s.log.Debug("instantiating wasm module") + + // Apply start functions on top of the provided (sandboxed) module config. + instCfg := s.modCfg.WithStartFunctions("_initialize", "_start") + + mod, err := s.runtime.InstantiateModule(ctx, s.compiled, instCfg) + if err != nil { + return fmt.Errorf("wasm: instantiate module: %w", err) + } + + s.mod = mod + s.adapter = newWasmAdapter(mod, s.log) + s.healthy = true + + // Snapshot linear memory so we can restore it before each request. + s.strategy = NewFullMemcpyStrategy() + if err := s.takeSnapshot(); err != nil { + _ = s.strategy.Close() + _ = mod.Close(ctx) + s.mod = nil + return fmt.Errorf("wasm: snapshot memory: %w", err) + } + + memSize := uint32(0) + if m := s.mod.Memory(); m != nil { + memSize = m.Size() + } + s.log.Debug("wasm module ready", + zap.Uint32("snapshot_bytes", memSize), + zap.String("strategy", fmt.Sprintf("%T", s.strategy)), + ) + + return nil +} + +// Send calls the guest's evaluate function, then restores linear memory from +// the snapshot so the next request starts from a clean state. +func (s *wasmSupervisor) Send( + ctx context.Context, + method string, + data map[string]any, +) (map[string]any, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.mod == nil || s.adapter == nil { + return nil, fmt.Errorf("wasm: supervisor not started") + } + + result, err := s.adapter.send(ctx, method, data, s.timeout) + + // Restore memory snapshot to keep state clean for the next request. + // If restore fails, mark the supervisor unhealthy so the dispatcher + // discards it rather than returning it to the pool with undefined state. + if restoreErr := s.restoreSnapshot(); restoreErr != nil { + s.log.Error("failed to restore memory snapshot — marking supervisor unhealthy", zap.Error(restoreErr)) + s.healthy = false + if err == nil { + err = fmt.Errorf("wasm: restore snapshot: %w", restoreErr) + } + } + + return result, err +} + +// IsHealthy reports whether the supervisor is in a known-good state. +// Safe to call without holding s.mu (acquires the lock internally). (I-3 fix) +func (s *wasmSupervisor) IsHealthy() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.healthy +} + +// Shutdown closes the module instance and releases resources. +func (s *wasmSupervisor) Shutdown(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.mod == nil { + return nil + } + + s.log.Debug("shutting down wasm module instance") + + if err := s.mod.Close(ctx); err != nil { + return fmt.Errorf("wasm: close module: %w", err) + } + + s.mod = nil + s.adapter = nil + + if err := s.strategy.Close(); err != nil { + s.log.Warn("failed to close snapshot strategy", zap.Error(err)) + } + + return nil +} + +// takeSnapshot captures the guest's linear memory via the active strategy and +// records the memory size so restoreSnapshot can detect post-snapshot growth. +// Must be called with s.mu held. +func (s *wasmSupervisor) takeSnapshot() error { + mem := s.mod.Memory() + if mem == nil { + s.snapshotSize = 0 + return nil + } + if err := s.strategy.Take(mem); err != nil { + return err + } + s.snapshotSize = mem.Size() + return nil +} + +// restoreSnapshot restores the guest's linear memory from the last snapshot +// via the active strategy. If the guest grew memory during the request +// (memory.grow), it zero-fills the tail beyond the snapshotted size to prevent +// leaking guest data into the next request and returns ErrMemoryGrew so the +// caller (Send) marks the supervisor unhealthy and discards it. Must be called +// with s.mu held. +func (s *wasmSupervisor) restoreSnapshot() error { + if s.mod == nil { + return nil + } + mem := s.mod.Memory() + if mem == nil { + return nil + } + if err := s.strategy.Restore(mem); err != nil { + return err + } + if cur := mem.Size(); cur > s.snapshotSize { + tail := cur - s.snapshotSize + zeros := make([]byte, tail) + if !mem.Write(s.snapshotSize, zeros) { + return fmt.Errorf("wasm: memory grew by %d bytes; zero-fill failed: %w", tail, ErrMemoryGrew) + } + return fmt.Errorf("wasm: memory grew by %d bytes (tail zero-filled): %w", tail, ErrMemoryGrew) + } + return nil +} diff --git a/internal/execution/wasm/testdata/echo.wasm b/internal/execution/wasm/testdata/echo.wasm new file mode 100644 index 0000000000000000000000000000000000000000..4e5072d27666d3f20c884e94a126ed1e602facca GIT binary patch literal 241 zcmX}ny$ZrW5Cq`az5JPrT3QLw$|tb6?%pIp(nvzUKd?&cd-!Zt7J_Icf};VO{g|0* zQEnRnAek1@NmppcYm7odBanD%qNZxv%~27Sb=|Ijq&k%KzT8!i^ej3N={y#SnRw)q zW4%;rPx2p>gZlArP;VW+5f0L$J%+s426XNak{e@0uQcxKggA!*d9Y3Com%>&8NJXU QebZ5{)}{gv8w7mx1BL}9Jpcdz literal 0 HcmV?d00001 diff --git a/internal/execution/wasm/testdata/echo.wat b/internal/execution/wasm/testdata/echo.wat new file mode 100644 index 0000000..e49d21b --- /dev/null +++ b/internal/execution/wasm/testdata/echo.wat @@ -0,0 +1,66 @@ +;; echo.wat — minimal guest ABI fixture for wasm package tests. +;; +;; Implements: +;; alloc(size i32) i32 — bump allocator; heap pointer stored at mem[0..3] +;; evaluate(req_ptr i32, req_len i32) i32 +;; — ignores input; always returns fixed response {"ok":true} +;; as a length-prefixed blob: [4-byte LE uint32 len][JSON bytes] +;; +;; The compiled binary (echo.wasm) was generated from this source. +;; {"ok":true} is 11 bytes: 7b 22 6f 6b 22 3a 74 72 75 65 7d +;; +;; Design note: the heap pointer is stored IN linear memory (offset 0, 4 bytes) +;; rather than in a WASM global. This means the snapshot/restore mechanism +;; (which copies linear memory) correctly resets the allocator state between +;; requests. If a global were used, snapshot/restore would not reset it and +;; the heap pointer would keep advancing across requests. +(module + (memory (export "memory") 1) + + ;; mem[0..3]: heap pointer (i32, LE), initialized to 4 + ;; (offset 0..3 reserved for the pointer itself, so allocations start at 4) + (data (i32.const 0) "\04\00\00\00") + + ;; alloc(size i32) i32 + (func (export "alloc") (param $size i32) (result i32) + (local $ptr i32) + ;; ptr = i32.load(mem[0]) + (local.set $ptr (i32.load (i32.const 0))) + ;; mem[0] = ptr + size + (i32.store (i32.const 0) (i32.add (local.get $ptr) (local.get $size))) + (local.get $ptr) + ) + + ;; evaluate(req_ptr i32, req_len i32) i32 + ;; Returns pointer P where: + ;; mem[P .. P+4) = little-endian uint32 length (11) + ;; mem[P+4 .. P+15) = {"ok":true} + (func (export "evaluate") (param $req_ptr i32) (param $req_len i32) (result i32) + (local $resp_ptr i32) + ;; resp_ptr = i32.load(mem[0]) + (local.set $resp_ptr (i32.load (i32.const 0))) + ;; mem[0] = resp_ptr + 15 (4 bytes length prefix + 11 bytes JSON) + (i32.store (i32.const 0) (i32.add (local.get $resp_ptr) (i32.const 15))) + + ;; Write little-endian length prefix: 11, 0, 0, 0 + (i32.store8 offset=0 (local.get $resp_ptr) (i32.const 11)) + (i32.store8 offset=1 (local.get $resp_ptr) (i32.const 0)) + (i32.store8 offset=2 (local.get $resp_ptr) (i32.const 0)) + (i32.store8 offset=3 (local.get $resp_ptr) (i32.const 0)) + + ;; Write {"ok":true} + (i32.store8 offset=4 (local.get $resp_ptr) (i32.const 0x7b)) ;; { + (i32.store8 offset=5 (local.get $resp_ptr) (i32.const 0x22)) ;; " + (i32.store8 offset=6 (local.get $resp_ptr) (i32.const 0x6f)) ;; o + (i32.store8 offset=7 (local.get $resp_ptr) (i32.const 0x6b)) ;; k + (i32.store8 offset=8 (local.get $resp_ptr) (i32.const 0x22)) ;; " + (i32.store8 offset=9 (local.get $resp_ptr) (i32.const 0x3a)) ;; : + (i32.store8 offset=10 (local.get $resp_ptr) (i32.const 0x74)) ;; t + (i32.store8 offset=11 (local.get $resp_ptr) (i32.const 0x72)) ;; r + (i32.store8 offset=12 (local.get $resp_ptr) (i32.const 0x75)) ;; u + (i32.store8 offset=13 (local.get $resp_ptr) (i32.const 0x65)) ;; e + (i32.store8 offset=14 (local.get $resp_ptr) (i32.const 0x7d)) ;; } + + (local.get $resp_ptr) + ) +) diff --git a/internal/execution/wasm/testhelpers_test.go b/internal/execution/wasm/testhelpers_test.go new file mode 100644 index 0000000..73161e5 --- /dev/null +++ b/internal/execution/wasm/testhelpers_test.go @@ -0,0 +1,46 @@ +package wasm + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" +) + +// echoWasmBytes reads the pre-compiled echo fixture from testdata/echo.wasm. +// The fixture is a minimal WASM module that: +// - exports a bump-allocator alloc(size i32) i32 +// - exports evaluate(req_ptr i32, req_len i32) i32 that always returns the +// fixed JSON {"ok":true} as a 4-byte LE length-prefixed blob +// +// The WAT source is kept alongside the binary at testdata/echo.wat for +// reference. The binary was generated using a pure-Go WASM assembler so that +// the test suite requires no external toolchain. +func echoWasmBytes(t *testing.T) []byte { + t.Helper() + path := echoModulePath(t) + b, err := os.ReadFile(path) + require.NoError(t, err, "read echo.wasm fixture") + return b +} + +// compileEchoModule creates a wazero runtime, wires up WASI host functions, +// and compiles the echo WASM bytes into a CompiledModule. The runtime must be +// closed by the caller. +func compileEchoModule(t *testing.T, ctx context.Context, wasmBytes []byte) (wazero.Runtime, wazero.CompiledModule) { + t.Helper() + + rt := wazero.NewRuntime(ctx) + _, err := wasi_snapshot_preview1.Instantiate(ctx, rt) + require.NoError(t, err, "instantiate WASI") + + compiled, err := rt.CompileModule(ctx, wasmBytes) + require.NoError(t, err, "compile echo module") + + t.Cleanup(func() { _ = compiled.Close(ctx) }) + + return rt, compiled +} diff --git a/scripts/demo-wasm.sh b/scripts/demo-wasm.sh new file mode 100755 index 0000000..5ef0a99 --- /dev/null +++ b/scripts/demo-wasm.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PORT="${PORT:-}" +HOST="127.0.0.1" +if [[ -z "${PORT}" ]]; then + PORT="$(python3 - <<'PY' +import socket +s = socket.socket() +s.bind(('127.0.0.1', 0)) +print(s.getsockname()[1]) +s.close() +PY +)" +fi +BASE_URL="http://${HOST}:${PORT}" +BIN="${ROOT}/bin/shimmy-demo" +DEMO_DIR="${ROOT}/examples/demo-stateful" +WASM="${DEMO_DIR}/eval.wasm" +LOG="${ROOT}/.demo-wasm-server.log" + +for cmd in go curl python3; do + if ! command -v "${cmd}" >/dev/null 2>&1; then + echo "error: ${cmd} is required" >&2 + exit 1 + fi +done + +echo "==> Building shimmy demo binary" +(cd "${ROOT}" && go build -trimpath -buildvcs=false -o "${BIN}" .) + +echo "==> Building demo evaluator: examples/demo-stateful -> eval.wasm" +(cd "${DEMO_DIR}" && GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o "${WASM}" .) + +rm -f "${LOG}" +server_pid="" +cleanup() { + if [[ -n "${server_pid}" ]] && kill -0 "${server_pid}" 2>/dev/null; then + kill "${server_pid}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${server_pid}" 2>/dev/null || return 0 + sleep 0.1 + done + kill -KILL "${server_pid}" 2>/dev/null || true + wait "${server_pid}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +echo "==> Starting shimmy on ${BASE_URL}" +( + cd "${ROOT}" + exec env \ + LOG_LEVEL=error \ + FUNCTION_INTERFACE=wasm \ + FUNCTION_WASM_MODULE="${WASM}" \ + FUNCTION_MAX_PROCS=1 \ + FUNCTION_TIMEOUT=5s \ + "${BIN}" serve --host "${HOST}" --port "${PORT}" +) >"${LOG}" 2>&1 & +server_pid="$!" + +for _ in {1..60}; do + if ! kill -0 "${server_pid}" 2>/dev/null; then + echo "server exited early; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 + fi + if curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + break + fi + sleep 0.2 +done + +if ! curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + echo "server did not become ready; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 +fi + +request_eval() { + local response="$1" + local answer="$2" + curl -fsS \ + -X POST "${BASE_URL}/" \ + -H 'Content-Type: application/json' \ + -H 'Command: eval' \ + --data "{\"response\":\"${response}\",\"answer\":\"${answer}\",\"params\":{}}" +} + +echo "==> Request #1: correct answer" +resp1="$(request_eval 42 42)" +echo "${resp1}" | python3 -m json.tool + +echo "==> Request #2: wrong answer; guest global state should still be reset" +resp2="$(request_eval 41 42)" +echo "${resp2}" | python3 -m json.tool + +RESP1="${resp1}" RESP2="${resp2}" python3 - <<'PY' +import json, os, sys +r1 = json.loads(os.environ['RESP1'])['result'] +r2 = json.loads(os.environ['RESP2'])['result'] +checks = [ + (r1.get('is_correct') is True, 'request #1 should be correct'), + (r2.get('is_correct') is False, 'request #2 should be incorrect'), + (r1.get('guest_invocation_count') == 1, 'request #1 should see guest_invocation_count == 1'), + (r2.get('guest_invocation_count') == 1, 'request #2 should still see guest_invocation_count == 1'), + (r1.get('snapshot_isolation_ok') is True, 'request #1 snapshot flag should be true'), + (r2.get('snapshot_isolation_ok') is True, 'request #2 snapshot flag should be true'), +] +failed = [msg for ok, msg in checks if not ok] +if failed: + print('DEMO FAILED:', *failed, sep='\n- ', file=sys.stderr) + sys.exit(1) +print('\n✅ Demo passed: two HTTP requests hit the same warm WASM evaluator, but guest global state was reset after each request.') +PY + +printf '\nServer log: %s\n' "${LOG}" From a9dc47a35c118cb48d6cdd66e6cada4a5329f71b Mon Sep 17 00:00:00 2001 From: bkmashiro <53376445+bkmashiro@users.noreply.github.com> Date: Fri, 19 Jun 2026 13:45:04 +0100 Subject: [PATCH 2/7] docs: add cpp wasm evaluator demo --- README.md | 18 ++- examples/demo-cpp-compare/README.md | 50 ++++++ examples/demo-cpp-compare/evaluator.cpp | 203 ++++++++++++++++++++++++ scripts/demo-cpp-wasm.sh | 140 ++++++++++++++++ 4 files changed, 406 insertions(+), 5 deletions(-) create mode 100644 examples/demo-cpp-compare/README.md create mode 100644 examples/demo-cpp-compare/evaluator.cpp create mode 100755 scripts/demo-cpp-wasm.sh diff --git a/README.md b/README.md index 71bea78..4206de7 100644 --- a/README.md +++ b/README.md @@ -247,9 +247,11 @@ wolframscript -file evaluation.wl /tmp/shimmy/abc/request-data-123 /tmp/shimmy/a #### WebAssembly (`--interface wasm`, opt-in) -The WASM interface executes a pre-built WASI module in-process using wazero. This -is an execution backend only: Shimmy still owns the public HTTP/API contract, -request validation, command routing, cases, and response handling. +The WASM interface executes a pre-built WebAssembly module in-process using +wazero. The module can be a WASI module or a small freestanding module as long +as it exports the Shimmy adapter ABI below. This is an execution backend only: +Shimmy still owns the public HTTP/API contract, request validation, command +routing, cases, and response handling. Shimmy does not compile evaluator source code at request time and does not infer a source language from dependency files. Language-specific work belongs in build @@ -304,9 +306,14 @@ GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm . # Rust cargo build --target wasm32-wasip1 --release -# C/C++ +# C/C++ with wasi-sdk /opt/wasi-sdk/bin/clang --target=wasm32-wasip1 ... -o eval.wasm /opt/wasi-sdk/bin/clang++ --target=wasm32-wasip1 ... -o eval.wasm + +# Freestanding C++ with Zig's clang driver +zig c++ -target wasm32-freestanding -nostdlib \ + -Wl,--no-entry -Wl,--export=alloc -Wl,--export=evaluate -Wl,--export-memory \ + -o eval.wasm evaluator.cpp ``` The backend keeps a warm module instance pool and restores a full linear-memory @@ -314,10 +321,11 @@ snapshot after each request. This gives warm reuse without leaking guest mutable state between requests. Dirty-page restore, Python runtimes, Pyodide, and package bundling are intentionally out of scope for this generic backend. -Try the state-isolation example: +Try the state-isolation examples: ```shell scripts/demo-wasm.sh +scripts/demo-cpp-wasm.sh ``` ### Sandboxed Execution (Linux only, experimental) diff --git a/examples/demo-cpp-compare/README.md b/examples/demo-cpp-compare/README.md new file mode 100644 index 0000000..c3b21d3 --- /dev/null +++ b/examples/demo-cpp-compare/README.md @@ -0,0 +1,50 @@ +# C++ Compare Evaluator for Shimmy WASM + +This example is a minimal C++ evaluator that compiles to a WebAssembly module and runs through Shimmy's opt-in WASM backend. + +It intentionally mirrors the shape of a simple Lambda Feedback evaluator: + +- input: `response`, `answer`, and optional feedback strings in `params` +- output: `{ "command": "eval", "result": { "is_correct", "feedback" } }` + +The evaluator also reports `guest_invocation_count` and `snapshot_isolation_ok` so the demo can prove that Shimmy reuses a warm WASM instance while restoring guest memory after each request. + +## Build + +The demo uses Zig's clang-compatible C++ driver because the default macOS Apple clang does not ship a WebAssembly target: + +```bash +zig c++ \ + -target wasm32-freestanding \ + -Oz \ + -nostdlib \ + -fno-exceptions \ + -fno-rtti \ + -Wl,--no-entry \ + -Wl,--export=alloc \ + -Wl,--export=evaluate \ + -Wl,--export-memory \ + -Wl,--initial-memory=2097152 \ + -o eval.wasm \ + evaluator.cpp +``` + +The source avoids libc/libc++ and implements only the small amount of JSON handling needed for this evaluator, so it can be built as a small freestanding WebAssembly module. Real C++ evaluators can use a richer build setup, but still need to expose the same Shimmy WASM ABI: + +```text +memory +alloc(size: i32) -> i32 +evaluate(req_ptr: i32, req_len: i32) -> i32 +``` + +`evaluate` returns a pointer to `[uint32 little-endian response_len][response JSON bytes]`. + +## Run the end-to-end demo + +From the repository root: + +```bash +./scripts/demo-cpp-wasm.sh +``` + +The script builds Shimmy, compiles this evaluator to `eval.wasm`, starts Shimmy with `FUNCTION_INTERFACE=wasm`, sends two HTTP requests, and asserts that both requests see `guest_invocation_count == 1`. diff --git a/examples/demo-cpp-compare/evaluator.cpp b/examples/demo-cpp-compare/evaluator.cpp new file mode 100644 index 0000000..64704c9 --- /dev/null +++ b/examples/demo-cpp-compare/evaluator.cpp @@ -0,0 +1,203 @@ +// A tiny C++ evaluation function that can be compiled directly to WebAssembly. +// +// It intentionally avoids libc/libc++ so the build is just a single freestanding +// C++ source file plus Zig's wasm-capable clang driver. The evaluator follows +// Shimmy's internal WASM ABI: +// - export memory +// - alloc(size) -> request pointer +// - evaluate(ptr, len) -> pointer to [u32 little-endian length][JSON bytes] +// +// The business logic is deliberately Lambda Feedback-shaped: compare the +// submitted response with the answer and return feedback from params. + +using u32 = unsigned int; +using i32 = int; +using usize = __SIZE_TYPE__; +using uintptr = __UINTPTR_TYPE__; + +alignas(16) static char request_buffer[256 * 1024]; +alignas(16) static char response_buffer[256 * 1024]; + +// Mutable guest state. Shimmy should restore the warm instance snapshot after +// every request, so this must be 1 for every HTTP call when FUNCTION_MAX_PROCS=1. +static u32 invocation_count = 0; + +static usize cstr_len(const char *s) { + usize n = 0; + while (s[n] != 0) n++; + return n; +} + +static bool bytes_equal(const char *a, usize a_len, const char *b, usize b_len) { + if (a_len != b_len) return false; + for (usize i = 0; i < a_len; i++) { + if (a[i] != b[i]) return false; + } + return true; +} + +static void copy_bytes(char *dst, usize &pos, const char *src, usize len) { + for (usize i = 0; i < len; i++) dst[pos++] = src[i]; +} + +static void append_cstr(char *dst, usize &pos, const char *src) { + copy_bytes(dst, pos, src, cstr_len(src)); +} + +static void append_u32(char *dst, usize &pos, u32 value) { + char tmp[10]; + usize n = 0; + do { + tmp[n++] = char('0' + (value % 10)); + value /= 10; + } while (value != 0); + while (n > 0) dst[pos++] = tmp[--n]; +} + +static void append_json_string(char *dst, usize &pos, const char *src, usize len) { + dst[pos++] = '"'; + for (usize i = 0; i < len; i++) { + char c = src[i]; + if (c == '"' || c == '\\') { + dst[pos++] = '\\'; + dst[pos++] = c; + } else if (c == '\n') { + dst[pos++] = '\\'; + dst[pos++] = 'n'; + } else { + dst[pos++] = c; + } + } + dst[pos++] = '"'; +} + +static bool match_at(const char *json, usize len, usize pos, const char *needle) { + for (usize i = 0; needle[i] != 0; i++) { + if (pos + i >= len || json[pos + i] != needle[i]) return false; + } + return true; +} + +static bool find_json_string(const char *json, usize len, const char *key, + const char *&value, usize &value_len) { + char quoted_key[96]; + usize key_pos = 0; + quoted_key[key_pos++] = '"'; + for (usize i = 0; key[i] != 0 && key_pos + 2 < sizeof(quoted_key); i++) { + quoted_key[key_pos++] = key[i]; + } + quoted_key[key_pos++] = '"'; + quoted_key[key_pos] = 0; + + for (usize i = 0; i < len; i++) { + if (!match_at(json, len, i, quoted_key)) continue; + i += key_pos; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != ':') continue; + i++; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != '"') continue; + i++; + + usize start = i; + while (i < len) { + if (json[i] == '\\') { + i += 2; + continue; + } + if (json[i] == '"') { + value = json + start; + value_len = i - start; + return true; + } + i++; + } + } + return false; +} + +static i32 write_error(const char *message) { + usize pos = 4; + append_cstr(response_buffer, pos, "{\"error\":{\"message\":"); + append_json_string(response_buffer, pos, message, cstr_len(message)); + append_cstr(response_buffer, pos, "}}" + ); + u32 len = u32(pos - 4); + response_buffer[0] = char(len & 0xff); + response_buffer[1] = char((len >> 8) & 0xff); + response_buffer[2] = char((len >> 16) & 0xff); + response_buffer[3] = char((len >> 24) & 0xff); + return i32(uintptr(response_buffer)); +} + +static i32 write_eval_response(bool is_correct, + const char *feedback, + usize feedback_len) { + usize pos = 4; + append_cstr(response_buffer, pos, "{\"command\":\"eval\",\"result\":{"); + append_cstr(response_buffer, pos, "\"is_correct\":"); + append_cstr(response_buffer, pos, is_correct ? "true" : "false"); + append_cstr(response_buffer, pos, ",\"feedback\":"); + append_json_string(response_buffer, pos, feedback, feedback_len); + append_cstr(response_buffer, pos, ",\"guest_invocation_count\":"); + append_u32(response_buffer, pos, invocation_count); + append_cstr(response_buffer, pos, ",\"snapshot_isolation_ok\":"); + append_cstr(response_buffer, pos, invocation_count == 1 ? "true" : "false"); + append_cstr(response_buffer, pos, "}}" + ); + + u32 len = u32(pos - 4); + response_buffer[0] = char(len & 0xff); + response_buffer[1] = char((len >> 8) & 0xff); + response_buffer[2] = char((len >> 16) & 0xff); + response_buffer[3] = char((len >> 24) & 0xff); + return i32(uintptr(response_buffer)); +} + +extern "C" i32 alloc(i32 size) { + if (size <= 0 || usize(size) > sizeof(request_buffer)) return 0; + return i32(uintptr(request_buffer)); +} + +extern "C" i32 evaluate(i32 req_ptr, i32 req_len) { + if (req_ptr == 0 || req_len <= 0) return write_error("empty request"); + + const char *json = reinterpret_cast(uintptr(req_ptr)); + usize len = usize(req_len); + + const char *method = nullptr; + const char *response = nullptr; + const char *answer = nullptr; + const char *correct_feedback = nullptr; + const char *incorrect_feedback = nullptr; + usize method_len = 0; + usize response_len = 0; + usize answer_len = 0; + usize correct_feedback_len = 0; + usize incorrect_feedback_len = 0; + + if (!find_json_string(json, len, "method", method, method_len)) return write_error("missing method"); + if (!bytes_equal(method, method_len, "eval", 4)) return write_error("unsupported method"); + if (!find_json_string(json, len, "response", response, response_len)) return write_error("missing response"); + if (!find_json_string(json, len, "answer", answer, answer_len)) return write_error("missing answer"); + + bool has_correct_feedback = find_json_string(json, len, "correct_response_feedback", correct_feedback, correct_feedback_len); + bool has_incorrect_feedback = find_json_string(json, len, "incorrect_response_feedback", incorrect_feedback, incorrect_feedback_len); + + invocation_count++; + + bool is_correct = bytes_equal(response, response_len, answer, answer_len); + if (is_correct) { + if (!has_correct_feedback) { + correct_feedback = "Correct"; + correct_feedback_len = 7; + } + return write_eval_response(true, correct_feedback, correct_feedback_len); + } + + if (!has_incorrect_feedback) { + incorrect_feedback = "Incorrect"; + incorrect_feedback_len = 9; + } + return write_eval_response(false, incorrect_feedback, incorrect_feedback_len); +} diff --git a/scripts/demo-cpp-wasm.sh b/scripts/demo-cpp-wasm.sh new file mode 100755 index 0000000..ccd7a44 --- /dev/null +++ b/scripts/demo-cpp-wasm.sh @@ -0,0 +1,140 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PORT="${PORT:-}" +HOST="127.0.0.1" +if [[ -z "${PORT}" ]]; then + PORT="$(python3 - <<'PY' +import socket +s = socket.socket() +s.bind(('127.0.0.1', 0)) +print(s.getsockname()[1]) +s.close() +PY +)" +fi +BASE_URL="http://${HOST}:${PORT}" +BIN="${ROOT}/bin/shimmy-demo-cpp" +DEMO_DIR="${ROOT}/examples/demo-cpp-compare" +WASM="${DEMO_DIR}/eval.wasm" +LOG="${ROOT}/.demo-cpp-wasm-server.log" + +for cmd in go curl python3 zig; do + if ! command -v "${cmd}" >/dev/null 2>&1; then + echo "error: ${cmd} is required" >&2 + exit 1 + fi +done + +if [[ ! -f "${DEMO_DIR}/evaluator.cpp" ]]; then + echo "error: missing ${DEMO_DIR}/evaluator.cpp" >&2 + exit 1 +fi + +echo "==> Building shimmy demo binary" +(cd "${ROOT}" && go build -trimpath -buildvcs=false -o "${BIN}" .) + +echo "==> Building C++ evaluator: examples/demo-cpp-compare -> eval.wasm" +zig c++ \ + -target wasm32-freestanding \ + -Oz \ + -nostdlib \ + -fno-exceptions \ + -fno-rtti \ + -Wl,--no-entry \ + -Wl,--export=alloc \ + -Wl,--export=evaluate \ + -Wl,--export-memory \ + -Wl,--initial-memory=2097152 \ + -o "${WASM}" \ + "${DEMO_DIR}/evaluator.cpp" + +file "${WASM}" + +rm -f "${LOG}" +server_pid="" +cleanup() { + if [[ -n "${server_pid}" ]] && kill -0 "${server_pid}" 2>/dev/null; then + kill "${server_pid}" 2>/dev/null || true + for _ in {1..20}; do + kill -0 "${server_pid}" 2>/dev/null || return 0 + sleep 0.1 + done + kill -KILL "${server_pid}" 2>/dev/null || true + wait "${server_pid}" 2>/dev/null || true + fi +} +trap cleanup EXIT + +echo "==> Starting shimmy on ${BASE_URL}" +( + cd "${ROOT}" + exec env \ + LOG_LEVEL=error \ + FUNCTION_INTERFACE=wasm \ + FUNCTION_WASM_MODULE="${WASM}" \ + FUNCTION_MAX_PROCS=1 \ + FUNCTION_TIMEOUT=5s \ + "${BIN}" serve --host "${HOST}" --port "${PORT}" +) >"${LOG}" 2>&1 & +server_pid="$!" + +for _ in {1..60}; do + if ! kill -0 "${server_pid}" 2>/dev/null; then + echo "server exited early; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 + fi + if curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + break + fi + sleep 0.2 +done + +if ! curl -fsS "${BASE_URL}/health" >/dev/null 2>&1; then + echo "server did not become ready; log follows:" >&2 + cat "${LOG}" >&2 || true + exit 1 +fi + +request_eval() { + local response="$1" + local answer="$2" + curl -fsS \ + -X POST "${BASE_URL}/" \ + -H 'Content-Type: application/json' \ + -H 'Command: eval' \ + --data "{\"response\":\"${response}\",\"answer\":\"${answer}\",\"params\":{\"correct_response_feedback\":\"Correct!\",\"incorrect_response_feedback\":\"Try again.\"}}" +} + +echo "==> Request #1: correct answer" +resp1="$(request_eval 42 42)" +echo "${resp1}" | python3 -m json.tool + +echo "==> Request #2: wrong answer; C++ guest global state should still reset" +resp2="$(request_eval 41 42)" +echo "${resp2}" | python3 -m json.tool + +RESP1="${resp1}" RESP2="${resp2}" python3 - <<'PY' +import json, os, sys +r1 = json.loads(os.environ['RESP1'])['result'] +r2 = json.loads(os.environ['RESP2'])['result'] +checks = [ + (r1.get('is_correct') is True, 'request #1 should be correct'), + (r1.get('feedback') == 'Correct!', 'request #1 should use correct feedback'), + (r2.get('is_correct') is False, 'request #2 should be incorrect'), + (r2.get('feedback') == 'Try again.', 'request #2 should use incorrect feedback'), + (r1.get('guest_invocation_count') == 1, 'request #1 should see guest_invocation_count == 1'), + (r2.get('guest_invocation_count') == 1, 'request #2 should still see guest_invocation_count == 1'), + (r1.get('snapshot_isolation_ok') is True, 'request #1 snapshot flag should be true'), + (r2.get('snapshot_isolation_ok') is True, 'request #2 snapshot flag should be true'), +] +failed = [msg for ok, msg in checks if not ok] +if failed: + print('DEMO FAILED:', *failed, sep='\n- ', file=sys.stderr) + sys.exit(1) +print('\n✅ C++ WASM demo passed: Shimmy ran a C++ evaluator compiled to WebAssembly and restored guest state after each request.') +PY + +printf '\nServer log: %s\n' "${LOG}" From f6cf54a22660a6b01c25194fe8600e39f69d5097 Mon Sep 17 00:00:00 2001 From: bkmashiro <53376445+bkmashiro@users.noreply.github.com> Date: Fri, 19 Jun 2026 14:05:33 +0100 Subject: [PATCH 3/7] test: cover cpp wasm example --- README.md | 4 +- examples/demo-cpp-compare/README.md | 24 ++++- internal/execution/wasm/cpp_example_test.go | 99 +++++++++++++++++++++ 3 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 internal/execution/wasm/cpp_example_test.go diff --git a/README.md b/README.md index 4206de7..04a0e39 100644 --- a/README.md +++ b/README.md @@ -321,7 +321,9 @@ snapshot after each request. This gives warm reuse without leaking guest mutable state between requests. Dirty-page restore, Python runtimes, Pyodide, and package bundling are intentionally out of scope for this generic backend. -Try the state-isolation examples: +Try the state-isolation examples. These are intentionally small synthetic +evaluators for the Go/C++ artifact path; real language/runtime packaging such as +Pyodide is a separate profile/follow-up. ```shell scripts/demo-wasm.sh diff --git a/examples/demo-cpp-compare/README.md b/examples/demo-cpp-compare/README.md index c3b21d3..903316b 100644 --- a/examples/demo-cpp-compare/README.md +++ b/examples/demo-cpp-compare/README.md @@ -1,13 +1,18 @@ # C++ Compare Evaluator for Shimmy WASM -This example is a minimal C++ evaluator that compiles to a WebAssembly module and runs through Shimmy's opt-in WASM backend. +This example is a minimal, self-contained C++ evaluator that compiles to a +WebAssembly module and runs through Shimmy's opt-in WASM backend. It is not a +port of an existing Lambda Feedback repository; it is a small Go/C++-style +artifact example for validating the generic WASM execution path. It intentionally mirrors the shape of a simple Lambda Feedback evaluator: - input: `response`, `answer`, and optional feedback strings in `params` - output: `{ "command": "eval", "result": { "is_correct", "feedback" } }` -The evaluator also reports `guest_invocation_count` and `snapshot_isolation_ok` so the demo can prove that Shimmy reuses a warm WASM instance while restoring guest memory after each request. +The evaluator also reports `guest_invocation_count` and `snapshot_isolation_ok` +so the demo and integration test can prove that Shimmy reuses a warm WASM +instance while restoring guest memory after each request. ## Build @@ -29,7 +34,10 @@ zig c++ \ evaluator.cpp ``` -The source avoids libc/libc++ and implements only the small amount of JSON handling needed for this evaluator, so it can be built as a small freestanding WebAssembly module. Real C++ evaluators can use a richer build setup, but still need to expose the same Shimmy WASM ABI: +The source avoids libc/libc++ and implements only the small amount of JSON +handling needed for this evaluator, so it can be built as a small freestanding +WebAssembly module. Real C++ evaluators can use a richer build setup, but still +need to expose the same Shimmy WASM ABI: ```text memory @@ -47,4 +55,12 @@ From the repository root: ./scripts/demo-cpp-wasm.sh ``` -The script builds Shimmy, compiles this evaluator to `eval.wasm`, starts Shimmy with `FUNCTION_INTERFACE=wasm`, sends two HTTP requests, and asserts that both requests see `guest_invocation_count == 1`. +The script builds Shimmy, compiles this evaluator to `eval.wasm`, starts Shimmy +with `FUNCTION_INTERFACE=wasm`, sends two HTTP requests, and asserts that both +requests see `guest_invocation_count == 1`. + +The Go test suite also compiles this example when `zig` is available: + +```bash +go test ./internal/execution/wasm -run TestCppCompareExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/internal/execution/wasm/cpp_example_test.go b/internal/execution/wasm/cpp_example_test.go new file mode 100644 index 0000000..b59f912 --- /dev/null +++ b/internal/execution/wasm/cpp_example_test.go @@ -0,0 +1,99 @@ +package wasm + +import ( + "context" + "os/exec" + "path/filepath" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func repoRootFromTest(t *testing.T) string { + t.Helper() + _, filename, _, ok := runtime.Caller(0) + require.True(t, ok, "runtime.Caller failed") + return filepath.Clean(filepath.Join(filepath.Dir(filename), "..", "..", "..")) +} + +func buildCppCompareExample(t *testing.T) string { + t.Helper() + zig, err := exec.LookPath("zig") + if err != nil { + t.Skip("zig is required to compile the C++ WASM example") + } + + root := repoRootFromTest(t) + src := filepath.Join(root, "examples", "demo-cpp-compare", "evaluator.cpp") + out := filepath.Join(t.TempDir(), "eval.wasm") + + cmd := exec.Command(zig, "c++", + "-target", "wasm32-freestanding", + "-Oz", + "-nostdlib", + "-fno-exceptions", + "-fno-rtti", + "-Wl,--no-entry", + "-Wl,--export=alloc", + "-Wl,--export=evaluate", + "-Wl,--export-memory", + "-Wl,--initial-memory=2097152", + "-o", out, + src, + ) + output, err := cmd.CombinedOutput() + require.NoError(t, err, "compile C++ WASM example:\n%s", string(output)) + return out +} + +func TestCppCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + modulePath := buildCppCompareExample(t) + + d := NewDispatcher(Config{ + ModulePath: modulePath, + MaxInstances: 1, + Timeout: 5 * time.Second, + }, newTestLogger(t)) + require.NoError(t, d.Start(context.Background())) + t.Cleanup(func() { _ = d.Shutdown(context.Background()) }) + + correct, err := d.Send(context.Background(), "eval", map[string]any{ + "response": "42", + "answer": "42", + "params": map[string]any{ + "correct_response_feedback": "Correct!", + "incorrect_response_feedback": "Try again.", + }, + }) + require.NoError(t, err) + + wrong, err := d.Send(context.Background(), "eval", map[string]any{ + "response": "41", + "answer": "42", + "params": map[string]any{ + "correct_response_feedback": "Correct!", + "incorrect_response_feedback": "Try again.", + }, + }) + require.NoError(t, err) + + correctResult, ok := correct["result"].(map[string]any) + require.True(t, ok, "correct response result must be an object: %#v", correct) + wrongResult, ok := wrong["result"].(map[string]any) + require.True(t, ok, "wrong response result must be an object: %#v", wrong) + + assert.Equal(t, "eval", correct["command"]) + assert.Equal(t, true, correctResult["is_correct"]) + assert.Equal(t, "Correct!", correctResult["feedback"]) + assert.EqualValues(t, 1, correctResult["guest_invocation_count"]) + assert.Equal(t, true, correctResult["snapshot_isolation_ok"]) + + assert.Equal(t, "eval", wrong["command"]) + assert.Equal(t, false, wrongResult["is_correct"]) + assert.Equal(t, "Try again.", wrongResult["feedback"]) + assert.EqualValues(t, 1, wrongResult["guest_invocation_count"]) + assert.Equal(t, true, wrongResult["snapshot_isolation_ok"]) +} From 8f4db4ed71fb96bcc5cab2a3ff98ac7f14fcc39c Mon Sep 17 00:00:00 2001 From: bkmashiro <53376445+bkmashiro@users.noreply.github.com> Date: Fri, 19 Jun 2026 14:23:02 +0100 Subject: [PATCH 4/7] docs: make linux wasm test environment explicit --- README.md | 9 ++++++--- examples/demo-cpp-compare/README.md | 5 ++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 04a0e39..6038359 100644 --- a/README.md +++ b/README.md @@ -321,9 +321,12 @@ snapshot after each request. This gives warm reuse without leaking guest mutable state between requests. Dirty-page restore, Python runtimes, Pyodide, and package bundling are intentionally out of scope for this generic backend. -Try the state-isolation examples. These are intentionally small synthetic -evaluators for the Go/C++ artifact path; real language/runtime packaging such as -Pyodide is a separate profile/follow-up. +Try the state-isolation examples. Linux, or a Linux container, is the reference +environment for evaluator build/test recipes. The scripts also run on macOS when +the same toolchain is installed, but CI/reviewer instructions should assume +Linux by default. These are intentionally small synthetic evaluators for the +Go/C++ artifact path; real language/runtime packaging such as Pyodide is a +separate profile/follow-up. ```shell scripts/demo-wasm.sh diff --git a/examples/demo-cpp-compare/README.md b/examples/demo-cpp-compare/README.md index 903316b..510b1aa 100644 --- a/examples/demo-cpp-compare/README.md +++ b/examples/demo-cpp-compare/README.md @@ -16,7 +16,10 @@ instance while restoring guest memory after each request. ## Build -The demo uses Zig's clang-compatible C++ driver because the default macOS Apple clang does not ship a WebAssembly target: +The reference environment for this example is Linux, or a Linux container, with +Zig installed. The same command also works on macOS when Zig is installed; the +point is to rely on an explicit WASM-capable toolchain rather than the host's +default C++ compiler. ```bash zig c++ \ From 9f6798656c3bf91a68d1609888fa4ea377f34d33 Mon Sep 17 00:00:00 2001 From: bkmashiro <53376445+bkmashiro@users.noreply.github.com> Date: Fri, 19 Jun 2026 14:31:11 +0100 Subject: [PATCH 5/7] test: cover go and rust wasm examples --- README.md | 5 + examples/demo-rust-compare/README.md | 54 ++++++ examples/demo-rust-compare/evaluator.rs | 177 ++++++++++++++++++ examples/demo-stateful/README.md | 7 + ...mple_test.go => artifact_examples_test.go} | 69 ++++++- 5 files changed, 308 insertions(+), 4 deletions(-) create mode 100644 examples/demo-rust-compare/README.md create mode 100644 examples/demo-rust-compare/evaluator.rs rename internal/execution/wasm/{cpp_example_test.go => artifact_examples_test.go} (55%) diff --git a/README.md b/README.md index 6038359..c18f3d5 100644 --- a/README.md +++ b/README.md @@ -306,6 +306,10 @@ GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm . # Rust cargo build --target wasm32-wasip1 --release +# Freestanding Rust +rustc --target wasm32-unknown-unknown --crate-type cdylib \ + -C panic=abort -O -o eval.wasm evaluator.rs + # C/C++ with wasi-sdk /opt/wasi-sdk/bin/clang --target=wasm32-wasip1 ... -o eval.wasm /opt/wasi-sdk/bin/clang++ --target=wasm32-wasip1 ... -o eval.wasm @@ -331,6 +335,7 @@ separate profile/follow-up. ```shell scripts/demo-wasm.sh scripts/demo-cpp-wasm.sh +go test ./internal/execution/wasm -run 'Test(GoStateful|RustCompare|CppCompare)Example_CompilesAndRunsThroughDispatcher' -v ``` ### Sandboxed Execution (Linux only, experimental) diff --git a/examples/demo-rust-compare/README.md b/examples/demo-rust-compare/README.md new file mode 100644 index 0000000..2fe203f --- /dev/null +++ b/examples/demo-rust-compare/README.md @@ -0,0 +1,54 @@ +# Rust Compare Evaluator for Shimmy WASM + +This example is a minimal, self-contained Rust evaluator that compiles to a +WebAssembly module and runs through Shimmy's opt-in WASM backend. It is not a +port of an existing Lambda Feedback repository; it is a small Rust artifact +example for validating the generic WASM execution path when a Rust WASM target +is installed. + +It intentionally mirrors the shape of a simple Lambda Feedback evaluator: + +- input: `response`, `answer`, and optional feedback strings in `params` +- output: `{ "command": "eval", "result": { "is_correct", "feedback" } }` + +The evaluator also reports `guest_invocation_count` and `snapshot_isolation_ok` +so the integration test can prove that Shimmy reuses a warm WASM instance while +restoring guest memory after each request. + +## Build + +The reference environment is Linux, or a Linux container, with Rust installed and +the `wasm32-unknown-unknown` target available: + +```bash +rustup target add wasm32-unknown-unknown +rustc \ + --target wasm32-unknown-unknown \ + --crate-type cdylib \ + -C panic=abort \ + -O \ + -o eval.wasm \ + evaluator.rs +``` + +The source is `#![no_std]` and implements only the small amount of JSON handling +needed for this evaluator. Real Rust evaluators can use richer build setup, but +still need to expose the same Shimmy WASM ABI: + +```text +memory +alloc(size: i32) -> i32 +evaluate(req_ptr: i32, req_len: i32) -> i32 +``` + +`evaluate` returns a pointer to `[uint32 little-endian response_len][response JSON bytes]`. + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestRustCompareExample_CompilesAndRunsThroughDispatcher -v +``` + +The test skips when `rustc` or the `wasm32-unknown-unknown` target is unavailable. diff --git a/examples/demo-rust-compare/evaluator.rs b/examples/demo-rust-compare/evaluator.rs new file mode 100644 index 0000000..1ef5d6e --- /dev/null +++ b/examples/demo-rust-compare/evaluator.rs @@ -0,0 +1,177 @@ +#![no_std] +#![no_main] + +use core::panic::PanicInfo; + +static mut REQ_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; +static mut RESP_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; + +// Deliberately mutable guest state. Shimmy snapshots this after startup and +// restores it after every request, so each call should observe count == 1. +static mut INVOCATION_COUNT: u32 = 0; + +#[panic_handler] +fn panic(_info: &PanicInfo) -> ! { + loop {} +} + +#[no_mangle] +pub extern "C" fn alloc(_size: i32) -> i32 { + core::ptr::addr_of_mut!(REQ_BUF) as *mut u8 as i32 +} + +#[no_mangle] +pub extern "C" fn evaluate(_req_ptr: i32, req_len: i32) -> i32 { + unsafe { + INVOCATION_COUNT += 1; + } + + let req = unsafe { + core::slice::from_raw_parts(core::ptr::addr_of!(REQ_BUF) as *const u8, req_len as usize) + }; + + let response = json_string_field(req, b"response"); + let answer = json_string_field(req, b"answer"); + let correct_feedback = + json_string_field(req, b"correct_response_feedback").unwrap_or(b"Correct!"); + let incorrect_feedback = + json_string_field(req, b"incorrect_response_feedback").unwrap_or(b"Try again."); + + let is_correct = response.is_some() && answer.is_some() && response == answer; + let feedback = if is_correct { + correct_feedback + } else { + incorrect_feedback + }; + let count = unsafe { INVOCATION_COUNT }; + + write_response(is_correct, feedback, count) +} + +fn json_string_field<'a>(input: &'a [u8], name: &[u8]) -> Option<&'a [u8]> { + let key = find_key(input, name)?; + let mut i = key + name.len() + 2; // leading quote + name + trailing quote + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b':' { + return None; + } + i += 1; + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b'"' { + return None; + } + i += 1; + let start = i; + while i < input.len() { + match input[i] { + b'"' => return Some(&input[start..i]), + b'\\' => i += 2, + _ => i += 1, + } + } + None +} + +fn find_key(input: &[u8], name: &[u8]) -> Option { + if input.len() < name.len() + 2 { + return None; + } + let last = input.len() - name.len() - 1; + let mut i = 0; + while i < last { + if input[i] == b'"' + && &input[i + 1..i + 1 + name.len()] == name + && input[i + 1 + name.len()] == b'"' + { + return Some(i); + } + i += 1; + } + None +} + +fn write_response(is_correct: bool, feedback: &[u8], count: u32) -> i32 { + let mut w = Writer::new(); + w.bytes(b"{\"command\":\"eval\",\"result\":{\"is_correct\":"); + w.bytes(if is_correct { b"true" } else { b"false" }); + w.bytes(b",\"feedback\":\""); + w.json_string_bytes(feedback); + w.bytes(b"\",\"guest_invocation_count\":"); + w.u32(count); + w.bytes(b",\"snapshot_isolation_ok\":"); + w.bytes(if count == 1 { b"true" } else { b"false" }); + w.bytes(b"}}"); + w.finish() +} + +struct Writer { + pos: usize, +} + +impl Writer { + fn new() -> Self { + Self { pos: 4 } + } + + fn bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + self.push(b); + } + } + + fn json_string_bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + match b { + b'"' => self.bytes(b"\\\""), + b'\\' => self.bytes(b"\\\\"), + _ => self.push(b), + } + } + } + + fn u32(&mut self, mut n: u32) { + if n == 0 { + self.push(b'0'); + return; + } + let mut digits = [0u8; 10]; + let mut len = 0; + while n > 0 { + digits[len] = b'0' + (n % 10) as u8; + n /= 10; + len += 1; + } + while len > 0 { + len -= 1; + self.push(digits[len]); + } + } + + fn push(&mut self, b: u8) { + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(self.pos) = b; + } + self.pos += 1; + } + + fn finish(self) -> i32 { + let len = (self.pos - 4) as u32; + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(0) = (len & 0xff) as u8; + *ptr.add(1) = ((len >> 8) & 0xff) as u8; + *ptr.add(2) = ((len >> 16) & 0xff) as u8; + *ptr.add(3) = ((len >> 24) & 0xff) as u8; + ptr as i32 + } + } +} diff --git a/examples/demo-stateful/README.md b/examples/demo-stateful/README.md index ffe4c02..b38ef5a 100644 --- a/examples/demo-stateful/README.md +++ b/examples/demo-stateful/README.md @@ -13,6 +13,13 @@ Use it via the one-command demo runner: scripts/demo-wasm.sh ``` +The Go test suite also compiles this example when the local Go toolchain +supports `GOOS=wasip1` and `//go:wasmexport`: + +```bash +go test ./internal/execution/wasm -run TestGoStatefulExample_CompilesAndRunsThroughDispatcher -v +``` + What the demo shows: 1. Build Shimmy. diff --git a/internal/execution/wasm/cpp_example_test.go b/internal/execution/wasm/artifact_examples_test.go similarity index 55% rename from internal/execution/wasm/cpp_example_test.go rename to internal/execution/wasm/artifact_examples_test.go index b59f912..82afc93 100644 --- a/internal/execution/wasm/cpp_example_test.go +++ b/internal/execution/wasm/artifact_examples_test.go @@ -2,9 +2,11 @@ package wasm import ( "context" + "os" "os/exec" "path/filepath" "runtime" + "strings" "testing" "time" @@ -49,8 +51,55 @@ func buildCppCompareExample(t *testing.T) string { return out } -func TestCppCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { - modulePath := buildCppCompareExample(t) +func buildGoStatefulExample(t *testing.T) string { + t.Helper() + goBin, err := exec.LookPath("go") + if err != nil { + t.Skip("go is required to compile the Go WASM example") + } + + root := repoRootFromTest(t) + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command(goBin, "build", "-buildmode=c-shared", "-o", out, "./examples/demo-stateful") + cmd.Dir = root + cmd.Env = append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm") + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "requires go1.24 or later") { + t.Skipf("go toolchain does not support //go:wasmexport:\n%s", string(output)) + } + require.NoError(t, err, "compile Go WASM example:\n%s", string(output)) + return out +} + +func buildRustCompareExample(t *testing.T) string { + t.Helper() + rustc, err := exec.LookPath("rustc") + if err != nil { + t.Skip("rustc is required to compile the Rust WASM example") + } + + root := repoRootFromTest(t) + src := filepath.Join(root, "examples", "demo-rust-compare", "evaluator.rs") + require.FileExists(t, src, "Rust WASM example source must exist") + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command(rustc, + "--target", "wasm32-unknown-unknown", + "--crate-type", "cdylib", + "-C", "panic=abort", + "-O", + "-o", out, + src, + ) + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "target may not be installed") { + t.Skipf("rust target wasm32-unknown-unknown is not installed:\n%s", string(output)) + } + require.NoError(t, err, "compile Rust WASM example:\n%s", string(output)) + return out +} + +func assertCompareEvaluatorRunsThroughDispatcher(t *testing.T, modulePath string) { + t.Helper() d := NewDispatcher(Config{ ModulePath: modulePath, @@ -87,13 +136,25 @@ func TestCppCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { assert.Equal(t, "eval", correct["command"]) assert.Equal(t, true, correctResult["is_correct"]) - assert.Equal(t, "Correct!", correctResult["feedback"]) + assert.NotEmpty(t, correctResult["feedback"]) assert.EqualValues(t, 1, correctResult["guest_invocation_count"]) assert.Equal(t, true, correctResult["snapshot_isolation_ok"]) assert.Equal(t, "eval", wrong["command"]) assert.Equal(t, false, wrongResult["is_correct"]) - assert.Equal(t, "Try again.", wrongResult["feedback"]) + assert.NotEmpty(t, wrongResult["feedback"]) assert.EqualValues(t, 1, wrongResult["guest_invocation_count"]) assert.Equal(t, true, wrongResult["snapshot_isolation_ok"]) } + +func TestCppCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildCppCompareExample(t)) +} + +func TestGoStatefulExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildGoStatefulExample(t)) +} + +func TestRustCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildRustCompareExample(t)) +} From 3092053615aabec23e3249d24a94e8b482d253cb Mon Sep 17 00:00:00 2001 From: bkmashiro <53376445+bkmashiro@users.noreply.github.com> Date: Tue, 23 Jun 2026 15:31:55 +0100 Subject: [PATCH 6/7] test: add package-shaped wasm evaluator examples --- README.md | 31 ++-- examples/demo-cpp-package/Makefile | 23 +++ examples/demo-cpp-package/README.md | 35 ++++ examples/demo-cpp-package/include/compare.hpp | 11 ++ examples/demo-cpp-package/src/compare.cpp | 21 +++ examples/demo-cpp-package/src/evaluator.cpp | 139 ++++++++++++++ examples/demo-go-package/README.md | 27 +++ .../demo-go-package/cmd/evaluator/main.go | 69 +++++++ examples/demo-go-package/go.mod | 3 + .../internal/compare/compare.go | 18 ++ examples/demo-rust-package/Cargo.lock | 7 + examples/demo-rust-package/Cargo.toml | 12 ++ examples/demo-rust-package/README.md | 34 ++++ examples/demo-rust-package/src/compare.rs | 15 ++ examples/demo-rust-package/src/lib.rs | 170 ++++++++++++++++++ .../execution/wasm/artifact_examples_test.go | 74 ++++++++ 16 files changed, 674 insertions(+), 15 deletions(-) create mode 100644 examples/demo-cpp-package/Makefile create mode 100644 examples/demo-cpp-package/README.md create mode 100644 examples/demo-cpp-package/include/compare.hpp create mode 100644 examples/demo-cpp-package/src/compare.cpp create mode 100644 examples/demo-cpp-package/src/evaluator.cpp create mode 100644 examples/demo-go-package/README.md create mode 100644 examples/demo-go-package/cmd/evaluator/main.go create mode 100644 examples/demo-go-package/go.mod create mode 100644 examples/demo-go-package/internal/compare/compare.go create mode 100644 examples/demo-rust-package/Cargo.lock create mode 100644 examples/demo-rust-package/Cargo.toml create mode 100644 examples/demo-rust-package/README.md create mode 100644 examples/demo-rust-package/src/compare.rs create mode 100644 examples/demo-rust-package/src/lib.rs diff --git a/README.md b/README.md index c18f3d5..aadc497 100644 --- a/README.md +++ b/README.md @@ -297,29 +297,30 @@ shimmy serve `FUNCTION_COMMAND=/path/to/eval.wasm` is also accepted for compatibility, but `FUNCTION_WASM_MODULE` is clearer for new deployments. -Example build recipes: +Example build recipes, including package-shaped evaluators: ```shell -# Go -GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm . +# Go package/module +cd examples/demo-go-package +GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm ./cmd/evaluator -# Rust -cargo build --target wasm32-wasip1 --release +# Rust crate/package +cd examples/demo-rust-package +cargo build --target wasm32-unknown-unknown --release -# Freestanding Rust -rustc --target wasm32-unknown-unknown --crate-type cdylib \ - -C panic=abort -O -o eval.wasm evaluator.rs +# C++ package with Makefile + Zig's clang driver +cd examples/demo-cpp-package +make wasm OUT=eval.wasm -# C/C++ with wasi-sdk +# C/C++ with wasi-sdk also works when the project exposes the same ABI /opt/wasi-sdk/bin/clang --target=wasm32-wasip1 ... -o eval.wasm /opt/wasi-sdk/bin/clang++ --target=wasm32-wasip1 ... -o eval.wasm - -# Freestanding C++ with Zig's clang driver -zig c++ -target wasm32-freestanding -nostdlib \ - -Wl,--no-entry -Wl,--export=alloc -Wl,--export=evaluate -Wl,--export-memory \ - -o eval.wasm evaluator.cpp ``` +Shimmy intentionally does not run these build commands from `shimmy serve`. +Build recipes can be overridden in Makefiles, CI, Dockerfiles, or deployment +scripts; the runtime boundary remains the pre-built `eval.wasm` module. + The backend keeps a warm module instance pool and restores a full linear-memory snapshot after each request. This gives warm reuse without leaking guest mutable state between requests. Dirty-page restore, Python runtimes, Pyodide, and package @@ -335,7 +336,7 @@ separate profile/follow-up. ```shell scripts/demo-wasm.sh scripts/demo-cpp-wasm.sh -go test ./internal/execution/wasm -run 'Test(GoStateful|RustCompare|CppCompare)Example_CompilesAndRunsThroughDispatcher' -v +go test ./internal/execution/wasm -run 'Test(GoStateful|RustCompare|CppCompare|GoPackage|RustPackage|CppPackage)Example_CompilesAndRunsThroughDispatcher' -v ``` ### Sandboxed Execution (Linux only, experimental) diff --git a/examples/demo-cpp-package/Makefile b/examples/demo-cpp-package/Makefile new file mode 100644 index 0000000..12905e1 --- /dev/null +++ b/examples/demo-cpp-package/Makefile @@ -0,0 +1,23 @@ +OUT ?= eval.wasm +ZIG ?= zig + +.PHONY: wasm clean + +wasm: + $(ZIG) c++ \ + -target wasm32-freestanding \ + -Oz \ + -nostdlib \ + -fno-exceptions \ + -fno-rtti \ + -Iinclude \ + -Wl,--no-entry \ + -Wl,--export=alloc \ + -Wl,--export=evaluate \ + -Wl,--export-memory \ + -Wl,--initial-memory=2097152 \ + -o $(OUT) \ + src/evaluator.cpp src/compare.cpp + +clean: + rm -f eval.wasm diff --git a/examples/demo-cpp-package/README.md b/examples/demo-cpp-package/README.md new file mode 100644 index 0000000..f40d315 --- /dev/null +++ b/examples/demo-cpp-package/README.md @@ -0,0 +1,35 @@ +# C++ Package Evaluator for Shimmy WASM + +This example is intentionally package-shaped rather than a single source file: + +```text +Makefile +include/compare.hpp +src/evaluator.cpp +src/compare.cpp +``` + +It demonstrates the intended boundary for real evaluators: the package build +recipe emits an `eval.wasm` artifact, and Shimmy's WASM backend only loads that +pre-built module. + +## Build + +```bash +make wasm +``` + +The Makefile uses Zig's clang-compatible C++ driver to produce a freestanding +WebAssembly module. Override the output path with: + +```bash +make wasm OUT=/tmp/eval.wasm +``` + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestCppPackageExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/examples/demo-cpp-package/include/compare.hpp b/examples/demo-cpp-package/include/compare.hpp new file mode 100644 index 0000000..1f60238 --- /dev/null +++ b/examples/demo-cpp-package/include/compare.hpp @@ -0,0 +1,11 @@ +#pragma once + +using usize = __SIZE_TYPE__; + +struct TextView { + const char *ptr; + usize len; +}; + +bool bytes_equal(TextView a, TextView b); +TextView feedback_for(bool is_correct, TextView correct_feedback, TextView incorrect_feedback); diff --git a/examples/demo-cpp-package/src/compare.cpp b/examples/demo-cpp-package/src/compare.cpp new file mode 100644 index 0000000..3869e63 --- /dev/null +++ b/examples/demo-cpp-package/src/compare.cpp @@ -0,0 +1,21 @@ +#include "compare.hpp" + +static const char kCorrect[] = "Correct!"; +static const char kIncorrect[] = "Try again."; + +bool bytes_equal(TextView a, TextView b) { + if (a.len != b.len) return false; + for (usize i = 0; i < a.len; i++) { + if (a.ptr[i] != b.ptr[i]) return false; + } + return true; +} + +TextView feedback_for(bool is_correct, TextView correct_feedback, TextView incorrect_feedback) { + if (is_correct) { + if (correct_feedback.ptr != nullptr) return correct_feedback; + return TextView{kCorrect, sizeof(kCorrect) - 1}; + } + if (incorrect_feedback.ptr != nullptr) return incorrect_feedback; + return TextView{kIncorrect, sizeof(kIncorrect) - 1}; +} diff --git a/examples/demo-cpp-package/src/evaluator.cpp b/examples/demo-cpp-package/src/evaluator.cpp new file mode 100644 index 0000000..956bdf5 --- /dev/null +++ b/examples/demo-cpp-package/src/evaluator.cpp @@ -0,0 +1,139 @@ +#include "compare.hpp" + +using u32 = unsigned int; +using i32 = int; +using uintptr = __UINTPTR_TYPE__; + +alignas(16) static char request_buffer[256 * 1024]; +alignas(16) static char response_buffer[256 * 1024]; +static u32 invocation_count = 0; + +static usize cstr_len(const char *s) { + usize n = 0; + while (s[n] != 0) n++; + return n; +} + +static void copy_bytes(char *dst, usize &pos, const char *src, usize len) { + for (usize i = 0; i < len; i++) dst[pos++] = src[i]; +} + +static void append_cstr(char *dst, usize &pos, const char *src) { + copy_bytes(dst, pos, src, cstr_len(src)); +} + +static void append_u32(char *dst, usize &pos, u32 value) { + char tmp[10]; + usize n = 0; + do { + tmp[n++] = char('0' + (value % 10)); + value /= 10; + } while (value != 0); + while (n > 0) dst[pos++] = tmp[--n]; +} + +static void append_json_string(char *dst, usize &pos, TextView src) { + dst[pos++] = '"'; + for (usize i = 0; i < src.len; i++) { + char c = src.ptr[i]; + if (c == '"' || c == '\\') { + dst[pos++] = '\\'; + dst[pos++] = c; + } else if (c == '\n') { + dst[pos++] = '\\'; + dst[pos++] = 'n'; + } else { + dst[pos++] = c; + } + } + dst[pos++] = '"'; +} + +static bool match_at(const char *json, usize len, usize pos, const char *needle) { + for (usize i = 0; needle[i] != 0; i++) { + if (pos + i >= len || json[pos + i] != needle[i]) return false; + } + return true; +} + +static bool find_json_string(const char *json, usize len, const char *key, TextView &value) { + char quoted_key[96]; + usize key_pos = 0; + quoted_key[key_pos++] = '"'; + for (usize i = 0; key[i] != 0 && key_pos + 2 < sizeof(quoted_key); i++) { + quoted_key[key_pos++] = key[i]; + } + quoted_key[key_pos++] = '"'; + quoted_key[key_pos] = 0; + + for (usize i = 0; i < len; i++) { + if (!match_at(json, len, i, quoted_key)) continue; + i += key_pos; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != ':') continue; + i++; + while (i < len && (json[i] == ' ' || json[i] == '\n' || json[i] == '\r' || json[i] == '\t')) i++; + if (i >= len || json[i] != '"') continue; + i++; + + usize start = i; + while (i < len) { + if (json[i] == '\\') { + i += 2; + continue; + } + if (json[i] == '"') { + value = TextView{json + start, i - start}; + return true; + } + i++; + } + } + return false; +} + +static i32 write_eval_response(bool is_correct, TextView feedback) { + usize pos = 4; + append_cstr(response_buffer, pos, "{\"command\":\"eval\",\"result\":{"); + append_cstr(response_buffer, pos, "\"is_correct\":"); + append_cstr(response_buffer, pos, is_correct ? "true" : "false"); + append_cstr(response_buffer, pos, ",\"feedback\":"); + append_json_string(response_buffer, pos, feedback); + append_cstr(response_buffer, pos, ",\"guest_invocation_count\":"); + append_u32(response_buffer, pos, invocation_count); + append_cstr(response_buffer, pos, ",\"snapshot_isolation_ok\":"); + append_cstr(response_buffer, pos, invocation_count == 1 ? "true" : "false"); + append_cstr(response_buffer, pos, "}}"); + + u32 len = u32(pos - 4); + response_buffer[0] = char(len & 0xff); + response_buffer[1] = char((len >> 8) & 0xff); + response_buffer[2] = char((len >> 16) & 0xff); + response_buffer[3] = char((len >> 24) & 0xff); + return i32(uintptr(response_buffer)); +} + +extern "C" i32 alloc(i32 size) { + if (size <= 0 || usize(size) > sizeof(request_buffer)) return 0; + return i32(uintptr(request_buffer)); +} + +extern "C" i32 evaluate(i32 req_ptr, i32 req_len) { + if (req_ptr == 0 || req_len <= 0) return 0; + const char *json = reinterpret_cast(uintptr(req_ptr)); + usize len = usize(req_len); + + TextView response{nullptr, 0}; + TextView answer{nullptr, 0}; + TextView correct_feedback{nullptr, 0}; + TextView incorrect_feedback{nullptr, 0}; + + if (!find_json_string(json, len, "response", response)) return 0; + if (!find_json_string(json, len, "answer", answer)) return 0; + find_json_string(json, len, "correct_response_feedback", correct_feedback); + find_json_string(json, len, "incorrect_response_feedback", incorrect_feedback); + + invocation_count++; + bool is_correct = bytes_equal(response, answer); + return write_eval_response(is_correct, feedback_for(is_correct, correct_feedback, incorrect_feedback)); +} diff --git a/examples/demo-go-package/README.md b/examples/demo-go-package/README.md new file mode 100644 index 0000000..b3710c4 --- /dev/null +++ b/examples/demo-go-package/README.md @@ -0,0 +1,27 @@ +# Go Package Evaluator for Shimmy WASM + +This example is intentionally package-shaped rather than a single source file: + +```text +go.mod +cmd/evaluator/main.go +internal/compare/compare.go +``` + +It demonstrates the intended boundary for real evaluators: the package build +recipe emits an `eval.wasm` artifact, and Shimmy's WASM backend only loads that +pre-built module. + +## Build + +```bash +GOOS=wasip1 GOARCH=wasm go build -buildmode=c-shared -o eval.wasm ./cmd/evaluator +``` + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestGoPackageExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/examples/demo-go-package/cmd/evaluator/main.go b/examples/demo-go-package/cmd/evaluator/main.go new file mode 100644 index 0000000..143259b --- /dev/null +++ b/examples/demo-go-package/cmd/evaluator/main.go @@ -0,0 +1,69 @@ +//go:build wasip1 + +package main + +import ( + "encoding/binary" + "encoding/json" + "unsafe" + + "demo-go-package/internal/compare" +) + +var reqBuf [256 * 1024]byte +var respBuf [256 * 1024]byte +var invocationCount uint32 + +//go:wasmexport alloc +func alloc(size int32) int32 { + _ = size + return int32(uintptr(unsafe.Pointer(&reqBuf[0]))) +} + +//go:wasmexport evaluate +func evaluate(reqPtr int32, reqLen int32) int32 { + _ = reqPtr + var req struct { + Method string `json:"method"` + Params struct { + Response string `json:"response"` + Answer string `json:"answer"` + Params map[string]any `json:"params"` + } `json:"params"` + } + if err := json.Unmarshal(reqBuf[:reqLen], &req); err != nil { + writeResp(map[string]any{"error": map[string]any{"message": err.Error()}}) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) + } + if req.Method != "eval" { + writeResp(map[string]any{"error": map[string]any{"message": "unsupported method"}}) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) + } + + invocationCount++ + isCorrect := compare.IsCorrect(req.Params.Response, req.Params.Answer) + correctFeedback, _ := req.Params.Params["correct_response_feedback"].(string) + incorrectFeedback, _ := req.Params.Params["incorrect_response_feedback"].(string) + feedback := compare.Feedback(isCorrect, correctFeedback, incorrectFeedback) + writeResp(map[string]any{ + "command": "eval", + "result": map[string]any{ + "is_correct": isCorrect, + "feedback": feedback, + "guest_invocation_count": invocationCount, + "snapshot_isolation_ok": invocationCount == 1, + }, + }) + return int32(uintptr(unsafe.Pointer(&respBuf[0]))) +} + +func writeResp(v map[string]any) { + data, err := json.Marshal(v) + if err != nil { + data = []byte(`{"error":{"message":"marshal failed"}}`) + } + binary.LittleEndian.PutUint32(respBuf[:4], uint32(len(data))) + copy(respBuf[4:], data) +} + +func main() {} diff --git a/examples/demo-go-package/go.mod b/examples/demo-go-package/go.mod new file mode 100644 index 0000000..f04ed4f --- /dev/null +++ b/examples/demo-go-package/go.mod @@ -0,0 +1,3 @@ +module demo-go-package + +go 1.24 diff --git a/examples/demo-go-package/internal/compare/compare.go b/examples/demo-go-package/internal/compare/compare.go new file mode 100644 index 0000000..8978f98 --- /dev/null +++ b/examples/demo-go-package/internal/compare/compare.go @@ -0,0 +1,18 @@ +package compare + +func IsCorrect(response, answer string) bool { + return response == answer +} + +func Feedback(isCorrect bool, correctFeedback, incorrectFeedback string) string { + if isCorrect { + if correctFeedback != "" { + return correctFeedback + } + return "Correct" + } + if incorrectFeedback != "" { + return incorrectFeedback + } + return "Incorrect" +} diff --git a/examples/demo-rust-package/Cargo.lock b/examples/demo-rust-package/Cargo.lock new file mode 100644 index 0000000..70fcca9 --- /dev/null +++ b/examples/demo-rust-package/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "demo-rust-package" +version = "0.1.0" diff --git a/examples/demo-rust-package/Cargo.toml b/examples/demo-rust-package/Cargo.toml new file mode 100644 index 0000000..c413091 --- /dev/null +++ b/examples/demo-rust-package/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "demo-rust-package" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["cdylib"] + +[profile.release] +panic = "abort" +lto = true +opt-level = "z" diff --git a/examples/demo-rust-package/README.md b/examples/demo-rust-package/README.md new file mode 100644 index 0000000..c595947 --- /dev/null +++ b/examples/demo-rust-package/README.md @@ -0,0 +1,34 @@ +# Rust Package Evaluator for Shimmy WASM + +This example is intentionally crate-shaped rather than a single source file: + +```text +Cargo.toml +src/lib.rs +src/compare.rs +``` + +It demonstrates the intended boundary for real evaluators: the crate build +recipe emits a `.wasm` artifact, and Shimmy's WASM backend only loads that +pre-built module. + +## Build + +```bash +rustup target add wasm32-unknown-unknown +cargo build --target wasm32-unknown-unknown --release +``` + +The output module is: + +```text +target/wasm32-unknown-unknown/release/demo_rust_package.wasm +``` + +## Test + +From the repository root: + +```bash +go test ./internal/execution/wasm -run TestRustPackageExample_CompilesAndRunsThroughDispatcher -v +``` diff --git a/examples/demo-rust-package/src/compare.rs b/examples/demo-rust-package/src/compare.rs new file mode 100644 index 0000000..23cc2a9 --- /dev/null +++ b/examples/demo-rust-package/src/compare.rs @@ -0,0 +1,15 @@ +pub fn is_correct(response: Option<&[u8]>, answer: Option<&[u8]>) -> bool { + response.is_some() && answer.is_some() && response == answer +} + +pub fn feedback<'a>( + is_correct: bool, + correct_feedback: Option<&'a [u8]>, + incorrect_feedback: Option<&'a [u8]>, +) -> &'a [u8] { + if is_correct { + correct_feedback.unwrap_or(b"Correct!") + } else { + incorrect_feedback.unwrap_or(b"Try again.") + } +} diff --git a/examples/demo-rust-package/src/lib.rs b/examples/demo-rust-package/src/lib.rs new file mode 100644 index 0000000..e54fd2d --- /dev/null +++ b/examples/demo-rust-package/src/lib.rs @@ -0,0 +1,170 @@ +#![no_std] +#![no_main] + +mod compare; + +use core::panic::PanicInfo; + +static mut REQ_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; +static mut RESP_BUF: [u8; 256 * 1024] = [0; 256 * 1024]; +static mut INVOCATION_COUNT: u32 = 0; + +#[panic_handler] +fn panic(_info: &PanicInfo) -> ! { + loop {} +} + +#[no_mangle] +pub extern "C" fn alloc(_size: i32) -> i32 { + core::ptr::addr_of_mut!(REQ_BUF) as *mut u8 as i32 +} + +#[no_mangle] +pub extern "C" fn evaluate(_req_ptr: i32, req_len: i32) -> i32 { + unsafe { + INVOCATION_COUNT += 1; + } + + let req = unsafe { + core::slice::from_raw_parts(core::ptr::addr_of!(REQ_BUF) as *const u8, req_len as usize) + }; + + let response = json_string_field(req, b"response"); + let answer = json_string_field(req, b"answer"); + let correct_feedback = json_string_field(req, b"correct_response_feedback"); + let incorrect_feedback = json_string_field(req, b"incorrect_response_feedback"); + + let is_correct = compare::is_correct(response, answer); + let feedback = compare::feedback(is_correct, correct_feedback, incorrect_feedback); + let count = unsafe { INVOCATION_COUNT }; + + write_response(is_correct, feedback, count) +} + +fn json_string_field<'a>(input: &'a [u8], name: &[u8]) -> Option<&'a [u8]> { + let key = find_key(input, name)?; + let mut i = key + name.len() + 2; + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b':' { + return None; + } + i += 1; + while i < input.len() + && (input[i] == b' ' || input[i] == b'\n' || input[i] == b'\t' || input[i] == b'\r') + { + i += 1; + } + if i >= input.len() || input[i] != b'"' { + return None; + } + i += 1; + let start = i; + while i < input.len() { + match input[i] { + b'"' => return Some(&input[start..i]), + b'\\' => i += 2, + _ => i += 1, + } + } + None +} + +fn find_key(input: &[u8], name: &[u8]) -> Option { + if input.len() < name.len() + 2 { + return None; + } + let last = input.len() - name.len() - 1; + let mut i = 0; + while i < last { + if input[i] == b'"' + && &input[i + 1..i + 1 + name.len()] == name + && input[i + 1 + name.len()] == b'"' + { + return Some(i); + } + i += 1; + } + None +} + +fn write_response(is_correct: bool, feedback: &[u8], count: u32) -> i32 { + let mut w = Writer::new(); + w.bytes(b"{\"command\":\"eval\",\"result\":{\"is_correct\":"); + w.bytes(if is_correct { b"true" } else { b"false" }); + w.bytes(b",\"feedback\":\""); + w.json_string_bytes(feedback); + w.bytes(b"\",\"guest_invocation_count\":"); + w.u32(count); + w.bytes(b",\"snapshot_isolation_ok\":"); + w.bytes(if count == 1 { b"true" } else { b"false" }); + w.bytes(b"}}"); + w.finish() +} + +struct Writer { + pos: usize, +} + +impl Writer { + fn new() -> Self { + Self { pos: 4 } + } + + fn bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + self.push(b); + } + } + + fn json_string_bytes(&mut self, bytes: &[u8]) { + for &b in bytes { + match b { + b'"' => self.bytes(b"\\\""), + b'\\' => self.bytes(b"\\\\"), + _ => self.push(b), + } + } + } + + fn u32(&mut self, mut n: u32) { + if n == 0 { + self.push(b'0'); + return; + } + let mut digits = [0u8; 10]; + let mut len = 0; + while n > 0 { + digits[len] = b'0' + (n % 10) as u8; + n /= 10; + len += 1; + } + while len > 0 { + len -= 1; + self.push(digits[len]); + } + } + + fn push(&mut self, b: u8) { + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(self.pos) = b; + } + self.pos += 1; + } + + fn finish(self) -> i32 { + let len = (self.pos - 4) as u32; + unsafe { + let ptr = core::ptr::addr_of_mut!(RESP_BUF) as *mut u8; + *ptr.add(0) = (len & 0xff) as u8; + *ptr.add(1) = ((len >> 8) & 0xff) as u8; + *ptr.add(2) = ((len >> 16) & 0xff) as u8; + *ptr.add(3) = ((len >> 24) & 0xff) as u8; + ptr as i32 + } + } +} diff --git a/internal/execution/wasm/artifact_examples_test.go b/internal/execution/wasm/artifact_examples_test.go index 82afc93..a25420d 100644 --- a/internal/execution/wasm/artifact_examples_test.go +++ b/internal/execution/wasm/artifact_examples_test.go @@ -98,6 +98,68 @@ func buildRustCompareExample(t *testing.T) string { return out } +func buildGoPackageExample(t *testing.T) string { + t.Helper() + goBin, err := exec.LookPath("go") + if err != nil { + t.Skip("go is required to compile the Go package WASM example") + } + + root := repoRootFromTest(t) + packageDir := filepath.Join(root, "examples", "demo-go-package") + require.DirExists(t, packageDir, "Go package WASM example must exist") + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command(goBin, "build", "-buildmode=c-shared", "-o", out, "./cmd/evaluator") + cmd.Dir = packageDir + cmd.Env = append(os.Environ(), "GOOS=wasip1", "GOARCH=wasm") + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "requires go1.24 or later") { + t.Skipf("go toolchain does not support //go:wasmexport:\n%s", string(output)) + } + require.NoError(t, err, "compile Go package WASM example:\n%s", string(output)) + return out +} + +func buildRustPackageExample(t *testing.T) string { + t.Helper() + cargo, err := exec.LookPath("cargo") + if err != nil { + t.Skip("cargo is required to compile the Rust package WASM example") + } + + root := repoRootFromTest(t) + packageDir := filepath.Join(root, "examples", "demo-rust-package") + require.DirExists(t, packageDir, "Rust package WASM example must exist") + cmd := exec.Command(cargo, "build", "--target", "wasm32-unknown-unknown", "--release") + cmd.Dir = packageDir + output, err := cmd.CombinedOutput() + if err != nil && strings.Contains(string(output), "target may not be installed") { + t.Skipf("rust target wasm32-unknown-unknown is not installed:\n%s", string(output)) + } + require.NoError(t, err, "compile Rust package WASM example:\n%s", string(output)) + return filepath.Join(packageDir, "target", "wasm32-unknown-unknown", "release", "demo_rust_package.wasm") +} + +func buildCppPackageExample(t *testing.T) string { + t.Helper() + if _, err := exec.LookPath("zig"); err != nil { + t.Skip("zig is required to compile the C++ package WASM example") + } + if _, err := exec.LookPath("make"); err != nil { + t.Skip("make is required to compile the C++ package WASM example") + } + + root := repoRootFromTest(t) + packageDir := filepath.Join(root, "examples", "demo-cpp-package") + require.DirExists(t, packageDir, "C++ package WASM example must exist") + out := filepath.Join(t.TempDir(), "eval.wasm") + cmd := exec.Command("make", "wasm", "OUT="+out) + cmd.Dir = packageDir + output, err := cmd.CombinedOutput() + require.NoError(t, err, "compile C++ package WASM example:\n%s", string(output)) + return out +} + func assertCompareEvaluatorRunsThroughDispatcher(t *testing.T, modulePath string) { t.Helper() @@ -158,3 +220,15 @@ func TestGoStatefulExample_CompilesAndRunsThroughDispatcher(t *testing.T) { func TestRustCompareExample_CompilesAndRunsThroughDispatcher(t *testing.T) { assertCompareEvaluatorRunsThroughDispatcher(t, buildRustCompareExample(t)) } + +func TestGoPackageExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildGoPackageExample(t)) +} + +func TestRustPackageExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildRustPackageExample(t)) +} + +func TestCppPackageExample_CompilesAndRunsThroughDispatcher(t *testing.T) { + assertCompareEvaluatorRunsThroughDispatcher(t, buildCppPackageExample(t)) +} From f5604b306ada3a26cb2e82ef45c4e6131be29d9a Mon Sep 17 00:00:00 2001 From: bkmashiro <53376445+bkmashiro@users.noreply.github.com> Date: Tue, 23 Jun 2026 16:03:22 +0100 Subject: [PATCH 7/7] fix: preserve command validation for process backends --- internal/execution/dispatcher.go | 15 ++++++++++ internal/execution/dispatcher_test.go | 42 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 internal/execution/dispatcher_test.go diff --git a/internal/execution/dispatcher.go b/internal/execution/dispatcher.go index f8aec05..016ee05 100644 --- a/internal/execution/dispatcher.go +++ b/internal/execution/dispatcher.go @@ -39,6 +39,9 @@ type Params struct { func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { switch params.Config.Supervisor.IO.Interface { case supervisor.RpcIO: + if err := requireProcessWorkerCommand(params.Config.Supervisor); err != nil { + return nil, err + } return dispatcher.NewDedicatedDispatcher( dispatcher.DedicatedDispatcherParams{ Config: dispatcher.DedicatedDispatcherConfig{ @@ -72,6 +75,11 @@ func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { return d, nil default: + if params.Config.Supervisor.IO.Interface == supervisor.FileIO { + if err := requireProcessWorkerCommand(params.Config.Supervisor); err != nil { + return nil, err + } + } return dispatcher.NewPooledDispatcher( dispatcher.PooledDispatcherParams{ Config: dispatcher.PooledDispatcherConfig{ @@ -84,3 +92,10 @@ func NewDispatcher(params Params) (dispatcher.Dispatcher, error) { ) } } + +func requireProcessWorkerCommand(cfg supervisor.Config) error { + if strings.TrimSpace(cfg.StartParams.Cmd) == "" { + return fmt.Errorf("FUNCTION_COMMAND is required when FUNCTION_INTERFACE=%q", cfg.IO.Interface) + } + return nil +} diff --git a/internal/execution/dispatcher_test.go b/internal/execution/dispatcher_test.go new file mode 100644 index 0000000..3cbd314 --- /dev/null +++ b/internal/execution/dispatcher_test.go @@ -0,0 +1,42 @@ +package execution + +import ( + "context" + "strings" + "testing" + + "go.uber.org/zap" + + "github.com/lambda-feedback/shimmy/internal/execution/supervisor" +) + +func TestNewDispatcher_RequiresCommandForProcessInterfaces(t *testing.T) { + tests := []struct { + name string + io supervisor.IOInterface + }{ + {name: "rpc", io: supervisor.RpcIO}, + {name: "file", io: supervisor.FileIO}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewDispatcher(Params{ + Context: context.Background(), + Config: Config{ + MaxWorkers: 1, + Supervisor: supervisor.Config{ + IO: supervisor.IOConfig{Interface: tt.io}, + }, + }, + Log: zap.NewNop(), + }) + if err == nil { + t.Fatal("expected missing command error") + } + if got, want := err.Error(), "FUNCTION_COMMAND is required"; !strings.Contains(got, want) { + t.Fatalf("expected error to contain %q, got %q", want, got) + } + }) + } +}