#!/usr/bin/python3
# -*- coding: utf-8 -*-

import sys
import sqlite3
import re
import subprocess

from pkg_resources import parse_version
from arv.db.initialize import initialize_database, commit_database
from arv.db.edge import *
from arv.db.node import *
from arv.db.version import add_mod_DbVersion, get_DbVersion
from arv.config import (sqlite_database, strongswan_database,
                strongswan_tmpconfigs_directory, arv_address_ip, arv_db_version)
from arv.lib.util import cred_end_validity_date, bin_encoding, valid

from creole import cert
from pyeole.ihm import question_ouinon
from creole.client import CreoleClient
from pyeole.service import unmanaged_service

import glob
from os.path import isfile, join
from os import unlink, listdir

def gen_local_ca():
    try:
        cert.load_conf({'ssl_dir': ssl_dir, 'start_index': "01",
            'ca_conf_file': ca_conf_file, 'ca_file': join(ssl_dir, 'certs/CaCert.pem'),
            'ssl_default_key_bits': x509_default_key_bits,
            'ssl_default_cert_time': '5475'})
        cert.gen_ca(extensions="ac-ext")
    except Exception as e:
        print("ERROR: --- unable to create the certificate --- : %s " % str(e))
        sys.exit(1)

def populate_database():
    conf_eole = CreoleClient().get_creole()
    initialize_database(create=True)
    #add TmplNode Sphynx, Etablissement and Roadwarrior
    tmplnode1 = add_tmpl_node(name="Sphynx", mimetype='sphynx')
    tmplnode2 = add_tmpl_node(name="Etablissement", mimetype='etablissement')
    tmplnode3 = add_tmpl_node(name="Roadwarrior", mimetype='roadwarrior')
    add_mod_DbVersion(arv_db_version)
    #add node Sphynx
    print("Ajout du serveur sphynx...")
    name = conf_eole['nom_machine']
    rne = conf_eole['numero_etab']
    node1 = tmplnode1.add_node(name=name, uai=rne)
    if arv_address_ip != None:
        node1.add_extremity(pub_ip=arv_address_ip)

    #add authority credential
    credential = open(cert.ca_file, 'r').read()
    credauth = add_credential_auth(credential=credential, local=True)
    add_credential('sphynx', 'eole', node1, 'autosigned')
init = False

def upgrade_subject_field_in_credential_table():
    certificates = get_all_credentials()
    for certif in certificates:
        subject = cert.get_subject(cert=certif.credential)
        name = str(subject[1])
        suffix_cred = subject[0]
        suffix_cred = valid(suffix_cred, 'string')
        subject = bin_encoding("%s/CN=%s"%(suffix_cred, name))
        if subject != certif.subject:
            certif.subject = subject

def upgrade_credential_table():
    certificates = get_all_credentials()
    for certif in certificates:
        cn_pattern = re.compile('/CN=', re.DOTALL)
        if isinstance(certif.subject, bytes):
            certif.subject = certif.subject.decode()
        if cn_pattern.findall(certif.subject):
            subject = cert.get_subject(cert=certif.credential)
            suffix_cred = subject[0]
            certif.subject = bin_encoding("{0}, CN = {1}".format(suffix_cred, subject[1]))
            certif.suffix_cred = valid(suffix_cred, 'string')

def encode_subject_credential_table():
    certificates = get_all_credentials()
    for certif in certificates:
        cn_pattern = re.compile('/CN=', re.DOTALL)
        if isinstance(certif.subject, bytes):
            certif.subject = certif.subject.decode()
            subject = cert.get_subject(cert=certif.credential)
            suffix_cred = subject[0]
            certif.subject = bin_encoding("{0}, CN = {1}".format(suffix_cred, subject[1]))
            certif.suffix_cred = valid(suffix_cred, 'string')

#seulement à l'instance
if 'instance' in sys.argv:
    if isfile(sqlite_database):
        question = 'Voulez-vous réinitialiser la base ARV et perdre vos modifications ?'
        if question_ouinon(question, level='warn') == 'oui':
            init = True
            unlink(sqlite_database)
            if isfile(strongswan_database):
                unlink(strongswan_database)
            for filename in glob.glob('%s*.db'% strongswan_tmpconfigs_directory):
                unlink(filename)

cmd = ['systemctl', 'is-enabled', 'arv']
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
arv_enabled = process.wait() == 0
if not isfile(sqlite_database):
    try:
        unmanaged_service('stop', 'arv', 'service')
    except Exception as e:
        print(e)
        sys.exit(1)
    gen_local_ca()
    populate_database()
    commit_database()
    if arv_enabled:
        try:
            unmanaged_service('start', 'arv', 'service')
        except Exception as e:
            print(e)
            sys.exit(1)
else:
    if not init:
        initialize_database()
    conn = sqlite3.connect(sqlite_database)
    c = conn.cursor()
    expiration_date_exception = False
    id_zephir_exception = False
    eole_version_exception = False
    tmplconnect_leftsendcert_exception = False
    connect_leftsendcert_exception = False
    tmplconnect_fragmentation_exception = False
    connect_fragmentation_exception = False
    arv_db_version_table = False
    upgrade_subject_field = False
    rw_dns_exception = False
    certificate_subject_openssl = False
    certificate_encode_subject = False
    try:
        c.execute('ALTER TABLE arv_db_edge_credential ADD COLUMN expiration_date TEXT')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: expiration_date':
            expiration_date_exception = True
    try:
        c.execute('ALTER TABLE arv_db_node_node ADD COLUMN id_zephir INTEGER')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: id_zephir':
            id_zephir_exception = True
    try:
        c.execute('ALTER TABLE arv_db_node_node ADD COLUMN eole_version TEXT')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: eole_version':
            eole_version_exception = True
    try:
        c.execute('ALTER TABLE arv_db_edge_tmplconnect ADD COLUMN leftsendcert TEXT')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: leftsendcert':
            tmplconnect_leftsendcert_exception = True
    try:
        c.execute('ALTER TABLE arv_db_edge_connect ADD COLUMN leftsendcert TEXT')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: leftsendcert':
            connect_leftsendcert_exception = True
    try:
        c.execute('ALTER TABLE arv_db_edge_tmplconnect ADD COLUMN fragmentation TEXT')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: fragmentation':
            tmplconnect_fragmentation_exception = True
    try:
        c.execute('ALTER TABLE arv_db_edge_connect ADD COLUMN fragmentation TEXT')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: fragmentation':
            connect_fragmentation_exception = True
    try:
        if get_DbVersion() is None:
            upgrade_subject_field_in_credential_table()
        else:
            upgrade_subject_field = True
    except Exception as e:
        print("Probleme durant l'upgrade du champ subject", e)
    try:
        c.execute('ALTER TABLE arv_db_edge_connect ADD COLUMN rw_dns TEXT')
    except sqlite3.OperationalError as err:
        if str(err) == 'duplicate column name: rw_dns':
           rw_dns_exception = True
    try:
        c.execute('create table arv_db_version_dbversion ( id INTEGER NOT NULL, version TEXT , PRIMARY KEY (id))')
        add_mod_DbVersion(version=arv_db_version)
    except sqlite3.OperationalError as err:
        if str(err) == 'table arv_db_version_dbversion already exists':
            arv_db_version_table = True
    try:
        db_version = get_DbVersion()
        upgrade_credential_table_db_version = '1.0.1'
        if parse_version(db_version) < parse_version(upgrade_credential_table_db_version):
            upgrade_credential_table()
            add_mod_DbVersion(version=upgrade_credential_table_db_version)
        else:
            certificate_subject_openssl = True
    except Exception as e:
        print("Probleme lors de l upgrade de la table credential", e)
    try: #34339
        db_version = get_DbVersion()
        upgrade_credential_table_db_version = '1.0.2'
        if parse_version(db_version) < parse_version(upgrade_credential_table_db_version):
            encode_subject_credential_table()
            add_mod_DbVersion(version=upgrade_credential_table_db_version)
        else:
            certificate_encode_subject = True
    except Exception as e:
        print("Probleme lors de l encoding des subjects de la table credential", e)

    if expiration_date_exception \
       and tmplconnect_leftsendcert_exception \
       and connect_leftsendcert_exception \
       and tmplconnect_fragmentation_exception \
       and connect_fragmentation_exception \
       and upgrade_subject_field \
       and rw_dns_exception \
       and arv_db_version_table\
       and certificate_subject_openssl\
       and certificate_encode_subject:

        conn.close()
        try:
            unmanaged_service('stop', 'arv', 'service')
        except Exception as e:
            print(e)
            sys.exit(1)
        if arv_enabled:
            try:
                unmanaged_service('start', 'arv', 'service')
            except Exception as e:
                print(e)
                sys.exit(1)
        sys.exit(0)
    conn.close()
    try:
        unmanaged_service('stop', 'arv', 'service')
    except Exception as e:
        print(e)
        sys.exit(1)
    if not expiration_date_exception:
        for cred in get_all_credentials():
            cred.expiration_date = cred_end_validity_date(cred.credential)
    if not tmplconnect_leftsendcert_exception:
        for tmplconnect in get_tmpl_connects():
            tmplconnect.leftsendcert = 'always'
    if not connect_leftsendcert_exception:
        for connect in get_connects():
            connect.leftsendcert = 'always'
    if not tmplconnect_fragmentation_exception:
        for tmplconnect in get_tmpl_connects():
            tmplconnect.fragmentation = 'no'
    if not connect_fragmentation_exception:
        for connect in get_connects():
            connect.fragmentation = 'no'
    if not rw_dns_exception:
        tmplnode3 = add_tmpl_node(name="Roadwarrior", mimetype='roadwarrior')
        tmplnode3.add_tmpl_vertex(name="RW_SourceIP", mimetype='ip')
    commit_database()
    if arv_enabled:
        try:
            unmanaged_service('start', 'arv', 'service')
        except Exception as e:
            print(e)
            sys.exit(1)
