Your IP : 3.145.83.96
# MySQL Connector/Python - MySQL driver written in Python.
# Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
# MySQL Connector/Python is licensed under the terms of the GPLv2
# <http://www.gnu.org/licenses/old-licenses/gpl-2.0.html>, like most
# MySQL Connectors. There are special exceptions to the terms and
# conditions of the GPLv2 as it is applied to this software, see the
# FOSS License Exception
# <http://www.mysql.com/about/legal/licensing/foss-exception.html>.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
"""Implementation of Statements."""
import json
import re
from .errors import ProgrammingError
from .expr import ExprParser
from .compat import STRING_TYPES
from .constants import Algorithms, Securities
from .dbdoc import DbDoc
from .protobuf import mysqlx_crud_pb2 as MySQLxCrud
from .result import SqlResult, Result, ColumnType
class Expr(object):
def __init__(self, expr):
self.expr = expr
def flexible_params(*values):
if len(values) == 1 and isinstance(values[0], (list, tuple,)):
return values[0]
return values
def is_quoted_identifier(identifier, sql_mode=""):
"""Check if the given identifier is quoted.
Args:
identifier (string): Identifier to check.
sql_mode (Optional[string]): SQL mode.
Returns:
`True` if the identifier has backtick quotes, and False otherwise.
"""
if "ANSI_QUOTES" in sql_mode:
return ((identifier[0] == "`" and identifier[-1] == "`") or
(identifier[0] == '"' and identifier[-1] == '"'))
else:
return identifier[0] == "`" and identifier[-1] == "`"
def quote_identifier(identifier, sql_mode=""):
"""Quote the given identifier with backticks, converting backticks (`) in
the identifier name with the correct escape sequence (``) unless the
identifier is quoted (") as in sql_mode set to ANSI_QUOTES.
Args:
identifier (string): Identifier to quote.
sql_mode (Optional[string]): SQL mode.
Returns:
A string with the identifier quoted with backticks.
"""
if is_quoted_identifier(identifier, sql_mode):
return identifier
if "ANSI_QUOTES" in sql_mode:
return '"{0}"'.format(identifier.replace('"', '""'))
else:
return "`{0}`".format(identifier.replace("`", "``"))
def quote_multipart_identifier(identifiers, sql_mode=""):
"""Quote the given multi-part identifier with backticks.
Args:
identifiers (iterable): List of identifiers to quote.
sql_mode (Optional[string]): SQL mode.
Returns:
A string with the multi-part identifier quoted with backticks.
"""
return ".".join([quote_identifier(identifier, sql_mode)
for identifier in identifiers])
def parse_table_name(default_schema, table_name, sql_mode=""):
quote = '"' if "ANSI_QUOTES" in sql_mode else "`"
delimiter = ".{0}".format(quote) if quote in table_name else "."
temp = table_name.split(delimiter, 1)
return (default_schema if len(temp) is 1 else temp[0].strip(quote),
temp[-1].strip(quote),)
class Statement(object):
"""Provides base functionality for statement objects.
Args:
target (object): The target database object, it can be
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
doc_based (bool): `True` if it is document based.
"""
def __init__(self, target, doc_based=True):
self._target = target
self._doc_based = doc_based
self._connection = target._connection if target else None
@property
def target(self):
"""object: The database object target.
"""
return self._target
@property
def schema(self):
""":class:`mysqlx.Schema`: The Schema object.
"""
return self._target.schema
def execute(self):
"""Execute the statement.
Raises:
NotImplementedError: This method must be implemented.
"""
raise NotImplementedError
class FilterableStatement(Statement):
"""A statement to be used with filterable statements.
Args:
target (object): The target database object, it can be
:class:`mysqlx.Collection` or :class:`mysqlx.Table`.
doc_based (Optional[bool]): `True` if it is document based
(default: `True`).
condition (Optional[str]): Sets the search condition to filter
documents or records.
"""
def __init__(self, target, doc_based=True, condition=None):
super(FilterableStatement, self).__init__(target=target,
doc_based=doc_based)
self._has_projection = False
self._has_where = False
self._has_limit = False
self._has_sort = False
self._has_group_by = False
self._has_having = False
self._has_bindings = False
self._binding_map = {}
self._bindings = []
if condition is not None:
self.where(condition)
def where(self, condition):
"""Sets the search condition to filter.
Args:
condition (str): Sets the search condition to filter documents or
records.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
"""
self._has_where = True
self._where = condition
expr = ExprParser(condition, not self._doc_based)
self._where_expr = expr.expr()
self._binding_map = expr.placeholder_name_to_position
return self
def _projection(self, *fields):
fields = flexible_params(*fields)
self._has_projection = True
self._projection_str = ",".join(fields)
self._projection_expr = ExprParser(self._projection_str,
not self._doc_based).parse_table_select_projection()
return self
def limit(self, row_count, offset=0):
"""Sets the maximum number of records or documents to be returned.
Args:
row_count (int): The maximum number of records or documents.
offset (Optional[int]) The number of records or documents to skip.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
"""
self._has_limit = True
self._limit_offset = offset
self._limit_row_count = row_count
return self
def sort(self, *sort_clauses):
"""Sets the sorting criteria.
Args:
*sort_clauses: The expression strings defining the sort criteria.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
"""
sort_clauses = flexible_params(*sort_clauses)
self._has_sort = True
self._sort_str = ",".join(sort_clauses)
self._sort_expr = ExprParser(self._sort_str,
not self._doc_based).parse_order_spec()
return self
def _group_by(self, *fields):
fields = flexible_params(*fields)
self._has_group_by = True
self._grouping_str = ",".join(fields)
self._grouping = ExprParser(self._grouping_str,
not self._doc_based).parse_expr_list()
def _having(self, condition):
self._has_having = True
self._having = ExprParser(condition, not self._doc_based).expr()
def bind(self, *args):
"""Binds a value to a specific placeholder.
Args:
*args: The name of the placeholder and the value to bind.
A :class:`mysqlx.DbDoc` object or a JSON string
representation can be used.
Returns:
mysqlx.FilterableStatement: FilterableStatement object.
Raises:
ProgrammingError: If the number of arguments is invalid.
"""
self._has_bindings = True
count = len(args)
if count == 1:
self._bind_single(args[0])
elif count > 2:
raise ProgrammingError("Invalid number of arguments to bind")
else:
self._bindings.append({"name": args[0], "value": args[1]})
return self
def _bind_single(self, object):
if isinstance(object, DbDoc):
self.bind(str(object))
elif isinstance(object, STRING_TYPES):
dict = json.loads(object)
for key in dict.keys():
self.bind(key, dict[key])
def execute(self):
"""Execute the statement.
Raises:
NotImplementedError: This method must be implemented.
"""
raise NotImplementedError
class SqlStatement(Statement):
"""A statement for SQL execution.
Args:
connection (mysqlx.connection.Connection): Connection object.
sql (string): The sql statement to be executed.
"""
def __init__(self, connection, sql):
super(SqlStatement, self).__init__(target=None, doc_based=False)
self._connection = connection
self._sql = sql
def execute(self):
"""Execute the statement.
Returns:
mysqlx.SqlResult: SqlResult object.
"""
self._connection.send_sql(self._sql)
return SqlResult(self._connection)
class AddStatement(Statement):
"""A statement for document addition on a collection.
Args:
collection (mysqlx.Collection): The Collection object.
"""
def __init__(self, collection):
super(AddStatement, self).__init__(target=collection)
self._values = []
self._ids = []
def add(self, *values):
"""Adds a list of documents into a collection.
Args:
*values: The documents to be added into the collection.
Returns:
mysqlx.AddStatement: AddStatement object.
"""
for val in flexible_params(*values):
if isinstance(val, DbDoc):
self._values.append(val)
else:
self._values.append(DbDoc(val))
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
if len(self._values) == 0:
return Result()
for doc in self._values:
self._ids.append(doc.ensure_id())
return self._connection.send_insert(self)
class UpdateSpec(object):
def __init__(self, update_type, source, value=None):
if update_type == MySQLxCrud.UpdateOperation.SET:
self._table_set(source, value)
else:
self.update_type = update_type
self.source = source
if len(source) > 0 and source[0] == '$':
self.source = source[1:]
self.source = ExprParser(self.source,
False).document_field().identifier
self.value = value
def _table_set(self, source, value):
self.update_type = MySQLxCrud.UpdateOperation.SET
self.source = ExprParser(source, True).parse_table_update_field()
self.value = value
class ModifyStatement(FilterableStatement):
"""A statement for document update operations on a Collection.
Args:
collection (mysqlx.Collection): The Collection object.
condition (Optional[str]): Sets the search condition to identify the
documents to be updated.
"""
def __init__(self, collection, condition=None):
super(ModifyStatement, self).__init__(target=collection,
condition=condition)
self._update_ops = []
def set(self, doc_path, value):
"""Sets or updates attributes on documents in a collection.
Args:
doc_path (string): The document path of the item to be set.
value (string): The value to be set on the specified attribute.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops.append(
UpdateSpec(MySQLxCrud.UpdateOperation.ITEM_SET, doc_path, value))
return self
def change(self, doc_path, value):
"""Add an update to the statement setting the field, if it exists at
the document path, to the given value.
Args:
doc_path (string): The document path of the item to be set.
value (object): The value to be set on the specified attribute.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops.append(
UpdateSpec(MySQLxCrud.UpdateOperation.ITEM_REPLACE, doc_path,
value))
return self
def unset(self, *doc_paths):
"""Removes attributes from documents in a collection.
Args:
doc_path (string): The document path of the attribute to be
removed.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops.extend([
UpdateSpec(MySQLxCrud.UpdateOperation.ITEM_REMOVE, x)
for x in flexible_params(*doc_paths)])
return self
def array_insert(self, field, value):
"""Insert a value into the specified array in documents of a
collection.
Args:
field (string): A document path that identifies the array attribute
and position where the value will be inserted.
value (object): The value to be inserted.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops.append(
UpdateSpec(MySQLxCrud.UpdateOperation.ARRAY_INSERT, field, value))
return self
def array_append(self, doc_path, value):
"""Inserts a value into a specific position in an array attribute in
documents of a collection.
Args:
doc_path (string): A document path that identifies the array
attribute and position where the value will be
inserted.
value (object): The value to be inserted.
Returns:
mysqlx.ModifyStatement: ModifyStatement object.
"""
self._update_ops.append(
UpdateSpec(MySQLxCrud.UpdateOperation.ARRAY_APPEND, doc_path,
value))
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
return self._connection.update(self)
class FindStatement(FilterableStatement):
"""A statement document selection on a Collection.
Args:
collection (mysqlx.Collection): The Collection object.
condition (Optional[str]): An optional expression to identify the
documents to be retrieved. If not specified
all the documents will be included on the
result unless a limit is set.
"""
def __init__(self, collection, condition=None):
super(FindStatement, self).__init__(collection, True, condition)
def fields(self, *fields):
"""Sets a document field filter.
Args:
*fields: The string expressions identifying the fields to be
extracted.
Returns:
mysqlx.FindStatement: FindStatement object.
"""
return self._projection(*fields)
def group_by(self, *fields):
"""Sets a grouping criteria for the resultset.
Args:
*fields: The string expressions identifying the grouping criteria.
Returns:
mysqlx.FindStatement: FindStatement object.
"""
self._group_by(*fields)
return self
def having(self, condition):
"""Sets a condition for records to be considered in agregate function
operations.
Args:
condition (string): A condition on the agregate functions used on
the grouping criteria.
Returns:
mysqlx.FindStatement: FindStatement object.
"""
self._having(condition)
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.DocResult: DocResult object.
"""
return self._connection.find(self)
class SelectStatement(FilterableStatement):
"""A statement for record retrieval operations on a Table.
Args:
table (mysqlx.Table): The Table object.
*fields: The fields to be retrieved.
"""
def __init__(self, table, *fields):
super(SelectStatement, self).__init__(table, False)
self._projection(*fields)
def group_by(self, *fields):
"""Sets a grouping criteria for the resultset.
Args:
*fields: The fields identifying the grouping criteria.
Returns:
mysqlx.SelectStatement: SelectStatement object.
"""
self._group_by(*fields)
return self
def having(self, condition):
"""Sets a condition for records to be considered in agregate function
operations.
Args:
condition (str): A condition on the agregate functions used on the
grouping criteria.
Returns:
mysqlx.SelectStatement: SelectStatement object.
"""
self._having(condition)
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.RowResult: RowResult object.
"""
return self._connection.find(self)
def get_sql(self):
where = " WHERE {0}".format(self._where) if self._has_where else ""
group_by = " GROUP BY {0}".format(self._grouping_str) if \
self._has_group_by else ""
having = " HAVING {0}".format(self._having) if self._has_having else ""
order_by = " ORDER BY {0}".format(self._sort_str) if self._has_sort \
else ""
limit = " LIMIT {0} OFFSET {1}".format(self._limit_row_count,
self._limit_offset) if self._has_limit else ""
stmt = ("SELECT {select} FROM {schema}.{table}{where}{group}{having}"
"{order}{limit}".format(
select=getattr(self, '_projection_str', "*"),
schema=self.schema.name, table=self.target.name, limit=limit,
where=where, group=group_by, having=having, order=order_by))
return stmt
class InsertStatement(Statement):
"""A statement for insert operations on Table.
Args:
table (mysqlx.Table): The Table object.
*fields: The fields to be inserted.
"""
def __init__(self, table, *fields):
super(InsertStatement, self).__init__(target=table, doc_based=False)
self._fields = flexible_params(*fields)
self._values = []
def values(self, *values):
"""Set the values to be inserted.
Args:
*values: The values of the columns to be inserted.
Returns:
mysqlx.InsertStatement: InsertStatement object.
"""
self._values.append(list(flexible_params(*values)))
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
return self._connection.send_insert(self)
class UpdateStatement(FilterableStatement):
"""A statement for record update operations on a Table.
Args:
table (mysqlx.Table): The Table object.
*fields: The fields to be updated.
"""
def __init__(self, table, *fields):
super(UpdateStatement, self).__init__(target=table, doc_based=False)
self._update_ops = []
def set(self, field, value):
"""Updates the column value on records in a table.
Args:
field (string): The column name to be updated.
value (object): The value to be set on the specified column.
Returns:
mysqlx.UpdateStatement: UpdateStatement object.
"""
self._update_ops.append(
UpdateSpec(MySQLxCrud.UpdateOperation.SET, field, value))
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object
"""
return self._connection.update(self)
class RemoveStatement(FilterableStatement):
"""A statement for document removal from a collection.
Args:
collection (mysqlx.Collection): The Collection object.
"""
def __init__(self, collection):
super(RemoveStatement, self).__init__(target=collection)
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
return self._connection.delete(self)
class DeleteStatement(FilterableStatement):
"""A statement that drops a table.
Args:
table (mysqlx.Table): The Table object.
condition (Optional[str]): The string with the filter expression of
the rows to be deleted.
"""
def __init__(self, table, condition=None):
super(DeleteStatement, self).__init__(target=table,
condition=condition,
doc_based=False)
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
return self._connection.delete(self)
class CreateCollectionIndexStatement(Statement):
"""A statement that creates an index on a collection.
Args:
collection (mysqlx.Collection): Collection.
index_name (string): Index name.
is_unique (bool): `True` if the index is unique.
"""
def __init__(self, collection, index_name, is_unique):
super(CreateCollectionIndexStatement, self).__init__(target=collection)
self._index_name = index_name
self._is_unique = is_unique
self._fields = []
def field(self, document_path, column_type, is_required):
"""Add the field specification to this index creation statement.
Args:
document_path (string): The document path.
column_type (string): The column type.
is_required (bool): `True` if the field is required.
Returns:
mysqlx.CreateCollectionIndexStatement: \
CreateCollectionIndexStatement object.
"""
self._fields.append((document_path, column_type, is_required,))
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
fields = [item for sublist in self._fields for item in sublist]
return self._connection.execute_nonquery(
"xplugin", "create_collection_index", True,
self._target.schema.name, self._target.name, self._index_name,
self._is_unique, *fields)
class DropCollectionIndexStatement(Statement):
"""A statement that drops an index on a collection.
Args:
collection (mysqlx.Collection): The Collection object.
index_name (string): The index name.
"""
def __init__(self, collection, index_name):
super(DropCollectionIndexStatement, self).__init__(target=collection)
self._index_name = index_name
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Result: Result object.
"""
return self._connection.execute_nonquery(
"xplugin", "drop_collection_index", True,
self._target.schema.name, self._target.name, self._index_name)
class TableIndex(object):
UNIQUE_INDEX = 1
INDEX = 2
def __init__(self, name, index_type, columns):
self._name = name
self._index_type = index_type
self._columns = columns
def get_sql(self):
stmt = ""
if self._index_type is TableIndex.UNIQUE_INDEX:
stmt += "UNIQUE "
stmt += "INDEX {0} ({1})"
return stmt.format(self._name, ",".join(self._columns))
class CreateViewStatement(Statement):
"""A statement for creating views.
Args:
view (mysqlx.View): The View object.
replace (Optional[bool]): `True` to add replace.
"""
def __init__(self, view, replace=False):
super(CreateViewStatement, self).__init__(target=view, doc_based=False)
self._view = view
self._schema = view.schema
self._name = view.name
self._replace = replace
self._columns = []
self._algorithm = Algorithms.UNDEFINED
self._security = Securities.DEFINER
self._definer = None
self._defined_as = None
self._check_option = None
def columns(self, columns):
"""Sets the column names.
Args:
columns (list): The list of column names.
Returns:
mysqlx.CreateViewStatement: CreateViewStatement object.
"""
self._columns = [quote_identifier(col) for col in columns]
return self
def algorithm(self, algorithm):
"""Sets the algorithm.
Args:
mysqlx.constants.ALGORITHMS: The algorithm.
Returns:
mysqlx.CreateViewStatement: CreateViewStatement object.
"""
self._algorithm = algorithm
return self
def security(self, security):
"""Sets the SQL security mode.
Args:
mysqlx.constants.SECURITIES: The SQL security mode.
Returns:
mysqlx.CreateViewStatement: CreateViewStatement object.
"""
self._security = security
return self
def definer(self, definer):
"""Sets the definer.
Args:
definer (string): The definer.
Returns:
mysqlx.CreateViewStatement: CreateViewStatement object.
"""
self._definer = definer
return self
def defined_as(self, statement):
"""Sets the SelectStatement statement for describing the view.
Args:
mysqlx.SelectStatement: SelectStatement object.
Returns:
mysqlx.CreateViewStatement: CreateViewStatement object.
"""
self._defined_as = statement
return self
def with_check_option(self, check_option):
"""Sets the check option.
Args:
mysqlx.constants.CHECK_OPTIONS: The check option.
Returns:
mysqlx.CreateViewStatement: CreateViewStatement object.
"""
self._check_option = check_option
return self
def execute(self):
"""Execute the statement to create a view.
Returns:
mysqlx.View: View object.
"""
replace = " OR REPLACE" if self._replace else ""
definer = " DEFINER = {0}".format(self._definer) \
if self._definer else ""
columns = " ({0})".format(", ".join(self._columns)) \
if self._columns else ""
view_name = quote_multipart_identifier((self._schema.name, self._name))
check_option = " WITH {0} CHECK OPTION".format(self._check_option) \
if self._check_option else ""
sql = ("CREATE{replace} ALGORITHM = {algorithm}{definer} "
"SQL SECURITY {security} VIEW {view_name}{columns} "
"AS {defined_as}{check_option}"
"".format(replace=replace, algorithm=self._algorithm,
definer=definer, security=self._security,
view_name=view_name, columns=columns,
defined_as=self._defined_as,
check_option=check_option))
self._connection.execute_nonquery("sql", sql)
return self._view
class AlterViewStatement(CreateViewStatement):
"""A statement for alter views.
Args:
view (mysqlx.View): The View object.
"""
def __init__(self, view):
super(AlterViewStatement, self).__init__(view)
def execute(self):
"""Execute the statement to alter a view.
Returns:
mysqlx.View: View object.
"""
definer = " DEFINER = {0}".format(self._definer) \
if self._definer else ""
columns = " ({0})".format(", ".join(self._columns)) \
if self._columns else ""
view_name = quote_multipart_identifier((self._schema.name, self._name))
check_option = " WITH {0} CHECK OPTION".format(self._check_option) \
if self._check_option else ""
sql = ("ALTER ALGORITHM = {algorithm}{definer} "
"SQL SECURITY {security} VIEW {view_name}{columns} "
"AS {defined_as}{check_option}"
"".format(algorithm=self._algorithm, definer=definer,
security=self._security, view_name=view_name,
columns=columns, defined_as=self._defined_as,
check_option=check_option))
self._connection.execute_nonquery("sql", sql)
return self._view
class CreateTableStatement(Statement):
"""A statement that creates a new table if it doesn't exist already.
Args:
collection (mysqlx.Schema): The Schema object.
table_name (string): The name for the new table.
"""
tbl_frmt = re.compile(r"(from\s+)([`\"].+[`\"]|[^\.]+)(\s|$)", re.IGNORECASE)
def __init__(self, schema, table_name):
super(CreateTableStatement, self).__init__(schema)
self._charset = None
self._collation = None
self._comment = None
self._as = None
self._like = None
self._temp = False
self._columns = []
self._f_keys = []
self._indices = []
self._p_keys = []
self._u_indices = []
self._auto_inc = 0
self._name = table_name
self._tbl_repl = r"\1{0}.\2\3".format(self.schema.get_name())
@property
def table_name(self):
"""string: The fully qualified name of the Table.
"""
return quote_multipart_identifier(parse_table_name(
self.schema.name, self._name))
def _get_table_opts(self):
options = []
options.append("AUTO_INCREMENT = {inc}")
if self._charset:
options.append("DEFAULT CHARACTER SET = {charset}")
if self._collation:
options.append("DEFAULT COLLATE = {collation}")
if self._comment:
options.append("COMMENT = '{comment}'")
table_opts = ",".join(options)
return table_opts.format(inc=self._auto_inc, charset=self._charset,
collation=self._collation, comment=self._comment)
def _get_create_def(self):
defs = []
if self._p_keys:
defs.append("PRIMARY KEY ({0})".format(",".join(self._p_keys)))
for col in self._columns:
defs.append(col.get_sql())
for key in self._f_keys:
defs.append(key.get_sql())
for index in self._indices:
defs.append(index.get_sql())
for index in self._u_indices:
defs.append(index.get_sql())
return ",".join(defs)
def like(self, table_name):
"""Create table with the definition of another existing Table.
Args:
table_name (string): Name of the source table.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._like = quote_multipart_identifier(
parse_table_name(self.schema.name, table_name))
return self
def as_select(self, select):
"""Create the Table and fill it with values from a Select Statement.
Args:
select (object): Select Statement. Can be a string or an instance of
:class`mysqlx.SelectStatement`.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
if isinstance(select, STRING_TYPES):
self._as = CreateTableStatement.tbl_frmt.sub(self._tbl_repl, select)
elif isinstance(select, SelectStatement):
self._as = select.get_sql()
return self
def add_column(self, column_def):
"""Add a Column to the Table.
Args:
column_def (MySQLx.ColumnDef): Column Definition object.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
column_def.set_schema(self.schema.get_name())
self._columns.append(column_def)
return self
def add_primary_key(self, *keys):
"""Add multiple Primary Keys to the Table.
Args:
*keys: Fields to be used as Primary Keys.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
keys = flexible_params(*keys)
self._p_keys.extend(keys)
return self
def add_index(self, index_name, *cols):
"""Adds an Index to the Table.
Args:
index_name (string): Name of the Index.
*cols: Fields to be used as an Index.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._indices.append(TableIndex(index_name, TableIndex.INDEX,
flexible_params(*cols)))
return self
def add_unique_index(self, index_name, *cols):
"""Adds a Unique Index to the Table.
Args:
index_name (string): Name of the Unique Index.
*cols: Fields to be used as a Unique Index.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._u_indices.append(TableIndex(index_name, TableIndex.UNIQUE_INDEX,
flexible_params(*cols)))
return self
def add_foreign_key(self, name, key):
"""Adds a Foreign Key to the Table.
Args:
key (MySQLx.ForeignKeyDef): The Foreign Key Definition object.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
key.set_schema(self.schema.get_name())
key.set_name(name)
self._f_keys.append(key)
return self
def set_initial_auto_increment(self, inc):
"""Set the initial Auto Increment value for the table.
Args:
inc (int): The initial AUTO_INCREMENT value for the table.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._auto_inc = inc
return self
def set_default_charset(self, charset):
"""Sets the default Charset type for the Table.
Args:
charset (string): Charset type.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._charset = charset
return self
def set_default_collation(self, collation):
"""Sets the default Collation type for the Table.
Args:
collation (string): Collation type.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._collation = collation
return self
def set_comment(self, comment):
"""Add a comment to the Table.
Args:
comment (string): Comment to be added to the Table.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._comment = comment
return self
def temporary(self):
"""Set the Table to be Temporary.
Returns:
mysqlx.CreateTableStatement: CreateTableStatement object.
"""
self._temp = True
return self
def execute(self):
"""Execute the statement.
Returns:
mysqlx.Table: Table object.
"""
create = "CREATE {table_type} {name}".format(name=self.table_name,
table_type="TEMPORARY TABLE" if self._temp else "TABLE")
if self._like:
stmt = "{create} LIKE {query}"
else:
stmt = "{create} ({create_def}) {table_opts} {query}"
stmt = stmt.format(
create=create,
query=self._like or self._as or "",
create_def=self._get_create_def(),
table_opts=self._get_table_opts())
self._connection.execute_nonquery("sql", stmt, False)
return self.schema.get_table(self._name)
class ColumnDefBase(object):
"""A Base class defining the basic parameters required to define a column.
Args:
name (string): Name of the column.
type (MySQLx.ColumnType): Type of the column.
size (int): Size of the column.
"""
def __init__(self, name, type, size):
self._default_schema = None
self._not_null = False
self._p_key = False
self._u_index = False
self._name = name
self._size = size
self._comment = ""
self._type = type
def not_null(self):
"""Disable NULL values for this column.
Returns:
mysqlx.ColumnDefBase: ColumnDefBase object.
"""
self._not_null = True
return self
def unique_index(self):
"""Set current column as a Unique Index.
Returns:
mysqlx.ColumnDefBase: ColumnDefBase object.
"""
self._u_index = True
return self
def comment(self, comment):
"""Add a comment to the column.
Args:
comment (string): Comment to be added to the column.
Returns:
mysqlx.ColumnDefBase: ColumnDefBase object.
"""
self._comment = comment
return self
def primary(self):
"""Sets the Column as a Primary Key.
Returns:
mysqlx.ColumnDefBase: ColumnDefBase object.
"""
self._p_key = True
return self
def set_schema(self, schema):
self._default_schema = schema
class ColumnDef(ColumnDefBase):
"""Class containing the complete definition of the Column.
Args:
name (string): Name of the column.
type (MySQL.ColumnType): Type of the column.
size (int): Size of the column.
"""
def __init__(self, name, type, size=None):
super(ColumnDef, self).__init__(name, type, size)
self._ref = None
self._default = None
self._decimals = None
self._ref_table = None
self._binary = False
self._auto_inc = False
self._unsigned = False
self._values = []
self._ref_fields = []
self._charset = None
self._collation = None
def _data_type(self):
type_def = ""
if self._size and (ColumnType.is_numeric(self._type) or \
ColumnType.is_char(self._type) or ColumnType.is_binary(self._type)):
type_def = "({0})".format(self._size)
elif ColumnType.is_decimals(self._type) and self._size:
type_def = "({0}, {1})".format(self._size, self._decimals or 0)
elif ColumnType.is_finite_set(self._type):
type_def = "({0})".format(",".join(self._values))
if self._unsigned:
type_def = "{0} UNSIGNED".format(type_def)
if self._binary:
type_def = "{0} BINARY".format(type_def)
if self._charset:
type_def = "{0} CHARACTER SET {1}".format(type_def, self._charset)
if self._collation:
type_def = "{0} COLLATE {1}".format(type_def, self._collation)
return "{0} {1}".format(ColumnType.to_string(self._type), type_def)
def _col_definition(self):
null = " NOT NULL" if self._not_null else " NULL"
auto_inc = " AUTO_INCREMENT" if self._auto_inc else ""
default = " DEFAULT {default}" if self._default else ""
comment = " COMMENT '{comment}'" if self._comment else ""
defn = "{0}{1}{2}{3}{4}".format(self._data_type(), null, default,
auto_inc, comment)
if self._p_key:
defn = "{0} PRIMARY KEY".format(defn)
elif self._u_index:
defn = "{0} UNIQUE KEY".format(defn)
if self._ref_table and self._ref_fields:
ref_table = quote_multipart_identifier(parse_table_name(
self._default_schema, self._ref_table))
defn = "{0} REFERENCES {1} ({2})".format(defn, ref_table,
",".join(self._ref_fields))
return defn.format(default=self._default, comment=self._comment)
def set_default(self, default_val):
"""Sets the default value of this Column.
Args:
default_val (object): The default value of the Column. Can be a
string, number or :class`MySQLx.Expr`.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
if isinstance(default_val, Expr):
self._default = default_val.expr
elif default_val is None:
self._default = "NULL"
else:
self._default = repr(default_val)
return self
def auto_increment(self):
"""Set the Column to Auto Increment.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._auto_inc = True
return self
def foreign_key(self, name, *refs):
"""Sets the Column as a Foreign Key.
Args:
name (string): Name of the referenced Table.
*refs: Fields this Column references.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._ref_fields = flexible_params(*refs)
self._ref_table = name
return self
def unsigned(self):
"""Set the Column as unsigned.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._unsigned = True
return self
def decimals(self, size):
"""Set the size of the decimal Column.
Args:
size (int): Size of the decimal.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._decimals = size
return self
def charset(self, charset):
"""Set the Charset type of the Column.
Args:
charset (string): Charset type.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._charset = charset
return self
def collation(self, collation):
"""Set the Collation type of the Column.
Args:
collation (string): Collation type.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._collation = collation
return self
def binary(self):
"""Set the current column to binary type.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._binary = True
return self
def values(self, *values):
"""Set the Enum/Set values.
Args:
*values: Values for Enum/Set type Column.
Returns:
mysqlx.ColumnDef: ColumnDef object.
"""
self._values = map(repr, flexible_params(*values))
return self
def get_sql(self):
return "{0} {1}".format(self._name, self._col_definition())
class GeneratedColumnDef(ColumnDef):
"""Class used to describe a Generated Column.
Args:
name: Name of the column.
col_type: Type of the column.
expr: The Expression used to generate the value of this column.
"""
def __init__(self, name, col_type, expr):
super(GeneratedColumnDef, self).__init__(name, col_type)
assert isinstance(expr, Expr)
self._stored = False
self._expr = expr.expr
def stored(self):
"""Set the Generated Column to be stored.
Returns:
mysqlx.GeneratedColumnDef: GeneratedColumnDef object.
"""
self._stored = True
return self
def get_sql(self):
return "{0} GENERATED ALWAYS AS ({1}){2}".format(
super(GeneratedColumnDef, self).get_sql(),
self._expr, " STORED" if self._stored else "")
class ForeignKeyDef(object):
"""Class describing a Foreign Key."""
NO_ACTION = 1
RESTRICT = 2
CASCADE = 3
SET_NULL = 4
def __init__(self):
self._fields = []
self._f_fields = []
self._name = None
self._f_table = None
self._default_schema = None
self._update_action = self._action(ForeignKeyDef.NO_ACTION)
self._delete_action = self._action(ForeignKeyDef.NO_ACTION)
def _action(self, action):
if action is ForeignKeyDef.RESTRICT:
return "RESTRICT"
elif action is ForeignKeyDef.CASCADE:
return "CASCADE"
elif action is ForeignKeyDef.SET_NULL:
return "SET NULL"
return "NO ACTION"
def set_name(self, name):
self._name = name
def set_schema(self, schema):
self._default_schema = schema
def fields(self, *fields):
"""Add a list of fields in the parent table.
Args:
*fields: Fields in the given table which constitute the Foreign Key.
Returns:
mysqlx.ForeignKeyDef: ForeignKeyDef object.
"""
self._fields = flexible_params(*fields)
return self
def refers_to(self, name, *refs):
"""Add the child table name and the fields.
Args:
name (string): Name of the referenced table.
*refs: A list fields in the referenced table.
Returns:
mysqlx.ForeignKeyDef: ForeignKeyDef object.
"""
self._f_fields = flexible_params(*refs)
self._f_table = name
return self
def on_update(self, action):
"""Define the action on updating a Foreign Key.
Args:
action (int): Action to be performed on updating the reference.
Can be any of the following values:
1. ForeignKeyDef.NO_ACTION
2. ForeignKeyDef.RESTRICT
3. ForeignKeyDef.CASCADE
4. ForeignKeyDef.SET_NULL
Returns:
mysqlx.ForeignKeyDef: ForeignKeyDef object.
"""
self._update_action = self._action(action)
return self
def on_delete(self, action):
"""Define the action on deleting a Foreign Key.
Args:
action (int): Action to be performed on updating the reference.
Can be any of the following values:
1. ForeignKeyDef.NO_ACTION
2. ForeignKeyDef.RESTRICT
3. ForeignKeyDef.CASCADE
4. ForeignKeyDef.SET_NULL
Returns:
mysqlx.ForeignKeyDef: ForeignKeyDef object.
"""
self._delete_action = self._action(action)
return self
def get_sql(self):
update = "ON UPDATE {0}".format(self._update_action)
delete = "ON DELETE {0}".format(self._delete_action)
key = "FOREIGN KEY {0}({1}) REFERENCES {2} ({3})".format(
self._name, ",".join(self._fields), quote_multipart_identifier(
parse_table_name(self._default_schema, self._f_table)),
",".join(self._f_fields))
return "{0} {1} {2}".format(key, update, delete)