diff --git a/common.go b/common.go index 32efcf1..84264c1 100644 --- a/common.go +++ b/common.go @@ -175,7 +175,17 @@ func isZero(v interface{}) bool { // Parse path parameters func parsePathParams(c *fiber.Ctx, input interface{}) error { inputValue := reflect.ValueOf(input).Elem() - inputType := reflect.TypeOf(input).Elem() + inputType := dereferenceType(reflect.TypeOf(input).Elem()) + + if inputType.Kind() != reflect.Struct { + return nil + } + if inputValue.Kind() == reflect.Ptr { + if inputValue.IsNil() { + return nil + } + inputValue = inputValue.Elem() + } for i := 0; i < inputType.NumField(); i++ { field := inputType.Field(i) @@ -196,7 +206,17 @@ func parsePathParams(c *fiber.Ctx, input interface{}) error { // Parse query parameters func parseQueryParams(c *fiber.Ctx, input interface{}) error { inputValue := reflect.ValueOf(input).Elem() - inputType := reflect.TypeOf(input).Elem() + inputType := dereferenceType(reflect.TypeOf(input).Elem()) + + if inputType.Kind() != reflect.Struct { + return nil + } + if inputValue.Kind() == reflect.Ptr { + if inputValue.IsNil() { + return nil + } + inputValue = inputValue.Elem() + } for i := 0; i < inputType.NumField(); i++ { field := inputType.Field(i) @@ -219,7 +239,17 @@ func parseQueryParams(c *fiber.Ctx, input interface{}) error { // Parse header parameters func parseHeaderParams(c *fiber.Ctx, input interface{}) error { inputValue := reflect.ValueOf(input).Elem() - inputType := reflect.TypeOf(input).Elem() + inputType := dereferenceType(reflect.TypeOf(input).Elem()) + + if inputType.Kind() != reflect.Struct { + return nil + } + if inputValue.Kind() == reflect.Ptr { + if inputValue.IsNil() { + return nil + } + inputValue = inputValue.Elem() + } for i := 0; i < inputType.NumField(); i++ { field := inputType.Field(i) diff --git a/fiberoapi.go b/fiberoapi.go index c50c03f..dfa9e42 100644 --- a/fiberoapi.go +++ b/fiberoapi.go @@ -283,8 +283,7 @@ func (o *OApiApp) GenerateOpenAPISpec() map[string]interface{} { var schemaRef map[string]interface{} - // For map types, use inline schema instead of reference - if inputType.Kind() == reflect.Map { + if shouldInlineOperationSchema(inputType) { schemaRef = generateSchema(inputType) } else { inputSchemaName := getTypeName(inputType) @@ -313,8 +312,7 @@ func (o *OApiApp) GenerateOpenAPISpec() map[string]interface{} { var schemaRef map[string]interface{} - // For map types, use inline schema instead of reference - if outputType.Kind() == reflect.Map { + if shouldInlineOperationSchema(outputType) { schemaRef = generateSchema(outputType) } else { outputSchemaName := getTypeName(outputType) @@ -335,15 +333,22 @@ func (o *OApiApp) GenerateOpenAPISpec() map[string]interface{} { // Error response (400/500) if op.ErrorType != nil && !isEmptyStruct(op.ErrorType) { - errorSchemaName := getTypeName(op.ErrorType) + errorType := dereferenceType(op.ErrorType) + + var schemaRef map[string]interface{} + if shouldInlineOperationSchema(errorType) { + schemaRef = generateSchema(errorType) + } else { + schemaRef = map[string]interface{}{ + "$ref": "#/components/schemas/" + getTypeName(errorType), + } + } responses["400"] = map[string]interface{}{ "description": "Validation error", "content": map[string]interface{}{ "application/json": map[string]interface{}{ - "schema": map[string]interface{}{ - "$ref": "#/components/schemas/" + errorSchemaName, - }, + "schema": schemaRef, }, }, } @@ -581,6 +586,33 @@ func isTimeType(t reflect.Type) bool { return t != nil && t.Kind() == reflect.Struct && t.Name() == "Time" && t.PkgPath() == "time" } +// shouldInlineOperationSchema reports whether a top-level request, response, +// or error body schema should be inlined rather than emitted as a $ref to +// #/components/schemas/. It mirrors the registration logic in collectAllTypes: +// types that never end up in the components map (primitives, maps, time.Time, +// slices of inlinable elements) must be inlined to avoid dangling $refs. +func shouldInlineOperationSchema(t reflect.Type) bool { + if t == nil { + return false + } + t = dereferenceType(t) + if isTimeType(t) { + return true + } + switch t.Kind() { + case reflect.Map: + return true + case reflect.String, reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return true + case reflect.Slice: + return shouldInlineOperationSchema(t.Elem()) + } + return false +} + // generateSchema generates an OpenAPI schema from a Go type func generateSchema(t reflect.Type) map[string]interface{} { if t == nil { diff --git a/time_type_test.go b/time_type_test.go index d134b94..45e0b74 100644 --- a/time_type_test.go +++ b/time_type_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "io" "net/http/httptest" + "strings" "testing" "time" @@ -119,3 +120,213 @@ func TestPointerTimeTypeRendersAsDateTimeString(t *testing.T) { t.Errorf("Expected *time.Time to render as string/date-time, got %v", startedAt) } } + +func TestTimeTypeAsTopLevelHandlerRuntime(t *testing.T) { + app := fiber.New() + oapi := New(app, Config{ + EnableValidation: false, + EnableOpenAPIDocs: true, + OpenAPIDocsPath: "/docs", + }) + + Post(oapi, "/timestamp", func(c *fiber.Ctx, req time.Time) (time.Time, *ErrorResponse) { + return req, nil + }, OpenAPIOptions{OperationID: "echoTimestamp"}) + + body := strings.NewReader(`"2024-01-15T10:30:00Z"`) + req := httptest.NewRequest("POST", "/timestamp", body) + req.Header.Set("Content-Type", "application/json") + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if resp.StatusCode != 200 { + respBody, _ := io.ReadAll(resp.Body) + t.Fatalf("Expected status 200, got %d: %s", resp.StatusCode, string(respBody)) + } + + respBody, _ := io.ReadAll(resp.Body) + var got time.Time + if err := json.Unmarshal(respBody, &got); err != nil { + t.Fatalf("Expected response to be a JSON time string, got %s (err: %v)", string(respBody), err) + } + want, _ := time.Parse(time.RFC3339, "2024-01-15T10:30:00Z") + if !got.Equal(want) { + t.Errorf("Expected echoed time %v, got %v", want, got) + } +} + +func TestTimeTypeAsTopLevelInputAndOutput(t *testing.T) { + app := fiber.New() + oapi := New(app) + + Post(oapi, "/timestamp", func(c *fiber.Ctx, req *time.Time) (*time.Time, *ErrorResponse) { + return req, nil + }, OpenAPIOptions{ + OperationID: "echoTimestamp", + Tags: []string{"timestamps"}, + }) + + oapi.SetupDocs() + + req := httptest.NewRequest("GET", "/openapi.json", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + body, _ := io.ReadAll(resp.Body) + var spec map[string]interface{} + if err := json.Unmarshal(body, &spec); err != nil { + t.Fatalf("Failed to parse OpenAPI JSON: %v", err) + } + + paths := spec["paths"].(map[string]interface{}) + timestamp := paths["/timestamp"].(map[string]interface{}) + post := timestamp["post"].(map[string]interface{}) + + reqBody := post["requestBody"].(map[string]interface{}) + reqSchema := reqBody["content"].(map[string]interface{})["application/json"].(map[string]interface{})["schema"].(map[string]interface{}) + if _, hasRef := reqSchema["$ref"]; hasRef { + t.Errorf("Expected top-level time.Time request body to be inlined, got $ref: %v", reqSchema["$ref"]) + } + if reqSchema["type"] != "string" || reqSchema["format"] != "date-time" { + t.Errorf("Expected request body schema to be string/date-time, got %v", reqSchema) + } + + respSchema := post["responses"].(map[string]interface{})["200"].(map[string]interface{})["content"].(map[string]interface{})["application/json"].(map[string]interface{})["schema"].(map[string]interface{}) + if _, hasRef := respSchema["$ref"]; hasRef { + t.Errorf("Expected top-level time.Time response to be inlined, got $ref: %v", respSchema["$ref"]) + } + if respSchema["type"] != "string" || respSchema["format"] != "date-time" { + t.Errorf("Expected response schema to be string/date-time, got %v", respSchema) + } + + if schemas, ok := spec["components"].(map[string]interface{})["schemas"].(map[string]interface{}); ok { + if _, exists := schemas["Time"]; exists { + t.Errorf("time.Time should not produce a 'Time' component schema") + } + } +} + +func TestPrimitiveAndSliceTopLevelSchemasInline(t *testing.T) { + app := fiber.New() + oapi := New(app, Config{ + EnableValidation: false, + EnableOpenAPIDocs: true, + OpenAPIDocsPath: "/docs", + }) + + Post(oapi, "/echo-string", func(c *fiber.Ctx, req string) (string, *ErrorResponse) { + return req, nil + }, OpenAPIOptions{OperationID: "echoString"}) + + Post(oapi, "/echo-tags", func(c *fiber.Ctx, req []string) ([]string, *ErrorResponse) { + return req, nil + }, OpenAPIOptions{OperationID: "echoTags"}) + + oapi.SetupDocs() + + req := httptest.NewRequest("GET", "/openapi.json", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + body, _ := io.ReadAll(resp.Body) + var spec map[string]interface{} + if err := json.Unmarshal(body, &spec); err != nil { + t.Fatalf("Failed to parse OpenAPI JSON: %v", err) + } + + cases := []struct { + path string + wantType string + wantItems bool + wantElem string + schemaName string + }{ + {path: "/echo-string", wantType: "string", schemaName: "String"}, + {path: "/echo-tags", wantType: "array", wantItems: true, wantElem: "string", schemaName: "Array_String"}, + } + + paths := spec["paths"].(map[string]interface{}) + for _, tc := range cases { + op := paths[tc.path].(map[string]interface{})["post"].(map[string]interface{}) + for _, where := range []string{"requestBody", "response"} { + var schema map[string]interface{} + if where == "requestBody" { + schema = op["requestBody"].(map[string]interface{})["content"].(map[string]interface{})["application/json"].(map[string]interface{})["schema"].(map[string]interface{}) + } else { + schema = op["responses"].(map[string]interface{})["200"].(map[string]interface{})["content"].(map[string]interface{})["application/json"].(map[string]interface{})["schema"].(map[string]interface{}) + } + if _, hasRef := schema["$ref"]; hasRef { + t.Errorf("%s %s: expected inline schema, got $ref %v", tc.path, where, schema["$ref"]) + continue + } + if schema["type"] != tc.wantType { + t.Errorf("%s %s: expected type %q, got %v", tc.path, where, tc.wantType, schema["type"]) + } + if tc.wantItems { + items, ok := schema["items"].(map[string]interface{}) + if !ok { + t.Errorf("%s %s: expected items to be set", tc.path, where) + continue + } + if items["type"] != tc.wantElem { + t.Errorf("%s %s: expected items.type %q, got %v", tc.path, where, tc.wantElem, items["type"]) + } + } + } + } + + if schemas, ok := spec["components"].(map[string]interface{})["schemas"].(map[string]interface{}); ok { + for _, tc := range cases { + if _, exists := schemas[tc.schemaName]; exists { + t.Errorf("Did not expect %q to be registered as a component schema", tc.schemaName) + } + } + } +} + +func TestTimeTypeAsTopLevelErrorBody(t *testing.T) { + app := fiber.New() + oapi := New(app) + + Post(oapi, "/timestamp-error", func(c *fiber.Ctx, req *EmptyRequest) (*EmptyRequest, *time.Time) { + return &EmptyRequest{}, nil + }, OpenAPIOptions{ + OperationID: "timestampError", + Tags: []string{"timestamps"}, + }) + + oapi.SetupDocs() + + req := httptest.NewRequest("GET", "/openapi.json", nil) + resp, err := app.Test(req) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + + body, _ := io.ReadAll(resp.Body) + var spec map[string]interface{} + if err := json.Unmarshal(body, &spec); err != nil { + t.Fatalf("Failed to parse OpenAPI JSON: %v", err) + } + + post := spec["paths"].(map[string]interface{})["/timestamp-error"].(map[string]interface{})["post"].(map[string]interface{}) + errSchema := post["responses"].(map[string]interface{})["400"].(map[string]interface{})["content"].(map[string]interface{})["application/json"].(map[string]interface{})["schema"].(map[string]interface{}) + + if _, hasRef := errSchema["$ref"]; hasRef { + t.Errorf("Expected top-level time.Time error body to be inlined, got $ref: %v", errSchema["$ref"]) + } + if errSchema["type"] != "string" || errSchema["format"] != "date-time" { + t.Errorf("Expected error schema to be string/date-time, got %v", errSchema) + } + + if schemas, ok := spec["components"].(map[string]interface{})["schemas"].(map[string]interface{}); ok { + if _, exists := schemas["Time"]; exists { + t.Errorf("time.Time should not produce a 'Time' component schema") + } + } +}