# coding: utf-8
from IPy import IP
from os import unlink
from sqlalchemy.types import UnicodeText, Boolean, Integer
from sqlalchemy.orm import relationship
from sqlalchemy import Column, ForeignKeyConstraint
from .elixir import Entity
from arv.lib.util import valid, try_unique_column, trace, \
        get_keyid_from_certifstring, bin_encoding, \
        decrypt_privkey, get_keyid_from_certifstring, suppress_colon, \
        valid_priv_and_cred, cred_end_validity_date, der_to_pem, \
        split_pkcs7
from creole import cert
from arv.db.edge import get_connected_nodes, TmplConnect, Credential, \
                        get_credential_auth, add_credential_auth, \
                        add_certificate_authorities
from arv.lib.logger import logger
from arv.config import ip_sep

def test_mimetype(mimetype):
    if mimetype not in ('ip', 'network'):
        raise ValueError('mimetype must be ip or network')

def valid_tmpl_vertex_zephir(zephir_module, zephir_var_ip1, zephir_var_ip2, mimetype):
    if zephir_module != None:
        zephir_module = valid(zephir_module, 'integer')
        if zephir_var_ip1 == None:
            raise ValueError('zephir_module and zephir_var_ip1 must be set together')
        zephir_var_ip1 = valid(zephir_var_ip1, 'string')
        if mimetype == 'ip':
            if zephir_var_ip2 != None:
                raise ValueError('if mimetype is "ip", zephir_var_ip2 must not be set')
        else:
            if zephir_var_ip2 == None:
                raise ValueError('if mimetype is "%s", zephir_var_ip2 must be set'%mimetype)
            zephir_var_ip2 = valid(zephir_var_ip2, 'string')
    elif zephir_var_ip1 != None or zephir_var_ip2 != None:
        raise ValueError('zephir_module, zephir_var_ip1 and zephir_var_ip2 must be set together')
    return zephir_module, zephir_var_ip1, zephir_var_ip2

class TmplNode(Entity):
    __tablename__ = 'arv_db_node_tmplnode'
    id = Column("id", Integer, primary_key=True)
    name = Column("name", UnicodeText, unique=True)
    mimetype = Column("mimetype", UnicodeText)
    allow_master = Column("allow_master", Boolean, default=True)
    tmpl_connects_a = relationship('TmplConnect', backref="tmpl_node_a", primaryjoin="TmplConnect.tmpl_node_a_id == TmplNode.id", uselist=True)
    tmpl_connects_b = relationship('TmplConnect', backref="tmpl_node_b", primaryjoin="TmplConnect.tmpl_node_b_id == TmplNode.id", uselist=True)
    state = Column('state', Integer, default=0)


    def add_tmpl_vertex(self, name, mimetype, zephir_module=None, zephir_var_ip1=None, zephir_var_ip2=None):
        name = valid(name, 'string')
        mimetype = valid(mimetype, 'string')
        test_mimetype(mimetype)
        if TmplVertex.query.filter_by(name=name, tmpl_node=self).first() != None:
            raise Exception("Name should be unique")
        zephir_module, zephir_var_ip1, zephir_var_ip2 = valid_tmpl_vertex_zephir(zephir_module, zephir_var_ip1, zephir_var_ip2, mimetype)

        return TmplVertex(name=name, tmpl_node=self, mimetype=mimetype,
                zephir_module=zephir_module, zephir_var_ip1=zephir_var_ip1,
                zephir_var_ip2=zephir_var_ip2)

    def _add_node(self, name, uai, id_zephir, eole_version):
        name = valid(name, 'string')
        if uai != None:
            uai = valid(uai, 'string')
            if Node.query.filter_by(name=name).first() != None:
                raise Exception('serveur name already in database')
        return Node(tmpl_node=self, name=name, uai=uai, id_zephir=id_zephir, eole_version=eole_version)

    def add_node(self, name, uai=None, id_zephir=None, eole_version=None):
        return try_unique_column('add_node', self._add_node, name=name, uai=uai, id_zephir=id_zephir, eole_version=eole_version)

@trace()
def get_tmpl_nodes(mimetype=None):
    try:
        if mimetype == None:
            return TmplNode.query.all()
        else:
            mimetype = valid(mimetype, 'string')
            return TmplNode.query.filter_by(mimetype=mimetype).all()
    except:
        raise Exception( 'error in get_tmpl_nodes')

@trace()
def add_tmpl_node(name, mimetype='custom', allow_master=True):
    try:
        name = valid(name, 'string')
        mimetype = valid(mimetype, 'string')
        allow_master = valid(allow_master, 'bool')
        if mimetype not in ('sphynx', 'etablissement', 'roadwarrior', 'custom'):
            raise ValueError('TmplNode mimetype must be sphynx, etablissement, roadwarrior or custom')
        if mimetype != 'custom' and TmplNode.query.filter_by(mimetype=mimetype).first() != None:
            raise ValueError('TmplNode mimetype %s is already used in database' % mimetype)
        return TmplNode(name=name, mimetype=mimetype, allow_master=allow_master)
    except Exception as e:
        raise Exception( "error in add_tmpl_node: %s"%str(e))

@trace()
def del_tmpl_node(tmpl_node):
    if Node.query.filter_by(tmpl_node=tmpl_node).first():
        raise Exception("TmplNode is already used by a Node, you cannot delete it")
    if not tmpl_node.tmpl_connects_a == [] or not tmpl_node.tmpl_connects_b == []:
        raise Exception("TmplNode is already used by a TmplConnect, you cannot delete it")
    try:
        for tmplvertex in tmpl_node.vertices:
            tmplvertex.delete()
    except:
        raise Exception("error when delete tmplvertex")
    try:
        tmpl_node.delete()
    except:
        raise Exception("error in del_tmpl_node")

@trace()
def create_tmpl_node(name):
    try:
        name = valid(name, 'string')
    except:
        raise TypeError("unsupported encoding in create_node")
    return Node(name=name)

@trace()
def get_tmpl_node_by_name(tname):
    try:
        tname = valid(tname, 'string')
        tnode = TmplNode.query.filter_by(name=tname).first()
    except Exception as e:
        raise ValueError(str(e))

    if tnode == '':
        raise Exception( 'no tmpl_node with this name')
    return tnode

class TmplVertex(Entity):
    __tablename__ = 'arv_db_node_tmplvertex'
    id = Column("id", Integer, primary_key=True)
    name = Column("name", UnicodeText)
    mimetype = Column("mimetype", UnicodeText)
    zephir_module = Column("zephir_module", Integer)
    zephir_var_ip1 = Column("zephir_var_ip1", UnicodeText)
    zephir_var_ip2 = Column("zephir_var_ip2", UnicodeText)
    tmpl_node_id = Column(Integer, index=True)
    tmpl_node = relationship('TmplNode', backref="vertices", primaryjoin="TmplVertex.tmpl_node_id == TmplNode.id", uselist=False)
    tmpl_edges_a = relationship('TmplEdge', backref="tmpl_vertex_a", primaryjoin="TmplEdge.tmpl_vertex_a_id == TmplVertex.id", uselist=True)
    tmpl_edges_b = relationship('TmplEdge', backref="tmpl_vertex_b", primaryjoin="TmplEdge.tmpl_vertex_b_id == TmplVertex.id", uselist=True)
    state = Column("state", Integer, default=0)
    __table_args__ = (ForeignKeyConstraint(['tmpl_node_id'], [u'arv_db_node_tmplnode.id'], **{'name': u'arv_db_node_tmplvertex_tmpl_node_id_fk'}),)

    def add_vertex(self, node, ip1, ip2=None):
        if node.tmpl_node != self.tmpl_node:
            raise Exception('Wrong node for this TmplVertex')
        vertex = Vertex(tmpl_vertex=self, ip1=ip1, ip2=ip2, node=node)
        vertex.set_ips(ip1, ip2)

    def mod_name(self, name):
        name = valid(name, 'string')
        if name != self.name:
            if TmplVertex.query.filter_by(name=name).first() != None:
                raise Exception('Name {0} already used'.format(name))
            self.name = name

    def mod_mimetype(self, mimetype):
        mimetype = valid(mimetype, 'string')
        test_mimetype(mimetype)
        if mimetype != self.mimetype:
            self.mimetype = mimetype
            self.zephir_var_ip1 = None
            self.zephir_var_ip2 = None

    def mod_zephir(self, zephir_module, zephir_var_ip1, zephir_var_ip2):
        zephir_module, zephir_var_ip1, zephir_var_ip2 = valid_tmpl_vertex_zephir(zephir_module, zephir_var_ip1, zephir_var_ip2, self.mimetype)
        if self.zephir_module != zephir_module:
            self.zephir_module = zephir_module
        if self.zephir_var_ip1 != zephir_var_ip1:
            self.zephir_var_ip1 = zephir_var_ip1
        if self.zephir_var_ip2 != zephir_var_ip2:
            self.zephir_var_ip2 = zephir_var_ip2


@trace()
def create_tmplvertex(name):
    try:
        name = valid(name, 'string')
        return TmplVertex(name=name)
    except:
        raise ValueError("unable to create TmplVertex object in internal database")

@trace()
def get_tmpl_vertices(tmplnode=None):
    try:
        if tmplnode == None:
            return TmplVertex.query.all()
        else:
            return TmplVertex.query.filter_by(tmpl_node=tmplnode).all()
    except:
        raise Exception( 'error in get_tmpl_vertices')

@trace()
def get_tmplvertex_by_node(node):
    vertices = get_vertex_by_node(node)
    templates = []
    for vertex in vertices:
        templates.append(vertex.tmplvertex)
    return templates

# ____________________________________________________________
class Node(Entity):
    __tablename__ = 'arv_db_node_node'
    id = Column("id", Integer, primary_key=True)
    name = Column("name", UnicodeText, unique=True)
    uai = Column("uai", UnicodeText)
    id_zephir = Column("id_zephir", Integer)
    eole_version = Column("eole_version", UnicodeText)
    tmpl_node_id = Column(Integer, index=True)
    tmpl_node = relationship('TmplNode', backref="nodes", primaryjoin="Node.tmpl_node_id == TmplNode.id", uselist=False)
    state = Column("state", Integer, default=0)
    credentials = relationship('Credential', backref="node", primaryjoin="Credential.node_id == Node.id", uselist=True)
    tail_conns = relationship('Connect', backref="head_node", primaryjoin="Connect.head_node_id == Node.id", uselist=True)
    head_conns = relationship('Connect', backref="tail_node", primaryjoin="Connect.tail_node_id == Node.id", uselist=True)
    __table_args__ = (ForeignKeyConstraint(['tmpl_node_id'], [u'arv_db_node_tmplnode.id'], **{'name': u'arv_db_node_node_tmpl_node_id_fk'}),)

    @trace(hide_args=[6], hide_kwargs=['passwd'])
    def add_credential(self, private_key, credential=None,
            unsigned_credential=None, cred_auth=None, name=None, passwd=None,
           keyid_cred=None, subjkey=None, old_credential=None):
        if isinstance(credential, str):
            credential = credential.encode()
        if isinstance(unsigned_credential, str):
            unsigned_credential = unsigned_credential.encode()
        if isinstance(private_key, str):
            private_key = private_key.encode()
        if isinstance(keyid_cred, str):
            keyid_cred = keyid_cred.encode()
        if isinstance(subjkey, str):
            subjkey = subjkey.encode()
        subject = cert.get_subject(cert=credential)
        name = str(subject[1])
        suffix_cred = subject[0]
        suffix_cred = valid(suffix_cred, 'string')
        subject = bin_encoding("{0}, CN = {1}".format(suffix_cred, name))
        expiration_date = cred_end_validity_date(credential)

        # Test du modulus
        if not valid_priv_and_cred(private_key, credential, passwd):
            raise ValueError("Credential and private_key doesn't match")
        # On ne déchiffre que la clé privée du node Sphynx-ARV
        tmpl_nodes = get_tmpl_nodes('sphynx')
        current_node = tmpl_nodes[0].nodes[0]
        if self == current_node:
            private_key = decrypt_privkey(privkey_string=private_key, passwd=passwd)

        if old_credential is None:
            old_credential = Credential()
        old_credential.name = name
        old_credential.suffix_cred = suffix_cred
        old_credential.subject = subject
        old_credential.private_key = private_key
        old_credential.ca = False
        old_credential.credential = credential
        old_credential.unsigned_credential = unsigned_credential
        old_credential.node = self
        old_credential.cred_auth = cred_auth
        old_credential.keyid_cred = keyid_cred
        old_credential.subjkey = subjkey
        old_credential.expiration_date= expiration_date
        return old_credential

    @trace(hide_args=[3], hide_kwargs=['passwd'])
    def import_credential(self, private_key, credential, passwd,
            old_credential=None, old_ca=None):
        ca = None
        cert_type, certificate = der_to_pem(credential)
        if cert_type == "pkcs7":
            ca_chain, certificate = split_pkcs7(certificate)
            if old_ca is None:
                for ca_cred in ca_chain:
                    ca = add_certificate_authorities(ca_cred, 'x509')
        if certificate is not None:
            try:
                cred_subject = cert.get_issuer_subject(cert=certificate)
                name = cred_subject[1]
                suffix_cred = cred_subject[0]
                suffix_cred = valid(suffix_cred, 'string')
                cred_subject = bin_encoding("{0}, CN = {1}".format(suffix_cred, name))
                ca = Credential.query.filter_by(subject=cred_subject).first()
                if ca == None:
                    raise Exception("CA does not exists in database, Can't add certificate")
            except Exception as e:
                raise ValueError(str(e))
        else:
            raise Exception("Unknown certificate type")
        if old_ca is not None and old_ca != ca:
            raise ValueError("New credential is issued with another CA than old one's")
        subjkey, keyid_cred = get_keyid_from_certifstring(private_key, passwd=passwd)

        subjkey = suppress_colon(subjkey)
        keyid_cred = suppress_colon(keyid_cred)
        return self.add_credential(private_key=private_key,
                credential=certificate, cred_auth=ca, passwd=passwd,
                keyid_cred=keyid_cred, subjkey=subjkey,
                old_credential=old_credential)

    def add_extremity(self, pub_ip, priv_ip=None):
        try:
            pub_ip = valid(pub_ip, 'ip')
            if Extremity.query.filter_by(pub_ip=pub_ip).first() != None:
                raise Exception('Public IP should be unique')
            if priv_ip != None:
                priv_ip = valid(priv_ip, 'ip')
            else:
                priv_ip = pub_ip
            return Extremity(pub_ip=pub_ip, priv_ip=priv_ip, node=self)
        except Exception as e:
            raise ValueError(str(e))

    def get_credentials(self, cred_auth=None):
        try:
            if cred_auth == None:
                return Credential.query.filter_by(node=self).all()
            else:
                return Credential.query.filter_by(node=self,
                        cred_auth=cred_auth).all()
        except Exception as e:
            raise ValueError(str(e))

    def get_extremities(self):
        try:
            return Extremity.query.filter_by(node=self).all()
        except Exception as e:
            raise ValueError(str(e))

@trace()
def get_etabs():
    ret = []
    for node in get_nodes():
        libelle = node.name
        ret.append((libelle, node.uai, node.id_zephir, node.eole_version, node.tmpl_node.mimetype))
    return ret

def get_nodes(nodea=None):
    if nodea == None:
        return Node.query.all()
    else:
        tmpl_node_a = nodea.tmpl_node
        nodes_b = set()
        nodes2 = set(get_connected_nodes(nodea))
        nodes2.add(nodea)
        #retrieve all node_b in connects_a
        tmpl_connects = TmplConnect.query.filter_by(tmpl_node_a=tmpl_node_a).all()
        for tmpl_connect in tmpl_connects:
            #tmpl_nodes_b = TmplNode.query.filter_by(connects_b=tmpl_connect).all()
            tmpl_node_b = tmpl_connect.tmpl_node_b
            for node in tmpl_node_b.nodes:
                nodes_b.add(node)
        tmpl_connects = TmplConnect.query.filter_by(tmpl_node_b=tmpl_node_a).all()
        for tmpl_connect in tmpl_connects:
            #tmpl_nodes_b = TmplNode.query.filter_by(connects_a=tmpl_connect).all()
            tmpl_node_b = tmpl_connect.tmpl_node_a
            for node in tmpl_node_b.nodes:
                nodes_b.add(node)
        return list(nodes_b - nodes2)

@trace()
def del_node(node):
    if node.tail_conns != [] or node.head_conns != []:
        raise Exception("Node is already used by a Connect, you cannot delete it")
    try:
        for vertex in node.vertices:
            vertex.delete()
        extr2delete = Extremity.query.filter_by(node=node).all()
        for extremity in extr2delete:
            extremity.delete()
        cred2delete = Credential.query.filter_by(node=node).all()
        for credential in cred2delete:
            credential.delete()

        node.delete()
    except Exception as e:
        raise Exception("error in del_node" % str(e))


class Vertex(Entity):
    __tablename__ = 'arv_db_node_vertex'
    id = Column("id", Integer, primary_key=True)
    ip1 = Column("ip1", UnicodeText)
    ip2 = Column("ip2", UnicodeText)
    tmpl_vertex_id = Column(Integer, index=True)
    tmpl_vertex = relationship('TmplVertex', backref="vertices", primaryjoin="Vertex.tmpl_vertex_id == TmplVertex.id", uselist=False)
    node_id = Column(Integer, index=True)
    node = relationship('Node', backref="vertices", primaryjoin="Vertex.node_id == Node.id", uselist=False)
    state = Column("state", Integer, default=0)
    __table_args__ = (ForeignKeyConstraint(['tmpl_vertex_id'], [u'arv_db_node_tmplvertex.id'], **{'name': u'arv_db_node_vertex_tmpl_vertex_id_fk'}),
                      ForeignKeyConstraint(['node_id'], [u'arv_db_node_node.id'], **{'name': u'arv_db_node_vertex_node_id_fk'})
    )

    @trace()
    def set_ips(self, ip1, ip2):
        vip1 = []
        vip2 = []
        for ip1_ in ip1.split(ip_sep):
            vip1.append(valid(ip1_, 'ip'))
        ip1 = ip_sep.join(vip1)
        if self.tmpl_vertex.mimetype != 'ip':
            for ip2_ in ip2.split(ip_sep):
                vip2.append(valid(ip2_, 'ip'))
            if len(vip1) != len(vip2):
                raise Exception('Error: length of ip1 different than ip2')
            ip2 = ip_sep.join(vip2)

        if self.tmpl_vertex.mimetype == 'ip':
            if ip2 != None:
                raise ValueError('No ip2 if mimetype is ip')
        elif self.tmpl_vertex.mimetype == 'network':
            try:
                for i in range(0, len(vip1)):
                    IP("%s/%s"%(vip1[i], vip2[i]))
            except:
                raise ValueError('ip1 not a valid network or ip2 not a valid mask')
        else:
            raise TypeError('not a valid mimetype %s for tmpl_vertex %s' %
                    (self.tmpl_vertex.mimetype, self.tmpl_vertex.name))
        self.ip1 = ip1
        self.ip2 = ip2

@trace()
def del_vertex(vertex):
    vertex.delete()

@trace()
def get_vertex_by_node(node):
    return Vertex.query.filter_by(node=node).all()

@trace()
def get_vertex_by_node_tmplvertex(node, tmpl_vertex):
    return Vertex.query.filter_by(node=node, tmpl_vertex=tmpl_vertex).first()


class Extremity(Entity):
    __tablename__ = 'arv_db_node_extremity'
    id = Column("id", Integer, primary_key=True)
    pub_ip = Column("pub_ip", UnicodeText, unique=True)
    priv_ip = Column("priv_ip", UnicodeText)
    node_id = Column(Integer, index=True)
    state = Column("state", Integer, default=0)
    node = relationship('Node', backref="extremities", primaryjoin="Extremity.node_id == Node.id", uselist=False)
    head_conns = relationship('Connect', backref="head_extr", primaryjoin="Connect.head_extr_id == Extremity.id", uselist=True)
    tail_conns = relationship('Connect', backref="tail_extr", primaryjoin="Connect.tail_extr_id == Extremity.id", uselist=True)
    __table_args__ = (ForeignKeyConstraint(['node_id'], [u'arv_db_node_node.id'], **{'name': u'arv_db_node_extremity_node_id_fk'}),)


@trace()
def del_extremity(extremity):
    if extremity.head_conns != [] or extremity.tail_conns != []:
        raise Exception("Extremity is already used by a Connect, you cannot delete it")
    extremity.delete()

@trace()
def get_extremities():
    Extremity.query.all()

# vim: ts=4 sw=4 expandtab
