Skip to content
111 changes: 105 additions & 6 deletions packages/google-auth/google/auth/transport/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from __future__ import absolute_import

import logging
import logging

_LOGGER = logging.getLogger(__name__)
from google.auth import exceptions
from google.auth.transport import _mtls_helper
from google.oauth2 import service_account
Expand Down Expand Up @@ -208,7 +210,7 @@ def my_client_cert_callback():

channel = google.auth.transport.grpc.secure_authorized_channel(
credentials, request, mtls_endpoint)

Args:
credentials (google.auth.credentials.Credentials): The credentials to
add to requests.
Expand Down Expand Up @@ -253,6 +255,7 @@ def my_client_cert_callback():
)

# If SSL credentials are not explicitly set, try client_cert_callback and ADC.
cached_cert = None
if not ssl_credentials:
use_client_cert = _mtls_helper.check_use_client_cert()
if use_client_cert and client_cert_callback:
Expand All @@ -261,20 +264,40 @@ def my_client_cert_callback():
ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
cached_cert = cert
elif use_client_cert:
# Use application default SSL credentials.
adc_ssl_credentils = SslCredentials()
ssl_credentials = adc_ssl_credentils.ssl_credentials
adc_ssl_credentials = SslCredentials()
ssl_credentials = adc_ssl_credentials.ssl_credentials
cached_cert = adc_ssl_credentials._cached_cert
else:
ssl_credentials = grpc.ssl_channel_credentials()

# Combine the ssl credentials and the authorization credentials.
composite_credentials = grpc.composite_channel_credentials(
ssl_credentials, google_auth_credentials
)

return grpc.secure_channel(target, composite_credentials, **kwargs)

is_retry = kwargs.pop("_is_retry", False)
channel = grpc.secure_channel(target, composite_credentials, **kwargs)
# Check if we are already inside a retry to avoid infinite recursion
if cached_cert and not is_retry:
# Package arguments to recreate the channel if rotation occurs
factory_args = {
"credentials": credentials,
"request": request,
"target": target,
"ssl_credentials": None,
"client_cert_callback": client_cert_callback,
"_is_retry": True, # Hidden flag to stop recursion
**kwargs
}
interceptor = _MTLSCallInterceptor()

wrapper = _MTLSRefreshingChannel(target, factory_args, channel, cached_cert)

interceptor._wrapper = wrapper
return grpc.intercept_channel(wrapper, interceptor)
return channel

class SslCredentials:
"""Class for application default SSL credentials.
Expand All @@ -292,6 +315,7 @@ class SslCredentials:

def __init__(self):
use_client_cert = _mtls_helper.check_use_client_cert()
self._cached_cert = None
if not use_client_cert:
self._is_mtls = False
else:
Expand Down Expand Up @@ -323,6 +347,7 @@ def ssl_credentials(self):
self._ssl_credentials = grpc.ssl_channel_credentials(
certificate_chain=cert, private_key=key
)
self._cached_cert = cert
except exceptions.ClientCertError as caught_exc:
new_exc = exceptions.MutualTLSChannelError(caught_exc)
raise new_exc from caught_exc
Expand All @@ -335,3 +360,77 @@ def ssl_credentials(self):
def is_mtls(self):
"""Indicates if the created SSL channel credentials is mutual TLS."""
return self._is_mtls

class _MTLSCallInterceptor(grpc.UnaryUnaryClientInterceptor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _MTLSCallInterceptor currently only implements intercept_unary_unary. To fully support certificate rotation for all gRPC call types, it should also implement intercept_unary_stream, intercept_stream_unary, and intercept_stream_stream.

def __init__(self):
self._wrapper = None
self._max_retries = 2 # Set your desired limit here

def _should_retry(self, code, retry_count):
if code != grpc.StatusCode.UNAUTHENTICATED or not self._wrapper:
return False

if retry_count >= self._max_retries:
_LOGGER.debug("Max retries reached (%d/%d).", retry_count, self._max_retries)
return False

# Fingerprint check logic
_, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._wrapper._cached_cert)
return cached_fp != current_fp

def intercept_unary_unary(self, continuation, client_call_details, request):
retry_count = 0

while True:
try:
# Every time we call continuation(), our Wrapper (which is the channel
# being intercepted) will point to its CURRENT active raw channel.
response = continuation(client_call_details, request)
status_code = response.code()
except grpc.RpcError as e:
status_code = e.code()
if not self._should_retry(status_code, retry_count):
raise e
# If we should retry, we fall through to the refresh logic below

if self._should_retry(status_code, retry_count):
retry_count += 1
# Tell the wrapper to swap the channel.
# We don't need the wrapper to execute the retry; the loop does it!
self._wrapper.refresh_logic(retry_count)
continue # Jump back to the start of the while loop

return response

class _MTLSRefreshingChannel(grpc.Channel):
def __init__(self, target, factory_args, initial_channel, initial_cert):
self._target = target
self._factory_args = factory_args
self._channel = initial_channel
self._cached_cert = initial_cert
self._lock = threading.Lock()
Comment on lines +406 to +411
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

_MTLSRefreshingChannel should implement __enter__ and __exit__ to support being used as a context manager, which is a standard pattern for gRPC channels.

    def __init__(self, target, factory_args, initial_channel, initial_cert):
        self._target = target
        self._factory_args = factory_args
        self._channel = initial_channel
        self._cached_cert = initial_cert
        self._lock = threading.Lock()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()


def refresh_logic(self, count):
with self._lock:
# Re-check inside lock to prevent race conditions
_, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._cached_cert)
if cached_fp != current_fp:
_LOGGER.debug("Wrapper: Refreshing mTLS channel. Retry count: %d", count)
old_channel = self._channel
self._channel = secure_authorized_channel(**self._factory_args)

creds = _mtls_helper.get_client_ssl_credentials()
self._cached_cert = creds[1]
old_channel.close()

def unary_unary(self, method, *args, **kwargs):
# Always return a callable from the CURRENT channel
return self._channel.unary_unary(method, *args, **kwargs)

# Mandatory passthroughs
def unary_stream(self, method, *args, **kwargs): return self._channel.unary_stream(method, *args, **kwargs)
def stream_unary(self, method, *args, **kwargs): return self._channel.stream_unary(method, *args, **kwargs)
def stream_stream(self, method, *args, **kwargs): return self._channel.stream_stream(method, *args, **kwargs)
def subscribe(self, *args, **kwargs): return self._channel.subscribe(*args, **kwargs)
def unsubscribe(self, *args, **kwargs): return self._channel.unsubscribe(*args, **kwargs)
def close(self): self._channel.close()
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ async def test_unsupported_session(self):
with pytest.raises(ValueError):
await aiohttp_requests.Request(http)

def test_timeout(self):
http = mock.create_autospec(
aiohttp.ClientSession, instance=True, _auto_decompress=False
)
request = aiohttp_requests.Request(http)
request(url="http://example.com", method="GET", timeout=5)
@pytest.mark.asyncio
async def test_timeout(self):
http = mock.create_autospec(
aiohttp.ClientSession, instance=True, _auto_decompress=False
)
request = aiohttp_requests.Request(http)
await request(url="http://example.com", method="GET", timeout=5)



class CredentialsStub(google.auth._credentials_async.Credentials):
Expand Down
Loading