Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions adapter/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ type RedisServer struct {
pubsub *redisPubSub
scriptMu sync.RWMutex
scriptCache map[string]string
luaPool *luaStatePool
luaPoolOnce sync.Once
traceCommands bool
traceSeq atomic.Uint64
redisAddr string
Expand Down Expand Up @@ -406,6 +408,7 @@ func NewRedisServer(listen net.Listener, redisAddr string, store store.MVCCStore
leaderClients: make(map[string]*redis.Client),
pubsub: newRedisPubSub(),
scriptCache: map[string]string{},
luaPool: newLuaStatePool(),
traceCommands: os.Getenv("ELASTICKV_REDIS_TRACE") == "1",
baseCtx: baseCtx,
baseCancel: baseCancel,
Expand Down
67 changes: 11 additions & 56 deletions adapter/redis_lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,13 @@ func (r *RedisServer) runLuaScript(conn redcon.Conn, script string, evalArgs [][
return err
}
defer scriptCtx.Close()
state := newRedisLuaState()
defer state.Close()

luaPool := r.getLuaPool()
pls := luaPool.get(scriptCtx)
defer luaPool.put(pls)
state := pls.state
state.SetContext(ctx)
r.initLuaGlobals(state, scriptCtx, keys, argv)
r.initPooledLuaScriptGlobals(state, keys, argv)

chunk, err := state.LoadString(script)
if err != nil {
Expand Down Expand Up @@ -204,32 +207,13 @@ func parseRedisEvalArgs(args [][]byte) ([][]byte, [][]byte, error) {
return keys, argv, nil
}

func (r *RedisServer) initLuaGlobals(state *lua.LState, ctx *luaScriptContext, keys [][]byte, argv [][]byte) {
// initPooledLuaScriptGlobals installs the per-eval globals (KEYS, ARGV)
// on a pooled *lua.LState. All shared modules (redis, cjson, cmsgpack,
// table.unpack => unpack) are wired once at pool-fill time and restored
// on release; see redis_lua_pool.go for the reset invariant.
func (r *RedisServer) initPooledLuaScriptGlobals(state *lua.LState, keys [][]byte, argv [][]byte) {
state.SetGlobal("KEYS", makeLuaStringArray(state, keys))
state.SetGlobal("ARGV", makeLuaStringArray(state, argv))
registerRedisModule(state, ctx)
registerCJSONModule(state)
registerCMsgpackModule(state)

tableModule := state.GetGlobal("table")
if tbl, ok := tableModule.(*lua.LTable); ok {
if unpack := tbl.RawGetString("unpack"); unpack != lua.LNil {
state.SetGlobal("unpack", unpack)
}
}
}

func newRedisLuaState() *lua.LState {
state := lua.NewState(lua.Options{SkipOpenLibs: true})
openLuaLib(state, lua.BaseLibName, lua.OpenBase)
openLuaLib(state, lua.TabLibName, lua.OpenTable)
openLuaLib(state, lua.StringLibName, lua.OpenString)
openLuaLib(state, lua.MathLibName, lua.OpenMath)

for _, name := range []string{"dofile", "load", "loadfile", "loadstring", "module", "require"} {
state.SetGlobal(name, lua.LNil)
}
return state
}

func openLuaLib(state *lua.LState, name string, fn lua.LGFunction) {
Expand All @@ -238,35 +222,6 @@ func openLuaLib(state *lua.LState, name string, fn lua.LGFunction) {
state.Call(1, 0)
}

func registerRedisModule(state *lua.LState, ctx *luaScriptContext) {
module := state.NewTable()
state.SetFuncs(module, map[string]lua.LGFunction{
"call": func(scriptState *lua.LState) int {
return luaRedisCommand(scriptState, ctx, true)
},
"pcall": func(scriptState *lua.LState) int {
return luaRedisCommand(scriptState, ctx, false)
},
"sha1hex": func(scriptState *lua.LState) int {
scriptState.Push(lua.LString(luaScriptSHA(scriptState.CheckString(1))))
return 1
},
"status_reply": func(scriptState *lua.LState) int {
reply := scriptState.NewTable()
reply.RawSetString(luaTypeOKKey, lua.LString(scriptState.CheckString(1)))
scriptState.Push(reply)
return 1
},
"error_reply": func(scriptState *lua.LState) int {
reply := scriptState.NewTable()
reply.RawSetString(luaTypeErrKey, lua.LString(scriptState.CheckString(1)))
scriptState.Push(reply)
return 1
},
})
state.SetGlobal("redis", module)
}

func luaRedisCommand(state *lua.LState, ctx *luaScriptContext, raise bool) int {
if state.GetTop() == 0 {
if raise {
Expand Down
Loading
Loading