Your IP :
import os
import platform
import socket
import ssl
import sys
import typing
import _ssl # type: ignore[import-not-found]
from ._ssl_constants import (
if platform.system() == "Windows":
from ._windows import _configure_context, _verify_peercerts_impl
elif platform.system() == "Darwin":
from ._macos import _configure_context, _verify_peercerts_impl
from ._openssl import _configure_context, _verify_peercerts_impl
if typing.TYPE_CHECKING:
from pip._vendor.typing_extensions import Buffer
# From typeshed/stdlib/ssl.pyi
_StrOrBytesPath: typing.TypeAlias = str | bytes | os.PathLike[str] | os.PathLike[bytes]
_PasswordType: typing.TypeAlias = str | bytes | typing.Callable[[], str | bytes]
def inject_into_ssl() -> None:
"""Injects the :class:`truststore.SSLContext` into the ``ssl``
module by replacing :class:`ssl.SSLContext`.
setattr(ssl, "SSLContext", SSLContext)
# urllib3 holds on to its own reference of ssl.SSLContext
# so we need to replace that reference too.
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
setattr(urllib3_ssl, "SSLContext", SSLContext)
except ImportError:
def extract_from_ssl() -> None:
"""Restores the :class:`ssl.SSLContext` class to its original state"""
setattr(ssl, "SSLContext", _original_SSLContext)
import pip._vendor.urllib3.util.ssl_ as urllib3_ssl
urllib3_ssl.SSLContext = _original_SSLContext # type: ignore[assignment]
except ImportError:
class SSLContext(_truststore_SSLContext_super_class): # type: ignore[misc]
"""SSLContext API that uses system certificates on all platforms"""
@property # type: ignore[misc]
def __class__(self) -> type:
# Dirty hack to get around isinstance() checks
# for ssl.SSLContext instances in aiohttp/trustme
# when using non-CPython implementations.
return _truststore_SSLContext_dunder_class or SSLContext
def __init__(self, protocol: int = None) -> None: # type: ignore[assignment]
self._ctx = _original_SSLContext(protocol)
class TruststoreSSLObject(ssl.SSLObject):
# This object exists because wrap_bio() doesn't
# immediately do the handshake so we need to do
# certificate verifications after SSLObject.do_handshake()
def do_handshake(self) -> None:
ret = super().do_handshake()
_verify_peercerts(self, server_hostname=self.server_hostname)
return ret
self._ctx.sslobject_class = TruststoreSSLObject
def wrap_socket(
sock: socket.socket,
server_side: bool = False,
do_handshake_on_connect: bool = True,
suppress_ragged_eofs: bool = True,
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLSocket:
# Use a context manager here because the
# inner SSLContext holds on to our state
# but also does the actual handshake.
with _configure_context(self._ctx):
ssl_sock = self._ctx.wrap_socket(
_verify_peercerts(ssl_sock, server_hostname=server_hostname)
except Exception:
return ssl_sock
def wrap_bio(
incoming: ssl.MemoryBIO,
outgoing: ssl.MemoryBIO,
server_side: bool = False,
server_hostname: str | None = None,
session: ssl.SSLSession | None = None,
) -> ssl.SSLObject:
with _configure_context(self._ctx):
ssl_obj = self._ctx.wrap_bio(
return ssl_obj
def load_verify_locations(
cafile: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
capath: str | bytes | os.PathLike[str] | os.PathLike[bytes] | None = None,
cadata: typing.Union[str, "Buffer", None] = None,
) -> None:
return self._ctx.load_verify_locations(
cafile=cafile, capath=capath, cadata=cadata
def load_cert_chain(
certfile: _StrOrBytesPath,
keyfile: _StrOrBytesPath | None = None,
password: _PasswordType | None = None,
) -> None:
return self._ctx.load_cert_chain(
certfile=certfile, keyfile=keyfile, password=password
def load_default_certs(
self, purpose: ssl.Purpose = ssl.Purpose.SERVER_AUTH
) -> None:
return self._ctx.load_default_certs(purpose)
def set_alpn_protocols(self, alpn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_alpn_protocols(alpn_protocols)
def set_npn_protocols(self, npn_protocols: typing.Iterable[str]) -> None:
return self._ctx.set_npn_protocols(npn_protocols)
def set_ciphers(self, __cipherlist: str) -> None:
return self._ctx.set_ciphers(__cipherlist)
def get_ciphers(self) -> typing.Any:
return self._ctx.get_ciphers()
def session_stats(self) -> dict[str, int]:
return self._ctx.session_stats()
def cert_store_stats(self) -> dict[str, int]:
raise NotImplementedError()
def set_default_verify_paths(self) -> None:
def get_ca_certs(
self, binary_form: typing.Literal[False] = ...
) -> list[typing.Any]: ...
def get_ca_certs(self, binary_form: typing.Literal[True] = ...) -> list[bytes]: ...
def get_ca_certs(self, binary_form: bool = ...) -> typing.Any: ...
def get_ca_certs(self, binary_form: bool = False) -> list[typing.Any] | list[bytes]:
raise NotImplementedError()
def check_hostname(self) -> bool:
return self._ctx.check_hostname
def check_hostname(self, value: bool) -> None:
self._ctx.check_hostname = value
def hostname_checks_common_name(self) -> bool:
return self._ctx.hostname_checks_common_name
def hostname_checks_common_name(self, value: bool) -> None:
self._ctx.hostname_checks_common_name = value
def keylog_filename(self) -> str:
return self._ctx.keylog_filename
def keylog_filename(self, value: str) -> None:
self._ctx.keylog_filename = value
def maximum_version(self) -> ssl.TLSVersion:
return self._ctx.maximum_version
def maximum_version(self, value: ssl.TLSVersion) -> None:
_original_super_SSLContext.maximum_version.__set__( # type: ignore[attr-defined]
self._ctx, value
def minimum_version(self) -> ssl.TLSVersion:
return self._ctx.minimum_version
def minimum_version(self, value: ssl.TLSVersion) -> None:
_original_super_SSLContext.minimum_version.__set__( # type: ignore[attr-defined]
self._ctx, value
def options(self) -> ssl.Options:
return self._ctx.options
def options(self, value: ssl.Options) -> None:
_original_super_SSLContext.options.__set__( # type: ignore[attr-defined]
self._ctx, value
def post_handshake_auth(self) -> bool:
return self._ctx.post_handshake_auth
def post_handshake_auth(self, value: bool) -> None:
self._ctx.post_handshake_auth = value
def protocol(self) -> ssl._SSLMethod:
return self._ctx.protocol
def security_level(self) -> int:
return self._ctx.security_level
def verify_flags(self) -> ssl.VerifyFlags:
return self._ctx.verify_flags
def verify_flags(self, value: ssl.VerifyFlags) -> None:
_original_super_SSLContext.verify_flags.__set__( # type: ignore[attr-defined]
self._ctx, value
def verify_mode(self) -> ssl.VerifyMode:
return self._ctx.verify_mode
def verify_mode(self, value: ssl.VerifyMode) -> None:
_original_super_SSLContext.verify_mode.__set__( # type: ignore[attr-defined]
self._ctx, value
# Python 3.13+ makes get_unverified_chain() a public API that only returns DER
# encoded certificates. We detect whether we need to call public_bytes() for 3.10->3.12
# Pre-3.13 returned None instead of an empty list from get_unverified_chain()
if sys.version_info >= (3, 13):
def _get_unverified_chain_bytes(sslobj: ssl.SSLObject) -> list[bytes]:
unverified_chain = sslobj.get_unverified_chain() or () # type: ignore[attr-defined]
return [
cert if isinstance(cert, bytes) else cert.public_bytes(_ssl.ENCODING_DER)
for cert in unverified_chain
def _get_unverified_chain_bytes(sslobj: ssl.SSLObject) -> list[bytes]:
unverified_chain = sslobj.get_unverified_chain() or () # type: ignore[attr-defined]
return [cert.public_bytes(_ssl.ENCODING_DER) for cert in unverified_chain]
def _verify_peercerts(
sock_or_sslobj: ssl.SSLSocket | ssl.SSLObject, server_hostname: str | None
) -> None:
Verifies the peer certificates from an SSLSocket or SSLObject
against the certificates in the OS trust store.
sslobj: ssl.SSLObject = sock_or_sslobj # type: ignore[assignment]
while not hasattr(sslobj, "get_unverified_chain"):
sslobj = sslobj._sslobj # type: ignore[attr-defined]
except AttributeError:
cert_bytes = _get_unverified_chain_bytes(sslobj)
sock_or_sslobj.context, cert_bytes, server_hostname=server_hostname