diff --git a/examples/api/main.py b/examples/api/main.py index 18dd9eb..9014da6 100644 --- a/examples/api/main.py +++ b/examples/api/main.py @@ -16,6 +16,7 @@ serializer, get, post, + catcher, ) import uuid @@ -62,10 +63,7 @@ class UserInputSerializer(serializer.Serializer): def register(request: Request): new_user = UserInputSerializer(request) - try: - new_user.validate() - except Exception as e: - return str(e), Status.BAD_REQUEST + new_user.validate() password = new_user.validate_data["password"] email = new_user.validate_data["email"] @@ -85,10 +83,7 @@ def register(request: Request): def login(request: Request): user_input = UserInputSerializer(request) - try: - user_input.validate() - except Exception as e: - return str(e), Status.BAD_REQUEST + user_input.validate() email = user_input.validate_data["email"] password = user_input.validate_data["password"] @@ -129,6 +124,11 @@ def all(request: Request) -> Response: return serializer.data +@catcher(Status.INTERNAL_SERVER_ERROR) +def catch_500(req: Request, res: Response): + return {"error": res.body}, Status.BAD_REQUEST + + pub_router = Router() pub_router.routes([hello_world, login, register, add]) pub_router.middleware(logger) @@ -138,8 +138,10 @@ def all(request: Request) -> Response: sec_router.middleware(jwt_middleware) sec_router.middleware(logger) + server = HttpServer(("127.0.0.1", 5555)) server.app_data(AppData()) +server.catcher(catch_500) server.attach(sec_router) server.attach(pub_router) diff --git a/src/catcher.rs b/src/catcher.rs new file mode 100644 index 0000000..d585159 --- /dev/null +++ b/src/catcher.rs @@ -0,0 +1,35 @@ +use std::sync::Arc; + +use pyo3::prelude::*; + +use crate::status::Status; + +#[derive(Clone)] +#[pyclass] +pub struct Catcher { + pub status: Status, + pub handler: Arc>, +} + +#[pymethods] +impl Catcher { + #[new] + pub fn new(status: PyRef<'_, Status>, py: Python<'_>) -> Self { + Self { + status: status.clone(), + handler: Arc::new(py.None()), + } + } + + fn __call__(&self, handler: Py) -> PyResult { + Ok(Self { + handler: Arc::new(handler), + ..self.clone() + }) + } +} + +#[pyfunction] +pub fn catcher(status: PyRef<'_, Status>, py: Python<'_>) -> Catcher { + Catcher::new(status, py) +} diff --git a/src/handling/request_handler.rs b/src/handling/request_handler.rs index 068deef..ad83423 100644 --- a/src/handling/request_handler.rs +++ b/src/handling/request_handler.rs @@ -20,6 +20,7 @@ use crate::{ IntoPyException, MatchitRoute, ProcessRequest, }; +#[allow(clippy::too_many_arguments)] pub async fn handle_request( req: HyperRequest, request_sender: Sender, @@ -28,6 +29,7 @@ pub async fn handle_request( channel_capacity: usize, cors: Option>, template: Option>, + status_catcher: Option>>>, ) -> Result>, hyper::http::Error> { if req.method() == hyper::Method::OPTIONS && cors.is_some() { let response = cors.as_ref().unwrap().into_response().unwrap(); @@ -50,6 +52,7 @@ pub async fn handle_request( route, response_sender, cors: cors.clone(), + status_catcher, }; if request_sender.send(process_request).await.is_ok() { diff --git a/src/handling/response_handler.rs b/src/handling/response_handler.rs index 2716642..f83ffb3 100644 --- a/src/handling/response_handler.rs +++ b/src/handling/response_handler.rs @@ -21,10 +21,11 @@ pub async fn handle_response( loop { tokio::select! { Some(process_request) = request_receiver.recv() => { - let response = match process_response( + let request: &Request = &process_request.request; + let mut response = match process_response( &process_request.router, process_request.route, - &process_request.request, + request, ) { Ok(response) => response, Err(e) => Status::INTERNAL_SERVER_ERROR @@ -33,13 +34,23 @@ pub async fn handle_response( .set_body(e.to_string()), }; - let final_response = if let Some(cors) = process_request.cors { - cors.apply_to_response(response).unwrap() - } else { - response - }; + if let Some(cors) = process_request.cors { + response = cors.apply_to_response(response).unwrap(); + } + + if let Some(status_catcher) = process_request.status_catcher { + if let Some(catcher) = status_catcher.get(&response.status) { + Python::with_gil(|py| { + if let Ok(catcher_response) = catcher.call(py, (request.clone(), response.clone()), None) { + if let Ok(resp) = convert_to_response(catcher_response, py) { + response = resp; + } + } + }); + } + } - _ = process_request.response_sender.send(final_response).await; + _ = process_request.response_sender.send(response).await; } _ = shutdown_rx.recv() => {break} } diff --git a/src/lib.rs b/src/lib.rs index 915b60b..35a6c2f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod catcher; mod cors; mod handling; mod into_response; @@ -11,6 +12,7 @@ mod serializer; mod status; mod templating; +use catcher::Catcher; use cors::Cors; use handling::request_handler::handle_request; use handling::response_handler::handle_response; @@ -32,6 +34,7 @@ use tokio::sync::mpsc::{channel, Sender}; use tokio::sync::Semaphore; use std::{ + collections::HashMap, net::SocketAddr, sync::{ atomic::{AtomicBool, Ordering}, @@ -55,12 +58,14 @@ impl IntoPyException for Result { } } +#[derive(Clone)] struct ProcessRequest { request: Arc, router: Arc, route: MatchitRoute, response_sender: Sender, cors: Option>, + status_catcher: Option>>>, } #[derive(Clone)] @@ -73,6 +78,7 @@ struct HttpServer { channel_capacity: usize, cors_header: Option>, template: Option>, + status_catchers: Option>>>, } #[pymethods] @@ -88,6 +94,7 @@ impl HttpServer { channel_capacity: 100, cors_header: None, template: None, + status_catchers: None, }) } @@ -103,6 +110,15 @@ impl HttpServer { self.template = Some(Arc::new(template)) } + fn catcher(&mut self, catcher: Catcher, py: Python<'_>) { + if self.status_catchers.is_none() { + self.status_catchers = Some(Arc::new(HashMap::new())); + } + if let Some(catchers) = Arc::get_mut(self.status_catchers.as_mut().unwrap()) { + catchers.insert(catcher.status, catcher.handler.clone_ref(py)); + } + } + fn run(&self) -> PyResult<()> { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() @@ -153,6 +169,7 @@ impl HttpServer { let app_data = self.app_data.clone(); let cors = self.cors_header.clone(); let template = self.template.clone(); + let status_catcher = self.status_catchers.clone(); tokio::spawn(async move { while running_clone.load(Ordering::SeqCst) { @@ -164,6 +181,7 @@ impl HttpServer { let app_data = app_data.clone(); let cors = cors.clone(); let template = template.clone(); + let status_catcher = status_catcher.clone(); tokio::spawn(async move { let _permit = permit; @@ -176,6 +194,7 @@ impl HttpServer { let app_data = app_data.clone(); let cors = cors.clone(); let template = template.clone(); + let status_catcher = status_catcher.clone(); async move { handle_request( @@ -186,6 +205,7 @@ impl HttpServer { channel_capacity, cors, template, + status_catcher, ) .await } @@ -219,6 +239,7 @@ fn oxapy(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(head, m)?)?; m.add_function(wrap_pyfunction!(options, m)?)?; m.add_function(wrap_pyfunction!(static_file, m)?)?; + m.add_function(wrap_pyfunction!(catcher::catcher, m)?)?; templating_submodule(m)?; serializer_submodule(m)?; diff --git a/src/routing.rs b/src/routing.rs index 5d86126..63fc297 100644 --- a/src/routing.rs +++ b/src/routing.rs @@ -18,7 +18,7 @@ impl Route { #[new] #[pyo3(signature=(path, method=None))] pub fn new(path: String, method: Option, py: Python<'_>) -> Self { - Route { + Self { method: method.unwrap_or_else(|| "GET".to_string()), path, handler: Arc::new(py.None()), diff --git a/src/serializer/mod.rs b/src/serializer/mod.rs index 06cd3e1..238bddd 100644 --- a/src/serializer/mod.rs +++ b/src/serializer/mod.rs @@ -142,7 +142,7 @@ impl Serializer { results.push(dict.into()); } } - return Ok(PyList::new(py, results)?.into_py_any(py)?); + return PyList::new(py, results)?.into_py_any(py); } if let Some(inst) = slf.getattr("instance")?.extract::>()? { @@ -150,10 +150,10 @@ impl Serializer { .as_ref() .call_method1("to_representation", (inst.clone_ref(py),))?; let dict: Bound = py_repr.extract()?; - return Ok(dict.into_py_any(py)?); + return dict.into_py_any(py); } - Ok(py.None().into()) + Ok(py.None()) } } diff --git a/src/status.rs b/src/status.rs index 77da296..0a43298 100644 --- a/src/status.rs +++ b/src/status.rs @@ -5,7 +5,7 @@ use pyo3::prelude::*; use crate::{into_response::IntoResponse, response::Response}; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] #[pyclass] pub struct Status(pub u16);