diff --git a/tests/conftest.py b/tests/conftest.py index 796cfaf310..b8327622f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1080,6 +1080,28 @@ def inner(response_content, serialize_pydantic=False, request_headers=None): return inner +@pytest.fixture +def get_rate_limit_model_response(): + def inner(request_headers=None): + if request_headers is None: + request_headers = {} + + model_request = HttpxRequest( + "POST", + "/responses", + headers=request_headers, + ) + + response = HttpxResponse( + 429, + request=model_request, + ) + + return response + + return inner + + @pytest.fixture def streaming_chat_completions_model_response(): return [ diff --git a/tests/integrations/litellm/test_litellm.py b/tests/integrations/litellm/test_litellm.py index 40a7dd00c4..107e0b29ad 100644 --- a/tests/integrations/litellm/test_litellm.py +++ b/tests/integrations/litellm/test_litellm.py @@ -465,7 +465,9 @@ def test_embeddings_no_pii( assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"] -def test_exception_handling(sentry_init, capture_events): +def test_exception_handling( + reset_litellm_executor, sentry_init, capture_events, get_rate_limit_model_response +): sentry_init( integrations=[LiteLLMIntegration()], traces_sample_rate=1.0, @@ -474,19 +476,22 @@ def test_exception_handling(sentry_init, capture_events): messages = [{"role": "user", "content": "Hello!"}] - with start_transaction(name="litellm test"): - kwargs = { - "model": "gpt-3.5-turbo", - "messages": messages, - } + client = OpenAI(api_key="z") - _input_callback(kwargs) - _failure_callback( - kwargs, - Exception("API rate limit reached"), - datetime.now(), - datetime.now(), - ) + model_response = get_rate_limit_model_response() + + with mock.patch.object( + client.completions._client._client, + "send", + return_value=model_response, + ): + with start_transaction(name="litellm test"): + with pytest.raises(litellm.RateLimitError): + litellm.completion( + model="gpt-3.5-turbo", + messages=messages, + client=client, + ) # Should have error event and transaction assert len(events) >= 1