#!/usr/bin/env python
# Copyright 2008 (c) Niels Provos
# All rights reserved.
#
# Probes your local resolver to determine the distribution of its
# source and query id space; if it's not random, your resolver can be
# poisoned more easily.
#
# This tool requires that the Unix utility dig is in the path; if you
# want to run this on Windows, you need to reimplement the
# ResolveAddress function.
#
#
# How To Use
#
# Simple:
# $ ./dnspredict.py          - uses default resolver and provides terse
#                              information
# $ ./dnspredict.py 10.0.0.1 - chooses 10.0.0.1 as resolver and provides
#                              terse information
# Complex:
# $ ./dnspredict --output /tmp/model
# $ ./dnspredict --input /tmp/model
#

import Queue
import getopt
import marshal
import math
import os
import sys
import threading
import time
import random

sample_size = 256

def fact(n):
    assert n >= 0
    assert math.floor(n) == n
    assert n+1 > n
    result = 1
    factor = 2
    while factor <= n:
        result *= factor
        factor += 1
    return result

def ResolveAddress(query, resolver=None):
    """DNS resolves an address."""
    if resolver:
        command = 'dig \"@%s\" +short \"%s\" a' % (resolver,query)
    else:
        command = 'dig +short \"%s\" a' % query

    data = os.popen(command).read()
    data = data.strip()

    return data

def RunToSum(data):
    sumrun = {}

    if not data.keys():
        maxrun = 5
    else:
        maxrun = max(data.keys())
    for count in data.keys():
        sumrun[count] = 0
        for off in range(count, maxrun + 1):
            sumrun[count] += data.get(off, 0)

    return sumrun

def ZScores(data, num):
    result = []
    if not data.keys():
        maxkey = 5
    else:
        maxkey = min(10, max(data.keys()))
    for r in xrange(1, maxkey + 1):
        actual = data.get(r, 0)
        expected = float((r + 1) * num - (r**2 + r - 1)) / float(fact(r + 2))
        stddev = math.sqrt(expected)
        z = float(actual - expected)/stddev
        result.append(z)

    return result[:4]

def MeanStdDevDiff(data):
    differences = map(lambda x: x[0] - x[1], zip(data[1:], data[:-1]))
    if not differences:
        return (0, 0, 0)

    mean = sum(differences)/float(len(differences))
    tmp = sum(map(lambda x: (x - mean) ** 2, differences))
    stddev = math.sqrt(tmp/float(len(differences) - 1))

    differences.sort()
    num = len(differences)
    if num % 2 == 0:
        median = (differences[num/2] + differences[num/2 + 1])/2.0
    else:
        median = differences[num/2]

    return (median, mean, stddev)
    
def RunsUpDown(data):
    """Compute the randomness of a sequence of integers."""

    differences = map(lambda x: x[0] - x[1], zip(data[1:], data[:-1]))
    if not differences:
        return ([0],[0])

    runs = { -1 : {},
            0 : {},
            1 : {}
            }

    run_type = None
    run_count = 1
    for number in differences:
        cur_type = 0
        if number < 0:
            cur_type = -1
        elif number > 0:
            cur_type = 1
        else:
            cur_type = 0
                
        if run_type != cur_type:
            if run_type != None:
                runs[run_type][run_count] = runs[run_type].get(run_count, 0) + 1
            run_count = 1
            run_type = cur_type
        else:
            run_count += 1

    runs[run_type][run_count] = runs[run_type].get(run_count, 0) + 1

    sum_ups = RunToSum(runs[1])
    sum_downs = RunToSum(runs[-1])

    return (ZScores(sum_ups, len(data)), ZScores(sum_downs, len(data)))

def Histogram(data):
    hist = {}
    for port in data:
        hist[port] = hist.get(port, 0) + 1

    kv = hist.items()
    kv.sort(lambda x, y: cmp(y[1], x[1]))

    return kv

class DNSQuery(threading.Thread):
    """Waits for DNS queries and executes them."""
    def __init__(self, input_queue, output_queue, resolver=None):
        threading.Thread.__init__(self)
        self._input_q = input_queue
        self._output_q = output_queue
        self._resolver = resolver

    def run(self):
        while 1:
            query = self._input_q.get()

            # let's us know that we are done
            if query == 'stop':
                return

            data = ResolveAddress(query, self._resolver)
            self._output_q.put((query, data))

class RandomSequence:
    _MIN_STDDEV = 15000
    _POOR_STDDEV = 3000
    _FAIR_STDDEV = 7000
    _MAX_ZSCORE = 1.5 * 1.95
    
    def __init__(self, sequence):
        self._sequence = sequence

        self._meanstd = MeanStdDevDiff(sequence)
        self._hist = Histogram(sequence)
        self._zscores = RunsUpDown(sequence)

    def __str__(self):
        output = 'Median: %.1f Mean: %d StdDev: %.0f' % self._meanstd
        if self._meanstd[2] < 15000:
            output += ' (NR)'
        output += '\n'

        output += 'Top repeating: '
        ncount = 0
        for (port, count) in self._hist[:5]:
            if count <= 1:
                break
            ncount += 1
            output += '%d:%d ' % (port, count)
        if not ncount:
            output += 'None'
        output += '\n'

        for (off, name) in [ (0, 'Up'), (1, 'Down') ]:
            output += '\tRuns (%s) (Z-Score):\t' % name
            zscores = self._zscores[off]
            for i in range(len(zscores)):
                zscore = zscores[i]
                output += '%d=%f ' % (i + 1, zscore)
                if abs(zscore) > self._MAX_ZSCORE:
                    output += '(NR) '
            output += '\n'

        return output

    @property
    def isRandom(self):
        if self._meanstd[2] < self._MIN_STDDEV:
            return False
        isBadScore = lambda x: abs(x) > self._MAX_ZSCORE
        count = len(filter(isBadScore, self._zscores[0]))
        count += len(filter(isBadScore, self._zscores[1]))

        return count < 2

    @property
    def stdDev(self):
        return self._meanstd[2]

    @property
    def randomCategory(self):
        randcat = self.isRandom and "random" or "not random"
        if self._meanstd[2] < self._POOR_STDDEV:
            return 'POOR (%s)' % randcat
        elif self._meanstd[2] < self._FAIR_STDDEV:
            return 'FAIR (%s)' % randcat
        else:
            return 'GOOD (%s)' % randcat
            

class Model:
    def __init__(self, resolver, external_ip, queue):
        self._resolver = resolver
        self._external_ip = external_ip
        self._testing_set = []
        self._validation_set = []

        while not queue.empty():
            (query, result) = queue.get()
            # sometimes we cannot get a result
            if not result or len(result.split('.')) != 4:
                continue
            (o1, o2, o3, o4) = result.split('.')
            port = int(o1) * 256 + int(o2)
            qid = int(o3) * 256 + int(o4)
            
            if query.endswith("osd.honeyd.org"):
                self._testing_set.append((query, result, port, qid))
            elif query.endswith("dso.honeyd.org"):
                self._validation_set.append((query, result, port, qid))
            else:
                assert 0, "unknown query domain: %s" % query

        self._testing_set.sort(lambda x,y: cmp(int(x[0].split('.')[0]),
                                               int(y[0].split('.')[0])))

        self._port = map(lambda x: x[2], self._testing_set)
        self._qid = map(lambda x: x[3], self._testing_set)

        self._validation_port = map(lambda x: x[2], self._validation_set)
        self._validation_qid = map(lambda x: x[3], self._validation_set)

        self._random_port = RandomSequence(self.port)
        self._random_qid = RandomSequence(self.qid)

        self._random_v_port = RandomSequence(self.validation_port)
        self._random_v_qid = RandomSequence(self.validation_qid)

    @property
    def resolver(self):
        """The resolver against which this model was computed."""
        return self._resolver

    @property
    def external_ip(self):
        """The IP address that contacted the name server."""
        return self._external_ip

    @property
    def testing_set(self):
        """Returns all instances in the testing set."""
        return self._testing_set

    @property
    def validation_set(self):
        """Returns all instances in the testing set."""
        return self._validation_set

    @property
    def port(self):
        return self._port

    @property
    def qid(self):
        return self._qid

    @property
    def validation_port(self):
        return self._validation_port

    @property
    def validation_qid(self):
        return self._validation_qid

    @property
    def random_port(self):
        return self._random_port

    @property
    def random_qid(self):
        return self._random_qid

    @property
    def random_v_port(self):
        return self._random_v_port

    @property
    def random_v_qid(self):
        return self._random_v_qid

class Controller:
    def __init__(self, num_threads, resolver=None):
        self._work_queue = Queue.Queue()
        self._result_queue = Queue.Queue()
        self._resolver = resolver

        self._num_threads = num_threads
        self._threads = []
        
        for _ in xrange(num_threads):
            t = DNSQuery(self._work_queue, self._result_queue, resolver)
            self._threads.append(t)
            t.start()

    def _run(self, suffix, nqueries, increment=True):
        """Runs nqueries against the suffix."""

        for off in range(nqueries):
            # we could add an instance to our queries to create a
            # different namespace for each tool run.
            query = '%s.%s' % (increment and str(off) or 'fixed', suffix)
            self._work_queue.put(query)


    def _terminate_threads(self):
        """Terminate all threads."""
        for _ in xrange(self._num_threads):
            self._work_queue.put('stop')

        for t in self._threads:
            t.join()

    def ProduceModel(self):
        global sample_size
        
        start = time.time()
        print >>sys.stderr, "Starting DNS queries."

        self._run("osd.honeyd.org", sample_size)
        self._run("dso.honeyd.org", 64)

        self._terminate_threads()

        end = time.time()
        duration = end - start
        print >>sys.stderr, "Finishing DNS queries (%.1f seconds)" % duration

        external_ip = ResolveAddress("ns.osd.honeyd.org", self._resolver)

        return Model(self._resolver, external_ip, self._result_queue)

class TextView:
    def __init__(self, model):
        self._model = model
        self._show_validation = False

    def Show(self, fp, verbose=False, terse=False):
        """Displays the model."""

        model = self._model

        print >>fp, 'Resolver analysis: %s (ext: %s)' % (
            model.resolver and model.resolver or 'Unknown Resolver',
            model.external_ip)
        print >>fp, 'Queries: %d' % len(model.testing_set)
        if len(model.testing_set) < 10:
            print >>fp, '\tInsufficient data'
            return
        if not terse:
            print >>fp, 'Port Statistics:'
            print >>fp, '\t', model.random_port,
            print >>fp, '\t', \
                  model.random_port.isRandom and 'Okay' or 'NOT RANDOM'
            print >>fp, 'Qid Statistics:'
            print >>fp, '\t', model.random_qid,
            print >>fp, '\t', \
                  model.random_qid.isRandom and 'Okay' or 'NOT RANDOM'
        else:
            print >>fp, 'Port Statistics: StdDev:', \
                  int(model.random_port.stdDev), \
                  model.random_port.randomCategory
            print >>fp, 'Qid Statistics: StdDev:', \
                  int(model.random_qid.stdDev), \
                  model.random_qid.randomCategory

        if self._show_validation:
            print >>fp, 'Validation (%d): Port: %s Qid: %s' % (
                len(model.validation_set),
                model.random_v_port.isRandom and 'Okay' or 'NOT RANDOM',
                model.random_v_qid.isRandom and 'Okay' or 'NOT RANDOM')
        if verbose:
            print >>fp, 'All ports:\n', model.port
        


def usage(name):
    print >>sys.stderr, (
        '%s: [-h] [--input=filename] [--output=filename] [--resolver=ip]\n'
        '\t-h                  display this help message\n'
        '\t--resolver=ip       the ip address of a resolver, otherwise we\n'
        '\t                    use the default resolver\n'
        '\t--input=filename    reads a model from the filename\n'
        '\t--output=filename   probes a server and writes a model\n') % name
       

def LoadModel(filename):
    """Loads a model from disk."""
    data = marshal.load(open(filename, 'r'))
    q = Queue.Queue()
    for (query, result, pport, qid) in data[1]:
        q.put((query, result))
    for (query, result, pport, qid) in data[2]:
        q.put((query, result))

    external_ip = None
    if len(data) >= 4:
        external_ip = data[3]

    model = Model(data[0], external_ip, q)

    return model

def CreateModel(resolver=None):
    """Probes a resolver and creates a randomness model."""
    controller = Controller(16,resolver)
    model = controller.ProduceModel()

    def level(x):
        return '%d' % (x / 3000)
    pr = model.random_port.isRandom and 'r' or 'n'
    prl = level(model.random_port.stdDev)
    pq = model.random_qid.isRandom and 'r' or 'n'
    pql = level(model.random_qid.stdDev)

    # report randomness results back for analysis
    addr = '%s%s%s%s.osd.honeyd.org' % (pr, prl, pq, pql)
    ResolveAddress(addr, resolver)

    return model

def WriteModel(model, filename):
    data = [ model.resolver,
             model.testing_set, model.validation_set,
             model.external_ip]
    marshal.dump(data, open(filename, 'w'))
    
def main(argv):
    optlist, args = getopt.getopt(argv[1:], 'hvs:', [
        'output=', 'input=', 'resolver='])

    model = None
    resolver = None
    terse = False
    verbose = False
    filename_create = None
    filename_load = None

    for (o, a) in optlist:
        if o == '-h':
            usage(argv[0])
            sys.exit(0)
        elif o == '-v':
            verbose = True
        elif o == '-s':
            global sample_size
            sample_size = int(a)
        elif o == '--resolver':
            resolver = a
        elif o == '--input':
            filename_load = a
        elif o == '--output':
            filename_create = a

    if filename_create:
        model = CreateModel(resolver)
        WriteModel(model, filename_create)
    elif filename_load:
        model = LoadModel(filename_load)

    if not resolver and len(argv) > 1:
        resolver = argv[1]

    if not model:
        if not verbose:
            terse = True
        model = CreateModel(resolver)

    view = TextView(model)
    view.Show(sys.stdout, verbose, terse)

if __name__ == '__main__':
    main(sys.argv)
