#!/usr/bin/python3
#
# Script for comparing buildstats from two different builds
#
# Copyright (c) 2016, Intel Corporation.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms and conditions of the GNU General Public License,
# version 2, as published by the Free Software Foundation.
#
# This program is distributed in the hope it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
# more details.
#
import argparse
import glob
import json
import logging
import math
import os
import re
import sys
from collections import namedtuple
from operator import attrgetter

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger()


class ScriptError(Exception):
    """Exception for internal error handling of this script"""
    pass


taskdiff_fields = ('pkg', 'pkg_op', 'task', 'task_op', 'value1', 'value2',
                   'absdiff', 'reldiff')
TaskDiff = namedtuple('TaskDiff', ' '.join(taskdiff_fields))


class BSTask(dict):
    def __init__(self, *args, **kwargs):
        self['start_time'] = None
        self['elapsed_time'] = None
        self['status'] = None
        self['iostat'] = {}
        self['rusage'] = {}
        self['child_rusage'] = {}
        super(BSTask, self).__init__(*args, **kwargs)

    @property
    def cputime(self):
        """Sum of user and system time taken by the task"""
        rusage = self['rusage']['ru_stime'] + self['rusage']['ru_utime']
        if self['child_rusage']:
            # Child rusage may have been optimized out
            return rusage + self['child_rusage']['ru_stime'] + self['child_rusage']['ru_utime']
        else:
            return rusage

    @property
    def walltime(self):
        """Elapsed wall clock time"""
        return self['elapsed_time']

    @property
    def read_bytes(self):
        """Bytes read from the block layer"""
        return self['iostat']['read_bytes']

    @property
    def write_bytes(self):
        """Bytes written to the block layer"""
        return self['iostat']['write_bytes']

    @property
    def read_ops(self):
        """Number of read operations on the block layer"""
        if self['child_rusage']:
            # Child rusage may have been optimized out
            return self['rusage']['ru_inblock'] + self['child_rusage']['ru_inblock']
        else:
            return self['rusage']['ru_inblock']

    @property
    def write_ops(self):
        """Number of write operations on the block layer"""
        if self['child_rusage']:
            # Child rusage may have been optimized out
            return self['rusage']['ru_oublock'] + self['child_rusage']['ru_oublock']
        else:
            return self['rusage']['ru_oublock']


def read_buildstats_file(buildstat_file):
    """Convert buildstat text file into dict/json"""
    bs_task = BSTask()
    log.debug("Reading task buildstats from %s", buildstat_file)
    end_time = None
    with open(buildstat_file) as fobj:
        for line in fobj.readlines():
            key, val = line.split(':', 1)
            val = val.strip()
            if key == 'Started':
                start_time = float(val)
                bs_task['start_time'] = start_time
            elif key == 'Ended':
                end_time = float(val)
            elif key.startswith('IO '):
                split = key.split()
                bs_task['iostat'][split[1]] = int(val)
            elif key.find('rusage') >= 0:
                split = key.split()
                ru_key = split[-1]
                if ru_key in ('ru_stime', 'ru_utime'):
                    val = float(val)
                else:
                    val = int(val)
                ru_type = 'rusage' if split[0] == 'rusage' else \
                                                  'child_rusage'
                bs_task[ru_type][ru_key] = val
            elif key == 'Status':
                bs_task['status'] = val
    if end_time is not None and start_time is not None:
        bs_task['elapsed_time'] = end_time - start_time
    else:
        raise ScriptError("{} looks like a invalid buildstats file".format(buildstat_file))
    return bs_task


def read_buildstats_dir(bs_dir):
    """Read buildstats directory"""
    def split_nevr(nevr):
        """Split name and version information from recipe "nevr" string"""
        n_e_v, revision = nevr.rsplit('-', 1)
        match = re.match(r'^(?P<name>\S+)-((?P<epoch>[0-9]{1,5})_)?(?P<version>[0-9]\S*)$',
                         n_e_v)
        if not match:
            # If we're not able to parse a version starting with a number, just
            # take the part after last dash
            match = re.match(r'^(?P<name>\S+)-((?P<epoch>[0-9]{1,5})_)?(?P<version>[^-]+)$',
                             n_e_v)
        name = match.group('name')
        version = match.group('version')
        epoch = match.group('epoch')
        return name, epoch, version, revision

    if not os.path.isfile(os.path.join(bs_dir, 'build_stats')):
        raise ScriptError("{} does not look like a buildstats directory".format(bs_dir))

    log.debug("Reading buildstats directory %s", bs_dir)

    buildstats = {}
    subdirs = os.listdir(bs_dir)
    for dirname in subdirs:
        recipe_dir = os.path.join(bs_dir, dirname)
        if not os.path.isdir(recipe_dir):
            continue
        name, epoch, version, revision = split_nevr(dirname)
        recipe_bs = {'nevr': dirname,
                     'name': name,
                     'epoch': epoch,
                     'version': version,
                     'revision': revision,
                     'tasks': {}}
        for task in os.listdir(recipe_dir):
            recipe_bs['tasks'][task] = [read_buildstats_file(
                    os.path.join(recipe_dir, task))]
        if name in buildstats:
            raise ScriptError("Cannot handle multiple versions of the same "
                              "package ({})".format(name))
        buildstats[name] = recipe_bs

    return buildstats


def bs_append(dst, src):
    """Append data from another buildstats"""
    if set(dst.keys()) != set(src.keys()):
        raise ScriptError("Refusing to join buildstats, set of packages is "
                          "different")
    for pkg, data in dst.items():
        if data['nevr'] != src[pkg]['nevr']:
            raise ScriptError("Refusing to join buildstats, package version "
                              "differs: {} vs. {}".format(data['nevr'], src[pkg]['nevr']))
        if set(data['tasks'].keys()) != set(src[pkg]['tasks'].keys()):
            raise ScriptError("Refusing to join buildstats, set of tasks "
                              "in {} differ".format(pkg))
        for taskname, taskdata in data['tasks'].items():
            taskdata.extend(src[pkg]['tasks'][taskname])


def read_buildstats_json(path):
    """Read buildstats from JSON file"""
    buildstats = {}
    with open(path) as fobj:
        bs_json = json.load(fobj)
    for recipe_bs in bs_json:
        if recipe_bs['name'] in buildstats:
            raise ScriptError("Cannot handle multiple versions of the same "
                              "package ({})".format(recipe_bs['name']))

        if recipe_bs['epoch'] is None:
            recipe_bs['nevr'] = "{}-{}-{}".format(recipe_bs['name'], recipe_bs['version'], recipe_bs['revision'])
        else:
            recipe_bs['nevr'] = "{}-{}_{}-{}".format(recipe_bs['name'], recipe_bs['epoch'], recipe_bs['version'], recipe_bs['revision'])

        for task, data in recipe_bs['tasks'].copy().items():
            recipe_bs['tasks'][task] = [BSTask(data)]

        buildstats[recipe_bs['name']] = recipe_bs

    return buildstats


def read_buildstats(path, multi):
    """Read buildstats"""
    if not os.path.exists(path):
        raise ScriptError("No such file or directory: {}".format(path))

    if os.path.isfile(path):
        return read_buildstats_json(path)

    if os.path.isfile(os.path.join(path, 'build_stats')):
        return read_buildstats_dir(path)

    # Handle a non-buildstat directory
    subpaths = sorted(glob.glob(path + '/*'))
    if len(subpaths) > 1:
        if multi:
            log.info("Averaging over {} buildstats from {}".format(
                     len(subpaths), path))
        else:
            raise ScriptError("Multiple buildstats found in '{}'. Please give "
                              "a single buildstat directory of use the --multi "
                              "option".format(path))
    bs = None
    for subpath in subpaths:
        if os.path.isfile(subpath):
            tmpbs = read_buildstats_json(subpath)
        else:
            tmpbs = read_buildstats_dir(subpath)
        if not bs:
            bs = tmpbs
        else:
            log.debug("Joining buildstats")
            bs_append(bs, tmpbs)

    if not bs:
        raise ScriptError("No buildstats found under {}".format(path))
    return bs


def print_ver_diff(bs1, bs2):
    """Print package version differences"""
    pkgs1 = set(bs1.keys())
    pkgs2 = set(bs2.keys())
    new_pkgs = pkgs2 - pkgs1
    deleted_pkgs = pkgs1 - pkgs2

    echanged = []
    vchanged = []
    rchanged = []
    unchanged = []
    common_pkgs = pkgs2.intersection(pkgs1)
    if common_pkgs:
        for pkg in common_pkgs:
            if bs1[pkg]['epoch'] != bs2[pkg]['epoch']:
                echanged.append(pkg)
            elif bs1[pkg]['version'] != bs2[pkg]['version']:
                vchanged.append(pkg)
            elif bs1[pkg]['revision'] != bs2[pkg]['revision']:
                rchanged.append(pkg)
            else:
                unchanged.append(pkg)

    maxlen = max([len(pkg) for pkg in pkgs1.union(pkgs2)])
    fmt_str = "  {:{maxlen}} ({})"
#    if unchanged:
#        print("\nUNCHANGED PACKAGES:")
#        print("-------------------")
#        maxlen = max([len(pkg) for pkg in unchanged])
#        for pkg in sorted(unchanged):
#            print(fmt_str.format(pkg, bs2[pkg]['nevr'], maxlen=maxlen))

    if new_pkgs:
        print("\nNEW PACKAGES:")
        print("-------------")
        for pkg in sorted(new_pkgs):
            print(fmt_str.format(pkg, bs2[pkg]['nevr'], maxlen=maxlen))

    if deleted_pkgs:
        print("\nDELETED PACKAGES:")
        print("-----------------")
        for pkg in sorted(deleted_pkgs):
            print(fmt_str.format(pkg, bs1[pkg]['nevr'], maxlen=maxlen))

    fmt_str = "  {0:{maxlen}} {1:<20}    ({2})"
    if rchanged:
        print("\nREVISION CHANGED:")
        print("-----------------")
        for pkg in sorted(rchanged):
            field1 = "{} -> {}".format(pkg, bs1[pkg]['revision'], bs2[pkg]['revision'])
            field2 = "{} -> {}".format(bs1[pkg]['nevr'], bs2[pkg]['nevr'])
            print(fmt_str.format(pkg, field1, field2, maxlen=maxlen))

    if vchanged:
        print("\nVERSION CHANGED:")
        print("----------------")
        for pkg in sorted(vchanged):
            field1 = "{} -> {}".format(bs1[pkg]['version'], bs2[pkg]['version'])
            field2 = "{} -> {}".format(bs1[pkg]['nevr'], bs2[pkg]['nevr'])
            print(fmt_str.format(pkg, field1, field2, maxlen=maxlen))

    if echanged:
        print("\nEPOCH CHANGED:")
        print("--------------")
        for pkg in sorted(echanged):
            field1 = "{} -> {}".format(bs1[pkg]['epoch'], bs2[pkg]['epoch'])
            field2 = "{} -> {}".format(bs1[pkg]['nevr'], bs2[pkg]['nevr'])
            print(fmt_str.format(pkg, field1, field2, maxlen=maxlen))


def print_task_diff(bs1, bs2, val_type, min_val=0, min_absdiff=0, sort_by=('absdiff',)):
    """Diff task execution times"""
    def val_to_str(val, human_readable=False):
        """Convert raw value to printable string"""
        def hms_time(secs):
            """Get time in human-readable HH:MM:SS format"""
            h = int(secs / 3600)
            m = int((secs % 3600) / 60)
            s = secs % 60
            if h == 0:
                return "{:02d}:{:04.1f}".format(m, s)
            else:
                return "{:d}:{:02d}:{:04.1f}".format(h, m, s)

        if 'time' in val_type:
            if human_readable:
                return hms_time(val)
            else:
                return "{:.1f}s".format(val)
        elif 'bytes' in val_type and human_readable:
                prefix = ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi']
                dec = int(math.log(val, 2) / 10)
                prec = 1 if dec > 0 else 0
                return "{:.{prec}f}{}B".format(val / (2 ** (10 * dec)),
                                               prefix[dec], prec=prec)
        elif 'ops' in val_type and human_readable:
                prefix = ['', 'k', 'M', 'G', 'T', 'P']
                dec = int(math.log(val, 1000))
                prec = 1 if dec > 0 else 0
                return "{:.{prec}f}{}ops".format(val / (1000 ** dec),
                                                 prefix[dec], prec=prec)
        return str(int(val))

    def sum_vals(buildstats):
        """Get cumulative sum of all tasks"""
        total = 0.0
        for recipe_data in buildstats.values():
            for bs_task in recipe_data['tasks'].values():
                total += sum([getattr(b, val_type) for b in bs_task]) / len(bs_task)
        return total

    tasks_diff = []

    if min_val:
        print("Ignoring tasks less than {} ({})".format(
                val_to_str(min_val, True), val_to_str(min_val)))
    if min_absdiff:
        print("Ignoring differences less than {} ({})".format(
                val_to_str(min_absdiff, True), val_to_str(min_absdiff)))

    # Prepare the data
    pkgs = set(bs1.keys()).union(set(bs2.keys()))
    for pkg in pkgs:
        tasks1 = bs1[pkg]['tasks'] if pkg in bs1 else {}
        tasks2 = bs2[pkg]['tasks'] if pkg in bs2 else {}
        if not tasks1:
            pkg_op = '+ '
        elif not tasks2:
            pkg_op = '- '
        else:
            pkg_op = '  '

        for task in set(tasks1.keys()).union(set(tasks2.keys())):
            task_op = '  '
            if task in tasks1:
                # Average over all values
                val1 = [getattr(b, val_type) for b in bs1[pkg]['tasks'][task]]
                val1 = sum(val1) / len(val1)
            else:
                task_op = '+ '
                val1 = 0
            if task in tasks2:
                # Average over all values
                val2 = [getattr(b, val_type) for b in bs2[pkg]['tasks'][task]]
                val2 = sum(val2) / len(val2)
            else:
                val2 = 0
                task_op = '- '

            if val1 == 0:
                reldiff = float('inf')
            else:
                reldiff = 100 * (val2 - val1) / val1

            if max(val1, val2) < min_val:
                log.debug("Filtering out %s:%s (%s)", pkg, task,
                          val_to_str(max(val1, val2)))
                continue
            if abs(val2 - val1) < min_absdiff:
                log.debug("Filtering out %s:%s (difference of %s)", pkg, task,
                          val_to_str(val2-val1))
                continue
            tasks_diff.append(TaskDiff(pkg, pkg_op, task, task_op, val1, val2,
                                       val2-val1, reldiff))

    # Sort our list
    for field in reversed(sort_by):
        if field.startswith('-'):
            field = field[1:]
            reverse = True
        else:
            reverse = False
        tasks_diff = sorted(tasks_diff, key=attrgetter(field), reverse=reverse)

    linedata = [('  ', 'PKG', '  ', 'TASK', 'ABSDIFF', 'RELDIFF',
                val_type.upper() + '1', val_type.upper() + '2')]
    field_lens = dict([('len_{}'.format(i), len(f)) for i, f in enumerate(linedata[0])])

    # Prepare fields in string format and measure field lengths
    for diff in tasks_diff:
        task_prefix = diff.task_op if diff.pkg_op == '  ' else '  '
        linedata.append((diff.pkg_op, diff.pkg, task_prefix, diff.task,
                         val_to_str(diff.absdiff),
                         '{:+.1f}%'.format(diff.reldiff),
                         val_to_str(diff.value1),
                         val_to_str(diff.value2)))
        for i, field in enumerate(linedata[-1]):
            key = 'len_{}'.format(i)
            if len(field) > field_lens[key]:
                field_lens[key] = len(field)

    # Print data
    print()
    for fields in linedata:
        print("{:{len_0}}{:{len_1}}  {:{len_2}}{:{len_3}}  {:>{len_4}}  {:>{len_5}}  {:>{len_6}} -> {:{len_7}}".format(
                *fields, **field_lens))

    # Print summary of the diffs
    total1 = sum_vals(bs1)
    total2 = sum_vals(bs2)
    print("\nCumulative {}:".format(val_type))
    print ("  {}    {:+.1f}%    {} ({}) -> {} ({})".format(
                val_to_str(total2 - total1), 100 * (total2-total1) / total1,
                val_to_str(total1, True), val_to_str(total1),
                val_to_str(total2, True), val_to_str(total2)))


def parse_args(argv):
    """Parse cmdline arguments"""
    description="""
Script for comparing buildstats of two separate builds."""
    parser = argparse.ArgumentParser(
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
            description=description)

    min_val_defaults = {'cputime': 3.0,
                        'read_bytes': 524288,
                        'write_bytes': 524288,
                        'read_ops': 500,
                        'write_ops': 500,
                        'walltime': 5}
    min_absdiff_defaults = {'cputime': 1.0,
                            'read_bytes': 131072,
                            'write_bytes': 131072,
                            'read_ops': 50,
                            'write_ops': 50,
                            'walltime': 2}

    parser.add_argument('--debug', '-d', action='store_true',
                        help="Verbose logging")
    parser.add_argument('--ver-diff', action='store_true',
                        help="Show package version differences and exit")
    parser.add_argument('--diff-attr', default='cputime',
                        choices=min_val_defaults.keys(),
                        help="Buildstat attribute which to compare")
    parser.add_argument('--min-val', default=min_val_defaults, type=float,
                        help="Filter out tasks less than MIN_VAL. "
                             "Default depends on --diff-attr.")
    parser.add_argument('--min-absdiff', default=min_absdiff_defaults, type=float,
                        help="Filter out tasks whose difference is less than "
                             "MIN_ABSDIFF, Default depends on --diff-attr.")
    parser.add_argument('--sort-by', default='absdiff',
                        help="Comma-separated list of field sort order. "
                             "Prepend the field name with '-' for reversed sort. "
                             "Available fields are: {}".format(', '.join(taskdiff_fields)))
    parser.add_argument('--multi', action='store_true',
                        help="Read all buildstats from the given paths and "
                             "average over them")
    parser.add_argument('buildstats1', metavar='BUILDSTATS1', help="'Left' buildstat")
    parser.add_argument('buildstats2', metavar='BUILDSTATS2', help="'Right' buildstat")

    args = parser.parse_args(argv)

    # We do not nedd/want to read all buildstats if we just want to look at the
    # package versions
    if args.ver_diff:
        args.multi = False

    # Handle defaults for the filter arguments
    if args.min_val is min_val_defaults:
        args.min_val = min_val_defaults[args.diff_attr]
    if args.min_absdiff is min_absdiff_defaults:
        args.min_absdiff = min_absdiff_defaults[args.diff_attr]

    return args


def main(argv=None):
    """Script entry point"""
    args = parse_args(argv)
    if args.debug:
        log.setLevel(logging.DEBUG)

    # Validate sort fields
    sort_by = []
    for field in args.sort_by.split(','):
        if field.lstrip('-') not in taskdiff_fields:
            log.error("Invalid sort field '%s' (must be one of: %s)" %
                      (field, ', '.join(taskdiff_fields)))
            sys.exit(1)
        sort_by.append(field)

    try:
        bs1 = read_buildstats(args.buildstats1, args.multi)
        bs2 = read_buildstats(args.buildstats2, args.multi)

        if args.ver_diff:
            print_ver_diff(bs1, bs2)
        else:
            print_task_diff(bs1, bs2, args.diff_attr, args.min_val,
                            args.min_absdiff, sort_by)
    except ScriptError as err:
        log.error(str(err))
        return 1
    return 0

if __name__ == "__main__":
    sys.exit(main())