From 9f7e79fda1c10eea18d0a492c48c976c651231cd Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 05:05:44 +0900 Subject: [PATCH 1/9] feat(sqs): scaffold SQS-compatible adapter Start the SQS-compatible HTTP adapter per docs/design/2026_04_24_proposed_sqs_compatible_adapter.md. This first increment wires the server skeleton, dispatch table, error envelope, health endpoints, and leader-proxy path so subsequent PRs can drop real handlers into a tested frame. Included: - adapter/sqs.go: SQSServer struct, NewSQSServer with SQSServerOption, Run/Stop, /sqs_health + /sqs_leader_health, X-Amz-Target dispatch for all Milestone-1/2 SQS operations (CreateQueue ... ListQueueTags), and the JSON-1.0 error envelope (__type + x-amzn-ErrorType + message). Every target currently replies 501 NotImplemented; handlers will be filled in separately. - adapter/sqs_test.go: health endpoints, method gating, error-envelope shape, unknown-target -> InvalidAction, every known target -> NotImplemented, and a Stop-unblocks-Run regression guard. - adapter/leader_http_proxy.go: shared proxyHTTPRequestToLeader helper factored out of the DynamoDB adapter. DynamoDB and SQS both delegate to it via per-adapter error writers, removing the structural dup the linter flagged. No main.go wiring or flag changes in this PR; that lands alongside the first real handler so the binary does not expose a 501-only endpoint. --- adapter/dynamodb.go | 36 ++--- adapter/leader_http_proxy.go | 64 +++++++++ adapter/sqs.go | 257 +++++++++++++++++++++++++++++++++++ adapter/sqs_test.go | 223 ++++++++++++++++++++++++++++++ 4 files changed, 552 insertions(+), 28 deletions(-) create mode 100644 adapter/leader_http_proxy.go create mode 100644 adapter/sqs.go create mode 100644 adapter/sqs_test.go diff --git a/adapter/dynamodb.go b/adapter/dynamodb.go index 0de6aa393..00a18d981 100644 --- a/adapter/dynamodb.go +++ b/adapter/dynamodb.go @@ -12,8 +12,6 @@ import ( "math/big" "net" "net/http" - "net/http/httputil" - "net/url" "slices" "sort" "strconv" @@ -294,33 +292,15 @@ func (d *DynamoDBServer) Stop() { // (either proxied or an error response was written), false if the request // should be handled locally (i.e. this node is the leader or no leader map is // configured). +// +// Serving reads or writes locally on a follower would expose G2-item-realtime +// stale reads, so every follower request is forwarded to the leader. func (d *DynamoDBServer) proxyToLeader(w http.ResponseWriter, r *http.Request) bool { - if len(d.leaderDynamo) == 0 || d.coordinator == nil { - return false - } - if d.coordinator.IsLeader() { - return false - } - // This node is a follower. All requests must be forwarded to the leader to - // preserve linearizability — serving reads or writes locally on a follower - // causes stale-read anomalies (G2-item-realtime). - leader := d.coordinator.RaftLeader() - if leader == "" { - writeDynamoError(w, http.StatusServiceUnavailable, dynamoErrInternal, "no raft leader currently available") - return true - } - targetAddr, ok := d.leaderDynamo[leader] - if !ok || strings.TrimSpace(targetAddr) == "" { - writeDynamoError(w, http.StatusServiceUnavailable, dynamoErrInternal, "leader dynamo address not found") - return true - } - target := &url.URL{Scheme: "http", Host: targetAddr} - proxy := httputil.NewSingleHostReverseProxy(target) - proxy.ErrorHandler = func(rw http.ResponseWriter, _ *http.Request, err error) { - writeDynamoError(rw, http.StatusServiceUnavailable, dynamoErrInternal, "leader proxy error: "+err.Error()) - } - proxy.ServeHTTP(w, r) - return true + return proxyHTTPRequestToLeader(d.coordinator, d.leaderDynamo, dynamoLeaderProxyErrorWriter, w, r) +} + +func dynamoLeaderProxyErrorWriter(w http.ResponseWriter, status int, message string) { + writeDynamoError(w, status, dynamoErrInternal, message) } func (d *DynamoDBServer) handle(w http.ResponseWriter, r *http.Request) { diff --git a/adapter/leader_http_proxy.go b/adapter/leader_http_proxy.go new file mode 100644 index 000000000..8483698b8 --- /dev/null +++ b/adapter/leader_http_proxy.go @@ -0,0 +1,64 @@ +package adapter + +import ( + "net/http" + "net/http/httputil" + "net/url" + "strings" + + "github.com/bootjp/elastickv/kv" +) + +// httpLeaderErrorWriter writes an adapter-specific error envelope (JSON for +// DynamoDB/SQS, XML for S3, RESP for Redis, ...) when the HTTP leader proxy +// cannot forward a follower request. The adapter owns the shape of the body; +// this helper only decides when to call it. +type httpLeaderErrorWriter func(w http.ResponseWriter, status int, message string) + +// proxyHTTPRequestToLeader forwards r to the current Raft leader's HTTP +// adapter endpoint when this node is a follower. It returns true when the +// request was handled (either proxied or an error body was written) and +// false when the caller should serve it locally (no leader map configured, +// no coordinator, or this node is the leader). +// +// Error paths: +// 1. no known Raft leader → 503 via errWriter("no raft leader currently available") +// 2. leader id missing from leaderMap → 503 via errWriter("leader address not found") +// 3. reverse-proxy dial/copy failure → 503 via errWriter("leader proxy error: ") +// +// leaderMap keys are Raft addresses; values are the matching adapter HTTP +// addresses exported on that same node. +func proxyHTTPRequestToLeader( + coordinator kv.Coordinator, + leaderMap map[string]string, + errWriter httpLeaderErrorWriter, + w http.ResponseWriter, + r *http.Request, +) bool { + if len(leaderMap) == 0 || coordinator == nil { + return false + } + if coordinator.IsLeader() { + return false + } + // Follower ingress: forward to the leader so reads and writes both see a + // single serialization point. Serving locally on a follower would expose + // G2-item-realtime stale reads. + leader := coordinator.RaftLeader() + if leader == "" { + errWriter(w, http.StatusServiceUnavailable, "no raft leader currently available") + return true + } + targetAddr, ok := leaderMap[leader] + if !ok || strings.TrimSpace(targetAddr) == "" { + errWriter(w, http.StatusServiceUnavailable, "leader address not found") + return true + } + target := &url.URL{Scheme: "http", Host: targetAddr} + proxy := httputil.NewSingleHostReverseProxy(target) + proxy.ErrorHandler = func(rw http.ResponseWriter, _ *http.Request, err error) { + errWriter(rw, http.StatusServiceUnavailable, "leader proxy error: "+err.Error()) + } + proxy.ServeHTTP(w, r) + return true +} diff --git a/adapter/sqs.go b/adapter/sqs.go new file mode 100644 index 000000000..545d4fd81 --- /dev/null +++ b/adapter/sqs.go @@ -0,0 +1,257 @@ +package adapter + +import ( + "context" + "io" + "net" + "net/http" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +// SQS target prefix for the JSON-1.0 protocol. Every supported operation is +// dispatched by the X-Amz-Target header, mirroring the DynamoDB adapter. +const sqsTargetPrefix = "AmazonSQS." + +const ( + sqsCreateQueueTarget = sqsTargetPrefix + "CreateQueue" + sqsDeleteQueueTarget = sqsTargetPrefix + "DeleteQueue" + sqsListQueuesTarget = sqsTargetPrefix + "ListQueues" + sqsGetQueueUrlTarget = sqsTargetPrefix + "GetQueueUrl" + sqsGetQueueAttributesTarget = sqsTargetPrefix + "GetQueueAttributes" + sqsSetQueueAttributesTarget = sqsTargetPrefix + "SetQueueAttributes" + sqsPurgeQueueTarget = sqsTargetPrefix + "PurgeQueue" + sqsSendMessageTarget = sqsTargetPrefix + "SendMessage" + sqsSendMessageBatchTarget = sqsTargetPrefix + "SendMessageBatch" + sqsReceiveMessageTarget = sqsTargetPrefix + "ReceiveMessage" + sqsDeleteMessageTarget = sqsTargetPrefix + "DeleteMessage" + sqsDeleteMessageBatchTarget = sqsTargetPrefix + "DeleteMessageBatch" + sqsChangeMessageVisibilityTarget = sqsTargetPrefix + "ChangeMessageVisibility" + sqsChangeMessageVisibilityBatchTgt = sqsTargetPrefix + "ChangeMessageVisibilityBatch" + sqsTagQueueTarget = sqsTargetPrefix + "TagQueue" + sqsUntagQueueTarget = sqsTargetPrefix + "UntagQueue" + sqsListQueueTagsTarget = sqsTargetPrefix + "ListQueueTags" +) + +const ( + sqsHealthPath = "/sqs_health" + sqsLeaderHealthPath = "/sqs_leader_health" +) + +const ( + sqsHealthMaxRequestBodyBytes = 1024 + sqsMaxRequestBodyBytes = 1 << 20 + sqsContentTypeJSON = "application/x-amz-json-1.0" +) + +// AWS SQS error codes used by the JSON protocol. The canonical list is on the +// "Common Errors" page of the SQS API reference. +const ( + sqsErrInvalidAction = "InvalidAction" + sqsErrNotImplemented = "NotImplemented" + sqsErrInternalFailure = "InternalFailure" + sqsErrServiceUnavailable = "ServiceUnavailable" + sqsErrMalformedRequest = "MalformedQueryString" +) + +type SQSServerOption func(*SQSServer) + +type SQSServer struct { + listen net.Listener + store store.MVCCStore + coordinator kv.Coordinator + httpServer *http.Server + targetHandlers map[string]func(http.ResponseWriter, *http.Request) + leaderSQS map[string]string +} + +// WithSQSLeaderMap configures the Raft-address-to-SQS-address mapping used to +// forward requests from followers to the current leader. Format mirrors +// WithDynamoDBLeaderMap / WithS3LeaderMap. +func WithSQSLeaderMap(m map[string]string) SQSServerOption { + return func(s *SQSServer) { + s.leaderSQS = make(map[string]string, len(m)) + for k, v := range m { + s.leaderSQS[k] = v + } + } +} + +func NewSQSServer(listen net.Listener, st store.MVCCStore, coordinate kv.Coordinator, opts ...SQSServerOption) *SQSServer { + s := &SQSServer{ + listen: listen, + store: st, + coordinator: coordinate, + } + s.targetHandlers = map[string]func(http.ResponseWriter, *http.Request){ + sqsCreateQueueTarget: s.notImplemented("CreateQueue"), + sqsDeleteQueueTarget: s.notImplemented("DeleteQueue"), + sqsListQueuesTarget: s.notImplemented("ListQueues"), + sqsGetQueueUrlTarget: s.notImplemented("GetQueueUrl"), + sqsGetQueueAttributesTarget: s.notImplemented("GetQueueAttributes"), + sqsSetQueueAttributesTarget: s.notImplemented("SetQueueAttributes"), + sqsPurgeQueueTarget: s.notImplemented("PurgeQueue"), + sqsSendMessageTarget: s.notImplemented("SendMessage"), + sqsSendMessageBatchTarget: s.notImplemented("SendMessageBatch"), + sqsReceiveMessageTarget: s.notImplemented("ReceiveMessage"), + sqsDeleteMessageTarget: s.notImplemented("DeleteMessage"), + sqsDeleteMessageBatchTarget: s.notImplemented("DeleteMessageBatch"), + sqsChangeMessageVisibilityTarget: s.notImplemented("ChangeMessageVisibility"), + sqsChangeMessageVisibilityBatchTgt: s.notImplemented("ChangeMessageVisibilityBatch"), + sqsTagQueueTarget: s.notImplemented("TagQueue"), + sqsUntagQueueTarget: s.notImplemented("UntagQueue"), + sqsListQueueTagsTarget: s.notImplemented("ListQueueTags"), + } + mux := http.NewServeMux() + mux.HandleFunc("/", s.handle) + s.httpServer = &http.Server{Handler: mux, ReadHeaderTimeout: time.Second} + for _, opt := range opts { + if opt != nil { + opt(s) + } + } + return s +} + +func (s *SQSServer) Run() error { + if err := s.httpServer.Serve(s.listen); err != nil && !errors.Is(err, http.ErrServerClosed) { + return errors.WithStack(err) + } + return nil +} + +func (s *SQSServer) Stop() { + if s.httpServer != nil { + _ = s.httpServer.Shutdown(context.Background()) + } +} + +func (s *SQSServer) handle(w http.ResponseWriter, r *http.Request) { + if s.serveHealthz(w, r) { + return + } + if s.proxyToLeader(w, r) { + return + } + + if r.Method != http.MethodPost { + w.Header().Set("Allow", http.MethodPost) + writeSQSError(w, http.StatusMethodNotAllowed, sqsErrMalformedRequest, "SQS JSON protocol requires POST") + return + } + + target := r.Header.Get("X-Amz-Target") + handler, ok := s.targetHandlers[target] + if !ok { + writeSQSError(w, http.StatusBadRequest, sqsErrInvalidAction, "unsupported SQS target: "+target) + return + } + handler(w, r) +} + +func (s *SQSServer) serveHealthz(w http.ResponseWriter, r *http.Request) bool { + if r == nil || r.URL == nil { + return false + } + switch r.URL.Path { + case sqsHealthPath: + if r.Body != nil { + r.Body = http.MaxBytesReader(w, r.Body, sqsHealthMaxRequestBodyBytes) + } + serveSQSHealthz(w, r) + return true + case sqsLeaderHealthPath: + if r.Body != nil { + r.Body = http.MaxBytesReader(w, r.Body, sqsHealthMaxRequestBodyBytes) + } + s.serveSQSLeaderHealthz(w, r) + return true + default: + return false + } +} + +func serveSQSHealthz(w http.ResponseWriter, r *http.Request) { + if !writeSQSHealthMethod(w, r) { + return + } + writeSQSHealthBody(w, r, http.StatusOK, "ok\n") +} + +func (s *SQSServer) serveSQSLeaderHealthz(w http.ResponseWriter, r *http.Request) { + if !writeSQSHealthMethod(w, r) { + return + } + if isVerifiedSQSLeader(s.coordinator) { + writeSQSHealthBody(w, r, http.StatusOK, "ok\n") + return + } + writeSQSHealthBody(w, r, http.StatusServiceUnavailable, "not leader\n") +} + +func isVerifiedSQSLeader(coordinator kv.Coordinator) bool { + if coordinator == nil || !coordinator.IsLeader() { + return false + } + return coordinator.VerifyLeader() == nil +} + +func writeSQSHealthMethod(w http.ResponseWriter, r *http.Request) bool { + switch r.Method { + case http.MethodGet, http.MethodHead: + return true + default: + w.Header().Set("Allow", "GET, HEAD") + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return false + } +} + +func writeSQSHealthBody(w http.ResponseWriter, r *http.Request, statusCode int, body string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(statusCode) + if r.Method == http.MethodHead { + return + } + _, _ = io.WriteString(w, body) +} + +// proxyToLeader forwards the HTTP request to the current SQS leader when this +// node is not the Raft leader. Returns true if the request was handled. +func (s *SQSServer) proxyToLeader(w http.ResponseWriter, r *http.Request) bool { + return proxyHTTPRequestToLeader(s.coordinator, s.leaderSQS, sqsLeaderProxyErrorWriter, w, r) +} + +func sqsLeaderProxyErrorWriter(w http.ResponseWriter, status int, message string) { + writeSQSError(w, status, sqsErrServiceUnavailable, message) +} + +// notImplemented returns a handler that responds with a JSON-protocol +// NotImplemented error so clients get a clean signal while the real handlers +// are still being built out. +func (s *SQSServer) notImplemented(op string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, _ *http.Request) { + writeSQSError(w, http.StatusNotImplemented, sqsErrNotImplemented, op+" is not implemented yet") + } +} + +// writeSQSError emits an SQS JSON-protocol error envelope. AWS returns: +// +// { "__type": "", "message": "" } +// +// with Content-Type application/x-amz-json-1.0 and the x-amzn-ErrorType header +// set to the code. SDKs key off x-amzn-ErrorType first, the body second. +func writeSQSError(w http.ResponseWriter, status int, code string, message string) { + resp := map[string]string{"message": message} + if code != "" { + resp["__type"] = code + w.Header().Set("x-amzn-ErrorType", code) + } + w.Header().Set("Content-Type", sqsContentTypeJSON) + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(resp) +} diff --git a/adapter/sqs_test.go b/adapter/sqs_test.go new file mode 100644 index 000000000..078d4f05f --- /dev/null +++ b/adapter/sqs_test.go @@ -0,0 +1,223 @@ +package adapter + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + json "github.com/goccy/go-json" +) + +// startTestSQSServer starts an SQSServer on an ephemeral localhost listener +// with no coordinator or leader map — enough to exercise the health +// endpoints, the dispatch table, and the error envelope. Real flows with Raft +// come in later PRs alongside the handler implementations. +func startTestSQSServer(t *testing.T) string { + t.Helper() + lc := &net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := NewSQSServer(listener, nil, nil) + go func() { + _ = srv.Run() + }() + t.Cleanup(func() { + srv.Stop() + }) + return "http://" + listener.Addr().String() +} + +func doRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) *http.Response { + t.Helper() + req, err := http.NewRequestWithContext(context.Background(), method, url, body) + if err != nil { + t.Fatalf("request: %v", err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + return resp +} + +func postSQSRequest(t *testing.T, url string, target string, body string) *http.Response { + t.Helper() + headers := map[string]string{"Content-Type": sqsContentTypeJSON} + if target != "" { + headers["X-Amz-Target"] = target + } + return doRequest(t, http.MethodPost, url, strings.NewReader(body), headers) +} + +func TestSQSServer_HealthOK(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodGet, base+sqsHealthPath, nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Fatalf("body: got %q want %q", got, "ok") + } +} + +func TestSQSServer_LeaderHealthWithoutCoordinatorIsNotLeader(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodGet, base+sqsLeaderHealthPath, nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusServiceUnavailable) + } +} + +func TestSQSServer_HealthRejectsNonGetHead(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodPost, base+sqsHealthPath, nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } + if allow := resp.Header.Get("Allow"); !strings.Contains(allow, "GET") { + t.Fatalf("allow header: got %q", allow) + } +} + +func TestSQSServer_UnknownTargetReturnsInvalidAction(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := postSQSRequest(t, base+"/", "AmazonSQS.NoSuchOperation", "{}") + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusBadRequest) + } + if got := resp.Header.Get("x-amzn-ErrorType"); got != sqsErrInvalidAction { + t.Fatalf("error type: got %q want %q", got, sqsErrInvalidAction) + } + var body map[string]string + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if body["__type"] != sqsErrInvalidAction { + t.Fatalf("body __type: got %q want %q", body["__type"], sqsErrInvalidAction) + } +} + +func TestSQSServer_KnownTargetsReturnNotImplemented(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + targets := []string{ + sqsCreateQueueTarget, + sqsDeleteQueueTarget, + sqsListQueuesTarget, + sqsGetQueueUrlTarget, + sqsGetQueueAttributesTarget, + sqsSetQueueAttributesTarget, + sqsPurgeQueueTarget, + sqsSendMessageTarget, + sqsSendMessageBatchTarget, + sqsReceiveMessageTarget, + sqsDeleteMessageTarget, + sqsDeleteMessageBatchTarget, + sqsChangeMessageVisibilityTarget, + sqsChangeMessageVisibilityBatchTgt, + sqsTagQueueTarget, + sqsUntagQueueTarget, + sqsListQueueTagsTarget, + } + for _, target := range targets { + t.Run(target, func(t *testing.T) { + t.Parallel() + resp := postSQSRequest(t, base+"/", target, "{}") + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotImplemented { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusNotImplemented) + } + if got := resp.Header.Get("x-amzn-ErrorType"); got != sqsErrNotImplemented { + t.Fatalf("error type: got %q want %q", got, sqsErrNotImplemented) + } + }) + } +} + +func TestSQSServer_RejectsNonPostOnRoot(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodGet, base+"/", nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } + if allow := resp.Header.Get("Allow"); !strings.Contains(allow, http.MethodPost) { + t.Fatalf("allow header: got %q", allow) + } +} + +func TestSQSServer_ErrorEnvelopeShape(t *testing.T) { + t.Parallel() + // We check the precise envelope shape via httptest so that SDK parsers + // that key off x-amzn-ErrorType + __type + message do not regress. + rec := httptest.NewRecorder() + writeSQSError(rec, http.StatusBadRequest, sqsErrInvalidAction, "oops") + if got := rec.Header().Get("Content-Type"); got != sqsContentTypeJSON { + t.Fatalf("content-type: got %q", got) + } + if got := rec.Header().Get("x-amzn-ErrorType"); got != sqsErrInvalidAction { + t.Fatalf("error type header: got %q", got) + } + var body map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode: %v", err) + } + if body["__type"] != sqsErrInvalidAction { + t.Fatalf("__type: got %q", body["__type"]) + } + if body["message"] != "oops" { + t.Fatalf("message: got %q", body["message"]) + } +} + +// TestSQSServer_StopShutsDown guards against regressions where Stop leaves the +// goroutine leaking; after Stop, Run must return promptly. +func TestSQSServer_StopShutsDown(t *testing.T) { + t.Parallel() + lc := &net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := NewSQSServer(listener, nil, nil) + done := make(chan error, 1) + go func() { done <- srv.Run() }() + // Give Serve a chance to enter its accept loop. + time.Sleep(20 * time.Millisecond) + srv.Stop() + + select { + case err := <-done: + if err != nil { + t.Fatalf("run: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Run did not return within timeout after Stop") + } +} From e46ada90690d47a789aeeb7a18afdb0465d2b9b6 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 14:59:03 +0900 Subject: [PATCH 2/9] feat(sqs): wire SQS adapter into main.go and test harness Add the minimal plumbing to bring an SQSServer online next to the existing DynamoDB and S3 servers. No new handlers yet - every target still returns NotImplemented - but the binary and the test harness can now open an SQS endpoint so subsequent PRs that add real operations have somewhere to land. - shard_config.go: ErrInvalidRaftSQSMapEntry + parseRaftSQSMap. - main.go: --sqsAddress / --raftSqsMap flags, leaderSQS in runtimeConfig, buildLeaderSQS, and startSQSServer wired into runtimeServerRunner.start. sqsAddress defaults to empty. - main_sqs.go: startSQSServer mirrors main_s3.go two-goroutine shutdown pattern. - adapter/test_util.go: portsAdress / listeners / Node / newNode / setupNodes carry an sqsAddress and *SQSServer per node, shutdown calls Stop via an extracted shutdownNode helper, portSQS starts at 29000 so tests do not collide with the dynamo range at 28000. - main_bootstrap_e2e_test.go: update parseRuntimeConfig signature. --- adapter/test_util.go | 82 ++++++++++++++++++++++++++++---------- main.go | 23 ++++++++++- main_bootstrap_e2e_test.go | 2 +- main_sqs.go | 58 +++++++++++++++++++++++++++ shard_config.go | 5 +++ 5 files changed, 146 insertions(+), 24 deletions(-) create mode 100644 main_sqs.go diff --git a/adapter/test_util.go b/adapter/test_util.go index b43f024af..7dcf22ac2 100644 --- a/adapter/test_util.go +++ b/adapter/test_util.go @@ -52,28 +52,35 @@ func newTestFactory() raftengine.Factory { func shutdown(nodes []Node) { for _, n := range nodes { - if n.opsCancel != nil { - n.opsCancel() - } - n.grpcServer.Stop() - if n.grpcService != nil { - if err := n.grpcService.Close(); err != nil { - log.Printf("grpc service close: %v", err) - } - } - n.redisServer.Stop() - if n.dynamoServer != nil { - n.dynamoServer.Stop() + shutdownNode(n) + } +} + +func shutdownNode(n Node) { + if n.opsCancel != nil { + n.opsCancel() + } + n.grpcServer.Stop() + if n.grpcService != nil { + if err := n.grpcService.Close(); err != nil { + log.Printf("grpc service close: %v", err) } - if n.engine != nil { - if err := n.engine.Close(); err != nil { - log.Printf("engine close: %v", err) - } + } + n.redisServer.Stop() + if n.dynamoServer != nil { + n.dynamoServer.Stop() + } + if n.sqsServer != nil { + n.sqsServer.Stop() + } + if n.engine != nil { + if err := n.engine.Close(); err != nil { + log.Printf("engine close: %v", err) } - if n.closeFactory != nil { - if err := n.closeFactory(); err != nil { - log.Printf("factory close: %v", err) - } + } + if n.closeFactory != nil { + if err := n.closeFactory(); err != nil { + log.Printf("factory close: %v", err) } } } @@ -83,10 +90,12 @@ type portsAdress struct { raft int redis int dynamo int + sqs int grpcAddress string raftAddress string redisAddress string dynamoAddress string + sqsAddress string } const ( @@ -95,6 +104,7 @@ const ( raftPort = 50000 redisPort = 63790 dynamoPort = 28000 + sqsPort = 29000 ) var mu sync.Mutex @@ -102,12 +112,14 @@ var portGrpc atomic.Int32 var portRaft atomic.Int32 var portRedis atomic.Int32 var portDynamo atomic.Int32 +var portSQS atomic.Int32 func init() { portGrpc.Store(raftPort) portRaft.Store(grpcPort) portRedis.Store(redisPort) portDynamo.Store(dynamoPort) + portSQS.Store(sqsPort) } func portAssigner() portsAdress { @@ -117,15 +129,18 @@ func portAssigner() portsAdress { rp := portRaft.Add(1) rd := portRedis.Add(1) dn := portDynamo.Add(1) + sq := portSQS.Add(1) return portsAdress{ grpc: int(gp), raft: int(rp), redis: int(rd), dynamo: int(dn), + sqs: int(sq), grpcAddress: net.JoinHostPort("localhost", strconv.Itoa(int(gp))), raftAddress: net.JoinHostPort("localhost", strconv.Itoa(int(rp))), redisAddress: net.JoinHostPort("localhost", strconv.Itoa(int(rd))), dynamoAddress: net.JoinHostPort("localhost", strconv.Itoa(int(dn))), + sqsAddress: net.JoinHostPort("localhost", strconv.Itoa(int(sq))), } } @@ -134,25 +149,29 @@ type Node struct { raftAddress string redisAddress string dynamoAddress string + sqsAddress string grpcServer *grpc.Server grpcService *GRPCServer redisServer *RedisServer dynamoServer *DynamoDBServer + sqsServer *SQSServer opsCancel context.CancelFunc engine raftengine.Engine closeFactory func() error } -func newNode(grpcAddress, raftAddress, redisAddress, dynamoAddress string, engine raftengine.Engine, closeFactory func() error, grpcs *grpc.Server, grpcService *GRPCServer, rd *RedisServer, ds *DynamoDBServer, opsCancel context.CancelFunc) Node { +func newNode(grpcAddress, raftAddress, redisAddress, dynamoAddress, sqsAddress string, engine raftengine.Engine, closeFactory func() error, grpcs *grpc.Server, grpcService *GRPCServer, rd *RedisServer, ds *DynamoDBServer, sq *SQSServer, opsCancel context.CancelFunc) Node { return Node{ grpcAddress: grpcAddress, raftAddress: raftAddress, redisAddress: redisAddress, dynamoAddress: dynamoAddress, + sqsAddress: sqsAddress, grpcServer: grpcs, grpcService: grpcService, redisServer: rd, dynamoServer: ds, + sqsServer: sq, opsCancel: opsCancel, engine: engine, closeFactory: closeFactory, @@ -183,6 +202,7 @@ type listeners struct { grpc net.Listener redis net.Listener dynamo net.Listener + sqs net.Listener } func bindListeners(ctx context.Context, lc *net.ListenConfig, port portsAdress) (portsAdress, listeners, bool, error) { @@ -213,10 +233,22 @@ func bindListeners(ctx context.Context, lc *net.ListenConfig, port portsAdress) return port, listeners{}, false, errors.WithStack(err) } + sqsSock, err := lc.Listen(ctx, "tcp", port.sqsAddress) + if err != nil { + _ = grpcSock.Close() + _ = redisSock.Close() + _ = dynamoSock.Close() + if errors.Is(err, unix.EADDRINUSE) { + return port, listeners{}, true, nil + } + return port, listeners{}, false, errors.WithStack(err) + } + return port, listeners{ grpc: grpcSock, redis: redisSock, dynamo: dynamoSock, + sqs: sqsSock, }, false, nil } @@ -386,6 +418,7 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ( grpcSock := lis[i].grpc redisSock := lis[i].redis dynamoSock := lis[i].dynamo + sqsSock := lis[i].sqs result, err := factory.Create(raftengine.FactoryConfig{ LocalID: strconv.Itoa(i), @@ -433,17 +466,24 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ( assert.NoError(t, ds.Run()) }() + sq := NewSQSServer(sqsSock, routedStore, coordinator) + go func() { + assert.NoError(t, sq.Run()) + }() + nodes = append(nodes, newNode( port.grpcAddress, port.raftAddress, port.redisAddress, port.dynamoAddress, + port.sqsAddress, result.Engine, result.Close, s, gs, rd, ds, + sq, opsCancel, )) } diff --git a/main.go b/main.go index 4306388a8..a9bb6dbac 100644 --- a/main.go +++ b/main.go @@ -79,6 +79,7 @@ var ( s3Region = flag.String("s3Region", "us-east-1", "S3 signing region") s3CredsFile = flag.String("s3CredentialsFile", "", "Path to a JSON file containing static S3 credentials") s3PathStyleOnly = flag.Bool("s3PathStyleOnly", true, "Only accept path-style S3 requests") + sqsAddr = flag.String("sqsAddress", "", "TCP host+port for SQS-compatible API; empty to disable") metricsAddr = flag.String("metricsAddress", "localhost:9090", "TCP host+port for Prometheus metrics") metricsToken = flag.String("metricsToken", "", "Bearer token for Prometheus metrics; required for non-loopback metricsAddress") pprofAddr = flag.String("pprofAddress", "localhost:6060", "TCP host+port for pprof debug endpoints; empty to disable") @@ -93,6 +94,7 @@ var ( raftRedisMap = flag.String("raftRedisMap", "", "Map of Raft address to Redis address (raftAddr=redisAddr,...)") raftS3Map = flag.String("raftS3Map", "", "Map of Raft address to S3 address (raftAddr=s3Addr,...)") raftDynamoMap = flag.String("raftDynamoMap", "", "Map of Raft address to DynamoDB address (raftAddr=dynamoAddr,...)") + raftSqsMap = flag.String("raftSqsMap", "", "Map of Raft address to SQS address (raftAddr=sqsAddr,...)") ) func main() { @@ -212,6 +214,8 @@ func run() error { s3Region: *s3Region, s3CredsFile: *s3CredsFile, s3PathStyleOnly: *s3PathStyleOnly, + sqsAddress: *sqsAddr, + leaderSQS: cfg.leaderSQS, metricsAddress: *metricsAddr, metricsToken: *metricsToken, pprofAddress: *pprofAddr, @@ -238,7 +242,7 @@ func resolveRuntimeInputs() (runtimeConfig, raftEngineType, []raftengine.Server, return runtimeConfig{}, "", nil, false, err } - cfg, err := parseRuntimeConfig(*myAddr, *redisAddr, *s3Addr, *dynamoAddr, *raftGroups, *shardRanges, *raftRedisMap, *raftS3Map, *raftDynamoMap) + cfg, err := parseRuntimeConfig(*myAddr, *redisAddr, *s3Addr, *dynamoAddr, *sqsAddr, *raftGroups, *shardRanges, *raftRedisMap, *raftS3Map, *raftDynamoMap, *raftSqsMap) if err != nil { return runtimeConfig{}, "", nil, false, err } @@ -258,10 +262,11 @@ type runtimeConfig struct { leaderRedis map[string]string leaderS3 map[string]string leaderDynamo map[string]string + leaderSQS map[string]string multi bool } -func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, raftGroups, shardRanges, raftRedisMap, raftS3Map, raftDynamoMap string) (runtimeConfig, error) { +func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, sqsAddr, raftGroups, shardRanges, raftRedisMap, raftS3Map, raftDynamoMap, raftSqsMap string) (runtimeConfig, error) { groups, err := parseRaftGroups(raftGroups, myAddr) if err != nil { return runtimeConfig{}, errors.Wrapf(err, "failed to parse raft groups") @@ -288,6 +293,10 @@ func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, raftGroups, shard if err != nil { return runtimeConfig{}, errors.Wrapf(err, "failed to parse raft dynamo map") } + leaderSQS, err := buildLeaderSQS(groups, sqsAddr, raftSqsMap) + if err != nil { + return runtimeConfig{}, errors.Wrapf(err, "failed to parse raft sqs map") + } return runtimeConfig{ groups: groups, @@ -296,6 +305,7 @@ func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, raftGroups, shard leaderRedis: leaderRedis, leaderS3: leaderS3, leaderDynamo: leaderDynamo, + leaderSQS: leaderSQS, multi: len(groups) > 1, }, nil } @@ -316,6 +326,10 @@ func buildLeaderS3(groups []groupSpec, s3Addr string, raftS3Map string) (map[str return buildLeaderAddrMap(groups, s3Addr, raftS3Map, parseRaftS3Map) } +func buildLeaderSQS(groups []groupSpec, sqsAddr string, raftSqsMap string) (map[string]string, error) { + return buildLeaderAddrMap(groups, sqsAddr, raftSqsMap, parseRaftSQSMap) +} + func buildLeaderDynamo(groups []groupSpec, dynamoAddr string, raftDynamoMap string) (map[string]string, error) { return buildLeaderAddrMap(groups, dynamoAddr, raftDynamoMap, parseRaftDynamoMap) } @@ -824,6 +838,8 @@ type runtimeServerRunner struct { s3Region string s3CredsFile string s3PathStyleOnly bool + sqsAddress string + leaderSQS map[string]string metricsAddress string metricsToken string pprofAddress string @@ -856,6 +872,9 @@ func (r runtimeServerRunner) start() error { if err := startS3Server(r.ctx, r.lc, r.eg, r.s3Address, r.shardStore, r.coordinate, r.leaderS3, r.s3Region, r.s3CredsFile, r.s3PathStyleOnly, r.readTracker); err != nil { return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) } + if err := startSQSServer(r.ctx, r.lc, r.eg, r.sqsAddress, r.shardStore, r.coordinate, r.leaderSQS); err != nil { + return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) + } if err := startMetricsServer(r.ctx, r.lc, r.eg, r.metricsAddress, r.metricsToken, r.metricsRegistry.Handler()); err != nil { return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) } diff --git a/main_bootstrap_e2e_test.go b/main_bootstrap_e2e_test.go index b6040bb45..92474ddbc 100644 --- a/main_bootstrap_e2e_test.go +++ b/main_bootstrap_e2e_test.go @@ -340,7 +340,7 @@ func startBootstrapE2ENode( bootstrapMembers string, engineType raftEngineType, ) (*bootstrapE2ENode, error) { - cfg, err := parseRuntimeConfig(ep.raftAddr, ep.redisAddr, "", "", "", "", "", "", "") + cfg, err := parseRuntimeConfig(ep.raftAddr, ep.redisAddr, "", "", "", "", "", "", "", "", "") if err != nil { return nil, err } diff --git a/main_sqs.go b/main_sqs.go new file mode 100644 index 000000000..a6371e96e --- /dev/null +++ b/main_sqs.go @@ -0,0 +1,58 @@ +package main + +import ( + "context" + "net" + "strings" + + "github.com/bootjp/elastickv/adapter" + "github.com/bootjp/elastickv/kv" + "github.com/cockroachdb/errors" + "golang.org/x/sync/errgroup" +) + +func startSQSServer( + ctx context.Context, + lc *net.ListenConfig, + eg *errgroup.Group, + sqsAddr string, + shardStore *kv.ShardStore, + coordinate kv.Coordinator, + leaderSQS map[string]string, +) error { + sqsAddr = strings.TrimSpace(sqsAddr) + if sqsAddr == "" { + return nil + } + sqsL, err := lc.Listen(ctx, "tcp", sqsAddr) + if err != nil { + return errors.Wrapf(err, "failed to listen on %s", sqsAddr) + } + sqsServer := adapter.NewSQSServer( + sqsL, + shardStore, + coordinate, + adapter.WithSQSLeaderMap(leaderSQS), + ) + // Two-goroutine shutdown pattern mirrors startS3Server: one goroutine waits + // on either ctx.Done() or Run completion to call Stop, the other runs the + // server and cancels the waiter once it has returned. + runDoneCtx, runDoneCancel := context.WithCancel(context.Background()) + eg.Go(func() error { + select { + case <-ctx.Done(): + sqsServer.Stop() + case <-runDoneCtx.Done(): + } + return nil + }) + eg.Go(func() error { + err := sqsServer.Run() + runDoneCancel() + if err == nil || errors.Is(err, net.ErrClosed) { + return nil + } + return errors.WithStack(err) + }) + return nil +} diff --git a/shard_config.go b/shard_config.go index a7353038b..eee13d204 100644 --- a/shard_config.go +++ b/shard_config.go @@ -32,6 +32,7 @@ var ( ErrInvalidRaftRedisMapEntry = errors.New("invalid raftRedisMap entry") ErrInvalidRaftS3MapEntry = errors.New("invalid raftS3Map entry") ErrInvalidRaftDynamoMapEntry = errors.New("invalid raftDynamoMap entry") + ErrInvalidRaftSQSMapEntry = errors.New("invalid raftSqsMap entry") ErrInvalidRaftBootstrapMembersEntry = errors.New("invalid raftBootstrapMembers entry") ) @@ -128,6 +129,10 @@ func parseRaftDynamoMap(raw string) (map[string]string, error) { return parseRaftAddressMap(raw, ErrInvalidRaftDynamoMapEntry) } +func parseRaftSQSMap(raw string) (map[string]string, error) { + return parseRaftAddressMap(raw, ErrInvalidRaftSQSMapEntry) +} + func parseRaftAddressMap(raw string, invalidEntry error) (map[string]string, error) { out := make(map[string]string) if raw == "" { From 94904526437de895063154e5fd3b92344e288e04 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 15:10:13 +0900 Subject: [PATCH 3/9] feat(sqs): implement queue catalog CRUD Wire up the first real SQS handlers on top of the existing MVCC / Coordinator primitives: - CreateQueue (idempotent when attributes match, QueueNameExists otherwise; FIFO suffix + FifoQueue attribute cross-validated; every supported attribute validated against its AWS range) - DeleteQueue (tombstones the meta record and bumps the generation counter so a subsequent CreateQueue with the same name starts a new incarnation; actual message-keyspace reclaim lands with Stage 4) - ListQueues (prefix-scans !sqs|queue|meta|, supports QueueNamePrefix / MaxResults / NextToken) - GetQueueUrl - GetQueueAttributes / SetQueueAttributes Design choices: - Metadata is stored as JSON with a four-byte magic prefix so future encoding migrations can switch formats without reading back garbage. - Segment encoding uses base64 raw URL for queue names, matching the DynamoDB adapter's encodeDynamoSegment so operators reading raw Pebble keys see the same shape across adapters. - Attribute validation is table-driven via sqsAttributeAppliers so adding an attribute does not grow applyAttributes's cyclomatic complexity. - The OCC retry loop is split into tryCreateQueueOnce so each retry pass is self-contained and the outer loop only deals with backoff. - All mutating operations go through the leader via the shared proxyHTTPRequestToLeader helper, matching the DynamoDB adapter. Tests (adapter/sqs_catalog_test.go) spin up a real single-node cluster via createNode(t, 1) and exercise the end-to-end JSON-protocol path: Create/Get/List, idempotent Create, QueueNameExists on attribute change, Get/SetAttributes, Delete + follow-up NonExistentQueue, FIFO name/attribute validation, and key-encoding round trips. The scaffold "unknown target returns NotImplemented" test was updated to skip the now-implemented catalog targets. --- adapter/sqs.go | 12 +- adapter/sqs_catalog.go | 832 ++++++++++++++++++++++++++++++++++++ adapter/sqs_catalog_test.go | 294 +++++++++++++ adapter/sqs_keys.go | 58 +++ adapter/sqs_test.go | 11 +- 5 files changed, 1195 insertions(+), 12 deletions(-) create mode 100644 adapter/sqs_catalog.go create mode 100644 adapter/sqs_catalog_test.go create mode 100644 adapter/sqs_keys.go diff --git a/adapter/sqs.go b/adapter/sqs.go index 545d4fd81..d0ec3c150 100644 --- a/adapter/sqs.go +++ b/adapter/sqs.go @@ -88,12 +88,12 @@ func NewSQSServer(listen net.Listener, st store.MVCCStore, coordinate kv.Coordin coordinator: coordinate, } s.targetHandlers = map[string]func(http.ResponseWriter, *http.Request){ - sqsCreateQueueTarget: s.notImplemented("CreateQueue"), - sqsDeleteQueueTarget: s.notImplemented("DeleteQueue"), - sqsListQueuesTarget: s.notImplemented("ListQueues"), - sqsGetQueueUrlTarget: s.notImplemented("GetQueueUrl"), - sqsGetQueueAttributesTarget: s.notImplemented("GetQueueAttributes"), - sqsSetQueueAttributesTarget: s.notImplemented("SetQueueAttributes"), + sqsCreateQueueTarget: s.createQueue, + sqsDeleteQueueTarget: s.deleteQueue, + sqsListQueuesTarget: s.listQueues, + sqsGetQueueUrlTarget: s.getQueueUrl, + sqsGetQueueAttributesTarget: s.getQueueAttributes, + sqsSetQueueAttributesTarget: s.setQueueAttributes, sqsPurgeQueueTarget: s.notImplemented("PurgeQueue"), sqsSendMessageTarget: s.notImplemented("SendMessage"), sqsSendMessageBatchTarget: s.notImplemented("SendMessageBatch"), diff --git a/adapter/sqs_catalog.go b/adapter/sqs_catalog.go new file mode 100644 index 000000000..010e96802 --- /dev/null +++ b/adapter/sqs_catalog.go @@ -0,0 +1,832 @@ +package adapter + +import ( + "bytes" + "context" + "io" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +// AWS SQS defaults, reproduced from the public API reference so clients that +// send an empty Attributes map get standard behavior. +const ( + sqsDefaultVisibilityTimeoutSeconds = 30 + sqsDefaultRetentionSeconds = 345600 // 4 days + sqsDefaultDelaySeconds = 0 + sqsDefaultReceiveMessageWaitSeconds = 0 + sqsDefaultMaximumMessageSize = 262144 // 256 KiB + + sqsMaxVisibilityTimeoutSeconds = 43200 // 12 hours + sqsMaxRetentionSeconds = 1209600 // 14 days + sqsMinRetentionSeconds = 60 + sqsMaxDelaySeconds = 900 + sqsMaxReceiveMessageWaitSeconds = 20 + sqsMinMaximumMessageSize = 1024 + sqsMaximumAllowedMaximumMessageSize = 262144 + sqsMaxQueueNameLength = 80 + sqsFIFOQueueNameSuffix = ".fifo" + sqsListQueuesDefaultMaxResults = 1000 + sqsListQueuesHardMaxResults = 1000 + sqsQueueScanPageLimit = 1024 +) + +// AWS error codes specific to the queue catalog. +const ( + sqsErrValidation = "InvalidParameterValue" + sqsErrMissingParameter = "MissingParameter" + sqsErrQueueNameExists = "QueueNameExists" + sqsErrQueueDoesNotExist = "AWS.SimpleQueueService.NonExistentQueue" + sqsErrInvalidAttributeName = "InvalidAttributeName" + sqsErrInvalidAttributeValue = "InvalidAttributeValue" +) + +var sqsQueueNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_-]{1,80}(\.fifo)?$`) + +// sqsQueueMeta is the Go mirror of the queue metadata record persisted at +// !sqs|queue|meta|. Serialized as JSON with a short magic prefix so +// future schema migrations can switch encoding without reading back garbage. +type sqsQueueMeta struct { + Name string `json:"name"` + Generation uint64 `json:"generation"` + CreatedAtHLC uint64 `json:"created_at_hlc,omitempty"` + IsFIFO bool `json:"is_fifo,omitempty"` + ContentBasedDedup bool `json:"content_based_dedup,omitempty"` + VisibilityTimeoutSeconds int64 `json:"visibility_timeout_seconds"` + MessageRetentionSeconds int64 `json:"message_retention_seconds"` + DelaySeconds int64 `json:"delay_seconds"` + ReceiveMessageWaitSeconds int64 `json:"receive_message_wait_seconds"` + MaximumMessageSize int64 `json:"maximum_message_size"` + RedrivePolicy string `json:"redrive_policy,omitempty"` + Tags map[string]string `json:"tags,omitempty"` +} + +var storedSQSMetaPrefix = []byte{0x00, 'S', 'Q', 0x01} + +func encodeSQSQueueMeta(m *sqsQueueMeta) ([]byte, error) { + body, err := json.Marshal(m) + if err != nil { + return nil, errors.WithStack(err) + } + out := make([]byte, 0, len(storedSQSMetaPrefix)+len(body)) + out = append(out, storedSQSMetaPrefix...) + out = append(out, body...) + return out, nil +} + +func decodeSQSQueueMeta(b []byte) (*sqsQueueMeta, error) { + if !bytes.HasPrefix(b, storedSQSMetaPrefix) { + return nil, errors.New("unrecognized sqs meta format") + } + var m sqsQueueMeta + if err := json.Unmarshal(b[len(storedSQSMetaPrefix):], &m); err != nil { + return nil, errors.WithStack(err) + } + return &m, nil +} + +// sqsAPIError is a typed error that captures the HTTP status and AWS error +// code so handler helpers can fail deep in the call chain and let the top +// level render a consistent envelope via writeSQSErrorFromErr. +type sqsAPIError struct { + status int + errorType string + message string +} + +func (e *sqsAPIError) Error() string { + if e == nil { + return "" + } + if e.message != "" { + return e.message + } + return http.StatusText(e.status) +} + +func newSQSAPIError(status int, errorType string, message string) error { + return &sqsAPIError{status: status, errorType: errorType, message: message} +} + +func writeSQSErrorFromErr(w http.ResponseWriter, err error) { + var apiErr *sqsAPIError + if errors.As(err, &apiErr) { + writeSQSError(w, apiErr.status, apiErr.errorType, apiErr.message) + return + } + writeSQSError(w, http.StatusInternalServerError, sqsErrInternalFailure, err.Error()) +} + +func writeSQSJSON(w http.ResponseWriter, payload any) { + w.Header().Set("Content-Type", sqsContentTypeJSON) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(payload) +} + +// ------------------------ input decoding ------------------------ + +type sqsCreateQueueInput struct { + QueueName string `json:"QueueName"` + Attributes map[string]string `json:"Attributes"` + Tags map[string]string `json:"tags"` +} + +type sqsDeleteQueueInput struct { + QueueUrl string `json:"QueueUrl"` +} + +type sqsListQueuesInput struct { + QueueNamePrefix string `json:"QueueNamePrefix"` + MaxResults int `json:"MaxResults"` + NextToken string `json:"NextToken"` +} + +type sqsGetQueueUrlInput struct { + QueueName string `json:"QueueName"` +} + +type sqsGetQueueAttributesInput struct { + QueueUrl string `json:"QueueUrl"` + AttributeNames []string `json:"AttributeNames"` +} + +type sqsSetQueueAttributesInput struct { + QueueUrl string `json:"QueueUrl"` + Attributes map[string]string `json:"Attributes"` +} + +func decodeSQSJSONInput(r *http.Request, v any) error { + body, err := io.ReadAll(http.MaxBytesReader(nil, r.Body, sqsMaxRequestBodyBytes)) + if err != nil { + return newSQSAPIError(http.StatusBadRequest, sqsErrMalformedRequest, err.Error()) + } + if len(bytes.TrimSpace(body)) == 0 { + // Empty body is legal for some actions; leave v at its zero value. + return nil + } + if err := json.Unmarshal(body, v); err != nil { + return newSQSAPIError(http.StatusBadRequest, sqsErrMalformedRequest, err.Error()) + } + return nil +} + +// ------------------------ URL helpers ------------------------ + +// queueURL builds the AWS-compatible URL echoed back to clients. We follow +// the endpoint the client addressed so SDKs that re-sign subsequent requests +// against the same URL keep working behind reverse proxies. +func (s *SQSServer) queueURL(r *http.Request, queueName string) string { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + host := r.Host + if host == "" && s.listen != nil { + host = s.listen.Addr().String() + } + return scheme + "://" + host + "/" + queueName +} + +// scanContinuationSentinel is appended to a key to produce the exclusive +// upper bound for the next scan page — i.e. the smallest key strictly +// greater than k. +const scanContinuationSentinel = 0x00 + +func nextScanCursorAfter(key []byte) []byte { + return append(bytes.Clone(key), scanContinuationSentinel) +} + +func queueNameFromURL(queueUrl string) (string, error) { + if strings.TrimSpace(queueUrl) == "" { + return "", newSQSAPIError(http.StatusBadRequest, sqsErrMissingParameter, "missing QueueUrl") + } + parsed, err := url.Parse(queueUrl) + if err != nil { + return "", newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "invalid QueueUrl: "+err.Error()) + } + name := strings.TrimPrefix(parsed.Path, "/") + if name == "" { + return "", newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "QueueUrl path is empty") + } + // Strip an AWS-style account-id prefix so http://host/12345/MyQueue works. + if idx := strings.LastIndex(name, "/"); idx >= 0 { + name = name[idx+1:] + } + return name, nil +} + +func validateQueueName(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return newSQSAPIError(http.StatusBadRequest, sqsErrMissingParameter, "missing QueueName") + } + if len(name) > sqsMaxQueueNameLength { + return newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "QueueName too long") + } + if !sqsQueueNamePattern.MatchString(name) { + return newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "QueueName contains invalid characters") + } + return nil +} + +// ------------------------ attribute parsing ------------------------ + +func parseAttributesIntoMeta(name string, attrs map[string]string) (*sqsQueueMeta, error) { + meta := &sqsQueueMeta{ + Name: name, + IsFIFO: strings.HasSuffix(name, sqsFIFOQueueNameSuffix), + VisibilityTimeoutSeconds: sqsDefaultVisibilityTimeoutSeconds, + MessageRetentionSeconds: sqsDefaultRetentionSeconds, + DelaySeconds: sqsDefaultDelaySeconds, + ReceiveMessageWaitSeconds: sqsDefaultReceiveMessageWaitSeconds, + MaximumMessageSize: sqsDefaultMaximumMessageSize, + } + if err := applyAttributes(meta, attrs); err != nil { + return nil, err + } + // FifoQueue attribute is authoritative if explicitly set; otherwise the + // .fifo suffix implies true and callers without the suffix get a + // standard queue. + if v, ok := attrs["FifoQueue"]; ok { + b, err := strconv.ParseBool(v) + if err != nil { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, "FifoQueue must be a boolean") + } + if b && !strings.HasSuffix(name, sqsFIFOQueueNameSuffix) { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "FIFO queue name must end in .fifo") + } + if !b && strings.HasSuffix(name, sqsFIFOQueueNameSuffix) { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "Queue name ends in .fifo but FifoQueue=false") + } + meta.IsFIFO = b + } + return meta, nil +} + +// attributeApplier writes one attribute into meta. Keeping one applier per +// attribute keeps applyAttributes trivial to dispatch through. +type attributeApplier func(meta *sqsQueueMeta, value string) error + +var sqsAttributeAppliers = map[string]attributeApplier{ + // FifoQueue is applied after applyAttributes returns (see + // parseAttributesIntoMeta) because it needs cross-field validation + // against the queue name. + "FifoQueue": func(_ *sqsQueueMeta, _ string) error { return nil }, + "VisibilityTimeout": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("VisibilityTimeout", v, 0, sqsMaxVisibilityTimeoutSeconds) + if err != nil { + return err + } + m.VisibilityTimeoutSeconds = n + return nil + }, + "MessageRetentionPeriod": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("MessageRetentionPeriod", v, sqsMinRetentionSeconds, sqsMaxRetentionSeconds) + if err != nil { + return err + } + m.MessageRetentionSeconds = n + return nil + }, + "DelaySeconds": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("DelaySeconds", v, 0, sqsMaxDelaySeconds) + if err != nil { + return err + } + m.DelaySeconds = n + return nil + }, + "ReceiveMessageWaitTimeSeconds": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("ReceiveMessageWaitTimeSeconds", v, 0, sqsMaxReceiveMessageWaitSeconds) + if err != nil { + return err + } + m.ReceiveMessageWaitSeconds = n + return nil + }, + "MaximumMessageSize": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("MaximumMessageSize", v, sqsMinMaximumMessageSize, sqsMaximumAllowedMaximumMessageSize) + if err != nil { + return err + } + m.MaximumMessageSize = n + return nil + }, + "ContentBasedDeduplication": func(m *sqsQueueMeta, v string) error { + b, err := strconv.ParseBool(v) + if err != nil { + return newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, "ContentBasedDeduplication must be a boolean") + } + m.ContentBasedDedup = b + return nil + }, + "RedrivePolicy": func(m *sqsQueueMeta, v string) error { + m.RedrivePolicy = v + return nil + }, +} + +// applyAttributes writes every entry in attrs into meta, returning a typed +// SQS error on the first unknown key or out-of-range value. +func applyAttributes(meta *sqsQueueMeta, attrs map[string]string) error { + for k, v := range attrs { + apply, ok := sqsAttributeAppliers[k] + if !ok { + return newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeName, "unsupported attribute: "+k) + } + if err := apply(meta, v); err != nil { + return err + } + } + return nil +} + +func parseIntAttr(name, value string, minVal, maxVal int64) (int64, error) { + n, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64) + if err != nil { + return 0, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, name+" must be an integer") + } + if n < minVal || n > maxVal { + return 0, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, name+" is out of range") + } + return n, nil +} + +// attributesEqual is used by CreateQueue for idempotency: calling the same +// CreateQueue twice with identical attributes is a no-op; differing values +// must fail with QueueNameExists. +func attributesEqual(a, b *sqsQueueMeta) bool { + if a == nil || b == nil { + return false + } + return a.IsFIFO == b.IsFIFO && + a.ContentBasedDedup == b.ContentBasedDedup && + a.VisibilityTimeoutSeconds == b.VisibilityTimeoutSeconds && + a.MessageRetentionSeconds == b.MessageRetentionSeconds && + a.DelaySeconds == b.DelaySeconds && + a.ReceiveMessageWaitSeconds == b.ReceiveMessageWaitSeconds && + a.MaximumMessageSize == b.MaximumMessageSize && + a.RedrivePolicy == b.RedrivePolicy +} + +// ------------------------ storage primitives ------------------------ + +func (s *SQSServer) nextTxnReadTS(ctx context.Context) uint64 { + maxTS := uint64(0) + if p, ok := s.store.(globalLastCommitTSProvider); ok { + maxTS = p.GlobalLastCommitTS(ctx) + } else if s.store != nil { + maxTS = s.store.LastCommitTS() + } + if s.coordinator != nil { + if clock := s.coordinator.Clock(); clock != nil && maxTS > 0 { + clock.Observe(maxTS) + } + } + if maxTS == 0 { + return 1 + } + return maxTS +} + +func (s *SQSServer) loadQueueMetaAt(ctx context.Context, queueName string, ts uint64) (*sqsQueueMeta, bool, error) { + b, err := s.store.GetAt(ctx, sqsQueueMetaKey(queueName), ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, false, nil + } + return nil, false, errors.WithStack(err) + } + meta, err := decodeSQSQueueMeta(b) + if err != nil { + return nil, false, err + } + return meta, true, nil +} + +func (s *SQSServer) loadQueueGenerationAt(ctx context.Context, queueName string, ts uint64) (uint64, error) { + b, err := s.store.GetAt(ctx, sqsQueueGenKey(queueName), ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return 0, nil + } + return 0, errors.WithStack(err) + } + n, err := strconv.ParseUint(string(b), 10, 64) + if err != nil { + return 0, errors.WithStack(err) + } + return n, nil +} + +// ------------------------ handlers ------------------------ + +func (s *SQSServer) createQueue(w http.ResponseWriter, r *http.Request) { + var in sqsCreateQueueInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := validateQueueName(in.QueueName); err != nil { + writeSQSErrorFromErr(w, err) + return + } + requested, err := parseAttributesIntoMeta(in.QueueName, in.Attributes) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + requested.Tags = in.Tags + + if err := s.createQueueWithRetry(r.Context(), requested); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]string{"QueueUrl": s.queueURL(r, in.QueueName)}) +} + +func (s *SQSServer) createQueueWithRetry(ctx context.Context, requested *sqsQueueMeta) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + done, err := s.tryCreateQueueOnce(ctx, requested) + if err == nil && done { + return nil + } + if err != nil && !isRetryableTransactWriteError(err) { + return err + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "create queue retry attempts exhausted") +} + +// tryCreateQueueOnce runs one read/check/dispatch pass. The first bool reports +// whether the caller should stop retrying: true means the queue now exists +// with the requested attributes, false means the dispatch hit a retryable +// conflict and should be retried after backoff. +func (s *SQSServer) tryCreateQueueOnce(ctx context.Context, requested *sqsQueueMeta) (bool, error) { + readTS := s.nextTxnReadTS(ctx) + existing, exists, err := s.loadQueueMetaAt(ctx, requested.Name, readTS) + if err != nil { + return false, errors.WithStack(err) + } + if exists { + if attributesEqual(existing, requested) { + return true, nil + } + return false, newSQSAPIError(http.StatusBadRequest, sqsErrQueueNameExists, "queue already exists with different attributes") + } + lastGen, err := s.loadQueueGenerationAt(ctx, requested.Name, readTS) + if err != nil { + return false, errors.WithStack(err) + } + requested.Generation = lastGen + 1 + if clock := s.coordinator.Clock(); clock != nil { + requested.CreatedAtHLC = clock.Current() + } + metaBytes, err := encodeSQSQueueMeta(requested) + if err != nil { + return false, errors.WithStack(err) + } + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: sqsQueueMetaKey(requested.Name), Value: metaBytes}, + {Op: kv.Put, Key: sqsQueueGenKey(requested.Name), Value: []byte(strconv.FormatUint(requested.Generation, 10))}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err != nil { + return false, errors.WithStack(err) + } + return true, nil +} + +func (s *SQSServer) deleteQueue(w http.ResponseWriter, r *http.Request) { + var in sqsDeleteQueueInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + name, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := s.deleteQueueWithRetry(r.Context(), name); err != nil { + writeSQSErrorFromErr(w, err) + return + } + // SQS DeleteQueue returns 200 with an empty body. + writeSQSJSON(w, map[string]any{}) +} + +func (s *SQSServer) deleteQueueWithRetry(ctx context.Context, queueName string) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := s.nextTxnReadTS(ctx) + _, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newSQSAPIError(http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + } + + // Bump the generation counter so any stragglers under the old + // generation are unreachable by routing. Actual message cleanup + // lands in a follow-up PR along with the message keyspace. + lastGen, err := s.loadQueueGenerationAt(ctx, queueName, readTS) + if err != nil { + return errors.WithStack(err) + } + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: sqsQueueMetaKey(queueName)}, + {Op: kv.Put, Key: sqsQueueGenKey(queueName), Value: []byte(strconv.FormatUint(lastGen+1, 10))}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "delete queue retry attempts exhausted") +} + +func (s *SQSServer) listQueues(w http.ResponseWriter, r *http.Request) { + var in sqsListQueuesInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + maxResults := clampListQueuesMaxResults(in.MaxResults) + + names, err := s.scanQueueNames(r.Context()) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + sort.Strings(names) + filtered := filterByPrefix(names, in.QueueNamePrefix) + start := resolveListQueuesStart(filtered, in.NextToken) + + end := start + maxResults + truncated := end < len(filtered) + if !truncated { + end = len(filtered) + } + page := filtered[start:end] + + urls := make([]string, 0, len(page)) + for _, n := range page { + urls = append(urls, s.queueURL(r, n)) + } + resp := map[string]any{"QueueUrls": urls} + if truncated && len(page) > 0 { + resp["NextToken"] = encodeSQSSegment(page[len(page)-1]) + } + writeSQSJSON(w, resp) +} + +func clampListQueuesMaxResults(requested int) int { + if requested <= 0 { + return sqsListQueuesDefaultMaxResults + } + if requested > sqsListQueuesHardMaxResults { + return sqsListQueuesHardMaxResults + } + return requested +} + +func filterByPrefix(names []string, prefix string) []string { + if prefix == "" { + return names + } + out := names[:0] + for _, n := range names { + if strings.HasPrefix(n, prefix) { + out = append(out, n) + } + } + return out +} + +// resolveListQueuesStart decodes the NextToken boundary and returns the index +// of the first name strictly greater than it. A malformed token behaves as +// "start from the beginning" to match AWS's lenient behavior. +func resolveListQueuesStart(names []string, token string) int { + if token == "" { + return 0 + } + boundary, err := decodeSQSSegment(token) + if err != nil { + return 0 + } + for i, n := range names { + if n > boundary { + return i + } + } + return len(names) +} + +func (s *SQSServer) scanQueueNames(ctx context.Context) ([]string, error) { + prefix := []byte(SqsQueueMetaPrefix) + end := prefixScanEnd(prefix) + start := bytes.Clone(prefix) + readTS := s.nextTxnReadTS(ctx) + var names []string + for { + kvs, err := s.store.ScanAt(ctx, start, end, sqsQueueScanPageLimit, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if len(kvs) == 0 { + break + } + for _, kvp := range kvs { + if !bytes.HasPrefix(kvp.Key, prefix) { + return names, nil + } + name, ok := queueNameFromMetaKey(kvp.Key) + if !ok { + continue + } + names = append(names, name) + } + if len(kvs) < sqsQueueScanPageLimit { + break + } + start = nextScanCursorAfter(kvs[len(kvs)-1].Key) + if end != nil && bytes.Compare(start, end) > 0 { + break + } + } + return names, nil +} + +func (s *SQSServer) getQueueUrl(w http.ResponseWriter, r *http.Request) { + var in sqsGetQueueUrlInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := validateQueueName(in.QueueName); err != nil { + writeSQSErrorFromErr(w, err) + return + } + _, exists, err := s.loadQueueMetaAt(r.Context(), in.QueueName, s.nextTxnReadTS(r.Context())) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if !exists { + writeSQSError(w, http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + return + } + writeSQSJSON(w, map[string]string{"QueueUrl": s.queueURL(r, in.QueueName)}) +} + +func (s *SQSServer) getQueueAttributes(w http.ResponseWriter, r *http.Request) { + var in sqsGetQueueAttributesInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + name, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + meta, exists, err := s.loadQueueMetaAt(r.Context(), name, s.nextTxnReadTS(r.Context())) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if !exists { + writeSQSError(w, http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + return + } + selection := selectedAttributeNames(in.AttributeNames) + attrs := queueMetaToAttributes(meta, selection) + writeSQSJSON(w, map[string]any{"Attributes": attrs}) +} + +// selectedAttributeNames returns a set of attribute names to include in the +// response. An empty selection, or any entry equal to "All", expands to +// every supported attribute. +func selectedAttributeNames(req []string) map[string]bool { + selection := map[string]bool{} + if len(req) == 0 { + return nil + } + for _, n := range req { + if n == "All" { + return nil + } + selection[n] = true + } + return selection +} + +func queueMetaToAttributes(meta *sqsQueueMeta, selection map[string]bool) map[string]string { + all := map[string]string{ + "VisibilityTimeout": strconv.FormatInt(meta.VisibilityTimeoutSeconds, 10), + "MessageRetentionPeriod": strconv.FormatInt(meta.MessageRetentionSeconds, 10), + "DelaySeconds": strconv.FormatInt(meta.DelaySeconds, 10), + "ReceiveMessageWaitTimeSeconds": strconv.FormatInt(meta.ReceiveMessageWaitSeconds, 10), + "MaximumMessageSize": strconv.FormatInt(meta.MaximumMessageSize, 10), + "FifoQueue": strconv.FormatBool(meta.IsFIFO), + "ContentBasedDeduplication": strconv.FormatBool(meta.ContentBasedDedup), + } + if meta.RedrivePolicy != "" { + all["RedrivePolicy"] = meta.RedrivePolicy + } + if selection == nil { + return all + } + out := make(map[string]string, len(selection)) + for k := range selection { + if v, ok := all[k]; ok { + out[k] = v + } + } + return out +} + +func (s *SQSServer) setQueueAttributes(w http.ResponseWriter, r *http.Request) { + var in sqsSetQueueAttributesInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + name, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := s.setQueueAttributesWithRetry(r.Context(), name, in.Attributes); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]any{}) +} + +func (s *SQSServer) setQueueAttributesWithRetry(ctx context.Context, queueName string, attrs map[string]string) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := s.nextTxnReadTS(ctx) + meta, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newSQSAPIError(http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + } + if err := applyAttributes(meta, attrs); err != nil { + return err + } + metaBytes, err := encodeSQSQueueMeta(meta) + if err != nil { + return errors.WithStack(err) + } + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: sqsQueueMetaKey(queueName), Value: metaBytes}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "set queue attributes retry attempts exhausted") +} diff --git a/adapter/sqs_catalog_test.go b/adapter/sqs_catalog_test.go new file mode 100644 index 000000000..3dac5fd54 --- /dev/null +++ b/adapter/sqs_catalog_test.go @@ -0,0 +1,294 @@ +package adapter + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + + json "github.com/goccy/go-json" +) + +// callSQS routes a JSON-protocol request to the given node's SQS endpoint. +// The helper exists so tests read like "createQueue → 200 with a URL" rather +// than having to hand-build X-Amz-Target envelopes every time. +func callSQS(t *testing.T, node Node, target string, in any) (int, map[string]any) { + t.Helper() + body, err := json.Marshal(in) + if err != nil { + t.Fatalf("marshal: %v", err) + } + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, + "http://"+node.sqsAddress+"/", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("X-Amz-Target", target) + req.Header.Set("Content-Type", sqsContentTypeJSON) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) + out := map[string]any{} + if len(bytes.TrimSpace(raw)) > 0 { + if err := json.Unmarshal(raw, &out); err != nil { + t.Fatalf("decode %q: %v", string(raw), err) + } + } + return resp.StatusCode, out +} + +func sqsLeaderNode(t *testing.T, nodes []Node) Node { + t.Helper() + for _, n := range nodes { + if n.engine != nil && n.engine.Leader().Address == n.raftAddress { + return n + } + } + return nodes[0] +} + +func TestSQSServer_CatalogCreateGetList(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + // CreateQueue: 200 with a QueueUrl that ends in the queue name. + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "orders", + }) + if status != http.StatusOK { + t.Fatalf("create: status %d body %v", status, out) + } + url, _ := out["QueueUrl"].(string) + if !strings.HasSuffix(url, "/orders") { + t.Fatalf("QueueUrl %q does not end in /orders", url) + } + + // GetQueueUrl returns the same URL. + status, out = callSQS(t, node, sqsGetQueueUrlTarget, map[string]any{ + "QueueName": "orders", + }) + if status != http.StatusOK { + t.Fatalf("getQueueUrl: status %d body %v", status, out) + } + if got, _ := out["QueueUrl"].(string); got != url { + t.Fatalf("GetQueueUrl=%q want %q", got, url) + } + + // ListQueues sees it. + status, out = callSQS(t, node, sqsListQueuesTarget, map[string]any{}) + if status != http.StatusOK { + t.Fatalf("list: status %d body %v", status, out) + } + urls, _ := out["QueueUrls"].([]any) + foundList := false + for _, u := range urls { + if s, _ := u.(string); s == url { + foundList = true + break + } + } + if !foundList { + t.Fatalf("ListQueues did not include %q; got %v", url, urls) + } +} + +func TestSQSServer_CatalogCreateIsIdempotent(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + in := map[string]any{ + "QueueName": "idempotent", + "Attributes": map[string]string{ + "VisibilityTimeout": "60", + }, + } + status1, out1 := callSQS(t, node, sqsCreateQueueTarget, in) + if status1 != http.StatusOK { + t.Fatalf("first create: %d %v", status1, out1) + } + // Second call with the same attributes must succeed with the same URL. + status2, out2 := callSQS(t, node, sqsCreateQueueTarget, in) + if status2 != http.StatusOK { + t.Fatalf("second create (same attrs): %d %v", status2, out2) + } + if out1["QueueUrl"] != out2["QueueUrl"] { + t.Fatalf("idempotent create returned different URLs: %v vs %v", out1, out2) + } + + // Third call with differing attributes must fail with QueueNameExists. + changed := map[string]any{ + "QueueName": "idempotent", + "Attributes": map[string]string{"VisibilityTimeout": "120"}, + } + status3, out3 := callSQS(t, node, sqsCreateQueueTarget, changed) + if status3 != http.StatusBadRequest { + t.Fatalf("differing-attrs create: got %d want 400; body %v", status3, out3) + } + if got, _ := out3["__type"].(string); got != sqsErrQueueNameExists { + t.Fatalf("differing-attrs error type: got %q want %q", got, sqsErrQueueNameExists) + } +} + +func TestSQSServer_CatalogGetAndSetAttributes(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "attrs", + }) + if status != http.StatusOK { + t.Fatalf("create: %d %v", status, out) + } + url, _ := out["QueueUrl"].(string) + + status, out = callSQS(t, node, sqsGetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "AttributeNames": []string{"All"}, + }) + if status != http.StatusOK { + t.Fatalf("getAttrs: %d %v", status, out) + } + attrs, _ := out["Attributes"].(map[string]any) + if attrs["VisibilityTimeout"] != "30" { + t.Fatalf("default VisibilityTimeout = %v, want 30", attrs["VisibilityTimeout"]) + } + + status, out = callSQS(t, node, sqsSetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "Attributes": map[string]string{ + "VisibilityTimeout": "90", + "DelaySeconds": "5", + }, + }) + if status != http.StatusOK { + t.Fatalf("setAttrs: %d %v", status, out) + } + + _, out = callSQS(t, node, sqsGetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "AttributeNames": []string{"VisibilityTimeout", "DelaySeconds"}, + }) + attrs, _ = out["Attributes"].(map[string]any) + if attrs["VisibilityTimeout"] != "90" || attrs["DelaySeconds"] != "5" { + t.Fatalf("updated attrs = %v, want VisibilityTimeout=90 DelaySeconds=5", attrs) + } +} + +func TestSQSServer_CatalogDelete(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + _, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "deleteme", + }) + url, _ := out["QueueUrl"].(string) + + status, out := callSQS(t, node, sqsDeleteQueueTarget, map[string]any{ + "QueueUrl": url, + }) + if status != http.StatusOK { + t.Fatalf("delete: %d %v", status, out) + } + + // GetQueueUrl after delete returns NonExistentQueue. + status, out = callSQS(t, node, sqsGetQueueUrlTarget, map[string]any{ + "QueueName": "deleteme", + }) + if status != http.StatusBadRequest { + t.Fatalf("getQueueUrl after delete: got %d want 400; body %v", status, out) + } + if got, _ := out["__type"].(string); got != sqsErrQueueDoesNotExist { + t.Fatalf("error type: got %q want %q", got, sqsErrQueueDoesNotExist) + } + + // DeleteQueue on an unknown queue also returns NonExistentQueue. + status, _ = callSQS(t, node, sqsDeleteQueueTarget, map[string]any{ + "QueueUrl": url, + }) + if status != http.StatusBadRequest { + t.Fatalf("second delete: got %d want 400", status) + } +} + +func TestSQSServer_CatalogFIFOValidation(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + // FIFO name with FifoQueue=false is rejected. + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "bad.fifo", + "Attributes": map[string]string{"FifoQueue": "false"}, + }) + if status != http.StatusBadRequest { + t.Fatalf("mismatch name/FifoQueue: got %d want 400; body %v", status, out) + } + + // Non-FIFO name with FifoQueue=true is rejected. + status, out = callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "plain", + "Attributes": map[string]string{"FifoQueue": "true"}, + }) + if status != http.StatusBadRequest { + t.Fatalf("plain name + FifoQueue=true: got %d want 400; body %v", status, out) + } + + // Valid FIFO succeeds and the attribute is echoed back. + status, out = callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "events.fifo", + "Attributes": map[string]string{"FifoQueue": "true"}, + }) + if status != http.StatusOK { + t.Fatalf("FIFO create: %d %v", status, out) + } + url, _ := out["QueueUrl"].(string) + _, out = callSQS(t, node, sqsGetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "AttributeNames": []string{"FifoQueue"}, + }) + attrs, _ := out["Attributes"].(map[string]any) + if attrs["FifoQueue"] != "true" { + t.Fatalf("FIFO flag not persisted: %v", attrs) + } +} + +func TestSQSServer_CatalogKeyEncoding(t *testing.T) { + t.Parallel() + for _, name := range []string{"", "a", "hello world", "queue.fifo", strings.Repeat("x", 80)} { + encoded := encodeSQSSegment(name) + decoded, err := decodeSQSSegment(encoded) + if err != nil { + t.Fatalf("decode %q: %v", name, err) + } + if decoded != name { + t.Fatalf("round-trip %q -> %q -> %q", name, encoded, decoded) + } + } + + // queueNameFromMetaKey round-trips sqsQueueMetaKey. + name := "round.trip.fifo" + key := sqsQueueMetaKey(name) + got, ok := queueNameFromMetaKey(key) + if !ok || got != name { + t.Fatalf("queueNameFromMetaKey(sqsQueueMetaKey(%q)) = (%q, %v), want (%q, true)", name, got, ok, name) + } + + // Unknown prefixes are rejected. + if _, ok := queueNameFromMetaKey([]byte("random")); ok { + t.Fatal("queueNameFromMetaKey should reject non-catalog keys") + } +} diff --git a/adapter/sqs_keys.go b/adapter/sqs_keys.go new file mode 100644 index 000000000..f0f402665 --- /dev/null +++ b/adapter/sqs_keys.go @@ -0,0 +1,58 @@ +package adapter + +import ( + "encoding/base64" + "strings" + + "github.com/cockroachdb/errors" +) + +// SQS keyspace prefixes. Kept in sync with the naming in +// docs/design/2026_04_24_proposed_sqs_compatible_adapter.md. +const ( + // SqsQueueMetaPrefix prefixes queue-metadata records. + SqsQueueMetaPrefix = "!sqs|queue|meta|" + // SqsQueueGenPrefix prefixes the per-queue monotonic generation counter. + // Bumped on DeleteQueue / PurgeQueue so keys from an older incarnation of + // the same queue name cannot leak into a newly created queue. + SqsQueueGenPrefix = "!sqs|queue|gen|" +) + +func sqsQueueMetaKey(queueName string) []byte { + return []byte(SqsQueueMetaPrefix + encodeSQSSegment(queueName)) +} + +func sqsQueueGenKey(queueName string) []byte { + return []byte(SqsQueueGenPrefix + encodeSQSSegment(queueName)) +} + +// encodeSQSSegment emits a printable, byte-ordered-unique representation of a +// queue name. Base64 raw URL encoding matches the encoding the DynamoDB +// adapter uses for table segments (see encodeDynamoSegment) so that operators +// reading raw keys from Pebble see the same shape across both adapters. +func encodeSQSSegment(v string) string { + return base64.RawURLEncoding.EncodeToString([]byte(v)) +} + +func decodeSQSSegment(v string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(v) + if err != nil { + return "", errors.WithStack(err) + } + return string(b), nil +} + +// queueNameFromMetaKey pulls the queue name out of a !sqs|queue|meta| +// key. The second return reports success so callers can skip keys that were +// not written by this adapter. +func queueNameFromMetaKey(key []byte) (string, bool) { + enc, ok := strings.CutPrefix(string(key), SqsQueueMetaPrefix) + if !ok || enc == "" { + return "", false + } + name, err := decodeSQSSegment(enc) + if err != nil { + return "", false + } + return name, true +} diff --git a/adapter/sqs_test.go b/adapter/sqs_test.go index 078d4f05f..9cf122171 100644 --- a/adapter/sqs_test.go +++ b/adapter/sqs_test.go @@ -124,13 +124,12 @@ func TestSQSServer_KnownTargetsReturnNotImplemented(t *testing.T) { t.Parallel() base := startTestSQSServer(t) + // Targets that still return NotImplemented. The catalog targets + // (CreateQueue/DeleteQueue/ListQueues/GetQueueUrl/GetQueueAttributes/ + // SetQueueAttributes) are covered by TestSQSServer_Catalog* against a + // real single-node cluster because they require a coordinator to + // dispatch transactions. targets := []string{ - sqsCreateQueueTarget, - sqsDeleteQueueTarget, - sqsListQueuesTarget, - sqsGetQueueUrlTarget, - sqsGetQueueAttributesTarget, - sqsSetQueueAttributesTarget, sqsPurgeQueueTarget, sqsSendMessageTarget, sqsSendMessageBatchTarget, From 61fa917675b5d30126c580331f96309490da1084 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 15:20:42 +0900 Subject: [PATCH 4/9] feat(sqs): SigV4 verifier shared with S3, wire static credentials Refactor S3's SigV4 primitives into adapter/sigv4.go and implement the SQS JSON-protocol verifier on top of them. - adapter/sigv4.go: service-agnostic SigV4 helpers (parseSigV4AuthorizationHeader, buildSigV4AuthorizationHeader, buildSigV4AuthorizationHeaderRestricted, extractSigV4Signature / SignedHeaders / Auth fields). The *Restricted builder strips transport-added headers that are not listed in SignedHeaders, so re-signing a server-side request does not fail because Go's http client added Accept-Encoding after the client signed. - adapter/s3_auth.go: delegates parse/build/extract to the shared helpers. Error ordering and the presigned-URL path are unchanged to keep the existing S3 tests green. - adapter/sqs_auth.go: WithSQSRegion / WithSQSStaticCredentials options and the JSON-protocol verifier, split into validateSQSAuthScope / validateSQSAuthDate / drainAndHashSQSBody / verifySQSSignatureMatches so each pass is independently testable. When no credentials are configured the endpoint stays open (matching S3's scaffold-test-friendly behavior). - adapter/sqs.go: SQSServer gains region + staticCreds fields and calls authorizeSQSRequest after the leader-proxy check, before dispatching to the handler map. - adapter/sqs_auth_test.go: happy-path SigV4, missing Authorization, wrong-service scope, unknown access key, tampered body, clock skew, and the no-credentials open-endpoint case. All tests drive requests through the actual AWS SDK v4 signer so the adapter stays compatible with real SDK callers. - main_sigv4_creds.go: loadSigV4StaticCredentialsFile, the shared JSON credentials loader; main_s3.go.loadS3StaticCredentials now delegates to it. - main_sqs.go + main.go: --sqsRegion / --sqsCredentialsFile flags and per-runner fields, passed through to WithSQSRegion / WithSQSStaticCredentials. --- adapter/s3_auth.go | 85 ++------------- adapter/s3_test.go | 2 +- adapter/sigv4.go | 203 ++++++++++++++++++++++++++++++++++ adapter/sqs.go | 7 ++ adapter/sqs_auth.go | 178 ++++++++++++++++++++++++++++++ adapter/sqs_auth_test.go | 230 +++++++++++++++++++++++++++++++++++++++ main.go | 8 +- main_s3.go | 39 +------ main_sigv4_creds.go | 58 ++++++++++ main_sqs.go | 9 ++ 10 files changed, 703 insertions(+), 116 deletions(-) create mode 100644 adapter/sigv4.go create mode 100644 adapter/sqs_auth.go create mode 100644 adapter/sqs_auth_test.go create mode 100644 main_sigv4_creds.go diff --git a/adapter/s3_auth.go b/adapter/s3_auth.go index 1df6070de..15a61bbd1 100644 --- a/adapter/s3_auth.go +++ b/adapter/s3_auth.go @@ -1,7 +1,6 @@ package adapter import ( - "context" "crypto/subtle" "net/http" "net/url" @@ -12,11 +11,12 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/bootjp/elastickv/kv" - "github.com/cockroachdb/errors" ) const ( - s3SigV4Algorithm = "AWS4-HMAC-SHA256" + // s3SigV4Algorithm is an alias of sigv4Algorithm kept for call-site + // readability in the S3 adapter. + s3SigV4Algorithm = sigv4Algorithm // s3UnsignedPayload / s3Streaming* are sentinel values that AWS SDKs may // place in X-Amz-Content-Sha256. None of them are a literal SHA-256 hash // of the body, so the PUT pipeline must skip hash validation when it sees @@ -27,8 +27,8 @@ const ( s3StreamingSignedPayload = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" s3StreamingSignedPayloadTrailer = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER" s3EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - s3DateHeaderFormat = "20060102T150405Z" - s3RequestTimeMaxSkew = 15 * time.Minute + s3DateHeaderFormat = sigv4DateHeaderFormat + s3RequestTimeMaxSkew = sigv4RequestTimeMaxSkew ) // isS3PayloadMarker reports whether the given X-Amz-Content-Sha256 value is a @@ -79,14 +79,6 @@ type s3AuthError struct { Message string } -type s3AuthorizationHeader struct { - Algorithm string - AccessKeyID string - Date string - Region string - Service string -} - func WithS3Region(region string) S3ServerOption { return func(server *S3Server) { if server == nil || strings.TrimSpace(region) == "" { @@ -151,7 +143,7 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { } } - parsed, err := parseS3AuthorizationHeader(authHeader) + parsed, err := parseSigV4AuthorizationHeader(authHeader) if err != nil { return &s3AuthError{ Status: http.StatusForbidden, @@ -224,7 +216,7 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { } } - expectedAuth, err := buildS3AuthorizationHeader(r, parsed.AccessKeyID, secretAccessKey, s.effectiveRegion(), signingTime, payloadHash) + expectedAuth, err := buildSigV4AuthorizationHeader(r, parsed.AccessKeyID, secretAccessKey, "s3", s.effectiveRegion(), signingTime, payloadHash) if err != nil { return &s3AuthError{ Status: http.StatusForbidden, @@ -235,8 +227,8 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { // Compare only the Signature component to avoid false rejections caused by // equivalent Authorization headers that differ in whitespace or parameter // ordering (but carry the same cryptographic signature). - gotSig := extractS3Signature(authHeader) - expectedSig := extractS3Signature(expectedAuth) + gotSig := extractSigV4Signature(authHeader) + expectedSig := extractSigV4Signature(expectedAuth) if gotSig == "" || subtle.ConstantTimeCompare([]byte(gotSig), []byte(expectedSig)) != 1 { return &s3AuthError{ Status: http.StatusForbidden, @@ -247,65 +239,6 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { return nil } -func buildS3AuthorizationHeader(r *http.Request, accessKeyID string, secretAccessKey string, region string, signingTime time.Time, payloadHash string) (string, error) { - if r == nil { - return "", errors.New("request is required") - } - clone := r.Clone(context.Background()) - clone.Host = r.Host - clone.Header = clone.Header.Clone() - clone.Header.Del("Authorization") - - signer := v4.NewSigner(func(opts *v4.SignerOptions) { - opts.DisableURIPathEscaping = true - }) - creds := aws.Credentials{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - Source: "elastickv-s3", - } - if err := signer.SignHTTP(context.Background(), creds, clone, payloadHash, "s3", region, signingTime.UTC()); err != nil { - return "", errors.WithStack(err) - } - return strings.TrimSpace(clone.Header.Get("Authorization")), nil -} - -//nolint:cyclop // AWS authorization parsing is branchy because malformed scopes must be rejected precisely. -func parseS3AuthorizationHeader(raw string) (s3AuthorizationHeader, error) { - raw = strings.TrimSpace(raw) - if raw == "" { - return s3AuthorizationHeader{}, errors.New("authorization header is required") - } - algorithm, rest, ok := strings.Cut(raw, " ") - if !ok { - return s3AuthorizationHeader{}, errors.New("authorization header is malformed") - } - out := s3AuthorizationHeader{Algorithm: strings.TrimSpace(algorithm)} - params := strings.Split(rest, ",") - for _, param := range params { - key, value, ok := strings.Cut(strings.TrimSpace(param), "=") - if !ok { - continue - } - if key != "Credential" { - continue - } - scope := strings.Split(value, "/") - if len(scope) != 5 || scope[4] != "aws4_request" { - return s3AuthorizationHeader{}, errors.New("credential scope is malformed") - } - out.AccessKeyID = scope[0] - out.Date = scope[1] - out.Region = scope[2] - out.Service = scope[3] - break - } - if out.AccessKeyID == "" || out.Date == "" || out.Region == "" || out.Service == "" { - return s3AuthorizationHeader{}, errors.New("credential scope is required") - } - return out, nil -} - func normalizeS3PayloadHash(raw string) string { return strings.TrimSpace(raw) } diff --git a/adapter/s3_test.go b/adapter/s3_test.go index f3f27d2be..457a540bd 100644 --- a/adapter/s3_test.go +++ b/adapter/s3_test.go @@ -874,7 +874,7 @@ func newSignedS3Request( signingTime, ) require.NoError(t, err) - expectedAuth, err := buildS3AuthorizationHeader(req, testS3AccessKey, testS3SecretKey, testS3Region, signingTime, payloadHash) + expectedAuth, err := buildSigV4AuthorizationHeader(req, testS3AccessKey, testS3SecretKey, "s3", testS3Region, signingTime, payloadHash) require.NoError(t, err) require.Equal(t, strings.TrimSpace(req.Header.Get("Authorization")), expectedAuth) return req diff --git a/adapter/sigv4.go b/adapter/sigv4.go new file mode 100644 index 000000000..1f26ff103 --- /dev/null +++ b/adapter/sigv4.go @@ -0,0 +1,203 @@ +package adapter + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/cockroachdb/errors" +) + +// Service-agnostic SigV4 primitives shared by the S3 and SQS adapters. +// The per-adapter files wrap these with their own error ordering and +// service-specific rules (payload hashing for S3, JSON body hashing for +// SQS, etc.). + +const ( + sigv4Algorithm = "AWS4-HMAC-SHA256" + sigv4DateHeaderFormat = "20060102T150405Z" + sigv4ScopeDateFormat = "20060102" + sigv4RequestTimeMaxSkew = 15 * time.Minute + sigv4ScopeTerminator = "aws4_request" + sigv4ScopeParts = 5 +) + +// sigv4AuthorizationHeader captures the Credential scope fields extracted +// from a SigV4 Authorization header ("AWS4-HMAC-SHA256 Credential=/ +// ///aws4_request, SignedHeaders=..., Signature=..."). +type sigv4AuthorizationHeader struct { + Algorithm string + AccessKeyID string + Date string + Region string + Service string +} + +// parseSigV4AuthorizationHeader decodes the Credential scope from an +// Authorization header. Other fields (SignedHeaders, Signature) are not +// parsed here because the adapter-level verifier rebuilds and compares the +// full signature through the AWS SDK v4 signer. +func parseSigV4AuthorizationHeader(raw string) (sigv4AuthorizationHeader, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return sigv4AuthorizationHeader{}, errors.New("authorization header is required") + } + algorithm, rest, ok := strings.Cut(raw, " ") + if !ok { + return sigv4AuthorizationHeader{}, errors.New("authorization header is malformed") + } + credentialValue, ok := findSigV4Param(rest, "Credential") + if !ok { + return sigv4AuthorizationHeader{}, errors.New("credential scope is required") + } + out, err := parseSigV4CredentialScope(credentialValue) + if err != nil { + return sigv4AuthorizationHeader{}, err + } + out.Algorithm = strings.TrimSpace(algorithm) + return out, nil +} + +// findSigV4Param returns the value of the named parameter from the +// "k=v, k=v, ..." tail of an Authorization header. +func findSigV4Param(params, name string) (string, bool) { + for _, param := range strings.Split(params, ",") { + key, value, ok := strings.Cut(strings.TrimSpace(param), "=") + if ok && key == name { + return strings.TrimSpace(value), true + } + } + return "", false +} + +// parseSigV4CredentialScope validates an +// ////aws4_request scope string. +func parseSigV4CredentialScope(value string) (sigv4AuthorizationHeader, error) { + scope := strings.Split(value, "/") + if len(scope) != sigv4ScopeParts || scope[4] != sigv4ScopeTerminator { + return sigv4AuthorizationHeader{}, errors.New("credential scope is malformed") + } + out := sigv4AuthorizationHeader{ + AccessKeyID: scope[0], + Date: scope[1], + Region: scope[2], + Service: scope[3], + } + if out.AccessKeyID == "" || out.Date == "" || out.Region == "" || out.Service == "" { + return sigv4AuthorizationHeader{}, errors.New("credential scope is required") + } + return out, nil +} + +// buildSigV4AuthorizationHeader re-signs r with the given credentials and +// returns the Authorization header the server expects. Used by adapter +// verifiers to compare against the client-supplied Authorization. +func buildSigV4AuthorizationHeader( + r *http.Request, + accessKeyID, secretAccessKey, service, region string, + signingTime time.Time, + payloadHash string, +) (string, error) { + if r == nil { + return "", errors.New("request is required") + } + clone := r.Clone(context.Background()) + clone.Host = r.Host + clone.Header = clone.Header.Clone() + clone.Header.Del("Authorization") + + signer := v4.NewSigner(func(opts *v4.SignerOptions) { + opts.DisableURIPathEscaping = true + }) + creds := aws.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + Source: "elastickv-" + service, + } + if err := signer.SignHTTP(context.Background(), creds, clone, payloadHash, service, region, signingTime.UTC()); err != nil { + return "", errors.WithStack(err) + } + return strings.TrimSpace(clone.Header.Get("Authorization")), nil +} + +// extractSigV4Signature returns the hex Signature= value from an +// Authorization header, or "" if missing. +func extractSigV4Signature(auth string) string { + return extractSigV4AuthField(auth, "Signature") +} + +// extractSigV4SignedHeaders returns the semicolon-separated SignedHeaders +// list as a []string, or nil if missing. Header names are lowercased. +func extractSigV4SignedHeaders(auth string) []string { + raw := extractSigV4AuthField(auth, "SignedHeaders") + if raw == "" { + return nil + } + parts := strings.Split(raw, ";") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + out = append(out, strings.ToLower(p)) + } + return out +} + +func extractSigV4AuthField(auth, field string) string { + _, params, ok := strings.Cut(auth, " ") + if !ok { + return "" + } + v, _ := findSigV4Param(params, field) + return v +} + +// buildSigV4AuthorizationHeaderRestricted is a variant of +// buildSigV4AuthorizationHeader that strips any header not listed in +// signedHeaders before re-signing, so the server's re-computation matches +// the client's canonical request even when Go's transport added headers +// (Accept-Encoding, etc.) after the original signature. +func buildSigV4AuthorizationHeaderRestricted( + r *http.Request, + accessKeyID, secretAccessKey, service, region string, + signingTime time.Time, + payloadHash string, + signedHeaders []string, +) (string, error) { + if r == nil { + return "", errors.New("request is required") + } + clone := r.Clone(context.Background()) + clone.Host = r.Host + clone.Header = clone.Header.Clone() + clone.Header.Del("Authorization") + if len(signedHeaders) > 0 { + allowed := make(map[string]struct{}, len(signedHeaders)) + for _, h := range signedHeaders { + allowed[strings.ToLower(strings.TrimSpace(h))] = struct{}{} + } + for name := range clone.Header { + if _, ok := allowed[strings.ToLower(name)]; !ok { + clone.Header.Del(name) + } + } + } + + signer := v4.NewSigner(func(opts *v4.SignerOptions) { + opts.DisableURIPathEscaping = true + }) + creds := aws.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + Source: "elastickv-" + service, + } + if err := signer.SignHTTP(context.Background(), creds, clone, payloadHash, service, region, signingTime.UTC()); err != nil { + return "", errors.WithStack(err) + } + return strings.TrimSpace(clone.Header.Get("Authorization")), nil +} diff --git a/adapter/sqs.go b/adapter/sqs.go index d0ec3c150..6b30085bc 100644 --- a/adapter/sqs.go +++ b/adapter/sqs.go @@ -67,6 +67,8 @@ type SQSServer struct { httpServer *http.Server targetHandlers map[string]func(http.ResponseWriter, *http.Request) leaderSQS map[string]string + region string + staticCreds map[string]string } // WithSQSLeaderMap configures the Raft-address-to-SQS-address mapping used to @@ -144,6 +146,11 @@ func (s *SQSServer) handle(w http.ResponseWriter, r *http.Request) { return } + if authErr := s.authorizeSQSRequest(r); authErr != nil { + writeSQSError(w, authErr.Status, authErr.Code, authErr.Message) + return + } + target := r.Header.Get("X-Amz-Target") handler, ok := s.targetHandlers[target] if !ok { diff --git a/adapter/sqs_auth.go b/adapter/sqs_auth.go new file mode 100644 index 000000000..c8f49ea77 --- /dev/null +++ b/adapter/sqs_auth.go @@ -0,0 +1,178 @@ +package adapter + +import ( + "bytes" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "io" + "net/http" + "strings" + "time" +) + +// Service name used in the SigV4 credential scope. +const sqsSigV4Service = "sqs" + +// sqsAuthError is the SQS-flavored counterpart to s3AuthError. The outer +// handler turns these into the AWS JSON-1.0 error envelope via +// writeSQSError. +type sqsAuthError struct { + Status int + Code string + Message string +} + +// WithSQSRegion configures the signing region the adapter expects inside +// the Credential scope. Empty values retain the previous setting. +func WithSQSRegion(region string) SQSServerOption { + return func(s *SQSServer) { + if s == nil || strings.TrimSpace(region) == "" { + return + } + s.region = strings.TrimSpace(region) + } +} + +// WithSQSStaticCredentials supplies the access-key → secret map the +// adapter will accept. Passing an empty map disables authorization +// entirely (open endpoint), matching the S3 adapter's behavior for +// unit-test friendliness. +func WithSQSStaticCredentials(creds map[string]string) SQSServerOption { + return func(s *SQSServer) { + if s == nil || len(creds) == 0 { + return + } + s.staticCreds = make(map[string]string, len(creds)) + for accessKeyID, secretAccessKey := range creds { + accessKeyID = strings.TrimSpace(accessKeyID) + secretAccessKey = strings.TrimSpace(secretAccessKey) + if accessKeyID == "" || secretAccessKey == "" { + continue + } + s.staticCreds[accessKeyID] = secretAccessKey + } + if len(s.staticCreds) == 0 { + s.staticCreds = nil + } + } +} + +func (s *SQSServer) effectiveRegion() string { + if s == nil || strings.TrimSpace(s.region) == "" { + return "us-east-1" + } + return s.region +} + +// authorizeSQSRequest verifies SigV4 on an SQS JSON-protocol request. The +// flow mirrors S3's header-based path (no presigned URLs or streaming +// payloads) but restricts the credential scope service to "sqs". +// +// Returns nil when authorization is either unconfigured (open endpoint) or +// the signature matches. Returns *sqsAuthError otherwise, suitable for +// writeSQSError. +// +// The function must consume and replace r.Body so that downstream handlers +// can still parse the JSON payload. +func (s *SQSServer) authorizeSQSRequest(r *http.Request) *sqsAuthError { + if s == nil || r == nil || len(s.staticCreds) == 0 { + return nil + } + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + if authHeader == "" { + return &sqsAuthError{Status: http.StatusForbidden, Code: "MissingAuthenticationToken", Message: "missing Authorization header"} + } + parsed, authErr := s.validateSQSAuthScope(authHeader) + if authErr != nil { + return authErr + } + secretAccessKey, ok := s.staticCreds[parsed.AccessKeyID] + if !ok { + return &sqsAuthError{Status: http.StatusForbidden, Code: "InvalidClientTokenId", Message: "unknown access key"} + } + signingTime, authErr := validateSQSAuthDate(r, parsed) + if authErr != nil { + return authErr + } + payloadHash, authErr := drainAndHashSQSBody(r) + if authErr != nil { + return authErr + } + return verifySQSSignatureMatches(r, authHeader, parsed.AccessKeyID, secretAccessKey, s.effectiveRegion(), signingTime, payloadHash) +} + +// validateSQSAuthScope parses the Credential scope and rejects scopes +// whose algorithm/service/region do not match this adapter. +func (s *SQSServer) validateSQSAuthScope(authHeader string) (sigv4AuthorizationHeader, *sqsAuthError) { + parsed, err := parseSigV4AuthorizationHeader(authHeader) + if err != nil { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "invalid Authorization header"} + } + if parsed.Algorithm != sigv4Algorithm { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "unsupported signature algorithm"} + } + if parsed.Service != sqsSigV4Service { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "credential scope service must be sqs"} + } + if parsed.Region != s.effectiveRegion() { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "credential scope region does not match server region"} + } + return parsed, nil +} + +// validateSQSAuthDate parses X-Amz-Date and verifies it matches the +// Credential scope date and is within the allowed clock-skew window. +func validateSQSAuthDate(r *http.Request, parsed sigv4AuthorizationHeader) (time.Time, *sqsAuthError) { + amzDate := strings.TrimSpace(r.Header.Get("X-Amz-Date")) + if amzDate == "" { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "MissingAuthenticationToken", Message: "missing x-amz-date header"} + } + signingTime, err := time.Parse(sigv4DateHeaderFormat, amzDate) + if err != nil { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "invalid x-amz-date header"} + } + if parsed.Date != signingTime.UTC().Format(sigv4ScopeDateFormat) { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "credential scope date does not match x-amz-date"} + } + skew := time.Now().UTC().Sub(signingTime.UTC()) + if skew < 0 { + skew = -skew + } + if skew > sigv4RequestTimeMaxSkew { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "RequestTimeTooSkewed", Message: "The difference between the request time and the server's time is too large"} + } + return signingTime, nil +} + +// drainAndHashSQSBody reads the request body so the signer reproduces the +// client's payload hash, then replaces r.Body so handler code can re-read +// it afterwards. +func drainAndHashSQSBody(r *http.Request) (string, *sqsAuthError) { + body, err := io.ReadAll(http.MaxBytesReader(nil, r.Body, sqsMaxRequestBodyBytes)) + if err != nil { + return "", &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "failed to read request body for signature verification"} + } + r.Body = io.NopCloser(bytes.NewReader(body)) + sum := sha256.Sum256(body) + return hex.EncodeToString(sum[:]), nil +} + +// verifySQSSignatureMatches rebuilds the expected Authorization header and +// compares its hex signature to the one the client sent. +// +// The restricted builder is used so Go's transport-added headers +// (Accept-Encoding etc.) do not leak into the canonical request. +func verifySQSSignatureMatches(r *http.Request, authHeader, accessKeyID, secretAccessKey, region string, signingTime time.Time, payloadHash string) *sqsAuthError { + signedHeaders := extractSigV4SignedHeaders(authHeader) + expectedAuth, err := buildSigV4AuthorizationHeaderRestricted(r, accessKeyID, secretAccessKey, sqsSigV4Service, region, signingTime, payloadHash, signedHeaders) + if err != nil { + return &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "failed to verify request signature"} + } + gotSig := extractSigV4Signature(authHeader) + expectedSig := extractSigV4Signature(expectedAuth) + if gotSig == "" || subtle.ConstantTimeCompare([]byte(gotSig), []byte(expectedSig)) != 1 { + return &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "request signature does not match"} + } + return nil +} diff --git a/adapter/sqs_auth_test.go b/adapter/sqs_auth_test.go new file mode 100644 index 000000000..a8f815513 --- /dev/null +++ b/adapter/sqs_auth_test.go @@ -0,0 +1,230 @@ +package adapter + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + json "github.com/goccy/go-json" +) + +const ( + testSQSAccessKey = "AKIATESTSQSAAAAAAAAA" + testSQSSecretKey = "test-secret-key/xxxxxxxxxxxxxxxxxxxxxxxx" + testSQSRegion = "us-west-2" +) + +// startAuthedSQSServer brings up an SQSServer with static credentials so +// we can exercise the full SigV4 verifier. No coordinator is wired in, so +// the server only executes the auth + health paths (handler requests land +// on unknown targets, returning 400 InvalidAction after auth passes). +func startAuthedSQSServer(t *testing.T) string { + t.Helper() + lc := &net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := NewSQSServer(listener, nil, nil, + WithSQSRegion(testSQSRegion), + WithSQSStaticCredentials(map[string]string{testSQSAccessKey: testSQSSecretKey}), + ) + go func() { + _ = srv.Run() + }() + t.Cleanup(func() { + srv.Stop() + }) + return "http://" + listener.Addr().String() +} + +// signSQSRequest signs an SQS request with the AWS SDK v4 signer against +// the supplied credentials, exactly as a real AWS SDK client would. +func signSQSRequest(t *testing.T, base, target string, body []byte, signingTime time.Time, accessKey, secretKey string) *http.Request { + region := testSQSRegion + t.Helper() + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", target) + sum := sha256.Sum256(body) + payloadHash := hex.EncodeToString(sum[:]) + + signer := v4.NewSigner(func(opts *v4.SignerOptions) { + opts.DisableURIPathEscaping = true + }) + if err := signer.SignHTTP(context.Background(), aws.Credentials{ + AccessKeyID: accessKey, + SecretAccessKey: secretKey, + Source: "test", + }, req, payloadHash, sqsSigV4Service, region, signingTime.UTC()); err != nil { + t.Fatalf("sign: %v", err) + } + return req +} + +func doReq(t *testing.T, req *http.Request) (int, map[string]string) { + t.Helper() + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) + out := map[string]string{} + if len(bytes.TrimSpace(raw)) > 0 { + _ = json.Unmarshal(raw, &out) + } + return resp.StatusCode, out +} + +func TestSQSAuth_ValidSignaturePassesAuth(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + body := []byte(`{}`) + req := signSQSRequest(t, base, "AmazonSQS.NoSuchOperation", body, time.Now().UTC(), + testSQSAccessKey, testSQSSecretKey) + + // Auth should pass; then the target is unknown, so the dispatcher + // returns InvalidAction — *not* SignatureDoesNotMatch. That precise + // ordering is what we care about: auth runs first, dispatch second. + status, body2 := doReq(t, req) + if status != http.StatusBadRequest { + t.Fatalf("status: got %d body %v", status, body2) + } + if body2["__type"] != sqsErrInvalidAction { + t.Fatalf("error type: got %q want %q", body2["__type"], sqsErrInvalidAction) + } +} + +func TestSQSAuth_MissingAuthorizationHeaderRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader([]byte("{}"))) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", sqsListQueuesTarget) + + status, body := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, body) + } + if body["__type"] != "MissingAuthenticationToken" { + t.Fatalf("error type: got %q", body["__type"]) + } +} + +func TestSQSAuth_WrongServiceScopeRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + body := []byte(`{}`) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", sqsListQueuesTarget) + sum := sha256.Sum256(body) + payloadHash := hex.EncodeToString(sum[:]) + signer := v4.NewSigner() + // Sign with service "s3" instead of "sqs". + if err := signer.SignHTTP(context.Background(), aws.Credentials{ + AccessKeyID: testSQSAccessKey, + SecretAccessKey: testSQSSecretKey, + }, req, payloadHash, "s3", testSQSRegion, time.Now().UTC()); err != nil { + t.Fatalf("sign: %v", err) + } + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "SignatureDoesNotMatch" { + t.Fatalf("error type: got %q", out["__type"]) + } +} + +func TestSQSAuth_UnknownAccessKeyRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + req := signSQSRequest(t, base, sqsListQueuesTarget, []byte("{}"), time.Now().UTC(), + "AKIAUNKNOWNEEEEEEEEE", "wrong-secret-keyxxxxxxxxxxxxxxxxxxxxxxxx") + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "InvalidClientTokenId" { + t.Fatalf("error type: got %q", out["__type"]) + } +} + +func TestSQSAuth_TamperedBodyRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + // Sign the original body, then send a different one — the server must + // hash the body it actually receives and catch the mismatch. + req := signSQSRequest(t, base, sqsListQueuesTarget, []byte(`{"QueueNamePrefix":"a"}`), time.Now().UTC(), + testSQSAccessKey, testSQSSecretKey) + req.Body = io.NopCloser(bytes.NewReader([]byte(`{"QueueNamePrefix":"b"}`))) + req.ContentLength = int64(len(`{"QueueNamePrefix":"b"}`)) + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "SignatureDoesNotMatch" { + t.Fatalf("error type: got %q body %v", out["__type"], out) + } +} + +func TestSQSAuth_ClockSkewRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + old := time.Now().UTC().Add(-30 * time.Minute) + req := signSQSRequest(t, base, sqsListQueuesTarget, []byte("{}"), old, + testSQSAccessKey, testSQSSecretKey) + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "RequestTimeTooSkewed" { + t.Fatalf("error type: got %q body %v", out["__type"], out) + } +} + +func TestSQSAuth_NoCredentialsMeansOpenEndpoint(t *testing.T) { + t.Parallel() + // No creds configured → scaffold test server already exercises this + // path (no Authorization header, no 403). This test explicitly + // double-checks: the dispatch runs even without auth, landing on + // InvalidAction for an unknown target. + base := startTestSQSServer(t) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader([]byte("{}"))) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", "AmazonSQS.NoSuchOperation") + + status, out := doReq(t, req) + if status != http.StatusBadRequest { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != sqsErrInvalidAction { + t.Fatalf("error type: got %q", out["__type"]) + } +} diff --git a/main.go b/main.go index a9bb6dbac..2e5e3c4ac 100644 --- a/main.go +++ b/main.go @@ -80,6 +80,8 @@ var ( s3CredsFile = flag.String("s3CredentialsFile", "", "Path to a JSON file containing static S3 credentials") s3PathStyleOnly = flag.Bool("s3PathStyleOnly", true, "Only accept path-style S3 requests") sqsAddr = flag.String("sqsAddress", "", "TCP host+port for SQS-compatible API; empty to disable") + sqsRegion = flag.String("sqsRegion", "us-east-1", "SQS signing region") + sqsCredsFile = flag.String("sqsCredentialsFile", "", "Path to a JSON file containing static SQS credentials") metricsAddr = flag.String("metricsAddress", "localhost:9090", "TCP host+port for Prometheus metrics") metricsToken = flag.String("metricsToken", "", "Bearer token for Prometheus metrics; required for non-loopback metricsAddress") pprofAddr = flag.String("pprofAddress", "localhost:6060", "TCP host+port for pprof debug endpoints; empty to disable") @@ -216,6 +218,8 @@ func run() error { s3PathStyleOnly: *s3PathStyleOnly, sqsAddress: *sqsAddr, leaderSQS: cfg.leaderSQS, + sqsRegion: *sqsRegion, + sqsCredsFile: *sqsCredsFile, metricsAddress: *metricsAddr, metricsToken: *metricsToken, pprofAddress: *pprofAddr, @@ -840,6 +844,8 @@ type runtimeServerRunner struct { s3PathStyleOnly bool sqsAddress string leaderSQS map[string]string + sqsRegion string + sqsCredsFile string metricsAddress string metricsToken string pprofAddress string @@ -872,7 +878,7 @@ func (r runtimeServerRunner) start() error { if err := startS3Server(r.ctx, r.lc, r.eg, r.s3Address, r.shardStore, r.coordinate, r.leaderS3, r.s3Region, r.s3CredsFile, r.s3PathStyleOnly, r.readTracker); err != nil { return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) } - if err := startSQSServer(r.ctx, r.lc, r.eg, r.sqsAddress, r.shardStore, r.coordinate, r.leaderSQS); err != nil { + if err := startSQSServer(r.ctx, r.lc, r.eg, r.sqsAddress, r.shardStore, r.coordinate, r.leaderSQS, r.sqsRegion, r.sqsCredsFile); err != nil { return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) } if err := startMetricsServer(r.ctx, r.lc, r.eg, r.metricsAddress, r.metricsToken, r.metricsRegistry.Handler()); err != nil { diff --git a/main_s3.go b/main_s3.go index 7fae6f16d..5de290c7f 100644 --- a/main_s3.go +++ b/main_s3.go @@ -2,10 +2,7 @@ package main import ( "context" - "encoding/json" - "fmt" "net" - "os" "strings" "github.com/bootjp/elastickv/adapter" @@ -14,15 +11,6 @@ import ( "golang.org/x/sync/errgroup" ) -type s3CredentialFile struct { - Credentials []s3CredentialEntry `json:"credentials"` -} - -type s3CredentialEntry struct { - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` -} - func startS3Server( ctx context.Context, lc *net.ListenConfig, @@ -83,30 +71,5 @@ func startS3Server( } func loadS3StaticCredentials(path string) (map[string]string, error) { - path = strings.TrimSpace(path) - if path == "" { - return nil, nil - } - f, err := os.Open(path) - if err != nil { - return nil, errors.WithStack(err) - } - defer f.Close() - file := s3CredentialFile{} - if err := json.NewDecoder(f).Decode(&file); err != nil { - return nil, errors.WithStack(err) - } - out := make(map[string]string, len(file.Credentials)) - for _, cred := range file.Credentials { - accessKeyID := strings.TrimSpace(cred.AccessKeyID) - secretAccessKey := strings.TrimSpace(cred.SecretAccessKey) - if accessKeyID == "" || secretAccessKey == "" { - return nil, errors.New("s3 credentials file contains an empty access key or secret key") - } - if _, exists := out[accessKeyID]; exists { - return nil, errors.WithStack(fmt.Errorf("s3 credentials file contains duplicate access key ID: %q", accessKeyID)) - } - out[accessKeyID] = secretAccessKey - } - return out, nil + return loadSigV4StaticCredentialsFile(path, "s3") } diff --git a/main_sigv4_creds.go b/main_sigv4_creds.go new file mode 100644 index 000000000..9ab517bf7 --- /dev/null +++ b/main_sigv4_creds.go @@ -0,0 +1,58 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/cockroachdb/errors" +) + +// sigV4CredentialsFile is the JSON schema both --s3CredentialsFile and +// --sqsCredentialsFile read. Sharing the schema means operators maintain one +// file per deployment regardless of which adapters are enabled. +type sigV4CredentialsFile struct { + Credentials []sigV4CredentialEntry `json:"credentials"` +} + +type sigV4CredentialEntry struct { + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` +} + +// loadSigV4StaticCredentialsFile parses a credentials file and returns a +// map of access-key → secret suitable for WithS3StaticCredentials or +// WithSQSStaticCredentials. An empty path returns a nil map so the caller +// can leave authorization disabled. +// +// labelForErrors appears in error text ("s3 credentials file ...") so +// operators get context on which flag produced the problem. +func loadSigV4StaticCredentialsFile(path string, labelForErrors string) (map[string]string, error) { + path = strings.TrimSpace(path) + if path == "" { + return nil, nil + } + f, err := os.Open(path) + if err != nil { + return nil, errors.WithStack(err) + } + defer f.Close() + file := sigV4CredentialsFile{} + if err := json.NewDecoder(f).Decode(&file); err != nil { + return nil, errors.WithStack(err) + } + out := make(map[string]string, len(file.Credentials)) + for _, cred := range file.Credentials { + accessKeyID := strings.TrimSpace(cred.AccessKeyID) + secretAccessKey := strings.TrimSpace(cred.SecretAccessKey) + if accessKeyID == "" || secretAccessKey == "" { + return nil, errors.WithStack(fmt.Errorf("%s credentials file contains an empty access key or secret key", labelForErrors)) + } + if _, exists := out[accessKeyID]; exists { + return nil, errors.WithStack(fmt.Errorf("%s credentials file contains duplicate access key ID: %q", labelForErrors, accessKeyID)) + } + out[accessKeyID] = secretAccessKey + } + return out, nil +} diff --git a/main_sqs.go b/main_sqs.go index a6371e96e..55ca41684 100644 --- a/main_sqs.go +++ b/main_sqs.go @@ -19,6 +19,8 @@ func startSQSServer( shardStore *kv.ShardStore, coordinate kv.Coordinator, leaderSQS map[string]string, + region string, + credentialsFile string, ) error { sqsAddr = strings.TrimSpace(sqsAddr) if sqsAddr == "" { @@ -28,11 +30,18 @@ func startSQSServer( if err != nil { return errors.Wrapf(err, "failed to listen on %s", sqsAddr) } + staticCreds, err := loadSigV4StaticCredentialsFile(credentialsFile, "sqs") + if err != nil { + _ = sqsL.Close() + return err + } sqsServer := adapter.NewSQSServer( sqsL, shardStore, coordinate, adapter.WithSQSLeaderMap(leaderSQS), + adapter.WithSQSRegion(region), + adapter.WithSQSStaticCredentials(staticCreds), ) // Two-goroutine shutdown pattern mirrors startS3Server: one goroutine waits // on either ctx.Done() or Run completion to call Stop, the other runs the From 387a19c295418ebf9f768e70edf666c265b7f4bf Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 15:30:07 +0900 Subject: [PATCH 5/9] feat(sqs): SendMessage, ReceiveMessage, DeleteMessage, ChangeMessageVisibility Implement the core message path against the visibility-indexed keyspace described in docs/design/2026_04_24_proposed_sqs_compatible_adapter.md. Storage layout: - !sqs|msg|data| - the full message record: body, MD5, attributes, send / available / visible timestamps, receive count, and the current rotating receipt token. - !sqs|msg|vis| - a visibility index keyed by visible_at. ReceiveMessage scans the range [now=0, now+1) to find the next visible messages without any background sweeper. - Receipt handles are base64url(version || queue_gen || message_id || receipt_token) and carry everything Delete / ChangeVisibility need to locate the record and verify ownership. Handlers: - sendMessage: validates body size against MaximumMessageSize, resolves effective DelaySeconds, then writes the data record and the matching visibility entry in one OCC transaction. - receiveMessage: fences the scan with LeaseReadThrough so the snapshot scan stays local on the leader inside the lease window and falls back to LinearizableRead when the lease is cold. For each candidate it runs a single-message OCC transaction that deletes the old vis entry, inserts a new one at (now + visibility timeout), rotates the receipt token, and bumps ReceiveCount. Race losers are skipped so a batch returns whatever it could deliver. - deleteMessage: verifies receipt token on the data record, then atomically drops the data + vis keys. - changeMessageVisibility: swaps the vis entry to a new visible_at and updates the record; rejects messages whose visibility already expired with MessageNotInflight. Tests (adapter/sqs_messages_test.go) spin up a single-node cluster and cover: - Send -> Receive -> Delete happy path, with MD5 and MessageId checks. - Tampered receipt handle rejected with InvalidReceiptHandle. - MaxNumberOfMessages returns a partial batch across two receives. - DelaySeconds defers delivery until the delay elapses. - Visibility-timeout expiry re-delivers the same message with ReceiveCount=2. - ChangeMessageVisibility extends the in-flight window. - Receipt-handle codec round trips for several (queue_gen, message_id, token) combinations. Scaffold "returns NotImplemented" test updated to skip the newly implemented targets (Send/Receive/Delete/ChangeMessageVisibility). PurgeQueue, batch APIs, and tag APIs remain NotImplemented for now. --- adapter/sqs.go | 8 +- adapter/sqs_messages.go | 704 +++++++++++++++++++++++++++++++++++ adapter/sqs_messages_test.go | 388 +++++++++++++++++++ adapter/sqs_test.go | 14 +- 4 files changed, 1101 insertions(+), 13 deletions(-) create mode 100644 adapter/sqs_messages.go create mode 100644 adapter/sqs_messages_test.go diff --git a/adapter/sqs.go b/adapter/sqs.go index 6b30085bc..1a3a06fdf 100644 --- a/adapter/sqs.go +++ b/adapter/sqs.go @@ -97,12 +97,12 @@ func NewSQSServer(listen net.Listener, st store.MVCCStore, coordinate kv.Coordin sqsGetQueueAttributesTarget: s.getQueueAttributes, sqsSetQueueAttributesTarget: s.setQueueAttributes, sqsPurgeQueueTarget: s.notImplemented("PurgeQueue"), - sqsSendMessageTarget: s.notImplemented("SendMessage"), + sqsSendMessageTarget: s.sendMessage, sqsSendMessageBatchTarget: s.notImplemented("SendMessageBatch"), - sqsReceiveMessageTarget: s.notImplemented("ReceiveMessage"), - sqsDeleteMessageTarget: s.notImplemented("DeleteMessage"), + sqsReceiveMessageTarget: s.receiveMessage, + sqsDeleteMessageTarget: s.deleteMessage, sqsDeleteMessageBatchTarget: s.notImplemented("DeleteMessageBatch"), - sqsChangeMessageVisibilityTarget: s.notImplemented("ChangeMessageVisibility"), + sqsChangeMessageVisibilityTarget: s.changeMessageVisibility, sqsChangeMessageVisibilityBatchTgt: s.notImplemented("ChangeMessageVisibilityBatch"), sqsTagQueueTarget: s.notImplemented("TagQueue"), sqsUntagQueueTarget: s.notImplemented("UntagQueue"), diff --git a/adapter/sqs_messages.go b/adapter/sqs_messages.go new file mode 100644 index 000000000..1c7682eb9 --- /dev/null +++ b/adapter/sqs_messages.go @@ -0,0 +1,704 @@ +package adapter + +import ( + "bytes" + "context" + "crypto/md5" //nolint:gosec // AWS SQS ETag specifies MD5; not used as a cryptographic primitive. + "crypto/rand" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "net/http" + "strconv" + "strings" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +// Message-keyspace prefixes. The data record holds the message body and +// state; the visibility index is a separate, visible_at-sorted key family +// so ReceiveMessage can find the next visible message with a single bounded +// prefix scan. +const ( + SqsMsgDataPrefix = "!sqs|msg|data|" + SqsMsgVisPrefix = "!sqs|msg|vis|" +) + +const ( + sqsMessageIDBytes = 16 + sqsReceiptTokenBytes = 16 + sqsReceiveDefaultMaxMessages = 1 + sqsReceiveHardMaxMessages = 10 + sqsReceiveScanOverfetchFactor = 2 + sqsChangeVisibilityMaxSeconds = sqsMaxVisibilityTimeoutSeconds + sqsVisScanPageLimit = 1024 + // Version byte prefixed to encoded receipt handles. Bumped when the + // on-wire handle format changes so old handles fail to decode loudly. + sqsReceiptHandleVersion = byte(0x01) + // Byte sizes used when pre-sizing key buffers. The exact value is not + // critical; it only avoids one append growth for typical queue/ID + // lengths. + sqsKeyCapSmall = 32 + sqsKeyCapLarge = 64 + // Conversion factors for SQS second-granularity inputs. + sqsMillisPerSecond = 1000 +) + +// AWS error codes specific to message operations. +const ( + sqsErrReceiptHandleInvalid = "ReceiptHandleIsInvalid" + sqsErrInvalidReceiptHandle = "InvalidReceiptHandle" + sqsErrMessageTooLong = "InvalidParameterValue" + sqsErrMessageNotInflight = "MessageNotInflight" +) + +// sqsMessageRecord mirrors !sqs|msg|data|... on disk. Visibility state +// (VisibleAtMillis, CurrentReceiptToken, ReceiveCount) lives here rather +// than in a side-record so a single OCC transaction can rotate it. +type sqsMessageRecord struct { + MessageID string `json:"message_id"` + Body []byte `json:"body"` + MD5OfBody string `json:"md5_of_body"` + MessageAttributes map[string]string `json:"message_attributes,omitempty"` + SenderID string `json:"sender_id,omitempty"` + SendTimestampMillis int64 `json:"send_timestamp_millis"` + AvailableAtMillis int64 `json:"available_at_millis"` + VisibleAtMillis int64 `json:"visible_at_millis"` + ReceiveCount int64 `json:"receive_count"` + FirstReceiveMillis int64 `json:"first_receive_millis,omitempty"` + CurrentReceiptToken []byte `json:"current_receipt_token"` + QueueGeneration uint64 `json:"queue_generation"` +} + +var storedSQSMsgPrefix = []byte{0x00, 'S', 'M', 0x01} + +func encodeSQSMessageRecord(m *sqsMessageRecord) ([]byte, error) { + body, err := json.Marshal(m) + if err != nil { + return nil, errors.WithStack(err) + } + out := make([]byte, 0, len(storedSQSMsgPrefix)+len(body)) + out = append(out, storedSQSMsgPrefix...) + out = append(out, body...) + return out, nil +} + +func decodeSQSMessageRecord(b []byte) (*sqsMessageRecord, error) { + if !bytes.HasPrefix(b, storedSQSMsgPrefix) { + return nil, errors.New("unrecognized sqs message format") + } + var m sqsMessageRecord + if err := json.Unmarshal(b[len(storedSQSMsgPrefix):], &m); err != nil { + return nil, errors.WithStack(err) + } + return &m, nil +} + +// ------------------------ key helpers ------------------------ + +func sqsMsgDataKey(queueName string, gen uint64, messageID string) []byte { + buf := make([]byte, 0, len(SqsMsgDataPrefix)+sqsKeyCapLarge) + buf = append(buf, SqsMsgDataPrefix...) + buf = append(buf, encodeSQSSegment(queueName)...) + buf = appendU64(buf, gen) + buf = append(buf, encodeSQSSegment(messageID)...) + return buf +} + +func sqsMsgVisKey(queueName string, gen uint64, visibleAtMillis int64, messageID string) []byte { + buf := make([]byte, 0, len(SqsMsgVisPrefix)+sqsKeyCapLarge) + buf = append(buf, SqsMsgVisPrefix...) + buf = append(buf, encodeSQSSegment(queueName)...) + buf = appendU64(buf, gen) + buf = appendU64(buf, uint64MaxZero(visibleAtMillis)) + buf = append(buf, encodeSQSSegment(messageID)...) + return buf +} + +func sqsMsgVisPrefixForQueue(queueName string, gen uint64) []byte { + buf := make([]byte, 0, len(SqsMsgVisPrefix)+sqsKeyCapSmall) + buf = append(buf, SqsMsgVisPrefix...) + buf = append(buf, encodeSQSSegment(queueName)...) + buf = appendU64(buf, gen) + return buf +} + +// uint64MaxZero clamps negative int64 (which never happens for wall-clock +// timestamps but would silently overflow under uint64() cast) to zero. +func uint64MaxZero(v int64) uint64 { + if v < 0 { + return 0 + } + return uint64(v) +} + +func sqsMsgVisScanBounds(queueName string, gen uint64, maxVisibleAtMillis int64) (start, end []byte) { + prefix := sqsMsgVisPrefixForQueue(queueName, gen) + start = append(bytes.Clone(prefix), zeroU64()...) + upper := uint64MaxZero(maxVisibleAtMillis) + if upper < ^uint64(0) { + upper++ + } + end = append(bytes.Clone(prefix), encodedU64(upper)...) + return start, end +} + +func appendU64(dst []byte, v uint64) []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], v) + return append(dst, buf[:]...) +} + +func encodedU64(v uint64) []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], v) + return buf[:] +} + +func zeroU64() []byte { + var buf [8]byte + return buf[:] +} + +// ------------------------ message id + receipt handle ------------------------ + +func newMessageIDHex() (string, error) { + var buf [sqsMessageIDBytes]byte + if _, err := rand.Read(buf[:]); err != nil { + return "", errors.WithStack(err) + } + return hex.EncodeToString(buf[:]), nil +} + +func newReceiptToken() ([]byte, error) { + buf := make([]byte, sqsReceiptTokenBytes) + if _, err := rand.Read(buf); err != nil { + return nil, errors.WithStack(err) + } + return buf, nil +} + +// encodeReceiptHandle packs (queue_gen, message_id, receipt_token) into a +// single opaque blob. Format: +// +// [ 0 ] byte version = 0x01 +// [ 1..9 ] uint64 queue_gen (BE) +// [ 9..25 ] 16 bytes message_id (raw bytes from hex decode) +// [ 25..41 ] 16 bytes receipt_token +// +// The result is base64-urlsafe (no padding) so it passes through JSON and +// HTTP query parameters untouched. +func encodeReceiptHandle(queueGen uint64, messageIDHex string, receiptToken []byte) (string, error) { + if len(receiptToken) != sqsReceiptTokenBytes { + return "", errors.New("receipt token has wrong length") + } + idBytes, err := hex.DecodeString(messageIDHex) + if err != nil || len(idBytes) != sqsMessageIDBytes { + return "", errors.New("message id has wrong format") + } + buf := make([]byte, 0, 1+8+sqsMessageIDBytes+sqsReceiptTokenBytes) + buf = append(buf, sqsReceiptHandleVersion) + buf = appendU64(buf, queueGen) + buf = append(buf, idBytes...) + buf = append(buf, receiptToken...) + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +type decodedReceiptHandle struct { + QueueGeneration uint64 + MessageIDHex string + ReceiptToken []byte +} + +func decodeReceiptHandle(raw string) (*decodedReceiptHandle, error) { + b, err := base64.RawURLEncoding.DecodeString(raw) + if err != nil { + return nil, errors.WithStack(err) + } + want := 1 + 8 + sqsMessageIDBytes + sqsReceiptTokenBytes + if len(b) != want || b[0] != sqsReceiptHandleVersion { + return nil, errors.New("receipt handle length or version mismatch") + } + out := &decodedReceiptHandle{ + QueueGeneration: binary.BigEndian.Uint64(b[1:9]), + MessageIDHex: hex.EncodeToString(b[9 : 9+sqsMessageIDBytes]), + ReceiptToken: bytes.Clone(b[9+sqsMessageIDBytes:]), + } + return out, nil +} + +// ------------------------ input decoding ------------------------ + +type sqsSendMessageInput struct { + QueueUrl string `json:"QueueUrl"` + MessageBody string `json:"MessageBody"` + DelaySeconds *int64 `json:"DelaySeconds,omitempty"` + MessageAttributes map[string]string `json:"MessageAttributes,omitempty"` +} + +type sqsReceiveMessageInput struct { + QueueUrl string `json:"QueueUrl"` + MaxNumberOfMessages int `json:"MaxNumberOfMessages,omitempty"` + VisibilityTimeout *int64 `json:"VisibilityTimeout,omitempty"` + WaitTimeSeconds *int64 `json:"WaitTimeSeconds,omitempty"` +} + +type sqsDeleteMessageInput struct { + QueueUrl string `json:"QueueUrl"` + ReceiptHandle string `json:"ReceiptHandle"` +} + +type sqsChangeVisibilityInput struct { + QueueUrl string `json:"QueueUrl"` + ReceiptHandle string `json:"ReceiptHandle"` + VisibilityTimeout int64 `json:"VisibilityTimeout"` +} + +// ------------------------ handlers ------------------------ + +func (s *SQSServer) sendMessage(w http.ResponseWriter, r *http.Request) { + var in sqsSendMessageInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + queueName, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + meta, apiErr := s.loadQueueMetaForSend(r.Context(), queueName, []byte(in.MessageBody)) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + delay, apiErr := resolveSendDelay(meta, in.DelaySeconds) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + rec, recordBytes, apiErr := buildSendRecord(meta, in, delay) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + + dataKey := sqsMsgDataKey(queueName, meta.Generation, rec.MessageID) + visKey := sqsMsgVisKey(queueName, meta.Generation, rec.AvailableAtMillis, rec.MessageID) + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: dataKey, Value: recordBytes}, + {Op: kv.Put, Key: visKey, Value: []byte(rec.MessageID)}, + }, + } + if _, err := s.coordinator.Dispatch(r.Context(), req); err != nil { + writeSQSErrorFromErr(w, err) + return + } + + writeSQSJSON(w, map[string]string{ + "MessageId": rec.MessageID, + "MD5OfMessageBody": rec.MD5OfBody, + "MD5OfMessageAttributes": md5OfAttributesHex(in.MessageAttributes), + }) +} + +func (s *SQSServer) loadQueueMetaForSend(ctx context.Context, queueName string, body []byte) (*sqsQueueMeta, error) { + readTS := s.nextTxnReadTS(ctx) + meta, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if !exists { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + } + if int64(len(body)) > meta.MaximumMessageSize { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrMessageTooLong, "message body exceeds MaximumMessageSize") + } + return meta, nil +} + +func resolveSendDelay(meta *sqsQueueMeta, requested *int64) (int64, error) { + delay := meta.DelaySeconds + if requested == nil { + return delay, nil + } + if *requested < 0 || *requested > sqsMaxDelaySeconds { + return 0, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, "DelaySeconds out of range") + } + return *requested, nil +} + +func buildSendRecord(meta *sqsQueueMeta, in sqsSendMessageInput, delay int64) (*sqsMessageRecord, []byte, error) { + messageID, err := newMessageIDHex() + if err != nil { + return nil, nil, errors.WithStack(err) + } + token, err := newReceiptToken() + if err != nil { + return nil, nil, errors.WithStack(err) + } + now := time.Now().UnixMilli() + availableAt := now + delay*sqsMillisPerSecond + body := []byte(in.MessageBody) + rec := &sqsMessageRecord{ + MessageID: messageID, + Body: body, + MD5OfBody: sqsMD5Hex(body), + MessageAttributes: in.MessageAttributes, + SendTimestampMillis: now, + AvailableAtMillis: availableAt, + VisibleAtMillis: availableAt, + CurrentReceiptToken: token, + QueueGeneration: meta.Generation, + } + recordBytes, err := encodeSQSMessageRecord(rec) + if err != nil { + return nil, nil, errors.WithStack(err) + } + return rec, recordBytes, nil +} + +//nolint:cyclop // AWS ReceiveMessage branches on per-message eligibility; splitting further just moves the branching around. +func (s *SQSServer) receiveMessage(w http.ResponseWriter, r *http.Request) { + var in sqsReceiveMessageInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + queueName, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + ctx := r.Context() + + // Use LeaseRead to fence this scan against a leader that silently lost + // quorum mid-request. When the lease is warm this is a local + // wall-clock compare; when it is cold it falls back to a full + // LinearizableRead. + if _, err := kv.LeaseReadThrough(s.coordinator, ctx); err != nil { + writeSQSErrorFromErr(w, err) + return + } + + readTS := s.nextTxnReadTS(ctx) + meta, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if !exists { + writeSQSError(w, http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + return + } + max := clampReceiveMaxMessages(in.MaxNumberOfMessages) + visibilityTimeout := meta.VisibilityTimeoutSeconds + if in.VisibilityTimeout != nil { + if *in.VisibilityTimeout < 0 || *in.VisibilityTimeout > sqsChangeVisibilityMaxSeconds { + writeSQSError(w, http.StatusBadRequest, sqsErrInvalidAttributeValue, "VisibilityTimeout out of range") + return + } + visibilityTimeout = *in.VisibilityTimeout + } + + candidates, err := s.scanVisibleMessageCandidates(ctx, queueName, meta.Generation, max*sqsReceiveScanOverfetchFactor, readTS) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + delivered := s.rotateMessagesForDelivery(ctx, queueName, meta.Generation, candidates, visibilityTimeout, max, readTS) + writeSQSJSON(w, map[string]any{"Messages": delivered}) +} + +func clampReceiveMaxMessages(requested int) int { + if requested <= 0 { + return sqsReceiveDefaultMaxMessages + } + if requested > sqsReceiveHardMaxMessages { + return sqsReceiveHardMaxMessages + } + return requested +} + +// scanVisibleMessageCandidates returns vis-index entries with +// visible_at <= now, up to limit. Each entry carries the key (needed +// for the delete-old-vis step) and the message_id pointed at by its +// value. +type sqsMsgCandidate struct { + visKey []byte + messageID string +} + +func (s *SQSServer) scanVisibleMessageCandidates(ctx context.Context, queueName string, gen uint64, limit int, readTS uint64) ([]sqsMsgCandidate, error) { + if limit <= 0 { + return nil, nil + } + now := time.Now().UnixMilli() + start, end := sqsMsgVisScanBounds(queueName, gen, now) + page := limit + if page > sqsVisScanPageLimit { + page = sqsVisScanPageLimit + } + kvs, err := s.store.ScanAt(ctx, start, end, page, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + out := make([]sqsMsgCandidate, 0, len(kvs)) + for _, kvp := range kvs { + out = append(out, sqsMsgCandidate{visKey: bytes.Clone(kvp.Key), messageID: string(kvp.Value)}) + } + return out, nil +} + +// rotateMessagesForDelivery runs an OCC transaction per candidate to +// rotate its visibility entry + receipt token. Failures on individual +// messages (races, write-conflict) are skipped rather than aborting the +// whole batch — AWS semantics allow ReceiveMessage to return fewer +// messages than requested. +func (s *SQSServer) rotateMessagesForDelivery( + ctx context.Context, + queueName string, + gen uint64, + candidates []sqsMsgCandidate, + visibilityTimeout int64, + max int, + readTS uint64, +) []map[string]any { + delivered := make([]map[string]any, 0, max) + for _, cand := range candidates { + if len(delivered) >= max { + break + } + msg, ok := s.tryDeliverCandidate(ctx, queueName, gen, cand, visibilityTimeout, readTS) + if !ok { + continue + } + delivered = append(delivered, msg) + } + return delivered +} + +func (s *SQSServer) tryDeliverCandidate( + ctx context.Context, + queueName string, + gen uint64, + cand sqsMsgCandidate, + visibilityTimeout int64, + readTS uint64, +) (map[string]any, bool) { + dataKey := sqsMsgDataKey(queueName, gen, cand.messageID) + raw, err := s.store.GetAt(ctx, dataKey, readTS) + if err != nil { + return nil, false + } + rec, err := decodeSQSMessageRecord(raw) + if err != nil { + return nil, false + } + + newToken, err := newReceiptToken() + if err != nil { + return nil, false + } + now := time.Now().UnixMilli() + newVisibleAt := now + visibilityTimeout*sqsMillisPerSecond + rec.VisibleAtMillis = newVisibleAt + rec.CurrentReceiptToken = newToken + rec.ReceiveCount++ + if rec.FirstReceiveMillis == 0 { + rec.FirstReceiveMillis = now + } + recordBytes, err := encodeSQSMessageRecord(rec) + if err != nil { + return nil, false + } + newVisKey := sqsMsgVisKey(queueName, gen, newVisibleAt, cand.messageID) + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: cand.visKey}, + {Op: kv.Put, Key: newVisKey, Value: []byte(cand.messageID)}, + {Op: kv.Put, Key: dataKey, Value: recordBytes}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err != nil { + return nil, false + } + + handle, err := encodeReceiptHandle(gen, cand.messageID, newToken) + if err != nil { + return nil, false + } + return map[string]any{ + "MessageId": cand.messageID, + "ReceiptHandle": handle, + "Body": string(rec.Body), + "MD5OfBody": rec.MD5OfBody, + "Attributes": map[string]string{ + "ApproximateReceiveCount": strconv.FormatInt(rec.ReceiveCount, 10), + "SentTimestamp": strconv.FormatInt(rec.SendTimestampMillis, 10), + "ApproximateFirstReceiveTimestamp": strconv.FormatInt(rec.FirstReceiveMillis, 10), + }, + }, true +} + +func (s *SQSServer) deleteMessage(w http.ResponseWriter, r *http.Request) { + var in sqsDeleteMessageInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + queueName, handle, err := s.parseQueueAndReceipt(in.QueueUrl, in.ReceiptHandle) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + rec, dataKey, apiErr := s.loadAndVerifyMessage(r.Context(), queueName, handle) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + visKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: dataKey}, + {Op: kv.Del, Key: visKey}, + }, + } + if _, err := s.coordinator.Dispatch(r.Context(), req); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]any{}) +} + +func (s *SQSServer) changeMessageVisibility(w http.ResponseWriter, r *http.Request) { + var in sqsChangeVisibilityInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + if in.VisibilityTimeout < 0 || in.VisibilityTimeout > sqsChangeVisibilityMaxSeconds { + writeSQSError(w, http.StatusBadRequest, sqsErrInvalidAttributeValue, "VisibilityTimeout out of range") + return + } + queueName, handle, err := s.parseQueueAndReceipt(in.QueueUrl, in.ReceiptHandle) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + rec, dataKey, apiErr := s.loadAndVerifyMessage(r.Context(), queueName, handle) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + now := time.Now().UnixMilli() + if rec.VisibleAtMillis <= now { + writeSQSError(w, http.StatusBadRequest, sqsErrMessageNotInflight, "message is not currently in flight") + return + } + + oldVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + rec.VisibleAtMillis = now + in.VisibilityTimeout*sqsMillisPerSecond + recordBytes, err := encodeSQSMessageRecord(rec) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + newVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: oldVisKey}, + {Op: kv.Put, Key: newVisKey, Value: []byte(rec.MessageID)}, + {Op: kv.Put, Key: dataKey, Value: recordBytes}, + }, + } + if _, err := s.coordinator.Dispatch(r.Context(), req); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]any{}) +} + +// parseQueueAndReceipt extracts the queue name and decodes the receipt +// handle from a DeleteMessage / ChangeMessageVisibility input. +func (s *SQSServer) parseQueueAndReceipt(queueUrl, receiptHandle string) (string, *decodedReceiptHandle, error) { + queueName, err := queueNameFromURL(queueUrl) + if err != nil { + return "", nil, err + } + handle, err := decodeReceiptHandle(receiptHandle) + if err != nil { + return "", nil, newSQSAPIError(http.StatusBadRequest, sqsErrReceiptHandleInvalid, "receipt handle is not parseable") + } + return queueName, handle, nil +} + +// loadAndVerifyMessage reads the data record for the given handle and +// verifies that the receipt token matches the current one on record. +// Returns the record, its key, or a typed SQS error. +func (s *SQSServer) loadAndVerifyMessage(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, error) { + readTS := s.nextTxnReadTS(ctx) + dataKey := sqsMsgDataKey(queueName, handle.QueueGeneration, handle.MessageIDHex) + raw, err := s.store.GetAt(ctx, dataKey, readTS) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, nil, newSQSAPIError(http.StatusBadRequest, sqsErrReceiptHandleInvalid, "message not found") + } + return nil, nil, errors.WithStack(err) + } + rec, err := decodeSQSMessageRecord(raw) + if err != nil { + return nil, nil, errors.WithStack(err) + } + if !bytes.Equal(rec.CurrentReceiptToken, handle.ReceiptToken) { + return nil, nil, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidReceiptHandle, "receipt handle token does not match") + } + return rec, dataKey, nil +} + +// ------------------------ small helpers ------------------------ + +func sqsMD5Hex(body []byte) string { + sum := md5.Sum(body) //nolint:gosec // AWS-specified ETag hashing, not a crypto primitive. + return hex.EncodeToString(sum[:]) +} + +// md5OfAttributesHex computes AWS's MD5 of a MessageAttributes map. The +// real AWS format canonicalizes names and types; this adapter only +// returns "" on an empty map and a simple concatenated hash otherwise +// (full canonicalization lives in a follow-up PR along with typed +// attribute values). +func md5OfAttributesHex(attrs map[string]string) string { + if len(attrs) == 0 { + return "" + } + keys := make([]string, 0, len(attrs)) + for k := range attrs { + keys = append(keys, k) + } + // Stable order for determinism. + for i := 0; i < len(keys); i++ { + for j := i + 1; j < len(keys); j++ { + if keys[j] < keys[i] { + keys[i], keys[j] = keys[j], keys[i] + } + } + } + var b strings.Builder + for _, k := range keys { + b.WriteString(k) + b.WriteString("=") + b.WriteString(attrs[k]) + b.WriteString(";") + } + return sqsMD5Hex([]byte(b.String())) +} diff --git a/adapter/sqs_messages_test.go b/adapter/sqs_messages_test.go new file mode 100644 index 000000000..688cda999 --- /dev/null +++ b/adapter/sqs_messages_test.go @@ -0,0 +1,388 @@ +package adapter + +import ( + "encoding/hex" + "net/http" + "strconv" + "testing" + "time" +) + +// createSQSQueueForTest is a small helper so every message-path test does +// not have to repeat the "createQueue -> pull URL" dance. +func createSQSQueueForTest(t *testing.T, node Node, name string) string { + t.Helper() + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{"QueueName": name}) + if status != http.StatusOK { + t.Fatalf("createQueue %q: %d %v", name, status, out) + } + url, _ := out["QueueUrl"].(string) + if url == "" { + t.Fatalf("createQueue %q: empty URL", name) + } + return url +} + +func TestSQSServer_SendReceiveDelete(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "work") + + msgID := sendOneMessage(t, node, queueURL, "hello") + receipt := receiveOneMessage(t, node, queueURL, msgID, "hello") + expectNoMessagesVisible(t, node, queueURL, "after-receive") + deleteMessageOK(t, node, queueURL, receipt) + expectNoMessagesVisible(t, node, queueURL, "after-delete") +} + +// sendOneMessage sends a single message and returns the assigned MessageId. +// It asserts the response carries an MD5 over the body that matches what +// sqsMD5Hex would compute locally. +func sendOneMessage(t *testing.T, node Node, queueURL, body string) string { + t.Helper() + status, out := callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": body, + }) + if status != http.StatusOK { + t.Fatalf("send: %d %v", status, out) + } + msgID, _ := out["MessageId"].(string) + if msgID == "" { + t.Fatalf("no MessageId in %v", out) + } + if out["MD5OfMessageBody"] != sqsMD5Hex([]byte(body)) { + t.Fatalf("md5 mismatch: got %v want %q", out["MD5OfMessageBody"], sqsMD5Hex([]byte(body))) + } + return msgID +} + +func receiveOneMessage(t *testing.T, node Node, queueURL, wantID, wantBody string) string { + t.Helper() + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("got %d messages, want 1 (%v)", len(msgs), out) + } + m, _ := msgs[0].(map[string]any) + if m["Body"] != wantBody { + t.Fatalf("Body=%v want %q", m["Body"], wantBody) + } + if m["MessageId"] != wantID { + t.Fatalf("MessageId=%v want %q", m["MessageId"], wantID) + } + receipt, _ := m["ReceiptHandle"].(string) + if receipt == "" { + t.Fatalf("no ReceiptHandle in %v", m) + } + return receipt +} + +func expectNoMessagesVisible(t *testing.T, node Node, queueURL, tag string) { + t.Helper() + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + }) + if status != http.StatusOK { + t.Fatalf("%s receive: %d %v", tag, status, out) + } + if msgs, _ := out["Messages"].([]any); len(msgs) != 0 { + t.Fatalf("%s: got %d messages, want 0", tag, len(msgs)) + } +} + +func deleteMessageOK(t *testing.T, node Node, queueURL, receipt string) { + t.Helper() + status, out := callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": receipt, + }) + if status != http.StatusOK { + t.Fatalf("delete: %d %v", status, out) + } +} + +func TestSQSServer_DeleteWithWrongReceiptRejected(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "wrong-receipt") + + // Send + receive once to get a real message on the queue. + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "x", + }) + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) == 0 { + t.Fatalf("no messages received") + } + goodHandle, _ := msgs[0].(map[string]any)["ReceiptHandle"].(string) + + // Tamper the receipt token portion of the handle and expect + // InvalidReceiptHandle. + decoded, err := decodeReceiptHandle(goodHandle) + if err != nil { + t.Fatalf("decode: %v", err) + } + decoded.ReceiptToken[0] ^= 0xff + badHandle, err := encodeReceiptHandle(decoded.QueueGeneration, decoded.MessageIDHex, decoded.ReceiptToken) + if err != nil { + t.Fatalf("encode: %v", err) + } + status, out = callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": badHandle, + }) + if status != http.StatusBadRequest { + t.Fatalf("delete with bad receipt: status=%d body=%v", status, out) + } + if out["__type"] != sqsErrInvalidReceiptHandle { + t.Fatalf("error type: %q want %q", out["__type"], sqsErrInvalidReceiptHandle) + } +} + +func TestSQSServer_ReceiveBatchRespectsMaxMessages(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "batch") + + const sent = 5 + for i := 0; i < sent; i++ { + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "m-" + strconv.Itoa(i), + }) + } + + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 3, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 3 { + t.Fatalf("got %d messages, want 3 (%v)", len(msgs), out) + } + + // A second receive picks up the remaining two. + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 10, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive#2: %d %v", status, out) + } + msgs, _ = out["Messages"].([]any) + if len(msgs) != 2 { + t.Fatalf("got %d messages on second receive, want 2", len(msgs)) + } +} + +func TestSQSServer_DelaySecondsDefersDelivery(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "delayed") + + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "later", + "DelaySeconds": 2, + }) + + // Immediate receive must return nothing. + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + if msgs, _ := out["Messages"].([]any); len(msgs) != 0 { + t.Fatalf("expected 0 messages before delay elapsed, got %d", len(msgs)) + } + + // After the delay, the message becomes visible. + time.Sleep(2100 * time.Millisecond) + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("expected 1 message after delay, got %d (%v)", len(msgs), out) + } +} + +func TestSQSServer_VisibilityTimeoutExpiryMakesMessageVisibleAgain(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "revisible") + + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "retry-me", + }) + + // First receive with a very short visibility so this test doesn't + // sit idle for 30 seconds. + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("first receive: got %d messages want 1", len(msgs)) + } + first, _ := msgs[0].(map[string]any) + firstAttrs, _ := first["Attributes"].(map[string]any) + if firstAttrs["ApproximateReceiveCount"] != "1" { + t.Fatalf("first receive count = %v, want 1", firstAttrs) + } + + time.Sleep(1200 * time.Millisecond) + + // Second receive should see the same message with ReceiveCount=2. + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("second receive: %d %v", status, out) + } + msgs, _ = out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("second receive: got %d messages want 1", len(msgs)) + } + second, _ := msgs[0].(map[string]any) + if first["MessageId"] != second["MessageId"] { + t.Fatalf("second receive returned a different message: %v vs %v", first, second) + } + secondAttrs, _ := second["Attributes"].(map[string]any) + if secondAttrs["ApproximateReceiveCount"] != "2" { + t.Fatalf("second receive count = %v, want 2", secondAttrs) + } +} + +func TestSQSServer_ChangeMessageVisibility(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "chgvis") + + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "bumpy", + }) + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("no message received") + } + receipt, _ := msgs[0].(map[string]any)["ReceiptHandle"].(string) + + // Extend visibility to 60s. + status, out = callSQS(t, node, sqsChangeMessageVisibilityTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": receipt, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("change visibility: %d %v", status, out) + } + + // After the original 1s expiry, the message must still be hidden. + time.Sleep(1200 * time.Millisecond) + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + if msgs, _ := out["Messages"].([]any); len(msgs) != 0 { + t.Fatalf("expected 0 messages after visibility extended, got %d", len(msgs)) + } +} + +func TestSQSServer_ReceiptHandleCodecRoundTrip(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + gen uint64 + id string + }{ + {0, "00000000000000000000000000000000"}, + {1, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + {42, "deadbeefdeadbeefdeadbeefdeadbeef"}, + } { + token := make([]byte, sqsReceiptTokenBytes) + for i := range token { + token[i] = byte(i) + } + h, err := encodeReceiptHandle(tc.gen, tc.id, token) + if err != nil { + t.Fatalf("encode: %v", err) + } + back, err := decodeReceiptHandle(h) + if err != nil { + t.Fatalf("decode: %v", err) + } + if back.QueueGeneration != tc.gen || back.MessageIDHex != tc.id { + t.Fatalf("round-trip mismatch: %+v vs %d/%s", back, tc.gen, tc.id) + } + if hex.EncodeToString(back.ReceiptToken) != hex.EncodeToString(token) { + t.Fatalf("token round-trip mismatch: %x vs %x", back.ReceiptToken, token) + } + } + + // A garbage handle must fail to decode, not crash. + if _, err := decodeReceiptHandle("!!!"); err == nil { + t.Fatal("expected decode error for garbage handle") + } +} diff --git a/adapter/sqs_test.go b/adapter/sqs_test.go index 9cf122171..721917174 100644 --- a/adapter/sqs_test.go +++ b/adapter/sqs_test.go @@ -124,19 +124,15 @@ func TestSQSServer_KnownTargetsReturnNotImplemented(t *testing.T) { t.Parallel() base := startTestSQSServer(t) - // Targets that still return NotImplemented. The catalog targets - // (CreateQueue/DeleteQueue/ListQueues/GetQueueUrl/GetQueueAttributes/ - // SetQueueAttributes) are covered by TestSQSServer_Catalog* against a - // real single-node cluster because they require a coordinator to - // dispatch transactions. + // Targets that still return NotImplemented. The catalog and core + // message operations (Create/Delete/List/Get/SetQueue*, SendMessage, + // ReceiveMessage, DeleteMessage, ChangeMessageVisibility) have real + // handlers; they are exercised against a single-node cluster by + // TestSQSServer_Catalog* and TestSQSServer_Send*. targets := []string{ sqsPurgeQueueTarget, - sqsSendMessageTarget, sqsSendMessageBatchTarget, - sqsReceiveMessageTarget, - sqsDeleteMessageTarget, sqsDeleteMessageBatchTarget, - sqsChangeMessageVisibilityTarget, sqsChangeMessageVisibilityBatchTgt, sqsTagQueueTarget, sqsUntagQueueTarget, From 4bae6d05155f2fc907e31c64e02e777fd1fa60bc Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 16:44:54 +0900 Subject: [PATCH 6/9] fix(sqs): close receipt-token and visibility races under OCC Address the P1 findings from the codex review on #610 and the gemini micro-nit: - tryDeliverCandidate now passes ReadKeys=[old_vis_key, data_key] on its commit so a concurrent ReceiveMessage that rotates the same message commits a newer version and our Dispatch returns ErrWriteConflict. Skipping the candidate on conflict prevents the "two workers both think they delivered" duplicate. - deleteMessage and changeMessageVisibility lift their single-shot validate-and-commit into retry loops that mirror the catalog handlers. Both pass ReadKeys covering the data record and the current vis entry; on ErrWriteConflict we re-validate the token, which either succeeds (the state we just saw is still current) or returns InvalidReceiptHandle (someone else rotated the token first, which is the correct AWS semantics for a stale handle). - md5OfAttributesHex: drop the hand-rolled O(n^2) selection sort in favor of sort.Strings. No behavior change on the happy path; these are purely correctness-under-contention fixes. Existing tests still cover the single-writer flows, and the retry loops inherit the same test-wide timeout knobs the catalog code uses. --- adapter/sqs_messages.go | 142 +++++++++++++++++++++++++--------------- 1 file changed, 91 insertions(+), 51 deletions(-) diff --git a/adapter/sqs_messages.go b/adapter/sqs_messages.go index 1c7682eb9..4bd261c58 100644 --- a/adapter/sqs_messages.go +++ b/adapter/sqs_messages.go @@ -9,6 +9,7 @@ import ( "encoding/binary" "encoding/hex" "net/http" + "sort" "strconv" "strings" "time" @@ -519,8 +520,13 @@ func (s *SQSServer) tryDeliverCandidate( return nil, false } newVisKey := sqsMsgVisKey(queueName, gen, newVisibleAt, cand.messageID) + // ReadKeys include the visibility entry and the data record so a + // concurrent ReceiveMessage that rotated the same message commits a + // newer version and forces our Dispatch to return ErrWriteConflict. + // Skipping the candidate on conflict prevents duplicate delivery. req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, + IsTxn: true, + ReadKeys: [][]byte{cand.visKey, dataKey}, Elems: []*kv.Elem[kv.OP]{ {Op: kv.Del, Key: cand.visKey}, {Op: kv.Put, Key: newVisKey, Value: []byte(cand.messageID)}, @@ -559,26 +565,49 @@ func (s *SQSServer) deleteMessage(w http.ResponseWriter, r *http.Request) { writeSQSErrorFromErr(w, err) return } - rec, dataKey, apiErr := s.loadAndVerifyMessage(r.Context(), queueName, handle) - if apiErr != nil { - writeSQSErrorFromErr(w, apiErr) - return - } - visKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) - req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, - Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Del, Key: dataKey}, - {Op: kv.Del, Key: visKey}, - }, - } - if _, err := s.coordinator.Dispatch(r.Context(), req); err != nil { + if err := s.deleteMessageWithRetry(r.Context(), queueName, handle); err != nil { writeSQSErrorFromErr(w, err) return } writeSQSJSON(w, map[string]any{}) } +// deleteMessageWithRetry runs the receipt-token check and the delete +// transaction under one OCC budget. A concurrent ReceiveMessage / +// ChangeMessageVisibility that rotates the token between validate and +// commit produces ErrWriteConflict; we retry by re-validating the +// current token against the handle. If the token has actually been +// rotated under us, the next pass returns InvalidReceiptHandle. +func (s *SQSServer) deleteMessageWithRetry(ctx context.Context, queueName string, handle *decodedReceiptHandle) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + rec, dataKey, apiErr := s.loadAndVerifyMessage(ctx, queueName, handle) + if apiErr != nil { + return apiErr + } + visKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + ReadKeys: [][]byte{dataKey, visKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: dataKey}, + {Op: kv.Del, Key: visKey}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "delete message retry attempts exhausted") +} + func (s *SQSServer) changeMessageVisibility(w http.ResponseWriter, r *http.Request) { var in sqsChangeVisibilityInput if err := decodeSQSJSONInput(r, &in); err != nil { @@ -594,40 +623,58 @@ func (s *SQSServer) changeMessageVisibility(w http.ResponseWriter, r *http.Reque writeSQSErrorFromErr(w, err) return } - rec, dataKey, apiErr := s.loadAndVerifyMessage(r.Context(), queueName, handle) - if apiErr != nil { - writeSQSErrorFromErr(w, apiErr) - return - } - now := time.Now().UnixMilli() - if rec.VisibleAtMillis <= now { - writeSQSError(w, http.StatusBadRequest, sqsErrMessageNotInflight, "message is not currently in flight") - return - } - - oldVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) - rec.VisibleAtMillis = now + in.VisibilityTimeout*sqsMillisPerSecond - recordBytes, err := encodeSQSMessageRecord(rec) - if err != nil { - writeSQSErrorFromErr(w, err) - return - } - newVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) - req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, - Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Del, Key: oldVisKey}, - {Op: kv.Put, Key: newVisKey, Value: []byte(rec.MessageID)}, - {Op: kv.Put, Key: dataKey, Value: recordBytes}, - }, - } - if _, err := s.coordinator.Dispatch(r.Context(), req); err != nil { + if err := s.changeVisibilityWithRetry(r.Context(), queueName, handle, in.VisibilityTimeout); err != nil { writeSQSErrorFromErr(w, err) return } writeSQSJSON(w, map[string]any{}) } +// changeVisibilityWithRetry runs the validate-and-swap flow under an OCC +// retry budget. ReadKeys cover the data record and the current vis +// entry; a concurrent receive or delete will bump their commitTS past +// our startTS and we re-validate. +func (s *SQSServer) changeVisibilityWithRetry(ctx context.Context, queueName string, handle *decodedReceiptHandle, newTimeout int64) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + rec, dataKey, apiErr := s.loadAndVerifyMessage(ctx, queueName, handle) + if apiErr != nil { + return apiErr + } + now := time.Now().UnixMilli() + if rec.VisibleAtMillis <= now { + return newSQSAPIError(http.StatusBadRequest, sqsErrMessageNotInflight, "message is not currently in flight") + } + oldVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + rec.VisibleAtMillis = now + newTimeout*sqsMillisPerSecond + recordBytes, err := encodeSQSMessageRecord(rec) + if err != nil { + return errors.WithStack(err) + } + newVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + ReadKeys: [][]byte{dataKey, oldVisKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: oldVisKey}, + {Op: kv.Put, Key: newVisKey, Value: []byte(rec.MessageID)}, + {Op: kv.Put, Key: dataKey, Value: recordBytes}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "change visibility retry attempts exhausted") +} + // parseQueueAndReceipt extracts the queue name and decodes the receipt // handle from a DeleteMessage / ChangeMessageVisibility input. func (s *SQSServer) parseQueueAndReceipt(queueUrl, receiptHandle string) (string, *decodedReceiptHandle, error) { @@ -685,14 +732,7 @@ func md5OfAttributesHex(attrs map[string]string) string { for k := range attrs { keys = append(keys, k) } - // Stable order for determinism. - for i := 0; i < len(keys); i++ { - for j := i + 1; j < len(keys); j++ { - if keys[j] < keys[i] { - keys[i], keys[j] = keys[j], keys[i] - } - } - } + sort.Strings(keys) var b strings.Builder for _, k := range keys { b.WriteString(k) From 6685fbf530a810c28a743595549eae034b5afdc7 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 16:55:08 +0900 Subject: [PATCH 7/9] fix(sqs): make DeleteMessage idempotent on stale receipt handles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AWS SQS semantics: DeleteMessage with a stale receipt handle (token rotated under us, or record already gone) is a 200 success no-op, NOT an InvalidReceiptHandle. This is relied on by SDK retry paths and batch workers, and also by clients that do a visibility-expiry redelivery dance. - deleteMessageWithRetry now calls a new loadMessageForDelete helper that returns a sqsDeleteOutcome tag. Missing record and token mismatch both map to sqsDeleteNoOp -> return nil (handler renders 200 empty body). Token-match proceeds to the OCC delete txn. Only structural errors (malformed handle) and retry-budget exhaustion still propagate as errors. - ChangeMessageVisibility keeps the strict behavior — AWS also errors there when the handle is stale, so changeVisibilityWithRetry is unchanged. - Test rename: TestSQSServer_DeleteWithWrongReceiptRejected -> TestSQSServer_DeleteWithStaleReceiptIsIdempotentNoOp. The test now asserts that a token-mismatched delete returns 200, that the real handle still works afterwards (the stale no-op must not have "stolen" the in-flight message), and that a structurally malformed handle still fails with ReceiptHandleIsInvalid. The ReceiveCount / first-receive-timestamp behavior the same codex review flagged is already implemented correctly (tryDeliverCandidate: "if rec.FirstReceiveMillis == 0 { ... = now }"), so no code change was needed for that finding beyond the design doc clarification that landed on feat/sqs_compatible_adapter. --- adapter/sqs_messages.go | 62 ++++++++++++++++++++++++++++++------ adapter/sqs_messages_test.go | 46 ++++++++++++++++++++------ 2 files changed, 89 insertions(+), 19 deletions(-) diff --git a/adapter/sqs_messages.go b/adapter/sqs_messages.go index 4bd261c58..614b7f2d6 100644 --- a/adapter/sqs_messages.go +++ b/adapter/sqs_messages.go @@ -572,19 +572,28 @@ func (s *SQSServer) deleteMessage(w http.ResponseWriter, r *http.Request) { writeSQSJSON(w, map[string]any{}) } -// deleteMessageWithRetry runs the receipt-token check and the delete -// transaction under one OCC budget. A concurrent ReceiveMessage / -// ChangeMessageVisibility that rotates the token between validate and -// commit produces ErrWriteConflict; we retry by re-validating the -// current token against the handle. If the token has actually been -// rotated under us, the next pass returns InvalidReceiptHandle. +// deleteMessageWithRetry runs the load-check-commit flow under one OCC +// budget. AWS SQS semantics: a stale receipt handle (message already +// gone, or token rotated by another consumer) is a 200 no-op, NOT an +// error. The only error cases are structural (malformed handle, caught +// before this function) and infrastructure (retry budget exhausted). +// ErrWriteConflict on the delete Dispatch means a concurrent rotation +// / delete landed between our read and our commit; we retry so the +// next pass either sees the rotated token (no-op success) or the +// missing record (no-op success). func (s *SQSServer) deleteMessageWithRetry(ctx context.Context, queueName string, handle *decodedReceiptHandle) error { backoff := transactRetryInitialBackoff deadline := time.Now().Add(transactRetryMaxDuration) for range transactRetryMaxAttempts { - rec, dataKey, apiErr := s.loadAndVerifyMessage(ctx, queueName, handle) - if apiErr != nil { - return apiErr + rec, dataKey, outcome, err := s.loadMessageForDelete(ctx, queueName, handle) + if err != nil { + return err + } + switch outcome { + case sqsDeleteNoOp: + return nil + case sqsDeleteProceed: + // fall through to commit below } visKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) req := &kv.OperationGroup[kv.OP]{ @@ -608,6 +617,41 @@ func (s *SQSServer) deleteMessageWithRetry(ctx context.Context, queueName string return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "delete message retry attempts exhausted") } +// sqsDeleteOutcome is a ternary tag returned by loadMessageForDelete so +// the caller can cleanly distinguish the AWS-idempotent no-op case from +// the proceed-to-commit case without conflating them with errors. +type sqsDeleteOutcome int + +const ( + sqsDeleteProceed sqsDeleteOutcome = iota + sqsDeleteNoOp +) + +// loadMessageForDelete reads the message record and classifies the +// outcome for AWS-compatible DeleteMessage semantics: structural errors +// propagate; missing records and token mismatches return +// sqsDeleteNoOp; matching tokens return sqsDeleteProceed with the +// loaded record. +func (s *SQSServer) loadMessageForDelete(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, sqsDeleteOutcome, error) { + readTS := s.nextTxnReadTS(ctx) + dataKey := sqsMsgDataKey(queueName, handle.QueueGeneration, handle.MessageIDHex) + raw, err := s.store.GetAt(ctx, dataKey, readTS) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, nil, sqsDeleteNoOp, nil + } + return nil, nil, sqsDeleteProceed, errors.WithStack(err) + } + rec, err := decodeSQSMessageRecord(raw) + if err != nil { + return nil, nil, sqsDeleteProceed, errors.WithStack(err) + } + if !bytes.Equal(rec.CurrentReceiptToken, handle.ReceiptToken) { + return nil, nil, sqsDeleteNoOp, nil + } + return rec, dataKey, sqsDeleteProceed, nil +} + func (s *SQSServer) changeMessageVisibility(w http.ResponseWriter, r *http.Request) { var in sqsChangeVisibilityInput if err := decodeSQSJSONInput(r, &in); err != nil { diff --git a/adapter/sqs_messages_test.go b/adapter/sqs_messages_test.go index 688cda999..cc86b6db9 100644 --- a/adapter/sqs_messages_test.go +++ b/adapter/sqs_messages_test.go @@ -112,14 +112,19 @@ func deleteMessageOK(t *testing.T, node Node, queueURL, receipt string) { } } -func TestSQSServer_DeleteWithWrongReceiptRejected(t *testing.T) { +func TestSQSServer_DeleteWithStaleReceiptIsIdempotentNoOp(t *testing.T) { t.Parallel() + // AWS SQS semantics: DeleteMessage with a stale receipt handle (token + // rotated under our feet, or record already gone) must return 200 + // success without deleting. SDK retry paths and batch workers rely on + // this so a retry after a visibility-expiry re-delivery does not fail + // loudly. Structurally malformed handles are still an error; + // token-only mismatches are not. nodes, _, _ := createNode(t, 1) defer shutdown(nodes) node := sqsLeaderNode(t, nodes) - queueURL := createSQSQueueForTest(t, node, "wrong-receipt") + queueURL := createSQSQueueForTest(t, node, "stale-receipt") - // Send + receive once to get a real message on the queue. _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ "QueueUrl": queueURL, "MessageBody": "x", @@ -138,26 +143,47 @@ func TestSQSServer_DeleteWithWrongReceiptRejected(t *testing.T) { } goodHandle, _ := msgs[0].(map[string]any)["ReceiptHandle"].(string) - // Tamper the receipt token portion of the handle and expect - // InvalidReceiptHandle. + // Flip a byte of the token so the stored token != handle token. decoded, err := decodeReceiptHandle(goodHandle) if err != nil { t.Fatalf("decode: %v", err) } decoded.ReceiptToken[0] ^= 0xff - badHandle, err := encodeReceiptHandle(decoded.QueueGeneration, decoded.MessageIDHex, decoded.ReceiptToken) + staleHandle, err := encodeReceiptHandle(decoded.QueueGeneration, decoded.MessageIDHex, decoded.ReceiptToken) if err != nil { t.Fatalf("encode: %v", err) } + + // Delete with the stale handle must succeed (no-op) per AWS. + status, out = callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": staleHandle, + }) + if status != http.StatusOK { + t.Fatalf("delete with stale receipt: status=%d body=%v", status, out) + } + + // The real handle must still work — the stale delete must not have + // removed the in-flight message out from under the original consumer. + status, out = callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": goodHandle, + }) + if status != http.StatusOK { + t.Fatalf("delete with good receipt after stale no-op: %d %v", status, out) + } + + // A structurally malformed handle still errors out — only token + // mismatches are the idempotent no-op case. status, out = callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ "QueueUrl": queueURL, - "ReceiptHandle": badHandle, + "ReceiptHandle": "not-base64-!!!", }) if status != http.StatusBadRequest { - t.Fatalf("delete with bad receipt: status=%d body=%v", status, out) + t.Fatalf("malformed handle: status=%d body=%v", status, out) } - if out["__type"] != sqsErrInvalidReceiptHandle { - t.Fatalf("error type: %q want %q", out["__type"], sqsErrInvalidReceiptHandle) + if out["__type"] != sqsErrReceiptHandleInvalid { + t.Fatalf("error type for malformed handle: %q want %q", out["__type"], sqsErrReceiptHandleInvalid) } } From 1846fd018b3cf91d98f3a48833a5ac1492b7ee80 Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 17:18:04 +0900 Subject: [PATCH 8/9] fix(sqs): pin OCC StartTS to validated-read snapshot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit codex round 2 on #610 flagged two P1 OCC safety holes: the ReceiveMessage rotation and the DeleteMessage commit both left StartTS=0, so the coordinator auto-assigned a newer StartTS at Dispatch time. A concurrent rotation that committed between our load and the auto-assigned StartTS would fall below the ReadKeys check and slip through, allowing stale-snapshot writes to land — double delivery in the receive path, stealing an in-flight message in the delete path. Fix: carry the snapshot readTS we actually loaded the record at through to the Dispatch call and set it as StartTS on the OperationGroup. The existing ReadKeys then fence correctly against any commit in (readTS, now]. - tryDeliverCandidate: StartTS=readTS on the rotation txn. - loadMessageForDelete: now returns the readTS it took; caller (deleteMessageWithRetry) threads it into the delete txn. - loadAndVerifyMessage: now returns the readTS too; changeVisibilityWithRetry threads it into the swap txn. No behavior change under no-contention. Existing tests still pass; a Jepsen-style concurrent receive/delete test that exposes the race will land in the follow-up workload PR alongside FIFO. --- adapter/sqs_messages.go | 57 ++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/adapter/sqs_messages.go b/adapter/sqs_messages.go index 614b7f2d6..fe8ba5443 100644 --- a/adapter/sqs_messages.go +++ b/adapter/sqs_messages.go @@ -520,12 +520,16 @@ func (s *SQSServer) tryDeliverCandidate( return nil, false } newVisKey := sqsMsgVisKey(queueName, gen, newVisibleAt, cand.messageID) - // ReadKeys include the visibility entry and the data record so a - // concurrent ReceiveMessage that rotated the same message commits a - // newer version and forces our Dispatch to return ErrWriteConflict. - // Skipping the candidate on conflict prevents duplicate delivery. + // StartTS pins the OCC read snapshot to the timestamp we actually + // loaded the record at. Without it, the coordinator assigns a newer + // StartTS at dispatch, so a concurrent rotation that committed + // AFTER our read but BEFORE the assigned StartTS would slip through + // ReadKeys validation and let this transaction double-deliver. + // ReadKeys cover both the visibility entry and the data record so + // any concurrent commit on either produces ErrWriteConflict. req := &kv.OperationGroup[kv.OP]{ IsTxn: true, + StartTS: readTS, ReadKeys: [][]byte{cand.visKey, dataKey}, Elems: []*kv.Elem[kv.OP]{ {Op: kv.Del, Key: cand.visKey}, @@ -585,7 +589,7 @@ func (s *SQSServer) deleteMessageWithRetry(ctx context.Context, queueName string backoff := transactRetryInitialBackoff deadline := time.Now().Add(transactRetryMaxDuration) for range transactRetryMaxAttempts { - rec, dataKey, outcome, err := s.loadMessageForDelete(ctx, queueName, handle) + rec, dataKey, readTS, outcome, err := s.loadMessageForDelete(ctx, queueName, handle) if err != nil { return err } @@ -596,8 +600,12 @@ func (s *SQSServer) deleteMessageWithRetry(ctx context.Context, queueName string // fall through to commit below } visKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + // StartTS pins OCC to the snapshot we loaded the record at, so a + // concurrent rotation that commits after our load but before a + // coordinator-assigned StartTS cannot slip past ReadKeys. req := &kv.OperationGroup[kv.OP]{ IsTxn: true, + StartTS: readTS, ReadKeys: [][]byte{dataKey, visKey}, Elems: []*kv.Elem[kv.OP]{ {Op: kv.Del, Key: dataKey}, @@ -631,25 +639,27 @@ const ( // outcome for AWS-compatible DeleteMessage semantics: structural errors // propagate; missing records and token mismatches return // sqsDeleteNoOp; matching tokens return sqsDeleteProceed with the -// loaded record. -func (s *SQSServer) loadMessageForDelete(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, sqsDeleteOutcome, error) { +// loaded record. The readTS it took the snapshot at is returned so the +// caller can pass it as StartTS on the OCC dispatch, pinning the +// read-write conflict detection window. +func (s *SQSServer) loadMessageForDelete(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, uint64, sqsDeleteOutcome, error) { readTS := s.nextTxnReadTS(ctx) dataKey := sqsMsgDataKey(queueName, handle.QueueGeneration, handle.MessageIDHex) raw, err := s.store.GetAt(ctx, dataKey, readTS) if err != nil { if errors.Is(err, store.ErrKeyNotFound) { - return nil, nil, sqsDeleteNoOp, nil + return nil, nil, readTS, sqsDeleteNoOp, nil } - return nil, nil, sqsDeleteProceed, errors.WithStack(err) + return nil, nil, readTS, sqsDeleteProceed, errors.WithStack(err) } rec, err := decodeSQSMessageRecord(raw) if err != nil { - return nil, nil, sqsDeleteProceed, errors.WithStack(err) + return nil, nil, readTS, sqsDeleteProceed, errors.WithStack(err) } if !bytes.Equal(rec.CurrentReceiptToken, handle.ReceiptToken) { - return nil, nil, sqsDeleteNoOp, nil + return nil, nil, readTS, sqsDeleteNoOp, nil } - return rec, dataKey, sqsDeleteProceed, nil + return rec, dataKey, readTS, sqsDeleteProceed, nil } func (s *SQSServer) changeMessageVisibility(w http.ResponseWriter, r *http.Request) { @@ -682,7 +692,7 @@ func (s *SQSServer) changeVisibilityWithRetry(ctx context.Context, queueName str backoff := transactRetryInitialBackoff deadline := time.Now().Add(transactRetryMaxDuration) for range transactRetryMaxAttempts { - rec, dataKey, apiErr := s.loadAndVerifyMessage(ctx, queueName, handle) + rec, dataKey, readTS, apiErr := s.loadAndVerifyMessage(ctx, queueName, handle) if apiErr != nil { return apiErr } @@ -697,8 +707,13 @@ func (s *SQSServer) changeVisibilityWithRetry(ctx context.Context, queueName str return errors.WithStack(err) } newVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + // StartTS pins OCC to the snapshot; without it the coordinator + // would auto-assign a newer StartTS and a concurrent receive / + // delete that commits between our load and dispatch could slip + // past the ReadKeys validation. req := &kv.OperationGroup[kv.OP]{ IsTxn: true, + StartTS: readTS, ReadKeys: [][]byte{dataKey, oldVisKey}, Elems: []*kv.Elem[kv.OP]{ {Op: kv.Del, Key: oldVisKey}, @@ -735,25 +750,27 @@ func (s *SQSServer) parseQueueAndReceipt(queueUrl, receiptHandle string) (string // loadAndVerifyMessage reads the data record for the given handle and // verifies that the receipt token matches the current one on record. -// Returns the record, its key, or a typed SQS error. -func (s *SQSServer) loadAndVerifyMessage(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, error) { +// Returns the record, its key, the snapshot timestamp the read ran at, +// or a typed SQS error. Callers use the snapshot as StartTS on the +// OCC dispatch so concurrent commits cannot slip past ReadKeys. +func (s *SQSServer) loadAndVerifyMessage(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, uint64, error) { readTS := s.nextTxnReadTS(ctx) dataKey := sqsMsgDataKey(queueName, handle.QueueGeneration, handle.MessageIDHex) raw, err := s.store.GetAt(ctx, dataKey, readTS) if err != nil { if errors.Is(err, store.ErrKeyNotFound) { - return nil, nil, newSQSAPIError(http.StatusBadRequest, sqsErrReceiptHandleInvalid, "message not found") + return nil, nil, readTS, newSQSAPIError(http.StatusBadRequest, sqsErrReceiptHandleInvalid, "message not found") } - return nil, nil, errors.WithStack(err) + return nil, nil, readTS, errors.WithStack(err) } rec, err := decodeSQSMessageRecord(raw) if err != nil { - return nil, nil, errors.WithStack(err) + return nil, nil, readTS, errors.WithStack(err) } if !bytes.Equal(rec.CurrentReceiptToken, handle.ReceiptToken) { - return nil, nil, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidReceiptHandle, "receipt handle token does not match") + return nil, nil, readTS, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidReceiptHandle, "receipt handle token does not match") } - return rec, dataKey, nil + return rec, dataKey, readTS, nil } // ------------------------ small helpers ------------------------ From 13e75f4e93a8aeef9dea6abc70c5ed5b61f2f0fe Mon Sep 17 00:00:00 2001 From: "Yoshiaki Ueda (bootjp)" Date: Fri, 24 Apr 2026 17:24:50 +0900 Subject: [PATCH 9/9] fix(sqs): pin catalog OCC StartTS to validated-read snapshot codex round 3 on #610 flagged the same StartTS-missing hole the message path already has fixed on the catalog CRUD path: - tryCreateQueueOnce reads queue absence + generation at readTS, then dispatches Put without StartTS or ReadKeys. Two races can both observe "queue missing" and both commit, and the coordinator auto-assigns StartTS so OCC never rejects the second one. - deleteQueueWithRetry and setQueueAttributesWithRetry have the same pattern (read at readTS, write with no StartTS/ReadKeys), which lets a concurrent create/delete/set clobber a just- committed change. Fix: carry readTS through to the dispatch as StartTS and add ReadKeys covering the meta and (for create/delete) generation records. The existing retry loops already handle ErrWriteConflict, so contention now retries instead of silently committing stale state. No behavior change at the single-writer level; the existing idempotent-create test and delete-generation-bump test still pass. --- adapter/sqs_catalog.go | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/adapter/sqs_catalog.go b/adapter/sqs_catalog.go index 010e96802..9abb74f1b 100644 --- a/adapter/sqs_catalog.go +++ b/adapter/sqs_catalog.go @@ -502,11 +502,22 @@ func (s *SQSServer) tryCreateQueueOnce(ctx context.Context, requested *sqsQueueM if err != nil { return false, errors.WithStack(err) } + metaKey := sqsQueueMetaKey(requested.Name) + genKey := sqsQueueGenKey(requested.Name) + // StartTS pins OCC to the snapshot we took the existence + generation + // read at, and ReadKeys cover both the meta and generation records so + // a concurrent CreateQueue that committed between our read and our + // dispatch causes ErrWriteConflict and the retry loop re-reads. + // Without this, two races could both decide "queue missing" and both + // write their own generation, leaving the later write on top of a + // record the loser never observed. req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{metaKey, genKey}, Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Put, Key: sqsQueueMetaKey(requested.Name), Value: metaBytes}, - {Op: kv.Put, Key: sqsQueueGenKey(requested.Name), Value: []byte(strconv.FormatUint(requested.Generation, 10))}, + {Op: kv.Put, Key: metaKey, Value: metaBytes}, + {Op: kv.Put, Key: genKey, Value: []byte(strconv.FormatUint(requested.Generation, 10))}, }, } if _, err := s.coordinator.Dispatch(ctx, req); err != nil { @@ -554,11 +565,17 @@ func (s *SQSServer) deleteQueueWithRetry(ctx context.Context, queueName string) if err != nil { return errors.WithStack(err) } + metaKey := sqsQueueMetaKey(queueName) + genKey := sqsQueueGenKey(queueName) + // StartTS + ReadKeys fence against a concurrent CreateQueue / + // SetQueueAttributes landing between our load and dispatch. req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{metaKey, genKey}, Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Del, Key: sqsQueueMetaKey(queueName)}, - {Op: kv.Put, Key: sqsQueueGenKey(queueName), Value: []byte(strconv.FormatUint(lastGen+1, 10))}, + {Op: kv.Del, Key: metaKey}, + {Op: kv.Put, Key: genKey, Value: []byte(strconv.FormatUint(lastGen+1, 10))}, }, } if _, err := s.coordinator.Dispatch(ctx, req); err == nil { @@ -812,10 +829,16 @@ func (s *SQSServer) setQueueAttributesWithRetry(ctx context.Context, queueName s if err != nil { return errors.WithStack(err) } + metaKey := sqsQueueMetaKey(queueName) + // StartTS + ReadKeys prevent two concurrent SetQueueAttributes + // from both reading the same old meta and the later dispatch + // clobbering the earlier commit's changes. req := &kv.OperationGroup[kv.OP]{ - IsTxn: true, + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{metaKey}, Elems: []*kv.Elem[kv.OP]{ - {Op: kv.Put, Key: sqsQueueMetaKey(queueName), Value: metaBytes}, + {Op: kv.Put, Key: metaKey, Value: metaBytes}, }, } if _, err := s.coordinator.Dispatch(ctx, req); err == nil {