Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 7 additions & 27 deletions packages/google-auth/google/auth/aio/transport/mtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,12 @@
from typing import Optional

from google.auth import exceptions
import google.auth.transport._mtls_helper
import google.auth.transport.mtls
from google.auth.transport._mtls_helper import secure_cert_key_paths

_LOGGER = logging.getLogger(__name__)


@contextlib.contextmanager
def _create_temp_file(content: bytes):
"""Creates a temporary file with the given content.

Args:
content (bytes): The content to write to the file.

Yields:
str: The path to the temporary file.
"""
# Create a temporary file that is readable only by the owner.
fd, file_path = tempfile.mkstemp()
try:
with os.fdopen(fd, "wb") as f:
f.write(content)
yield file_path
finally:
# Securely delete the file after use.
if os.path.exists(file_path):
os.remove(file_path)


def make_client_cert_ssl_context(
cert_bytes: bytes, key_bytes: bytes, passphrase: Optional[bytes] = None
) -> ssl.SSLContext:
Expand All @@ -71,13 +49,15 @@ def make_client_cert_ssl_context(
Raises:
google.auth.exceptions.TransportError: If there is an error loading the certificate.
"""
with _create_temp_file(cert_bytes) as cert_path, _create_temp_file(
key_bytes
) as key_path:
with secure_cert_key_paths(cert_bytes, key_bytes, passphrase=passphrase) as (
cert_path,
key_path,
passphrase_val,
):
try:
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
context.load_cert_chain(
certfile=cert_path, keyfile=key_path, password=passphrase
certfile=cert_path, keyfile=key_path, password=passphrase_val
)
return context
except (ssl.SSLError, OSError, IOError, ValueError, RuntimeError) as exc:
Expand Down
26 changes: 8 additions & 18 deletions packages/google-auth/google/auth/identity_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,9 @@ def __init__(self, trust_chain_path, leaf_cert_callback):

@_helpers.copy_docstring(SubjectTokenSupplier)
def get_subject_token(self, context, request):
# Import OpennSSL inline because it is an extra import only required by customers
# using mTLS.
from OpenSSL import crypto
from cryptography import x509

leaf_cert = crypto.load_certificate(
crypto.FILETYPE_PEM, self._leaf_cert_callback()
)
leaf_cert = x509.load_pem_x509_certificate(self._leaf_cert_callback())
trust_chain = self._read_trust_chain()
cert_chain = []

Expand All @@ -184,9 +180,7 @@ def get_subject_token(self, context, request):
return json.dumps(cert_chain)

def _read_trust_chain(self):
# Import OpennSSL inline because it is an extra import only required by customers
# using mTLS.
from OpenSSL import crypto
from cryptography import x509

certificate_trust_chain = []
# If no trust chain path was provided, return an empty list.
Expand All @@ -204,9 +198,7 @@ def _read_trust_chain(self):
cert_data = b"-----BEGIN CERTIFICATE-----" + cert_block
try:
# Load each certificate and add it to the trust chain.
cert = crypto.load_certificate(
crypto.FILETYPE_PEM, cert_data
)
cert = x509.load_pem_x509_certificate(cert_data)
certificate_trust_chain.append(cert)
except Exception as e:
raise exceptions.RefreshError(
Expand All @@ -221,13 +213,11 @@ def _read_trust_chain(self):
)

def _encode_cert(cert):
# Import OpennSSL inline because it is an extra import only required by customers
# using mTLS.
from OpenSSL import crypto
from cryptography.hazmat.primitives import serialization

return base64.b64encode(
crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
).decode("utf-8")
return base64.b64encode(cert.public_bytes(serialization.Encoding.DER)).decode(
"utf-8"
)


def _parse_token_data(token_content, format_type="text", subject_token_field_name=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import os
import sys

import cffi # type: ignore

from google.auth import exceptions

_LOGGER = logging.getLogger(__name__)
Expand All @@ -45,11 +43,6 @@
)


# Cast SSL_CTX* to void*
def _cast_ssl_ctx_to_void_p_pyopenssl(ssl_ctx):
return ctypes.cast(int(cffi.FFI().cast("intptr_t", ssl_ctx)), ctypes.c_void_p)


# Cast SSL_CTX* to void*
def _cast_ssl_ctx_to_void_p_stdlib(context):
return ctypes.c_void_p.from_address(
Expand Down Expand Up @@ -274,7 +267,7 @@ def attach_to_ssl_context(self, ctx):
if not self._offload_lib.ConfigureSslContext(
self._sign_callback,
ctypes.c_char_p(self._cert),
_cast_ssl_ctx_to_void_p_pyopenssl(ctx._ctx._context),
_cast_ssl_ctx_to_void_p_stdlib(ctx),
):
raise exceptions.MutualTLSChannelError(
"failed to configure ECP Offload SSL context"
Expand Down
Loading
Loading