Your IP : 18.221.147.141


Current Path : /opt/cloudlinux/venv/lib64/python3.11/site-packages/lvestats/lib/commons/
Upload File :
Current File : //opt/cloudlinux/venv/lib64/python3.11/site-packages/lvestats/lib/commons/func.py

# coding=utf-8
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2019 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENSE.TXT

import csv
import errno
import fcntl
import os
import pwd
import tempfile
import time
import uuid
from _socket import error as SocketError
from _socket import gethostname
from typing import Any
from urllib.parse import urlparse
from xml.dom import DOMException, minidom
from xml.parsers.expat import ExpatError

from sqlalchemy import engine as sqlalchemy_engine
from prettytable import PrettyTable

from clcommon import FormattedException, cpapi
from clcommon.clcaptain import mkdir as mkdir_p
from clcommon.clfunc import uid_max
from clcommon.clproc import LIMIT_LVP_ID
from clcommon.cpapi.cpapiexceptions import (
    CPAPIExternalProgramFailed,
    NotSupported,
    ParsingError,
)
from clcommon.lib import MySQLGovernor, MySQLGovException
from lvestats.lib.config import USER_NOTIFICATIONS_OFF_MARKER, get_max_lve_id

MAX_LVE_ID = None
LVE_FILE = '/proc/lve/list'
GOVERNOR_CONFIG = '/var/run/mysql-governor-config.xml'  # can read all in system; present in governor-mysql >= 1.1-14
GOVERNOR_CONFIG_OLD = '/etc/container/mysql-governor.xml'  # can read root only


def get_current_max_lve_id():
    """
    Get current maximum of LVE ID
    """
    global MAX_LVE_ID  # pylint: disable=global-statement
    if MAX_LVE_ID is None:
        MAX_LVE_ID = get_max_lve_id()
    return MAX_LVE_ID


def gcd(a, b):
    # returns greater common divisor
    while b:
        a, b = b, a % b
    return a


def merge_dicts(*dicts):
    """
    >>> merge_dicts({1: 2}, {3: 4})
    {1: 2, 3: 4}
    >>> merge_dicts({1: '1', 2: '2'}, {2: 'two', 3: 'three'})
    {1: '1', 2: 'two', 3: 'three'}
    """""
    dic_merged = {}
    for dic_ in dicts:
        dic_merged.update(dic_)
    return dic_merged


def get_chunks(_list, chunk_length=500):
    """
    :type chunk_length: int
    :type _list: list
    :rtype: Generator[list]
    """
    for i in range(0, len(_list), chunk_length):
        yield _list[i:i + chunk_length]


class LVEVersionError(Exception):
    def __init__(self, message=''):
        Exception.__init__(self, f"Can't detect LVE version; {message}")


def get_lve_version(_lve_file=LVE_FILE):
    """
    :param str _lve_file:
    :return int:
    """
    try:
        with open(_lve_file, 'r', encoding='utf-8') as lve_file_stream:
            lve_ver_str = lve_file_stream.read(5)  # read only first 5 symbols
            try:
                return int(lve_ver_str.split(':')[0])
            except (TypeError, IndexError, ValueError) as e:
                raise LVEVersionError(
                    f"error parsing line '{lve_ver_str}' from file {_lve_file}"
                ) from e
    except (OSError, IOError) as e:
        raise LVEVersionError(e) from e


def get_governor_mode(governor=None):
    """
    Parse governor config file and return governor mode
    :type governor: MySQLGovernor
    :rtype: str
    :return:
        if file exists and correct: "on", "single", "off", "abusers" or "all"
        if governor is not installed or error occurs: "none"
    """
    governor = governor or MySQLGovernor()
    try:
        governor_mode = governor.get_governor_mode()
    except MySQLGovException:
        governor_mode = None
    return governor_mode or "none"


def get_governor_config_path():
    # check governor installed, file GOVERNOR_CONFIG can be present after mysql-governor uninstalled
    """
    :rtype: str|None
    """
    if not os.path.exists('/usr/sbin/db_governor'):  # this need for compatibility with mysql-governor 1.1-15
        return None
    if os.path.exists(GOVERNOR_CONFIG):
        governor_config_path = GOVERNOR_CONFIG
    else:
        governor_config_path = GOVERNOR_CONFIG_OLD  # config for backward compatible with governor-mysql < 1.1-14
    return governor_config_path


# TODO: replace with MysqlGovernor().get_governor_status_by_username
def get_governor_ignore_for_user(username):
    """
    :rtype: bool|None
    :type username: str
    """
    governor_config_path = get_governor_config_path()
    if governor_config_path is None:
        return None
    try:
        governor_config = minidom.parse(governor_config_path)
        for user in governor_config.getElementsByTagName("user"):
            try:
                if user.attributes["name"].value == username:
                    # mode = ”restrict|norestrict|ignore”
                    return user.attributes["mode"].value.lower() == "ignore"
            except KeyError:
                continue
    except (IOError, ExpatError, DOMException):
        return None
    return False


class DomainException(FormattedException):
    pass


def get_governor_status():
    """
    :rtype: (str, MySQLGovException|None)
    """
    mysql_governor = MySQLGovernor()
    return mysql_governor.get_governor_status()


def get_reseller_domains(reseller_name):
    """
    Get dict[user, domain], empty if cannot obtain.
    :param reseller_name: reseller's name
    :rtype: dict[str, str|None]
    """
    try:
        return cpapi.reseller_domains(reseller_name)
    except (CPAPIExternalProgramFailed, NotSupported, AttributeError):
        return {}


def get_aliases(username, domain, raise_exception=True):
    """
    Get list of aliases for specified user's domain, or empty list.
    :type username: str
    :type domain: str
    :type raise_exception: bool
    :raises DomainException: if cannot obtain domains
    :rtype: list[str]|None
    """
    try:
        user_aliases = cpapi.useraliases(username, domain)
    except (KeyError, IOError, TypeError, cpapi.cpapiexceptions.NotSupported,
            cpapi.cpapiexceptions.NoPanelUser, ParsingError) as e:
        if raise_exception:
            raise DomainException(str(e)) from e
        return []
    else:
        return user_aliases


def get_domains(username, raise_exception=True):
    """
    Get list of domains for specified user, or None.
    :type username: str
    :type raise_exception: bool
    :raises DomainException: if cannot obtain domains
    :rtype: list[str]|None
    """
    try:
        user_domains = cpapi.userdomains(username)
    except (KeyError, IOError, TypeError, cpapi.cpapiexceptions.NotSupported,
            cpapi.cpapiexceptions.NoPanelUser, ParsingError) as e:
        if raise_exception:
            raise DomainException(str(e)) from e
        return []
    else:
        return [domain[0] for domain in user_domains]


def get_domain(username):
    """
    :param username: user name string
    :return:
        if not supported in CP, or domain name is not set: None
        if we can find it: domain name string
    :rtype: str|None
    :type username: str
    """
    user_domains = get_domains(username)
    if len(user_domains) == 0:
        return None
    return user_domains[0]


def init_database(db_engine, db_name):
    """
    Create database
    :param db_engine:
    :param db_name:
    :return:
    """
    conn = db_engine.connect()
    conn.execute('COMMIT')  # we can't run CREATE DATABASE in one transaction
    conn.execute(f"CREATE DATABASE {db_name}")
    conn.close()


def init_database_rnd(db_engine, prefix='tmp_unittest_'):
    """
    Create database with random prefix in name
    Using for unittests
    :param db_engine:
    :param str prefix:
    :return str:
    """
    db_name = prefix + str(uuid.uuid4())[0:8]
    init_database(db_engine, db_name)
    return db_name


def drop_database(uri, force=True):
    """
    Drop database (support MySQL and PostgresSQL)
    :param uri:
    :param force: kill all connection and locked session before drop database
    :return:
    """
    uri_ = urlparse(uri)
    db_name = uri_.path[1:]  # cut '/' on strings start
    if db_name:
        if uri_.scheme == 'sqlite':
            os.unlink(db_name)
            return True

        db_engine = sqlalchemy_engine.create_engine(uri_.scheme + '://' + uri_.netloc)
        conn = db_engine.connect()

        # force kill all session for dropped database
        if force and 'postgresql' in db_engine.name:
            # http://dba.stackexchange.com/questions/11893/force-drop-db-while-others-may-be-connected
            conn.execute(f"UPDATE pg_database SET datallowconn = 'false' where datname = '{db_name}'")
            pg_ver = conn.execute("SHOW server_version").fetchall()[0][0]
            # pg_ver here: CL - '9.6.22', Ubuntu - '12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)'
            pg_ver = pg_ver.strip().split(' ')[0]
            if list(map(int, pg_ver.split('.'))) < [9, 2]:  # compare versions
                conn.execute(f"SELECT pg_terminate_backend(procpid) FROM pg_stat_activity WHERE datname = '{db_name}'")
            else:
                conn.execute(f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{db_name}'")
        elif force and 'mysql' in db_engine.name:
            cursor = conn.execute("SHOW PROCESSLIST")
            for fetched in cursor.fetchall():
                if fetched[3] == db_name:
                    conn.execute(f"KILL {fetched[0]}")

        conn.execute("COMMIT")
        conn.execute(f"DROP DATABASE {db_name}")
        conn.close()
        return True


def _release_line():
    path = '/etc/redhat-release'
    if not os.path.exists(path):
        return ''
    with open(path, encoding='utf-8') as release_file:
        return release_file.readline()


def cl_ver(release_line=None):
    """
    Return CloudLinux version as list [major, minor]
    or [] if can't obtain version
    :param str release_line:
    :return list: splited version
    >>> cl_ver('CloudLinux Server release 5.11 (Vladislav Volkov)')
    [5, 11]
    >>> cl_ver('CloudLinux Server release 6.8 (Oleg Makarov)')
    [6, 8]
    >>> cl_ver('CloudLinux release 7.2 (Valeri Kubasov)')
    [7, 2]
    >>> cl_ver('Incorrect release line')
    []
    """
    if release_line is None:
        release_line = _release_line()
    release_line_splited = release_line.split()
    if 'release' in release_line_splited:
        ver_indx = release_line_splited.index('release') + 1
    else:
        return []
    ver_str = release_line_splited[ver_indx]
    ver = []
    for s in ver_str.split('.'):
        try:
            ver.append(int(s))
        except ValueError:
            pass
    return ver


def _reseller_users(reseller=None):
    cplogin_reseller_ = cpapi.cpinfo(keyls=('cplogin', 'reseller'))
    return [cplogin for cplogin, reseller_ in cplogin_reseller_ if reseller == reseller_]


def get_users_for_reseller(reseller=None):
    reseller_users = getattr(cpapi, 'reseller_users', _reseller_users)
    try:
        return reseller_users(reseller)
    except cpapi.cpapiexceptions.CPAPIException:
        return []


def touch(fname, times=None):
    """
    The equivalent console touch command
    :param str fname: path to file
    :param None| tuple times: (atime, mtime) if None atime and mtime are default
    """
    with open(fname, 'a', encoding='utf-8'):
        os.utime(fname, times)
    os.chmod(fname, 0o644)


def atomic_write_csv(filename, data):
    """
    :type data: list
    :type filename: str
    """
    def csv_write(csv_file):
        """:type csv_file: file"""
        writer = csv.writer(csv_file)
        writer.writerows(data)
    _atomic_write(filename, csv_write)


def atomic_write_str(filename, data):
    """
    :type data: str
    :type filename: str
    """
    def str_write(str_file):
        """:type str_file: file"""
        str_file.write(data)
    _atomic_write(filename, str_write)


def _atomic_write(filename, write_function):
    """
    :type write_function: (file) -> None
    :type filename: str
    """
    retry_count, retry_interval = 3, 0.3
    exc = None

    with tempfile.NamedTemporaryFile('w',
                                     delete=False,
                                     dir=os.path.dirname(filename)) as f:
        tmp_filename = f.name
        write_function(f)
        f.flush()
        for _ in range(retry_count):
            try:
                os.fsync(f.fileno())
                break
            except OSError as e:
                exc = e
                time.sleep(retry_interval)
        else:
            raise exc

    os.chmod(tmp_filename, 0o644)  # NamedTemporaryFile creates file with 0o600
    os.rename(tmp_filename, filename)


class reboot_lock(object):
    def __init__(self, fname=None, timeout=None):
        self.timeout = timeout
        if fname is None:
            self.fname = reboot_lock._build_lock_filename()
        else:
            self.fname = fname
        self.lock = os.open(self.fname, os.O_WRONLY | os.O_CREAT, 0o600)

    @staticmethod
    def _build_lock_filename():
        lock_file_dir = '/var/lve'
        # Create /var/lve directory if it is absent
        if not os.path.exists(lock_file_dir):
            mkdir_p(lock_file_dir)
        return os.path.join(lock_file_dir, 'lvestats.lock')

    def __enter__(self):
        start_time = time.time()
        while True:
            try:
                fcntl.flock(self.lock, fcntl.LOCK_EX | fcntl.LOCK_NB)
                return
            except IOError as e:
                # raise on unrelated IOErrors
                if e.errno != errno.EAGAIN:
                    raise
                else:
                    if self.timeout is not None and time.time() - start_time > self.timeout:
                        return
                    time.sleep(0.1)

    def __exit__(self, exc_type, exc_value, traceback):
        fcntl.flock(self.lock, fcntl.LOCK_UN)
        os.close(self.lock)
        try:
            os.unlink(self.fname)
        except Exception:
            pass


def deserialize_lve_id(serialized_value):
    """
    Extract container_id and bool(is_reseller) from serialized value;
    :type serialized_value: int
    :rtype: tuple[int, bool]
    """

    if serialized_value > get_current_max_lve_id() and serialized_value != LIMIT_LVP_ID:
        return serialized_value - get_current_max_lve_id(), True
    return serialized_value, False


def serialize_lve_id(lve_id, lvp_id):
    """
    Get serialized int by lvp_id and lve_id
    :type lvp_id: int
    :type lve_id: int
    :rtype: int
    """

    if lve_id == LIMIT_LVP_ID and lvp_id != 0:
        return get_current_max_lve_id() + lvp_id
    return lve_id


def get_hostname():
    """
    Get hostname using socket.gethostname;
    Return 'N/A' if some error happens;
    :return:
    """
    try:
        hostname = gethostname()
    except SocketError:
        hostname = 'N/A'
    return hostname


def skip_user_by_maxuid(user_id):
    """
    Returns is user should be skipped due to set max uid
    :param user_id: user id
    """
    return user_id > uid_max()


def get_all_user_domains(username, include_aliases=True):
    """
    Returns ALL user domains: main/sub/aliases
    """
    domains = get_domains(username, raise_exception=False)
    aliases = []
    if include_aliases:
        for domain in domains:
            alias = get_aliases(username, domain, raise_exception=False)
            if alias:
                aliases += alias
    return list(set(domains + aliases))


def normalize_domain(domain):
    """
    Cuts extra prefix www from domain
    """
    if domain.startswith('www.'):
        domain = domain[4:]
    return domain


def user_should_be_notified(user: str) -> bool:
    """
    Checks if notifications are turned off by user: marker file exists
    """
    return not os.path.exists(os.path.join(pwd.getpwnam(user).pw_dir, '.lvestats',
                                           USER_NOTIFICATIONS_OFF_MARKER))


def get_ascii_table(rows: list[list[Any]],
                    fields: list[str] | None = None,
                    left_padding: int = 0,
                    padding_width: int = 0) -> str:
    """
    Prepare rows and columns names to print in console
    """
    fields = fields or []
    table = PrettyTable(fields,
                        left_padding=left_padding,
                        padding_width=padding_width)
    for field in fields:
        table.align[field] = "l"  # align left
    table.add_rows(rows)
    return table.get_string()

?>