diff --git a/adapter/dynamodb.go b/adapter/dynamodb.go index 0de6aa393..00a18d981 100644 --- a/adapter/dynamodb.go +++ b/adapter/dynamodb.go @@ -12,8 +12,6 @@ import ( "math/big" "net" "net/http" - "net/http/httputil" - "net/url" "slices" "sort" "strconv" @@ -294,33 +292,15 @@ func (d *DynamoDBServer) Stop() { // (either proxied or an error response was written), false if the request // should be handled locally (i.e. this node is the leader or no leader map is // configured). +// +// Serving reads or writes locally on a follower would expose G2-item-realtime +// stale reads, so every follower request is forwarded to the leader. func (d *DynamoDBServer) proxyToLeader(w http.ResponseWriter, r *http.Request) bool { - if len(d.leaderDynamo) == 0 || d.coordinator == nil { - return false - } - if d.coordinator.IsLeader() { - return false - } - // This node is a follower. All requests must be forwarded to the leader to - // preserve linearizability — serving reads or writes locally on a follower - // causes stale-read anomalies (G2-item-realtime). - leader := d.coordinator.RaftLeader() - if leader == "" { - writeDynamoError(w, http.StatusServiceUnavailable, dynamoErrInternal, "no raft leader currently available") - return true - } - targetAddr, ok := d.leaderDynamo[leader] - if !ok || strings.TrimSpace(targetAddr) == "" { - writeDynamoError(w, http.StatusServiceUnavailable, dynamoErrInternal, "leader dynamo address not found") - return true - } - target := &url.URL{Scheme: "http", Host: targetAddr} - proxy := httputil.NewSingleHostReverseProxy(target) - proxy.ErrorHandler = func(rw http.ResponseWriter, _ *http.Request, err error) { - writeDynamoError(rw, http.StatusServiceUnavailable, dynamoErrInternal, "leader proxy error: "+err.Error()) - } - proxy.ServeHTTP(w, r) - return true + return proxyHTTPRequestToLeader(d.coordinator, d.leaderDynamo, dynamoLeaderProxyErrorWriter, w, r) +} + +func dynamoLeaderProxyErrorWriter(w http.ResponseWriter, status int, message string) { + writeDynamoError(w, status, dynamoErrInternal, message) } func (d *DynamoDBServer) handle(w http.ResponseWriter, r *http.Request) { diff --git a/adapter/leader_http_proxy.go b/adapter/leader_http_proxy.go new file mode 100644 index 000000000..8483698b8 --- /dev/null +++ b/adapter/leader_http_proxy.go @@ -0,0 +1,64 @@ +package adapter + +import ( + "net/http" + "net/http/httputil" + "net/url" + "strings" + + "github.com/bootjp/elastickv/kv" +) + +// httpLeaderErrorWriter writes an adapter-specific error envelope (JSON for +// DynamoDB/SQS, XML for S3, RESP for Redis, ...) when the HTTP leader proxy +// cannot forward a follower request. The adapter owns the shape of the body; +// this helper only decides when to call it. +type httpLeaderErrorWriter func(w http.ResponseWriter, status int, message string) + +// proxyHTTPRequestToLeader forwards r to the current Raft leader's HTTP +// adapter endpoint when this node is a follower. It returns true when the +// request was handled (either proxied or an error body was written) and +// false when the caller should serve it locally (no leader map configured, +// no coordinator, or this node is the leader). +// +// Error paths: +// 1. no known Raft leader → 503 via errWriter("no raft leader currently available") +// 2. leader id missing from leaderMap → 503 via errWriter("leader address not found") +// 3. reverse-proxy dial/copy failure → 503 via errWriter("leader proxy error: ") +// +// leaderMap keys are Raft addresses; values are the matching adapter HTTP +// addresses exported on that same node. +func proxyHTTPRequestToLeader( + coordinator kv.Coordinator, + leaderMap map[string]string, + errWriter httpLeaderErrorWriter, + w http.ResponseWriter, + r *http.Request, +) bool { + if len(leaderMap) == 0 || coordinator == nil { + return false + } + if coordinator.IsLeader() { + return false + } + // Follower ingress: forward to the leader so reads and writes both see a + // single serialization point. Serving locally on a follower would expose + // G2-item-realtime stale reads. + leader := coordinator.RaftLeader() + if leader == "" { + errWriter(w, http.StatusServiceUnavailable, "no raft leader currently available") + return true + } + targetAddr, ok := leaderMap[leader] + if !ok || strings.TrimSpace(targetAddr) == "" { + errWriter(w, http.StatusServiceUnavailable, "leader address not found") + return true + } + target := &url.URL{Scheme: "http", Host: targetAddr} + proxy := httputil.NewSingleHostReverseProxy(target) + proxy.ErrorHandler = func(rw http.ResponseWriter, _ *http.Request, err error) { + errWriter(rw, http.StatusServiceUnavailable, "leader proxy error: "+err.Error()) + } + proxy.ServeHTTP(w, r) + return true +} diff --git a/adapter/s3_auth.go b/adapter/s3_auth.go index 1df6070de..15a61bbd1 100644 --- a/adapter/s3_auth.go +++ b/adapter/s3_auth.go @@ -1,7 +1,6 @@ package adapter import ( - "context" "crypto/subtle" "net/http" "net/url" @@ -12,11 +11,12 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/bootjp/elastickv/kv" - "github.com/cockroachdb/errors" ) const ( - s3SigV4Algorithm = "AWS4-HMAC-SHA256" + // s3SigV4Algorithm is an alias of sigv4Algorithm kept for call-site + // readability in the S3 adapter. + s3SigV4Algorithm = sigv4Algorithm // s3UnsignedPayload / s3Streaming* are sentinel values that AWS SDKs may // place in X-Amz-Content-Sha256. None of them are a literal SHA-256 hash // of the body, so the PUT pipeline must skip hash validation when it sees @@ -27,8 +27,8 @@ const ( s3StreamingSignedPayload = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" s3StreamingSignedPayloadTrailer = "STREAMING-AWS4-HMAC-SHA256-PAYLOAD-TRAILER" s3EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - s3DateHeaderFormat = "20060102T150405Z" - s3RequestTimeMaxSkew = 15 * time.Minute + s3DateHeaderFormat = sigv4DateHeaderFormat + s3RequestTimeMaxSkew = sigv4RequestTimeMaxSkew ) // isS3PayloadMarker reports whether the given X-Amz-Content-Sha256 value is a @@ -79,14 +79,6 @@ type s3AuthError struct { Message string } -type s3AuthorizationHeader struct { - Algorithm string - AccessKeyID string - Date string - Region string - Service string -} - func WithS3Region(region string) S3ServerOption { return func(server *S3Server) { if server == nil || strings.TrimSpace(region) == "" { @@ -151,7 +143,7 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { } } - parsed, err := parseS3AuthorizationHeader(authHeader) + parsed, err := parseSigV4AuthorizationHeader(authHeader) if err != nil { return &s3AuthError{ Status: http.StatusForbidden, @@ -224,7 +216,7 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { } } - expectedAuth, err := buildS3AuthorizationHeader(r, parsed.AccessKeyID, secretAccessKey, s.effectiveRegion(), signingTime, payloadHash) + expectedAuth, err := buildSigV4AuthorizationHeader(r, parsed.AccessKeyID, secretAccessKey, "s3", s.effectiveRegion(), signingTime, payloadHash) if err != nil { return &s3AuthError{ Status: http.StatusForbidden, @@ -235,8 +227,8 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { // Compare only the Signature component to avoid false rejections caused by // equivalent Authorization headers that differ in whitespace or parameter // ordering (but carry the same cryptographic signature). - gotSig := extractS3Signature(authHeader) - expectedSig := extractS3Signature(expectedAuth) + gotSig := extractSigV4Signature(authHeader) + expectedSig := extractSigV4Signature(expectedAuth) if gotSig == "" || subtle.ConstantTimeCompare([]byte(gotSig), []byte(expectedSig)) != 1 { return &s3AuthError{ Status: http.StatusForbidden, @@ -247,65 +239,6 @@ func (s *S3Server) authorizeRequest(r *http.Request) *s3AuthError { return nil } -func buildS3AuthorizationHeader(r *http.Request, accessKeyID string, secretAccessKey string, region string, signingTime time.Time, payloadHash string) (string, error) { - if r == nil { - return "", errors.New("request is required") - } - clone := r.Clone(context.Background()) - clone.Host = r.Host - clone.Header = clone.Header.Clone() - clone.Header.Del("Authorization") - - signer := v4.NewSigner(func(opts *v4.SignerOptions) { - opts.DisableURIPathEscaping = true - }) - creds := aws.Credentials{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - Source: "elastickv-s3", - } - if err := signer.SignHTTP(context.Background(), creds, clone, payloadHash, "s3", region, signingTime.UTC()); err != nil { - return "", errors.WithStack(err) - } - return strings.TrimSpace(clone.Header.Get("Authorization")), nil -} - -//nolint:cyclop // AWS authorization parsing is branchy because malformed scopes must be rejected precisely. -func parseS3AuthorizationHeader(raw string) (s3AuthorizationHeader, error) { - raw = strings.TrimSpace(raw) - if raw == "" { - return s3AuthorizationHeader{}, errors.New("authorization header is required") - } - algorithm, rest, ok := strings.Cut(raw, " ") - if !ok { - return s3AuthorizationHeader{}, errors.New("authorization header is malformed") - } - out := s3AuthorizationHeader{Algorithm: strings.TrimSpace(algorithm)} - params := strings.Split(rest, ",") - for _, param := range params { - key, value, ok := strings.Cut(strings.TrimSpace(param), "=") - if !ok { - continue - } - if key != "Credential" { - continue - } - scope := strings.Split(value, "/") - if len(scope) != 5 || scope[4] != "aws4_request" { - return s3AuthorizationHeader{}, errors.New("credential scope is malformed") - } - out.AccessKeyID = scope[0] - out.Date = scope[1] - out.Region = scope[2] - out.Service = scope[3] - break - } - if out.AccessKeyID == "" || out.Date == "" || out.Region == "" || out.Service == "" { - return s3AuthorizationHeader{}, errors.New("credential scope is required") - } - return out, nil -} - func normalizeS3PayloadHash(raw string) string { return strings.TrimSpace(raw) } diff --git a/adapter/s3_test.go b/adapter/s3_test.go index f3f27d2be..457a540bd 100644 --- a/adapter/s3_test.go +++ b/adapter/s3_test.go @@ -874,7 +874,7 @@ func newSignedS3Request( signingTime, ) require.NoError(t, err) - expectedAuth, err := buildS3AuthorizationHeader(req, testS3AccessKey, testS3SecretKey, testS3Region, signingTime, payloadHash) + expectedAuth, err := buildSigV4AuthorizationHeader(req, testS3AccessKey, testS3SecretKey, "s3", testS3Region, signingTime, payloadHash) require.NoError(t, err) require.Equal(t, strings.TrimSpace(req.Header.Get("Authorization")), expectedAuth) return req diff --git a/adapter/sigv4.go b/adapter/sigv4.go new file mode 100644 index 000000000..1f26ff103 --- /dev/null +++ b/adapter/sigv4.go @@ -0,0 +1,203 @@ +package adapter + +import ( + "context" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/cockroachdb/errors" +) + +// Service-agnostic SigV4 primitives shared by the S3 and SQS adapters. +// The per-adapter files wrap these with their own error ordering and +// service-specific rules (payload hashing for S3, JSON body hashing for +// SQS, etc.). + +const ( + sigv4Algorithm = "AWS4-HMAC-SHA256" + sigv4DateHeaderFormat = "20060102T150405Z" + sigv4ScopeDateFormat = "20060102" + sigv4RequestTimeMaxSkew = 15 * time.Minute + sigv4ScopeTerminator = "aws4_request" + sigv4ScopeParts = 5 +) + +// sigv4AuthorizationHeader captures the Credential scope fields extracted +// from a SigV4 Authorization header ("AWS4-HMAC-SHA256 Credential=/ +// ///aws4_request, SignedHeaders=..., Signature=..."). +type sigv4AuthorizationHeader struct { + Algorithm string + AccessKeyID string + Date string + Region string + Service string +} + +// parseSigV4AuthorizationHeader decodes the Credential scope from an +// Authorization header. Other fields (SignedHeaders, Signature) are not +// parsed here because the adapter-level verifier rebuilds and compares the +// full signature through the AWS SDK v4 signer. +func parseSigV4AuthorizationHeader(raw string) (sigv4AuthorizationHeader, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return sigv4AuthorizationHeader{}, errors.New("authorization header is required") + } + algorithm, rest, ok := strings.Cut(raw, " ") + if !ok { + return sigv4AuthorizationHeader{}, errors.New("authorization header is malformed") + } + credentialValue, ok := findSigV4Param(rest, "Credential") + if !ok { + return sigv4AuthorizationHeader{}, errors.New("credential scope is required") + } + out, err := parseSigV4CredentialScope(credentialValue) + if err != nil { + return sigv4AuthorizationHeader{}, err + } + out.Algorithm = strings.TrimSpace(algorithm) + return out, nil +} + +// findSigV4Param returns the value of the named parameter from the +// "k=v, k=v, ..." tail of an Authorization header. +func findSigV4Param(params, name string) (string, bool) { + for _, param := range strings.Split(params, ",") { + key, value, ok := strings.Cut(strings.TrimSpace(param), "=") + if ok && key == name { + return strings.TrimSpace(value), true + } + } + return "", false +} + +// parseSigV4CredentialScope validates an +// ////aws4_request scope string. +func parseSigV4CredentialScope(value string) (sigv4AuthorizationHeader, error) { + scope := strings.Split(value, "/") + if len(scope) != sigv4ScopeParts || scope[4] != sigv4ScopeTerminator { + return sigv4AuthorizationHeader{}, errors.New("credential scope is malformed") + } + out := sigv4AuthorizationHeader{ + AccessKeyID: scope[0], + Date: scope[1], + Region: scope[2], + Service: scope[3], + } + if out.AccessKeyID == "" || out.Date == "" || out.Region == "" || out.Service == "" { + return sigv4AuthorizationHeader{}, errors.New("credential scope is required") + } + return out, nil +} + +// buildSigV4AuthorizationHeader re-signs r with the given credentials and +// returns the Authorization header the server expects. Used by adapter +// verifiers to compare against the client-supplied Authorization. +func buildSigV4AuthorizationHeader( + r *http.Request, + accessKeyID, secretAccessKey, service, region string, + signingTime time.Time, + payloadHash string, +) (string, error) { + if r == nil { + return "", errors.New("request is required") + } + clone := r.Clone(context.Background()) + clone.Host = r.Host + clone.Header = clone.Header.Clone() + clone.Header.Del("Authorization") + + signer := v4.NewSigner(func(opts *v4.SignerOptions) { + opts.DisableURIPathEscaping = true + }) + creds := aws.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + Source: "elastickv-" + service, + } + if err := signer.SignHTTP(context.Background(), creds, clone, payloadHash, service, region, signingTime.UTC()); err != nil { + return "", errors.WithStack(err) + } + return strings.TrimSpace(clone.Header.Get("Authorization")), nil +} + +// extractSigV4Signature returns the hex Signature= value from an +// Authorization header, or "" if missing. +func extractSigV4Signature(auth string) string { + return extractSigV4AuthField(auth, "Signature") +} + +// extractSigV4SignedHeaders returns the semicolon-separated SignedHeaders +// list as a []string, or nil if missing. Header names are lowercased. +func extractSigV4SignedHeaders(auth string) []string { + raw := extractSigV4AuthField(auth, "SignedHeaders") + if raw == "" { + return nil + } + parts := strings.Split(raw, ";") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + out = append(out, strings.ToLower(p)) + } + return out +} + +func extractSigV4AuthField(auth, field string) string { + _, params, ok := strings.Cut(auth, " ") + if !ok { + return "" + } + v, _ := findSigV4Param(params, field) + return v +} + +// buildSigV4AuthorizationHeaderRestricted is a variant of +// buildSigV4AuthorizationHeader that strips any header not listed in +// signedHeaders before re-signing, so the server's re-computation matches +// the client's canonical request even when Go's transport added headers +// (Accept-Encoding, etc.) after the original signature. +func buildSigV4AuthorizationHeaderRestricted( + r *http.Request, + accessKeyID, secretAccessKey, service, region string, + signingTime time.Time, + payloadHash string, + signedHeaders []string, +) (string, error) { + if r == nil { + return "", errors.New("request is required") + } + clone := r.Clone(context.Background()) + clone.Host = r.Host + clone.Header = clone.Header.Clone() + clone.Header.Del("Authorization") + if len(signedHeaders) > 0 { + allowed := make(map[string]struct{}, len(signedHeaders)) + for _, h := range signedHeaders { + allowed[strings.ToLower(strings.TrimSpace(h))] = struct{}{} + } + for name := range clone.Header { + if _, ok := allowed[strings.ToLower(name)]; !ok { + clone.Header.Del(name) + } + } + } + + signer := v4.NewSigner(func(opts *v4.SignerOptions) { + opts.DisableURIPathEscaping = true + }) + creds := aws.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + Source: "elastickv-" + service, + } + if err := signer.SignHTTP(context.Background(), creds, clone, payloadHash, service, region, signingTime.UTC()); err != nil { + return "", errors.WithStack(err) + } + return strings.TrimSpace(clone.Header.Get("Authorization")), nil +} diff --git a/adapter/sqs.go b/adapter/sqs.go new file mode 100644 index 000000000..1a3a06fdf --- /dev/null +++ b/adapter/sqs.go @@ -0,0 +1,264 @@ +package adapter + +import ( + "context" + "io" + "net" + "net/http" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +// SQS target prefix for the JSON-1.0 protocol. Every supported operation is +// dispatched by the X-Amz-Target header, mirroring the DynamoDB adapter. +const sqsTargetPrefix = "AmazonSQS." + +const ( + sqsCreateQueueTarget = sqsTargetPrefix + "CreateQueue" + sqsDeleteQueueTarget = sqsTargetPrefix + "DeleteQueue" + sqsListQueuesTarget = sqsTargetPrefix + "ListQueues" + sqsGetQueueUrlTarget = sqsTargetPrefix + "GetQueueUrl" + sqsGetQueueAttributesTarget = sqsTargetPrefix + "GetQueueAttributes" + sqsSetQueueAttributesTarget = sqsTargetPrefix + "SetQueueAttributes" + sqsPurgeQueueTarget = sqsTargetPrefix + "PurgeQueue" + sqsSendMessageTarget = sqsTargetPrefix + "SendMessage" + sqsSendMessageBatchTarget = sqsTargetPrefix + "SendMessageBatch" + sqsReceiveMessageTarget = sqsTargetPrefix + "ReceiveMessage" + sqsDeleteMessageTarget = sqsTargetPrefix + "DeleteMessage" + sqsDeleteMessageBatchTarget = sqsTargetPrefix + "DeleteMessageBatch" + sqsChangeMessageVisibilityTarget = sqsTargetPrefix + "ChangeMessageVisibility" + sqsChangeMessageVisibilityBatchTgt = sqsTargetPrefix + "ChangeMessageVisibilityBatch" + sqsTagQueueTarget = sqsTargetPrefix + "TagQueue" + sqsUntagQueueTarget = sqsTargetPrefix + "UntagQueue" + sqsListQueueTagsTarget = sqsTargetPrefix + "ListQueueTags" +) + +const ( + sqsHealthPath = "/sqs_health" + sqsLeaderHealthPath = "/sqs_leader_health" +) + +const ( + sqsHealthMaxRequestBodyBytes = 1024 + sqsMaxRequestBodyBytes = 1 << 20 + sqsContentTypeJSON = "application/x-amz-json-1.0" +) + +// AWS SQS error codes used by the JSON protocol. The canonical list is on the +// "Common Errors" page of the SQS API reference. +const ( + sqsErrInvalidAction = "InvalidAction" + sqsErrNotImplemented = "NotImplemented" + sqsErrInternalFailure = "InternalFailure" + sqsErrServiceUnavailable = "ServiceUnavailable" + sqsErrMalformedRequest = "MalformedQueryString" +) + +type SQSServerOption func(*SQSServer) + +type SQSServer struct { + listen net.Listener + store store.MVCCStore + coordinator kv.Coordinator + httpServer *http.Server + targetHandlers map[string]func(http.ResponseWriter, *http.Request) + leaderSQS map[string]string + region string + staticCreds map[string]string +} + +// WithSQSLeaderMap configures the Raft-address-to-SQS-address mapping used to +// forward requests from followers to the current leader. Format mirrors +// WithDynamoDBLeaderMap / WithS3LeaderMap. +func WithSQSLeaderMap(m map[string]string) SQSServerOption { + return func(s *SQSServer) { + s.leaderSQS = make(map[string]string, len(m)) + for k, v := range m { + s.leaderSQS[k] = v + } + } +} + +func NewSQSServer(listen net.Listener, st store.MVCCStore, coordinate kv.Coordinator, opts ...SQSServerOption) *SQSServer { + s := &SQSServer{ + listen: listen, + store: st, + coordinator: coordinate, + } + s.targetHandlers = map[string]func(http.ResponseWriter, *http.Request){ + sqsCreateQueueTarget: s.createQueue, + sqsDeleteQueueTarget: s.deleteQueue, + sqsListQueuesTarget: s.listQueues, + sqsGetQueueUrlTarget: s.getQueueUrl, + sqsGetQueueAttributesTarget: s.getQueueAttributes, + sqsSetQueueAttributesTarget: s.setQueueAttributes, + sqsPurgeQueueTarget: s.notImplemented("PurgeQueue"), + sqsSendMessageTarget: s.sendMessage, + sqsSendMessageBatchTarget: s.notImplemented("SendMessageBatch"), + sqsReceiveMessageTarget: s.receiveMessage, + sqsDeleteMessageTarget: s.deleteMessage, + sqsDeleteMessageBatchTarget: s.notImplemented("DeleteMessageBatch"), + sqsChangeMessageVisibilityTarget: s.changeMessageVisibility, + sqsChangeMessageVisibilityBatchTgt: s.notImplemented("ChangeMessageVisibilityBatch"), + sqsTagQueueTarget: s.notImplemented("TagQueue"), + sqsUntagQueueTarget: s.notImplemented("UntagQueue"), + sqsListQueueTagsTarget: s.notImplemented("ListQueueTags"), + } + mux := http.NewServeMux() + mux.HandleFunc("/", s.handle) + s.httpServer = &http.Server{Handler: mux, ReadHeaderTimeout: time.Second} + for _, opt := range opts { + if opt != nil { + opt(s) + } + } + return s +} + +func (s *SQSServer) Run() error { + if err := s.httpServer.Serve(s.listen); err != nil && !errors.Is(err, http.ErrServerClosed) { + return errors.WithStack(err) + } + return nil +} + +func (s *SQSServer) Stop() { + if s.httpServer != nil { + _ = s.httpServer.Shutdown(context.Background()) + } +} + +func (s *SQSServer) handle(w http.ResponseWriter, r *http.Request) { + if s.serveHealthz(w, r) { + return + } + if s.proxyToLeader(w, r) { + return + } + + if r.Method != http.MethodPost { + w.Header().Set("Allow", http.MethodPost) + writeSQSError(w, http.StatusMethodNotAllowed, sqsErrMalformedRequest, "SQS JSON protocol requires POST") + return + } + + if authErr := s.authorizeSQSRequest(r); authErr != nil { + writeSQSError(w, authErr.Status, authErr.Code, authErr.Message) + return + } + + target := r.Header.Get("X-Amz-Target") + handler, ok := s.targetHandlers[target] + if !ok { + writeSQSError(w, http.StatusBadRequest, sqsErrInvalidAction, "unsupported SQS target: "+target) + return + } + handler(w, r) +} + +func (s *SQSServer) serveHealthz(w http.ResponseWriter, r *http.Request) bool { + if r == nil || r.URL == nil { + return false + } + switch r.URL.Path { + case sqsHealthPath: + if r.Body != nil { + r.Body = http.MaxBytesReader(w, r.Body, sqsHealthMaxRequestBodyBytes) + } + serveSQSHealthz(w, r) + return true + case sqsLeaderHealthPath: + if r.Body != nil { + r.Body = http.MaxBytesReader(w, r.Body, sqsHealthMaxRequestBodyBytes) + } + s.serveSQSLeaderHealthz(w, r) + return true + default: + return false + } +} + +func serveSQSHealthz(w http.ResponseWriter, r *http.Request) { + if !writeSQSHealthMethod(w, r) { + return + } + writeSQSHealthBody(w, r, http.StatusOK, "ok\n") +} + +func (s *SQSServer) serveSQSLeaderHealthz(w http.ResponseWriter, r *http.Request) { + if !writeSQSHealthMethod(w, r) { + return + } + if isVerifiedSQSLeader(s.coordinator) { + writeSQSHealthBody(w, r, http.StatusOK, "ok\n") + return + } + writeSQSHealthBody(w, r, http.StatusServiceUnavailable, "not leader\n") +} + +func isVerifiedSQSLeader(coordinator kv.Coordinator) bool { + if coordinator == nil || !coordinator.IsLeader() { + return false + } + return coordinator.VerifyLeader() == nil +} + +func writeSQSHealthMethod(w http.ResponseWriter, r *http.Request) bool { + switch r.Method { + case http.MethodGet, http.MethodHead: + return true + default: + w.Header().Set("Allow", "GET, HEAD") + http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) + return false + } +} + +func writeSQSHealthBody(w http.ResponseWriter, r *http.Request, statusCode int, body string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(statusCode) + if r.Method == http.MethodHead { + return + } + _, _ = io.WriteString(w, body) +} + +// proxyToLeader forwards the HTTP request to the current SQS leader when this +// node is not the Raft leader. Returns true if the request was handled. +func (s *SQSServer) proxyToLeader(w http.ResponseWriter, r *http.Request) bool { + return proxyHTTPRequestToLeader(s.coordinator, s.leaderSQS, sqsLeaderProxyErrorWriter, w, r) +} + +func sqsLeaderProxyErrorWriter(w http.ResponseWriter, status int, message string) { + writeSQSError(w, status, sqsErrServiceUnavailable, message) +} + +// notImplemented returns a handler that responds with a JSON-protocol +// NotImplemented error so clients get a clean signal while the real handlers +// are still being built out. +func (s *SQSServer) notImplemented(op string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, _ *http.Request) { + writeSQSError(w, http.StatusNotImplemented, sqsErrNotImplemented, op+" is not implemented yet") + } +} + +// writeSQSError emits an SQS JSON-protocol error envelope. AWS returns: +// +// { "__type": "", "message": "" } +// +// with Content-Type application/x-amz-json-1.0 and the x-amzn-ErrorType header +// set to the code. SDKs key off x-amzn-ErrorType first, the body second. +func writeSQSError(w http.ResponseWriter, status int, code string, message string) { + resp := map[string]string{"message": message} + if code != "" { + resp["__type"] = code + w.Header().Set("x-amzn-ErrorType", code) + } + w.Header().Set("Content-Type", sqsContentTypeJSON) + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(resp) +} diff --git a/adapter/sqs_auth.go b/adapter/sqs_auth.go new file mode 100644 index 000000000..c8f49ea77 --- /dev/null +++ b/adapter/sqs_auth.go @@ -0,0 +1,178 @@ +package adapter + +import ( + "bytes" + "crypto/sha256" + "crypto/subtle" + "encoding/hex" + "io" + "net/http" + "strings" + "time" +) + +// Service name used in the SigV4 credential scope. +const sqsSigV4Service = "sqs" + +// sqsAuthError is the SQS-flavored counterpart to s3AuthError. The outer +// handler turns these into the AWS JSON-1.0 error envelope via +// writeSQSError. +type sqsAuthError struct { + Status int + Code string + Message string +} + +// WithSQSRegion configures the signing region the adapter expects inside +// the Credential scope. Empty values retain the previous setting. +func WithSQSRegion(region string) SQSServerOption { + return func(s *SQSServer) { + if s == nil || strings.TrimSpace(region) == "" { + return + } + s.region = strings.TrimSpace(region) + } +} + +// WithSQSStaticCredentials supplies the access-key → secret map the +// adapter will accept. Passing an empty map disables authorization +// entirely (open endpoint), matching the S3 adapter's behavior for +// unit-test friendliness. +func WithSQSStaticCredentials(creds map[string]string) SQSServerOption { + return func(s *SQSServer) { + if s == nil || len(creds) == 0 { + return + } + s.staticCreds = make(map[string]string, len(creds)) + for accessKeyID, secretAccessKey := range creds { + accessKeyID = strings.TrimSpace(accessKeyID) + secretAccessKey = strings.TrimSpace(secretAccessKey) + if accessKeyID == "" || secretAccessKey == "" { + continue + } + s.staticCreds[accessKeyID] = secretAccessKey + } + if len(s.staticCreds) == 0 { + s.staticCreds = nil + } + } +} + +func (s *SQSServer) effectiveRegion() string { + if s == nil || strings.TrimSpace(s.region) == "" { + return "us-east-1" + } + return s.region +} + +// authorizeSQSRequest verifies SigV4 on an SQS JSON-protocol request. The +// flow mirrors S3's header-based path (no presigned URLs or streaming +// payloads) but restricts the credential scope service to "sqs". +// +// Returns nil when authorization is either unconfigured (open endpoint) or +// the signature matches. Returns *sqsAuthError otherwise, suitable for +// writeSQSError. +// +// The function must consume and replace r.Body so that downstream handlers +// can still parse the JSON payload. +func (s *SQSServer) authorizeSQSRequest(r *http.Request) *sqsAuthError { + if s == nil || r == nil || len(s.staticCreds) == 0 { + return nil + } + authHeader := strings.TrimSpace(r.Header.Get("Authorization")) + if authHeader == "" { + return &sqsAuthError{Status: http.StatusForbidden, Code: "MissingAuthenticationToken", Message: "missing Authorization header"} + } + parsed, authErr := s.validateSQSAuthScope(authHeader) + if authErr != nil { + return authErr + } + secretAccessKey, ok := s.staticCreds[parsed.AccessKeyID] + if !ok { + return &sqsAuthError{Status: http.StatusForbidden, Code: "InvalidClientTokenId", Message: "unknown access key"} + } + signingTime, authErr := validateSQSAuthDate(r, parsed) + if authErr != nil { + return authErr + } + payloadHash, authErr := drainAndHashSQSBody(r) + if authErr != nil { + return authErr + } + return verifySQSSignatureMatches(r, authHeader, parsed.AccessKeyID, secretAccessKey, s.effectiveRegion(), signingTime, payloadHash) +} + +// validateSQSAuthScope parses the Credential scope and rejects scopes +// whose algorithm/service/region do not match this adapter. +func (s *SQSServer) validateSQSAuthScope(authHeader string) (sigv4AuthorizationHeader, *sqsAuthError) { + parsed, err := parseSigV4AuthorizationHeader(authHeader) + if err != nil { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "invalid Authorization header"} + } + if parsed.Algorithm != sigv4Algorithm { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "unsupported signature algorithm"} + } + if parsed.Service != sqsSigV4Service { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "credential scope service must be sqs"} + } + if parsed.Region != s.effectiveRegion() { + return sigv4AuthorizationHeader{}, &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "credential scope region does not match server region"} + } + return parsed, nil +} + +// validateSQSAuthDate parses X-Amz-Date and verifies it matches the +// Credential scope date and is within the allowed clock-skew window. +func validateSQSAuthDate(r *http.Request, parsed sigv4AuthorizationHeader) (time.Time, *sqsAuthError) { + amzDate := strings.TrimSpace(r.Header.Get("X-Amz-Date")) + if amzDate == "" { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "MissingAuthenticationToken", Message: "missing x-amz-date header"} + } + signingTime, err := time.Parse(sigv4DateHeaderFormat, amzDate) + if err != nil { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "invalid x-amz-date header"} + } + if parsed.Date != signingTime.UTC().Format(sigv4ScopeDateFormat) { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "credential scope date does not match x-amz-date"} + } + skew := time.Now().UTC().Sub(signingTime.UTC()) + if skew < 0 { + skew = -skew + } + if skew > sigv4RequestTimeMaxSkew { + return time.Time{}, &sqsAuthError{Status: http.StatusForbidden, Code: "RequestTimeTooSkewed", Message: "The difference between the request time and the server's time is too large"} + } + return signingTime, nil +} + +// drainAndHashSQSBody reads the request body so the signer reproduces the +// client's payload hash, then replaces r.Body so handler code can re-read +// it afterwards. +func drainAndHashSQSBody(r *http.Request) (string, *sqsAuthError) { + body, err := io.ReadAll(http.MaxBytesReader(nil, r.Body, sqsMaxRequestBodyBytes)) + if err != nil { + return "", &sqsAuthError{Status: http.StatusForbidden, Code: "IncompleteSignature", Message: "failed to read request body for signature verification"} + } + r.Body = io.NopCloser(bytes.NewReader(body)) + sum := sha256.Sum256(body) + return hex.EncodeToString(sum[:]), nil +} + +// verifySQSSignatureMatches rebuilds the expected Authorization header and +// compares its hex signature to the one the client sent. +// +// The restricted builder is used so Go's transport-added headers +// (Accept-Encoding etc.) do not leak into the canonical request. +func verifySQSSignatureMatches(r *http.Request, authHeader, accessKeyID, secretAccessKey, region string, signingTime time.Time, payloadHash string) *sqsAuthError { + signedHeaders := extractSigV4SignedHeaders(authHeader) + expectedAuth, err := buildSigV4AuthorizationHeaderRestricted(r, accessKeyID, secretAccessKey, sqsSigV4Service, region, signingTime, payloadHash, signedHeaders) + if err != nil { + return &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "failed to verify request signature"} + } + gotSig := extractSigV4Signature(authHeader) + expectedSig := extractSigV4Signature(expectedAuth) + if gotSig == "" || subtle.ConstantTimeCompare([]byte(gotSig), []byte(expectedSig)) != 1 { + return &sqsAuthError{Status: http.StatusForbidden, Code: "SignatureDoesNotMatch", Message: "request signature does not match"} + } + return nil +} diff --git a/adapter/sqs_auth_test.go b/adapter/sqs_auth_test.go new file mode 100644 index 000000000..a8f815513 --- /dev/null +++ b/adapter/sqs_auth_test.go @@ -0,0 +1,230 @@ +package adapter + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "io" + "net" + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + json "github.com/goccy/go-json" +) + +const ( + testSQSAccessKey = "AKIATESTSQSAAAAAAAAA" + testSQSSecretKey = "test-secret-key/xxxxxxxxxxxxxxxxxxxxxxxx" + testSQSRegion = "us-west-2" +) + +// startAuthedSQSServer brings up an SQSServer with static credentials so +// we can exercise the full SigV4 verifier. No coordinator is wired in, so +// the server only executes the auth + health paths (handler requests land +// on unknown targets, returning 400 InvalidAction after auth passes). +func startAuthedSQSServer(t *testing.T) string { + t.Helper() + lc := &net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := NewSQSServer(listener, nil, nil, + WithSQSRegion(testSQSRegion), + WithSQSStaticCredentials(map[string]string{testSQSAccessKey: testSQSSecretKey}), + ) + go func() { + _ = srv.Run() + }() + t.Cleanup(func() { + srv.Stop() + }) + return "http://" + listener.Addr().String() +} + +// signSQSRequest signs an SQS request with the AWS SDK v4 signer against +// the supplied credentials, exactly as a real AWS SDK client would. +func signSQSRequest(t *testing.T, base, target string, body []byte, signingTime time.Time, accessKey, secretKey string) *http.Request { + region := testSQSRegion + t.Helper() + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", target) + sum := sha256.Sum256(body) + payloadHash := hex.EncodeToString(sum[:]) + + signer := v4.NewSigner(func(opts *v4.SignerOptions) { + opts.DisableURIPathEscaping = true + }) + if err := signer.SignHTTP(context.Background(), aws.Credentials{ + AccessKeyID: accessKey, + SecretAccessKey: secretKey, + Source: "test", + }, req, payloadHash, sqsSigV4Service, region, signingTime.UTC()); err != nil { + t.Fatalf("sign: %v", err) + } + return req +} + +func doReq(t *testing.T, req *http.Request) (int, map[string]string) { + t.Helper() + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) + out := map[string]string{} + if len(bytes.TrimSpace(raw)) > 0 { + _ = json.Unmarshal(raw, &out) + } + return resp.StatusCode, out +} + +func TestSQSAuth_ValidSignaturePassesAuth(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + body := []byte(`{}`) + req := signSQSRequest(t, base, "AmazonSQS.NoSuchOperation", body, time.Now().UTC(), + testSQSAccessKey, testSQSSecretKey) + + // Auth should pass; then the target is unknown, so the dispatcher + // returns InvalidAction — *not* SignatureDoesNotMatch. That precise + // ordering is what we care about: auth runs first, dispatch second. + status, body2 := doReq(t, req) + if status != http.StatusBadRequest { + t.Fatalf("status: got %d body %v", status, body2) + } + if body2["__type"] != sqsErrInvalidAction { + t.Fatalf("error type: got %q want %q", body2["__type"], sqsErrInvalidAction) + } +} + +func TestSQSAuth_MissingAuthorizationHeaderRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader([]byte("{}"))) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", sqsListQueuesTarget) + + status, body := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, body) + } + if body["__type"] != "MissingAuthenticationToken" { + t.Fatalf("error type: got %q", body["__type"]) + } +} + +func TestSQSAuth_WrongServiceScopeRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + body := []byte(`{}`) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", sqsListQueuesTarget) + sum := sha256.Sum256(body) + payloadHash := hex.EncodeToString(sum[:]) + signer := v4.NewSigner() + // Sign with service "s3" instead of "sqs". + if err := signer.SignHTTP(context.Background(), aws.Credentials{ + AccessKeyID: testSQSAccessKey, + SecretAccessKey: testSQSSecretKey, + }, req, payloadHash, "s3", testSQSRegion, time.Now().UTC()); err != nil { + t.Fatalf("sign: %v", err) + } + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "SignatureDoesNotMatch" { + t.Fatalf("error type: got %q", out["__type"]) + } +} + +func TestSQSAuth_UnknownAccessKeyRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + req := signSQSRequest(t, base, sqsListQueuesTarget, []byte("{}"), time.Now().UTC(), + "AKIAUNKNOWNEEEEEEEEE", "wrong-secret-keyxxxxxxxxxxxxxxxxxxxxxxxx") + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "InvalidClientTokenId" { + t.Fatalf("error type: got %q", out["__type"]) + } +} + +func TestSQSAuth_TamperedBodyRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + // Sign the original body, then send a different one — the server must + // hash the body it actually receives and catch the mismatch. + req := signSQSRequest(t, base, sqsListQueuesTarget, []byte(`{"QueueNamePrefix":"a"}`), time.Now().UTC(), + testSQSAccessKey, testSQSSecretKey) + req.Body = io.NopCloser(bytes.NewReader([]byte(`{"QueueNamePrefix":"b"}`))) + req.ContentLength = int64(len(`{"QueueNamePrefix":"b"}`)) + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "SignatureDoesNotMatch" { + t.Fatalf("error type: got %q body %v", out["__type"], out) + } +} + +func TestSQSAuth_ClockSkewRejected(t *testing.T) { + t.Parallel() + base := startAuthedSQSServer(t) + old := time.Now().UTC().Add(-30 * time.Minute) + req := signSQSRequest(t, base, sqsListQueuesTarget, []byte("{}"), old, + testSQSAccessKey, testSQSSecretKey) + + status, out := doReq(t, req) + if status != http.StatusForbidden { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != "RequestTimeTooSkewed" { + t.Fatalf("error type: got %q body %v", out["__type"], out) + } +} + +func TestSQSAuth_NoCredentialsMeansOpenEndpoint(t *testing.T) { + t.Parallel() + // No creds configured → scaffold test server already exercises this + // path (no Authorization header, no 403). This test explicitly + // double-checks: the dispatch runs even without auth, landing on + // InvalidAction for an unknown target. + base := startTestSQSServer(t) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, base+"/", bytes.NewReader([]byte("{}"))) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("Content-Type", sqsContentTypeJSON) + req.Header.Set("X-Amz-Target", "AmazonSQS.NoSuchOperation") + + status, out := doReq(t, req) + if status != http.StatusBadRequest { + t.Fatalf("status: got %d body %v", status, out) + } + if out["__type"] != sqsErrInvalidAction { + t.Fatalf("error type: got %q", out["__type"]) + } +} diff --git a/adapter/sqs_catalog.go b/adapter/sqs_catalog.go new file mode 100644 index 000000000..9abb74f1b --- /dev/null +++ b/adapter/sqs_catalog.go @@ -0,0 +1,855 @@ +package adapter + +import ( + "bytes" + "context" + "io" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +// AWS SQS defaults, reproduced from the public API reference so clients that +// send an empty Attributes map get standard behavior. +const ( + sqsDefaultVisibilityTimeoutSeconds = 30 + sqsDefaultRetentionSeconds = 345600 // 4 days + sqsDefaultDelaySeconds = 0 + sqsDefaultReceiveMessageWaitSeconds = 0 + sqsDefaultMaximumMessageSize = 262144 // 256 KiB + + sqsMaxVisibilityTimeoutSeconds = 43200 // 12 hours + sqsMaxRetentionSeconds = 1209600 // 14 days + sqsMinRetentionSeconds = 60 + sqsMaxDelaySeconds = 900 + sqsMaxReceiveMessageWaitSeconds = 20 + sqsMinMaximumMessageSize = 1024 + sqsMaximumAllowedMaximumMessageSize = 262144 + sqsMaxQueueNameLength = 80 + sqsFIFOQueueNameSuffix = ".fifo" + sqsListQueuesDefaultMaxResults = 1000 + sqsListQueuesHardMaxResults = 1000 + sqsQueueScanPageLimit = 1024 +) + +// AWS error codes specific to the queue catalog. +const ( + sqsErrValidation = "InvalidParameterValue" + sqsErrMissingParameter = "MissingParameter" + sqsErrQueueNameExists = "QueueNameExists" + sqsErrQueueDoesNotExist = "AWS.SimpleQueueService.NonExistentQueue" + sqsErrInvalidAttributeName = "InvalidAttributeName" + sqsErrInvalidAttributeValue = "InvalidAttributeValue" +) + +var sqsQueueNamePattern = regexp.MustCompile(`^[a-zA-Z0-9_-]{1,80}(\.fifo)?$`) + +// sqsQueueMeta is the Go mirror of the queue metadata record persisted at +// !sqs|queue|meta|. Serialized as JSON with a short magic prefix so +// future schema migrations can switch encoding without reading back garbage. +type sqsQueueMeta struct { + Name string `json:"name"` + Generation uint64 `json:"generation"` + CreatedAtHLC uint64 `json:"created_at_hlc,omitempty"` + IsFIFO bool `json:"is_fifo,omitempty"` + ContentBasedDedup bool `json:"content_based_dedup,omitempty"` + VisibilityTimeoutSeconds int64 `json:"visibility_timeout_seconds"` + MessageRetentionSeconds int64 `json:"message_retention_seconds"` + DelaySeconds int64 `json:"delay_seconds"` + ReceiveMessageWaitSeconds int64 `json:"receive_message_wait_seconds"` + MaximumMessageSize int64 `json:"maximum_message_size"` + RedrivePolicy string `json:"redrive_policy,omitempty"` + Tags map[string]string `json:"tags,omitempty"` +} + +var storedSQSMetaPrefix = []byte{0x00, 'S', 'Q', 0x01} + +func encodeSQSQueueMeta(m *sqsQueueMeta) ([]byte, error) { + body, err := json.Marshal(m) + if err != nil { + return nil, errors.WithStack(err) + } + out := make([]byte, 0, len(storedSQSMetaPrefix)+len(body)) + out = append(out, storedSQSMetaPrefix...) + out = append(out, body...) + return out, nil +} + +func decodeSQSQueueMeta(b []byte) (*sqsQueueMeta, error) { + if !bytes.HasPrefix(b, storedSQSMetaPrefix) { + return nil, errors.New("unrecognized sqs meta format") + } + var m sqsQueueMeta + if err := json.Unmarshal(b[len(storedSQSMetaPrefix):], &m); err != nil { + return nil, errors.WithStack(err) + } + return &m, nil +} + +// sqsAPIError is a typed error that captures the HTTP status and AWS error +// code so handler helpers can fail deep in the call chain and let the top +// level render a consistent envelope via writeSQSErrorFromErr. +type sqsAPIError struct { + status int + errorType string + message string +} + +func (e *sqsAPIError) Error() string { + if e == nil { + return "" + } + if e.message != "" { + return e.message + } + return http.StatusText(e.status) +} + +func newSQSAPIError(status int, errorType string, message string) error { + return &sqsAPIError{status: status, errorType: errorType, message: message} +} + +func writeSQSErrorFromErr(w http.ResponseWriter, err error) { + var apiErr *sqsAPIError + if errors.As(err, &apiErr) { + writeSQSError(w, apiErr.status, apiErr.errorType, apiErr.message) + return + } + writeSQSError(w, http.StatusInternalServerError, sqsErrInternalFailure, err.Error()) +} + +func writeSQSJSON(w http.ResponseWriter, payload any) { + w.Header().Set("Content-Type", sqsContentTypeJSON) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(payload) +} + +// ------------------------ input decoding ------------------------ + +type sqsCreateQueueInput struct { + QueueName string `json:"QueueName"` + Attributes map[string]string `json:"Attributes"` + Tags map[string]string `json:"tags"` +} + +type sqsDeleteQueueInput struct { + QueueUrl string `json:"QueueUrl"` +} + +type sqsListQueuesInput struct { + QueueNamePrefix string `json:"QueueNamePrefix"` + MaxResults int `json:"MaxResults"` + NextToken string `json:"NextToken"` +} + +type sqsGetQueueUrlInput struct { + QueueName string `json:"QueueName"` +} + +type sqsGetQueueAttributesInput struct { + QueueUrl string `json:"QueueUrl"` + AttributeNames []string `json:"AttributeNames"` +} + +type sqsSetQueueAttributesInput struct { + QueueUrl string `json:"QueueUrl"` + Attributes map[string]string `json:"Attributes"` +} + +func decodeSQSJSONInput(r *http.Request, v any) error { + body, err := io.ReadAll(http.MaxBytesReader(nil, r.Body, sqsMaxRequestBodyBytes)) + if err != nil { + return newSQSAPIError(http.StatusBadRequest, sqsErrMalformedRequest, err.Error()) + } + if len(bytes.TrimSpace(body)) == 0 { + // Empty body is legal for some actions; leave v at its zero value. + return nil + } + if err := json.Unmarshal(body, v); err != nil { + return newSQSAPIError(http.StatusBadRequest, sqsErrMalformedRequest, err.Error()) + } + return nil +} + +// ------------------------ URL helpers ------------------------ + +// queueURL builds the AWS-compatible URL echoed back to clients. We follow +// the endpoint the client addressed so SDKs that re-sign subsequent requests +// against the same URL keep working behind reverse proxies. +func (s *SQSServer) queueURL(r *http.Request, queueName string) string { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + host := r.Host + if host == "" && s.listen != nil { + host = s.listen.Addr().String() + } + return scheme + "://" + host + "/" + queueName +} + +// scanContinuationSentinel is appended to a key to produce the exclusive +// upper bound for the next scan page — i.e. the smallest key strictly +// greater than k. +const scanContinuationSentinel = 0x00 + +func nextScanCursorAfter(key []byte) []byte { + return append(bytes.Clone(key), scanContinuationSentinel) +} + +func queueNameFromURL(queueUrl string) (string, error) { + if strings.TrimSpace(queueUrl) == "" { + return "", newSQSAPIError(http.StatusBadRequest, sqsErrMissingParameter, "missing QueueUrl") + } + parsed, err := url.Parse(queueUrl) + if err != nil { + return "", newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "invalid QueueUrl: "+err.Error()) + } + name := strings.TrimPrefix(parsed.Path, "/") + if name == "" { + return "", newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "QueueUrl path is empty") + } + // Strip an AWS-style account-id prefix so http://host/12345/MyQueue works. + if idx := strings.LastIndex(name, "/"); idx >= 0 { + name = name[idx+1:] + } + return name, nil +} + +func validateQueueName(name string) error { + name = strings.TrimSpace(name) + if name == "" { + return newSQSAPIError(http.StatusBadRequest, sqsErrMissingParameter, "missing QueueName") + } + if len(name) > sqsMaxQueueNameLength { + return newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "QueueName too long") + } + if !sqsQueueNamePattern.MatchString(name) { + return newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "QueueName contains invalid characters") + } + return nil +} + +// ------------------------ attribute parsing ------------------------ + +func parseAttributesIntoMeta(name string, attrs map[string]string) (*sqsQueueMeta, error) { + meta := &sqsQueueMeta{ + Name: name, + IsFIFO: strings.HasSuffix(name, sqsFIFOQueueNameSuffix), + VisibilityTimeoutSeconds: sqsDefaultVisibilityTimeoutSeconds, + MessageRetentionSeconds: sqsDefaultRetentionSeconds, + DelaySeconds: sqsDefaultDelaySeconds, + ReceiveMessageWaitSeconds: sqsDefaultReceiveMessageWaitSeconds, + MaximumMessageSize: sqsDefaultMaximumMessageSize, + } + if err := applyAttributes(meta, attrs); err != nil { + return nil, err + } + // FifoQueue attribute is authoritative if explicitly set; otherwise the + // .fifo suffix implies true and callers without the suffix get a + // standard queue. + if v, ok := attrs["FifoQueue"]; ok { + b, err := strconv.ParseBool(v) + if err != nil { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, "FifoQueue must be a boolean") + } + if b && !strings.HasSuffix(name, sqsFIFOQueueNameSuffix) { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "FIFO queue name must end in .fifo") + } + if !b && strings.HasSuffix(name, sqsFIFOQueueNameSuffix) { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrValidation, "Queue name ends in .fifo but FifoQueue=false") + } + meta.IsFIFO = b + } + return meta, nil +} + +// attributeApplier writes one attribute into meta. Keeping one applier per +// attribute keeps applyAttributes trivial to dispatch through. +type attributeApplier func(meta *sqsQueueMeta, value string) error + +var sqsAttributeAppliers = map[string]attributeApplier{ + // FifoQueue is applied after applyAttributes returns (see + // parseAttributesIntoMeta) because it needs cross-field validation + // against the queue name. + "FifoQueue": func(_ *sqsQueueMeta, _ string) error { return nil }, + "VisibilityTimeout": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("VisibilityTimeout", v, 0, sqsMaxVisibilityTimeoutSeconds) + if err != nil { + return err + } + m.VisibilityTimeoutSeconds = n + return nil + }, + "MessageRetentionPeriod": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("MessageRetentionPeriod", v, sqsMinRetentionSeconds, sqsMaxRetentionSeconds) + if err != nil { + return err + } + m.MessageRetentionSeconds = n + return nil + }, + "DelaySeconds": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("DelaySeconds", v, 0, sqsMaxDelaySeconds) + if err != nil { + return err + } + m.DelaySeconds = n + return nil + }, + "ReceiveMessageWaitTimeSeconds": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("ReceiveMessageWaitTimeSeconds", v, 0, sqsMaxReceiveMessageWaitSeconds) + if err != nil { + return err + } + m.ReceiveMessageWaitSeconds = n + return nil + }, + "MaximumMessageSize": func(m *sqsQueueMeta, v string) error { + n, err := parseIntAttr("MaximumMessageSize", v, sqsMinMaximumMessageSize, sqsMaximumAllowedMaximumMessageSize) + if err != nil { + return err + } + m.MaximumMessageSize = n + return nil + }, + "ContentBasedDeduplication": func(m *sqsQueueMeta, v string) error { + b, err := strconv.ParseBool(v) + if err != nil { + return newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, "ContentBasedDeduplication must be a boolean") + } + m.ContentBasedDedup = b + return nil + }, + "RedrivePolicy": func(m *sqsQueueMeta, v string) error { + m.RedrivePolicy = v + return nil + }, +} + +// applyAttributes writes every entry in attrs into meta, returning a typed +// SQS error on the first unknown key or out-of-range value. +func applyAttributes(meta *sqsQueueMeta, attrs map[string]string) error { + for k, v := range attrs { + apply, ok := sqsAttributeAppliers[k] + if !ok { + return newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeName, "unsupported attribute: "+k) + } + if err := apply(meta, v); err != nil { + return err + } + } + return nil +} + +func parseIntAttr(name, value string, minVal, maxVal int64) (int64, error) { + n, err := strconv.ParseInt(strings.TrimSpace(value), 10, 64) + if err != nil { + return 0, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, name+" must be an integer") + } + if n < minVal || n > maxVal { + return 0, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, name+" is out of range") + } + return n, nil +} + +// attributesEqual is used by CreateQueue for idempotency: calling the same +// CreateQueue twice with identical attributes is a no-op; differing values +// must fail with QueueNameExists. +func attributesEqual(a, b *sqsQueueMeta) bool { + if a == nil || b == nil { + return false + } + return a.IsFIFO == b.IsFIFO && + a.ContentBasedDedup == b.ContentBasedDedup && + a.VisibilityTimeoutSeconds == b.VisibilityTimeoutSeconds && + a.MessageRetentionSeconds == b.MessageRetentionSeconds && + a.DelaySeconds == b.DelaySeconds && + a.ReceiveMessageWaitSeconds == b.ReceiveMessageWaitSeconds && + a.MaximumMessageSize == b.MaximumMessageSize && + a.RedrivePolicy == b.RedrivePolicy +} + +// ------------------------ storage primitives ------------------------ + +func (s *SQSServer) nextTxnReadTS(ctx context.Context) uint64 { + maxTS := uint64(0) + if p, ok := s.store.(globalLastCommitTSProvider); ok { + maxTS = p.GlobalLastCommitTS(ctx) + } else if s.store != nil { + maxTS = s.store.LastCommitTS() + } + if s.coordinator != nil { + if clock := s.coordinator.Clock(); clock != nil && maxTS > 0 { + clock.Observe(maxTS) + } + } + if maxTS == 0 { + return 1 + } + return maxTS +} + +func (s *SQSServer) loadQueueMetaAt(ctx context.Context, queueName string, ts uint64) (*sqsQueueMeta, bool, error) { + b, err := s.store.GetAt(ctx, sqsQueueMetaKey(queueName), ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, false, nil + } + return nil, false, errors.WithStack(err) + } + meta, err := decodeSQSQueueMeta(b) + if err != nil { + return nil, false, err + } + return meta, true, nil +} + +func (s *SQSServer) loadQueueGenerationAt(ctx context.Context, queueName string, ts uint64) (uint64, error) { + b, err := s.store.GetAt(ctx, sqsQueueGenKey(queueName), ts) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return 0, nil + } + return 0, errors.WithStack(err) + } + n, err := strconv.ParseUint(string(b), 10, 64) + if err != nil { + return 0, errors.WithStack(err) + } + return n, nil +} + +// ------------------------ handlers ------------------------ + +func (s *SQSServer) createQueue(w http.ResponseWriter, r *http.Request) { + var in sqsCreateQueueInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := validateQueueName(in.QueueName); err != nil { + writeSQSErrorFromErr(w, err) + return + } + requested, err := parseAttributesIntoMeta(in.QueueName, in.Attributes) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + requested.Tags = in.Tags + + if err := s.createQueueWithRetry(r.Context(), requested); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]string{"QueueUrl": s.queueURL(r, in.QueueName)}) +} + +func (s *SQSServer) createQueueWithRetry(ctx context.Context, requested *sqsQueueMeta) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + done, err := s.tryCreateQueueOnce(ctx, requested) + if err == nil && done { + return nil + } + if err != nil && !isRetryableTransactWriteError(err) { + return err + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "create queue retry attempts exhausted") +} + +// tryCreateQueueOnce runs one read/check/dispatch pass. The first bool reports +// whether the caller should stop retrying: true means the queue now exists +// with the requested attributes, false means the dispatch hit a retryable +// conflict and should be retried after backoff. +func (s *SQSServer) tryCreateQueueOnce(ctx context.Context, requested *sqsQueueMeta) (bool, error) { + readTS := s.nextTxnReadTS(ctx) + existing, exists, err := s.loadQueueMetaAt(ctx, requested.Name, readTS) + if err != nil { + return false, errors.WithStack(err) + } + if exists { + if attributesEqual(existing, requested) { + return true, nil + } + return false, newSQSAPIError(http.StatusBadRequest, sqsErrQueueNameExists, "queue already exists with different attributes") + } + lastGen, err := s.loadQueueGenerationAt(ctx, requested.Name, readTS) + if err != nil { + return false, errors.WithStack(err) + } + requested.Generation = lastGen + 1 + if clock := s.coordinator.Clock(); clock != nil { + requested.CreatedAtHLC = clock.Current() + } + metaBytes, err := encodeSQSQueueMeta(requested) + if err != nil { + return false, errors.WithStack(err) + } + metaKey := sqsQueueMetaKey(requested.Name) + genKey := sqsQueueGenKey(requested.Name) + // StartTS pins OCC to the snapshot we took the existence + generation + // read at, and ReadKeys cover both the meta and generation records so + // a concurrent CreateQueue that committed between our read and our + // dispatch causes ErrWriteConflict and the retry loop re-reads. + // Without this, two races could both decide "queue missing" and both + // write their own generation, leaving the later write on top of a + // record the loser never observed. + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{metaKey, genKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: metaKey, Value: metaBytes}, + {Op: kv.Put, Key: genKey, Value: []byte(strconv.FormatUint(requested.Generation, 10))}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err != nil { + return false, errors.WithStack(err) + } + return true, nil +} + +func (s *SQSServer) deleteQueue(w http.ResponseWriter, r *http.Request) { + var in sqsDeleteQueueInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + name, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := s.deleteQueueWithRetry(r.Context(), name); err != nil { + writeSQSErrorFromErr(w, err) + return + } + // SQS DeleteQueue returns 200 with an empty body. + writeSQSJSON(w, map[string]any{}) +} + +func (s *SQSServer) deleteQueueWithRetry(ctx context.Context, queueName string) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := s.nextTxnReadTS(ctx) + _, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newSQSAPIError(http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + } + + // Bump the generation counter so any stragglers under the old + // generation are unreachable by routing. Actual message cleanup + // lands in a follow-up PR along with the message keyspace. + lastGen, err := s.loadQueueGenerationAt(ctx, queueName, readTS) + if err != nil { + return errors.WithStack(err) + } + metaKey := sqsQueueMetaKey(queueName) + genKey := sqsQueueGenKey(queueName) + // StartTS + ReadKeys fence against a concurrent CreateQueue / + // SetQueueAttributes landing between our load and dispatch. + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{metaKey, genKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: metaKey}, + {Op: kv.Put, Key: genKey, Value: []byte(strconv.FormatUint(lastGen+1, 10))}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "delete queue retry attempts exhausted") +} + +func (s *SQSServer) listQueues(w http.ResponseWriter, r *http.Request) { + var in sqsListQueuesInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + maxResults := clampListQueuesMaxResults(in.MaxResults) + + names, err := s.scanQueueNames(r.Context()) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + sort.Strings(names) + filtered := filterByPrefix(names, in.QueueNamePrefix) + start := resolveListQueuesStart(filtered, in.NextToken) + + end := start + maxResults + truncated := end < len(filtered) + if !truncated { + end = len(filtered) + } + page := filtered[start:end] + + urls := make([]string, 0, len(page)) + for _, n := range page { + urls = append(urls, s.queueURL(r, n)) + } + resp := map[string]any{"QueueUrls": urls} + if truncated && len(page) > 0 { + resp["NextToken"] = encodeSQSSegment(page[len(page)-1]) + } + writeSQSJSON(w, resp) +} + +func clampListQueuesMaxResults(requested int) int { + if requested <= 0 { + return sqsListQueuesDefaultMaxResults + } + if requested > sqsListQueuesHardMaxResults { + return sqsListQueuesHardMaxResults + } + return requested +} + +func filterByPrefix(names []string, prefix string) []string { + if prefix == "" { + return names + } + out := names[:0] + for _, n := range names { + if strings.HasPrefix(n, prefix) { + out = append(out, n) + } + } + return out +} + +// resolveListQueuesStart decodes the NextToken boundary and returns the index +// of the first name strictly greater than it. A malformed token behaves as +// "start from the beginning" to match AWS's lenient behavior. +func resolveListQueuesStart(names []string, token string) int { + if token == "" { + return 0 + } + boundary, err := decodeSQSSegment(token) + if err != nil { + return 0 + } + for i, n := range names { + if n > boundary { + return i + } + } + return len(names) +} + +func (s *SQSServer) scanQueueNames(ctx context.Context) ([]string, error) { + prefix := []byte(SqsQueueMetaPrefix) + end := prefixScanEnd(prefix) + start := bytes.Clone(prefix) + readTS := s.nextTxnReadTS(ctx) + var names []string + for { + kvs, err := s.store.ScanAt(ctx, start, end, sqsQueueScanPageLimit, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if len(kvs) == 0 { + break + } + for _, kvp := range kvs { + if !bytes.HasPrefix(kvp.Key, prefix) { + return names, nil + } + name, ok := queueNameFromMetaKey(kvp.Key) + if !ok { + continue + } + names = append(names, name) + } + if len(kvs) < sqsQueueScanPageLimit { + break + } + start = nextScanCursorAfter(kvs[len(kvs)-1].Key) + if end != nil && bytes.Compare(start, end) > 0 { + break + } + } + return names, nil +} + +func (s *SQSServer) getQueueUrl(w http.ResponseWriter, r *http.Request) { + var in sqsGetQueueUrlInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := validateQueueName(in.QueueName); err != nil { + writeSQSErrorFromErr(w, err) + return + } + _, exists, err := s.loadQueueMetaAt(r.Context(), in.QueueName, s.nextTxnReadTS(r.Context())) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if !exists { + writeSQSError(w, http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + return + } + writeSQSJSON(w, map[string]string{"QueueUrl": s.queueURL(r, in.QueueName)}) +} + +func (s *SQSServer) getQueueAttributes(w http.ResponseWriter, r *http.Request) { + var in sqsGetQueueAttributesInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + name, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + meta, exists, err := s.loadQueueMetaAt(r.Context(), name, s.nextTxnReadTS(r.Context())) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if !exists { + writeSQSError(w, http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + return + } + selection := selectedAttributeNames(in.AttributeNames) + attrs := queueMetaToAttributes(meta, selection) + writeSQSJSON(w, map[string]any{"Attributes": attrs}) +} + +// selectedAttributeNames returns a set of attribute names to include in the +// response. An empty selection, or any entry equal to "All", expands to +// every supported attribute. +func selectedAttributeNames(req []string) map[string]bool { + selection := map[string]bool{} + if len(req) == 0 { + return nil + } + for _, n := range req { + if n == "All" { + return nil + } + selection[n] = true + } + return selection +} + +func queueMetaToAttributes(meta *sqsQueueMeta, selection map[string]bool) map[string]string { + all := map[string]string{ + "VisibilityTimeout": strconv.FormatInt(meta.VisibilityTimeoutSeconds, 10), + "MessageRetentionPeriod": strconv.FormatInt(meta.MessageRetentionSeconds, 10), + "DelaySeconds": strconv.FormatInt(meta.DelaySeconds, 10), + "ReceiveMessageWaitTimeSeconds": strconv.FormatInt(meta.ReceiveMessageWaitSeconds, 10), + "MaximumMessageSize": strconv.FormatInt(meta.MaximumMessageSize, 10), + "FifoQueue": strconv.FormatBool(meta.IsFIFO), + "ContentBasedDeduplication": strconv.FormatBool(meta.ContentBasedDedup), + } + if meta.RedrivePolicy != "" { + all["RedrivePolicy"] = meta.RedrivePolicy + } + if selection == nil { + return all + } + out := make(map[string]string, len(selection)) + for k := range selection { + if v, ok := all[k]; ok { + out[k] = v + } + } + return out +} + +func (s *SQSServer) setQueueAttributes(w http.ResponseWriter, r *http.Request) { + var in sqsSetQueueAttributesInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + name, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := s.setQueueAttributesWithRetry(r.Context(), name, in.Attributes); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]any{}) +} + +func (s *SQSServer) setQueueAttributesWithRetry(ctx context.Context, queueName string, attrs map[string]string) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + readTS := s.nextTxnReadTS(ctx) + meta, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + return errors.WithStack(err) + } + if !exists { + return newSQSAPIError(http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + } + if err := applyAttributes(meta, attrs); err != nil { + return err + } + metaBytes, err := encodeSQSQueueMeta(meta) + if err != nil { + return errors.WithStack(err) + } + metaKey := sqsQueueMetaKey(queueName) + // StartTS + ReadKeys prevent two concurrent SetQueueAttributes + // from both reading the same old meta and the later dispatch + // clobbering the earlier commit's changes. + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{metaKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: metaKey, Value: metaBytes}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "set queue attributes retry attempts exhausted") +} diff --git a/adapter/sqs_catalog_test.go b/adapter/sqs_catalog_test.go new file mode 100644 index 000000000..3dac5fd54 --- /dev/null +++ b/adapter/sqs_catalog_test.go @@ -0,0 +1,294 @@ +package adapter + +import ( + "bytes" + "context" + "io" + "net/http" + "strings" + "testing" + + json "github.com/goccy/go-json" +) + +// callSQS routes a JSON-protocol request to the given node's SQS endpoint. +// The helper exists so tests read like "createQueue → 200 with a URL" rather +// than having to hand-build X-Amz-Target envelopes every time. +func callSQS(t *testing.T, node Node, target string, in any) (int, map[string]any) { + t.Helper() + body, err := json.Marshal(in) + if err != nil { + t.Fatalf("marshal: %v", err) + } + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, + "http://"+node.sqsAddress+"/", bytes.NewReader(body)) + if err != nil { + t.Fatalf("request: %v", err) + } + req.Header.Set("X-Amz-Target", target) + req.Header.Set("Content-Type", sqsContentTypeJSON) + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + defer resp.Body.Close() + raw, _ := io.ReadAll(resp.Body) + out := map[string]any{} + if len(bytes.TrimSpace(raw)) > 0 { + if err := json.Unmarshal(raw, &out); err != nil { + t.Fatalf("decode %q: %v", string(raw), err) + } + } + return resp.StatusCode, out +} + +func sqsLeaderNode(t *testing.T, nodes []Node) Node { + t.Helper() + for _, n := range nodes { + if n.engine != nil && n.engine.Leader().Address == n.raftAddress { + return n + } + } + return nodes[0] +} + +func TestSQSServer_CatalogCreateGetList(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + // CreateQueue: 200 with a QueueUrl that ends in the queue name. + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "orders", + }) + if status != http.StatusOK { + t.Fatalf("create: status %d body %v", status, out) + } + url, _ := out["QueueUrl"].(string) + if !strings.HasSuffix(url, "/orders") { + t.Fatalf("QueueUrl %q does not end in /orders", url) + } + + // GetQueueUrl returns the same URL. + status, out = callSQS(t, node, sqsGetQueueUrlTarget, map[string]any{ + "QueueName": "orders", + }) + if status != http.StatusOK { + t.Fatalf("getQueueUrl: status %d body %v", status, out) + } + if got, _ := out["QueueUrl"].(string); got != url { + t.Fatalf("GetQueueUrl=%q want %q", got, url) + } + + // ListQueues sees it. + status, out = callSQS(t, node, sqsListQueuesTarget, map[string]any{}) + if status != http.StatusOK { + t.Fatalf("list: status %d body %v", status, out) + } + urls, _ := out["QueueUrls"].([]any) + foundList := false + for _, u := range urls { + if s, _ := u.(string); s == url { + foundList = true + break + } + } + if !foundList { + t.Fatalf("ListQueues did not include %q; got %v", url, urls) + } +} + +func TestSQSServer_CatalogCreateIsIdempotent(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + in := map[string]any{ + "QueueName": "idempotent", + "Attributes": map[string]string{ + "VisibilityTimeout": "60", + }, + } + status1, out1 := callSQS(t, node, sqsCreateQueueTarget, in) + if status1 != http.StatusOK { + t.Fatalf("first create: %d %v", status1, out1) + } + // Second call with the same attributes must succeed with the same URL. + status2, out2 := callSQS(t, node, sqsCreateQueueTarget, in) + if status2 != http.StatusOK { + t.Fatalf("second create (same attrs): %d %v", status2, out2) + } + if out1["QueueUrl"] != out2["QueueUrl"] { + t.Fatalf("idempotent create returned different URLs: %v vs %v", out1, out2) + } + + // Third call with differing attributes must fail with QueueNameExists. + changed := map[string]any{ + "QueueName": "idempotent", + "Attributes": map[string]string{"VisibilityTimeout": "120"}, + } + status3, out3 := callSQS(t, node, sqsCreateQueueTarget, changed) + if status3 != http.StatusBadRequest { + t.Fatalf("differing-attrs create: got %d want 400; body %v", status3, out3) + } + if got, _ := out3["__type"].(string); got != sqsErrQueueNameExists { + t.Fatalf("differing-attrs error type: got %q want %q", got, sqsErrQueueNameExists) + } +} + +func TestSQSServer_CatalogGetAndSetAttributes(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "attrs", + }) + if status != http.StatusOK { + t.Fatalf("create: %d %v", status, out) + } + url, _ := out["QueueUrl"].(string) + + status, out = callSQS(t, node, sqsGetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "AttributeNames": []string{"All"}, + }) + if status != http.StatusOK { + t.Fatalf("getAttrs: %d %v", status, out) + } + attrs, _ := out["Attributes"].(map[string]any) + if attrs["VisibilityTimeout"] != "30" { + t.Fatalf("default VisibilityTimeout = %v, want 30", attrs["VisibilityTimeout"]) + } + + status, out = callSQS(t, node, sqsSetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "Attributes": map[string]string{ + "VisibilityTimeout": "90", + "DelaySeconds": "5", + }, + }) + if status != http.StatusOK { + t.Fatalf("setAttrs: %d %v", status, out) + } + + _, out = callSQS(t, node, sqsGetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "AttributeNames": []string{"VisibilityTimeout", "DelaySeconds"}, + }) + attrs, _ = out["Attributes"].(map[string]any) + if attrs["VisibilityTimeout"] != "90" || attrs["DelaySeconds"] != "5" { + t.Fatalf("updated attrs = %v, want VisibilityTimeout=90 DelaySeconds=5", attrs) + } +} + +func TestSQSServer_CatalogDelete(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + _, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "deleteme", + }) + url, _ := out["QueueUrl"].(string) + + status, out := callSQS(t, node, sqsDeleteQueueTarget, map[string]any{ + "QueueUrl": url, + }) + if status != http.StatusOK { + t.Fatalf("delete: %d %v", status, out) + } + + // GetQueueUrl after delete returns NonExistentQueue. + status, out = callSQS(t, node, sqsGetQueueUrlTarget, map[string]any{ + "QueueName": "deleteme", + }) + if status != http.StatusBadRequest { + t.Fatalf("getQueueUrl after delete: got %d want 400; body %v", status, out) + } + if got, _ := out["__type"].(string); got != sqsErrQueueDoesNotExist { + t.Fatalf("error type: got %q want %q", got, sqsErrQueueDoesNotExist) + } + + // DeleteQueue on an unknown queue also returns NonExistentQueue. + status, _ = callSQS(t, node, sqsDeleteQueueTarget, map[string]any{ + "QueueUrl": url, + }) + if status != http.StatusBadRequest { + t.Fatalf("second delete: got %d want 400", status) + } +} + +func TestSQSServer_CatalogFIFOValidation(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + + // FIFO name with FifoQueue=false is rejected. + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "bad.fifo", + "Attributes": map[string]string{"FifoQueue": "false"}, + }) + if status != http.StatusBadRequest { + t.Fatalf("mismatch name/FifoQueue: got %d want 400; body %v", status, out) + } + + // Non-FIFO name with FifoQueue=true is rejected. + status, out = callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "plain", + "Attributes": map[string]string{"FifoQueue": "true"}, + }) + if status != http.StatusBadRequest { + t.Fatalf("plain name + FifoQueue=true: got %d want 400; body %v", status, out) + } + + // Valid FIFO succeeds and the attribute is echoed back. + status, out = callSQS(t, node, sqsCreateQueueTarget, map[string]any{ + "QueueName": "events.fifo", + "Attributes": map[string]string{"FifoQueue": "true"}, + }) + if status != http.StatusOK { + t.Fatalf("FIFO create: %d %v", status, out) + } + url, _ := out["QueueUrl"].(string) + _, out = callSQS(t, node, sqsGetQueueAttributesTarget, map[string]any{ + "QueueUrl": url, + "AttributeNames": []string{"FifoQueue"}, + }) + attrs, _ := out["Attributes"].(map[string]any) + if attrs["FifoQueue"] != "true" { + t.Fatalf("FIFO flag not persisted: %v", attrs) + } +} + +func TestSQSServer_CatalogKeyEncoding(t *testing.T) { + t.Parallel() + for _, name := range []string{"", "a", "hello world", "queue.fifo", strings.Repeat("x", 80)} { + encoded := encodeSQSSegment(name) + decoded, err := decodeSQSSegment(encoded) + if err != nil { + t.Fatalf("decode %q: %v", name, err) + } + if decoded != name { + t.Fatalf("round-trip %q -> %q -> %q", name, encoded, decoded) + } + } + + // queueNameFromMetaKey round-trips sqsQueueMetaKey. + name := "round.trip.fifo" + key := sqsQueueMetaKey(name) + got, ok := queueNameFromMetaKey(key) + if !ok || got != name { + t.Fatalf("queueNameFromMetaKey(sqsQueueMetaKey(%q)) = (%q, %v), want (%q, true)", name, got, ok, name) + } + + // Unknown prefixes are rejected. + if _, ok := queueNameFromMetaKey([]byte("random")); ok { + t.Fatal("queueNameFromMetaKey should reject non-catalog keys") + } +} diff --git a/adapter/sqs_keys.go b/adapter/sqs_keys.go new file mode 100644 index 000000000..f0f402665 --- /dev/null +++ b/adapter/sqs_keys.go @@ -0,0 +1,58 @@ +package adapter + +import ( + "encoding/base64" + "strings" + + "github.com/cockroachdb/errors" +) + +// SQS keyspace prefixes. Kept in sync with the naming in +// docs/design/2026_04_24_proposed_sqs_compatible_adapter.md. +const ( + // SqsQueueMetaPrefix prefixes queue-metadata records. + SqsQueueMetaPrefix = "!sqs|queue|meta|" + // SqsQueueGenPrefix prefixes the per-queue monotonic generation counter. + // Bumped on DeleteQueue / PurgeQueue so keys from an older incarnation of + // the same queue name cannot leak into a newly created queue. + SqsQueueGenPrefix = "!sqs|queue|gen|" +) + +func sqsQueueMetaKey(queueName string) []byte { + return []byte(SqsQueueMetaPrefix + encodeSQSSegment(queueName)) +} + +func sqsQueueGenKey(queueName string) []byte { + return []byte(SqsQueueGenPrefix + encodeSQSSegment(queueName)) +} + +// encodeSQSSegment emits a printable, byte-ordered-unique representation of a +// queue name. Base64 raw URL encoding matches the encoding the DynamoDB +// adapter uses for table segments (see encodeDynamoSegment) so that operators +// reading raw keys from Pebble see the same shape across both adapters. +func encodeSQSSegment(v string) string { + return base64.RawURLEncoding.EncodeToString([]byte(v)) +} + +func decodeSQSSegment(v string) (string, error) { + b, err := base64.RawURLEncoding.DecodeString(v) + if err != nil { + return "", errors.WithStack(err) + } + return string(b), nil +} + +// queueNameFromMetaKey pulls the queue name out of a !sqs|queue|meta| +// key. The second return reports success so callers can skip keys that were +// not written by this adapter. +func queueNameFromMetaKey(key []byte) (string, bool) { + enc, ok := strings.CutPrefix(string(key), SqsQueueMetaPrefix) + if !ok || enc == "" { + return "", false + } + name, err := decodeSQSSegment(enc) + if err != nil { + return "", false + } + return name, true +} diff --git a/adapter/sqs_messages.go b/adapter/sqs_messages.go new file mode 100644 index 000000000..fe8ba5443 --- /dev/null +++ b/adapter/sqs_messages.go @@ -0,0 +1,805 @@ +package adapter + +import ( + "bytes" + "context" + "crypto/md5" //nolint:gosec // AWS SQS ETag specifies MD5; not used as a cryptographic primitive. + "crypto/rand" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "net/http" + "sort" + "strconv" + "strings" + "time" + + "github.com/bootjp/elastickv/kv" + "github.com/bootjp/elastickv/store" + "github.com/cockroachdb/errors" + json "github.com/goccy/go-json" +) + +// Message-keyspace prefixes. The data record holds the message body and +// state; the visibility index is a separate, visible_at-sorted key family +// so ReceiveMessage can find the next visible message with a single bounded +// prefix scan. +const ( + SqsMsgDataPrefix = "!sqs|msg|data|" + SqsMsgVisPrefix = "!sqs|msg|vis|" +) + +const ( + sqsMessageIDBytes = 16 + sqsReceiptTokenBytes = 16 + sqsReceiveDefaultMaxMessages = 1 + sqsReceiveHardMaxMessages = 10 + sqsReceiveScanOverfetchFactor = 2 + sqsChangeVisibilityMaxSeconds = sqsMaxVisibilityTimeoutSeconds + sqsVisScanPageLimit = 1024 + // Version byte prefixed to encoded receipt handles. Bumped when the + // on-wire handle format changes so old handles fail to decode loudly. + sqsReceiptHandleVersion = byte(0x01) + // Byte sizes used when pre-sizing key buffers. The exact value is not + // critical; it only avoids one append growth for typical queue/ID + // lengths. + sqsKeyCapSmall = 32 + sqsKeyCapLarge = 64 + // Conversion factors for SQS second-granularity inputs. + sqsMillisPerSecond = 1000 +) + +// AWS error codes specific to message operations. +const ( + sqsErrReceiptHandleInvalid = "ReceiptHandleIsInvalid" + sqsErrInvalidReceiptHandle = "InvalidReceiptHandle" + sqsErrMessageTooLong = "InvalidParameterValue" + sqsErrMessageNotInflight = "MessageNotInflight" +) + +// sqsMessageRecord mirrors !sqs|msg|data|... on disk. Visibility state +// (VisibleAtMillis, CurrentReceiptToken, ReceiveCount) lives here rather +// than in a side-record so a single OCC transaction can rotate it. +type sqsMessageRecord struct { + MessageID string `json:"message_id"` + Body []byte `json:"body"` + MD5OfBody string `json:"md5_of_body"` + MessageAttributes map[string]string `json:"message_attributes,omitempty"` + SenderID string `json:"sender_id,omitempty"` + SendTimestampMillis int64 `json:"send_timestamp_millis"` + AvailableAtMillis int64 `json:"available_at_millis"` + VisibleAtMillis int64 `json:"visible_at_millis"` + ReceiveCount int64 `json:"receive_count"` + FirstReceiveMillis int64 `json:"first_receive_millis,omitempty"` + CurrentReceiptToken []byte `json:"current_receipt_token"` + QueueGeneration uint64 `json:"queue_generation"` +} + +var storedSQSMsgPrefix = []byte{0x00, 'S', 'M', 0x01} + +func encodeSQSMessageRecord(m *sqsMessageRecord) ([]byte, error) { + body, err := json.Marshal(m) + if err != nil { + return nil, errors.WithStack(err) + } + out := make([]byte, 0, len(storedSQSMsgPrefix)+len(body)) + out = append(out, storedSQSMsgPrefix...) + out = append(out, body...) + return out, nil +} + +func decodeSQSMessageRecord(b []byte) (*sqsMessageRecord, error) { + if !bytes.HasPrefix(b, storedSQSMsgPrefix) { + return nil, errors.New("unrecognized sqs message format") + } + var m sqsMessageRecord + if err := json.Unmarshal(b[len(storedSQSMsgPrefix):], &m); err != nil { + return nil, errors.WithStack(err) + } + return &m, nil +} + +// ------------------------ key helpers ------------------------ + +func sqsMsgDataKey(queueName string, gen uint64, messageID string) []byte { + buf := make([]byte, 0, len(SqsMsgDataPrefix)+sqsKeyCapLarge) + buf = append(buf, SqsMsgDataPrefix...) + buf = append(buf, encodeSQSSegment(queueName)...) + buf = appendU64(buf, gen) + buf = append(buf, encodeSQSSegment(messageID)...) + return buf +} + +func sqsMsgVisKey(queueName string, gen uint64, visibleAtMillis int64, messageID string) []byte { + buf := make([]byte, 0, len(SqsMsgVisPrefix)+sqsKeyCapLarge) + buf = append(buf, SqsMsgVisPrefix...) + buf = append(buf, encodeSQSSegment(queueName)...) + buf = appendU64(buf, gen) + buf = appendU64(buf, uint64MaxZero(visibleAtMillis)) + buf = append(buf, encodeSQSSegment(messageID)...) + return buf +} + +func sqsMsgVisPrefixForQueue(queueName string, gen uint64) []byte { + buf := make([]byte, 0, len(SqsMsgVisPrefix)+sqsKeyCapSmall) + buf = append(buf, SqsMsgVisPrefix...) + buf = append(buf, encodeSQSSegment(queueName)...) + buf = appendU64(buf, gen) + return buf +} + +// uint64MaxZero clamps negative int64 (which never happens for wall-clock +// timestamps but would silently overflow under uint64() cast) to zero. +func uint64MaxZero(v int64) uint64 { + if v < 0 { + return 0 + } + return uint64(v) +} + +func sqsMsgVisScanBounds(queueName string, gen uint64, maxVisibleAtMillis int64) (start, end []byte) { + prefix := sqsMsgVisPrefixForQueue(queueName, gen) + start = append(bytes.Clone(prefix), zeroU64()...) + upper := uint64MaxZero(maxVisibleAtMillis) + if upper < ^uint64(0) { + upper++ + } + end = append(bytes.Clone(prefix), encodedU64(upper)...) + return start, end +} + +func appendU64(dst []byte, v uint64) []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], v) + return append(dst, buf[:]...) +} + +func encodedU64(v uint64) []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], v) + return buf[:] +} + +func zeroU64() []byte { + var buf [8]byte + return buf[:] +} + +// ------------------------ message id + receipt handle ------------------------ + +func newMessageIDHex() (string, error) { + var buf [sqsMessageIDBytes]byte + if _, err := rand.Read(buf[:]); err != nil { + return "", errors.WithStack(err) + } + return hex.EncodeToString(buf[:]), nil +} + +func newReceiptToken() ([]byte, error) { + buf := make([]byte, sqsReceiptTokenBytes) + if _, err := rand.Read(buf); err != nil { + return nil, errors.WithStack(err) + } + return buf, nil +} + +// encodeReceiptHandle packs (queue_gen, message_id, receipt_token) into a +// single opaque blob. Format: +// +// [ 0 ] byte version = 0x01 +// [ 1..9 ] uint64 queue_gen (BE) +// [ 9..25 ] 16 bytes message_id (raw bytes from hex decode) +// [ 25..41 ] 16 bytes receipt_token +// +// The result is base64-urlsafe (no padding) so it passes through JSON and +// HTTP query parameters untouched. +func encodeReceiptHandle(queueGen uint64, messageIDHex string, receiptToken []byte) (string, error) { + if len(receiptToken) != sqsReceiptTokenBytes { + return "", errors.New("receipt token has wrong length") + } + idBytes, err := hex.DecodeString(messageIDHex) + if err != nil || len(idBytes) != sqsMessageIDBytes { + return "", errors.New("message id has wrong format") + } + buf := make([]byte, 0, 1+8+sqsMessageIDBytes+sqsReceiptTokenBytes) + buf = append(buf, sqsReceiptHandleVersion) + buf = appendU64(buf, queueGen) + buf = append(buf, idBytes...) + buf = append(buf, receiptToken...) + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +type decodedReceiptHandle struct { + QueueGeneration uint64 + MessageIDHex string + ReceiptToken []byte +} + +func decodeReceiptHandle(raw string) (*decodedReceiptHandle, error) { + b, err := base64.RawURLEncoding.DecodeString(raw) + if err != nil { + return nil, errors.WithStack(err) + } + want := 1 + 8 + sqsMessageIDBytes + sqsReceiptTokenBytes + if len(b) != want || b[0] != sqsReceiptHandleVersion { + return nil, errors.New("receipt handle length or version mismatch") + } + out := &decodedReceiptHandle{ + QueueGeneration: binary.BigEndian.Uint64(b[1:9]), + MessageIDHex: hex.EncodeToString(b[9 : 9+sqsMessageIDBytes]), + ReceiptToken: bytes.Clone(b[9+sqsMessageIDBytes:]), + } + return out, nil +} + +// ------------------------ input decoding ------------------------ + +type sqsSendMessageInput struct { + QueueUrl string `json:"QueueUrl"` + MessageBody string `json:"MessageBody"` + DelaySeconds *int64 `json:"DelaySeconds,omitempty"` + MessageAttributes map[string]string `json:"MessageAttributes,omitempty"` +} + +type sqsReceiveMessageInput struct { + QueueUrl string `json:"QueueUrl"` + MaxNumberOfMessages int `json:"MaxNumberOfMessages,omitempty"` + VisibilityTimeout *int64 `json:"VisibilityTimeout,omitempty"` + WaitTimeSeconds *int64 `json:"WaitTimeSeconds,omitempty"` +} + +type sqsDeleteMessageInput struct { + QueueUrl string `json:"QueueUrl"` + ReceiptHandle string `json:"ReceiptHandle"` +} + +type sqsChangeVisibilityInput struct { + QueueUrl string `json:"QueueUrl"` + ReceiptHandle string `json:"ReceiptHandle"` + VisibilityTimeout int64 `json:"VisibilityTimeout"` +} + +// ------------------------ handlers ------------------------ + +func (s *SQSServer) sendMessage(w http.ResponseWriter, r *http.Request) { + var in sqsSendMessageInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + queueName, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + meta, apiErr := s.loadQueueMetaForSend(r.Context(), queueName, []byte(in.MessageBody)) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + delay, apiErr := resolveSendDelay(meta, in.DelaySeconds) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + rec, recordBytes, apiErr := buildSendRecord(meta, in, delay) + if apiErr != nil { + writeSQSErrorFromErr(w, apiErr) + return + } + + dataKey := sqsMsgDataKey(queueName, meta.Generation, rec.MessageID) + visKey := sqsMsgVisKey(queueName, meta.Generation, rec.AvailableAtMillis, rec.MessageID) + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Put, Key: dataKey, Value: recordBytes}, + {Op: kv.Put, Key: visKey, Value: []byte(rec.MessageID)}, + }, + } + if _, err := s.coordinator.Dispatch(r.Context(), req); err != nil { + writeSQSErrorFromErr(w, err) + return + } + + writeSQSJSON(w, map[string]string{ + "MessageId": rec.MessageID, + "MD5OfMessageBody": rec.MD5OfBody, + "MD5OfMessageAttributes": md5OfAttributesHex(in.MessageAttributes), + }) +} + +func (s *SQSServer) loadQueueMetaForSend(ctx context.Context, queueName string, body []byte) (*sqsQueueMeta, error) { + readTS := s.nextTxnReadTS(ctx) + meta, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + if !exists { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + } + if int64(len(body)) > meta.MaximumMessageSize { + return nil, newSQSAPIError(http.StatusBadRequest, sqsErrMessageTooLong, "message body exceeds MaximumMessageSize") + } + return meta, nil +} + +func resolveSendDelay(meta *sqsQueueMeta, requested *int64) (int64, error) { + delay := meta.DelaySeconds + if requested == nil { + return delay, nil + } + if *requested < 0 || *requested > sqsMaxDelaySeconds { + return 0, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidAttributeValue, "DelaySeconds out of range") + } + return *requested, nil +} + +func buildSendRecord(meta *sqsQueueMeta, in sqsSendMessageInput, delay int64) (*sqsMessageRecord, []byte, error) { + messageID, err := newMessageIDHex() + if err != nil { + return nil, nil, errors.WithStack(err) + } + token, err := newReceiptToken() + if err != nil { + return nil, nil, errors.WithStack(err) + } + now := time.Now().UnixMilli() + availableAt := now + delay*sqsMillisPerSecond + body := []byte(in.MessageBody) + rec := &sqsMessageRecord{ + MessageID: messageID, + Body: body, + MD5OfBody: sqsMD5Hex(body), + MessageAttributes: in.MessageAttributes, + SendTimestampMillis: now, + AvailableAtMillis: availableAt, + VisibleAtMillis: availableAt, + CurrentReceiptToken: token, + QueueGeneration: meta.Generation, + } + recordBytes, err := encodeSQSMessageRecord(rec) + if err != nil { + return nil, nil, errors.WithStack(err) + } + return rec, recordBytes, nil +} + +//nolint:cyclop // AWS ReceiveMessage branches on per-message eligibility; splitting further just moves the branching around. +func (s *SQSServer) receiveMessage(w http.ResponseWriter, r *http.Request) { + var in sqsReceiveMessageInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + queueName, err := queueNameFromURL(in.QueueUrl) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + ctx := r.Context() + + // Use LeaseRead to fence this scan against a leader that silently lost + // quorum mid-request. When the lease is warm this is a local + // wall-clock compare; when it is cold it falls back to a full + // LinearizableRead. + if _, err := kv.LeaseReadThrough(s.coordinator, ctx); err != nil { + writeSQSErrorFromErr(w, err) + return + } + + readTS := s.nextTxnReadTS(ctx) + meta, exists, err := s.loadQueueMetaAt(ctx, queueName, readTS) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if !exists { + writeSQSError(w, http.StatusBadRequest, sqsErrQueueDoesNotExist, "queue does not exist") + return + } + max := clampReceiveMaxMessages(in.MaxNumberOfMessages) + visibilityTimeout := meta.VisibilityTimeoutSeconds + if in.VisibilityTimeout != nil { + if *in.VisibilityTimeout < 0 || *in.VisibilityTimeout > sqsChangeVisibilityMaxSeconds { + writeSQSError(w, http.StatusBadRequest, sqsErrInvalidAttributeValue, "VisibilityTimeout out of range") + return + } + visibilityTimeout = *in.VisibilityTimeout + } + + candidates, err := s.scanVisibleMessageCandidates(ctx, queueName, meta.Generation, max*sqsReceiveScanOverfetchFactor, readTS) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + delivered := s.rotateMessagesForDelivery(ctx, queueName, meta.Generation, candidates, visibilityTimeout, max, readTS) + writeSQSJSON(w, map[string]any{"Messages": delivered}) +} + +func clampReceiveMaxMessages(requested int) int { + if requested <= 0 { + return sqsReceiveDefaultMaxMessages + } + if requested > sqsReceiveHardMaxMessages { + return sqsReceiveHardMaxMessages + } + return requested +} + +// scanVisibleMessageCandidates returns vis-index entries with +// visible_at <= now, up to limit. Each entry carries the key (needed +// for the delete-old-vis step) and the message_id pointed at by its +// value. +type sqsMsgCandidate struct { + visKey []byte + messageID string +} + +func (s *SQSServer) scanVisibleMessageCandidates(ctx context.Context, queueName string, gen uint64, limit int, readTS uint64) ([]sqsMsgCandidate, error) { + if limit <= 0 { + return nil, nil + } + now := time.Now().UnixMilli() + start, end := sqsMsgVisScanBounds(queueName, gen, now) + page := limit + if page > sqsVisScanPageLimit { + page = sqsVisScanPageLimit + } + kvs, err := s.store.ScanAt(ctx, start, end, page, readTS) + if err != nil { + return nil, errors.WithStack(err) + } + out := make([]sqsMsgCandidate, 0, len(kvs)) + for _, kvp := range kvs { + out = append(out, sqsMsgCandidate{visKey: bytes.Clone(kvp.Key), messageID: string(kvp.Value)}) + } + return out, nil +} + +// rotateMessagesForDelivery runs an OCC transaction per candidate to +// rotate its visibility entry + receipt token. Failures on individual +// messages (races, write-conflict) are skipped rather than aborting the +// whole batch — AWS semantics allow ReceiveMessage to return fewer +// messages than requested. +func (s *SQSServer) rotateMessagesForDelivery( + ctx context.Context, + queueName string, + gen uint64, + candidates []sqsMsgCandidate, + visibilityTimeout int64, + max int, + readTS uint64, +) []map[string]any { + delivered := make([]map[string]any, 0, max) + for _, cand := range candidates { + if len(delivered) >= max { + break + } + msg, ok := s.tryDeliverCandidate(ctx, queueName, gen, cand, visibilityTimeout, readTS) + if !ok { + continue + } + delivered = append(delivered, msg) + } + return delivered +} + +func (s *SQSServer) tryDeliverCandidate( + ctx context.Context, + queueName string, + gen uint64, + cand sqsMsgCandidate, + visibilityTimeout int64, + readTS uint64, +) (map[string]any, bool) { + dataKey := sqsMsgDataKey(queueName, gen, cand.messageID) + raw, err := s.store.GetAt(ctx, dataKey, readTS) + if err != nil { + return nil, false + } + rec, err := decodeSQSMessageRecord(raw) + if err != nil { + return nil, false + } + + newToken, err := newReceiptToken() + if err != nil { + return nil, false + } + now := time.Now().UnixMilli() + newVisibleAt := now + visibilityTimeout*sqsMillisPerSecond + rec.VisibleAtMillis = newVisibleAt + rec.CurrentReceiptToken = newToken + rec.ReceiveCount++ + if rec.FirstReceiveMillis == 0 { + rec.FirstReceiveMillis = now + } + recordBytes, err := encodeSQSMessageRecord(rec) + if err != nil { + return nil, false + } + newVisKey := sqsMsgVisKey(queueName, gen, newVisibleAt, cand.messageID) + // StartTS pins the OCC read snapshot to the timestamp we actually + // loaded the record at. Without it, the coordinator assigns a newer + // StartTS at dispatch, so a concurrent rotation that committed + // AFTER our read but BEFORE the assigned StartTS would slip through + // ReadKeys validation and let this transaction double-deliver. + // ReadKeys cover both the visibility entry and the data record so + // any concurrent commit on either produces ErrWriteConflict. + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{cand.visKey, dataKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: cand.visKey}, + {Op: kv.Put, Key: newVisKey, Value: []byte(cand.messageID)}, + {Op: kv.Put, Key: dataKey, Value: recordBytes}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err != nil { + return nil, false + } + + handle, err := encodeReceiptHandle(gen, cand.messageID, newToken) + if err != nil { + return nil, false + } + return map[string]any{ + "MessageId": cand.messageID, + "ReceiptHandle": handle, + "Body": string(rec.Body), + "MD5OfBody": rec.MD5OfBody, + "Attributes": map[string]string{ + "ApproximateReceiveCount": strconv.FormatInt(rec.ReceiveCount, 10), + "SentTimestamp": strconv.FormatInt(rec.SendTimestampMillis, 10), + "ApproximateFirstReceiveTimestamp": strconv.FormatInt(rec.FirstReceiveMillis, 10), + }, + }, true +} + +func (s *SQSServer) deleteMessage(w http.ResponseWriter, r *http.Request) { + var in sqsDeleteMessageInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + queueName, handle, err := s.parseQueueAndReceipt(in.QueueUrl, in.ReceiptHandle) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := s.deleteMessageWithRetry(r.Context(), queueName, handle); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]any{}) +} + +// deleteMessageWithRetry runs the load-check-commit flow under one OCC +// budget. AWS SQS semantics: a stale receipt handle (message already +// gone, or token rotated by another consumer) is a 200 no-op, NOT an +// error. The only error cases are structural (malformed handle, caught +// before this function) and infrastructure (retry budget exhausted). +// ErrWriteConflict on the delete Dispatch means a concurrent rotation +// / delete landed between our read and our commit; we retry so the +// next pass either sees the rotated token (no-op success) or the +// missing record (no-op success). +func (s *SQSServer) deleteMessageWithRetry(ctx context.Context, queueName string, handle *decodedReceiptHandle) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + rec, dataKey, readTS, outcome, err := s.loadMessageForDelete(ctx, queueName, handle) + if err != nil { + return err + } + switch outcome { + case sqsDeleteNoOp: + return nil + case sqsDeleteProceed: + // fall through to commit below + } + visKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + // StartTS pins OCC to the snapshot we loaded the record at, so a + // concurrent rotation that commits after our load but before a + // coordinator-assigned StartTS cannot slip past ReadKeys. + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{dataKey, visKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: dataKey}, + {Op: kv.Del, Key: visKey}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "delete message retry attempts exhausted") +} + +// sqsDeleteOutcome is a ternary tag returned by loadMessageForDelete so +// the caller can cleanly distinguish the AWS-idempotent no-op case from +// the proceed-to-commit case without conflating them with errors. +type sqsDeleteOutcome int + +const ( + sqsDeleteProceed sqsDeleteOutcome = iota + sqsDeleteNoOp +) + +// loadMessageForDelete reads the message record and classifies the +// outcome for AWS-compatible DeleteMessage semantics: structural errors +// propagate; missing records and token mismatches return +// sqsDeleteNoOp; matching tokens return sqsDeleteProceed with the +// loaded record. The readTS it took the snapshot at is returned so the +// caller can pass it as StartTS on the OCC dispatch, pinning the +// read-write conflict detection window. +func (s *SQSServer) loadMessageForDelete(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, uint64, sqsDeleteOutcome, error) { + readTS := s.nextTxnReadTS(ctx) + dataKey := sqsMsgDataKey(queueName, handle.QueueGeneration, handle.MessageIDHex) + raw, err := s.store.GetAt(ctx, dataKey, readTS) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, nil, readTS, sqsDeleteNoOp, nil + } + return nil, nil, readTS, sqsDeleteProceed, errors.WithStack(err) + } + rec, err := decodeSQSMessageRecord(raw) + if err != nil { + return nil, nil, readTS, sqsDeleteProceed, errors.WithStack(err) + } + if !bytes.Equal(rec.CurrentReceiptToken, handle.ReceiptToken) { + return nil, nil, readTS, sqsDeleteNoOp, nil + } + return rec, dataKey, readTS, sqsDeleteProceed, nil +} + +func (s *SQSServer) changeMessageVisibility(w http.ResponseWriter, r *http.Request) { + var in sqsChangeVisibilityInput + if err := decodeSQSJSONInput(r, &in); err != nil { + writeSQSErrorFromErr(w, err) + return + } + if in.VisibilityTimeout < 0 || in.VisibilityTimeout > sqsChangeVisibilityMaxSeconds { + writeSQSError(w, http.StatusBadRequest, sqsErrInvalidAttributeValue, "VisibilityTimeout out of range") + return + } + queueName, handle, err := s.parseQueueAndReceipt(in.QueueUrl, in.ReceiptHandle) + if err != nil { + writeSQSErrorFromErr(w, err) + return + } + if err := s.changeVisibilityWithRetry(r.Context(), queueName, handle, in.VisibilityTimeout); err != nil { + writeSQSErrorFromErr(w, err) + return + } + writeSQSJSON(w, map[string]any{}) +} + +// changeVisibilityWithRetry runs the validate-and-swap flow under an OCC +// retry budget. ReadKeys cover the data record and the current vis +// entry; a concurrent receive or delete will bump their commitTS past +// our startTS and we re-validate. +func (s *SQSServer) changeVisibilityWithRetry(ctx context.Context, queueName string, handle *decodedReceiptHandle, newTimeout int64) error { + backoff := transactRetryInitialBackoff + deadline := time.Now().Add(transactRetryMaxDuration) + for range transactRetryMaxAttempts { + rec, dataKey, readTS, apiErr := s.loadAndVerifyMessage(ctx, queueName, handle) + if apiErr != nil { + return apiErr + } + now := time.Now().UnixMilli() + if rec.VisibleAtMillis <= now { + return newSQSAPIError(http.StatusBadRequest, sqsErrMessageNotInflight, "message is not currently in flight") + } + oldVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + rec.VisibleAtMillis = now + newTimeout*sqsMillisPerSecond + recordBytes, err := encodeSQSMessageRecord(rec) + if err != nil { + return errors.WithStack(err) + } + newVisKey := sqsMsgVisKey(queueName, handle.QueueGeneration, rec.VisibleAtMillis, rec.MessageID) + // StartTS pins OCC to the snapshot; without it the coordinator + // would auto-assign a newer StartTS and a concurrent receive / + // delete that commits between our load and dispatch could slip + // past the ReadKeys validation. + req := &kv.OperationGroup[kv.OP]{ + IsTxn: true, + StartTS: readTS, + ReadKeys: [][]byte{dataKey, oldVisKey}, + Elems: []*kv.Elem[kv.OP]{ + {Op: kv.Del, Key: oldVisKey}, + {Op: kv.Put, Key: newVisKey, Value: []byte(rec.MessageID)}, + {Op: kv.Put, Key: dataKey, Value: recordBytes}, + }, + } + if _, err := s.coordinator.Dispatch(ctx, req); err == nil { + return nil + } else if !isRetryableTransactWriteError(err) { + return errors.WithStack(err) + } + if err := waitRetryWithDeadline(ctx, deadline, backoff); err != nil { + return errors.WithStack(err) + } + backoff = nextTransactRetryBackoff(backoff) + } + return newSQSAPIError(http.StatusInternalServerError, sqsErrInternalFailure, "change visibility retry attempts exhausted") +} + +// parseQueueAndReceipt extracts the queue name and decodes the receipt +// handle from a DeleteMessage / ChangeMessageVisibility input. +func (s *SQSServer) parseQueueAndReceipt(queueUrl, receiptHandle string) (string, *decodedReceiptHandle, error) { + queueName, err := queueNameFromURL(queueUrl) + if err != nil { + return "", nil, err + } + handle, err := decodeReceiptHandle(receiptHandle) + if err != nil { + return "", nil, newSQSAPIError(http.StatusBadRequest, sqsErrReceiptHandleInvalid, "receipt handle is not parseable") + } + return queueName, handle, nil +} + +// loadAndVerifyMessage reads the data record for the given handle and +// verifies that the receipt token matches the current one on record. +// Returns the record, its key, the snapshot timestamp the read ran at, +// or a typed SQS error. Callers use the snapshot as StartTS on the +// OCC dispatch so concurrent commits cannot slip past ReadKeys. +func (s *SQSServer) loadAndVerifyMessage(ctx context.Context, queueName string, handle *decodedReceiptHandle) (*sqsMessageRecord, []byte, uint64, error) { + readTS := s.nextTxnReadTS(ctx) + dataKey := sqsMsgDataKey(queueName, handle.QueueGeneration, handle.MessageIDHex) + raw, err := s.store.GetAt(ctx, dataKey, readTS) + if err != nil { + if errors.Is(err, store.ErrKeyNotFound) { + return nil, nil, readTS, newSQSAPIError(http.StatusBadRequest, sqsErrReceiptHandleInvalid, "message not found") + } + return nil, nil, readTS, errors.WithStack(err) + } + rec, err := decodeSQSMessageRecord(raw) + if err != nil { + return nil, nil, readTS, errors.WithStack(err) + } + if !bytes.Equal(rec.CurrentReceiptToken, handle.ReceiptToken) { + return nil, nil, readTS, newSQSAPIError(http.StatusBadRequest, sqsErrInvalidReceiptHandle, "receipt handle token does not match") + } + return rec, dataKey, readTS, nil +} + +// ------------------------ small helpers ------------------------ + +func sqsMD5Hex(body []byte) string { + sum := md5.Sum(body) //nolint:gosec // AWS-specified ETag hashing, not a crypto primitive. + return hex.EncodeToString(sum[:]) +} + +// md5OfAttributesHex computes AWS's MD5 of a MessageAttributes map. The +// real AWS format canonicalizes names and types; this adapter only +// returns "" on an empty map and a simple concatenated hash otherwise +// (full canonicalization lives in a follow-up PR along with typed +// attribute values). +func md5OfAttributesHex(attrs map[string]string) string { + if len(attrs) == 0 { + return "" + } + keys := make([]string, 0, len(attrs)) + for k := range attrs { + keys = append(keys, k) + } + sort.Strings(keys) + var b strings.Builder + for _, k := range keys { + b.WriteString(k) + b.WriteString("=") + b.WriteString(attrs[k]) + b.WriteString(";") + } + return sqsMD5Hex([]byte(b.String())) +} diff --git a/adapter/sqs_messages_test.go b/adapter/sqs_messages_test.go new file mode 100644 index 000000000..cc86b6db9 --- /dev/null +++ b/adapter/sqs_messages_test.go @@ -0,0 +1,414 @@ +package adapter + +import ( + "encoding/hex" + "net/http" + "strconv" + "testing" + "time" +) + +// createSQSQueueForTest is a small helper so every message-path test does +// not have to repeat the "createQueue -> pull URL" dance. +func createSQSQueueForTest(t *testing.T, node Node, name string) string { + t.Helper() + status, out := callSQS(t, node, sqsCreateQueueTarget, map[string]any{"QueueName": name}) + if status != http.StatusOK { + t.Fatalf("createQueue %q: %d %v", name, status, out) + } + url, _ := out["QueueUrl"].(string) + if url == "" { + t.Fatalf("createQueue %q: empty URL", name) + } + return url +} + +func TestSQSServer_SendReceiveDelete(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "work") + + msgID := sendOneMessage(t, node, queueURL, "hello") + receipt := receiveOneMessage(t, node, queueURL, msgID, "hello") + expectNoMessagesVisible(t, node, queueURL, "after-receive") + deleteMessageOK(t, node, queueURL, receipt) + expectNoMessagesVisible(t, node, queueURL, "after-delete") +} + +// sendOneMessage sends a single message and returns the assigned MessageId. +// It asserts the response carries an MD5 over the body that matches what +// sqsMD5Hex would compute locally. +func sendOneMessage(t *testing.T, node Node, queueURL, body string) string { + t.Helper() + status, out := callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": body, + }) + if status != http.StatusOK { + t.Fatalf("send: %d %v", status, out) + } + msgID, _ := out["MessageId"].(string) + if msgID == "" { + t.Fatalf("no MessageId in %v", out) + } + if out["MD5OfMessageBody"] != sqsMD5Hex([]byte(body)) { + t.Fatalf("md5 mismatch: got %v want %q", out["MD5OfMessageBody"], sqsMD5Hex([]byte(body))) + } + return msgID +} + +func receiveOneMessage(t *testing.T, node Node, queueURL, wantID, wantBody string) string { + t.Helper() + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("got %d messages, want 1 (%v)", len(msgs), out) + } + m, _ := msgs[0].(map[string]any) + if m["Body"] != wantBody { + t.Fatalf("Body=%v want %q", m["Body"], wantBody) + } + if m["MessageId"] != wantID { + t.Fatalf("MessageId=%v want %q", m["MessageId"], wantID) + } + receipt, _ := m["ReceiptHandle"].(string) + if receipt == "" { + t.Fatalf("no ReceiptHandle in %v", m) + } + return receipt +} + +func expectNoMessagesVisible(t *testing.T, node Node, queueURL, tag string) { + t.Helper() + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + }) + if status != http.StatusOK { + t.Fatalf("%s receive: %d %v", tag, status, out) + } + if msgs, _ := out["Messages"].([]any); len(msgs) != 0 { + t.Fatalf("%s: got %d messages, want 0", tag, len(msgs)) + } +} + +func deleteMessageOK(t *testing.T, node Node, queueURL, receipt string) { + t.Helper() + status, out := callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": receipt, + }) + if status != http.StatusOK { + t.Fatalf("delete: %d %v", status, out) + } +} + +func TestSQSServer_DeleteWithStaleReceiptIsIdempotentNoOp(t *testing.T) { + t.Parallel() + // AWS SQS semantics: DeleteMessage with a stale receipt handle (token + // rotated under our feet, or record already gone) must return 200 + // success without deleting. SDK retry paths and batch workers rely on + // this so a retry after a visibility-expiry re-delivery does not fail + // loudly. Structurally malformed handles are still an error; + // token-only mismatches are not. + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "stale-receipt") + + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "x", + }) + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) == 0 { + t.Fatalf("no messages received") + } + goodHandle, _ := msgs[0].(map[string]any)["ReceiptHandle"].(string) + + // Flip a byte of the token so the stored token != handle token. + decoded, err := decodeReceiptHandle(goodHandle) + if err != nil { + t.Fatalf("decode: %v", err) + } + decoded.ReceiptToken[0] ^= 0xff + staleHandle, err := encodeReceiptHandle(decoded.QueueGeneration, decoded.MessageIDHex, decoded.ReceiptToken) + if err != nil { + t.Fatalf("encode: %v", err) + } + + // Delete with the stale handle must succeed (no-op) per AWS. + status, out = callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": staleHandle, + }) + if status != http.StatusOK { + t.Fatalf("delete with stale receipt: status=%d body=%v", status, out) + } + + // The real handle must still work — the stale delete must not have + // removed the in-flight message out from under the original consumer. + status, out = callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": goodHandle, + }) + if status != http.StatusOK { + t.Fatalf("delete with good receipt after stale no-op: %d %v", status, out) + } + + // A structurally malformed handle still errors out — only token + // mismatches are the idempotent no-op case. + status, out = callSQS(t, node, sqsDeleteMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": "not-base64-!!!", + }) + if status != http.StatusBadRequest { + t.Fatalf("malformed handle: status=%d body=%v", status, out) + } + if out["__type"] != sqsErrReceiptHandleInvalid { + t.Fatalf("error type for malformed handle: %q want %q", out["__type"], sqsErrReceiptHandleInvalid) + } +} + +func TestSQSServer_ReceiveBatchRespectsMaxMessages(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "batch") + + const sent = 5 + for i := 0; i < sent; i++ { + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "m-" + strconv.Itoa(i), + }) + } + + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 3, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 3 { + t.Fatalf("got %d messages, want 3 (%v)", len(msgs), out) + } + + // A second receive picks up the remaining two. + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 10, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive#2: %d %v", status, out) + } + msgs, _ = out["Messages"].([]any) + if len(msgs) != 2 { + t.Fatalf("got %d messages on second receive, want 2", len(msgs)) + } +} + +func TestSQSServer_DelaySecondsDefersDelivery(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "delayed") + + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "later", + "DelaySeconds": 2, + }) + + // Immediate receive must return nothing. + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + if msgs, _ := out["Messages"].([]any); len(msgs) != 0 { + t.Fatalf("expected 0 messages before delay elapsed, got %d", len(msgs)) + } + + // After the delay, the message becomes visible. + time.Sleep(2100 * time.Millisecond) + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("expected 1 message after delay, got %d (%v)", len(msgs), out) + } +} + +func TestSQSServer_VisibilityTimeoutExpiryMakesMessageVisibleAgain(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "revisible") + + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "retry-me", + }) + + // First receive with a very short visibility so this test doesn't + // sit idle for 30 seconds. + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("first receive: got %d messages want 1", len(msgs)) + } + first, _ := msgs[0].(map[string]any) + firstAttrs, _ := first["Attributes"].(map[string]any) + if firstAttrs["ApproximateReceiveCount"] != "1" { + t.Fatalf("first receive count = %v, want 1", firstAttrs) + } + + time.Sleep(1200 * time.Millisecond) + + // Second receive should see the same message with ReceiveCount=2. + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("second receive: %d %v", status, out) + } + msgs, _ = out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("second receive: got %d messages want 1", len(msgs)) + } + second, _ := msgs[0].(map[string]any) + if first["MessageId"] != second["MessageId"] { + t.Fatalf("second receive returned a different message: %v vs %v", first, second) + } + secondAttrs, _ := second["Attributes"].(map[string]any) + if secondAttrs["ApproximateReceiveCount"] != "2" { + t.Fatalf("second receive count = %v, want 2", secondAttrs) + } +} + +func TestSQSServer_ChangeMessageVisibility(t *testing.T) { + t.Parallel() + nodes, _, _ := createNode(t, 1) + defer shutdown(nodes) + node := sqsLeaderNode(t, nodes) + queueURL := createSQSQueueForTest(t, node, "chgvis") + + _, _ = callSQS(t, node, sqsSendMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MessageBody": "bumpy", + }) + status, out := callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + "VisibilityTimeout": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + msgs, _ := out["Messages"].([]any) + if len(msgs) != 1 { + t.Fatalf("no message received") + } + receipt, _ := msgs[0].(map[string]any)["ReceiptHandle"].(string) + + // Extend visibility to 60s. + status, out = callSQS(t, node, sqsChangeMessageVisibilityTarget, map[string]any{ + "QueueUrl": queueURL, + "ReceiptHandle": receipt, + "VisibilityTimeout": 60, + }) + if status != http.StatusOK { + t.Fatalf("change visibility: %d %v", status, out) + } + + // After the original 1s expiry, the message must still be hidden. + time.Sleep(1200 * time.Millisecond) + status, out = callSQS(t, node, sqsReceiveMessageTarget, map[string]any{ + "QueueUrl": queueURL, + "MaxNumberOfMessages": 1, + }) + if status != http.StatusOK { + t.Fatalf("receive: %d %v", status, out) + } + if msgs, _ := out["Messages"].([]any); len(msgs) != 0 { + t.Fatalf("expected 0 messages after visibility extended, got %d", len(msgs)) + } +} + +func TestSQSServer_ReceiptHandleCodecRoundTrip(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + gen uint64 + id string + }{ + {0, "00000000000000000000000000000000"}, + {1, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"}, + {42, "deadbeefdeadbeefdeadbeefdeadbeef"}, + } { + token := make([]byte, sqsReceiptTokenBytes) + for i := range token { + token[i] = byte(i) + } + h, err := encodeReceiptHandle(tc.gen, tc.id, token) + if err != nil { + t.Fatalf("encode: %v", err) + } + back, err := decodeReceiptHandle(h) + if err != nil { + t.Fatalf("decode: %v", err) + } + if back.QueueGeneration != tc.gen || back.MessageIDHex != tc.id { + t.Fatalf("round-trip mismatch: %+v vs %d/%s", back, tc.gen, tc.id) + } + if hex.EncodeToString(back.ReceiptToken) != hex.EncodeToString(token) { + t.Fatalf("token round-trip mismatch: %x vs %x", back.ReceiptToken, token) + } + } + + // A garbage handle must fail to decode, not crash. + if _, err := decodeReceiptHandle("!!!"); err == nil { + t.Fatal("expected decode error for garbage handle") + } +} diff --git a/adapter/sqs_test.go b/adapter/sqs_test.go new file mode 100644 index 000000000..721917174 --- /dev/null +++ b/adapter/sqs_test.go @@ -0,0 +1,218 @@ +package adapter + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + json "github.com/goccy/go-json" +) + +// startTestSQSServer starts an SQSServer on an ephemeral localhost listener +// with no coordinator or leader map — enough to exercise the health +// endpoints, the dispatch table, and the error envelope. Real flows with Raft +// come in later PRs alongside the handler implementations. +func startTestSQSServer(t *testing.T) string { + t.Helper() + lc := &net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := NewSQSServer(listener, nil, nil) + go func() { + _ = srv.Run() + }() + t.Cleanup(func() { + srv.Stop() + }) + return "http://" + listener.Addr().String() +} + +func doRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) *http.Response { + t.Helper() + req, err := http.NewRequestWithContext(context.Background(), method, url, body) + if err != nil { + t.Fatalf("request: %v", err) + } + for k, v := range headers { + req.Header.Set(k, v) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + return resp +} + +func postSQSRequest(t *testing.T, url string, target string, body string) *http.Response { + t.Helper() + headers := map[string]string{"Content-Type": sqsContentTypeJSON} + if target != "" { + headers["X-Amz-Target"] = target + } + return doRequest(t, http.MethodPost, url, strings.NewReader(body), headers) +} + +func TestSQSServer_HealthOK(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodGet, base+sqsHealthPath, nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusOK) + } + body, _ := io.ReadAll(resp.Body) + if got := strings.TrimSpace(string(body)); got != "ok" { + t.Fatalf("body: got %q want %q", got, "ok") + } +} + +func TestSQSServer_LeaderHealthWithoutCoordinatorIsNotLeader(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodGet, base+sqsLeaderHealthPath, nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusServiceUnavailable) + } +} + +func TestSQSServer_HealthRejectsNonGetHead(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodPost, base+sqsHealthPath, nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } + if allow := resp.Header.Get("Allow"); !strings.Contains(allow, "GET") { + t.Fatalf("allow header: got %q", allow) + } +} + +func TestSQSServer_UnknownTargetReturnsInvalidAction(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := postSQSRequest(t, base+"/", "AmazonSQS.NoSuchOperation", "{}") + defer resp.Body.Close() + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusBadRequest) + } + if got := resp.Header.Get("x-amzn-ErrorType"); got != sqsErrInvalidAction { + t.Fatalf("error type: got %q want %q", got, sqsErrInvalidAction) + } + var body map[string]string + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if body["__type"] != sqsErrInvalidAction { + t.Fatalf("body __type: got %q want %q", body["__type"], sqsErrInvalidAction) + } +} + +func TestSQSServer_KnownTargetsReturnNotImplemented(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + // Targets that still return NotImplemented. The catalog and core + // message operations (Create/Delete/List/Get/SetQueue*, SendMessage, + // ReceiveMessage, DeleteMessage, ChangeMessageVisibility) have real + // handlers; they are exercised against a single-node cluster by + // TestSQSServer_Catalog* and TestSQSServer_Send*. + targets := []string{ + sqsPurgeQueueTarget, + sqsSendMessageBatchTarget, + sqsDeleteMessageBatchTarget, + sqsChangeMessageVisibilityBatchTgt, + sqsTagQueueTarget, + sqsUntagQueueTarget, + sqsListQueueTagsTarget, + } + for _, target := range targets { + t.Run(target, func(t *testing.T) { + t.Parallel() + resp := postSQSRequest(t, base+"/", target, "{}") + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotImplemented { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusNotImplemented) + } + if got := resp.Header.Get("x-amzn-ErrorType"); got != sqsErrNotImplemented { + t.Fatalf("error type: got %q want %q", got, sqsErrNotImplemented) + } + }) + } +} + +func TestSQSServer_RejectsNonPostOnRoot(t *testing.T) { + t.Parallel() + base := startTestSQSServer(t) + + resp := doRequest(t, http.MethodGet, base+"/", nil, nil) + defer resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Fatalf("status: got %d want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } + if allow := resp.Header.Get("Allow"); !strings.Contains(allow, http.MethodPost) { + t.Fatalf("allow header: got %q", allow) + } +} + +func TestSQSServer_ErrorEnvelopeShape(t *testing.T) { + t.Parallel() + // We check the precise envelope shape via httptest so that SDK parsers + // that key off x-amzn-ErrorType + __type + message do not regress. + rec := httptest.NewRecorder() + writeSQSError(rec, http.StatusBadRequest, sqsErrInvalidAction, "oops") + if got := rec.Header().Get("Content-Type"); got != sqsContentTypeJSON { + t.Fatalf("content-type: got %q", got) + } + if got := rec.Header().Get("x-amzn-ErrorType"); got != sqsErrInvalidAction { + t.Fatalf("error type header: got %q", got) + } + var body map[string]string + if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil { + t.Fatalf("decode: %v", err) + } + if body["__type"] != sqsErrInvalidAction { + t.Fatalf("__type: got %q", body["__type"]) + } + if body["message"] != "oops" { + t.Fatalf("message: got %q", body["message"]) + } +} + +// TestSQSServer_StopShutsDown guards against regressions where Stop leaves the +// goroutine leaking; after Stop, Run must return promptly. +func TestSQSServer_StopShutsDown(t *testing.T) { + t.Parallel() + lc := &net.ListenConfig{} + listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + srv := NewSQSServer(listener, nil, nil) + done := make(chan error, 1) + go func() { done <- srv.Run() }() + // Give Serve a chance to enter its accept loop. + time.Sleep(20 * time.Millisecond) + srv.Stop() + + select { + case err := <-done: + if err != nil { + t.Fatalf("run: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("Run did not return within timeout after Stop") + } +} diff --git a/adapter/test_util.go b/adapter/test_util.go index b43f024af..7dcf22ac2 100644 --- a/adapter/test_util.go +++ b/adapter/test_util.go @@ -52,28 +52,35 @@ func newTestFactory() raftengine.Factory { func shutdown(nodes []Node) { for _, n := range nodes { - if n.opsCancel != nil { - n.opsCancel() - } - n.grpcServer.Stop() - if n.grpcService != nil { - if err := n.grpcService.Close(); err != nil { - log.Printf("grpc service close: %v", err) - } - } - n.redisServer.Stop() - if n.dynamoServer != nil { - n.dynamoServer.Stop() + shutdownNode(n) + } +} + +func shutdownNode(n Node) { + if n.opsCancel != nil { + n.opsCancel() + } + n.grpcServer.Stop() + if n.grpcService != nil { + if err := n.grpcService.Close(); err != nil { + log.Printf("grpc service close: %v", err) } - if n.engine != nil { - if err := n.engine.Close(); err != nil { - log.Printf("engine close: %v", err) - } + } + n.redisServer.Stop() + if n.dynamoServer != nil { + n.dynamoServer.Stop() + } + if n.sqsServer != nil { + n.sqsServer.Stop() + } + if n.engine != nil { + if err := n.engine.Close(); err != nil { + log.Printf("engine close: %v", err) } - if n.closeFactory != nil { - if err := n.closeFactory(); err != nil { - log.Printf("factory close: %v", err) - } + } + if n.closeFactory != nil { + if err := n.closeFactory(); err != nil { + log.Printf("factory close: %v", err) } } } @@ -83,10 +90,12 @@ type portsAdress struct { raft int redis int dynamo int + sqs int grpcAddress string raftAddress string redisAddress string dynamoAddress string + sqsAddress string } const ( @@ -95,6 +104,7 @@ const ( raftPort = 50000 redisPort = 63790 dynamoPort = 28000 + sqsPort = 29000 ) var mu sync.Mutex @@ -102,12 +112,14 @@ var portGrpc atomic.Int32 var portRaft atomic.Int32 var portRedis atomic.Int32 var portDynamo atomic.Int32 +var portSQS atomic.Int32 func init() { portGrpc.Store(raftPort) portRaft.Store(grpcPort) portRedis.Store(redisPort) portDynamo.Store(dynamoPort) + portSQS.Store(sqsPort) } func portAssigner() portsAdress { @@ -117,15 +129,18 @@ func portAssigner() portsAdress { rp := portRaft.Add(1) rd := portRedis.Add(1) dn := portDynamo.Add(1) + sq := portSQS.Add(1) return portsAdress{ grpc: int(gp), raft: int(rp), redis: int(rd), dynamo: int(dn), + sqs: int(sq), grpcAddress: net.JoinHostPort("localhost", strconv.Itoa(int(gp))), raftAddress: net.JoinHostPort("localhost", strconv.Itoa(int(rp))), redisAddress: net.JoinHostPort("localhost", strconv.Itoa(int(rd))), dynamoAddress: net.JoinHostPort("localhost", strconv.Itoa(int(dn))), + sqsAddress: net.JoinHostPort("localhost", strconv.Itoa(int(sq))), } } @@ -134,25 +149,29 @@ type Node struct { raftAddress string redisAddress string dynamoAddress string + sqsAddress string grpcServer *grpc.Server grpcService *GRPCServer redisServer *RedisServer dynamoServer *DynamoDBServer + sqsServer *SQSServer opsCancel context.CancelFunc engine raftengine.Engine closeFactory func() error } -func newNode(grpcAddress, raftAddress, redisAddress, dynamoAddress string, engine raftengine.Engine, closeFactory func() error, grpcs *grpc.Server, grpcService *GRPCServer, rd *RedisServer, ds *DynamoDBServer, opsCancel context.CancelFunc) Node { +func newNode(grpcAddress, raftAddress, redisAddress, dynamoAddress, sqsAddress string, engine raftengine.Engine, closeFactory func() error, grpcs *grpc.Server, grpcService *GRPCServer, rd *RedisServer, ds *DynamoDBServer, sq *SQSServer, opsCancel context.CancelFunc) Node { return Node{ grpcAddress: grpcAddress, raftAddress: raftAddress, redisAddress: redisAddress, dynamoAddress: dynamoAddress, + sqsAddress: sqsAddress, grpcServer: grpcs, grpcService: grpcService, redisServer: rd, dynamoServer: ds, + sqsServer: sq, opsCancel: opsCancel, engine: engine, closeFactory: closeFactory, @@ -183,6 +202,7 @@ type listeners struct { grpc net.Listener redis net.Listener dynamo net.Listener + sqs net.Listener } func bindListeners(ctx context.Context, lc *net.ListenConfig, port portsAdress) (portsAdress, listeners, bool, error) { @@ -213,10 +233,22 @@ func bindListeners(ctx context.Context, lc *net.ListenConfig, port portsAdress) return port, listeners{}, false, errors.WithStack(err) } + sqsSock, err := lc.Listen(ctx, "tcp", port.sqsAddress) + if err != nil { + _ = grpcSock.Close() + _ = redisSock.Close() + _ = dynamoSock.Close() + if errors.Is(err, unix.EADDRINUSE) { + return port, listeners{}, true, nil + } + return port, listeners{}, false, errors.WithStack(err) + } + return port, listeners{ grpc: grpcSock, redis: redisSock, dynamo: dynamoSock, + sqs: sqsSock, }, false, nil } @@ -386,6 +418,7 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ( grpcSock := lis[i].grpc redisSock := lis[i].redis dynamoSock := lis[i].dynamo + sqsSock := lis[i].sqs result, err := factory.Create(raftengine.FactoryConfig{ LocalID: strconv.Itoa(i), @@ -433,17 +466,24 @@ func setupNodes(t *testing.T, ctx context.Context, n int, ports []portsAdress) ( assert.NoError(t, ds.Run()) }() + sq := NewSQSServer(sqsSock, routedStore, coordinator) + go func() { + assert.NoError(t, sq.Run()) + }() + nodes = append(nodes, newNode( port.grpcAddress, port.raftAddress, port.redisAddress, port.dynamoAddress, + port.sqsAddress, result.Engine, result.Close, s, gs, rd, ds, + sq, opsCancel, )) } diff --git a/main.go b/main.go index 4306388a8..2e5e3c4ac 100644 --- a/main.go +++ b/main.go @@ -79,6 +79,9 @@ var ( s3Region = flag.String("s3Region", "us-east-1", "S3 signing region") s3CredsFile = flag.String("s3CredentialsFile", "", "Path to a JSON file containing static S3 credentials") s3PathStyleOnly = flag.Bool("s3PathStyleOnly", true, "Only accept path-style S3 requests") + sqsAddr = flag.String("sqsAddress", "", "TCP host+port for SQS-compatible API; empty to disable") + sqsRegion = flag.String("sqsRegion", "us-east-1", "SQS signing region") + sqsCredsFile = flag.String("sqsCredentialsFile", "", "Path to a JSON file containing static SQS credentials") metricsAddr = flag.String("metricsAddress", "localhost:9090", "TCP host+port for Prometheus metrics") metricsToken = flag.String("metricsToken", "", "Bearer token for Prometheus metrics; required for non-loopback metricsAddress") pprofAddr = flag.String("pprofAddress", "localhost:6060", "TCP host+port for pprof debug endpoints; empty to disable") @@ -93,6 +96,7 @@ var ( raftRedisMap = flag.String("raftRedisMap", "", "Map of Raft address to Redis address (raftAddr=redisAddr,...)") raftS3Map = flag.String("raftS3Map", "", "Map of Raft address to S3 address (raftAddr=s3Addr,...)") raftDynamoMap = flag.String("raftDynamoMap", "", "Map of Raft address to DynamoDB address (raftAddr=dynamoAddr,...)") + raftSqsMap = flag.String("raftSqsMap", "", "Map of Raft address to SQS address (raftAddr=sqsAddr,...)") ) func main() { @@ -212,6 +216,10 @@ func run() error { s3Region: *s3Region, s3CredsFile: *s3CredsFile, s3PathStyleOnly: *s3PathStyleOnly, + sqsAddress: *sqsAddr, + leaderSQS: cfg.leaderSQS, + sqsRegion: *sqsRegion, + sqsCredsFile: *sqsCredsFile, metricsAddress: *metricsAddr, metricsToken: *metricsToken, pprofAddress: *pprofAddr, @@ -238,7 +246,7 @@ func resolveRuntimeInputs() (runtimeConfig, raftEngineType, []raftengine.Server, return runtimeConfig{}, "", nil, false, err } - cfg, err := parseRuntimeConfig(*myAddr, *redisAddr, *s3Addr, *dynamoAddr, *raftGroups, *shardRanges, *raftRedisMap, *raftS3Map, *raftDynamoMap) + cfg, err := parseRuntimeConfig(*myAddr, *redisAddr, *s3Addr, *dynamoAddr, *sqsAddr, *raftGroups, *shardRanges, *raftRedisMap, *raftS3Map, *raftDynamoMap, *raftSqsMap) if err != nil { return runtimeConfig{}, "", nil, false, err } @@ -258,10 +266,11 @@ type runtimeConfig struct { leaderRedis map[string]string leaderS3 map[string]string leaderDynamo map[string]string + leaderSQS map[string]string multi bool } -func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, raftGroups, shardRanges, raftRedisMap, raftS3Map, raftDynamoMap string) (runtimeConfig, error) { +func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, sqsAddr, raftGroups, shardRanges, raftRedisMap, raftS3Map, raftDynamoMap, raftSqsMap string) (runtimeConfig, error) { groups, err := parseRaftGroups(raftGroups, myAddr) if err != nil { return runtimeConfig{}, errors.Wrapf(err, "failed to parse raft groups") @@ -288,6 +297,10 @@ func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, raftGroups, shard if err != nil { return runtimeConfig{}, errors.Wrapf(err, "failed to parse raft dynamo map") } + leaderSQS, err := buildLeaderSQS(groups, sqsAddr, raftSqsMap) + if err != nil { + return runtimeConfig{}, errors.Wrapf(err, "failed to parse raft sqs map") + } return runtimeConfig{ groups: groups, @@ -296,6 +309,7 @@ func parseRuntimeConfig(myAddr, redisAddr, s3Addr, dynamoAddr, raftGroups, shard leaderRedis: leaderRedis, leaderS3: leaderS3, leaderDynamo: leaderDynamo, + leaderSQS: leaderSQS, multi: len(groups) > 1, }, nil } @@ -316,6 +330,10 @@ func buildLeaderS3(groups []groupSpec, s3Addr string, raftS3Map string) (map[str return buildLeaderAddrMap(groups, s3Addr, raftS3Map, parseRaftS3Map) } +func buildLeaderSQS(groups []groupSpec, sqsAddr string, raftSqsMap string) (map[string]string, error) { + return buildLeaderAddrMap(groups, sqsAddr, raftSqsMap, parseRaftSQSMap) +} + func buildLeaderDynamo(groups []groupSpec, dynamoAddr string, raftDynamoMap string) (map[string]string, error) { return buildLeaderAddrMap(groups, dynamoAddr, raftDynamoMap, parseRaftDynamoMap) } @@ -824,6 +842,10 @@ type runtimeServerRunner struct { s3Region string s3CredsFile string s3PathStyleOnly bool + sqsAddress string + leaderSQS map[string]string + sqsRegion string + sqsCredsFile string metricsAddress string metricsToken string pprofAddress string @@ -856,6 +878,9 @@ func (r runtimeServerRunner) start() error { if err := startS3Server(r.ctx, r.lc, r.eg, r.s3Address, r.shardStore, r.coordinate, r.leaderS3, r.s3Region, r.s3CredsFile, r.s3PathStyleOnly, r.readTracker); err != nil { return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) } + if err := startSQSServer(r.ctx, r.lc, r.eg, r.sqsAddress, r.shardStore, r.coordinate, r.leaderSQS, r.sqsRegion, r.sqsCredsFile); err != nil { + return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) + } if err := startMetricsServer(r.ctx, r.lc, r.eg, r.metricsAddress, r.metricsToken, r.metricsRegistry.Handler()); err != nil { return waitErrgroupAfterStartupFailure(r.cancel, r.eg, err) } diff --git a/main_bootstrap_e2e_test.go b/main_bootstrap_e2e_test.go index b6040bb45..92474ddbc 100644 --- a/main_bootstrap_e2e_test.go +++ b/main_bootstrap_e2e_test.go @@ -340,7 +340,7 @@ func startBootstrapE2ENode( bootstrapMembers string, engineType raftEngineType, ) (*bootstrapE2ENode, error) { - cfg, err := parseRuntimeConfig(ep.raftAddr, ep.redisAddr, "", "", "", "", "", "", "") + cfg, err := parseRuntimeConfig(ep.raftAddr, ep.redisAddr, "", "", "", "", "", "", "", "", "") if err != nil { return nil, err } diff --git a/main_s3.go b/main_s3.go index 7fae6f16d..5de290c7f 100644 --- a/main_s3.go +++ b/main_s3.go @@ -2,10 +2,7 @@ package main import ( "context" - "encoding/json" - "fmt" "net" - "os" "strings" "github.com/bootjp/elastickv/adapter" @@ -14,15 +11,6 @@ import ( "golang.org/x/sync/errgroup" ) -type s3CredentialFile struct { - Credentials []s3CredentialEntry `json:"credentials"` -} - -type s3CredentialEntry struct { - AccessKeyID string `json:"access_key_id"` - SecretAccessKey string `json:"secret_access_key"` -} - func startS3Server( ctx context.Context, lc *net.ListenConfig, @@ -83,30 +71,5 @@ func startS3Server( } func loadS3StaticCredentials(path string) (map[string]string, error) { - path = strings.TrimSpace(path) - if path == "" { - return nil, nil - } - f, err := os.Open(path) - if err != nil { - return nil, errors.WithStack(err) - } - defer f.Close() - file := s3CredentialFile{} - if err := json.NewDecoder(f).Decode(&file); err != nil { - return nil, errors.WithStack(err) - } - out := make(map[string]string, len(file.Credentials)) - for _, cred := range file.Credentials { - accessKeyID := strings.TrimSpace(cred.AccessKeyID) - secretAccessKey := strings.TrimSpace(cred.SecretAccessKey) - if accessKeyID == "" || secretAccessKey == "" { - return nil, errors.New("s3 credentials file contains an empty access key or secret key") - } - if _, exists := out[accessKeyID]; exists { - return nil, errors.WithStack(fmt.Errorf("s3 credentials file contains duplicate access key ID: %q", accessKeyID)) - } - out[accessKeyID] = secretAccessKey - } - return out, nil + return loadSigV4StaticCredentialsFile(path, "s3") } diff --git a/main_sigv4_creds.go b/main_sigv4_creds.go new file mode 100644 index 000000000..9ab517bf7 --- /dev/null +++ b/main_sigv4_creds.go @@ -0,0 +1,58 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/cockroachdb/errors" +) + +// sigV4CredentialsFile is the JSON schema both --s3CredentialsFile and +// --sqsCredentialsFile read. Sharing the schema means operators maintain one +// file per deployment regardless of which adapters are enabled. +type sigV4CredentialsFile struct { + Credentials []sigV4CredentialEntry `json:"credentials"` +} + +type sigV4CredentialEntry struct { + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` +} + +// loadSigV4StaticCredentialsFile parses a credentials file and returns a +// map of access-key → secret suitable for WithS3StaticCredentials or +// WithSQSStaticCredentials. An empty path returns a nil map so the caller +// can leave authorization disabled. +// +// labelForErrors appears in error text ("s3 credentials file ...") so +// operators get context on which flag produced the problem. +func loadSigV4StaticCredentialsFile(path string, labelForErrors string) (map[string]string, error) { + path = strings.TrimSpace(path) + if path == "" { + return nil, nil + } + f, err := os.Open(path) + if err != nil { + return nil, errors.WithStack(err) + } + defer f.Close() + file := sigV4CredentialsFile{} + if err := json.NewDecoder(f).Decode(&file); err != nil { + return nil, errors.WithStack(err) + } + out := make(map[string]string, len(file.Credentials)) + for _, cred := range file.Credentials { + accessKeyID := strings.TrimSpace(cred.AccessKeyID) + secretAccessKey := strings.TrimSpace(cred.SecretAccessKey) + if accessKeyID == "" || secretAccessKey == "" { + return nil, errors.WithStack(fmt.Errorf("%s credentials file contains an empty access key or secret key", labelForErrors)) + } + if _, exists := out[accessKeyID]; exists { + return nil, errors.WithStack(fmt.Errorf("%s credentials file contains duplicate access key ID: %q", labelForErrors, accessKeyID)) + } + out[accessKeyID] = secretAccessKey + } + return out, nil +} diff --git a/main_sqs.go b/main_sqs.go new file mode 100644 index 000000000..55ca41684 --- /dev/null +++ b/main_sqs.go @@ -0,0 +1,67 @@ +package main + +import ( + "context" + "net" + "strings" + + "github.com/bootjp/elastickv/adapter" + "github.com/bootjp/elastickv/kv" + "github.com/cockroachdb/errors" + "golang.org/x/sync/errgroup" +) + +func startSQSServer( + ctx context.Context, + lc *net.ListenConfig, + eg *errgroup.Group, + sqsAddr string, + shardStore *kv.ShardStore, + coordinate kv.Coordinator, + leaderSQS map[string]string, + region string, + credentialsFile string, +) error { + sqsAddr = strings.TrimSpace(sqsAddr) + if sqsAddr == "" { + return nil + } + sqsL, err := lc.Listen(ctx, "tcp", sqsAddr) + if err != nil { + return errors.Wrapf(err, "failed to listen on %s", sqsAddr) + } + staticCreds, err := loadSigV4StaticCredentialsFile(credentialsFile, "sqs") + if err != nil { + _ = sqsL.Close() + return err + } + sqsServer := adapter.NewSQSServer( + sqsL, + shardStore, + coordinate, + adapter.WithSQSLeaderMap(leaderSQS), + adapter.WithSQSRegion(region), + adapter.WithSQSStaticCredentials(staticCreds), + ) + // Two-goroutine shutdown pattern mirrors startS3Server: one goroutine waits + // on either ctx.Done() or Run completion to call Stop, the other runs the + // server and cancels the waiter once it has returned. + runDoneCtx, runDoneCancel := context.WithCancel(context.Background()) + eg.Go(func() error { + select { + case <-ctx.Done(): + sqsServer.Stop() + case <-runDoneCtx.Done(): + } + return nil + }) + eg.Go(func() error { + err := sqsServer.Run() + runDoneCancel() + if err == nil || errors.Is(err, net.ErrClosed) { + return nil + } + return errors.WithStack(err) + }) + return nil +} diff --git a/shard_config.go b/shard_config.go index a7353038b..eee13d204 100644 --- a/shard_config.go +++ b/shard_config.go @@ -32,6 +32,7 @@ var ( ErrInvalidRaftRedisMapEntry = errors.New("invalid raftRedisMap entry") ErrInvalidRaftS3MapEntry = errors.New("invalid raftS3Map entry") ErrInvalidRaftDynamoMapEntry = errors.New("invalid raftDynamoMap entry") + ErrInvalidRaftSQSMapEntry = errors.New("invalid raftSqsMap entry") ErrInvalidRaftBootstrapMembersEntry = errors.New("invalid raftBootstrapMembers entry") ) @@ -128,6 +129,10 @@ func parseRaftDynamoMap(raw string) (map[string]string, error) { return parseRaftAddressMap(raw, ErrInvalidRaftDynamoMapEntry) } +func parseRaftSQSMap(raw string) (map[string]string, error) { + return parseRaftAddressMap(raw, ErrInvalidRaftSQSMapEntry) +} + func parseRaftAddressMap(raw string, invalidEntry error) (map[string]string, error) { out := make(map[string]string) if raw == "" {