Skip to content

feat: add catcher feature to custom status handler (#31) #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions examples/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
serializer,
get,
post,
catcher,
)

import uuid
Expand Down Expand Up @@ -62,10 +63,7 @@ class UserInputSerializer(serializer.Serializer):
def register(request: Request):
new_user = UserInputSerializer(request)

try:
new_user.validate()
except Exception as e:
return str(e), Status.BAD_REQUEST
new_user.validate()

password = new_user.validate_data["password"]
email = new_user.validate_data["email"]
Expand All @@ -85,10 +83,7 @@ def register(request: Request):
def login(request: Request):
user_input = UserInputSerializer(request)

try:
user_input.validate()
except Exception as e:
return str(e), Status.BAD_REQUEST
user_input.validate()

email = user_input.validate_data["email"]
password = user_input.validate_data["password"]
Expand Down Expand Up @@ -129,6 +124,11 @@ def all(request: Request) -> Response:
return serializer.data


@catcher(Status.INTERNAL_SERVER_ERROR)
def catch_500(req: Request, res: Response):
return {"error": res.body}, Status.BAD_REQUEST


pub_router = Router()
pub_router.routes([hello_world, login, register, add])
pub_router.middleware(logger)
Expand All @@ -138,8 +138,10 @@ def all(request: Request) -> Response:
sec_router.middleware(jwt_middleware)
sec_router.middleware(logger)


server = HttpServer(("127.0.0.1", 5555))
server.app_data(AppData())
server.catcher(catch_500)
server.attach(sec_router)
server.attach(pub_router)

Expand Down
35 changes: 35 additions & 0 deletions src/catcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::sync::Arc;

use pyo3::prelude::*;

use crate::status::Status;

#[derive(Clone)]
#[pyclass]
pub struct Catcher {
pub status: Status,
pub handler: Arc<Py<PyAny>>,
}

#[pymethods]
impl Catcher {
#[new]
pub fn new(status: PyRef<'_, Status>, py: Python<'_>) -> Self {
Self {
status: status.clone(),
handler: Arc::new(py.None()),
}
}

fn __call__(&self, handler: Py<PyAny>) -> PyResult<Self> {
Ok(Self {
handler: Arc::new(handler),
..self.clone()
})
}
}

#[pyfunction]
pub fn catcher(status: PyRef<'_, Status>, py: Python<'_>) -> Catcher {
Catcher::new(status, py)
}
3 changes: 3 additions & 0 deletions src/handling/request_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use crate::{
IntoPyException, MatchitRoute, ProcessRequest,
};

#[allow(clippy::too_many_arguments)]
pub async fn handle_request(
req: HyperRequest<Incoming>,
request_sender: Sender<ProcessRequest>,
Expand All @@ -28,6 +29,7 @@ pub async fn handle_request(
channel_capacity: usize,
cors: Option<Arc<Cors>>,
template: Option<Arc<Template>>,
status_catcher: Option<Arc<HashMap<Status, Py<PyAny>>>>,
) -> Result<HyperResponse<Full<Bytes>>, hyper::http::Error> {
if req.method() == hyper::Method::OPTIONS && cors.is_some() {
let response = cors.as_ref().unwrap().into_response().unwrap();
Expand All @@ -50,6 +52,7 @@ pub async fn handle_request(
route,
response_sender,
cors: cors.clone(),
status_catcher,
};

if request_sender.send(process_request).await.is_ok() {
Expand Down
27 changes: 19 additions & 8 deletions src/handling/response_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ pub async fn handle_response(
loop {
tokio::select! {
Some(process_request) = request_receiver.recv() => {
let response = match process_response(
let request: &Request = &process_request.request;
let mut response = match process_response(
&process_request.router,
process_request.route,
&process_request.request,
request,
) {
Ok(response) => response,
Err(e) => Status::INTERNAL_SERVER_ERROR
Expand All @@ -33,13 +34,23 @@ pub async fn handle_response(
.set_body(e.to_string()),
};

let final_response = if let Some(cors) = process_request.cors {
cors.apply_to_response(response).unwrap()
} else {
response
};
if let Some(cors) = process_request.cors {
response = cors.apply_to_response(response).unwrap();
}

if let Some(status_catcher) = process_request.status_catcher {
if let Some(catcher) = status_catcher.get(&response.status) {
Python::with_gil(|py| {
if let Ok(catcher_response) = catcher.call(py, (request.clone(), response.clone()), None) {
if let Ok(resp) = convert_to_response(catcher_response, py) {
response = resp;
}
}
});
}
}

_ = process_request.response_sender.send(final_response).await;
_ = process_request.response_sender.send(response).await;
}
_ = shutdown_rx.recv() => {break}
}
Expand Down
21 changes: 21 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod catcher;
mod cors;
mod handling;
mod into_response;
Expand All @@ -11,6 +12,7 @@ mod serializer;
mod status;
mod templating;

use catcher::Catcher;
use cors::Cors;
use handling::request_handler::handle_request;
use handling::response_handler::handle_response;
Expand All @@ -32,6 +34,7 @@ use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::Semaphore;

use std::{
collections::HashMap,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -55,12 +58,14 @@ impl<T, E: ToString> IntoPyException<T> for Result<T, E> {
}
}

#[derive(Clone)]
struct ProcessRequest {
request: Arc<Request>,
router: Arc<Router>,
route: MatchitRoute,
response_sender: Sender<Response>,
cors: Option<Arc<Cors>>,
status_catcher: Option<Arc<HashMap<Status, Py<PyAny>>>>,
}

#[derive(Clone)]
Expand All @@ -73,6 +78,7 @@ struct HttpServer {
channel_capacity: usize,
cors_header: Option<Arc<Cors>>,
template: Option<Arc<Template>>,
status_catchers: Option<Arc<HashMap<Status, Py<PyAny>>>>,
}

#[pymethods]
Expand All @@ -88,6 +94,7 @@ impl HttpServer {
channel_capacity: 100,
cors_header: None,
template: None,
status_catchers: None,
})
}

Expand All @@ -103,6 +110,15 @@ impl HttpServer {
self.template = Some(Arc::new(template))
}

fn catcher(&mut self, catcher: Catcher, py: Python<'_>) {
if self.status_catchers.is_none() {
self.status_catchers = Some(Arc::new(HashMap::new()));
}
if let Some(catchers) = Arc::get_mut(self.status_catchers.as_mut().unwrap()) {
catchers.insert(catcher.status, catcher.handler.clone_ref(py));
}
}

fn run(&self) -> PyResult<()> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
Expand Down Expand Up @@ -153,6 +169,7 @@ impl HttpServer {
let app_data = self.app_data.clone();
let cors = self.cors_header.clone();
let template = self.template.clone();
let status_catcher = self.status_catchers.clone();

tokio::spawn(async move {
while running_clone.load(Ordering::SeqCst) {
Expand All @@ -164,6 +181,7 @@ impl HttpServer {
let app_data = app_data.clone();
let cors = cors.clone();
let template = template.clone();
let status_catcher = status_catcher.clone();

tokio::spawn(async move {
let _permit = permit;
Expand All @@ -176,6 +194,7 @@ impl HttpServer {
let app_data = app_data.clone();
let cors = cors.clone();
let template = template.clone();
let status_catcher = status_catcher.clone();

async move {
handle_request(
Expand All @@ -186,6 +205,7 @@ impl HttpServer {
channel_capacity,
cors,
template,
status_catcher,
)
.await
}
Expand Down Expand Up @@ -219,6 +239,7 @@ fn oxapy(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(head, m)?)?;
m.add_function(wrap_pyfunction!(options, m)?)?;
m.add_function(wrap_pyfunction!(static_file, m)?)?;
m.add_function(wrap_pyfunction!(catcher::catcher, m)?)?;

templating_submodule(m)?;
serializer_submodule(m)?;
Expand Down
2 changes: 1 addition & 1 deletion src/routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl Route {
#[new]
#[pyo3(signature=(path, method=None))]
pub fn new(path: String, method: Option<String>, py: Python<'_>) -> Self {
Route {
Self {
method: method.unwrap_or_else(|| "GET".to_string()),
path,
handler: Arc::new(py.None()),
Expand Down
6 changes: 3 additions & 3 deletions src/serializer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,18 @@ impl Serializer {
results.push(dict.into());
}
}
return Ok(PyList::new(py, results)?.into_py_any(py)?);
return PyList::new(py, results)?.into_py_any(py);
}

if let Some(inst) = slf.getattr("instance")?.extract::<Option<PyObject>>()? {
let py_repr = slf
.as_ref()
.call_method1("to_representation", (inst.clone_ref(py),))?;
let dict: Bound<PyDict> = py_repr.extract()?;
return Ok(dict.into_py_any(py)?);
return dict.into_py_any(py);
}

Ok(py.None().into())
Ok(py.None())
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::prelude::*;

use crate::{into_response::IntoResponse, response::Response};

#[derive(Clone)]
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
#[pyclass]
pub struct Status(pub u16);

Expand Down