diff --git a/packages/google-auth/google/auth/transport/grpc.py b/packages/google-auth/google/auth/transport/grpc.py index e541d20ca0a4..f7e2d036d21e 100644 --- a/packages/google-auth/google/auth/transport/grpc.py +++ b/packages/google-auth/google/auth/transport/grpc.py @@ -16,8 +16,10 @@ from __future__ import absolute_import +import logging import logging +_LOGGER = logging.getLogger(__name__) from google.auth import exceptions from google.auth.transport import _mtls_helper from google.oauth2 import service_account @@ -208,7 +210,7 @@ def my_client_cert_callback(): channel = google.auth.transport.grpc.secure_authorized_channel( credentials, request, mtls_endpoint) - + Args: credentials (google.auth.credentials.Credentials): The credentials to add to requests. @@ -253,6 +255,7 @@ def my_client_cert_callback(): ) # If SSL credentials are not explicitly set, try client_cert_callback and ADC. + cached_cert = None if not ssl_credentials: use_client_cert = _mtls_helper.check_use_client_cert() if use_client_cert and client_cert_callback: @@ -261,10 +264,12 @@ def my_client_cert_callback(): ssl_credentials = grpc.ssl_channel_credentials( certificate_chain=cert, private_key=key ) + cached_cert = cert elif use_client_cert: # Use application default SSL credentials. - adc_ssl_credentils = SslCredentials() - ssl_credentials = adc_ssl_credentils.ssl_credentials + adc_ssl_credentials = SslCredentials() + ssl_credentials = adc_ssl_credentials.ssl_credentials + cached_cert = adc_ssl_credentials._cached_cert else: ssl_credentials = grpc.ssl_channel_credentials() @@ -272,9 +277,27 @@ def my_client_cert_callback(): composite_credentials = grpc.composite_channel_credentials( ssl_credentials, google_auth_credentials ) - - return grpc.secure_channel(target, composite_credentials, **kwargs) - + is_retry = kwargs.pop("_is_retry", False) + channel = grpc.secure_channel(target, composite_credentials, **kwargs) + # Check if we are already inside a retry to avoid infinite recursion + if cached_cert and not is_retry: + # Package arguments to recreate the channel if rotation occurs + factory_args = { + "credentials": credentials, + "request": request, + "target": target, + "ssl_credentials": None, + "client_cert_callback": client_cert_callback, + "_is_retry": True, # Hidden flag to stop recursion + **kwargs + } + interceptor = _MTLSCallInterceptor() + + wrapper = _MTLSRefreshingChannel(target, factory_args, channel, cached_cert) + + interceptor._wrapper = wrapper + return grpc.intercept_channel(wrapper, interceptor) + return channel class SslCredentials: """Class for application default SSL credentials. @@ -292,6 +315,7 @@ class SslCredentials: def __init__(self): use_client_cert = _mtls_helper.check_use_client_cert() + self._cached_cert = None if not use_client_cert: self._is_mtls = False else: @@ -323,6 +347,7 @@ def ssl_credentials(self): self._ssl_credentials = grpc.ssl_channel_credentials( certificate_chain=cert, private_key=key ) + self._cached_cert = cert except exceptions.ClientCertError as caught_exc: new_exc = exceptions.MutualTLSChannelError(caught_exc) raise new_exc from caught_exc @@ -335,3 +360,77 @@ def ssl_credentials(self): def is_mtls(self): """Indicates if the created SSL channel credentials is mutual TLS.""" return self._is_mtls + +class _MTLSCallInterceptor(grpc.UnaryUnaryClientInterceptor): + def __init__(self): + self._wrapper = None + self._max_retries = 2 # Set your desired limit here + + def _should_retry(self, code, retry_count): + if code != grpc.StatusCode.UNAUTHENTICATED or not self._wrapper: + return False + + if retry_count >= self._max_retries: + _LOGGER.debug("Max retries reached (%d/%d).", retry_count, self._max_retries) + return False + + # Fingerprint check logic + _, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._wrapper._cached_cert) + return cached_fp != current_fp + + def intercept_unary_unary(self, continuation, client_call_details, request): + retry_count = 0 + + while True: + try: + # Every time we call continuation(), our Wrapper (which is the channel + # being intercepted) will point to its CURRENT active raw channel. + response = continuation(client_call_details, request) + status_code = response.code() + except grpc.RpcError as e: + status_code = e.code() + if not self._should_retry(status_code, retry_count): + raise e + # If we should retry, we fall through to the refresh logic below + + if self._should_retry(status_code, retry_count): + retry_count += 1 + # Tell the wrapper to swap the channel. + # We don't need the wrapper to execute the retry; the loop does it! + self._wrapper.refresh_logic(retry_count) + continue # Jump back to the start of the while loop + + return response + +class _MTLSRefreshingChannel(grpc.Channel): + def __init__(self, target, factory_args, initial_channel, initial_cert): + self._target = target + self._factory_args = factory_args + self._channel = initial_channel + self._cached_cert = initial_cert + self._lock = threading.Lock() + + def refresh_logic(self, count): + with self._lock: + # Re-check inside lock to prevent race conditions + _, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._cached_cert) + if cached_fp != current_fp: + _LOGGER.debug("Wrapper: Refreshing mTLS channel. Retry count: %d", count) + old_channel = self._channel + self._channel = secure_authorized_channel(**self._factory_args) + + creds = _mtls_helper.get_client_ssl_credentials() + self._cached_cert = creds[1] + old_channel.close() + + def unary_unary(self, method, *args, **kwargs): + # Always return a callable from the CURRENT channel + return self._channel.unary_unary(method, *args, **kwargs) + + # Mandatory passthroughs + def unary_stream(self, method, *args, **kwargs): return self._channel.unary_stream(method, *args, **kwargs) + def stream_unary(self, method, *args, **kwargs): return self._channel.stream_unary(method, *args, **kwargs) + def stream_stream(self, method, *args, **kwargs): return self._channel.stream_stream(method, *args, **kwargs) + def subscribe(self, *args, **kwargs): return self._channel.subscribe(*args, **kwargs) + def unsubscribe(self, *args, **kwargs): return self._channel.unsubscribe(*args, **kwargs) + def close(self): self._channel.close() diff --git a/packages/google-auth/tests_async/transport/test_aiohttp_requests.py b/packages/google-auth/tests_async/transport/test_aiohttp_requests.py index d6a24da2e302..4f4a41265d34 100644 --- a/packages/google-auth/tests_async/transport/test_aiohttp_requests.py +++ b/packages/google-auth/tests_async/transport/test_aiohttp_requests.py @@ -121,12 +121,14 @@ async def test_unsupported_session(self): with pytest.raises(ValueError): await aiohttp_requests.Request(http) - def test_timeout(self): - http = mock.create_autospec( - aiohttp.ClientSession, instance=True, _auto_decompress=False - ) - request = aiohttp_requests.Request(http) - request(url="http://example.com", method="GET", timeout=5) + @pytest.mark.asyncio + async def test_timeout(self): + http = mock.create_autospec( + aiohttp.ClientSession, instance=True, _auto_decompress=False + ) + request = aiohttp_requests.Request(http) + await request(url="http://example.com", method="GET", timeout=5) + class CredentialsStub(google.auth._credentials_async.Credentials):