Your IP : 18.116.23.103


Current Path : /opt/cloudlinux/venv/lib64/python3.11/site-packages/alembic/ddl/
Upload File :
Current File : //opt/cloudlinux/venv/lib64/python3.11/site-packages/alembic/ddl/impl.py

from __future__ import annotations

from collections import namedtuple
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union

from sqlalchemy import cast
from sqlalchemy import schema
from sqlalchemy import text

from . import base
from .. import util
from ..util import sqla_compat

if TYPE_CHECKING:
    from typing import Literal
    from typing import TextIO

    from sqlalchemy.engine import Connection
    from sqlalchemy.engine import Dialect
    from sqlalchemy.engine.cursor import CursorResult
    from sqlalchemy.engine.reflection import Inspector
    from sqlalchemy.sql.elements import ClauseElement
    from sqlalchemy.sql.elements import ColumnElement
    from sqlalchemy.sql.elements import quoted_name
    from sqlalchemy.sql.schema import Column
    from sqlalchemy.sql.schema import Constraint
    from sqlalchemy.sql.schema import ForeignKeyConstraint
    from sqlalchemy.sql.schema import Index
    from sqlalchemy.sql.schema import Table
    from sqlalchemy.sql.schema import UniqueConstraint
    from sqlalchemy.sql.selectable import TableClause
    from sqlalchemy.sql.type_api import TypeEngine

    from .base import _ServerDefault
    from ..autogenerate.api import AutogenContext
    from ..operations.batch import ApplyBatchImpl
    from ..operations.batch import BatchOperationsImpl


class ImplMeta(type):
    def __init__(
        cls,
        classname: str,
        bases: Tuple[Type[DefaultImpl]],
        dict_: Dict[str, Any],
    ):
        newtype = type.__init__(cls, classname, bases, dict_)
        if "__dialect__" in dict_:
            _impls[dict_["__dialect__"]] = cls  # type: ignore[assignment]
        return newtype


_impls: Dict[str, Type[DefaultImpl]] = {}

Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])


class DefaultImpl(metaclass=ImplMeta):

    """Provide the entrypoint for major migration operations,
    including database-specific behavioral variances.

    While individual SQL/DDL constructs already provide
    for database-specific implementations, variances here
    allow for entirely different sequences of operations
    to take place for a particular migration, such as
    SQL Server's special 'IDENTITY INSERT' step for
    bulk inserts.

    """

    __dialect__ = "default"

    transactional_ddl = False
    command_terminator = ";"
    type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
    type_arg_extract: Sequence[str] = ()
    # on_null is known to be supported only by oracle
    identity_attrs_ignore: Tuple[str, ...] = ("on_null",)

    def __init__(
        self,
        dialect: Dialect,
        connection: Optional[Connection],
        as_sql: bool,
        transactional_ddl: Optional[bool],
        output_buffer: Optional[TextIO],
        context_opts: Dict[str, Any],
    ) -> None:
        self.dialect = dialect
        self.connection = connection
        self.as_sql = as_sql
        self.literal_binds = context_opts.get("literal_binds", False)

        self.output_buffer = output_buffer
        self.memo: dict = {}
        self.context_opts = context_opts
        if transactional_ddl is not None:
            self.transactional_ddl = transactional_ddl

        if self.literal_binds:
            if not self.as_sql:
                raise util.CommandError(
                    "Can't use literal_binds setting without as_sql mode"
                )

    @classmethod
    def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]:
        return _impls[dialect.name]

    def static_output(self, text: str) -> None:
        assert self.output_buffer is not None
        self.output_buffer.write(text + "\n\n")
        self.output_buffer.flush()

    def requires_recreate_in_batch(
        self, batch_op: BatchOperationsImpl
    ) -> bool:
        """Return True if the given :class:`.BatchOperationsImpl`
        would need the table to be recreated and copied in order to
        proceed.

        Normally, only returns True on SQLite when operations other
        than add_column are present.

        """
        return False

    def prep_table_for_batch(
        self, batch_impl: ApplyBatchImpl, table: Table
    ) -> None:
        """perform any operations needed on a table before a new
        one is created to replace it in batch mode.

        the PG dialect uses this to drop constraints on the table
        before the new one uses those same names.

        """

    @property
    def bind(self) -> Optional[Connection]:
        return self.connection

    def _exec(
        self,
        construct: Union[ClauseElement, str],
        execution_options: Optional[dict[str, Any]] = None,
        multiparams: Sequence[dict] = (),
        params: Dict[str, Any] = util.immutabledict(),
    ) -> Optional[CursorResult]:
        if isinstance(construct, str):
            construct = text(construct)
        if self.as_sql:
            if multiparams or params:
                # TODO: coverage
                raise Exception("Execution arguments not allowed with as_sql")

            if self.literal_binds and not isinstance(
                construct, schema.DDLElement
            ):
                compile_kw = dict(compile_kwargs={"literal_binds": True})
            else:
                compile_kw = {}

            compiled = construct.compile(
                dialect=self.dialect, **compile_kw  # type: ignore[arg-type]
            )
            self.static_output(
                str(compiled).replace("\t", "    ").strip()
                + self.command_terminator
            )
            return None
        else:
            conn = self.connection
            assert conn is not None
            if execution_options:
                conn = conn.execution_options(**execution_options)
            if params:
                assert isinstance(multiparams, tuple)
                multiparams += (params,)

            return conn.execute(  # type: ignore[call-overload]
                construct, multiparams
            )

    def execute(
        self,
        sql: Union[ClauseElement, str],
        execution_options: Optional[dict[str, Any]] = None,
    ) -> None:
        self._exec(sql, execution_options)

    def alter_column(
        self,
        table_name: str,
        column_name: str,
        nullable: Optional[bool] = None,
        server_default: Union[_ServerDefault, Literal[False]] = False,
        name: Optional[str] = None,
        type_: Optional[TypeEngine] = None,
        schema: Optional[str] = None,
        autoincrement: Optional[bool] = None,
        comment: Optional[Union[str, Literal[False]]] = False,
        existing_comment: Optional[str] = None,
        existing_type: Optional[TypeEngine] = None,
        existing_server_default: Optional[_ServerDefault] = None,
        existing_nullable: Optional[bool] = None,
        existing_autoincrement: Optional[bool] = None,
        **kw: Any,
    ) -> None:
        if autoincrement is not None or existing_autoincrement is not None:
            util.warn(
                "autoincrement and existing_autoincrement "
                "only make sense for MySQL",
                stacklevel=3,
            )
        if nullable is not None:
            self._exec(
                base.ColumnNullable(
                    table_name,
                    column_name,
                    nullable,
                    schema=schema,
                    existing_type=existing_type,
                    existing_server_default=existing_server_default,
                    existing_nullable=existing_nullable,
                    existing_comment=existing_comment,
                )
            )
        if server_default is not False:
            kw = {}
            cls_: Type[
                Union[
                    base.ComputedColumnDefault,
                    base.IdentityColumnDefault,
                    base.ColumnDefault,
                ]
            ]
            if sqla_compat._server_default_is_computed(
                server_default, existing_server_default
            ):
                cls_ = base.ComputedColumnDefault
            elif sqla_compat._server_default_is_identity(
                server_default, existing_server_default
            ):
                cls_ = base.IdentityColumnDefault
                kw["impl"] = self
            else:
                cls_ = base.ColumnDefault
            self._exec(
                cls_(
                    table_name,
                    column_name,
                    server_default,  # type:ignore[arg-type]
                    schema=schema,
                    existing_type=existing_type,
                    existing_server_default=existing_server_default,
                    existing_nullable=existing_nullable,
                    existing_comment=existing_comment,
                    **kw,
                )
            )
        if type_ is not None:
            self._exec(
                base.ColumnType(
                    table_name,
                    column_name,
                    type_,
                    schema=schema,
                    existing_type=existing_type,
                    existing_server_default=existing_server_default,
                    existing_nullable=existing_nullable,
                    existing_comment=existing_comment,
                )
            )

        if comment is not False:
            self._exec(
                base.ColumnComment(
                    table_name,
                    column_name,
                    comment,
                    schema=schema,
                    existing_type=existing_type,
                    existing_server_default=existing_server_default,
                    existing_nullable=existing_nullable,
                    existing_comment=existing_comment,
                )
            )

        # do the new name last ;)
        if name is not None:
            self._exec(
                base.ColumnName(
                    table_name,
                    column_name,
                    name,
                    schema=schema,
                    existing_type=existing_type,
                    existing_server_default=existing_server_default,
                    existing_nullable=existing_nullable,
                )
            )

    def add_column(
        self,
        table_name: str,
        column: Column[Any],
        schema: Optional[Union[str, quoted_name]] = None,
    ) -> None:
        self._exec(base.AddColumn(table_name, column, schema=schema))

    def drop_column(
        self,
        table_name: str,
        column: Column[Any],
        schema: Optional[str] = None,
        **kw,
    ) -> None:
        self._exec(base.DropColumn(table_name, column, schema=schema))

    def add_constraint(self, const: Any) -> None:
        if const._create_rule is None or const._create_rule(self):
            self._exec(schema.AddConstraint(const))

    def drop_constraint(self, const: Constraint) -> None:
        self._exec(schema.DropConstraint(const))

    def rename_table(
        self,
        old_table_name: str,
        new_table_name: Union[str, quoted_name],
        schema: Optional[Union[str, quoted_name]] = None,
    ) -> None:
        self._exec(
            base.RenameTable(old_table_name, new_table_name, schema=schema)
        )

    def create_table(self, table: Table) -> None:
        table.dispatch.before_create(
            table, self.connection, checkfirst=False, _ddl_runner=self
        )
        self._exec(schema.CreateTable(table))
        table.dispatch.after_create(
            table, self.connection, checkfirst=False, _ddl_runner=self
        )
        for index in table.indexes:
            self._exec(schema.CreateIndex(index))

        with_comment = (
            self.dialect.supports_comments and not self.dialect.inline_comments
        )
        comment = table.comment
        if comment and with_comment:
            self.create_table_comment(table)

        for column in table.columns:
            comment = column.comment
            if comment and with_comment:
                self.create_column_comment(column)

    def drop_table(self, table: Table) -> None:
        table.dispatch.before_drop(
            table, self.connection, checkfirst=False, _ddl_runner=self
        )
        self._exec(schema.DropTable(table))
        table.dispatch.after_drop(
            table, self.connection, checkfirst=False, _ddl_runner=self
        )

    def create_index(self, index: Index) -> None:
        self._exec(schema.CreateIndex(index))

    def create_table_comment(self, table: Table) -> None:
        self._exec(schema.SetTableComment(table))

    def drop_table_comment(self, table: Table) -> None:
        self._exec(schema.DropTableComment(table))

    def create_column_comment(self, column: ColumnElement[Any]) -> None:
        self._exec(schema.SetColumnComment(column))

    def drop_index(self, index: Index) -> None:
        self._exec(schema.DropIndex(index))

    def bulk_insert(
        self,
        table: Union[TableClause, Table],
        rows: List[dict],
        multiinsert: bool = True,
    ) -> None:
        if not isinstance(rows, list):
            raise TypeError("List expected")
        elif rows and not isinstance(rows[0], dict):
            raise TypeError("List of dictionaries expected")
        if self.as_sql:
            for row in rows:
                self._exec(
                    sqla_compat._insert_inline(table).values(
                        **{
                            k: sqla_compat._literal_bindparam(
                                k, v, type_=table.c[k].type
                            )
                            if not isinstance(
                                v, sqla_compat._literal_bindparam
                            )
                            else v
                            for k, v in row.items()
                        }
                    )
                )
        else:
            if rows:
                if multiinsert:
                    self._exec(
                        sqla_compat._insert_inline(table), multiparams=rows
                    )
                else:
                    for row in rows:
                        self._exec(
                            sqla_compat._insert_inline(table).values(**row)
                        )

    def _tokenize_column_type(self, column: Column) -> Params:
        definition = self.dialect.type_compiler.process(column.type).lower()

        # tokenize the SQLAlchemy-generated version of a type, so that
        # the two can be compared.
        #
        # examples:
        # NUMERIC(10, 5)
        # TIMESTAMP WITH TIMEZONE
        # INTEGER UNSIGNED
        # INTEGER (10) UNSIGNED
        # INTEGER(10) UNSIGNED
        # varchar character set utf8
        #

        tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)

        term_tokens = []
        paren_term = None

        for token in tokens:
            if re.match(r"^\(.*\)$", token):
                paren_term = token
            else:
                term_tokens.append(token)

        params = Params(term_tokens[0], term_tokens[1:], [], {})

        if paren_term:
            for term in re.findall("[^(),]+", paren_term):
                if "=" in term:
                    key, val = term.split("=")
                    params.kwargs[key.strip()] = val.strip()
                else:
                    params.args.append(term.strip())

        return params

    def _column_types_match(
        self, inspector_params: Params, metadata_params: Params
    ) -> bool:
        if inspector_params.token0 == metadata_params.token0:
            return True

        synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms]
        inspector_all_terms = " ".join(
            [inspector_params.token0] + inspector_params.tokens
        )
        metadata_all_terms = " ".join(
            [metadata_params.token0] + metadata_params.tokens
        )

        for batch in synonyms:
            if {inspector_all_terms, metadata_all_terms}.issubset(batch) or {
                inspector_params.token0,
                metadata_params.token0,
            }.issubset(batch):
                return True
        return False

    def _column_args_match(
        self, inspected_params: Params, meta_params: Params
    ) -> bool:
        """We want to compare column parameters. However, we only want
        to compare parameters that are set. If they both have `collation`,
        we want to make sure they are the same. However, if only one
        specifies it, dont flag it for being less specific
        """

        if (
            len(meta_params.tokens) == len(inspected_params.tokens)
            and meta_params.tokens != inspected_params.tokens
        ):
            return False

        if (
            len(meta_params.args) == len(inspected_params.args)
            and meta_params.args != inspected_params.args
        ):
            return False

        insp = " ".join(inspected_params.tokens).lower()
        meta = " ".join(meta_params.tokens).lower()

        for reg in self.type_arg_extract:
            mi = re.search(reg, insp)
            mm = re.search(reg, meta)

            if mi and mm and mi.group(1) != mm.group(1):
                return False

        return True

    def compare_type(
        self, inspector_column: Column[Any], metadata_column: Column
    ) -> bool:
        """Returns True if there ARE differences between the types of the two
        columns. Takes impl.type_synonyms into account between retrospected
        and metadata types
        """
        inspector_params = self._tokenize_column_type(inspector_column)
        metadata_params = self._tokenize_column_type(metadata_column)

        if not self._column_types_match(inspector_params, metadata_params):
            return True
        if not self._column_args_match(inspector_params, metadata_params):
            return True
        return False

    def compare_server_default(
        self,
        inspector_column,
        metadata_column,
        rendered_metadata_default,
        rendered_inspector_default,
    ):
        return rendered_inspector_default != rendered_metadata_default

    def correct_for_autogen_constraints(
        self,
        conn_uniques: Set[UniqueConstraint],
        conn_indexes: Set[Index],
        metadata_unique_constraints: Set[UniqueConstraint],
        metadata_indexes: Set[Index],
    ) -> None:
        pass

    def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
        if existing.type._type_affinity is not new_type._type_affinity:
            existing_transfer["expr"] = cast(
                existing_transfer["expr"], new_type
            )

    def render_ddl_sql_expr(
        self, expr: ClauseElement, is_server_default: bool = False, **kw: Any
    ) -> str:
        """Render a SQL expression that is typically a server default,
        index expression, etc.

        .. versionadded:: 1.0.11

        """

        compile_kw = {
            "compile_kwargs": {"literal_binds": True, "include_table": False}
        }
        return str(
            expr.compile(
                dialect=self.dialect, **compile_kw  # type: ignore[arg-type]
            )
        )

    def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
        return self.autogen_column_reflect

    def correct_for_autogen_foreignkeys(
        self,
        conn_fks: Set[ForeignKeyConstraint],
        metadata_fks: Set[ForeignKeyConstraint],
    ) -> None:
        pass

    def autogen_column_reflect(self, inspector, table, column_info):
        """A hook that is attached to the 'column_reflect' event for when
        a Table is reflected from the database during the autogenerate
        process.

        Dialects can elect to modify the information gathered here.

        """

    def start_migrations(self) -> None:
        """A hook called when :meth:`.EnvironmentContext.run_migrations`
        is called.

        Implementations can set up per-migration-run state here.

        """

    def emit_begin(self) -> None:
        """Emit the string ``BEGIN``, or the backend-specific
        equivalent, on the current connection context.

        This is used in offline mode and typically
        via :meth:`.EnvironmentContext.begin_transaction`.

        """
        self.static_output("BEGIN" + self.command_terminator)

    def emit_commit(self) -> None:
        """Emit the string ``COMMIT``, or the backend-specific
        equivalent, on the current connection context.

        This is used in offline mode and typically
        via :meth:`.EnvironmentContext.begin_transaction`.

        """
        self.static_output("COMMIT" + self.command_terminator)

    def render_type(
        self, type_obj: TypeEngine, autogen_context: AutogenContext
    ) -> Union[str, Literal[False]]:
        return False

    def _compare_identity_default(self, metadata_identity, inspector_identity):

        # ignored contains the attributes that were not considered
        # because assumed to their default values in the db.
        diff, ignored = _compare_identity_options(
            sqla_compat._identity_attrs,
            metadata_identity,
            inspector_identity,
            sqla_compat.Identity(),
        )

        meta_always = getattr(metadata_identity, "always", None)
        inspector_always = getattr(inspector_identity, "always", None)
        # None and False are the same in this comparison
        if bool(meta_always) != bool(inspector_always):
            diff.add("always")

        diff.difference_update(self.identity_attrs_ignore)

        # returns 3 values:
        return (
            # different identity attributes
            diff,
            # ignored identity attributes
            ignored,
            # if the two identity should be considered different
            bool(diff) or bool(metadata_identity) != bool(inspector_identity),
        )

    def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
        # order of col matters in an index
        return tuple(col.name for col in index.columns)

    def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
        conn_indexes_by_name = {c.name: c for c in conn_indexes}

        for idx in list(metadata_indexes):
            if idx.name in conn_indexes_by_name:
                continue
            iex = sqla_compat.is_expression_index(idx)
            if iex:
                util.warn(
                    "autogenerate skipping metadata-specified "
                    "expression-based index "
                    f"{idx.name!r}; dialect {self.__dialect__!r} under "
                    f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't "
                    "reflect these indexes so they can't be compared"
                )
                metadata_indexes.discard(idx)


def _compare_identity_options(
    attributes, metadata_io, inspector_io, default_io
):
    # this can be used for identity or sequence compare.
    # default_io is an instance of IdentityOption with all attributes to the
    # default value.
    diff = set()
    ignored_attr = set()
    for attr in attributes:
        meta_value = getattr(metadata_io, attr, None)
        default_value = getattr(default_io, attr, None)
        conn_value = getattr(inspector_io, attr, None)
        if conn_value != meta_value:
            if meta_value == default_value:
                ignored_attr.add(attr)
            else:
                diff.add(attr)
    return diff, ignored_attr

?>