diff --git a/echo.go b/echo.go index 5e706f8bd..bb7ce4aab 100644 --- a/echo.go +++ b/echo.go @@ -54,6 +54,7 @@ import ( "os" "os/signal" "path/filepath" + "strconv" "strings" "sync" "sync/atomic" @@ -100,6 +101,7 @@ type Echo struct { // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm) formParseMaxMemory int64 + AutoHead bool } // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. @@ -288,6 +290,11 @@ type Config struct { // FormParseMaxMemory is default value for memory limit that is used // when parsing multipart forms (See (*http.Request).ParseMultipartForm) FormParseMaxMemory int64 + + // AutoHead enables automatic registration of HEAD routes for GET routes. + // When enabled, a HEAD request to a GET-only path will be handled automatically + // using the same handler as GET, with the response body suppressed. + AutoHead bool } // NewWithConfig creates an instance of Echo with given configuration. @@ -326,6 +333,9 @@ func NewWithConfig(config Config) *Echo { if config.FormParseMaxMemory > 0 { e.formParseMaxMemory = config.FormParseMaxMemory } + if config.AutoHead { + e.AutoHead = config.AutoHead + } return e } @@ -421,6 +431,67 @@ func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { } } +// headResponseWriter wraps an http.ResponseWriter and suppresses the response body +// while preserving headers and status code. Used for automatic HEAD route handling. +// It counts the bytes that would have been written so we can set Content-Length accurately. +type headResponseWriter struct { + http.ResponseWriter + bytesWritten int64 + statusCode int + wroteHeader bool +} + +// Write intercepts writes to the response body and counts bytes without actually writing them. +func (hw *headResponseWriter) Write(b []byte) (int, error) { + if !hw.wroteHeader { + hw.statusCode = http.StatusOK + hw.wroteHeader = true + } + hw.bytesWritten += int64(len(b)) + // Return success without actually writing the body for HEAD requests + return len(b), nil +} + +// WriteHeader intercepts the status code but still writes it to the underlying ResponseWriter. +func (hw *headResponseWriter) WriteHeader(statusCode int) { + if !hw.wroteHeader { + hw.statusCode = statusCode + hw.wroteHeader = true + hw.ResponseWriter.WriteHeader(statusCode) + } +} + +// Unwrap returns the underlying http.ResponseWriter for compatibility with echo.Response unwrapping. +func (hw *headResponseWriter) Unwrap() http.ResponseWriter { + return hw.ResponseWriter +} + +func wrapHeadHandler(handler HandlerFunc) HandlerFunc { + return func(c *Context) error { + if c.Request().Method != http.MethodHead { + return handler(c) + } + originalWriter := c.Response() + headWriter := &headResponseWriter{ResponseWriter: originalWriter} + + c.SetResponse(headWriter) + defer func() { + c.SetResponse(originalWriter) + }() + err := handler(c) + + if headWriter.bytesWritten > 0 { + originalWriter.Header().Set("Content-Length", strconv.FormatInt(headWriter.bytesWritten, 10)) + } + + if !headWriter.wroteHeader && headWriter.statusCode > 0 { + originalWriter.WriteHeader(headWriter.statusCode) + } + + return err + } +} + // Pre adds middleware to the chain which is run before router tries to find matching route. // Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { @@ -634,6 +705,20 @@ func (e *Echo) add(route Route) (RouteInfo, error) { if paramsCount > e.contextPathParamAllocSize.Load() { e.contextPathParamAllocSize.Store(paramsCount) } + + // Auto-register HEAD route for GET if AutoHead is enabled + if e.AutoHead && route.Method == http.MethodGet { + headRoute := Route{ + Method: http.MethodHead, + Path: route.Path, + Handler: wrapHeadHandler(route.Handler), + Middlewares: route.Middlewares, + Name: route.Name, + } + // Attempt to add HEAD route, but ignore errors if an explicit HEAD route already exists + _, _ = e.router.Add(headRoute) + } + return ri, nil } @@ -642,6 +727,7 @@ func (e *Echo) add(route Route) (RouteInfo, error) { func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { ri, err := e.add( Route{ + Method: method, Path: path, Handler: handler, diff --git a/echo_test.go b/echo_test.go index b5045e111..950a52042 100644 --- a/echo_test.go +++ b/echo_test.go @@ -1233,6 +1233,159 @@ func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { assert.Equal(t, http.StatusOK, resp.Code) } +func TestAutoHeadRoute(t *testing.T) { + tests := []struct { + name string + autoHead bool + method string + wantBody bool + wantCode int + wantCLen bool // expect Content-Length header + }{ + { + name: "AutoHead disabled - HEAD returns 405", + autoHead: false, + method: http.MethodHead, + wantCode: http.StatusMethodNotAllowed, + wantBody: false, + }, + { + name: "AutoHead enabled - HEAD returns 200 with Content-Length", + autoHead: true, + method: http.MethodHead, + wantCode: http.StatusOK, + wantBody: false, + wantCLen: true, + }, + { + name: "GET request works normally with AutoHead enabled", + autoHead: true, + method: http.MethodGet, + wantCode: http.StatusOK, + wantBody: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create Echo instance with AutoHead configuration + e := New() + e.AutoHead = tt.autoHead + + // Register a simple GET route + testBody := "Hello, World!" + e.GET("/hello", func(c *Context) error { + return c.String(http.StatusOK, testBody) + }) + + // Create request and response + req := httptest.NewRequest(tt.method, "/hello", nil) + rec := httptest.NewRecorder() + + // Serve the request + e.ServeHTTP(rec, req) + + // Verify status code + if rec.Code != tt.wantCode { + t.Errorf("expected status %d, got %d", tt.wantCode, rec.Code) + } + + // Verify response body + if tt.wantBody { + if rec.Body.String() != testBody { + t.Errorf("expected body %q, got %q", testBody, rec.Body.String()) + } + } else { + if rec.Body.String() != "" { + t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) + } + } + + // Verify Content-Length header for HEAD + if tt.wantCLen && tt.method == http.MethodHead { + clen := rec.Header().Get("Content-Length") + if clen == "" { + t.Error("expected Content-Length header for HEAD request") + } + } + }) + } +} + +func TestAutoHeadExplicitHeadTakesPrecedence(t *testing.T) { + e := New() + e.AutoHead = true + + // Register explicit HEAD route FIRST with custom behavior + e.HEAD("/api/users", func(c *Context) error { + c.Response().Header().Set("X-Custom-Header", "explicit-head") + return c.NoContent(http.StatusOK) + }) + + // Then register GET route - AutoHead will try to add a HEAD route but fail silently + // since one already exists + e.GET("/api/users", func(c *Context) error { + return c.JSON(http.StatusOK, map[string]string{"name": "John"}) + }) + + // Test that the explicit HEAD route behavior is preserved + req := httptest.NewRequest(http.MethodHead, "/api/users", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rec.Code) + } + + if rec.Header().Get("X-Custom-Header") != "explicit-head" { + t.Error("expected explicit HEAD route to be used") + } + + // Verify body is empty + if rec.Body.String() != "" { + t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) + } +} + +func TestAutoHeadWithMiddleware(t *testing.T) { + e := New() + e.AutoHead = true + + // Add request logger middleware + middlewareExecuted := false + e.Use(func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + middlewareExecuted = true + c.Response().Header().Set("X-Middleware", "executed") + return next(c) + } + }) + + // Register GET route + e.GET("/test", func(c *Context) error { + return c.String(http.StatusOK, "test response") + }) + + // Test HEAD request goes through middleware + req := httptest.NewRequest(http.MethodHead, "/test", nil) + rec := httptest.NewRecorder() + + middlewareExecuted = false + e.ServeHTTP(rec, req) + + if !middlewareExecuted { + t.Error("middleware should execute for automatic HEAD route") + } + + if rec.Header().Get("X-Middleware") != "executed" { + t.Error("middleware header not set") + } + + if rec.Body.String() != "" { + t.Errorf("expected empty body for HEAD, got %q", rec.Body.String()) + } +} + func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -1278,3 +1431,23 @@ func BenchmarkEchoGitHubAPIMisses(b *testing.B) { func BenchmarkEchoParseAPI(b *testing.B) { benchmarkEchoRoutes(b, parseAPI) } + +func BenchmarkAutoHeadRoute(b *testing.B) { + e := New() + e.AutoHead = true + + e.GET("/bench", func(c *Context) error { + return c.String(http.StatusOK, "benchmark response body") + }) + + req := httptest.NewRequest(http.MethodHead, "/bench", nil) + rec := httptest.NewRecorder() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rec.Body.Reset() + e.ServeHTTP(rec, req) + } +}