diff --git a/README.md b/README.md index e3b4d50..e33413b 100644 --- a/README.md +++ b/README.md @@ -13,15 +13,16 @@ HTTP API providing user/client message handling for an fmsg host. Exposes CRUD o | `FMSG_JWT_ISSUER` | *(prod, required with JWKS)* | Expected `iss` claim value (e.g. `https://idp.fmsg.io`). Tokens with a different issuer are rejected. | | `FMSG_JWT_AUDIENCE` | *(optional)* | When set, tokens must include this value in their `aud` claim. | | `FMSG_API_JWT_SECRET` | *(dev)* | HMAC secret for HS256 token verification. Used only in dev mode (when `FMSG_JWT_JWKS_URL` is unset). Prefix with `base64:` to supply a base64-encoded key. Either this or `FMSG_JWT_JWKS_URL` must be set. | -| `FMSG_TLS_CERT` | *(optional)* | Path to the TLS certificate file (e.g. `/etc/letsencrypt/live/example.com/fullchain.pem`). When set with `FMSG_TLS_KEY`, enables HTTPS on port 443. | +| `FMSG_TLS_CERT` | *(optional)* | Path to the TLS certificate file (e.g. `/etc/letsencrypt/live/example.com/fullchain.pem`). When set with `FMSG_TLS_KEY`, enables HTTPS. | | `FMSG_TLS_KEY` | *(optional)* | Path to the TLS private key file (e.g. `/etc/letsencrypt/live/example.com/privkey.pem`). Must be set together with `FMSG_TLS_CERT`. | -| `FMSG_API_PORT` | `8000` | TCP port for plain HTTP mode (ignored when TLS is enabled) | +| `FMSG_API_PORT` | `443` (TLS) / `8000` (plain) | TCP port to listen on. | | `FMSG_ID_URL` | `http://127.0.0.1:8080` | Base URL of the fmsgid identity service | | `FMSG_API_RATE_LIMIT`| `10` | Max sustained requests per second per IP | | `FMSG_API_RATE_BURST`| `20` | Max burst size for the per-IP rate limiter | | `FMSG_API_MAX_DATA_SIZE`| `10` | Maximum message data size in megabytes | | `FMSG_API_MAX_ATTACH_SIZE`| `10` | Maximum attachment file size in megabytes | | `FMSG_API_MAX_MSG_SIZE`| `20` | Maximum total message size (data + attachments) in megabytes | +| `FMSG_CORS_ORIGINS` | *(optional)* | Comma-separated list of browser origins allowed via CORS, e.g. `https://fmsg.io,https://www.fmsg.io`. Use `*` to allow any origin. When unset, no CORS headers are emitted (server-to-server callers are unaffected). | Standard PostgreSQL environment variables (`PGHOST`, `PGPORT`, `PGUSER`, `PGPASSWORD`, `PGDATABASE`) are used for database connectivity. @@ -85,7 +86,8 @@ go test ./... ### TLS mode (production) -Set `FMSG_TLS_CERT` and `FMSG_TLS_KEY` to enable HTTPS on port `443`. +Set `FMSG_TLS_CERT` and `FMSG_TLS_KEY` to enable HTTPS. Listens on port `443` +by default; override with `FMSG_API_PORT`. ```bash export FMSG_DATA_DIR=/opt/fmsg/data @@ -107,6 +109,10 @@ go run . Omit the TLS variables to run a plain HTTP server. Override the port with `FMSG_API_PORT` (default `8000`). +This is the recommended mode when fronting fmsg-webapi with Apache, nginx, or +any other reverse proxy that already terminates TLS (e.g. Apache on `:443` +proxying `https://fmsgapi.example.com/` to `http://127.0.0.1:8000/`). + ```bash export FMSG_DATA_DIR=/var/lib/fmsgd/ export FMSG_API_JWT_SECRET=changeme diff --git a/src/main.go b/src/main.go index e0b4d9f..5966ea1 100644 --- a/src/main.go +++ b/src/main.go @@ -51,6 +51,10 @@ func main() { maxAttachSize := int64(envOrDefaultInt("FMSG_API_MAX_ATTACH_SIZE", 10)) * 1024 * 1024 maxMsgSize := int64(envOrDefaultInt("FMSG_API_MAX_MSG_SIZE", 20)) * 1024 * 1024 + // CORS: comma-separated list of allowed browser origins, e.g. + // "https://fmsg.io,https://www.fmsg.io". Empty disables CORS. + corsOrigins := parseCSV(os.Getenv("FMSG_CORS_ORIGINS")) + // Connect to PostgreSQL (uses standard PG* environment variables). ctx := context.Background() database, err := db.New(ctx, "") @@ -73,6 +77,16 @@ func main() { // Create Gin router. router := gin.Default() + // CORS must run before authentication so that browser preflight (OPTIONS) + // requests, which do not carry the Authorization header, are answered + // directly instead of being rejected by the JWT middleware. + if len(corsOrigins) > 0 { + corsCfg := middleware.DefaultCORSConfig() + corsCfg.AllowedOrigins = corsOrigins + router.Use(middleware.NewCORS(corsCfg)) + log.Printf("CORS enabled for origins: %s", strings.Join(corsOrigins, ", ")) + } + // Global rate limiter. router.Use(middleware.NewRateLimiter(ctx, float64(rateLimit), rateBurst)) @@ -109,8 +123,9 @@ func main() { } if tlsEnabled { - srv.Addr = ":443" - log.Println("fmsg-webapi starting on :443") + port := envOrDefault("FMSG_API_PORT", "443") + srv.Addr = ":" + port + log.Printf("fmsg-webapi starting on :%s (HTTPS)", port) srv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} if err = srv.ListenAndServeTLS(tlsCert, tlsKey); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("server error: %v", err) @@ -134,6 +149,21 @@ func mustEnv(key string) string { return v } +// parseCSV splits a comma-separated string into trimmed, non-empty values. +func parseCSV(s string) []string { + if s == "" { + return nil + } + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + if v := strings.TrimSpace(p); v != "" { + out = append(out, v) + } + } + return out +} + // envOrDefault returns the environment variable value or defaultValue when unset. func envOrDefault(key, defaultValue string) string { if v := os.Getenv(key); v != "" { diff --git a/src/middleware/cors.go b/src/middleware/cors.go new file mode 100644 index 0000000..7ba4272 --- /dev/null +++ b/src/middleware/cors.go @@ -0,0 +1,123 @@ +package middleware + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// CORSConfig configures the CORS middleware. +type CORSConfig struct { + // AllowedOrigins is the list of exact origins permitted to access the API + // from a browser, e.g. "https://fmsg.io". A single entry of "*" allows any + // origin (only valid when credentials are not used). An empty list + // disables CORS entirely. + AllowedOrigins []string + // AllowedMethods are the HTTP methods returned in the preflight response. + AllowedMethods []string + // AllowedHeaders are the request headers returned in the preflight response. + AllowedHeaders []string + // MaxAge controls how long browsers may cache the preflight result. + MaxAge time.Duration +} + +// DefaultCORSConfig returns a CORSConfig populated with values appropriate for +// this API: GET/POST/PUT/DELETE/OPTIONS plus Authorization and Content-Type +// request headers, with a 10 minute preflight cache. Callers must still set +// AllowedOrigins. +func DefaultCORSConfig() CORSConfig { + return CORSConfig{ + AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Authorization", "Content-Type"}, + MaxAge: 10 * time.Minute, + } +} + +// NewCORS returns a Gin middleware that handles CORS preflight requests and +// adds the Access-Control-Allow-* headers to matching cross-origin responses. +// +// Behaviour: +// - Requests without an Origin header pass through untouched. +// - When Origin matches an entry in AllowedOrigins (or AllowedOrigins is +// {"*"}), the appropriate Access-Control-Allow-* headers are added. +// - OPTIONS preflight requests are short-circuited with 204 so they never +// reach downstream auth middleware (which would reject them for missing +// the Authorization header). +// - When Origin is present but not allowed, the request is allowed to +// continue without CORS headers; the browser will then block the +// response, which is the standard CORS failure mode. +func NewCORS(cfg CORSConfig) gin.HandlerFunc { + if len(cfg.AllowedOrigins) == 0 { + // CORS disabled; return a no-op middleware. + return func(c *gin.Context) { c.Next() } + } + + trimmedOrigins := make([]string, 0, len(cfg.AllowedOrigins)) + allowed := make(map[string]struct{}, len(cfg.AllowedOrigins)) + for _, o := range cfg.AllowedOrigins { + origin := strings.TrimSpace(o) + if origin == "" { + continue + } + trimmedOrigins = append(trimmedOrigins, origin) + allowed[origin] = struct{}{} + } + if len(trimmedOrigins) == 0 { + // CORS disabled; return a no-op middleware. + return func(c *gin.Context) { c.Next() } + } + allowAny := len(trimmedOrigins) == 1 && trimmedOrigins[0] == "*" + + methods := strings.Join(cfg.AllowedMethods, ", ") + headers := strings.Join(cfg.AllowedHeaders, ", ") + maxAge := strconv.Itoa(int(cfg.MaxAge.Seconds())) + + return func(c *gin.Context) { + origin := c.GetHeader("Origin") + if origin == "" { + c.Next() + return + } + + // Always advertise that the response varies by Origin so caches + // (browser + intermediaries) don't serve a response keyed only on + // the URL across different origins. + c.Writer.Header().Add("Vary", "Origin") + + _, ok := allowed[origin] + if !ok && !allowAny { + // Not an allowed origin. Don't add CORS headers; let the request + // proceed (the browser will block the response). + c.Next() + return + } + + if allowAny { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + } else { + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + } + + if c.Request.Method == http.MethodOptions { + // Preflight. + c.Writer.Header().Add("Vary", "Access-Control-Request-Method") + c.Writer.Header().Add("Vary", "Access-Control-Request-Headers") + if methods != "" { + c.Writer.Header().Set("Access-Control-Allow-Methods", methods) + } + if headers != "" { + c.Writer.Header().Set("Access-Control-Allow-Headers", headers) + } + if cfg.MaxAge > 0 { + c.Writer.Header().Set("Access-Control-Max-Age", maxAge) + } + c.AbortWithStatus(http.StatusNoContent) + return + } + + c.Next() + } +} diff --git a/src/middleware/cors_test.go b/src/middleware/cors_test.go new file mode 100644 index 0000000..2b648cf --- /dev/null +++ b/src/middleware/cors_test.go @@ -0,0 +1,129 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func newCORSTestRouter(origins []string) *gin.Engine { + r := gin.New() + cfg := DefaultCORSConfig() + cfg.AllowedOrigins = origins + r.Use(NewCORS(cfg)) + r.GET("/x", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) + r.POST("/x", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) + return r +} + +func TestCORS_NoOriginPassesThrough(t *testing.T) { + r := newCORSTestRouter([]string{"https://fmsg.io"}) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/x", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("Access-Control-Allow-Origin = %q, want empty", got) + } +} + +func TestCORS_AllowedOriginGetsHeaders(t *testing.T) { + r := newCORSTestRouter([]string{"https://fmsg.io"}) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/x", nil) + req.Header.Set("Origin", "https://fmsg.io") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://fmsg.io" { + t.Errorf("Access-Control-Allow-Origin = %q, want https://fmsg.io", got) + } + if got := w.Header().Get("Vary"); got == "" { + t.Errorf("Vary header missing") + } +} + +func TestCORS_DisallowedOriginGetsNoHeaders(t *testing.T) { + r := newCORSTestRouter([]string{"https://fmsg.io"}) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/x", nil) + req.Header.Set("Origin", "https://evil.example") + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status = %d, want 200", w.Code) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("Access-Control-Allow-Origin = %q, want empty", got) + } +} + +func TestCORS_PreflightShortCircuits(t *testing.T) { + r := gin.New() + cfg := DefaultCORSConfig() + cfg.AllowedOrigins = []string{"https://fmsg.io"} + r.Use(NewCORS(cfg)) + // Downstream middleware that would reject if reached. + r.Use(func(c *gin.Context) { + c.AbortWithStatus(http.StatusUnauthorized) + }) + r.POST("/x", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodOptions, "/x", nil) + req.Header.Set("Origin", "https://fmsg.io") + req.Header.Set("Access-Control-Request-Method", "POST") + req.Header.Set("Access-Control-Request-Headers", "Authorization, Content-Type") + r.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("status = %d, want 204", w.Code) + } + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "https://fmsg.io" { + t.Errorf("Access-Control-Allow-Origin = %q", got) + } + if got := w.Header().Get("Access-Control-Allow-Methods"); got == "" { + t.Errorf("Access-Control-Allow-Methods missing") + } + if got := w.Header().Get("Access-Control-Allow-Headers"); got == "" { + t.Errorf("Access-Control-Allow-Headers missing") + } + if got := w.Header().Get("Access-Control-Max-Age"); got == "" { + t.Errorf("Access-Control-Max-Age missing") + } +} + +func TestCORS_Wildcard(t *testing.T) { + r := newCORSTestRouter([]string{"*"}) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/x", nil) + req.Header.Set("Origin", "https://anything.example") + r.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Errorf("Access-Control-Allow-Origin = %q, want *", got) + } +} + +func TestCORS_DisabledWhenNoOrigins(t *testing.T) { + r := newCORSTestRouter(nil) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/x", nil) + req.Header.Set("Origin", "https://fmsg.io") + r.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Errorf("Access-Control-Allow-Origin = %q, want empty", got) + } +}