Your IP : 18.117.168.40
# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls
# mypy: no-warn-return-any, allow-any-generics
from __future__ import annotations
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import Generic
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import UniqueConstraint
from typing_extensions import TypeGuard
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from alembic.autogenerate.api import AutogenContext
from alembic.ddl.impl import DefaultImpl
CompareConstraintType = Union[Constraint, Index]
_C = TypeVar("_C", bound=CompareConstraintType)
_clsreg: Dict[str, Type[_constraint_sig]] = {}
class ComparisonResult(NamedTuple):
status: Literal["equal", "different", "skip"]
message: str
@property
def is_equal(self) -> bool:
return self.status == "equal"
@property
def is_different(self) -> bool:
return self.status == "different"
@property
def is_skip(self) -> bool:
return self.status == "skip"
@classmethod
def Equal(cls) -> ComparisonResult:
"""the constraints are equal."""
return cls("equal", "The two constraints are equal")
@classmethod
def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraints are different for the provided reason(s)."""
return cls("different", ", ".join(util.to_list(reason)))
@classmethod
def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult:
"""the constraint cannot be compared for the provided reason(s).
The message is logged, but the constraints will be otherwise
considered equal, meaning that no migration command will be
generated.
"""
return cls("skip", ", ".join(util.to_list(reason)))
class _constraint_sig(Generic[_C]):
const: _C
_sig: Tuple[Any, ...]
name: Optional[sqla_compat._ConstraintNameDefined]
impl: DefaultImpl
_is_index: ClassVar[bool] = False
_is_fk: ClassVar[bool] = False
_is_uq: ClassVar[bool] = False
_is_metadata: bool
def __init_subclass__(cls) -> None:
cls._register()
@classmethod
def _register(cls):
raise NotImplementedError()
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: _C
) -> None:
raise NotImplementedError()
def compare_to_reflected(
self, other: _constraint_sig[Any]
) -> ComparisonResult:
assert self.impl is other.impl
assert self._is_metadata
assert not other._is_metadata
return self._compare_to_reflected(other)
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
raise NotImplementedError()
@classmethod
def from_constraint(
cls, is_metadata: bool, impl: DefaultImpl, constraint: _C
) -> _constraint_sig[_C]:
# these could be cached by constraint/impl, however, if the
# constraint is modified in place, then the sig is wrong. the mysql
# impl currently does this, and if we fixed that we can't be sure
# someone else might do it too, so play it safe.
sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint)
return sig
def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]:
return sqla_compat._get_constraint_final_name(
self.const, context.dialect
)
@util.memoized_property
def is_named(self):
return sqla_compat._constraint_is_named(self.const, self.impl.dialect)
@util.memoized_property
def unnamed(self) -> Tuple[Any, ...]:
return self._sig
@util.memoized_property
def unnamed_no_options(self) -> Tuple[Any, ...]:
raise NotImplementedError()
@util.memoized_property
def _full_sig(self) -> Tuple[Any, ...]:
return (self.name,) + self.unnamed
def __eq__(self, other) -> bool:
return self._full_sig == other._full_sig
def __ne__(self, other) -> bool:
return self._full_sig != other._full_sig
def __hash__(self) -> int:
return hash(self._full_sig)
class _uq_constraint_sig(_constraint_sig[UniqueConstraint]):
_is_uq = True
@classmethod
def _register(cls) -> None:
_clsreg["unique_constraint"] = cls
is_unique = True
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: UniqueConstraint,
) -> None:
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
self._sig = tuple(sorted([col.name for col in const.columns]))
self._is_metadata = is_metadata
@property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_uq_sig(conn_obj)
return self.impl.compare_unique_constraint(
metadata_obj.const, conn_obj.const
)
class _ix_constraint_sig(_constraint_sig[Index]):
_is_index = True
name: sqla_compat._ConstraintName
@classmethod
def _register(cls) -> None:
_clsreg["index"] = cls
def __init__(
self, is_metadata: bool, impl: DefaultImpl, const: Index
) -> None:
self.impl = impl
self.const = const
self.name = const.name
self.is_unique = bool(const.unique)
self._is_metadata = is_metadata
def _compare_to_reflected(
self, other: _constraint_sig[_C]
) -> ComparisonResult:
assert self._is_metadata
metadata_obj = self
conn_obj = other
assert is_index_sig(conn_obj)
return self.impl.compare_indexes(metadata_obj.const, conn_obj.const)
@util.memoized_property
def has_expressions(self):
return sqla_compat.is_expression_index(self.const)
@util.memoized_property
def column_names(self) -> Tuple[str, ...]:
return tuple([col.name for col in self.const.columns])
@util.memoized_property
def column_names_optional(self) -> Tuple[Optional[str], ...]:
return tuple(
[getattr(col, "name", None) for col in self.const.expressions]
)
@util.memoized_property
def is_named(self):
return True
@util.memoized_property
def unnamed(self):
return (self.is_unique,) + self.column_names_optional
class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]):
_is_fk = True
@classmethod
def _register(cls) -> None:
_clsreg["foreign_key_constraint"] = cls
def __init__(
self,
is_metadata: bool,
impl: DefaultImpl,
const: ForeignKeyConstraint,
) -> None:
self._is_metadata = is_metadata
self.impl = impl
self.const = const
self.name = sqla_compat.constraint_name_or_none(const.name)
(
self.source_schema,
self.source_table,
self.source_columns,
self.target_schema,
self.target_table,
self.target_columns,
onupdate,
ondelete,
deferrable,
initially,
) = sqla_compat._fk_spec(const)
self._sig: Tuple[Any, ...] = (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
) + (
(
(None if onupdate.lower() == "no action" else onupdate.lower())
if onupdate
else None
),
(
(None if ondelete.lower() == "no action" else ondelete.lower())
if ondelete
else None
),
# convert initially + deferrable into one three-state value
(
"initially_deferrable"
if initially and initially.lower() == "deferred"
else "deferrable" if deferrable else "not deferrable"
),
)
@util.memoized_property
def unnamed_no_options(self):
return (
self.source_schema,
self.source_table,
tuple(self.source_columns),
self.target_schema,
self.target_table,
tuple(self.target_columns),
)
def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]:
return sig._is_index
def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]:
return sig._is_uq
def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]:
return sig._is_fk