Your IP : 3.145.92.213


Current Path : /opt/hc_python/lib/python3.8/site-packages/sqlalchemy/testing/plugin/
Upload File :
Current File : //opt/hc_python/lib/python3.8/site-packages/sqlalchemy/testing/plugin/pytestplugin.py

# testing/plugin/pytestplugin.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors

from __future__ import annotations

import argparse
import collections
from functools import update_wrapper
import inspect
import itertools
import operator
import os
import re
import sys
from typing import TYPE_CHECKING
import uuid

import pytest

try:
    # installed by bootstrap.py
    if not TYPE_CHECKING:
        import sqla_plugin_base as plugin_base
except ImportError:
    # assume we're a package, use traditional import
    from . import plugin_base


def pytest_addoption(parser):
    group = parser.getgroup("sqlalchemy")

    def make_option(name, **kw):
        callback_ = kw.pop("callback", None)
        if callback_:

            class CallableAction(argparse.Action):
                def __call__(
                    self, parser, namespace, values, option_string=None
                ):
                    callback_(option_string, values, parser)

            kw["action"] = CallableAction

        zeroarg_callback = kw.pop("zeroarg_callback", None)
        if zeroarg_callback:

            class CallableAction(argparse.Action):
                def __init__(
                    self,
                    option_strings,
                    dest,
                    default=False,
                    required=False,
                    help=None,  # noqa
                ):
                    super().__init__(
                        option_strings=option_strings,
                        dest=dest,
                        nargs=0,
                        const=True,
                        default=default,
                        required=required,
                        help=help,
                    )

                def __call__(
                    self, parser, namespace, values, option_string=None
                ):
                    zeroarg_callback(option_string, values, parser)

            kw["action"] = CallableAction

        group.addoption(name, **kw)

    plugin_base.setup_options(make_option)


def pytest_configure(config: pytest.Config):
    plugin_base.read_config(config.rootpath)
    if plugin_base.exclude_tags or plugin_base.include_tags:
        new_expr = " and ".join(
            list(plugin_base.include_tags)
            + [f"not {tag}" for tag in plugin_base.exclude_tags]
        )

        if config.option.markexpr:
            config.option.markexpr += f" and {new_expr}"
        else:
            config.option.markexpr = new_expr

    if config.pluginmanager.hasplugin("xdist"):
        config.pluginmanager.register(XDistHooks())

    if hasattr(config, "workerinput"):
        plugin_base.restore_important_follower_config(config.workerinput)
        plugin_base.configure_follower(config.workerinput["follower_ident"])
    else:
        if config.option.write_idents and os.path.exists(
            config.option.write_idents
        ):
            os.remove(config.option.write_idents)

    plugin_base.pre_begin(config.option)

    plugin_base.set_coverage_flag(
        bool(getattr(config.option, "cov_source", False))
    )

    plugin_base.set_fixture_functions(PytestFixtureFunctions)

    if config.option.dump_pyannotate:
        global DUMP_PYANNOTATE
        DUMP_PYANNOTATE = True


DUMP_PYANNOTATE = False


@pytest.fixture(autouse=True)
def collect_types_fixture():
    if DUMP_PYANNOTATE:
        from pyannotate_runtime import collect_types

        collect_types.start()
    yield
    if DUMP_PYANNOTATE:
        collect_types.stop()


def _log_sqlalchemy_info(session):
    import sqlalchemy
    from sqlalchemy import __version__
    from sqlalchemy.util import has_compiled_ext
    from sqlalchemy.util._has_cy import _CYEXTENSION_MSG

    greet = "sqlalchemy installation"
    site = "no user site" if sys.flags.no_user_site else "user site loaded"
    msgs = [
        f"SQLAlchemy {__version__} ({site})",
        f"Path: {sqlalchemy.__file__}",
    ]

    if has_compiled_ext():
        from sqlalchemy.cyextension import util

        msgs.append(f"compiled extension enabled, e.g. {util.__file__} ")
    else:
        msgs.append(f"compiled extension not enabled; {_CYEXTENSION_MSG}")

    pm = session.config.pluginmanager.get_plugin("terminalreporter")
    if pm:
        pm.write_sep("=", greet)
        for m in msgs:
            pm.write_line(m)
    else:
        # fancy pants reporter not found, fallback to plain print
        print("=" * 25, greet, "=" * 25)
        for m in msgs:
            print(m)


def pytest_sessionstart(session):
    from sqlalchemy.testing import asyncio

    _log_sqlalchemy_info(session)
    asyncio._assume_async(plugin_base.post_begin)


def pytest_sessionfinish(session):
    from sqlalchemy.testing import asyncio

    asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)

    if session.config.option.dump_pyannotate:
        from pyannotate_runtime import collect_types

        collect_types.dump_stats(session.config.option.dump_pyannotate)


def pytest_unconfigure(config):
    from sqlalchemy.testing import asyncio

    asyncio._shutdown()


def pytest_collection_finish(session):
    if session.config.option.dump_pyannotate:
        from pyannotate_runtime import collect_types

        lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")

        def _filter(filename):
            filename = os.path.normpath(os.path.abspath(filename))
            if "lib/sqlalchemy" not in os.path.commonpath(
                [filename, lib_sqlalchemy]
            ):
                return None
            if "testing" in filename:
                return None

            return filename

        collect_types.init_types_collection(filter_filename=_filter)


class XDistHooks:
    def pytest_configure_node(self, node):
        from sqlalchemy.testing import provision
        from sqlalchemy.testing import asyncio

        # the master for each node fills workerinput dictionary
        # which pytest-xdist will transfer to the subprocess

        plugin_base.memoize_important_follower_config(node.workerinput)

        node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]

        asyncio._maybe_async_provisioning(
            provision.create_follower_db, node.workerinput["follower_ident"]
        )

    def pytest_testnodedown(self, node, error):
        from sqlalchemy.testing import provision
        from sqlalchemy.testing import asyncio

        asyncio._maybe_async_provisioning(
            provision.drop_follower_db, node.workerinput["follower_ident"]
        )


def pytest_collection_modifyitems(session, config, items):
    # look for all those classes that specify __backend__ and
    # expand them out into per-database test cases.

    # this is much easier to do within pytest_pycollect_makeitem, however
    # pytest is iterating through cls.__dict__ as makeitem is
    # called which causes a "dictionary changed size" error on py3k.
    # I'd submit a pullreq for them to turn it into a list first, but
    # it's to suit the rather odd use case here which is that we are adding
    # new classes to a module on the fly.

    from sqlalchemy.testing import asyncio

    rebuilt_items = collections.defaultdict(
        lambda: collections.defaultdict(list)
    )

    items[:] = [
        item
        for item in items
        if item.getparent(pytest.Class) is not None
        and not item.getparent(pytest.Class).name.startswith("_")
    ]

    test_classes = {item.getparent(pytest.Class) for item in items}

    def collect(element):
        for inst_or_fn in element.collect():
            if isinstance(inst_or_fn, pytest.Collector):
                yield from collect(inst_or_fn)
            else:
                yield inst_or_fn

    def setup_test_classes():
        for test_class in test_classes:
            # transfer legacy __backend__ and __sparse_backend__ symbols
            # to be markers
            add_markers = set()
            if getattr(test_class.cls, "__backend__", False) or getattr(
                test_class.cls, "__only_on__", False
            ):
                add_markers = {"backend"}
            elif getattr(test_class.cls, "__sparse_backend__", False):
                add_markers = {"sparse_backend"}
            else:
                add_markers = frozenset()

            existing_markers = {
                mark.name for mark in test_class.iter_markers()
            }
            add_markers = add_markers - existing_markers
            all_markers = existing_markers.union(add_markers)

            for marker in add_markers:
                test_class.add_marker(marker)

            for sub_cls in plugin_base.generate_sub_tests(
                test_class.cls, test_class.module, all_markers
            ):
                if sub_cls is not test_class.cls:
                    per_cls_dict = rebuilt_items[test_class.cls]

                    module = test_class.getparent(pytest.Module)

                    new_cls = pytest.Class.from_parent(
                        name=sub_cls.__name__, parent=module
                    )
                    for marker in add_markers:
                        new_cls.add_marker(marker)

                    for fn in collect(new_cls):
                        per_cls_dict[fn.name].append(fn)

    # class requirements will sometimes need to access the DB to check
    # capabilities, so need to do this for async
    asyncio._maybe_async_provisioning(setup_test_classes)

    newitems = []
    for item in items:
        cls_ = item.cls
        if cls_ in rebuilt_items:
            newitems.extend(rebuilt_items[cls_][item.name])
        else:
            newitems.append(item)

    # seems like the functions attached to a test class aren't sorted already?
    # is that true and why's that? (when using unittest, they're sorted)
    items[:] = sorted(
        newitems,
        key=lambda item: (
            item.getparent(pytest.Module).name,
            item.getparent(pytest.Class).name,
            item.name,
        ),
    )


def pytest_pycollect_makeitem(collector, name, obj):
    if inspect.isclass(obj) and plugin_base.want_class(name, obj):
        from sqlalchemy.testing import config

        if config.any_async:
            obj = _apply_maybe_async(obj)

        return [
            pytest.Class.from_parent(
                name=parametrize_cls.__name__, parent=collector
            )
            for parametrize_cls in _parametrize_cls(collector.module, obj)
        ]
    elif (
        inspect.isfunction(obj)
        and collector.cls is not None
        and plugin_base.want_method(collector.cls, obj)
    ):
        # None means, fall back to default logic, which includes
        # method-level parametrize
        return None
    else:
        # empty list means skip this item
        return []


def _is_wrapped_coroutine_function(fn):
    while hasattr(fn, "__wrapped__"):
        fn = fn.__wrapped__

    return inspect.iscoroutinefunction(fn)


def _apply_maybe_async(obj, recurse=True):
    from sqlalchemy.testing import asyncio

    for name, value in vars(obj).items():
        if (
            (callable(value) or isinstance(value, classmethod))
            and not getattr(value, "_maybe_async_applied", False)
            and (name.startswith("test_"))
            and not _is_wrapped_coroutine_function(value)
        ):
            is_classmethod = False
            if isinstance(value, classmethod):
                value = value.__func__
                is_classmethod = True

            @_pytest_fn_decorator
            def make_async(fn, *args, **kwargs):
                return asyncio._maybe_async(fn, *args, **kwargs)

            do_async = make_async(value)
            if is_classmethod:
                do_async = classmethod(do_async)
            do_async._maybe_async_applied = True

            setattr(obj, name, do_async)
    if recurse:
        for cls in obj.mro()[1:]:
            if cls != object:
                _apply_maybe_async(cls, False)
    return obj


def _parametrize_cls(module, cls):
    """implement a class-based version of pytest parametrize."""

    if "_sa_parametrize" not in cls.__dict__:
        return [cls]

    _sa_parametrize = cls._sa_parametrize
    classes = []
    for full_param_set in itertools.product(
        *[params for argname, params in _sa_parametrize]
    ):
        cls_variables = {}

        for argname, param in zip(
            [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
        ):
            if not argname:
                raise TypeError("need argnames for class-based combinations")
            argname_split = re.split(r",\s*", argname)
            for arg, val in zip(argname_split, param.values):
                cls_variables[arg] = val
        parametrized_name = "_".join(
            re.sub(r"\W", "", token)
            for param in full_param_set
            for token in param.id.split("-")
        )
        name = "%s_%s" % (cls.__name__, parametrized_name)
        newcls = type.__new__(type, name, (cls,), cls_variables)
        setattr(module, name, newcls)
        classes.append(newcls)
    return classes


_current_class = None


def pytest_runtest_setup(item):
    from sqlalchemy.testing import asyncio

    # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
    # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
    # for the whole class and has to run things that are across all current
    # databases, so we run this outside of the pytest fixture system altogether
    # and ensure asyncio greenlet if any engines are async

    global _current_class

    if isinstance(item, pytest.Function) and _current_class is None:
        asyncio._maybe_async_provisioning(
            plugin_base.start_test_class_outside_fixtures,
            item.cls,
        )
        _current_class = item.getparent(pytest.Class)


@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(item, nextitem):
    # runs inside of pytest function fixture scope
    # after test function runs

    from sqlalchemy.testing import asyncio

    asyncio._maybe_async(plugin_base.after_test, item)

    yield
    # this is now after all the fixture teardown have run, the class can be
    # finalized. Since pytest v7 this finalizer can no longer be added in
    # pytest_runtest_setup since the class has not yet been setup at that
    # time.
    # See https://github.com/pytest-dev/pytest/issues/9343
    global _current_class, _current_report

    if _current_class is not None and (
        # last test or a new class
        nextitem is None
        or nextitem.getparent(pytest.Class) is not _current_class
    ):
        _current_class = None

        try:
            asyncio._maybe_async_provisioning(
                plugin_base.stop_test_class_outside_fixtures, item.cls
            )
        except Exception as e:
            # in case of an exception during teardown attach the original
            # error to the exception message, otherwise it will get lost
            if _current_report.failed:
                if not e.args:
                    e.args = (
                        "__Original test failure__:\n"
                        + _current_report.longreprtext,
                    )
                elif e.args[-1] and isinstance(e.args[-1], str):
                    args = list(e.args)
                    args[-1] += (
                        "\n__Original test failure__:\n"
                        + _current_report.longreprtext
                    )
                    e.args = tuple(args)
                else:
                    e.args += (
                        "__Original test failure__",
                        _current_report.longreprtext,
                    )
            raise
        finally:
            _current_report = None


def pytest_runtest_call(item):
    # runs inside of pytest function fixture scope
    # before test function runs

    from sqlalchemy.testing import asyncio

    asyncio._maybe_async(
        plugin_base.before_test,
        item,
        item.module.__name__,
        item.cls,
        item.name,
    )


_current_report = None


def pytest_runtest_logreport(report):
    global _current_report
    if report.when == "call":
        _current_report = report


@pytest.fixture(scope="class")
def setup_class_methods(request):
    from sqlalchemy.testing import asyncio

    cls = request.cls

    if hasattr(cls, "setup_test_class"):
        asyncio._maybe_async(cls.setup_test_class)

    yield

    if hasattr(cls, "teardown_test_class"):
        asyncio._maybe_async(cls.teardown_test_class)

    asyncio._maybe_async(plugin_base.stop_test_class, cls)


@pytest.fixture(scope="function")
def setup_test_methods(request):
    from sqlalchemy.testing import asyncio

    # called for each test

    self = request.instance

    # before this fixture runs:

    # 1. function level "autouse" fixtures under py3k (examples: TablesTest
    #    define tables / data, MappedTest define tables / mappers / data)

    # 2. was for p2k. no longer applies

    # 3. run outer xdist-style setup
    if hasattr(self, "setup_test"):
        asyncio._maybe_async(self.setup_test)

    # alembic test suite is using setUp and tearDown
    # xdist methods; support these in the test suite
    # for the near term
    if hasattr(self, "setUp"):
        asyncio._maybe_async(self.setUp)

    # inside the yield:
    # 4. function level fixtures defined on test functions themselves,
    #    e.g. "connection", "metadata" run next

    # 5. pytest hook pytest_runtest_call then runs

    # 6. test itself runs

    yield

    # yield finishes:

    # 7. function level fixtures defined on test functions
    #    themselves, e.g. "connection" rolls back the transaction, "metadata"
    #    emits drop all

    # 8. pytest hook pytest_runtest_teardown hook runs, this is associated
    #    with fixtures close all sessions, provisioning.stop_test_class(),
    #    engines.testing_reaper -> ensure all connection pool connections
    #    are returned, engines created by testing_engine that aren't the
    #    config engine are disposed

    asyncio._maybe_async(plugin_base.after_test_fixtures, self)

    # 10. run xdist-style teardown
    if hasattr(self, "tearDown"):
        asyncio._maybe_async(self.tearDown)

    if hasattr(self, "teardown_test"):
        asyncio._maybe_async(self.teardown_test)

    # 11. was for p2k. no longer applies

    # 12. function level "autouse" fixtures under py3k (examples: TablesTest /
    #    MappedTest delete table data, possibly drop tables and clear mappers
    #    depending on the flags defined by the test class)


def _pytest_fn_decorator(target):
    """Port of langhelpers.decorator with pytest-specific tricks."""

    from sqlalchemy.util.langhelpers import format_argspec_plus
    from sqlalchemy.util.compat import inspect_getfullargspec

    def _exec_code_in_env(code, env, fn_name):
        # note this is affected by "from __future__ import annotations" at
        # the top; exec'ed code will use non-evaluated annotations
        # which allows us to be more flexible with code rendering
        # in format_argpsec_plus()
        exec(code, env)
        return env[fn_name]

    def decorate(fn, add_positional_parameters=()):
        spec = inspect_getfullargspec(fn)
        if add_positional_parameters:
            spec.args.extend(add_positional_parameters)

        metadata = dict(
            __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
        )
        metadata.update(format_argspec_plus(spec, grouped=False))
        code = (
            """\
def %(name)s%(grouped_args)s:
    return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
"""
            % metadata
        )
        decorated = _exec_code_in_env(
            code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
        )
        if not add_positional_parameters:
            decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
            decorated.__wrapped__ = fn
            return update_wrapper(decorated, fn)
        else:
            # this is the pytest hacky part.  don't do a full update wrapper
            # because pytest is really being sneaky about finding the args
            # for the wrapped function
            decorated.__module__ = fn.__module__
            decorated.__name__ = fn.__name__
            if hasattr(fn, "pytestmark"):
                decorated.pytestmark = fn.pytestmark
            return decorated

    return decorate


class PytestFixtureFunctions(plugin_base.FixtureFunctions):
    def skip_test_exception(self, *arg, **kw):
        return pytest.skip.Exception(*arg, **kw)

    @property
    def add_to_marker(self):
        return pytest.mark

    def mark_base_test_class(self):
        return pytest.mark.usefixtures(
            "setup_class_methods", "setup_test_methods"
        )

    _combination_id_fns = {
        "i": lambda obj: obj,
        "r": repr,
        "s": str,
        "n": lambda obj: (
            obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__
        ),
    }

    def combinations(self, *arg_sets, **kw):
        """Facade for pytest.mark.parametrize.

        Automatically derives argument names from the callable which in our
        case is always a method on a class with positional arguments.

        ids for parameter sets are derived using an optional template.

        """
        from sqlalchemy.testing import exclusions

        if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
            arg_sets = list(arg_sets[0])

        argnames = kw.pop("argnames", None)

        def _filter_exclusions(args):
            result = []
            gathered_exclusions = []
            for a in args:
                if isinstance(a, exclusions.compound):
                    gathered_exclusions.append(a)
                else:
                    result.append(a)

            return result, gathered_exclusions

        id_ = kw.pop("id_", None)

        tobuild_pytest_params = []
        has_exclusions = False
        if id_:
            _combination_id_fns = self._combination_id_fns

            # because itemgetter is not consistent for one argument vs.
            # multiple, make it multiple in all cases and use a slice
            # to omit the first argument
            _arg_getter = operator.itemgetter(
                0,
                *[
                    idx
                    for idx, char in enumerate(id_)
                    if char in ("n", "r", "s", "a")
                ],
            )
            fns = [
                (operator.itemgetter(idx), _combination_id_fns[char])
                for idx, char in enumerate(id_)
                if char in _combination_id_fns
            ]

            for arg in arg_sets:
                if not isinstance(arg, tuple):
                    arg = (arg,)

                fn_params, param_exclusions = _filter_exclusions(arg)

                parameters = _arg_getter(fn_params)[1:]

                if param_exclusions:
                    has_exclusions = True

                tobuild_pytest_params.append(
                    (
                        parameters,
                        param_exclusions,
                        "-".join(
                            comb_fn(getter(arg)) for getter, comb_fn in fns
                        ),
                    )
                )

        else:
            for arg in arg_sets:
                if not isinstance(arg, tuple):
                    arg = (arg,)

                fn_params, param_exclusions = _filter_exclusions(arg)

                if param_exclusions:
                    has_exclusions = True

                tobuild_pytest_params.append(
                    (fn_params, param_exclusions, None)
                )

        pytest_params = []
        for parameters, param_exclusions, id_ in tobuild_pytest_params:
            if has_exclusions:
                parameters += (param_exclusions,)

            param = pytest.param(*parameters, id=id_)
            pytest_params.append(param)

        def decorate(fn):
            if inspect.isclass(fn):
                if has_exclusions:
                    raise NotImplementedError(
                        "exclusions not supported for class level combinations"
                    )
                if "_sa_parametrize" not in fn.__dict__:
                    fn._sa_parametrize = []
                fn._sa_parametrize.append((argnames, pytest_params))
                return fn
            else:
                _fn_argnames = inspect.getfullargspec(fn).args[1:]
                if argnames is None:
                    _argnames = _fn_argnames
                else:
                    _argnames = re.split(r", *", argnames)

                if has_exclusions:
                    existing_exl = sum(
                        1 for n in _fn_argnames if n.startswith("_exclusions")
                    )
                    current_exclusion_name = f"_exclusions_{existing_exl}"
                    _argnames += [current_exclusion_name]

                    @_pytest_fn_decorator
                    def check_exclusions(fn, *args, **kw):
                        _exclusions = args[-1]
                        if _exclusions:
                            exlu = exclusions.compound().add(*_exclusions)
                            fn = exlu(fn)
                        return fn(*args[:-1], **kw)

                    fn = check_exclusions(
                        fn, add_positional_parameters=(current_exclusion_name,)
                    )

                return pytest.mark.parametrize(_argnames, pytest_params)(fn)

        return decorate

    def param_ident(self, *parameters):
        ident = parameters[0]
        return pytest.param(*parameters[1:], id=ident)

    def fixture(self, *arg, **kw):
        from sqlalchemy.testing import config
        from sqlalchemy.testing import asyncio

        # wrapping pytest.fixture function.  determine if
        # decorator was called as @fixture or @fixture().
        if len(arg) > 0 and callable(arg[0]):
            # was called as @fixture(), we have the function to wrap.
            fn = arg[0]
            arg = arg[1:]
        else:
            # was called as @fixture, don't have the function yet.
            fn = None

        # create a pytest.fixture marker.  because the fn is not being
        # passed, this is always a pytest.FixtureFunctionMarker()
        # object (or whatever pytest is calling it when you read this)
        # that is waiting for a function.
        fixture = pytest.fixture(*arg, **kw)

        # now apply wrappers to the function, including fixture itself

        def wrap(fn):
            if config.any_async:
                fn = asyncio._maybe_async_wrapper(fn)
            # other wrappers may be added here

            # now apply FixtureFunctionMarker
            fn = fixture(fn)

            return fn

        if fn:
            return wrap(fn)
        else:
            return wrap

    def get_current_test_name(self):
        return os.environ.get("PYTEST_CURRENT_TEST")

    def async_test(self, fn):
        from sqlalchemy.testing import asyncio

        @_pytest_fn_decorator
        def decorate(fn, *args, **kwargs):
            asyncio._run_coroutine_function(fn, *args, **kwargs)

        return decorate(fn)

?>