-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat(google-auth): grpc cert rotation handling #16597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
agrawalradhika-cell
wants to merge
9
commits into
googleapis:main
Choose a base branch
from
agrawalradhika-cell:agrawalradhika-cell-patch-1
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+113
−12
Draft
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
54a513a
Update grpc.py for cert rotation handling
agrawalradhika-cell b0eafcf
Update grpc.py for cert rotation handling
agrawalradhika-cell 271c279
chore: fix the lint spaces
agrawalradhika-cell 1d6a2fe
fix: Add await for async tests
agrawalradhika-cell cbd868e
Apply suggestion from @gemini-code-assist[bot]
agrawalradhika-cell bcb6e41
Apply suggestion from @gemini-code-assist[bot]
agrawalradhika-cell f6c49cb
Apply suggestion from @gemini-code-assist[bot]
agrawalradhika-cell fe8deed
Apply suggestion from @gemini-code-assist[bot]
agrawalradhika-cell 7d39f3f
Apply suggestion from @gemini-code-assist[bot]
agrawalradhika-cell File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,8 +16,10 @@ | |
|
|
||
| from __future__ import absolute_import | ||
|
|
||
| import logging | ||
| import logging | ||
|
|
||
| _LOGGER = logging.getLogger(__name__) | ||
| from google.auth import exceptions | ||
| from google.auth.transport import _mtls_helper | ||
| from google.oauth2 import service_account | ||
|
|
@@ -208,7 +210,7 @@ def my_client_cert_callback(): | |
|
|
||
| channel = google.auth.transport.grpc.secure_authorized_channel( | ||
| credentials, request, mtls_endpoint) | ||
|
|
||
| Args: | ||
| credentials (google.auth.credentials.Credentials): The credentials to | ||
| add to requests. | ||
|
|
@@ -253,6 +255,7 @@ def my_client_cert_callback(): | |
| ) | ||
|
|
||
| # If SSL credentials are not explicitly set, try client_cert_callback and ADC. | ||
| cached_cert = None | ||
| if not ssl_credentials: | ||
| use_client_cert = _mtls_helper.check_use_client_cert() | ||
| if use_client_cert and client_cert_callback: | ||
|
|
@@ -261,20 +264,40 @@ def my_client_cert_callback(): | |
| ssl_credentials = grpc.ssl_channel_credentials( | ||
| certificate_chain=cert, private_key=key | ||
| ) | ||
| cached_cert = cert | ||
| elif use_client_cert: | ||
| # Use application default SSL credentials. | ||
| adc_ssl_credentils = SslCredentials() | ||
| ssl_credentials = adc_ssl_credentils.ssl_credentials | ||
| adc_ssl_credentials = SslCredentials() | ||
| ssl_credentials = adc_ssl_credentials.ssl_credentials | ||
| cached_cert = adc_ssl_credentials._cached_cert | ||
| else: | ||
| ssl_credentials = grpc.ssl_channel_credentials() | ||
|
|
||
| # Combine the ssl credentials and the authorization credentials. | ||
| composite_credentials = grpc.composite_channel_credentials( | ||
| ssl_credentials, google_auth_credentials | ||
| ) | ||
|
|
||
| return grpc.secure_channel(target, composite_credentials, **kwargs) | ||
|
|
||
| is_retry = kwargs.pop("_is_retry", False) | ||
| channel = grpc.secure_channel(target, composite_credentials, **kwargs) | ||
| # Check if we are already inside a retry to avoid infinite recursion | ||
| if cached_cert and not is_retry: | ||
| # Package arguments to recreate the channel if rotation occurs | ||
| factory_args = { | ||
| "credentials": credentials, | ||
| "request": request, | ||
| "target": target, | ||
| "ssl_credentials": None, | ||
| "client_cert_callback": client_cert_callback, | ||
| "_is_retry": True, # Hidden flag to stop recursion | ||
| **kwargs | ||
| } | ||
| interceptor = _MTLSCallInterceptor() | ||
|
|
||
| wrapper = _MTLSRefreshingChannel(target, factory_args, channel, cached_cert) | ||
|
|
||
| interceptor._wrapper = wrapper | ||
| return grpc.intercept_channel(wrapper, interceptor) | ||
| return channel | ||
|
|
||
| class SslCredentials: | ||
| """Class for application default SSL credentials. | ||
|
|
@@ -292,6 +315,7 @@ class SslCredentials: | |
|
|
||
| def __init__(self): | ||
| use_client_cert = _mtls_helper.check_use_client_cert() | ||
| self._cached_cert = None | ||
| if not use_client_cert: | ||
| self._is_mtls = False | ||
| else: | ||
|
|
@@ -323,6 +347,7 @@ def ssl_credentials(self): | |
| self._ssl_credentials = grpc.ssl_channel_credentials( | ||
| certificate_chain=cert, private_key=key | ||
| ) | ||
| self._cached_cert = cert | ||
| except exceptions.ClientCertError as caught_exc: | ||
| new_exc = exceptions.MutualTLSChannelError(caught_exc) | ||
| raise new_exc from caught_exc | ||
|
|
@@ -335,3 +360,77 @@ def ssl_credentials(self): | |
| def is_mtls(self): | ||
| """Indicates if the created SSL channel credentials is mutual TLS.""" | ||
| return self._is_mtls | ||
|
|
||
| class _MTLSCallInterceptor(grpc.UnaryUnaryClientInterceptor): | ||
| def __init__(self): | ||
| self._wrapper = None | ||
| self._max_retries = 2 # Set your desired limit here | ||
|
|
||
| def _should_retry(self, code, retry_count): | ||
| if code != grpc.StatusCode.UNAUTHENTICATED or not self._wrapper: | ||
| return False | ||
|
|
||
| if retry_count >= self._max_retries: | ||
| _LOGGER.debug("Max retries reached (%d/%d).", retry_count, self._max_retries) | ||
| return False | ||
|
|
||
| # Fingerprint check logic | ||
| _, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._wrapper._cached_cert) | ||
| return cached_fp != current_fp | ||
|
|
||
| def intercept_unary_unary(self, continuation, client_call_details, request): | ||
| retry_count = 0 | ||
|
|
||
| while True: | ||
| try: | ||
| # Every time we call continuation(), our Wrapper (which is the channel | ||
| # being intercepted) will point to its CURRENT active raw channel. | ||
| response = continuation(client_call_details, request) | ||
| status_code = response.code() | ||
| except grpc.RpcError as e: | ||
| status_code = e.code() | ||
| if not self._should_retry(status_code, retry_count): | ||
| raise e | ||
| # If we should retry, we fall through to the refresh logic below | ||
|
|
||
| if self._should_retry(status_code, retry_count): | ||
| retry_count += 1 | ||
| # Tell the wrapper to swap the channel. | ||
| # We don't need the wrapper to execute the retry; the loop does it! | ||
| self._wrapper.refresh_logic(retry_count) | ||
| continue # Jump back to the start of the while loop | ||
|
|
||
| return response | ||
|
|
||
| class _MTLSRefreshingChannel(grpc.Channel): | ||
| def __init__(self, target, factory_args, initial_channel, initial_cert): | ||
| self._target = target | ||
| self._factory_args = factory_args | ||
| self._channel = initial_channel | ||
| self._cached_cert = initial_cert | ||
| self._lock = threading.Lock() | ||
|
Comment on lines
+406
to
+411
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
def __init__(self, target, factory_args, initial_channel, initial_cert):
self._target = target
self._factory_args = factory_args
self._channel = initial_channel
self._cached_cert = initial_cert
self._lock = threading.Lock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close() |
||
|
|
||
| def refresh_logic(self, count): | ||
| with self._lock: | ||
| # Re-check inside lock to prevent race conditions | ||
| _, _, cached_fp, current_fp = _mtls_helper.check_parameters_for_unauthorized_response(self._cached_cert) | ||
| if cached_fp != current_fp: | ||
| _LOGGER.debug("Wrapper: Refreshing mTLS channel. Retry count: %d", count) | ||
| old_channel = self._channel | ||
| self._channel = secure_authorized_channel(**self._factory_args) | ||
|
|
||
| creds = _mtls_helper.get_client_ssl_credentials() | ||
| self._cached_cert = creds[1] | ||
| old_channel.close() | ||
|
|
||
| def unary_unary(self, method, *args, **kwargs): | ||
| # Always return a callable from the CURRENT channel | ||
| return self._channel.unary_unary(method, *args, **kwargs) | ||
|
|
||
| # Mandatory passthroughs | ||
| def unary_stream(self, method, *args, **kwargs): return self._channel.unary_stream(method, *args, **kwargs) | ||
| def stream_unary(self, method, *args, **kwargs): return self._channel.stream_unary(method, *args, **kwargs) | ||
| def stream_stream(self, method, *args, **kwargs): return self._channel.stream_stream(method, *args, **kwargs) | ||
| def subscribe(self, *args, **kwargs): return self._channel.subscribe(*args, **kwargs) | ||
| def unsubscribe(self, *args, **kwargs): return self._channel.unsubscribe(*args, **kwargs) | ||
| def close(self): self._channel.close() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_MTLSCallInterceptorcurrently only implementsintercept_unary_unary. To fully support certificate rotation for all gRPC call types, it should also implementintercept_unary_stream,intercept_stream_unary, andintercept_stream_stream.