#!/usr/bin/env python2.7
#
# Copyright (c) 2013-present Facebook. All rights reserved.
#

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import csv
import io
import json
import logging
import multiprocessing
import os
import platform
import re
import shutil
import stat
import subprocess
import sys
import tempfile
import time
import traceback
import zipfile

# Infer imports
import inferlib
import utils

ANALYSIS_SUMMARY_OUTPUT = 'analysis_summary.txt'

BUCK_CONFIG = '.buckconfig.local'
BUCK_CONFIG_BACKUP = '.buckconfig.local.backup_generated_by_BuckAnalyze'
DEFAULT_BUCK_OUT = os.path.join(os.getcwd(), 'buck-out')
DEFAULT_BUCK_OUT_GEN = os.path.join(DEFAULT_BUCK_OUT, 'gen')

INFER_REPORT = os.path.join(utils.BUCK_INFER_OUT, utils.CSV_REPORT_FILENAME)
INFER_STATS = os.path.join(utils.BUCK_INFER_OUT, utils.STATS_FILENAME)

INFERJ_SCRIPT = """\
#!/bin/sh
{0} {1} javac $@
"""

LOCAL_CONFIG = """\
    [tools]
        javac = %s
"""


def prepare_build(args):
    """Creates script that redirects javac calls to inferJ and a local buck
    configuration that tells buck to use that script.
    """
    inferJ_options = [
        '--buck',
        '--analyzer',
        args.analyzer,
    ]

    if args.debug:
        inferJ_options.append('--debug')

    if args.no_filtering:
        inferJ_options.append('--no-filtering')

    if args.analyzer_mode:
       inferJ_options.append('--analyzer_mode')
       inferJ_options.append(args.analyzer_mode)

    # Create a temporary directory as a cache for jar files.
    infer_cache_dir = os.path.join(args.infer_out, 'cache')
    if not os.path.isdir(infer_cache_dir):
       os.mkdir(infer_cache_dir)
    inferJ_options.append('--infer_cache')
    inferJ_options.append(infer_cache_dir)
    temp_files = [infer_cache_dir]

    try:
        inferJ = utils.get_cmd_in_bin_dir('inferJ') + ' ' +\
            ' '.join(inferJ_options)
    except subprocess.CalledProcessError as e:
        logging.error('Could not find inferJ')
        raise e

    # Disable the use of buckd as this scripts modifies .buckconfig.local
    logging.info('Disabling buckd: export NO_BUCKD=1')
    os.environ['NO_BUCKD'] = '1'

    # make sure INFER_ANALYSIS is set when buck is called
    logging.info('Setup Infer analysis mode for Buck: export INFER_ANALYSIS=1')
    os.environ['INFER_ANALYSIS'] = '1'

    # Create a script to be called by buck
    inferJ_script = None
    with tempfile.NamedTemporaryFile(delete=False,
                                     prefix='inferJ_',
                                     suffix='.sh',
                                     dir='.') as inferJ_script:
        logging.info('Creating %s' % inferJ_script.name)
        inferJ_script.file.write(
            (INFERJ_SCRIPT.format(sys.executable, inferJ)).encode())

    st = os.stat(inferJ_script.name)
    os.chmod(inferJ_script.name, st.st_mode | stat.S_IEXEC)

    # Backup and patch local buck config
    patched_config = ''
    if os.path.isfile(BUCK_CONFIG):
        logging.info('Backing up %s to %s', BUCK_CONFIG, BUCK_CONFIG_BACKUP)
        shutil.move(BUCK_CONFIG, BUCK_CONFIG_BACKUP)
        with open(BUCK_CONFIG_BACKUP) as buckconfig:
            patched_config = '\n'.join(buckconfig)

    javac_section = '[tools]\n{0}javac = {1}'.format(
        ' ' * 4,
        inferJ_script.name)
    patched_config += javac_section
    with open(BUCK_CONFIG, 'w') as buckconfig:
        buckconfig.write(patched_config)

    temp_files += [inferJ_script.name]
    return temp_files


def java_targets():
    target_types = [
        'android_library',
        'java_library',
    ]
    try:
        targets = subprocess.check_output([
            'buck',
            'targets',
            '--type',
        ] + target_types).decode().strip().split('\n')
    except subprocess.CalledProcessError as e:
        logging.error('Could not compute java library targets')
        raise e
    return set(targets)


def is_alias(target):
    return ':' not in target


def expand_target(target, java_targets):
    if not is_alias(target):
        return [target]
    else:
        try:
            buck_audit_cmd = ['buck', 'audit', 'classpath', '--dot', target]
            output = subprocess.check_output(buck_audit_cmd)
            dotty = output.decode().split('\n')
        except subprocess.CalledProcessError as e:
            logging.error('Could not expand target {0}'.format(target))
            raise e
        targets = set()
        edge_re = re.compile('.*"(.*)".*"(.*)".*')
        for line in dotty:
            match = re.match(edge_re, line)
            if match:
                for t in match.groups():
                    if t in java_targets:
                        targets.add(t)
        return targets


def normalize_target(target):
    if is_alias(target) or target.startswith('//'):
        return target
    else:
        return '//' + target


def determine_library_targets(args):
    """ Uses git and buck audit to expand aliases into the list of java or
        android library targets that are parts of these aliases.
        Buck targets directly passed as argument are not expanded """

    args.targets = [normalize_target(t) for t in args.targets]

    if any(map(is_alias, args.targets)):
        all_java_targets = java_targets()
        targets = set()
        for t in args.targets:
            targets.update(expand_target(t, all_java_targets))
        args.targets = list(targets)

    if args.verbose:
        logging.debug('Targets to analyze:')
        for target in args.targets:
            logging.debug(target)


def init_stats(args, start_time):
    """Returns dictionary with target independent statistics.
    """
    return {
        'float': {},
        'int': {
            'cores': multiprocessing.cpu_count(),
            'time': int(time.time()),
            'start_time': int(round(start_time)),
        },
        'normal': {
            'debug': str(args.debug),
            'analyzer': args.analyzer,
            'machine': platform.machine(),
            'node': platform.node(),
            'project': os.path.basename(os.getcwd()),
            'revision': utils.vcs_revision(),
            'branch': utils.vcs_branch(),
            'system': platform.system(),
            'infer_version': utils.infer_version(),
            'infer_branch': utils.infer_branch(),
        }
    }


def store_performances_csv(infer_out, stats):
    """Stores the statistics about perfromances into a CSV file to be exported
        to a database"""
    perf_filename = os.path.join(infer_out, utils.CSV_PERF_FILENAME)
    with open(perf_filename, 'w') as csv_file_out:
        csv_writer = csv.writer(csv_file_out)
        keys = ['infer_version', 'project', 'revision', 'files', 'lines',
                'cores', 'system', 'machine', 'node', 'total_time',
                'capture_time', 'analysis_time', 'reporting_time', 'time']
        int_stats = list(stats['int'].items())
        normal_stats = list(stats['normal'].items())
        flat_stats = dict(int_stats + normal_stats)
        values = []
        for key in keys:
            values.append(flat_stats[key])
        csv_writer.writerow(keys)
        csv_writer.writerow(values)
        csv_file_out.flush()


def get_harness_code():
    all_harness_code = '\nGenerated harness code:\n'
    for filename in os.listdir(DEFAULT_BUCK_OUT_GEN):
        if 'InferGeneratedHarness' in filename:
            all_harness_code += '\n' + filename + ':\n'
            with open(os.path.join(DEFAULT_BUCK_OUT_GEN,
                                   filename), 'r') as file_in:
                all_harness_code += file_in.read()
    return all_harness_code + '\n'


def get_basic_stats(stats):
    files_analyzed = '{0} files ({1} lines) analyzed in {2}s\n\n'.format(
        stats['int']['files'],
        stats['int']['lines'],
        stats['int']['total_time'],
    )
    phase_times = 'Capture time: {0}s\nAnalysis time: {1}s\n\n'.format(
        stats['int']['capture_time'],
        stats['int']['analysis_time'],
    )

    to_skip = {
        'files',
        'lines',
        'cores',
        'time',
        'start_time',
        'capture_time',
        'analysis_time',
        'reporting_time',
        'total_time',
    }
    bugs_found = 'Errors found:\n\n'
    for key, value in sorted(stats['int'].items()):
        if key not in to_skip:
            bugs_found += '  {0:>8}  {1}\n'.format(value, key)

    basic_stats_message = files_analyzed + phase_times + bugs_found + '\n'
    return basic_stats_message


def get_buck_stats():
    trace_filename = os.path.join(
        DEFAULT_BUCK_OUT,
        'log',
        'traces',
        'build.trace'
    )
    ARGS = 'args'
    SUCCESS_STATUS = 'success_type'
    buck_stats = {}
    try:
        with open(trace_filename, 'r') as file_in:
            trace = json.load(file_in)
            for t in trace:
                if SUCCESS_STATUS in t[ARGS]:
                    status = t[ARGS][SUCCESS_STATUS]
                    count = buck_stats.get(status, 0)
                    buck_stats[status] = count + 1

        buck_stats_message = 'Buck build statistics:\n\n'
        for key, value in sorted(buck_stats.items()):
            buck_stats_message += '  {0:>8}  {1}\n'.format(value, key)

        return buck_stats_message
    except IOError as e:
        logging.error('Caught %s: %s' % (e.__class__.__name__, str(e)))
        logging.error(traceback.format_exc())
        return ''


class NotFoundInJar(Exception):
    pass


def load_stats(opened_jar):
    try:
        return json.loads(opened_jar.read(INFER_STATS).decode())
    except KeyError as e:
        raise NotFoundInJar


def load_report(opened_jar):
    try:
        sio = io.StringIO(opened_jar.read(INFER_REPORT).decode())
        return list(csv.reader(sio))
    except KeyError as e:
        raise NotFoundInJar


def rows_remove_duplicates(rows):
    seen = {}
    result = []
    for row in rows:
        t = tuple(row)
        if t in seen:
            continue
        seen[t] = 1
        result.append(row)
    return result


def collect_results(args, start_time):
    """Walks through buck-gen, collects results for the different buck targets
    and stores them in in args.infer_out/results.csv.
    """

    buck_stats = get_buck_stats()
    logging.info(buck_stats)
    with open(os.path.join(args.infer_out, ANALYSIS_SUMMARY_OUTPUT), 'w') as f:
        f.write(buck_stats)

    all_rows = []
    headers = []
    stats = init_stats(args, start_time)

    accumulation_whitelist = list(map(re.compile, [
        '^cores$',
        '^time$',
        '^start_time$',
        '.*_pc',
    ]))

    expected_analyzer = stats['normal']['analyzer']
    expected_version = stats['normal']['infer_version']

    for root, _, files in os.walk(DEFAULT_BUCK_OUT_GEN):
        for f in [f for f in files if f.endswith('.jar')]:
            path = os.path.join(root, f)
            try:
                with zipfile.ZipFile(path) as jar:
                    # Accumulate integers and float values
                    target_stats = load_stats(jar)

                    found_analyzer = target_stats['normal']['analyzer']
                    found_version = target_stats['normal']['infer_version']

                    if (found_analyzer != expected_analyzer
                            or found_version != expected_version):
                        continue
                    else:
                        for type_k in ['int', 'float']:
                            items = target_stats.get(type_k, {}).items()
                            for key, value in items:
                                if not any(map(lambda r: r.match(key),
                                           accumulation_whitelist)):
                                    old_value = stats[type_k].get(key, 0)
                                    stats[type_k][key] = old_value + value

                    rows = load_report(jar)
                    if len(rows) > 0:
                        headers.append(rows[0])
                        all_rows.extend(rows[1:])

                    # Override normals
                    stats['normal'].update(target_stats.get('normal', {}))
            except NotFoundInJar:
                pass
            except zipfile.BadZipfile:
                logging.warn('Bad zip file %s', path)

    csv_report = os.path.join(args.infer_out, utils.CSV_REPORT_FILENAME)
    bugs_out = os.path.join(args.infer_out, utils.BUGS_FILENAME)

    if len(headers) == 0:
        with open(csv_report, 'w'):
            pass
        logging.info('No reports found')
        return
    elif len(headers) > 1:
        if any(map(lambda x: x != headers[0], headers)):
            raise Exception('Inconsistent reports found')

    # Convert all float values to integer values
    for key, value in stats.get('float', {}).items():
        stats['int'][key] = int(round(value))

    # Delete the float entries before exporting the results
    del(stats['float'])

    with open(csv_report, 'w') as report:
        writer = csv.writer(report)
        writer.writerows([headers[0]] + rows_remove_duplicates(all_rows))
        report.flush()

    # export the CSV rows to JSON
    utils.create_json_report(args.infer_out)

    print('\n')
    inferlib.print_errors(csv_report, bugs_out)

    stats['int']['total_time'] = int(round(utils.elapsed_time(start_time)))

    store_performances_csv(args.infer_out, stats)

    stats_filename = os.path.join(args.infer_out, utils.STATS_FILENAME)
    with open(stats_filename, 'w') as stats_out:
        json.dump(stats, stats_out)

    basic_stats = get_basic_stats(stats)

    if args.print_harness:
        harness_code = get_harness_code()
        basic_stats += harness_code

    logging.info(basic_stats)

    with open(os.path.join(args.infer_out, ANALYSIS_SUMMARY_OUTPUT), 'a') as f:
        f.write(basic_stats)


def cleanup(temp_files):
    """Removes the generated .buckconfig.local and the temporary inferJ script.
    """
    for file in [BUCK_CONFIG] + temp_files:
        try:
            logging.info('Removing %s' % file)
            if os.path.isdir(file):
              shutil.rmtree(file)
            else:
              os.unlink(file)
        except IOError:
            logging.error('Could not remove %s' % file)

    if os.path.isfile(BUCK_CONFIG_BACKUP):
        logging.info('Restoring %s', BUCK_CONFIG)
        shutil.move(BUCK_CONFIG_BACKUP, BUCK_CONFIG)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(parents=[inferlib.base_parser])
    parser.add_argument('--verbose', action='store_true',
                        help='Print buck compilation steps')
    parser.add_argument('--no-cache', action='store_true',
                        help='Do not use buck distributed cache')
    parser.add_argument('--print-harness', action='store_true',
                        help='Print generated harness code (Android only)')
    parser.add_argument('targets', nargs='*', metavar='target',
                        help='Build targets to analyze')
    args = parser.parse_args()

    utils.configure_logging(args.verbose)
    timer = utils.Timer(logging.info)
    temp_files = []

    try:
        start_time = time.time()
        logging.info('Starting the analysis')
        subprocess.check_call(
            [utils.get_cmd_in_bin_dir('InferAnalyze'), '-version'])

        if not os.path.isdir(args.infer_out):
            os.mkdir(args.infer_out)

        timer.start('Preparing build...')
        temp_files += prepare_build(args)
        timer.stop('Build prepared')

        # TODO(t3786463) Start buckd.

        timer.start('Computing library targets')
        determine_library_targets(args)
        timer.stop('%d targets computed', len(args.targets))

        timer.start('Running buck...')
        buck_cmd = ['buck', 'build']
        if args.no_cache:
            buck_cmd += ['--no-cache']
        if args.verbose:
            buck_cmd += ['-v', '2']
        subprocess.check_call(buck_cmd + args.targets)
        timer.stop('Buck finished')

        timer.start('Collecting results...')
        collect_results(args, start_time)
        timer.stop('Done')

    except KeyboardInterrupt as e:
        timer.stop('Exiting')
        sys.exit(0)
    except Exception as e:
        timer.stop('Failed')
        logging.error('Caught %s: %s' % (e.__class__.__name__, str(e)))
        logging.error(traceback.format_exc())
        sys.exit(1)
    finally:
        cleanup(temp_files)


# vim: set sw=4 ts=4 et:
