#!/usr/bin/env python

# Copyright (C) 2005 by Peter V. Radatti
VERSION = (1, 3, 0, 'March 2005')

import socket, thread, os, sys, urlparse, string, time, pwd
from errno import *

try:
    True, False
except NameError:
    True, False = 1, 0

try:
    object
except NameError:
    class object:
        pass

class config(object):

    INIT_COMMANDS = ['ENABLE EXPAND', 'ENABLE ORIGINAL', 'ENABLE HTML 1']
    LISTEN_ADDRESS = '0.0.0.0'
    LISTEN_PORT = 8082
    VFIND_ADDRESS = '127.0.0.1'
    VFIND_PORT = 8081
    PROXY_ADDRESS = None
    PROXY_PORT = None
    THREAD_TIMEOUT = None
    MAX_RETRY_WAIT = 60
    TEMP_DIR = '/tmp'
    LOG_FILE = '$VSTK_HOME/var/log/vfproxy.log'
    PID_FILE = '$VSTK_HOME/var/run/vfproxy.pid'
    INI_FILE = '$VSTK_HOME/data/vfproxy/vfproxy.ini'
    UNPRIVILEGED_USER = 'nobody'
    TRANSPARENT = False


ERRORS = {

    '403': {
        'text': ('You have received this notice because VFind has detected'
                 ' that the action you have taken has caused a threat to the'
                 ' computer network.  The name and the description of the threat'
                 ' is listed below.  If you think you have received the message'
                 ' in error, please contact your network administrator.'),
        'notice': 'Notice',
        'warning': 'Warning',
        'threat': 'Threat',
    },

    'default': {
        'html': ('<!DOCTYPE html PUBLIC "-//W3C//DTD HTML 4.01 Transitional//EN">\r\n'
                 '<html>\r\n'
                 ' <head>\r\n'
                 '  <title>CyberSoft Threat Detection</title>\r\n'
                 '  <style type="text/css" media="screen"><!--\r\n'
                 'h1 { font: 14pt/14pt Verdana; color: red; font-weight: bold; font-style: italic; }\r\n'
                 'h2 { font: 12pt/14pt Verdana; font-weight: bold; }\r\n'
                 'li { font: 10pt/14pt Verdana; }\r\n'
                 'p { font: 10pt/14pt Verdana; }\r\n'
                 '--></style>\r\n'
                 ' </head>\r\n'
                 ' <body bgcolor="white">\r\n'
                 '  <blockquote>\r\n'
                 '   <h2>%(code)s: %(cause)s</h2>\r\n'
                 '   <h1>CyberSoft Threat Detection %(notice)s</h1>\r\n'
                 '   <p><b>%(warning)s:</b> %(text)s</p>\r\n'
                 '   <p><b>%(threat)s Detected:</b></p>\r\n'
                 '   <ul>\r\n'
                 '    %(threats)s\r\n'
                 '   </ul>\r\n'
                 '  </blockquote>\r\n'
                 ' </body>\r\n'
                 '</html>\r\n'),
        'text': ('You have received this notice because VFind failed to'
                 ' determine wether the action you have taken could cause'
                 ' a threat to the computer network. The name and the'
                 ' description of the error is listed below.  If this error'
                 ' persists, please contact your network administrator.'),
        'notice': 'Error',
        'warning': 'Error',
        'threat': 'Error',
    },
}


def htmlquote(text):
    text = string.replace(text, '&', '&amp;')
    text = string.replace(text, '<', '&lt;')
    text = string.replace(text, '>', '&gt;')
    return text


resource_errors = []
for error in ['EAGAIN', 'ENOMEM', 'EBUSY', 'ENFILE', 'EMFILE',
              'ENOSPC', 'ETIMEDOUT', 'EDEADLK']:
    try:
        resource_errors.append(eval(error))
    except NameError:
        pass

def syserror(message):
    try:
        errno, message = message.errno, message.strerror
    except AttributeError:
        try:
            errno, message = message
        except (ValueError, TypeError):
            errno = 0
            message = str(message)

    if errno in resource_errors:
        return 503, errno, message
    else:
        return 500, errno, message


def configure():
    errors = 0

    vstk_home = os.environ.get('VSTK_HOME', '.')
    config.LOG_FILE = config.LOG_FILE.replace('$VSTK_HOME', vstk_home)
    config.PID_FILE = config.PID_FILE.replace('$VSTK_HOME', vstk_home)
    config.INI_FILE = config.INI_FILE.replace('$VSTK_HOME', vstk_home)

    try:
        ini = open(config.INI_FILE, 'r')
    except IOError:
        pass
    else:
        linenum = 0
        while True:
            linenum = linenum + 1
            line = ini.readline()
            if not line:
                break
            line = string.strip(line)
            if not line or line[0] == '#':
                continue
            line = string.split(line, '=')
            variable = string.join(string.split(string.upper(string.strip(line[0]))), '_')
            value = string.strip(string.join(line[1:], '='))
            try:
                setattr(config, variable, eval(value))
            except Exception:
                try:
                    setattr(config, variable, value)
                except Exception:
                    errors = errors + 1
                    sys.stderr.write('Syntax error in vfproxy.ini line %d: %s\n'
                                       % (linenum, string.join(line, '=')))

    variables = { 'code': None, 'cause': None, 'count': None, 'threats': None }
    for error in ERRORS.values():
        for variable in error.keys():
            variables[variable] = None
    variables = variables.keys()

    dir = os.path.join(vstk_home, 'data', 'vfproxy')
    try:
        files = os.listdir(dir)
    except OSError, message:
        code, errno, message = syserror(message)
        if errno != ENOENT:
            sys.stderr.write('%s: %s\n' % (dir, message))
            sys.exit(1)
        files = []

    for file in files:
        name = string.split(file, '.')
        if len(name) != 2:
            continue
        name, ext = name
        if ext not in ['text', 'html']:
            continue

        content = open(os.path.join(dir, file), 'r').read()
        if ext ==  'html':
            content = string.replace(content, '%', '%%')
            for variable in variables:
                 content = string.replace(content,
                                          '<vfproxy-' + variable + '/>',
                                          '%(' + variable + ')s')
        if ext ==  'text':
            content = htmlquote(content)

        content = string.replace(content, '\n', '\r\n')
        ERRORS.setdefault(name, {})[ext] = content

    for error in ERRORS.values():
        if error.has_key('html'):
            error['html'] = (
                'HTTP/1.0 %(code)s %(cause)s\r\n'
                'Content-Type: text/html\r\n'
                'Connection: close\r\n'
                '\r\n'
            ) + error['html']

    if errors:
        sys.exit(1)


def debug(message):
    pass # sys.stderr.write(message + '\n')

def dot(dot):
    pass # sys.stderr.write(dot)

def list_item(text):
    return '<li>' + htmlquote(text) + '</li>'

def error_message(code, cause, threats):
    edit = {
        'code': str(code),
        'cause': cause,
        'count': len(threats),
        'threats': string.join(map(list_item, threats), '\n'),
    }
    edit.update(ERRORS['default'])
    error = ERRORS.get(str(code))
    if error is not None:
        edit.update(error)
    if len(threats) != 1:
        edit['threat'] = edit['threat'] + 's'
    return edit['html'] % edit


class ProtocolError(Exception):
    pass

class BailOut(Exception):
    pass

class bail(object):
    def __getattr__(*args): raise BailOut()
    def __setattr__(*args): raise BailOut()
    def __call__(*args): raise BailOut()
bail = bail()


def sendall(socket, buffer):
    while True:
        sent = socket.send(buffer)
        remains = len(buffer) - sent
        if not remains:
            break
        buffer = buffer[sent:]


class makefile(object):

    def __init__(self, socket, mode):
        self.socket = socket
        self.buffer = []

    def flush(self):
        pass

    def close(self):
        self.socket.close()

    def write(self, data):
        sendall(self.socket, data)

    def read(self, size):
        if self.buffer:
            if len(self.buffer[0]) > size:
                try:
                    return self.buffer[0][:size]
                finally:
                    self.buffer[0] = self.buffer[0][size:]
            else:
                try:
                    return self.buffer[0]
                finally:
                    del self.buffer[0]
        else:
            result = self.socket.recv(size)
            return result

    def readline(self):
        while True:
            if self.buffer and '\n' in self.buffer[-1]:
                break

            data = self.socket.recv(512)
            if not data:
                break

            self.buffer.append(data)

        if not self.buffer:
            return ''

        result = self.buffer
        self.buffer = []
        try:
            end = string.index(result[-1], '\n') + 1
        except ValueError:
            pass
        else:
            if end < len(result[-1]):
                self.buffer = [result[-1][end:]]
                result[-1] = result[-1][:end]

        return string.join(result, '')


class VFind(object):

    def __init__(self, port):
        self.clients = {}
        self.mutex = thread.allocate_lock()
        self.server = None
        self.readline = None
        self.port = port
        self.retries = 0

    def connect(self):
        self.mutex.acquire()
        try:
            if not self.server or not self.readline:
                server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
                server.connect((config.VFIND_ADDRESS, self.port))
                self.readline = makefile(server, 'r').readline
                self.server = server

                ready = string.strip(self.readline())
                debug('<< ' + ready)
                if string.lower(ready) != 'ready':
                    raise ProtocolError(0, 'VFind daemon not ready: ' + ready)
                for command in config.INIT_COMMANDS:
                    self.command(command)

            return self.server, self.readline

        finally:
            self.mutex.release()

    def command(self, command):
        debug('>> ' + command)
        sendall(self.server, command + '\r\n')
        response = string.strip(self.readline())
        debug('<< ' + response)
        if string.lower(string.split(response, ' ')[0]) != 'ok':
            raise ProtocolError(0, response)

    def disconnect(self, message):
        self.mutex.acquire()
        try:
            if not self.server:
                return

            if self.clients:
                for client in self.clients.values():
                    client[1].append('ERROR: ' + message)
                    client[0].release()

            self.server.close()
            self.server = None
            self.readline = None
        finally:
            self.mutex.release()

    def receive(self):
        message = 'EOF'

        try:
            server, readline = self.connect()
            response = readline()
        except (IOError, OSError, socket.error, ProtocolError), message:
            code, errno, message = syserror(message)
            if errno in resource_errors:
                if self.retries < config.MAX_RETRY_WAIT:
                    self.retries = self.retries + 1
                time.sleep(self.retries)
            response = ''

        if response:
            debug('<< ' + string.rstrip(response))
            self.retries = 0
        else:
            self.disconnect('VFind daemon disconnected: ' + message)

        return string.strip(response)

    def run(self):
        thread.start_new_thread(self.handle, ())

    def handle(self):
        while True:
            line = self.receive()
            if not line:
                continue

            line = string.split(line, ' ')
            id, response = line[:2]
            line = string.join(line[2:], ' ')
            response = string.lower(response)

            if response in ('queued', 'scanning'):
                continue

            id = int(id)

            self.mutex.acquire()
            try:
                client = self.clients[id]

                if response == 'infected':
                    client[1].append(string.split(line, ' : ')[0])

                if response[0] == 'e':
                    client[1].append(response + ' ' + line)
                    response = 'done'

                if response == 'done':
                    self.clients[id][0].release()

            finally:
                self.mutex.release()

    def check(self, filename):
        id = thread.get_ident()
        try:
            self.mutex.acquire()
            try:
                client = thread.allocate_lock()
                self.clients[id] = (client, [])
                client.acquire()
            finally:
                self.mutex.release()

            messages = self.clients[id][1]
            try:
                self.sendline('%d SCAN/FILE %s' % (id, filename))
            except (IOError, OSError, socket.error, ProtocolError), message:
                messages.append('ERROR %d: %s' % syserror(message)[1:])
            except Exception, message:
                messages.append('ERROR: %s: %s' % (sys.exc_info()[0], message))
            else:
                client.acquire() # wait for server release
                client.release()

            return self.clients[id][1]

        finally:
            try:
                del self.clients[id]
            except KeyError:
                pass

    def sendline(self, command):
        self.mutex.acquire()
        try:
            server = self.server
            if not server:
                raise ProtocolError(0, 'VFind server is down')
        finally:
            self.mutex.release()

        debug('>> ' + command)
        sendall(server, command + '\r\n')


class Watcher(object):

    def __init__(self):
        self.threads = {}
        self.next = time.time() + config.THREAD_TIMEOUT
        self.mutex = thread.allocate_lock()

    def run(self):
        thread.start_new_thread(self.watch, ())

    def watch(self):
        while True:
            try:
                time.sleep(self.purge())
            except BailOut:
                dot('?')

    def ping(self, who, activity=None):
        self.mutex.acquire()
        try:
            if activity is None:
                activity = self.threads.get(who.id, ('',))[-1]
            self.threads[who.id] = (who, time.time(), activity)
        finally:
            self.mutex.release()

    def delete(self, id):
        self.mutex.acquire()
        try:
            try:
                del self.threads[id]
            except KeyError:
                pass
        finally:
            self.mutex.release()

    def purge(self):
        self.mutex.acquire()
        try:
            threads = self.threads.items()
        finally:
            self.mutex.release()

        debug('Purge: %d threads' % (len(threads),))

        oldest = 0
        now = time.time()
        cutoff = now - config.THREAD_TIMEOUT
        for id, (who, timestamp, activity) in threads:
            if timestamp <= cutoff:
                debug('Purging thread ' + str(id))
                try:
                    who.mutex.acquire()
                    try:
                        who.timeout(activity)
                    finally:
                        who.mutex.release()
                except(IOError, OSError, socket.error), message:
                    debug('Thread %d purge failed: %s' % (id, str(message)))
            elif oldest < now - timestamp:
                who.mutex.acquire()
                try:
                    gagged = who.timeout is bail
                finally:
                    who.mutex.release()
                if not gagged:
                    oldest = now - timestamp

        debug('Next purge in %d seconds' % (oldest or config.THREAD_TIMEOUT,))
        return oldest or config.THREAD_TIMEOUT


class NoWatcher(object):
    def ping(self, who, activity=None):
         pass
    def run(self):
         pass
    def delete(self, id):
         pass


class HTTPproxy(object):

    def __init__(self, client, scanner, watcher):
        self.activity = None
        self.server = None
        self.work = None
        self.scanner = scanner
        self.content_length = None
        self.watcher = watcher
        self.host = None
        self.port = 80
        self.client = makefile(client, 'r+')
        self.mutex = thread.allocate_lock()

    def ping(self, activity=None):
        self.watcher.ping(self, activity)

    def run(self):
        thread.start_new_thread(self.handle, ())

    def fail(self, code, cause, data):
        dot('#')
        self.mutex.acquire()
        try:
            gagged = self.gag()
        finally:
            self.mutex.release()
        if not gagged:
            self.client.flush()
            message = error_message(code, cause, data)
            self.client.write(message)
            self.client.flush()
        raise BailOut()

    def tempfile(self):
        return '%s/VFind_%06x' % (config.TEMP_DIR, self.id)

    def copy(self, source, target, size=None):
        bufsiz = 8192
        while True:
            if size is not None and size < bufsiz:
                bufsiz = size

            buf = source.read(bufsiz)

            self.ping()
            if not buf:
                break
            target.write(buf)
            self.ping()
            if size is not None:
                size = size - len(buf)
                if size <= 0:
                    break

    def load(self, source, what, size=None):
        self.ping('receiving ' + what)
        self.work = self.tempfile()
        target = open(self.work, 'w')
        self.copy(source, target, size)
        target.close()

    def dump(self, target, what, size=None):
        self.ping('sending ' + what)
        source = open(self.work, 'r')
        try:
            os.remove(self.work)
        except OSError:
            pass
        self.work = None
        self.copy(source, target, size)
        source.close()

    def receive_header(self, source, what):
        self.ping('receiving file header from ' + what)
        header = []
        while True:
            line = string.rstrip(source.readline())
            self.ping()
            if not line:
                break

            if len(header) > 1 and line[0] in string.whitespace:
                header[-1] = header[-1] + ' ' + string.lstrip(line)
            elif not header or ':' in line:
                header.append(line)

        return header

    def scan(self, what):
        self.ping('scanning ' + what)
        infected = self.scanner.check(self.work)
        self.ping()
        if infected:
            if len(infected) == 1 and infected[0][:5] == 'ERROR':
                self.fail(503, 'Could not scan for threats', infected)
            else:
                self.fail(403, 'Forbidden ' + what, infected)
        return not infected

    def handle(self):
        self.id = thread.get_ident()
        self.ping('initalizing')

        try:
            try:
                self.process()
            finally:
                self.mutex.acquire()
                try:
                    self.close()
                finally:
                    self.mutex.release()
            dot('.')
        except BailOut:
            pass
        except (IOError, OSError, socket.error, thread.error), message:
            code, errno, message = syserror(message)
            sys.stderr.write('ERROR %d: %s\n' % (errno, message))

    def connect(self):
        server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        if config.PROXY_PORT:
            server.connect((config.PROXY_ADDRESS, config.PROXY_PORT))
        else:
            server.connect((self.host, self.port))
        return server

    def parse_url(self, url):
        urldata = urlparse.urlparse(url)
        self.host = urldata[1]

        if ':' in self.host:
            self.host, port = string.split(self.host, ':')
            try:
                port = int(port)
                if not (0 < port <= 0xffff):
                    raise ValueError()
            except ValueError:
                self.fail(400, 'Bad port number in URL', [url])
            else:
                self.port = port

        if not self.host:
            if config.TRANSPARENT:
                self.port = config.PROXY_PORT
                self.host = config.PROXY_ADDRESS
            else:
                self.fail(400, 'No host in URL', [url])

        return urlparse.urlunparse(('', '') + urldata[2:])

    def parse_header(self, header):
        if not header:
            self.fail(400, 'No data from client', ['The client request is empty'])

        request = string.split(header[0], ' ') # GET http://example.com/url HTTP/1.0
        method = request[0]

        if method not in ['GET', 'HEAD', 'POST']:
            self.fail(501, 'Unsupported access method: %s' % (method,),
                      ['The %s access method is not supported;'
                       ' only GET, HEAD, and POST are supported' % (method,)])

        if request[2] == 'HTTP/1.1':     # force http/1.0 - no keepalive
            request[2] = 'HTTP/1.0'

        for i in range(len(header)):
            words = string.split(string.lower(header[i]), ':')
            if words[0] == 'content-length':
                self.content_length = int(words[1])
            elif words[0] == 'connection':
                header[i] = 'Connection: close'
            elif words[0] == 'proxy-connection':
                header[i] = 'Proxy-Connection: close'

        if config.PROXY_PORT and not config.TRANSPARENT:
            self.parse_url(request[1])
        else:
            request[1] = self.parse_url(request[1])

        header[0] = string.join(request, ' ')

        return header, method

    def send_header(self, header):
        sendall(self.server, string.join(header + ['', ''], '\r\n'))

    def post_data(self):
        if self.content_length is None:
            self.fail(411, 'Content length required for POST',
               ['Client sent a POST request with no Content-Length header'])

        self.load(self.client, 'POST data', self.content_length)
        if self.scan('POST data'):
            self.dump(makefile(self.server, 'w'), 'POST data to ' + self.host)

    def transfer_content(self):
        server = makefile(self.server, 'r')
        try:
            header = self.receive_header(server, self.host)
            self.load(server, 'file content from ' + self.host, None)
        finally:
            server.close()

        if self.scan('file content'):
            self.mutex.acquire()
            try:
                gagged = self.gag()
            finally:
                self.mutex.release()
            if not gagged:
                self.client.write(string.join(header+['',''],'\r\n'))
                self.dump(self.client, 'file content to client')

    def process(self):
        try:
            activity = 'receive request for %s'
            self.ping('receiving header from client ')
            header = self.receive_header(self.client, 'client')

            activity = 'parse request for %s'
            header, method = self.parse_header(header)

            activity = 'connect to %%s on port %d' % (self.port,)
            self.ping('connecting to ' + self.host)
            self.server = self.connect()

            activity = 'send header to %s'
            self.send_header(header)

            if method == 'POST':
                activity = 'POST data to %s'
                self.post_data()

            self.client.flush()

            if method in ['GET', 'POST']:
                activity = 'transfer content from %s to client'
                self.transfer_content()

            self.server.close()

        except (IOError, OSError, socket.error), message:
            code, errno, message = syserror(message)
            self.fail(code, 'Could not ' + activity % (self.host,),
                      ['ERROR %d: %s' % (errno, message)])

    def close(self):
        debug(':: Close')
        assert(self.mutex.locked())
        if self.watcher is not None:
            self.watcher.delete(self.id)

        if self.work is not None:
            try:
                os.remove(self.work)
            except OSError:
                pass
        self.work = None

        if self.server is not None:
            try:
                self.server.close()
            except (IOError, OSError, socket.error):
                pass
        self.server = bail

        if self.client is not None:
            try:
                self.client.close()
            except (IOError, OSError, socket.error):
                pass
        self.client = bail

    def gag(self):
        assert(self.mutex.locked())
        debug(':: GAG')
        gagged = self.timeout is bail
        self.timeout = bail
        self.fail = bail
        self.run = bail
        return gagged

    def timeout(self, activity):
        assert(self.mutex.locked())
        dot('!')
        debug('Thread %d timed out while %s' % (self.id, activity))
        gagged = self.gag()
        if not gagged:
            self.client.flush()
            self.client.write(error_message(504, 'Timeout',
                                            ['Timeout ' + activity]))
            self.client.flush()
        self.close()
        self.ping = bail


def usage():
    sys.stderr.write('usage: %s [<VFind port> <listen port>]\n'
                        % (sys.argv[0],))
    sys.exit(1)

def daemonize():
    if config.UNPRIVILEGED_USER:
        uid, gid = pwd.getpwnam(config.UNPRIVILEGED_USER)[2:4]

    if os.fork() > 0:
        sys.exit(0)

    os.chdir('/')
    os.setsid()
    os.umask(0)

    if config.LOG_FILE:
        sys.stdin.flush()
        sys.stdout.flush()
        sys.stderr.flush()
        stdin = open('/dev/null', 'r')
        stdout = open(config.LOG_FILE, 'a+')
        stderr = open(config.LOG_FILE, 'a+', 0)
        os.dup2(stdin.fileno(), sys.stdin.fileno())
        os.dup2(stdout.fileno(), sys.stdout.fileno())
        os.dup2(stderr.fileno(), sys.stderr.fileno())

    pid = os.fork()
    if pid > 0:
        if config.PID_FILE:
            open(config.PID_FILE, 'w', 0).write('%s\n' % (pid,))
        sys.exit(0)

    if os.getuid() == 0 and config.UNPRIVILEGED_USER:
        os.setgid(gid)
        os.setuid(uid)

def main():
    configure()

    if len(sys.argv) == 3:
        try:
            config.VFIND_PORT = int(sys.argv[1])
            config.LISTEN_PORT = int(sys.argv[2])
        except ValueError:
            usage();
    elif len(sys.argv) != 1:
        usage();

    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((config.LISTEN_ADDRESS, config.LISTEN_PORT))
    server.listen(5)

    if config.THREAD_TIMEOUT:
        watcher = Watcher()
    else:
        watcher = NoWatcher()
    vfind = VFind(config.VFIND_PORT)

    sys.stderr.write('VFind HTTP Proxy Version %d,'
                     ' Release %d,'
                     ' Patchlevel %d (%s)\n' % VERSION)
    daemonize()
    watcher.run()
    vfind.run()

    retries = 0
    while True:
        try:
            debug(':: Client connected')
            HTTPproxy(server.accept()[0], vfind, watcher).run()
        except (IOError, OSError, socket.error, thread.error), message:
            code, errno, message = syserror(message)
            sys.stderr.write('ERROR %d: %s\n' % (errno, message))
            if errno in resource_errors:
                if retries < config.MAX_RETRY_WAIT:
                    retries = retries + 1
                time.sleep(retries)
        else:
            retries = 0

if __name__ == '__main__':
    main()
