Your IP : 18.222.115.150
# ext/declarative/clsregistry.py
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Routines to handle the string class registry used by declarative.
This system allows specification of classes and expressions used in
:func:`_orm.relationship` using strings.
"""
import weakref
from ... import exc
from ... import inspection
from ... import util
from ...orm import class_mapper
from ...orm import interfaces
from ...orm.properties import ColumnProperty
from ...orm.properties import RelationshipProperty
from ...orm.properties import SynonymProperty
from ...schema import _get_table_key
# strong references to registries which we place in
# the _decl_class_registry, which is usually weak referencing.
# the internal registries here link to classes with weakrefs and remove
# themselves when all references to contained classes are removed.
_registries = set()
def add_class(classname, cls):
"""Add a class to the _decl_class_registry associated with the
given declarative class.
"""
if classname in cls._decl_class_registry:
# class already exists.
existing = cls._decl_class_registry[classname]
if not isinstance(existing, _MultipleClassMarker):
existing = cls._decl_class_registry[
classname
] = _MultipleClassMarker([cls, existing])
else:
cls._decl_class_registry[classname] = cls
try:
root_module = cls._decl_class_registry["_sa_module_registry"]
except KeyError:
cls._decl_class_registry[
"_sa_module_registry"
] = root_module = _ModuleMarker("_sa_module_registry", None)
tokens = cls.__module__.split(".")
# build up a tree like this:
# modulename: myapp.snacks.nuts
#
# myapp->snack->nuts->(classes)
# snack->nuts->(classes)
# nuts->(classes)
#
# this allows partial token paths to be used.
while tokens:
token = tokens.pop(0)
module = root_module.get_module(token)
for token in tokens:
module = module.get_module(token)
module.add_class(classname, cls)
class _MultipleClassMarker(object):
"""refers to multiple classes of the same name
within _decl_class_registry.
"""
__slots__ = "on_remove", "contents", "__weakref__"
def __init__(self, classes, on_remove=None):
self.on_remove = on_remove
self.contents = set(
[weakref.ref(item, self._remove_item) for item in classes]
)
_registries.add(self)
def __iter__(self):
return (ref() for ref in self.contents)
def attempt_get(self, path, key):
if len(self.contents) > 1:
raise exc.InvalidRequestError(
'Multiple classes found for path "%s" '
"in the registry of this declarative "
"base. Please use a fully module-qualified path."
% (".".join(path + [key]))
)
else:
ref = list(self.contents)[0]
cls = ref()
if cls is None:
raise NameError(key)
return cls
def _remove_item(self, ref):
self.contents.remove(ref)
if not self.contents:
_registries.discard(self)
if self.on_remove:
self.on_remove()
def add_item(self, item):
# protect against class registration race condition against
# asynchronous garbage collection calling _remove_item,
# [ticket:3208]
modules = set(
[
cls.__module__
for cls in [ref() for ref in self.contents]
if cls is not None
]
)
if item.__module__ in modules:
util.warn(
"This declarative base already contains a class with the "
"same class name and module name as %s.%s, and will "
"be replaced in the string-lookup table."
% (item.__module__, item.__name__)
)
self.contents.add(weakref.ref(item, self._remove_item))
class _ModuleMarker(object):
"""Refers to a module name within
_decl_class_registry.
"""
__slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
def __init__(self, name, parent):
self.parent = parent
self.name = name
self.contents = {}
self.mod_ns = _ModNS(self)
if self.parent:
self.path = self.parent.path + [self.name]
else:
self.path = []
_registries.add(self)
def __contains__(self, name):
return name in self.contents
def __getitem__(self, name):
return self.contents[name]
def _remove_item(self, name):
self.contents.pop(name, None)
if not self.contents and self.parent is not None:
self.parent._remove_item(self.name)
_registries.discard(self)
def resolve_attr(self, key):
return getattr(self.mod_ns, key)
def get_module(self, name):
if name not in self.contents:
marker = _ModuleMarker(name, self)
self.contents[name] = marker
else:
marker = self.contents[name]
return marker
def add_class(self, name, cls):
if name in self.contents:
existing = self.contents[name]
existing.add_item(cls)
else:
existing = self.contents[name] = _MultipleClassMarker(
[cls], on_remove=lambda: self._remove_item(name)
)
class _ModNS(object):
__slots__ = ("__parent",)
def __init__(self, parent):
self.__parent = parent
def __getattr__(self, key):
try:
value = self.__parent.contents[key]
except KeyError:
pass
else:
if value is not None:
if isinstance(value, _ModuleMarker):
return value.mod_ns
else:
assert isinstance(value, _MultipleClassMarker)
return value.attempt_get(self.__parent.path, key)
raise AttributeError(
"Module %r has no mapped classes "
"registered under the name %r" % (self.__parent.name, key)
)
class _GetColumns(object):
__slots__ = ("cls",)
def __init__(self, cls):
self.cls = cls
def __getattr__(self, key):
mp = class_mapper(self.cls, configure=False)
if mp:
if key not in mp.all_orm_descriptors:
raise exc.InvalidRequestError(
"Class %r does not have a mapped column named %r"
% (self.cls, key)
)
desc = mp.all_orm_descriptors[key]
if desc.extension_type is interfaces.NOT_EXTENSION:
prop = desc.property
if isinstance(prop, SynonymProperty):
key = prop.name
elif not isinstance(prop, ColumnProperty):
raise exc.InvalidRequestError(
"Property %r is not an instance of"
" ColumnProperty (i.e. does not correspond"
" directly to a Column)." % key
)
return getattr(self.cls, key)
inspection._inspects(_GetColumns)(
lambda target: inspection.inspect(target.cls)
)
class _GetTable(object):
__slots__ = "key", "metadata"
def __init__(self, key, metadata):
self.key = key
self.metadata = metadata
def __getattr__(self, key):
return self.metadata.tables[_get_table_key(key, self.key)]
def _determine_container(key, value):
if isinstance(value, _MultipleClassMarker):
value = value.attempt_get([], key)
return _GetColumns(value)
class _class_resolver(object):
def __init__(self, cls, prop, fallback, arg, favor_tables=False):
self.cls = cls
self.prop = prop
self.arg = self._declarative_arg = arg
self.fallback = fallback
self._dict = util.PopulateDict(self._access_cls)
self._resolvers = ()
self.favor_tables = favor_tables
def _access_cls(self, key):
cls = self.cls
if self.favor_tables:
if key in cls.metadata.tables:
return cls.metadata.tables[key]
elif key in cls.metadata._schemas:
return _GetTable(key, cls.metadata)
if key in cls._decl_class_registry:
return _determine_container(key, cls._decl_class_registry[key])
if not self.favor_tables:
if key in cls.metadata.tables:
return cls.metadata.tables[key]
elif key in cls.metadata._schemas:
return _GetTable(key, cls.metadata)
if (
"_sa_module_registry" in cls._decl_class_registry
and key in cls._decl_class_registry["_sa_module_registry"]
):
registry = cls._decl_class_registry["_sa_module_registry"]
return registry.resolve_attr(key)
elif self._resolvers:
for resolv in self._resolvers:
value = resolv(key)
if value is not None:
return value
return self.fallback[key]
def _raise_for_name(self, name, err):
util.raise_(
exc.InvalidRequestError(
"When initializing mapper %s, expression %r failed to "
"locate a name (%r). If this is a class name, consider "
"adding this relationship() to the %r class after "
"both dependent classes have been defined."
% (self.prop.parent, self.arg, name, self.cls)
),
from_=err,
)
def _resolve_name(self):
name = self.arg
d = self._dict
rval = None
try:
for token in name.split("."):
if rval is None:
rval = d[token]
else:
rval = getattr(rval, token)
except KeyError as err:
self._raise_for_name(name, err)
except NameError as n:
self._raise_for_name(n.args[0], n)
else:
if isinstance(rval, _GetColumns):
return rval.cls
else:
return rval
def __call__(self):
try:
x = eval(self.arg, globals(), self._dict)
if isinstance(x, _GetColumns):
return x.cls
else:
return x
except NameError as n:
self._raise_for_name(n.args[0], n)
def _resolver(cls, prop):
import sqlalchemy
from sqlalchemy.orm import foreign, remote
fallback = sqlalchemy.__dict__.copy()
fallback.update({"foreign": foreign, "remote": remote})
def resolve_arg(arg, favor_tables=False):
return _class_resolver(
cls, prop, fallback, arg, favor_tables=favor_tables
)
def resolve_name(arg):
return _class_resolver(cls, prop, fallback, arg)._resolve_name
return resolve_name, resolve_arg
def _deferred_relationship(cls, prop):
if isinstance(prop, RelationshipProperty):
resolve_name, resolve_arg = _resolver(cls, prop)
for attr in (
"order_by",
"primaryjoin",
"secondaryjoin",
"secondary",
"_user_defined_foreign_keys",
"remote_side",
):
v = getattr(prop, attr)
if isinstance(v, util.string_types):
setattr(
prop,
attr,
resolve_arg(v, favor_tables=attr == "secondary"),
)
for attr in ("argument",):
v = getattr(prop, attr)
if isinstance(v, util.string_types):
setattr(prop, attr, resolve_name(v))
if prop.backref and isinstance(prop.backref, tuple):
key, kwargs = prop.backref
for attr in (
"primaryjoin",
"secondaryjoin",
"secondary",
"foreign_keys",
"remote_side",
"order_by",
):
if attr in kwargs and isinstance(
kwargs[attr], util.string_types
):
kwargs[attr] = resolve_arg(kwargs[attr])
return prop