#!/usr/bin/python
# -*- encoding: utf8 -*-

"""
Ein einfacher threading ident-Server (RFC1413), der mit Verbindungen umgehen kann,
die über einen Load Balancer laufen.

Aufruf: noris-identd.py [configfile [port]]

configfile ist eine INI-Datei, die die Zuordnungen der Loadbalancer 
und die Logging-Konfiguration enthält. Wenn kein Dateiname angegeben wird,
ist der Default "/etc/noris-identd.ini"
  
Im Abschnitt [map] der Datei steht jeweils als Key die IP des realen Rechners (z.B. schlunz),
im Value die zugehörige IP des Loadbalancers, auf dem der identd läuft (z.B. krempel).
Es können numerische IPv4-Adressen oder FQDNs verwendet werden.

IPv6 wird (noch) nicht unterstützt

Zur Log-Konfiguration  siehe die Dokumentation des Moduls 'logging' in python.
"""


import ConfigParser, logging, logging.config, pwd, re, socket, SocketServer, struct, threading, sys, time, pprint

# Parser für die Zeiten aus /sys/net/tcp, 
# liefert local_address, local_port, remote_address, remote_port
nettcp_re = re.compile(r'^[ \d]+: ([A-F\d]{8}):([A-F\d]{4}) ([A-F\d]{8}):([A-F\d]{4})(?: +[A-F:\d]+){2} +([A-F\d]{2}):[A-F\d]{8} [A-F\d]{8} +([\d]+)')
NETTCP_STATUS_TIMEWAIT = 3

# Parser für Abfrage-Anforderungen
request_re = re.compile(r'^ *([\d]+) *, *([\d]+) *$')

class IllegalRequest(Exception):
    def __init__(self, msg):
        self.msg = msg
    
    def __str__(self):
        return self.msg

class UsageError(Exception):
    def __init__(self, msg):
        self.msg = msg
        
    def __str__(self):
        return self.msg

def hex_ip(ipstring):
    """Umwandeln einer IP-Strings ("127.0.0.1") in eine Folge von Hexziffern, wie sie /proc/net/tcp verwendet
    
    >>> hex_ip('127.0.0.1')
    '0100007F'
    """
    return "%08X" % struct.unpack('=L', socket.inet_aton(ipstring))

def hex_port(portstring):
    """Umwandeln eines-Port-Strings in eine Folge von Hexziffern, wie sie von /proc/net/tcp verwendet werden
    
    >>> hex_port('56228')
    'DBA4'
    
    >>> hex_port('1')
    '0001'
    """

    return "%04X" % int(portstring)

def any(iterable):
    for v in iterable:
        if v:
            return True
    return False

class ThreadedTCPServer(SocketServer.ThreadingMixIn, SocketServer.TCPServer):
    """Siehe Beispiele zu SocketServer ..."""
    allow_reuse_address = True
    
class IdentRequestHandler(SocketServer.BaseRequestHandler):
    """Ident-Request beantworten
    
    mapping: die (Klassen-)Variable muß z.B. durch
             type() oder Ableitung gesetzt werden.
             Sie ist ein dict; im key steht die IP-Adresse des Load-Balancers
             als Folge von Hexadezimalzahlen, wie von hex_ip() erzeugt
             (z.B. von intra), im value die zugehörige reale Adresse (z.B. intra1)
             im gleichen Format
    """
     
    def handle(self):
        local_port = remote_port = 0
        try:
            log_info = '?'
            client_address = self.client_address[0]
            server_address, server_port = self.request.getsockname()
            logger = logging.getLogger()
            logger.debug("%s connected on %s" % (client_address, server_address))
    
            ## ips ermitteln, und in proc Format umwandeln
            server_hex = hex_ip(server_address)
            client_hex = hex_ip(client_address)
            client_hex_set = set([client_hex])
            if client_hex in self.mapping:
                client_hex_set.add(self.mapping[client_hex])
                logger.debug("mapping: remote %s could appear as %s" % (client_hex, self.mapping[client_hex]))
            question = self.request.recv(1023).strip()
            log_info = "query from client %s at server %s: %s" % (client_address, server_address, question)
            match = request_re.match(question)
            if not match:
                raise IllegalRequest('Cannot parse request from %s: %r' % (client_address, question))
            local_port, remote_port = [int(port) for port in match.groups()]
            local_port_hex = "%04X" % local_port
            remote_port_hex = "%04X" % remote_port
            try:
                procnettcp = open('/proc/net/tcp','r')
                # Kopfzeile überspringen
                procnettcp.readline()
                
                for line in procnettcp:
                    match = nettcp_re.match(line)
                    if not match:
                        logging.warn("Cannot parse line from /proc/net/tcp: %r, skipping" % line)
                        continue
                    p_local_addr, p_local_port, p_remote_addr, p_remote_port, p_state, p_uid = match.groups()
                    if int(p_state) != NETTCP_STATUS_TIMEWAIT \
                            and local_port_hex == p_local_port \
                            and remote_port_hex == p_remote_port \
                            and server_hex == p_local_addr \
                            and p_remote_addr in client_hex_set:
                        try:
                            login = pwd.getpwuid(int(p_uid)).pw_name
                            response = 'USERID:UNIX:%s' % login
                            logger.debug("found: %s" % login)
                        except KeyError:
                            response = 'ERROR:NO-USER'
                            logger.warn("could not find pwentry for id %s" % p_uid)
                        break
                else:
                    response = 'ERROR:NO-USER'
                    logger.info('%s: no match, local=%s:%s, remote=%s:%s' % (
                        log_info, client_hex, local_port_hex, server_hex, remote_port_hex))
            finally:
                if procnettcp:
                    procnettcp.close()
        except IllegalRequest, err:
            logger.warn(err)
            return
        except StandardError, err:
            self.request.send('%d,%d:ERROR:UNKNOWN-ERROR\r\n' % (local_port, remote_port))
            logger.error('unknown error during %s: %s' % (log_info,err))
            raise

        response = '%d,%d:%s' % (local_port, remote_port, response)
        self.request.send('%s\r\n' % response)
        logger.debug('%s: %s' % (log_info, response))

def start_server(mapping, port):
    logger = logging.getLogger()
    logger.info('new server started on port %d' % port)
    logger.debug('mapping: %s' % pprint.pformat(mapping))
    hex_mapping = dict((hex_ip(k), hex_ip(v)) for k,v in mapping.iteritems())
    HandlerClass=type('InitializedIdentRequestHandler', (IdentRequestHandler, object), 
                      {'mapping': hex_mapping})
    server = ThreadedTCPServer(('', port), HandlerClass)
    server_thread = threading.Thread(target=server.serve_forever)
    server_thread.setDaemon(True)
    server_thread.start()
    try:        
        while True:
            # wait for keyboard interrupt
            time.sleep(6000)
    except KeyboardInterrupt:
        logger.info("shutting down because of keyboard interrupt")
        pass
    except:
        logger.info("shutting down")       
        
def resolve(hostname):
    """returns a list of IPv4 address as string list
    """
    result = []
    for _1, _2, _3, _4, host_port in socket.getaddrinfo(hostname, 0, socket.AF_INET, 0, socket.SOL_TCP):
        addr,port = host_port
        if addr not in result:
            result.append(addr)
    return result
    
def read_config(pathname):
    """liest die Konfigurationsdatei und gibt ein mapping für start_server() zurück
    """
    config = ConfigParser.RawConfigParser()
    config.read(pathname)
    mapping = {}
    for key, value in config.items("map"):
        value_ips = resolve(value)
        if not value_ips:
            raise ValueError("Hostname on right side does not resolve: %s\n" % value)
        if len(value_ips) > 1:
            raise ValueError("Hostname on right side resolves to multiple addresses: %s\n" % value)
        value_ip = value_ips[0]
        for key_ip in resolve(key):
            if key_ip in mapping:
                raise ValueError("Duplicate entry in [map] for key %s\n" % key)
            mapping[key_ip] = value_ip
    return mapping

def init_logging(pathname):
    configured = False
    if pathname:
        try:
            logging.config.fileConfig(pathname)
            configured = True
        except ConfigParser.NoSectionError, err:
            sys.stderr.write(
                'Hint: Cannot parse logging config:\n%s\n'
                'Falling back to debug logging\n' % str(err))
    if not configured:
        logging.basicConfig(level=logging.DEBUG)

if __name__ == '__main__':
    try:
        cmd = sys.argv[0]
        configfile = '/etc/noris-identd.ini'
        port = 113
        if any(arg.startswith('-') for arg in sys.argv):
            raise UsageError("%s does not take any options" % cmd)
        if len(sys.argv) >= 2:
            configfile = sys.argv[1]
        if len(sys.argv) >= 3:
            try:
                port = int(sys.argv[2])
            except ValueError:
                raise UsageError("%s: Wrong port: %s" % (cmd, sys.argv[2]))
        if len(sys.argv) > 3:
            raise UsageError("%s: Too many parameters")    
        try:
            init_logging(configfile)
            mapping = read_config(configfile)
        except ValueError, err:
            sys.stderr.write("%s: Error in Config: %s\n" % (cmd,str(err)))
            sys.exit(1)
        start_server(mapping, port)
    except UsageError, err:
        sys.stderr.write("%s\nUsage: %s [configfile [port]]\n" % (str(err), cmd))
        sys.exit(2)
    
