diff --git a/adapter/redis.go b/adapter/redis.go index ef709c08..6058faff 100644 --- a/adapter/redis.go +++ b/adapter/redis.go @@ -258,6 +258,8 @@ type RedisServer struct { pubsub *redisPubSub scriptMu sync.RWMutex scriptCache map[string]string + luaPool *luaStatePool + luaPoolOnce sync.Once traceCommands bool traceSeq atomic.Uint64 redisAddr string @@ -406,6 +408,7 @@ func NewRedisServer(listen net.Listener, redisAddr string, store store.MVCCStore leaderClients: make(map[string]*redis.Client), pubsub: newRedisPubSub(), scriptCache: map[string]string{}, + luaPool: newLuaStatePool(), traceCommands: os.Getenv("ELASTICKV_REDIS_TRACE") == "1", baseCtx: baseCtx, baseCancel: baseCancel, diff --git a/adapter/redis_lua.go b/adapter/redis_lua.go index 7b860139..f5e7918f 100644 --- a/adapter/redis_lua.go +++ b/adapter/redis_lua.go @@ -129,10 +129,13 @@ func (r *RedisServer) runLuaScript(conn redcon.Conn, script string, evalArgs [][ return err } defer scriptCtx.Close() - state := newRedisLuaState() - defer state.Close() + + luaPool := r.getLuaPool() + pls := luaPool.get(scriptCtx) + defer luaPool.put(pls) + state := pls.state state.SetContext(ctx) - r.initLuaGlobals(state, scriptCtx, keys, argv) + r.initPooledLuaScriptGlobals(state, keys, argv) chunk, err := state.LoadString(script) if err != nil { @@ -204,32 +207,13 @@ func parseRedisEvalArgs(args [][]byte) ([][]byte, [][]byte, error) { return keys, argv, nil } -func (r *RedisServer) initLuaGlobals(state *lua.LState, ctx *luaScriptContext, keys [][]byte, argv [][]byte) { +// initPooledLuaScriptGlobals installs the per-eval globals (KEYS, ARGV) +// on a pooled *lua.LState. All shared modules (redis, cjson, cmsgpack, +// table.unpack => unpack) are wired once at pool-fill time and restored +// on release; see redis_lua_pool.go for the reset invariant. +func (r *RedisServer) initPooledLuaScriptGlobals(state *lua.LState, keys [][]byte, argv [][]byte) { state.SetGlobal("KEYS", makeLuaStringArray(state, keys)) state.SetGlobal("ARGV", makeLuaStringArray(state, argv)) - registerRedisModule(state, ctx) - registerCJSONModule(state) - registerCMsgpackModule(state) - - tableModule := state.GetGlobal("table") - if tbl, ok := tableModule.(*lua.LTable); ok { - if unpack := tbl.RawGetString("unpack"); unpack != lua.LNil { - state.SetGlobal("unpack", unpack) - } - } -} - -func newRedisLuaState() *lua.LState { - state := lua.NewState(lua.Options{SkipOpenLibs: true}) - openLuaLib(state, lua.BaseLibName, lua.OpenBase) - openLuaLib(state, lua.TabLibName, lua.OpenTable) - openLuaLib(state, lua.StringLibName, lua.OpenString) - openLuaLib(state, lua.MathLibName, lua.OpenMath) - - for _, name := range []string{"dofile", "load", "loadfile", "loadstring", "module", "require"} { - state.SetGlobal(name, lua.LNil) - } - return state } func openLuaLib(state *lua.LState, name string, fn lua.LGFunction) { @@ -238,35 +222,6 @@ func openLuaLib(state *lua.LState, name string, fn lua.LGFunction) { state.Call(1, 0) } -func registerRedisModule(state *lua.LState, ctx *luaScriptContext) { - module := state.NewTable() - state.SetFuncs(module, map[string]lua.LGFunction{ - "call": func(scriptState *lua.LState) int { - return luaRedisCommand(scriptState, ctx, true) - }, - "pcall": func(scriptState *lua.LState) int { - return luaRedisCommand(scriptState, ctx, false) - }, - "sha1hex": func(scriptState *lua.LState) int { - scriptState.Push(lua.LString(luaScriptSHA(scriptState.CheckString(1)))) - return 1 - }, - "status_reply": func(scriptState *lua.LState) int { - reply := scriptState.NewTable() - reply.RawSetString(luaTypeOKKey, lua.LString(scriptState.CheckString(1))) - scriptState.Push(reply) - return 1 - }, - "error_reply": func(scriptState *lua.LState) int { - reply := scriptState.NewTable() - reply.RawSetString(luaTypeErrKey, lua.LString(scriptState.CheckString(1))) - scriptState.Push(reply) - return 1 - }, - }) - state.SetGlobal("redis", module) -} - func luaRedisCommand(state *lua.LState, ctx *luaScriptContext, raise bool) int { if state.GetTop() == 0 { if raise { diff --git a/adapter/redis_lua_pool.go b/adapter/redis_lua_pool.go new file mode 100644 index 00000000..0b156a18 --- /dev/null +++ b/adapter/redis_lua_pool.go @@ -0,0 +1,586 @@ +package adapter + +import ( + "sync" + "sync/atomic" + + lua "github.com/yuin/gopher-lua" +) + +// luaCtxRegistryKey is the fixed registry key under which each pooled +// *lua.LState stores a pre-allocated *lua.LUserData whose .Value holds +// the per-eval *luaScriptContext. Putting the binding in the state's +// own registry (instead of a global map guarded by sync.RWMutex) means +// every redis.call / redis.pcall lookup is O(1), lock-free, and local +// to the state -- no cross-state contention even under high fan-out +// workloads like BullMQ (~50 lookups/s/script). +const luaCtxRegistryKey = "elastickv_ctx" + +// luaInitialGlobalsHint is the expected number of string-keyed +// globals present on a freshly initialised pooled state (base lib +// helpers + string/math/table tables + redis/cjson/cmsgpack + the +// nil-ed loader placeholders + unpack). Sizing the snapshot map to +// this up front avoids an internal grow during fill. +const luaInitialGlobalsHint = 64 + +// luaResetKeySlack accounts for the handful of user-added globals +// (KEYS, ARGV, and any helpers the script itself defined) that the +// reset routine has to walk. Serves only as a capacity hint for a +// scratch slice in resetPooledLuaState. +const luaResetKeySlack = 8 + +// luaWhitelistedTableHint is a capacity hint for the tableSnapshots +// map -- one entry per nested table value at init (math, string, +// table, redis, cjson, cmsgpack). +const luaWhitelistedTableHint = 8 + +// luaStatePool pools *lua.LState instances to cut heap/GC pressure on +// high-rate EVAL / EVALSHA workloads (e.g. BullMQ ~10 scripts/s, where +// each fresh state allocs ~34% of in-use heap via newFuncContext, +// newRegistry, newFunctionProto). +// +// Security invariant: no state must leak between scripts. Each pooled +// state is initialised with a fixed set of base globals (redis, cjson, +// cmsgpack, table/string/math + base lib helpers, and nil-ed loaders). +// Three snapshots are captured at construction time: +// +// - globalsSnapshot: the full (*any*-keyed) _G map at init. Using an +// LValue-keyed map lets the reset path catch non-string-keyed +// leaks like `_G[42] = "secret"` or `_G[true] = "bad"`, which +// would otherwise survive a naive string-only wipe. +// - tableSnapshots: a shallow map from each whitelisted nested +// table (string, math, table, redis, cjson, cmsgpack) to its +// init-time field set. This is what blocks table-poisoning +// attacks such as `string.upper = function() return "pwned" end` +// -- merely restoring the `string` *reference* on _G would leave +// the shared table's fields still mutated. +// - metatableSnapshots: the init-time raw metatable of _G plus of +// every whitelisted nested table. Without this, a script calling +// `setmetatable(_G, { __index = function() return "pwned" end })` +// could leak a poisoned fallback into the next pooled eval via +// any undefined-global access. Same risk for `setmetatable(string, +// ...)` etc. +// +// On release, the reset routine +// +// 1. restores the raw metatable of _G and every whitelisted table +// (LNil if there was none originally), neutering setmetatable +// poisoning, +// 2. walks each snapshotted nested table and restores its contents +// (deletes script-added fields, rebinds original fields), +// 3. walks the current global table and deletes every key -- of any +// type -- that is not present in the globals snapshot (removes +// user-added globals such as KEYS, ARGV, GLOBAL_LEAK, _G[42]), +// and +// 4. restores every globals-snapshot key to its original value (so a +// script that did `table = nil` or `redis = evil` cannot poison +// the next script). +// +// Additionally the value stack is truncated to 0 and the script +// context binding is cleared so the redis.call/pcall closures cannot +// be invoked against a stale context. +// +// The redis / cjson / cmsgpack closures are registered ONCE at pool +// fill time and read the per-eval *luaScriptContext out of each +// state's own Lua registry (see luaCtxRegistryKey / ctxBinding), +// which is set on acquire and cleared on release. Closures that +// would otherwise capture a fresh context per eval no longer need +// to be re-registered, which is what makes pooling safe and cheap. +// The registry-backed binding is also the reason redis.call is +// lock-free in the hot path, unlike the first iteration which used +// a package-level map guarded by sync.RWMutex. +type luaStatePool struct { + pool sync.Pool + + // hits / misses are exposed for tests and metrics. + hits atomic.Uint64 + misses atomic.Uint64 +} + +// pooledLuaState wraps a *lua.LState plus the immutable snapshot of +// the globals that were present after base initialisation. Everything +// NOT in globalsSnapshot is treated as user-introduced state and +// removed on release. +type pooledLuaState struct { + state *lua.LState + // globalsSnapshot is a copy of every entry reachable via the + // state's global table at init, keyed by LValue (not just string) + // so scripts cannot smuggle state across evals via non-string + // keys such as _G[42] = "secret". + globalsSnapshot map[lua.LValue]lua.LValue + // tableSnapshots holds the shallow field sets of well-known + // whitelisted tables (string, math, table, redis, cjson, + // cmsgpack) captured at init. On reset we restore each to its + // original contents so a script doing e.g. + // `string.upper = function() return "pwned" end` cannot poison + // subsequent pooled reuses. + // + // The outer map is keyed by the *LTable pointer of the parent + // (e.g. the `string` table) so tableSnapshots survives even if a + // script rebinds the global name (`string = nil`) -- the reset + // restores the global name first, then restores the table's + // internal contents from this snapshot. + tableSnapshots map[*lua.LTable]map[lua.LValue]lua.LValue + // metatableSnapshots holds the init-time raw metatable of every + // snapshotted table (the globals table _G plus each entry in + // tableSnapshots). gopher-lua's base lib exposes setmetatable, so + // a script can do `setmetatable(_G, { __index = function() + // return "pwned" end })` -- the next pooled eval reading any + // undefined global would then fall through the poisoned __index. + // The same risk applies to the standard-library tables (string, + // math, ...). We restore each table's metatable on reset; if the + // original had none, we restore lua.LNil (which strips any + // metatable installed by the script). + metatableSnapshots map[*lua.LTable]lua.LValue + // ctxBinding is a pre-allocated *LUserData stashed in the state's + // registry under luaCtxRegistryKey. Its .Value holds the active + // *luaScriptContext for the duration of an eval. Using the state's + // own registry (instead of a global map + sync.RWMutex) keeps the + // redis.call / redis.pcall lookup lock-free and local, which is + // critical for high-concurrency workloads where a single script + // may issue dozens of redis.call invocations. + ctxBinding *lua.LUserData + // scratchKeys is a reusable slice for collecting table keys during + // reset / resetTableContents. Each reset leaves it sliced to + // [:0] so subsequent resets reuse the underlying array. If a + // pathological script inflates it past luaScratchKeysMaxCap we + // drop the backing array to avoid pinning unbounded memory on + // pooled states. + scratchKeys []lua.LValue +} + +// luaScratchKeysMaxCap bounds the backing array retained by +// scratchKeys across resets. Beyond this we drop the slice so one +// rogue script does not inflate the pool's per-state footprint +// indefinitely. Chosen to cover typical EVAL globals comfortably +// (base stdlib + redis/cjson/cmsgpack + a handful of user globals). +const luaScratchKeysMaxCap = 1024 + +// luaLookupContext returns the *luaScriptContext bound to state for +// the current eval, reading it from the state's own registry. Because +// each pooled *lua.LState is used by at most one goroutine at a time, +// this lookup needs no synchronisation -- unlike the previous global +// map guarded by sync.RWMutex, which under BullMQ-style workloads +// (dozens of redis.call invocations per script, thousands of scripts/s) +// became a global RLock contention point. +// +// The registry entry is a pre-allocated *LUserData (see +// pooledLuaState.ctxBinding) whose .Value is mutated by bind/unbind. +// Reading it therefore amortises to a single pointer load + type +// assertion per redis.call. +func luaLookupContext(state *lua.LState) (*luaScriptContext, bool) { + ud, ok := state.GetField(state.Get(lua.RegistryIndex), luaCtxRegistryKey).(*lua.LUserData) + if !ok || ud == nil { + return nil, false + } + ctx, ok := ud.Value.(*luaScriptContext) + if !ok || ctx == nil { + return nil, false + } + return ctx, true +} + +// getLuaPool returns the RedisServer's pooled lua state pool, +// creating it on first use. The constructor path (NewRedisServer) +// always pre-populates r.luaPool; this lazy fallback exists so unit +// tests that construct a bare &RedisServer{} literal (common in this +// package) do not NPE the first time EVAL is exercised. +func (r *RedisServer) getLuaPool() *luaStatePool { + r.luaPoolOnce.Do(func() { + if r.luaPool == nil { + r.luaPool = newLuaStatePool() + } + }) + return r.luaPool +} + +// newLuaStatePool returns a pool that lazily allocates +// *pooledLuaState instances on demand. The pool deliberately does NOT +// set sync.Pool.New: if it did, p.pool.Get() would auto-invoke the +// constructor on an empty pool and we could not distinguish a fresh +// allocation from a reused instance. Instead, get() inspects the +// result of p.pool.Get() -- a nil return signals an empty pool and +// drives the miss counter plus an explicit newPooledLuaState() call. +// This keeps the hit/miss metrics honest, which is what the serial +// reuse tests and the observability counters rely on. +func newLuaStatePool() *luaStatePool { + return &luaStatePool{} +} + +// newPooledLuaState builds a fresh pooled state: base libs, dangerous +// loaders nil-ed, a per-state ctxBinding userdata stashed in the Lua +// registry, redis/cjson/cmsgpack closures wired to that binding, and a +// snapshot of globals for leak-free reset. +func newPooledLuaState() *pooledLuaState { + state := lua.NewState(lua.Options{SkipOpenLibs: true}) + openLuaLib(state, lua.BaseLibName, lua.OpenBase) + openLuaLib(state, lua.TabLibName, lua.OpenTable) + openLuaLib(state, lua.StringLibName, lua.OpenString) + openLuaLib(state, lua.MathLibName, lua.OpenMath) + + for _, name := range []string{"dofile", "load", "loadfile", "loadstring", "module", "require"} { + state.SetGlobal(name, lua.LNil) + } + + // Pre-allocate the per-state context binding and stash it in the + // state's registry. redis.call / redis.pcall read this userdata + // (lock-free, per-state) to find the active *luaScriptContext for + // the current eval. + ctxBinding := state.NewUserData() + state.SetField(state.Get(lua.RegistryIndex), luaCtxRegistryKey, ctxBinding) + + registerPooledRedisModule(state) + registerCJSONModule(state) + registerCMsgpackModule(state) + + // Expose table.unpack as the top-level `unpack` just like the + // non-pooled path in initLuaGlobals does -- keeping the base set + // identical across paths avoids subtle semantic drift. + if tableModule, ok := state.GetGlobal("table").(*lua.LTable); ok { + if unpack := tableModule.RawGetString("unpack"); unpack != lua.LNil { + state.SetGlobal("unpack", unpack) + } + } + + globalsSnapshot, tableSnapshots, metatableSnapshots := snapshotGlobals(state) + return &pooledLuaState{ + state: state, + globalsSnapshot: globalsSnapshot, + tableSnapshots: tableSnapshots, + metatableSnapshots: metatableSnapshots, + ctxBinding: ctxBinding, + } +} + +// snapshotGlobals captures the full set of globals (string AND +// non-string keys) plus shallow snapshots of every nested table value +// reachable from _G, AND the raw metatable of each of those tables +// (plus _G itself). Returning all three lets resetPooledLuaState +// defeat three classes of pool-state leaks: +// +// 1. Non-string-keyed globals. Lua allows any non-nil, non-NaN value +// as a table key. A malicious script doing `_G[42] = "secret"` or +// `_G[true] = "bad"` would persist across pool reuse if we only +// snapshotted string keys. Iterating with ForEach over LValue keys +// closes this hole. +// +// 2. Table poisoning. Standard-library tables are mutable in +// gopher-lua, and the snapshot only holds a reference to the +// table object. A script doing `string.upper = function() return +// "pwned" end` mutates the shared table in place; merely +// re-binding the global name `string` to its original LTable +// value on reset is not enough. We therefore shallow-snapshot +// every LTable-typed global's contents at init time and restore +// them on reset. Inner tables are not recursed into -- they are +// expected to hold leaf values (functions, numbers, strings) in +// the libraries we install; if that ever changes, extend this. +// +// 3. Metatable poisoning. gopher-lua's base library exposes +// setmetatable, so a script can do +// `setmetatable(_G, { __index = function() return "pwned" end })` +// and the next pooled eval that reads any undefined global (which +// triggers __index) would observe attacker-controlled behaviour. +// The same risk applies to every whitelisted table (string, math, +// ...). Snapshotting each table's raw metatable at init lets +// reset put the original back; when a table had no metatable, +// the snapshot holds lua.LNil and reset strips whatever the +// script installed. +// +// We deliberately skip snapshotting _G's own contents as a "table +// snapshot": _G IS the globals table, so that entry would be +// redundant with the outer globals snapshot. Any other self-reference +// is handled the same way (by *LTable identity). _G's metatable is +// still captured, because the poisoning surface applies to _G too. +// +// We read each table's metatable via the exported LTable.Metatable +// field (not state.GetMetatable) to avoid dispatching through +// __metatable -- we want the raw pointer so SetMetatable can restore +// it verbatim. +func snapshotGlobals(state *lua.LState) ( + map[lua.LValue]lua.LValue, + map[*lua.LTable]map[lua.LValue]lua.LValue, + map[*lua.LTable]lua.LValue, +) { + globals := state.G.Global + snapshot := make(map[lua.LValue]lua.LValue, luaInitialGlobalsHint) + tableSnaps := make(map[*lua.LTable]map[lua.LValue]lua.LValue, luaWhitelistedTableHint) + metaSnaps := make(map[*lua.LTable]lua.LValue, luaWhitelistedTableHint+1) + + // _G itself is a poisoning target (setmetatable(_G, ...)). + metaSnaps[globals] = rawMetatable(globals) + + globals.ForEach(func(k, v lua.LValue) { + snapshot[k] = v + if tbl, ok := v.(*lua.LTable); ok && tbl != globals { + // Shallow copy the table's contents. Keys may be + // non-string (e.g. array-like entries). + inner := make(map[lua.LValue]lua.LValue, tbl.Len()+luaResetKeySlack) + tbl.ForEach(func(ik, iv lua.LValue) { + inner[ik] = iv + }) + tableSnaps[tbl] = inner + // Capture the raw metatable exactly once per *LTable. + // A given library table appears in _G under one name, so + // there is no duplication risk here in practice; even if + // there were, the value would be identical. + if _, seen := metaSnaps[tbl]; !seen { + metaSnaps[tbl] = rawMetatable(tbl) + } + } + }) + return snapshot, tableSnaps, metaSnaps +} + +// rawMetatable returns the LTable's raw metatable field, normalising a +// Go nil into lua.LNil so callers can pass the result straight to +// state.SetMetatable (which requires an LValue, not an untyped nil). +// We bypass state.GetMetatable deliberately: that path respects the +// __metatable field and can return something other than the real +// metatable, which would corrupt restore-on-reset if a script set +// __metatable = "blocked". +func rawMetatable(tbl *lua.LTable) lua.LValue { + if tbl.Metatable == nil { + return lua.LNil + } + return tbl.Metatable +} + +// resetPooledLuaState wipes all user-introduced globals and restores +// the whitelisted ones (including the contents of nested tables like +// `string`, `math`, `redis`), then truncates the value stack. It is +// the heart of the security invariant: anything the script did to +// globals must not be observable by the next user. +// +// Ordering matters: +// 1. Restore every snapshotted table's metatable FIRST. A poisoned +// __index / __newindex would otherwise intercept the subsequent +// RawSet / ForEach work we do to clean up fields. In practice +// RawSet bypasses metamethods already, but restoring the +// metatable first keeps any future code that uses non-raw access +// safe-by-construction. +// 2. Reset nested whitelisted tables' field sets. Doing this BEFORE +// restoring the globals' top-level bindings means we mutate the +// ORIGINAL table objects (the ones snapshot still references by +// pointer), even if the script rebound `string = nil` at the +// global level -- the original LTable is still alive and held +// via our tableSnapshots map key. +// 3. Delete top-level globals not in the snapshot (KEYS, ARGV, +// GLOBAL_LEAK, _G[42], etc). We iterate ALL key types, not just +// strings, so non-string-keyed leaks (`_G[42] = "secret"`) do not +// survive. +// 4. Restore top-level whitelisted globals. This fixes e.g. +// `redis = nil` by re-binding `redis` to the original module +// table. +func (p *pooledLuaState) reset() { + globals := p.state.G.Global + + // (1) Restore the raw metatable of every snapshotted table. + // This blocks setmetatable(_G, {__index=...}) and + // setmetatable(string, {...}) from leaking a poisoned fallback + // into the next eval. SetMetatable with lua.LNil strips any + // metatable the script installed where there was none originally. + for tbl, mt := range p.metatableSnapshots { + p.state.SetMetatable(tbl, mt) + } + + // (2) Restore inner contents of every snapshotted whitelisted + // table. This defeats poisoning attacks like + // `string.upper = function() return "pwned" end`. + // + // resetTableContents borrows p.scratchKeys as a working slice. + // We pass it in and receive the (possibly grown) backing array + // back so successive calls within the same reset share one + // allocation. + scratch := p.scratchKeys[:0] + for tbl, originalFields := range p.tableSnapshots { + scratch = resetTableContents(tbl, originalFields, scratch[:0]) + } + + // (3) Collect all current global keys (of any type). Mutating + // the table inside ForEach is unsafe, so snapshot keys first. + scratch = scratch[:0] + globals.ForEach(func(k, _ lua.LValue) { + scratch = append(scratch, k) + }) + + // Delete any key not in the init-time snapshot: these are + // user-introduced globals (KEYS, ARGV, GLOBAL_LEAK, _G[42], + // _G[true], ...). + // + // We use RawSet (not RawSetH) because gopher-lua stores integer + // keys in an internal `array` slice rather than `dict`; RawSetH + // only touches `dict`, so a call like RawSetH(LNumber(42), LNil) + // leaves the array entry intact. RawSet dispatches to the right + // storage by key type. + for _, k := range scratch { + if _, keep := p.globalsSnapshot[k]; !keep { + globals.RawSet(k, lua.LNil) + } + } + + // (4) Restore every whitelisted global to its original value. + // This covers the case where a script rebinds an allowed global + // (e.g. `redis = something`) -- we simply put the original back. + for k, v := range p.globalsSnapshot { + globals.RawSet(k, v) + } + + // Drop anything the script may have left on the value stack. + p.state.SetTop(0) + + // Clear any request-scoped context bound to the state via + // LState.SetContext (done in runLuaScript). Without this, the + // pooled *lua.LState keeps a reference to the previous request's + // context.Context -- and transitively anything the context retains + // (timers, cancel funcs, attached values) -- until the state is + // reused or garbage-collected. That causes memory retention and + // delays cancellation propagation for the prior request's chain. + // RemoveContext is the canonical API for this and is preferred over + // SetContext(context.Background()) for clearer intent. + p.state.RemoveContext() + + // Retain scratch for the next reset, but bound the backing array + // so a pathological script that created thousands of globals does + // not permanently bloat every pooled state. If we exceeded the + // cap, drop the slice -- the next reset will reallocate at the + // modest default size. + if cap(scratch) > luaScratchKeysMaxCap { + p.scratchKeys = nil + } else { + p.scratchKeys = scratch[:0] + } +} + +// resetTableContents restores tbl's entries so that it exactly +// matches originalFields: extra keys added by the script are deleted, +// and every original key is re-bound to its original value. Inner +// tables are treated as shallow: if a script mutated `string.upper`, +// the original function value (still alive via originalFields) is +// put back; if a script added a new field (`string.pwn = 1`), the +// field is deleted. +// +// scratch is a caller-provided slice used to buffer the current key +// set (we cannot mutate a table while ForEach iterates it). The +// (possibly grown) slice is returned so the caller can keep reusing +// the underlying array across invocations. +func resetTableContents(tbl *lua.LTable, originalFields map[lua.LValue]lua.LValue, scratch []lua.LValue) []lua.LValue { + currentKeys := scratch[:0] + tbl.ForEach(func(k, _ lua.LValue) { + currentKeys = append(currentKeys, k) + }) + for _, k := range currentKeys { + if _, keep := originalFields[k]; !keep { + tbl.RawSet(k, lua.LNil) + } + } + for k, v := range originalFields { + tbl.RawSet(k, v) + } + return currentKeys +} + +// get acquires a pooled state and binds the given *luaScriptContext +// so that redis.call / redis.pcall can see it. Binding is a single +// pointer write to the state-local ctxBinding userdata -- no lock, +// no global map. +// +// Because newLuaStatePool does NOT set sync.Pool.New, p.pool.Get() +// returns nil when the pool is empty; that is the signal for a miss +// (fresh allocation). A non-nil return is a genuine reuse and counts +// as a hit. The defensive type-assertion guard preserves behaviour if +// a future refactor ever puts something unexpected into the pool. +func (p *luaStatePool) get(ctx *luaScriptContext) *pooledLuaState { + v := p.pool.Get() + if v == nil { + p.misses.Add(1) + pls := newPooledLuaState() + pls.ctxBinding.Value = ctx + return pls + } + pls, ok := v.(*pooledLuaState) + if !ok || pls == nil { + // Defence in depth: anything other than a *pooledLuaState is + // treated as an allocation miss rather than a silent hit. + p.misses.Add(1) + pls = newPooledLuaState() + pls.ctxBinding.Value = ctx + return pls + } + p.hits.Add(1) + pls.ctxBinding.Value = ctx + return pls +} + +// put resets the state and returns it to the pool. If the state is +// somehow closed (shouldn't happen on the happy path), it is dropped +// so a dead VM is never handed out again. +func (p *luaStatePool) put(pls *pooledLuaState) { + if pls == nil || pls.state == nil { + return + } + // Clear the binding so a stale *luaScriptContext cannot be + // observed via a pooled state that is briefly re-acquired by a + // future get() before the caller writes a fresh context. + if pls.ctxBinding != nil { + pls.ctxBinding.Value = nil + } + if pls.state.IsClosed() { + return + } + pls.reset() + p.pool.Put(pls) +} + +// Hits / Misses are test hooks. They count Get() outcomes, not +// allocations proper, but in practice they track allocation avoidance +// well enough for the "is the pool actually being used?" test. +func (p *luaStatePool) Hits() uint64 { return p.hits.Load() } +func (p *luaStatePool) Misses() uint64 { return p.misses.Load() } + +// registerPooledRedisModule installs redis.call / redis.pcall / +// redis.sha1hex / redis.status_reply / redis.error_reply where the +// call/pcall closures resolve the *luaScriptContext per-invocation +// via luaLookupContext, so a single pre-registered module works for +// every eval the state is reused for. +func registerPooledRedisModule(state *lua.LState) { + module := state.NewTable() + state.SetFuncs(module, map[string]lua.LGFunction{ + "call": func(scriptState *lua.LState) int { + ctx, ok := luaLookupContext(scriptState) + // Must guard against ctx == nil as well as !ok: the + // bench path and misuse can luaBindContext(nil), which + // stores a (nil, true) entry. Dereferencing that in + // luaRedisCommand would panic. + if !ok || ctx == nil { + scriptState.RaiseError("redis.call invoked without an active script context") + return 0 + } + return luaRedisCommand(scriptState, ctx, true) + }, + "pcall": func(scriptState *lua.LState) int { + ctx, ok := luaLookupContext(scriptState) + if !ok || ctx == nil { + scriptState.Push(luaErrorTable(scriptState, "redis.pcall invoked without an active script context")) + return 1 + } + return luaRedisCommand(scriptState, ctx, false) + }, + "sha1hex": func(scriptState *lua.LState) int { + scriptState.Push(lua.LString(luaScriptSHA(scriptState.CheckString(1)))) + return 1 + }, + "status_reply": func(scriptState *lua.LState) int { + reply := scriptState.NewTable() + reply.RawSetString(luaTypeOKKey, lua.LString(scriptState.CheckString(1))) + scriptState.Push(reply) + return 1 + }, + "error_reply": func(scriptState *lua.LState) int { + reply := scriptState.NewTable() + reply.RawSetString(luaTypeErrKey, lua.LString(scriptState.CheckString(1))) + scriptState.Push(reply) + return 1 + }, + }) + state.SetGlobal("redis", module) +} diff --git a/adapter/redis_lua_pool_test.go b/adapter/redis_lua_pool_test.go new file mode 100644 index 00000000..f0138b7e --- /dev/null +++ b/adapter/redis_lua_pool_test.go @@ -0,0 +1,626 @@ +package adapter + +import ( + "context" + "sync" + "sync/atomic" + "testing" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + lua "github.com/yuin/gopher-lua" +) + +// BenchmarkLuaState_NewVsPooled compares the cost of minting a brand +// new *lua.LState per call (matching the pre-pool hot path) against +// pulling one out of the pool and resetting. Use: +// +// go test -run='^$' -bench=BenchmarkLuaState_NewVsPooled -benchmem ./adapter/ +// +// On the author's laptop (darwin/arm64, go1.26) it shows roughly a +// 10x reduction in B/op and allocs/op for the pooled path. +func BenchmarkLuaState_NewVsPooled(b *testing.B) { + b.Run("new_state_per_call", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + s := newPooledLuaState() + // Simulate a trivial KEYS/ARGV set + small script. + s.state.SetGlobal("KEYS", s.state.NewTable()) + s.state.SetGlobal("ARGV", s.state.NewTable()) + if err := s.state.DoString(`return 1 + 1`); err != nil { + b.Fatal(err) + } + s.state.Close() + } + }) + + b.Run("pooled_state", func(b *testing.B) { + pool := newLuaStatePool() + // Prime the pool so the first iteration is a hit. + pool.put(pool.get(nil)) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + pls := pool.get(nil) + pls.state.SetGlobal("KEYS", pls.state.NewTable()) + pls.state.SetGlobal("ARGV", pls.state.NewTable()) + if err := pls.state.DoString(`return 1 + 1`); err != nil { + b.Fatal(err) + } + pool.put(pls) + } + }) +} + +// TestLua_VMReuseDoesNotLeakGlobals is the load-bearing safety test +// for the pool. Script A assigns GLOBAL_LEAK = 42 at the Lua level; +// script B then executes on a *lua.LState obtained from the same +// pool and asserts that GLOBAL_LEAK is nil. +// +// It also asserts that script B sees a fresh KEYS / ARGV and that +// the pool did hand back the same underlying *lua.LState (pool hit), +// which is the whole point of the optimisation. +func TestLua_VMReuseDoesNotLeakGlobals(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + // --- Script A: sets a user global ----------------------------- + plsA := pool.get(nil) // nil ctx is fine: scriptA does not call redis.call. + stateA := plsA.state + require.NoError(t, stateA.DoString(`GLOBAL_LEAK = 42`)) + // Also add a random table global to stress the reset path on + // non-scalar user additions. + require.NoError(t, stateA.DoString(`LEAKY_TABLE = { x = 1, y = 2 }`)) + require.Equal(t, lua.LNumber(42), stateA.GetGlobal("GLOBAL_LEAK")) + ptrA := stateA + pool.put(plsA) + + // --- Script B: same pool, no leak ----------------------------- + // sync.Pool is free to allocate a fresh item even immediately + // after a put under race/GC, so we do not assert pointer + // identity here. To assert the pool is effective at all, see + // TestLua_PoolRecordsReuseVsAllocation which uses the hit counter. + // What we DO assert is the security invariant: whichever state + // we got, it must not observe the leaked globals from script A. + _ = ptrA + plsB := pool.get(nil) + stateB := plsB.state + + require.Equal(t, lua.LNil, stateB.GetGlobal("GLOBAL_LEAK"), + "GLOBAL_LEAK leaked from prior script -- security invariant broken") + require.Equal(t, lua.LNil, stateB.GetGlobal("LEAKY_TABLE"), + "LEAKY_TABLE leaked from prior script -- security invariant broken") + + // Whitelisted globals must still be intact for script B. + require.NotEqual(t, lua.LNil, stateB.GetGlobal("redis"), + "redis module missing after pool reuse") + require.NotEqual(t, lua.LNil, stateB.GetGlobal("cjson"), + "cjson module missing after pool reuse") + require.NotEqual(t, lua.LNil, stateB.GetGlobal("cmsgpack"), + "cmsgpack module missing after pool reuse") + require.NotEqual(t, lua.LNil, stateB.GetGlobal("string"), + "string stdlib missing after pool reuse") + + // Script B can still run normal Lua that depends on the + // whitelisted base libs. + require.NoError(t, stateB.DoString(`assert(string.upper("ok") == "OK")`)) + pool.put(plsB) + + // Pool should have registered at least one hit by now. + require.GreaterOrEqual(t, pool.Hits(), uint64(1), "pool never reported a hit") +} + +// TestLua_VMReuseRestoresRebindsWhitelistedGlobals guards against a +// script that overwrites an allowed global (e.g. `redis = nil`). The +// reset must put the original back so the next script isn't affected. +func TestLua_VMReuseRestoresRebindsWhitelistedGlobals(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + plsA := pool.get(nil) + // Try to sabotage pooled state: wipe redis and hijack string.upper. + require.NoError(t, plsA.state.DoString(`redis = nil; string = { upper = function() return "pwned" end }`)) + require.Equal(t, lua.LNil, plsA.state.GetGlobal("redis")) + pool.put(plsA) + + plsB := pool.get(nil) + defer pool.put(plsB) + require.NotEqual(t, lua.LNil, plsB.state.GetGlobal("redis"), + "redis global was not restored after sabotage; security invariant broken") + + // Original string lib must be restored such that string.upper works correctly. + require.NoError(t, plsB.state.DoString(`assert(string.upper("abc") == "ABC", "string.upper was poisoned")`)) +} + +// TestLua_PoolSerialAcquireReusesState verifies the pool serves +// existing *lua.LState instances in sequential acquire/release cycles +// -- the knob we care about for the heap-pressure win. sync.Pool is +// free to reclaim under GC pressure, so we cannot assert on the exact +// pointer; instead we count hits vs misses via the test hook. +func TestLua_PoolSerialAcquireReusesState(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + // Prime the pool so the first Get allocates. + pool.put(pool.get(nil)) + + const iters = 50 + for i := 0; i < iters; i++ { + pls := pool.get(nil) + pool.put(pls) + } + // At least one hit proves the pool is actually handing back an + // existing VM rather than minting a new one every time. + require.GreaterOrEqual(t, pool.Hits(), uint64(1), + "pool never reported a hit; sync.Pool reuse not happening") +} + +// TestLua_PoolRecordsReuseVsAllocation pins down the "is the pool +// actually doing anything?" question via the hit/miss counters. The +// test guards against the subtle regression where sync.Pool.New is +// (re-)configured: with a New func set, p.pool.Get() on an empty +// pool would auto-construct and never return nil, so hit/miss +// tracking would be meaningless. Two sub-scenarios are exercised: +// +// 1. Miss branch: a get() on a brand-new pool has nothing to hand +// out. It must increment the miss counter (fresh allocation) and +// leave hits at zero. This is deterministic -- sync.Pool's own +// scheduling cannot turn an empty pool into a non-empty one. +// 2. Hit branch: after many put/get cycles at least one acquire +// must actually be served from the pool. sync.Pool under -race +// randomises per-P caching and can drop items, so we cannot +// assert on a single put/get round-trip; instead we run a loop +// large enough that the probability of zero reuse is negligible. +// +// If sync.Pool.New were accidentally re-introduced, the miss branch +// (step 1) would fail immediately: Misses would be 0, Hits would be 1. +func TestLua_PoolRecordsReuseVsAllocation(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + // Scenario 1: empty pool -> miss. Deterministic. + plsA := pool.get(nil) + require.NotNil(t, plsA, "get on empty pool must allocate a fresh state, not return nil") + require.Equal(t, uint64(0), pool.Hits(), + "empty pool must not record a hit on first acquire -- sync.Pool.New likely reintroduced") + require.Equal(t, uint64(1), pool.Misses(), + "empty pool must record exactly one miss on first acquire") + pool.put(plsA) + + // Scenario 2: with the state now available, a loop of get/put + // cycles must observe at least one genuine reuse. We cannot + // assert on a single round-trip because sync.Pool under -race + // may drop the freshly-put item from the local P cache; over + // many iterations, however, at least one must be served. + const iters = 500 + for i := 0; i < iters; i++ { + pool.put(pool.get(nil)) + } + require.Greater(t, pool.Hits(), uint64(0), + "pool reported zero hits across %d cycles -- reuse not happening", iters) + // The total acquires must sum to Hits + Misses = iters + 1 (the + // initial get outside the loop). This invariant catches a bug + // where get() forgets to increment either counter on some path. + require.Equal(t, uint64(iters+1), pool.Hits()+pool.Misses(), + "hit+miss counters must sum to total acquires; got hits=%d misses=%d", + pool.Hits(), pool.Misses()) +} + +// TestLua_VMReuseNonStringGlobalKeysAreWiped guards against a leak +// vector missed by the original reset: globals keyed by types other +// than string. Lua permits any non-nil, non-NaN value as a table key, +// so a script doing `_G[42] = "leak"` or `_G[true] = "bad"` bypasses a +// naive string-only snapshot/wipe. The LValue-keyed snapshot + the +// RawSet-based reset in pool.reset must catch these. RawSet (rather +// than RawSetH) matters because gopher-lua stores integer keys in the +// array part, and only RawSet dispatches to the right storage by key +// type. +func TestLua_VMReuseNonStringGlobalKeysAreWiped(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + plsA := pool.get(nil) + // Set non-string-keyed globals directly via _G. This is the + // attack surface being regression-tested. + require.NoError(t, plsA.state.DoString(`_G[42] = "leak"; _G[true] = "bad"`)) + // Sanity: script A sees what it set. + require.NoError(t, plsA.state.DoString(`assert(_G[42] == "leak" and _G[true] == "bad")`)) + pool.put(plsA) + + plsB := pool.get(nil) + defer pool.put(plsB) + // If either leaks, DoString errors out via Lua's assert and + // the test fails with the error message. + require.NoError(t, plsB.state.DoString( + `assert(_G[42] == nil and _G[true] == nil, "non-string-keyed global leaked across pool reuse")`)) +} + +// TestLua_VMReuseDoesNotPoisonStringLib regression-tests the table +// poisoning fix. Script A mutates `string.upper` in place (not via +// rebinding the `string` global), which survives a naive snapshot +// that only restores the top-level `string` reference. The new +// tableSnapshots mechanism must restore the original `string.upper` +// function so script B's string.upper("x") == "X" holds. +func TestLua_VMReuseDoesNotPoisonStringLib(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + plsA := pool.get(nil) + require.NoError(t, plsA.state.DoString(` +string.upper = function() return "pwned" end +-- Sanity: script A sees its own sabotage. +assert(string.upper("x") == "pwned") +-- Add a rogue field too -- must also be cleaned up. +string.pwn = 1 +`)) + pool.put(plsA) + + plsB := pool.get(nil) + defer pool.put(plsB) + require.NoError(t, plsB.state.DoString(` +assert(string.upper("x") == "X", "string.upper was poisoned across pool reuse") +assert(string.pwn == nil, "script-added field on string leaked across pool reuse") +-- Same for other whitelisted tables. +assert(type(math.floor) == "function", "math.floor was wiped") +assert(type(table.insert) == "function", "table.insert was wiped") +`)) +} + +// TestLua_VMReuseDoesNotPoisonRedisModule covers the same poisoning +// class but on the pool-registered `redis` table itself. A script +// that replaces redis.sha1hex with a sabotaged implementation must +// not affect subsequent scripts. +func TestLua_VMReuseDoesNotPoisonRedisModule(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + plsA := pool.get(nil) + require.NoError(t, plsA.state.DoString(` +redis.sha1hex = function() return "deadbeef" end +assert(redis.sha1hex("x") == "deadbeef") +`)) + pool.put(plsA) + + plsB := pool.get(nil) + defer pool.put(plsB) + // Any non-"deadbeef" digest proves the original sha1hex is back. + require.NoError(t, plsB.state.DoString(` +local got = redis.sha1hex("x") +assert(got ~= "deadbeef", "redis.sha1hex remained poisoned after pool reuse: " .. tostring(got)) +assert(#got == 40, "redis.sha1hex returned non-hex value after reset: " .. tostring(got)) +`)) +} + +// TestLua_VMReuseDoesNotPoisonGlobalsMetatable regression-tests the +// metatable-snapshot fix. gopher-lua's base lib exposes setmetatable, +// so a script can install an __index handler on _G and poison every +// subsequent pooled eval's view of undefined globals. We verify that +// after Script A poisons _G's metatable, Script B -- acquired from +// the pool after A's state is released -- reads `_G.undefined` as +// genuine nil rather than the attacker-supplied sentinel. +// +// The test also pokes `_G[nonExisting]` via a local to make sure the +// leak path is _G's __index specifically and not something else (e.g. +// a leftover global called "undefined"). We use a freshly-minted +// symbol name on both sides to avoid interference with any snapshot +// entry the fix itself would restore. +func TestLua_VMReuseDoesNotPoisonGlobalsMetatable(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + plsA := pool.get(nil) + require.NoError(t, plsA.state.DoString(` +setmetatable(_G, { __index = function() return "leak" end }) +-- Sanity: script A sees its own poisoned __index. +assert(_G.some_never_defined_symbol == "leak", + "setmetatable on _G did not take effect inside script A") +`)) + pool.put(plsA) + + plsB := pool.get(nil) + defer pool.put(plsB) + require.NoError(t, plsB.state.DoString(` +-- An undefined global must read as nil, not the attacker sentinel. +local v = _G.some_never_defined_symbol +assert(v == nil, + "globals metatable leaked across pool reuse: got " .. tostring(v)) +-- And installing a fresh metatable on _G must still work (we didn't +-- accidentally lock _G via __metatable or anything similar). +setmetatable(_G, nil) +`)) +} + +// TestLua_VMReuseDoesNotPoisonStringMetatable covers the same +// metatable-poisoning risk, but applied to the string library table. +// gopher-lua resolves method-style calls on string literals (e.g. +// `("x"):upper()`) via the string builtin metatable, not via the +// string table's metatable -- so this test specifically guards +// against a script that installs an __index on the string table and +// relies on subsequent scripts fetching fields off that table (e.g. +// in code that walks the library dynamically). +func TestLua_VMReuseDoesNotPoisonStringMetatable(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + plsA := pool.get(nil) + require.NoError(t, plsA.state.DoString(` +setmetatable(string, { __index = function() return "leak" end }) +-- Sanity: the poisoned __index fires on absent fields. +assert(string.no_such_function == "leak", + "setmetatable on string did not take effect inside script A") +`)) + pool.put(plsA) + + plsB := pool.get(nil) + defer pool.put(plsB) + require.NoError(t, plsB.state.DoString(` +local v = string.no_such_function +assert(v == nil, + "string metatable leaked across pool reuse: got " .. tostring(v)) +-- string.upper must still be the genuine builtin. +assert(string.upper("x") == "X", "string.upper was damaged by reset") +`)) +} + +// TestLua_PoolNilContextProducesErrorNotPanic is the regression test +// for the nil-context nil-pointer deref. Before the fix, calling +// redis.call with a pool entry bound to a nil *luaScriptContext -- +// which happens in the bench path via pool.get(nil) -- would panic in +// luaRedisCommand. After the fix it surfaces as a clean Lua error. +func TestLua_PoolNilContextProducesErrorNotPanic(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + pls := pool.get(nil) // explicit nil context + defer pool.put(pls) + + // redis.call must raise a Lua error rather than panicking in + // Go; the returned error wraps the Lua error message. + err := pls.state.DoString(`redis.call("GET", "x")`) + require.Error(t, err, "redis.call with nil context should return an error") + require.Contains(t, err.Error(), "redis.call invoked without an active script context") + + // redis.pcall must not panic either; it should push a Lua + // error table. The DoString itself returns no Go error -- + // pcall is the pcall path -- but the returned value carries + // the err field. + require.NoError(t, pls.state.DoString(` +local reply = redis.pcall("GET", "x") +assert(type(reply) == "table", "redis.pcall should return a table even with nil context") +assert(type(reply.err) == "string", "redis.pcall error reply must carry .err") +assert(reply.err:find("redis.pcall invoked without an active script context") ~= nil, + "redis.pcall error reply text mismatch: " .. tostring(reply.err)) +`)) +} + +// TestRedis_LuaPoolNoGlobalLeakEndToEnd drives the full EVAL path on +// a live RedisServer to make sure the pool integration (not just the +// pool in isolation) holds the security invariant. Script A tries to +// leak GLOBAL_LEAK; script B asserts the leak is gone. +func TestRedis_LuaPoolNoGlobalLeakEndToEnd(t *testing.T) { + nodes, _, _ := createNode(t, 3) + defer shutdown(nodes) + + ctx := context.Background() + rdb := redis.NewClient(&redis.Options{Addr: nodes[0].redisAddress}) + defer func() { _ = rdb.Close() }() + + // Script A: set a leaking global. + _, err := rdb.Eval(ctx, `GLOBAL_LEAK = 42; return 1`, nil).Result() + require.NoError(t, err) + + // Script B: assert that GLOBAL_LEAK is nil from its point of view. + // Returning the raw value would conflate nil with Redis' nil-bulk; + // instead, return a sentinel string and check. + out, err := rdb.Eval(ctx, ` +if GLOBAL_LEAK == nil then + return "clean" +else + return "leaked:" .. tostring(GLOBAL_LEAK) +end`, nil).Result() + require.NoError(t, err) + require.Equal(t, "clean", out, "pooled *lua.LState leaked a global to a subsequent script") + + // Sanity: the pooled state still supports the standard shared modules. + out2, err := rdb.Eval(ctx, `return cjson.encode({a = 1})`, nil).Result() + require.NoError(t, err) + require.Equal(t, `{"a":1}`, out2) +} + +// TestLua_PoolConcurrentContextIsolation is the regression test for +// the HIGH-priority concurrency fix. It asserts that when many +// goroutines concurrently get / bind / lookup / put pooled states, +// each goroutine's redis.call closure observes *its own* +// *luaScriptContext -- never another goroutine's context. +// +// Before the fix, the global luaStateBindings map + sync.RWMutex was +// a single contention point on every redis.call. After the fix, each +// state reads an *LUserData from its own registry, which must never +// point at a different goroutine's context even under heavy +// interleaving. Run with `go test -race -count=5 -run TestLua_Pool`. +func TestLua_PoolConcurrentContextIsolation(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + + const ( + goroutines = 64 + lookupsPerScr = 100 + ) + + var ( + mismatches atomic.Int64 + wg sync.WaitGroup + ) + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func() { + defer wg.Done() + // Each goroutine uses a distinct context pointer so that + // observing a wrong-valued pointer is a detectable bug. + ownCtx := &luaScriptContext{} + pls := pool.get(ownCtx) + // Simulate many redis.call lookups inside one script. + for i := 0; i < lookupsPerScr; i++ { + observed, ok := luaLookupContext(pls.state) + if !ok || observed != ownCtx { + mismatches.Add(1) + } + } + pool.put(pls) + }() + } + wg.Wait() + + require.EqualValues(t, 0, mismatches.Load(), + "concurrent goroutines observed a wrong context via luaLookupContext -- state-local binding is broken") +} + +// TestLua_PoolContextIsRegistryBacked asserts the binding lives in the +// state's own Lua registry -- the very thing that frees us from the +// global sync.RWMutex. If a refactor ever reintroduces a global map, +// this test pins down the contract. +func TestLua_PoolContextIsRegistryBacked(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + ctx := &luaScriptContext{} + pls := pool.get(ctx) + defer pool.put(pls) + + ud, ok := pls.state.GetField(pls.state.Get(lua.RegistryIndex), luaCtxRegistryKey).(*lua.LUserData) + require.True(t, ok, "ctx binding userdata missing from state registry") + require.Same(t, pls.ctxBinding, ud, "registry userdata differs from pooledLuaState.ctxBinding") + storedCtx, ok := ud.Value.(*luaScriptContext) + require.True(t, ok, "registry userdata value is not a *luaScriptContext") + require.Same(t, ctx, storedCtx, + "registry userdata value does not point at the bound script context") +} + +// TestLua_PoolScratchKeysReused verifies the MEDIUM allocation fix. +// After a reset, pooledLuaState.scratchKeys must retain a non-nil +// backing array (sliced to zero length) so the next reset reuses it +// instead of minting a new one. We also verify the luaScratchKeysMaxCap +// bound kicks in for pathological scripts. +func TestLua_PoolScratchKeysReused(t *testing.T) { + t.Parallel() + + pls := newPooledLuaState() + + // First reset primes scratchKeys from nil to a real backing array. + pls.reset() + require.NotNil(t, pls.scratchKeys, + "scratchKeys still nil after reset; no reuse buffer was retained") + require.Equal(t, 0, len(pls.scratchKeys), + "scratchKeys must be reset to zero length for reuse") + firstCap := cap(pls.scratchKeys) + require.Greater(t, firstCap, 0, "scratchKeys capacity must be > 0 after priming") + + // Second reset must reuse the same backing array (cap unchanged + // or grown, never shrunk). + pls.reset() + require.GreaterOrEqual(t, cap(pls.scratchKeys), firstCap, + "scratchKeys backing array was discarded between resets; no reuse") + + // Force the cap-bound path: manually push scratchKeys past the + // bound, reset, and assert it is dropped. + pls.scratchKeys = make([]lua.LValue, 0, luaScratchKeysMaxCap+16) + pls.reset() + require.LessOrEqual(t, cap(pls.scratchKeys), luaScratchKeysMaxCap, + "scratchKeys was not bounded; pathological scripts can pin unbounded memory") +} + +// BenchmarkLuaLookupContext_Concurrent measures the cost of the +// redis.call context lookup under high fan-out. This is the bench the +// Gemini reviewer called out: ~50 lookups/script/s across concurrent +// scripts used to hammer a global RWMutex. Now it should be a +// lock-free per-state read. +// +// go test -run='^$' -bench=BenchmarkLuaLookupContext_Concurrent -benchtime=5s ./adapter/ +func BenchmarkLuaLookupContext_Concurrent(b *testing.B) { + pool := newLuaStatePool() + // Prime a handful of states so pool.Get is warm. + for i := 0; i < 8; i++ { + pool.put(pool.get(&luaScriptContext{})) + } + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + ctx := &luaScriptContext{} + for pb.Next() { + pls := pool.get(ctx) + // Simulate 5 redis.call invocations per script. + for i := 0; i < 5; i++ { + if got, ok := luaLookupContext(pls.state); !ok || got != ctx { + b.Fatalf("wrong ctx observed: got=%p want=%p ok=%v", got, ctx, ok) + } + } + pool.put(pls) + } + }) +} + +// TestLua_VMReuseClearsContext verifies that a pooled *lua.LState +// does NOT retain a reference to a previous request's +// context.Context after it has been returned to the pool via +// pool.put. +// +// runLuaScript binds a per-request context onto the state with +// LState.SetContext (redis_lua.go). If pooledLuaState.reset() fails +// to clear that binding, the pooled VM keeps the prior ctx alive +// until it is either reused or garbage collected -- that retains +// any timers / cancel funcs / attached values referenced by the +// context. The reset() path must call LState.RemoveContext (or +// equivalently SetContext(context.Background())) to prevent this. +// +// We check both conditions: +// 1. After put, LState.Context() must NOT return the original +// request ctx (identity compare). +// 2. After put, LState.Context() must be nil (RemoveContext's +// documented post-state). +func TestLua_VMReuseClearsContext(t *testing.T) { + t.Parallel() + + pool := newLuaStatePool() + pls := pool.get(nil) + + // Simulate what runLuaScript does: attach a request-scoped ctx + // to the state. Use WithCancel so we have a distinct, non-Background + // identity that the state would measurably retain if reset() is + // broken. + reqCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + pls.state.SetContext(reqCtx) + + // Sanity: the binding was actually observed by the state. + require.Same(t, reqCtx, pls.state.Context(), + "precondition: SetContext must bind the given ctx identity") + + // Release back to the pool. This is the code path that must + // clear the ctx retention. + pool.put(pls) + + // After put, the pooled state must have dropped the ctx reference. + got := pls.state.Context() + require.Nil(t, got, + "pooled LState must not retain a ctx reference after reset/put") + // Belt-and-braces identity check: even if a future gopher-lua + // version ever returns a non-nil Background-style ctx here, it + // must NOT be the original request ctx. + if got != nil { + require.NotSame(t, reqCtx, got, + "pooled LState leaked the original request ctx across put") + } +}