diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index 7e692de..7233e7f 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -432,14 +432,19 @@ def example(request: Request) -> Response: options = HttpsOptions(**kwargs) def on_request_inner_decorator(func: _C1): + func_with_init = _core._with_init(func) + + if options.cors is not None: + wrapped_function = _cross_origin( + methods=options.cors.cors_methods, + origins=options.cors.cors_origins, + )(func_with_init) + else: + wrapped_function = func_with_init + @_functools.wraps(func) def on_request_wrapped(request: Request) -> Response: - if options.cors is not None: - return _cross_origin( - methods=options.cors.cors_methods, - origins=options.cors.cors_origins, - )(func)(request) - return _core._with_init(func)(request) + return wrapped_function(request) _util.set_func_endpoint_attr( on_request_wrapped, diff --git a/tests/test_https_fn.py b/tests/test_https_fn.py index 1748b36..7a937e2 100644 --- a/tests/test_https_fn.py +++ b/tests/test_https_fn.py @@ -9,6 +9,7 @@ from werkzeug.test import EnvironBuilder from firebase_functions import core, https_fn +from firebase_functions.options import CorsOptions class TestHttps(unittest.TestCase): @@ -42,6 +43,34 @@ def init(): self.assertEqual(hello, "world") + def test_on_request_calls_init_function_with_cors(self): + app = Flask(__name__) + + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = Mock(__name__="example_func", return_value="OK") + + with app.test_request_context("/"): + environ = EnvironBuilder( + method="POST", + json={ + "data": {"test": "value"}, + }, + ).get_environ() + request = Request(environ) + decorated_func = https_fn.on_request( + cors=CorsOptions(cors_origins="*", cors_methods="GET") + )(func) + + decorated_func(request) + + self.assertEqual(hello, "world") + def test_on_call_calls_init_function(self): app = Flask(__name__)