#!/usr/bin/env python

"""
Primarily a tool to load test a running REST api deployment (ie: Apache 
and cassandra backend).  Spawns processes as denoted by the -p flag 
and return results.

Examples:

query_api -p 120 : query 120 endpoints in order returned by api.

query_api -p 200 -r -a 3600 -v : query 200 random hourly endpoints and
    include shortened query uri in output.

query_api -r -a 3600 -c max -vv : query default number of random hourly 
    max aggregation endpoints and dump returned data to the output.

query_api -D sunn-ar1 : dump all interface endpoints from device 
    sunn-ar1.

query_api -D lbl-mr2 -E in : dump all incoming data (including error 
    and discard) endpoints from device lbl-mr2.

query_api -D lbl-mr2 -E traffic : dump all traffic (/in and /out)
    endpoints from device lbl-mr2.

See query_api -h for all options.

"""

import json
from multiprocessing import Process, Queue
from optparse import OptionParser
import os
import pprint
import random
import re
import requests
import sys
import time

from esmond.cassandra import AGG_TYPES

class UrlContainer(object):
    def __init__(self, url):
        self.url = url

    def get_device(self):
        return self.url.split('/')[5]

    def get_shortened_request(self):
        return '/'.join(self.url.split('/')[5:])

class FetchProcess(Process):
    def __init__(self, q, host, uri, delay, agg=None, cf=None, last=None):
        super(FetchProcess, self).__init__()
        self.q = q
        self.host = host
        self.uri = uri
        self.delay = delay
        self.agg = agg
        self.cf = cf
        self.last = last
        self.duration = None
        self.text = None
        self.status = None
        self.url = None

    def generate_output(self):
        s = '{0} in {1}'.format(self.name, self.duration)        
        return (s, self.url, self.status, self.text)

    def run(self):
        p = {}
        if self.agg: p['agg'] = self.agg
        if self.cf: p['cf'] = self.cf
        if self.last: p['begin'] = int(time.time() - (self.last*60*60))
        time.sleep(random.randint(0, self.delay))
        r = requests.get('http://{0}{1}'.format(self.host, self.uri), params=p)
        self.url = UrlContainer(r.url)
        self.duration = r.elapsed
        self.status = r.status_code
        self.text = r.text
        self.q.put(self.generate_output())

class ResponseContainer(object):
    def __init__(self, message, url, status, data, verbose):
        self.message = message
        self.url = url
        self.status = status
        self.data = data
        self.verbose = verbose

        if self.data.strip(): self.data = json.loads(self.data)

        self.pp = pprint.PrettyPrinter(indent=4)

    @property
    def is_error(self):
        return True if self.status != 200 else False

    @property
    def error_message(self):
        s = '{0} - '.format(self.message)
        if self.status == 404:
            s += '404 for endpoint:'
        else:
            s += 'got error {0}'.format(self.data)
        s += ' ' + self.url.get_shortened_request()

        return s

    @property
    def has_data(self):
        if isinstance(self.data, dict) and self.data.has_key('data') \
            and len(self.data['data']):
            return True
        return False

    @property
    def response_text(self):
        m = self.message

        if self.has_data and self.verbose < 2:
            # note that data was found unless dumping payload
            m += ' data found {0}'.format(self.url.get_device())

        if self.verbose >= 1:
            # add shortened request uri for context
            m += ' ' + self.url.get_shortened_request()

        if self.verbose >= 2 and self.has_data:
            # add actual data payload
            i = 10
            m += '\nHad data:\n'
            m += self.pp.pformat(self.data)
            m += '\n' + '-' * 20

        return m


class RequestGenerator(object):
    endpoint_keywords = ['all', 'traffic', 'discard', 'error', 'in', 'out']

    def __init__(self, hostname, process_limit, randomize, intercloud,
        device_filter, endpoint_filter):
        self.hostname = hostname
        self.process_limit = process_limit
        self.randomize = randomize
        self.intercloud = intercloud
        self.device_filter = device_filter
        self.endpoint_filter = endpoint_filter

        self.device_uris = []
        self.interface_uris = []
        self.endpoint_uris = []

        self.filter_warning = False
        self.endpoint_filters = {}

    def _limit_exceeded(self, content):
        if self.device_filter:
            return False
        elif len(content) >= self.process_limit:
            return True
        else:
            return False

    def _filter_device(self, data):
        """Apply message filters"""
        if self.intercloud and not data['name'].endswith('-cr5'):
            if not self.filter_warning:
                print 'Filtering on intercloud devices'
                self.filter_warning = True
            return True
        if self.device_filter and not data['name'] == self.device_filter:
            if not self.filter_warning:
                print 'Filtering on device {0} - skipping process limits'.format(self.device_filter)
                self.filter_warning = True
            return True
        return False

    def _generate_devices(self):

        if len(self.device_uris): return # only do it once

        r = requests.get('http://{0}/v1/device/?limit=0'.format(self.hostname))

        if r.status_code == 200 and \
            r.headers['content-type'] == 'application/json':
            data = json.loads(r.text)
            for i in data:
                if self._filter_device(i):
                    continue
                for ii in i['children']:
                    if ii['name'] != 'interface':
                        continue
                    self.device_uris.append(ii['uri'])

        if self.randomize: 
            self.device_uris = random.sample(self.device_uris, len(self.device_uris))

        # return self.device_uris

    def _generate_interfaces(self):
        for device in self.device_uris:
            if self._limit_exceeded(self.interface_uris): break
            r = requests.get('http://{0}/{1}/'.format(self.hostname, device))
            data = json.loads(r.text)
            for i in data['children']:
                self.interface_uris.append(i['resource_uri'])

        if self.randomize:
            self.interface_uris = random.sample(self.interface_uris, len(self.interface_uris))

        # return self.interface_uris

    def _filter_endpoints(self, name):
        if not self.endpoint_filters:
            traffic = ['in', 'out']
            discard = ['discard/in', 'discard/out']
            errors  = ['error/in', 'error/out']

            all_endpoints = [traffic,discard,errors]

            self.endpoint_filters = {
                'all': traffic + discard + errors,
                'traffic': traffic,
                'discard': discard,
                'error': errors,
                'in': [i[0] for i in all_endpoints],
                'out': [i[1] for i in all_endpoints],
            }

        if name not in self.endpoint_filters[self.endpoint_filter]:
            return True

        return False

    def _generate_endpoints(self):
        for interface in self.interface_uris:
            if self._limit_exceeded(self.endpoint_uris): break
            r = requests.get('http://{0}/{1}/'.format(self.hostname, interface))
            if r.status_code == 200:
                data = json.loads(r.text)
                if data['children']:
                    for c in data['children']:
                        if self._filter_endpoints(c['name']):
                            continue
                        self.endpoint_uris.append(c['uri'])

        # return self.endpoint_uris[0:self.process_limit]

    def get_device_list(self):
        self._generate_devices()
        for d in self.device_uris:
            yield d.split('/')[3]

    def get_endpoint_list(self):
        self._generate_devices()
        self._generate_interfaces()
        self._generate_endpoints()

        if self.device_filter:
            return self.endpoint_uris
        else:
            return self.endpoint_uris[0:self.process_limit]


processes = []

def main():
    usage = '%prog [ options ]'
    parser = OptionParser(usage=usage)
    parser.add_option('-H', '--hostname', metavar='HOST',
            type='string', dest='hostname', 
            help='Host running rest api (default=%default).', default='localhost')
    parser.add_option('-p', '--processes', metavar='PROCESSES',
            type='int', dest='processes', default=50,
            help='Number of client processes to launch (default=%default).')
    parser.add_option('-d', '--delay', metavar='DELAY',
            type='int', dest='delay', default=2,
            help='Second range generated by randint to introduce thread startup entropy (default=%default).')
    parser.add_option('-a', '--agg', metavar='AGG',
            type='int', dest='agg', default=None,
            help='Aggregation level in seconds.')
    parser.add_option('-c', '--cf', metavar='CF',
            type='string', dest='cf', default=None,
            help='Select consolidation function (min/max/average).')
    parser.add_option('-l', '--last', metavar='LAST',
            type='int', dest='last', default=0,
            help='Set optional time range to last n hours. If unset, api defaults to last hour.')
    parser.add_option('-r', '--randomize',
            dest='randomize', action='store_true', default=False,
            help='Randomize endpoint list to get heterogenous results (default=%default).')
    parser.add_option('-I', '--intercloud',
            dest='intercloud', action='store_true', default=False,
            help='Filter to intercloud devices (ending in -cr5).')
    parser.add_option('-D', '--device', metavar='DEVICE',
            type='string', dest='device', default='',
            help='Device name to filter on (lbl-mr2, etc).')
    parser.add_option('-E', '--endpoint', metavar='ENDPOINT',
            type='string', dest='endpoint', default='all',
            help='Endpoint type to filter on ' + \
            str(tuple(i for i in RequestGenerator.endpoint_keywords)) + \
            ' (default=%default).')
    parser.add_option('-L', '--list',
            dest='list', action='store_true', default=False,
            help='Dump list of devices to output and exit.')
    parser.add_option('-v', '--verbose',
            dest='verbose', action='count', default=False,
            help='Verbose output - -v, -vv, etc.')
    options, args = parser.parse_args()

    arg_error = False

    if options.cf and options.cf not in AGG_TYPES:
        print '--cf flag must be one of {0}'.format(AGG_TYPES)
        arg_error = True

    if options.cf and not options.agg:
        print 'must specify --agg arg when using --cf flag'
        arg_error = True

    if options.intercloud and options.device:
        print 'specify either --intercloud or --device filter'
        arg_error = True

    if options.endpoint != all and \
        options.endpoint not in RequestGenerator.endpoint_keywords:
        print 'invalid option {0} for endpoint filter'.format(options.endpoint)
        arg_error = True

    if arg_error:
        parser.print_help()
        return 1

    # Done parsing args

    endpoint_uris = []

    rgen = RequestGenerator(options.hostname, options.processes,
        options.randomize, options.intercloud, options.device,
        options.endpoint)

    if options.list:
        print 'Dumping device list and exiting'
        for d in rgen.get_device_list():
            print d
        return 0

    endpoint_uris = rgen.get_endpoint_list()
    
    for u in endpoint_uris:
        p = FetchProcess(Queue(), options.hostname, u, options.delay, 
            options.agg, options.cf, options.last)
        processes.append(p)

    print 'launching {0} client processes...'.format(len(endpoint_uris)),

    data_found = 0

    t = time.time()
    for p in processes:
        p.start()

    print 'go!'

    for p in processes:
        p.join()
        r = ResponseContainer(*p.q.get(), verbose=options.verbose)
        if r.is_error:
            print r.error_message
            continue
        if r.has_data: data_found += 1
        print r.response_text

    print 'fetched {0} endpoints in {1} seconds - {2} with data'.format(
            len(endpoint_uris), 
            time.time() - t,
            data_found
        )

    return 0

if __name__ == '__main__':
    main()