#######################################################################
# Copyright (C) 2009-2012 by Carnegie Mellon University.
#
# @OPENSOURCE_HEADER_START@
#
# Use of the SILK system and related source code is subject to the terms
# of the following licenses:
#
# GNU Public License (GPL) Rights pursuant to Version 2, June 1991
# Government Purpose License Rights (GPLR) pursuant to DFARS 252.227.7013
#
# NO WARRANTY
#
# ANY INFORMATION, MATERIALS, SERVICES, INTELLECTUAL PROPERTY OR OTHER
# PROPERTY OR RIGHTS GRANTED OR PROVIDED BY CARNEGIE MELLON UNIVERSITY
# PURSUANT TO THIS LICENSE (HEREINAFTER THE "DELIVERABLES") ARE ON AN
# "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY
# KIND, EITHER EXPRESS OR IMPLIED AS TO ANY MATTER INCLUDING, BUT NOT
# LIMITED TO, WARRANTY OF FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABILITY, INFORMATIONAL CONTENT, NONINFRINGEMENT, OR ERROR-FREE
# OPERATION. CARNEGIE MELLON UNIVERSITY SHALL NOT BE LIABLE FOR INDIRECT,
# SPECIAL OR CONSEQUENTIAL DAMAGES, SUCH AS LOSS OF PROFITS OR INABILITY
# TO USE SAID INTELLECTUAL PROPERTY, UNDER THIS LICENSE, REGARDLESS OF
# WHETHER SUCH PARTY WAS AWARE OF THE POSSIBILITY OF SUCH DAMAGES.
# LICENSEE AGREES THAT IT WILL NOT MAKE ANY WARRANTY ON BEHALF OF
# CARNEGIE MELLON UNIVERSITY, EXPRESS OR IMPLIED, TO ANY PERSON
# CONCERNING THE APPLICATION OF OR THE RESULTS TO BE OBTAINED WITH THE
# DELIVERABLES UNDER THIS LICENSE.
#
# Licensee hereby agrees to defend, indemnify, and hold harmless Carnegie
# Mellon University, its trustees, officers, employees, and agents from
# all claims or demands made against them (and any related losses,
# expenses, or attorney's fees) arising out of, or relating to Licensee's
# and/or its sub licensees' negligent use or willful misuse of or
# negligent conduct or willful misconduct regarding the Software,
# facilities, or other rights or assistance granted by Carnegie Mellon
# University under this License, including, but not limited to, any
# claims of product liability, personal injury, death, damage to
# property, or violation of any laws or regulations.
#
# Carnegie Mellon University Software Engineering Institute authored
# documents are sponsored by the U.S. Department of Defense under
# Contract FA8721-05-C-0003. Carnegie Mellon University retains
# copyrights in all material produced under this contract. The U.S.
# Government retains a non-exclusive, royalty-free license to publish or
# reproduce these documents, or allow others to do so, for U.S.
# Government purposes only pursuant to the copyright license under the
# contract clause at 252.227.7013.
#
# @OPENSOURCE_HEADER_END@
#
#######################################################################

#######################################################################
# $SiLK: sendrcv_daemon_test.py c9a75ce720a4 2012-02-13 15:58:52Z mthomas $
#######################################################################

import os
import os.path
import tempfile
import random
import stat
import re
import shutil
import sys
from gencerts import generate_signed_cert, generate_ca_cert
from daemon_test import check_call, Daemon, Log_manager, get_ephemeral_port, Dirobject
from itertools import chain

try:
    import hashlib
    md5_new = hashlib.md5
    sha1_new = hashlib.sha1
except ImportError:
    import md5
    md5_new = md5.new
    import sha
    sha1_new = sha.new


CHUNKSIZE = 2048

class Sndrcv_base(Daemon):

    def __init__(self, name=None, **kwds):
        Daemon.__init__(self, name, **kwds)
        self.mode = "client"
        self.port = None
        self.clients = list()
        self.servers = list()
        self.ca_cert = None
        self.ca_key = None
        self.cert = None

    def create_cert(self):
        self.cert = generate_signed_cert(self.basedir, 
                                         (self.ca_key, self.ca_cert),
                                         "key.pem", "key.p12")

    def init(self):
        Daemon.init(self)
        if self.ca_cert:
            self.dirs.append("cert")
        self.create_dirs()
        if self.ca_cert:
            self.create_cert()

    def get_args(self):
        args = Daemon.get_args(self)
        args += ['--mode', self.mode, 
                 '--identifier', self.name]
        if self.ca_cert:
            args += ['--tls-ca', os.path.abspath(self.ca_cert),
                     '--tls-pkcs12', os.path.abspath(self.cert)]
        if self.mode == "server":
            args += ['--server-port', str(self.port)]
            for client in self.clients:
                 args += ['--client-ident', client]
        else:
            for (ident, addr, port) in self.servers:
                args += ['--server-address', 
                         ':'.join((ident, addr, str(port)))]
        return args

    def _check_file(self, dir, finfo):
        (path, (size, ck_sha, ck_md5)) = finfo
        path = os.path.join(self.dirname[dir], os.path.basename(path))
        if not os.path.exists(path):
            return ("Does not exist", path)
        (nsize, ck2_sha, ck2_md5) = checksum_file(path)
        if nsize != size:
            return ("Size mismatch (%s != %s)" % (size, nsize), path)
        if ck2_sha.digest() != ck_sha.digest():
            return ("SHA mismatch (%s != %s)" % (ck_sha.hexdigest(), 
                                                 ck2_sha.hexdigest()), path)
        if ck2_md5.digest() != ck_md5.digest():
            return ("MD5 mismatch (%s != %s)" % (ck_md5.hexdigest(), 
                                                 ck2_md5.hexdigest()), path)
        return (None, path)


class Rwsender(Sndrcv_base):

    def __init__(self, name=None, polling_interval=5, filters=[], **kwds):
        Sndrcv_base.__init__(self, name, prog_env="RWSENDER", **kwds)
        self.exe_name = "rwsender"
        self.filters = filters
        self.polling_interval = polling_interval
        self.dirs = ["in", "proc", "error"]

    def get_args(self):
        args = Sndrcv_base.get_args(self)
        args += ['--incoming-directory', os.path.abspath(self.dirname["in"]),
                 '--processing-directory', 
                 os.path.abspath(self.dirname["proc"]),
                 '--error-directory', os.path.abspath(self.dirname["error"]),
                 '--polling-interval', str(self.polling_interval)]
        for ident, regexp in self.filters:
            args.extend(["--filter", ident + ':' + regexp])
        return args

    def send_random_file(self, suffix="", prefix="random", size=(0, 0)):
        return create_random_file(suffix = suffix, prefix = prefix, 
                                  dir = self.dirname["in"], size = size)

    def send_files(self, files):
        for f, data in files:
            shutil.copy(f, self.dirname["in"])

    def check_error(self, data):
        return self._check_file("error", data)


class Rwreceiver(Sndrcv_base):

    def __init__(self, name=None, post_command=None, **kwds):
        Sndrcv_base.__init__(self, name, prog_env="RWRECEIVER", **kwds)
        self.exe_name = "rwreceiver"
        self.dirs = ["dest"]
        self.post_command = post_command

    def get_args(self):
        args = Sndrcv_base.get_args(self)
        args += ['--destination-directory', 
                 os.path.abspath(self.dirname["dest"])]
        if self.post_command:
            args += ['--post-command', self.post_command]
        return args

    def check_sent(self, data):
        return self._check_file("dest", data)


class System(Dirobject):

    def __init__(self):
        Dirobject.__init__(self)
        self.create_dirs()
        self.logs = Log_manager()
        self.logs.start()
        self.client_type = None
        self.server_type = None
        self.clients = set()
        self.servers = set()
        self.ca_cert = None
        self.ca_key = None

    def create_ca_cert(self):
        self.ca_key, self.ca_cert = generate_ca_cert(self.basedir, 
                                                     'ca_cert.pem')
        
    def connect(self, clients, servers, tls=False):
        if tls:
            self.create_ca_cert()
        if isinstance(clients, Sndrcv_base):
            clients = [clients]
        if isinstance(servers, Sndrcv_base):
            servers = [servers]
        for client in clients:
            for server in servers:
                self._connect(client, server, tls)

    def _connect(self, client, server, tls):
        if not isinstance(client, Sndrcv_base):
            raise ValueError, "Can only connect rwsenders and rwreceivers"
        if not self.client_type:
            if isinstance(client, Rwsender):
                self.client_type = Rwsender
                self.server_type = Rwreceiver
            else:
                self.client_type = Rwreceiver
                self.server_type = Rwsender
        if not isinstance(client, self.client_type):
            raise ValueError, ("Client must be of type %s" % 
                               self.client_type.__name__)
        if not isinstance(server, self.server_type):
            raise ValueError, ("Server must be of type %s" % 
                               self.server_type.__name__)
        client.mode = "client"
        server.mode = "server"

        if server.port is None:
            server.port = get_ephemeral_port()

        client.servers.append((server.name, "localhost", server.port))
        server.clients.append(client.name)

        self.clients.add(client)
        self.servers.add(server)

        client.log_manager = self.logs
        server.log_manager = self.logs

        if tls:
            client.ca_cert = self.ca_cert
            server.ca_cert = self.ca_cert
            client.ca_key = self.ca_key
            server.ca_key = self.ca_key

    def _forall(self, call, which, *args, **kwds):
        if which == "clients":
            it = self.clients
        elif which == "servers":
            it = self.servers
        else:
            it = chain(self.clients, self.servers)
        return map(lambda x: getattr(x, call)(*args, **kwds), it)

    def start(self, which = None):
        self._forall("init", which)
        self._forall("start", which)

    def stop(self, which = None):
        status = self._forall("stop", which)
        self._forall("remove_basedir", which)
        self.remove_basedir()
        self.logs.stop()
        return status

    def get_logger(self):
        return logs


def create_random_file(suffix="", prefix="random", dir=None, size=(0, 0)):
    (handle, path) = tempfile.mkstemp(suffix, prefix, dir)
    f = os.fdopen(handle, "w")
    numbytes = random.randint(size[0], size[1])
    totalbytes = numbytes
    checksum_sha = sha1_new()
    checksum_md5 = md5_new()
    while numbytes:
        length = min(numbytes, CHUNKSIZE)
        try:
            bytes = os.urandom(length)
        except NotImplementedError:
            bytes = ''.join(chr(random.getrandbits(8)) 
                            for x in xrange(0, length))
        f.write(bytes)
        checksum_sha.update(bytes)
        checksum_md5.update(bytes)
        numbytes -= length
    f.close()
    return (path, (totalbytes, checksum_sha, checksum_md5))


def checksum_file(path):
    f = open(path, 'rb')
    checksum_sha = sha1_new()
    checksum_md5 = md5_new()
    size = os.fstat(f.fileno())[stat.ST_SIZE]
    data = f.read(CHUNKSIZE)
    while data:
        checksum_sha.update(data)
        checksum_md5.update(data)
        data = f.read(1024)
    f.close()
    return (size, checksum_sha, checksum_md5)

