import ldap
import re
import traceback
import base64

from context import ServerError

def easydb_server_start(easydb_context):
    logger = easydb_context.get_logger('base.ldap')
    easydb_context.register_callback('ldap_get_user', {'callback': 'ldap_get_user'})
    easydb_context.register_callback('ldap_get_groups', {'callback': 'ldap_get_groups'})

def ldap_get_user(easydb_context, easydb_info):
    logger = easydb_context.get_logger('base.ldap')
    try:
        u = ldap_get_user_and_groups(
            easydb_context, logger,
            easydb_info['parameters']['login'],
            easydb_info['parameters']['password'])
        #print("user", repr(u))
        return u

    except Exception as e:
        logger.error("exception in LDAP module: %s\n%s" % (repr(e), traceback.format_exc()))
        return None

def ldap_get_groups(easydb_context, easydb_info):
    logger = easydb_context.get_logger('base.ldap')
    try:
        g = ldap_get_user_and_groups(
            easydb_context, logger,
            easydb_info['parameters']['login'],
            None,
            skip_bind = True)
        #print("groups", repr(g))
        return g

    except Exception as e:
        logger.error("exception in LDAP module: %s\n%s" % (repr(e), traceback.format_exc()))
        return None

class LdapEntry (object):
    scope_map = {
        'sub': ldap.SCOPE_SUBTREE,
        'one': ldap.SCOPE_ONELEVEL,
        'base': ldap.SCOPE_BASE,
    }

    def __init__(self, logger, fmt, conf, group):
        self.logger = logger
        self.fmt = fmt
        if not group in conf or not isinstance(conf[group], dict):
            raise Exception("no '%s' in LDAP config entry" % group)
        expected_attributes = ('server', 'basedn', 'filter')
        for attr in expected_attributes:
            if not attr in conf[group]:
                raise Exception("'%s' missing in '%s' config entry" % (attr, group))
        self.group = group
        self.conf = conf[group]
        self.url = self.conf.get('protocol', 'ldap') + '://' + self.conf['server']
        if 'port' in self.conf:
            self.url += ':' + str(self.conf['port'])
        self.scope = self.scope_map[self.conf.get('scope', 'sub')]
        self.basedn = self.conf['basedn']
        self.filterstr = self.conf['filter']

        self.machine_user = self.conf.get('user')
        self.machine_pass = self.conf.get('password')

    def init(self):
        #print("LDAP URL: %s" % self.url)
        self.ldap = ldap.initialize(self.url)
        self.ldap.set_option(ldap.OPT_REFERRALS, 0)

        if self.machine_user and self.machine_pass:
            self.bind(self.machine_user, self.machine_pass)

    def bind(self, binddn, password):
        self.ldap.simple_bind_s(binddn, password.encode('utf-8'))

    def search(self, parameters):
        #self.logger.warn("search filterstr: %s" % repr(self.filterstr))
        #self.logger.warn("search parameters: %s" % repr(parameters))

        variables = self.fmt.get_variables(self.filterstr)
        multivals = {}
        for var in variables:
            if isinstance(parameters.get(var), list) and len(parameters[var]) > 1:
                multivals[var] = list(parameters[var])

        if len(multivals) > 0:
            # there are some variables with multiple values, have to permutate
            if len(multivals) > 1:
                # cheap solution for now, just one multi-value variable allowed
                raise Exception("multiple multi-value variables for filter '{0}' ({1}), currently unsupported".format(
                    self.filterstr, ", ".join(multivals.keys())))
            flt = '(|'
            for k, v in multivals.items():
                for val in v:
                    subparams = parameters.copy()
                    subparams[k] = val
                    subflt = self.fmt(self.filterstr, subparams, self.fmt.escape_ldap)
                    flt += subflt
            flt += ')'
        else:
            flt = self.fmt(self.filterstr, parameters, self.fmt.escape_ldap)

        self.logger.debug("filter: %s" % flt)
        r = self.ldap.search_s(self.basedn,
            scope = self.scope,
            filterstr = flt)

        r = tuple(filter(lambda x: x[0] is not None, r))

        self.logger.debug("found record: %s" % (r,))

        return r

class Formatter (object):
    def __init__(self):
        self.find_fmt_re = re.compile('%\(([a-zA-Z0-9_.]*)\)s')

    @staticmethod
    def _to_unicode(x):
        if x is None or isinstance(x, str):
            return x
        try:
            return x.decode('utf-8')
        except UnicodeDecodeError:
            return str(base64.b64encode(x))

    def __call__(self, fmt, replacements, escape_method = None):
        repl = replacements.copy()
        for k, v in repl.items():
            if isinstance(v, list):
                repl[k] = ';'.join(map(self._to_unicode, v))
            elif isinstance(v, str):
                repl[k] = self._to_unicode(v)
        if escape_method is not None:
            repl_c = { k: escape_method(v) for k, v in repl.items() }
            repl = repl_c
        for match in self.find_fmt_re.findall(fmt):
            if not match in repl:
                repl[match] = ''
        return fmt % repl

    def get_variables(self, fmt):
        return self.find_fmt_re.findall(fmt)

    @staticmethod
    def escape_ldap(text):
        ret = ''
        for n in text.encode('utf-8'):
            if n < 32 or n > 127 or chr(n) in "\\,#+<>;\"()*\x00 ":
                ret += '\\%02x' % n
            else:
                ret += chr(n)
        return ret

    @staticmethod
    def _regex_replace(val, conf):
        if 'regex_match' in conf and 'regex_replace' in conf:
            val = re.sub(conf['regex_match'], conf['regex_replace'], Formatter._to_unicode(val))
        return val

    @staticmethod
    def _val_to_str(val):
        if isinstance(val, list):
            return [ Formatter._to_unicode(v) for v in val ]
        else:
            return Formatter._to_unicode(val)

    @classmethod
    def apply_mapping(cls, replacements, mapping_conf, keep_lists = False):
        for var, conf in mapping_conf.items():
            if not 'attr' in conf:
                raise Exception("expected 'attr' in mapping config for {0}".format(var))
            val = cls._val_to_str(replacements.get(conf['attr']))
            if val:
                if isinstance(val, list):
                    if keep_lists:
                        replacements[var] = []
                        for item in val:
                            replacements[var].append(cls._regex_replace(item, conf))
                    else:
                        replacements[var] = cls._regex_replace(";".join(val), conf)
                else:
                    replacements[var] = cls._regex_replace(val, conf)

def ldap_get_user_and_groups(easydb_context, logger, login, password, skip_bind = False):
    user_record = None
    fmt = Formatter()

    ldap_config = easydb_context.get_config('system.ldap')
    if not ldap_config or not len(ldap_config):
        logger.debug('no LDAP configuration, fail silently')
        return None

    for conf in ldap_config:
        env = conf.get('environment', {})

        l = LdapEntry(logger, fmt, conf, 'user')
        l.init()
        rec = l.search({
            'login': login.lower(),
            'Login': login,
            'LOGIN': login.upper(),
        })

        #print(repr(rec))

        if not rec or not len(rec):
            continue # no user found
        if len(rec) > 1:
            raise Exception("more than one user found, search filter not sufficient")

        (user_dn, user_rec) = rec[0]

        if not skip_bind:
            l.bind(user_dn, password)

        user_record = { 'user.' + k : v for k, v in user_rec.items() }
        user_record['user.dn'] = user_dn
        group_records = []

        mapping_conf = env.get('mapping')
        if isinstance(mapping_conf, dict):
            fmt.apply_mapping(user_record, mapping_conf)

        user_env = env.get('user', {})
        login_format = user_env.get('login', '%(user.dn)s')
        displayname_format = Formatter._to_unicode(user_env.get('displayname', '%(user.dn)s'))
        email_format = user_env.get('email')

        ulogin = fmt(login_format, user_record)
        if ulogin is None or not len(ulogin):
            raise Exception("empty login (format was %s)" % login_format)

        displayname = fmt(displayname_format, user_record)
        groups = set()

        emails = []
        if email_format:
            email = fmt(email_format, user_record)
            if email:
                emails.append({
                    'email': email,
                    'needs_confirmation': False,
                    'use_for_email': True,
                    'send_email': True,
                    'is_primary': True,
                })

        if 'group' in conf:
            lg = LdapEntry(logger, fmt, conf, 'group')
            lg.init()
            grprecs = lg.search(user_record)

            for grprec in grprecs:
                (group_dn, group_rec) = grprec
                group_record = { 'group.' + k : v for k, v in group_rec.items() }
                group_records.append(group_record)

                user_and_group = user_record.copy()
                user_and_group.update(group_record)
                if isinstance(mapping_conf, dict):
                    fmt.apply_mapping(user_and_group, mapping_conf, True)

                #logger.warn("group record: %s" % repr(group_record))
                #logger.warn("mixed record: %s" % repr(user_and_group))

                for group_env in env.get('groups', {}):
                    #logger.warn("group env: %s" % repr(group_env))
                    if not 'attr' in group_env:
                        raise Exception('"attr" missing in "environment.groups"')
                    val = user_and_group.get(group_env['attr'])
                    #logger.warn("value of %s: %s" % (group_env['attr'], val))
                    if val is None:
                        continue
                    divider = group_env.get('divider')
                    if isinstance(val, list):
                        for item in val:
                            _ldap_add_groups(groups, Formatter._to_unicode(item), divider)
                    else:
                        _ldap_add_groups(groups, Formatter._to_unicode(val), divider)


        user_record = {
            'user': {
                'login': ulogin,
                'displayname': displayname,
            },
            '_groups': list(groups),
            '_emails': emails,
        }
        logger.info("user record: %s" % user_record)
        break

    return user_record

def _ldap_add_groups(groups, val, divider):
    val = val.strip()
    if not len(val):
        return
    if divider:
        for g in val.split(divider):
            groups.add(g)
    else:
        groups.add(val)

