diff --git a/Cargo.lock b/Cargo.lock index a5c5a1f3..34f2fc37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -962,6 +962,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-test", + "tower", "tower-http", "tracing", "tracing-subscriber", diff --git a/src/cortex-app-server/Cargo.toml b/src/cortex-app-server/Cargo.toml index 8da2d7bd..1560ee94 100644 --- a/src/cortex-app-server/Cargo.toml +++ b/src/cortex-app-server/Cargo.toml @@ -73,3 +73,4 @@ gethostname = "0.5" [dev-dependencies] tokio-test = { workspace = true } +tower = { version = "0.5", default-features = false, features = ["util"] } diff --git a/src/cortex-app-server/src/api/mod.rs b/src/cortex-app-server/src/api/mod.rs index fc7f16ad..8e7c9865 100644 --- a/src/cortex-app-server/src/api/mod.rs +++ b/src/cortex-app-server/src/api/mod.rs @@ -38,6 +38,7 @@ pub use types::{ /// Create the API routes. pub fn routes() -> Router> { Router::new() + .without_v07_checks() // Health and metrics .route("/health", get(health::health_check)) .route("/metrics", get(health::get_metrics)) diff --git a/src/cortex-app-server/src/lib.rs b/src/cortex-app-server/src/lib.rs index 8e7acdf8..108cb142 100644 --- a/src/cortex-app-server/src/lib.rs +++ b/src/cortex-app-server/src/lib.rs @@ -36,9 +36,8 @@ pub mod websocket; use std::net::SocketAddr; use std::sync::Arc; -use axum::Router; +use axum::{Router, middleware as axum_middleware}; use tokio::net::TcpListener; -use tower_http::cors::CorsLayer; use tower_http::trace::TraceLayer; use tracing::{info, warn}; @@ -131,15 +130,171 @@ pub fn create_router(state: AppState) -> Router { /// This variant is useful when you need to keep a reference to the state /// for cleanup purposes (e.g., during graceful shutdown). pub fn create_router_with_state(state: Arc) -> Router { - let api_routes = api::routes() - .merge(websocket::routes()) - .merge(streaming::routes()) - .merge(share::routes()) - .merge(admin::routes()); + let cors_layer = middleware::cors_layer(&state.config.cors_origins); + + let api_routes = add_api_middleware( + api::routes() + .merge(websocket::routes()) + .merge(streaming::routes()) + .merge(share::routes()) + .merge(admin::routes()), + Arc::clone(&state), + ); Router::new() + .without_v07_checks() .nest("/api/v1", api_routes) .layer(TraceLayer::new_for_http()) - .layer(CorsLayer::permissive()) + .layer(cors_layer) .with_state(state) } + +fn add_api_middleware( + router: Router>, + state: Arc, +) -> Router> { + router + .layer(axum_middleware::from_fn_with_state( + Arc::clone(&state), + middleware::rate_limit_middleware, + )) + .layer(axum_middleware::from_fn( + middleware::content_type_middleware, + )) + .layer(axum_middleware::from_fn_with_state( + state, + middleware::timeout_middleware, + )) + .layer(axum_middleware::from_fn( + middleware::security_headers_middleware, + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::{ + body::Body, + http::{Request, StatusCode, header}, + routing::get, + }; + use tower::ServiceExt; + + async fn test_app(config: ServerConfig) -> Router { + let state = AppState::new(config).await.unwrap(); + create_router(state) + } + + async fn slow_test_handler() -> &'static str { + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + "done" + } + + #[tokio::test] + async fn create_router_applies_security_headers_middleware() { + let app = test_app(ServerConfig::default()).await; + + let response = app + .oneshot( + Request::builder() + .uri("/api/v1/models") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response.headers().get("X-Content-Type-Options").unwrap(), + "nosniff" + ); + assert_eq!(response.headers().get("X-Frame-Options").unwrap(), "DENY"); + } + + #[tokio::test] + async fn create_router_applies_rate_limit_middleware() { + let mut config = ServerConfig::default(); + config.rate_limit.requests_per_minute = 1; + config.rate_limit.burst_size = 1; + config.rate_limit.exempt_paths.clear(); + + let app = test_app(config).await; + + let first = app + .clone() + .oneshot( + Request::builder() + .uri("/api/v1/models") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(first.status(), StatusCode::OK); + + let second = app + .oneshot( + Request::builder() + .uri("/api/v1/models") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(second.status(), StatusCode::TOO_MANY_REQUESTS); + assert_eq!(second.headers().get(header::RETRY_AFTER).unwrap(), "60"); + assert_eq!( + second.headers().get("X-Content-Type-Options").unwrap(), + "nosniff" + ); + } + + #[tokio::test] + async fn create_router_applies_content_type_middleware() { + let app = test_app(ServerConfig::default()).await; + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/models") + .header(header::CONTENT_TYPE, "text/plain") + .body(Body::from("not json")) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); + assert_eq!( + response.headers().get("X-Content-Type-Options").unwrap(), + "nosniff" + ); + } + + #[tokio::test] + async fn create_router_applies_timeout_middleware() { + let mut config = ServerConfig::default(); + config.request_timeout = 0; + config.rate_limit.enabled = false; + + let state = Arc::new(AppState::new(config).await.unwrap()); + let app = add_api_middleware( + Router::new().route("/slow", get(slow_test_handler)), + Arc::clone(&state), + ) + .with_state(state); + + let response = app + .oneshot(Request::builder().uri("/slow").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::GATEWAY_TIMEOUT); + assert_eq!( + response.headers().get("X-Content-Type-Options").unwrap(), + "nosniff" + ); + } +} diff --git a/src/cortex-app-server/src/share.rs b/src/cortex-app-server/src/share.rs index ad03753c..ce432afa 100644 --- a/src/cortex-app-server/src/share.rs +++ b/src/cortex-app-server/src/share.rs @@ -22,6 +22,7 @@ use crate::storage::StoredMessage; /// Create share routes. pub fn routes() -> Router> { Router::new() + .without_v07_checks() .route("/share", post(create_share)) .route("/share/:token", get(get_shared_session)) .route("/share/:token", delete(revoke_share)) diff --git a/src/cortex-app-server/src/streaming.rs b/src/cortex-app-server/src/streaming.rs index 840b1a6a..6b89493c 100644 --- a/src/cortex-app-server/src/streaming.rs +++ b/src/cortex-app-server/src/streaming.rs @@ -30,6 +30,7 @@ use crate::state::AppState; /// Create streaming API routes. pub fn routes() -> Router> { Router::new() + .without_v07_checks() // CLI Session management .route("/cli/sessions", post(create_cli_session)) .route("/cli/sessions", get(list_cli_sessions)) diff --git a/src/cortex-app-server/src/websocket.rs b/src/cortex-app-server/src/websocket.rs index f5e21845..7a34412e 100644 --- a/src/cortex-app-server/src/websocket.rs +++ b/src/cortex-app-server/src/websocket.rs @@ -28,6 +28,7 @@ use crate::state::AppState; /// Create WebSocket routes. pub fn routes() -> Router> { Router::new() + .without_v07_checks() .route("/ws", get(websocket_handler)) .route("/ws/sessions/:id", get(session_websocket_handler)) }