#!/usr/bin/python3

import sys
import argparse
import logging
from logging import handlers
import os.path
import pickle

# Find right direction when running from source tree
sys.path.insert(0, "/usr/local/samba/lib/python2.7/site-packages")

from regex import compile as recomp
from regex import search

from yaml import load, dump, YAMLObject, safe_load, safe_dump
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper

from samba.samdb import SamDB
from samba import ldb
from samba import dsdb
from samba.param import LoadParm
from samba.auth import system_session
from samba.credentials import Credentials, AUTO_USE_KERBEROS, MUST_USE_KERBEROS

import getpass
import unicodedata


# Construction d’une liste ordonnée des règles de substitution
def pattern_from_list(patterns):
    return recomp(r"|".join([r"({})".format(p) for p in patterns]))


SUBSTITUTION_RULES = []

SUBSTITUTION_RULES.append(('', pattern_from_list([
                                                  r"^\s+",
                                                  r"\s+$",
                                                  ])))

SUBSTITUTION_RULES.append(('-', pattern_from_list([
                                                   r"\.-\.",
                                                   ])))

#SUBSTITUTION_RULES.append(('-', pattern_from_list([
#                                                   r"\s+",
#                                                   ])))

#SUBSTITUTION_RULES.append(('-', pattern_from_list([
#                                                   r"---",
#                                                   ])))

SUBSTITUTION_RULES.append(('-', pattern_from_list([
                                                  r" - ",
                                                  ])))

SUBSTITUTION_RULES.append(('-', pattern_from_list([
                                                  r"\s+",
                                                  ])))

SUBSTITUTION_RULES.append(('', pattern_from_list([
                                                  r"[^A-Za-z0-9\.'-]",
                                                  ])))

SUBSTITUTION_RULES.append(('', pattern_from_list([
                                                  r"\.\.",
                                                  ])))

SUBSTITUTION_RULES.append(('', pattern_from_list([
                                                  r"\s+",
                                                  ])))

SUBSTITUTION_RULES.append(('', pattern_from_list([
                                                  r"^\.",
                                                  r"\.$",
                                                  ])))

SUBSTITUTION_RULES.append(('-', pattern_from_list([
                                                   r"'",
                                                   ])))


DIV_SEP = recomp(r"(?:(?<!\\),)?(?:(?:ou=)|(?:dc=))")
DIV_ELEMENT = recomp(r'(?:((ou=)|(dc=))(?P<div>[^,]+)(?:,)?)+')


def filter_content(content, substitution_rules=SUBSTITUTION_RULES):
    """Return content with forbidden patterns substituted.
    :param content: value that may content forbidden patterns
    :type content: str
    """
    # print substitution de contenu
    for substitute, pattern in substitution_rules:
        content = pattern.sub(substitute, content)
    return content


def normalize(div):
    nfkd_form = unicodedata.normalize('NFKD', div)
    div = nfkd_form.encode('ASCII', 'ignore')
    # print normalisation des divisions
    return div.decode('ASCII')


def extract_division_elements(raw_division):
    elements = DIV_ELEMENT.search(raw_division).captures('div')
    # print extraction des éléments de division
    return elements


class Configuration:
    """
    Handle configuration for the script
    Load rules and configurations from a file.
    """
    def __init__(self, cfile):
        self.logger = logging.getLogger('manageGroups.configuration')
        self.cfile = cfile
        self.rules = []
        self.general = {}

    def load(self):
        """ Load the configuration from a file"""
        self.logger.info('Loading configuration', extra={'level': 'h1'})
        data = None
        with open(self.cfile, 'r') as infile:
            data = load(infile)
            for rule in data['ExtractRules']:
                self.logger.debug('Loading rule {}'.format(rule), extra={'level': 'l1'})
                self.rules.append(ExtractRule(**rule))
            if 'General' in data.keys():
                self.general = data['General']
        return True

    def dump(self):
        """ Write configuration to a file only use for devel matters"""
        with open(self.cfile, 'w') as outfile:
            data = { 'ExtractRules': [] }
            for rl in self.rules:
                data['ExtractRules'].append(rl.__dict__)
            dump(data, outfile, default_flow_style=False)
            for gn in self.general:
                dump(gn, outfile, default_flow_style=False)
        return True


class ExtractRule(object):
    """ Define the group extraction rules
    :param name: The rule name
    :type name: `String`
    :param separator: Maker to identify if the rule applies
    :type separator: `String`
    :param offset: If you want to ignore fileds befor the "seaprator"
    :type offset: `Int`
    :param ignore: List of strings to ignore (ac, gouv, fr)
    :type ignore: `List`
    :param match: Regexp to match group names
    :type match: `String`
    :param code: Regexp to create group names
    :type code: `String`
    :param suffix: List of group suffixies
    :type suffix: `List`
    :param affectRule: Rules to put users in created groups
    :type affectRule: `Dict`
    :param destOU: Default destination OU
    :type affectRule: `String`
    """
    def __init__(self, name="", separator="", offset=0, ignore=None,
                match=None, code=None, suffix=None, affectRule=None,
                destOU=None):
        self.logger = logging.getLogger('manageGroups.extractrules')
        self.name = name
        self.separator = separator
        self.offset = offset
        self.ignore = ignore if ignore is not None else []
        self.match = match
        self.code = code
        self.suffix = suffix if suffix is not None else []
        self.affectRule = affectRule if affectRule is not None else []
        self.destOU = destOU
        self.logger.debug('Initializing ExtractRule instance', extra={'level': 'h1'})
        for key, value in self.__dict__.items():
            self.logger.debug('with {} = {}'.format(key, value), extra={'level': 'l1'})

    def show(self):
        print("  Rule Separator : {0}".format(self.separator))
        print("  Rule Offset    : {0}".format(self.offset))
        print("  Rule To remove : {0}".format(self.ignore))


class DirectoryUser(object):
    """ User in directory
    :param samaccountname: the samAccountName attribute in directory
    :type samAccountName: `String`
    :param rawDivision: the division attribute in directory
    :type rawDivision: `String`
    :param mail: The mail attribute in directory
    :type mail: `String`
    :param rule: The group generation rule based on 'separator' identification
    :type rule: `ExtractRule`
    :param divisions: The divisions extracted with the "rule"
    :type divisions: `List`
    :param groups: The list of generated groups for user
    :type groups: `List`
    :param config; General configuration object
    :type config: `Dict`
    """
    def __init__(self, name, division, mail, rdiv, rule):
        self.logger = logging.getLogger('manageGroups.directoryuser')
        self.logger.info('Initializing DirectoryUser instance for user {}'.format(name), extra={'level': 'h2'})
        self.samaccountname = name
        self.rawDivision = rdiv
        self.mail = mail
        self.rule = rule
        self.divisions = self.__extractDivision__(division)
        self.groups = self.__genGroups__()
        self.logger.debug('DirectoryUser instance initialized with:', extra={'level': 'h3'})
        for key, value in self.__dict__.items():
            self.logger.debug('{} = {}'.format(key, value), extra={'level': 'l3'})

    def __extractDivision__(self, div):
        """ Extract divisions from the division attribute collected in the
        directory

        The filed is splitted on the ',' then we try to igonre listed fileds
        if it's the case. If it's not we try to match the regexp.
        If we don't have any ignore or any match we use the separator and the
        offset to create the 'division list'. We take all the fileds before
        separator remove the offset and we have a list.
        """
        self.logger.debug('Extracting division from {}'.format(div), extra={'level': 'h3'})
        if self.rule:
           if self.rule.ignore:
               div = [filter_content(el) for el in reversed(div) if el not in self.rule.ignore]
           elif self.rule.match:
               match = search(self.rule.match, self.rawDivision.lower())
               if match:
                   div = eval(self.rule.code)
                   div = [filter_content(d) for d in div]
               else:
                   div = []
           elif self.rule.separator:
               if self.rule.separator in division:
                   indx = div.index(self.rule.separator) - self.rule.offset
                   div = div[:indx]
                   div.pop(0)
                   div.reverse()
                   div = [filter_content.sub(d) for d in div]
               else:
                   div = []
           else:
               div = []
           return div

    def debug(self,section,message):
        """ Debug function for debug purpose """
        print("[DEBUG][{0}][{1}]".format(section, message))

    def __isMemberOf__(self, rule, group):
        """ Check if a user is memeber of a group according to the affect
        rule.

        Only supports "contain" key word for now
        """
        self.logger.debug('Checking if user {} is member of group {}'.format(self.samaccountname, group), extra={'level': 'h3'})
        if rule['cond'] == "contain":
            if rule['value'] == '*':
                self.logger.debug('User {} is member of group {}'.format(self.samaccountname, group), extra={'level': 'n3'})
                return True

            if rule['value'] in self.__getattribute__(rule['attr']):
                self.logger.debug('User {} is member of group {}'.format(self.samaccountname, group), extra={'level': 'n3'})
                return True
            else:
                self.logger.debug('User {} is not member of group {}'.format(self.samaccountname, group), extra={'level': 'n3'})
                return False

    def __cleanMembership__(self, groups):
        """Return group list purged leaving only leaves for affectRules
        with recursive == False
        """
        # lister les suffixes avec la propriété leave_only à Vrai
        # pour chacun de ces suffixes, lister les groupes qui ne sont pas des extrêmités
        # la liste de groupes finale est la différence entre la liste initiale et l’ensemble des groupes listés pour chaque suffixe.
        self.logger.debug('Cleaning membership for groups {}'.format(groups), extra={'level': 'h3'})
        def re_in_list(re, list_of_str):
            truth = False
            for el in list_of_str:
                if search(re, el) and re != el:
                    truth = True
                    break
            return truth

        superfluous_groups = []
        leaves_only_suffixes = [rule['suffix'] for rule in self.rule.affectRule
                                if rule['leaves_only'] == True]
        for leaves_only_suffix in leaves_only_suffixes:
            suffixed_groups = [group[0].rpartition(leaves_only_suffix)[0] for group in groups
                               if group[0].endswith(leaves_only_suffix if leaves_only_suffix is not None else '')]
            for suffixed_group in suffixed_groups:
                if re_in_list(suffixed_group, suffixed_groups):
                    superfluous_groups.extend([group for group in groups
                                               if group[0] == suffixed_group + leaves_only_suffix])
        return list(set(groups).difference(set(superfluous_groups)))

    def __genGroups__(self):
        """ Generate group list for a user
        If the user match the "membership" rule the group is added to his
        group list.

        All the groups are generated, the root group dans the subgroups.
        for exemple for the divisions [ 'ac', 'cpii', 'pne', 'sys' ]
        we genearte this groups :
            - ac
            - ac.cpii
            - ac.cpii.pne
            - ac.cpii.pne.sys

        If we provide suffixes ([-gouv, -id2 ]) we alsot generate groups
        with suffixes like :
            - ac
            - ac-gouv
            - ac-i2
            - ac.cpii
            - ac.cpii-gouv
            - ac.cpii-i2
            - ac.cpii.pne
            - ac.cpii.pne-gouv
            - ac.cpii.pne-i2
            - ac.cpii.pne.sys
            - ac.cpii.pne.sys-gouv
            - ac.cpii.pne.sys-i2
        """
        grps = []
        root = None
        if self.divisions:
            for grp in self.divisions:
                #self.debug("FOR", grp)
                if root:
                    grp_name = root + "." + grp
                    #self.debug("grp_name", grp_name)
                else:
                    grp_name = grp
                    #self.debug("grp_name", grp_name)

                root=grp_name
                #self.debug("root", grp_name)

                if self.rule.suffix:
                    for sf in self.rule.suffix:
                        #self.debug("sf",sf)
                        if sf is None:
                            group = grp_name
                            #self.debug("group", group)
                        else:
                            group = "{0}{1}".format(grp_name, sf)
                            #self.debug("group", group)

                        for rule in self.rule.affectRule:
                            #self.debug("RULE", rule)
                            if rule['suffix'] == sf:
                                  if self.__isMemberOf__(rule,group):
                                      grps.append((group,self.rule.destOU))
                                      #self.debug("APPEND TO GROUPS", grps)

                            else:
                                continue
                else:
                    grps.append((grp_name, self.rule.destOU))
                    #self.debug("NOSUFFIXAPPEND",grps)

            return self.__cleanMembership__(grps)
        else:
            return []


    def show(self):
        """ Print the obejct """
        print("================================================================================================")
        print("Name: {0}\nDivision: {1}\nMail:{2}".format(
            self.samaccountname,
            self.divisions,
            self.mail))
        print("Groups :")
        for grp in self.groups:
            print("\t{0}".format(grp))
        print("================================================================================================")
        print("Rule :")
        if self.rule:
            self.rule.show()

class DirectoryData:
    """ Data collected from the Samba4 directory
    """

    def __init__(self, rules, config):
        """ The collected data
        :param samDB: Samba "connection"
        :type samDB: `SamDB`
        :param domain_dn: The domain dn configured in samba
        :type domain_dn: Samba domain DN
        :param rules: The extraction rules
        :type rules: `List` of `ExtractRule` instance
        :param rawUserData: The user data
        :type rawUserData: `List` of `DirectoryUser` instance
        :param groups: `List` of groups to create in samba
        :type groups: `List`
        :param sambaGroups: Founed groups in samba
        :type sambaGroups: `List`
        :param config: General configuration
        :type sourceAttr: `String`
        :param sourceAttr: Group extraction string source attribute
        """
        self.logger = logging.getLogger('manageGroups.directorydata')
        self.logger.info('Initializing DirectoryData instance', extra={'level': 'h1'})
        self.samDB = self.__openConn__()
        self.domain_dn = self.samDB.domain_dn()
        self.rules = rules
        self.config = config
        if 'attributSource' in config:
            self.sourceAttr = config['attributSource']
        else:
            self.sourceAttr = 'departmentNumber'
        self.rawUserData = self.__getRawData__()
        self.groups = self.__getAllGroups__()
        self.sambaGroups = self.__getSambaGroups__()
        self.logger.debug('DirectoryData instance initialized with:', extra={'level': 'h2'})
        for key, value in self.__dict__.items():
            self.logger.debug('{} = {}'.format(key, value), extra={'level': 'l2'})

    def __openConn__(self):
        """ Open the samba connection"""
        # Loading Parameters
        lp = LoadParm()

        # Loading Credentials
        creds = Credentials()
        creds.guess(lp)
        creds.set_username('admin')

        try:
            samDB = SamDB(session_info=system_session(), credentials=creds, lp=lp)
            self.logger.debug('Samba connection open for user admin')
        except Exception as exc:
            self.logger.error('Samba connection failed', exc_info=True)
            raise exc
        return samDB

    def __getSambaGroups__(self):
        """ Get all groups from samba directory"""
        expression=("(objectClass=group)")
        grps = []
        res = self.samDB.search(self.domain_dn, scope=ldb.SCOPE_SUBTREE,
              expression=expression,
              attrs=["samaccountname", "grouptype"])

        for grp in res:
            samaccountname = grp.get('samaccountname', idx=0)
            grps.append(samaccountname.decode('utf-8') if isinstance(samaccountname, bytes) else samaccountname)

        self.logger.debug('Groups found in directory:', extra={'level': 'h2'})
        for grp in sorted(grps):
            self.logger.debug('{}'.format(grp), extra={'level': 'l2'})
        return grps

    def __getRawData__(self):
        """ Create a list of DirectoryUser instances with the user information
        like mail, division, rules ...
        """
        data = []
        expression=("(&(objectClass=user)(userAccountControl:%s:=%u))" %
                   (ldb.OID_COMPARATOR_AND, dsdb.UF_NORMAL_ACCOUNT))
        res = [el for el in self.samDB.search(self.domain_dn, scope=ldb.SCOPE_SUBTREE,
              expression=expression,
              attrs=[self.sourceAttr, "samaccountname", "mail"])
              if self.sourceAttr in el]
        foundrdiv = 0
        for elm in res:
            if elm:
                rdivs = [rd.decode('utf-8') if isinstance(rd, bytes) else rd for rd in elm.get(self.sourceAttr)]
                for rdiv in rdivs:
                    if rdiv:
                        foundrdiv += 1
                        name = elm.get('samaccountname', idx=0)
                        name = name.decode('utf-8') if isinstance(name, bytes) else name
                        mail = elm.get('mail', idx=0)
                        mail = mail.decode('utf-8') if isinstance(mail, bytes) else mail
                        if rdiv :
                            divs = extract_division_elements(rdiv)
                            divs = [normalize(div.lower()) for div in divs]

                            # Creating the user list with the rules selected with the
                            # separator
                            for rl in self.rules:
                                if rl:
                                    if rl.separator in ','.join(divs):
                                        if not mail:
                                           mail = ''
                                        data.append(DirectoryUser(name, divs, mail, rdiv, rl))
                    else:
                        continue
            else:
                continue

        if foundrdiv == 0:
            raise Exception("Bad attribute : {0}".format(self.sourceAttr))
        else:
            return data


    def __getAllGroups__(self):
        """ Create the full list of groups to create"""
        grps = []
        for user in self.rawUserData:
            grps = list(set(grps + user.groups))

        grps.sort()
        return grps

    def __renameGroup__(self, old_grp, new_grp):
        """Rename group
        """
        old_dn =  self.samDB.search(self.domain_dn, scope=ldb.SCOPE_SUBTREE, expression="(samaccountname={})".format(old_grp), attrs=[])[0].get('dn').get_linearized()
        new_dn = old_dn.replace(old_grp, new_grp)
        self.samDB.rename(old_dn, new_dn)
        ldif = """dn: {}
changetype: modify
replace: samaccountname
samaccountname: {}""".format(new_dn, new_grp)
        self.samDB.modify_ldif(ldif)

    def __cleanGroups__(self, cachefile, pgroups):
            self.logger.info('Cleaning groups', extra={'level': 'h2'})
            try:
                oPgroups = pickle.load(open(cachefile, "rb"))
            except Exception as e:
                self.logger.error("Error loading cache !!!", exc_info=True)
                return False

            toRemoveFromGrp = {}
            groupsToRemove = list(set(oPgroups.keys()) - set(pgroups.keys()))
            for group in pgroups:
                if group in oPgroups.keys() :
                    res = list(set(oPgroups[group]) - set(pgroups[group]))
                    if res:
                        toRemoveFromGrp[group] = res

            # Clean Groups
            for grp in toRemoveFromGrp:
                try:
                    self.logger.debug("Removing users from group {0}".format(grp), extra={'level': 'n2'})
                    for usr in toRemoveFromGrp[grp]:
                        self.logger.debug("Removed {0}".format(usr), extra={'level': 'l3'})
                    self.samDB.add_remove_group_members(grp, toRemoveFromGrp[grp],
                                                        add_members_operation=False)
                except Exception as e:
                    self.logger.error("Error removing users from group {0}".format(grp), exc_info=True)

            if groupsToRemove:
                for grp in groupsToRemove:
                    self.logger.debug("Removing group {0}".format(grp), extra={'level': 'n2'})
                    try:
                        self.__renameGroup__(grp, 'SAUVEGARDE.' + grp)
                    except Exception as e:
                        self.logger.error("Error removing group {0}".format(grp), exc_info=True)
                        continue

    def createGroups(self):
        """ Create all groups in Samba Directory """
        res = True
        failed = []
        self.logger.info('Creating missing groups', extra={'level': 'h1'})
        for grp,dest in self.groups:
            try:
                if grp in self.sambaGroups:
                    self.logger.debug("Group {0} already exists".format(grp), extra={'level': 'n1'})
                else:
                    if 'SAUVEGARDE.' + grp in self.sambaGroups:
                        self.logger.debug("Renaming group : {0}".format(grp), extra={'level': 'n1'})
                        self.__renameGroup__('SAUVEGARDE.' + grp, grp)
                    else:
                        self.logger.debug("Creating group : {0}".format(grp), extra={'level': 'n1'})
                        self.samDB.newgroup(grp, groupou=dest)
            except Exception as e:
                self.logger.error("Error Creating group {0}".format(grp), exc_info=True)

        if len(failed) != 0:
            return False
        else:
            return True

    def affectUsers(self):
        """ Put users in groups according to the affectRules """
        self.logger.info('Affecting users to groups', extra={'level': 'h1'})
        pgroups = {}
        failed = []
        cachefile = None
        oPgroups = None

        for grp,dest in self.groups:
            pgroups[grp] = []

        for user in self.rawUserData:
            for group,dest in user.groups:
                if user.samaccountname not in pgroups[group]:
                    pgroups[group].append(user.samaccountname)

        if 'cachefile' in self.config.keys():
            cachefile = self.config['cachefile']
        else:
            cachefile = '/var/lib/eole/extractcache.p'

        if os.path.exists(cachefile):
            self.__cleanGroups__(cachefile, pgroups)

        pickle.dump(pgroups, open(cachefile, "wb"))

        for key in pgroups:
            self.logger.debug('Group {}'.format(key), extra={'level': 'h2'})
            try:
                for usr in pgroups[key]:
                    self.logger.debug("Adding user {0} to group {1}".format(usr, key), extra={'level': 'l2'})
                self.samDB.add_remove_group_members(key,pgroups[key], add_members_operation=True)
            except Exception as e:
                self.logger.error("Error adding users to group {0}".format(key), exc_info=True)
                failed.append(key)

        if len(failed) != 0:
            return False
        else:
            return True


    def show(self):
        if self.rawUserData:
            for dt in self.rawUserData:
                dt.show()

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('-c', '--config', help=u'Configuration file', metavar='CONFIG_FILE')
parser.add_argument('-l', '--loglevel', help=u'Log level', metavar='LOG_LEVEL', default='info')
parser.add_argument('-s', '--silent', help=u'Silent mode (no output)', action='store_true')
args = parser.parse_args()

class ContextFilter(logging.Filter):
    """
    This is a filter which injects contextual information into the log.

    Rather than use actual contextual information, we just use random
    data in this demo.
    """

    LEVELS = {
            'h1': {'sep': '\n', 'coltag': '\x1b[36m'},
            'h2': {'sep': '\n', 'indent': 1, 'coltag': '\x1b[36m'},
            'h3': {'sep': '\n', 'indent': 2, 'coltag': '\x1b[36m'},
            'l1': {'indent': 1, 'symbol': '\__ '},
            'l2': {'indent': 2, 'symbol': '\__ '},
            'l3': {'indent': 3, 'symbol': '\__ '},
            'n1': {'indent': 1},
            'n2': {'indent': 2},
            'n3': {'indent': 3},
            }

    SEVERITY = {
            logging.ERROR: '\x1b[31m',
            logging.INFO: '\x1b[32m',
            logging.WARNING: '\x1b[33m',
            }

    def filter(self, record, **kwargs):
        if hasattr(record, 'level'):
            level = self.LEVELS.get(getattr(record, 'level'), {})
        else:
            level = {}
        record.sep = level.get('sep', '')
        record.tab = '\t'*level.get('indent', 0)
        record.symbol = level.get('symbol', '')
        coltag = level.get('coltag', '')
        if not coltag:
            coltag = self.SEVERITY.get(record.levelno, '')
        record.coltag = coltag
        record.endcoltag = "\x1b[0m" if record.coltag else ''
        return True

LOG_LEVELS = {
        'info': logging.INFO,
        'debug': logging.DEBUG,
        'error': logging.ERROR,
        'warning': logging.WARNING,
        }


if args.config:
    conf = Configuration(args.config)
    conf.load()

    logger = logging.getLogger('manageGroups')
    sys_hdlr = handlers.SysLogHandler(address='/dev/log')
    sys_formatter = logging.Formatter('manageGroups: %(message)s')
    sys_hdlr.setFormatter(sys_formatter)
    logger.addHandler(sys_hdlr)
    if not args.silent:
        std_hdlr = logging.StreamHandler()
        std_formatter = logging.Formatter('%(sep)s%(coltag)s%(tab)s%(symbol)s%(message)s%(endcoltag)s')
        f = ContextFilter()    
        std_hdlr.addFilter(f)
        std_hdlr.setFormatter(std_formatter)
        logger.addHandler(std_hdlr)
    logger.setLevel(LOG_LEVELS.get(args.loglevel, logging.INFO))


    try:
        data = DirectoryData(conf.rules, conf.general)
        data.createGroups()
        data.affectUsers()
    except Exception as e:
        logger.error("General Error {0}".format(e), exc_info=True)

else:
    parser.print_help()

