#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2015 Matt Martz <matt@sivel.net>
# Copyright (C) 2015 Rackspace US, Inc.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import print_function

import abc
import argparse
import ast
import os
import re
import sys

from distutils.version import StrictVersion
from fnmatch import fnmatch

from ansible import __version__ as ansible_version
from ansible.executor.module_common import REPLACER_WINDOWS
from ansible.plugins import module_loader
from ansible.utils.module_docs import BLACKLIST_MODULES, get_docstring

from module_args import get_argument_spec

from schema import doc_schema, option_schema

from utils import CaptureStd

import yaml


BLACKLIST_DIRS = frozenset(('.git', 'test', '.github', '.idea'))
INDENT_REGEX = re.compile(r'([\t]*)')
BLACKLIST_IMPORTS = {
    'requests': {
        'new_only': True,
        'msg': ('requests import found, should use '
                'ansible.module_utils.urls instead')
    },
    'boto(?:\.|$)': {
        'new_only': True,
        'msg': ('boto import found, new modules should use boto3')
    },
}


class Validator(object):
    """Validator instances are intended to be run on a single object.  if you
    are scanning multiple objects for problems, you'll want to have a separate
    Validator for each one."""
    __metaclass__ = abc.ABCMeta

    def __init__(self):
        self.reset()

    def reset(self):
        """Reset the test results"""
        self.errors = []
        self.warnings = []
        self.traces = []
        self.warning_traces = []

    @abc.abstractproperty
    def object_name(self):
        """Name of the object we validated"""
        pass

    @abc.abstractproperty
    def object_path(self):
        """Path of the object we validated"""
        pass

    @abc.abstractmethod
    def validate(self, reset=True):
        """Run this method to generate the test results"""
        if reset:
            self.reset()

    def report(self, warnings=False):
        """Print out the test results"""
        if self.errors or (warnings and self.warnings):
            print('=' * 76)
            print(self.object_path)
            print('=' * 76)

        ret = []

        traces = self.traces[:]
        if warnings and self.warnings:
            traces.extend(self.warning_traces)

        for trace in traces:
            print('TRACE:')
            print('\n    '.join(('    %s' % trace).splitlines()))
        for error in self.errors:
            print('ERROR: %s' % error)
            ret.append(1)
        if warnings:
            for warning in self.warnings:
                print('WARNING: %s' % warning)
                # ret.append(1)  # Don't incrememt exit status for warnings

        if self.errors or (warnings and self.warnings):
            print()

        return len(ret)


class ModuleValidator(Validator):
    BLACKLIST_PATTERNS = ('.git*', '*.pyc', '*.pyo', '.*', '*.md', '*.txt')
    BLACKLIST_FILES = frozenset(('.git', '.gitignore', '.travis.yml',
                                 'shippable.yml',
                                 '.gitattributes', '.gitmodules', 'COPYING',
                                 '__init__.py', 'VERSION', 'test-docs.sh'))
    BLACKLIST = BLACKLIST_FILES.union(BLACKLIST_MODULES)

    PS_DOC_BLACKLIST = frozenset((
        'async_status.ps1',
        'slurp.ps1',
        'setup.ps1'
    ))

    def __init__(self, path, analyze_arg_spec=False):
        super(ModuleValidator, self).__init__()

        self.path = path
        self.basename = os.path.basename(self.path)
        self.name, _ = os.path.splitext(self.basename)

        self.analyze_arg_spec = analyze_arg_spec

        self._python_module_override = False

        with open(path) as f:
            self.text = f.read()
        self.length = len(self.text.splitlines())
        try:
            self.ast = ast.parse(self.text)
        except:
            self.ast = None

    @property
    def object_name(self):
        return self.basename

    @property
    def object_path(self):
        return self.path

    def _python_module(self):
        if self.path.endswith('.py') or self._python_module_override:
            return True
        return False

    def _powershell_module(self):
        if self.path.endswith('.ps1'):
            return True
        return False

    def _just_docs(self):
        try:
            for child in self.ast.body:
                if not isinstance(child, ast.Assign):
                    return False
            return True
        except AttributeError:
            return False

    def _is_new_module(self):
        if self.name.startswith("_") and not os.path.islink(self.name):
            # This is a deprecated module, so look up the *old* name
            return not module_loader.has_plugin(self.name[1:])
        else:
            return not module_loader.has_plugin(self.name)

    def _check_interpreter(self, powershell=False):
        if powershell:
            if not self.text.startswith('#!powershell\n'):
                self.errors.append('Interpreter line is not "#!powershell"')
            return

        if not self.text.startswith('#!/usr/bin/python'):
            self.errors.append('Interpreter line is not "#!/usr/bin/python"')

    def _check_for_sys_exit(self):
        if 'sys.exit(' in self.text:
            self.errors.append('sys.exit() call found. Should be '
                               'exit_json/fail_json')

    def _check_for_gpl3_header(self):
        if ('GNU General Public License' not in self.text and
                'version 3' not in self.text):
            self.errors.append('GPLv3 license header not found')

    def _check_for_tabs(self):
        for line_no, line in enumerate(self.text.splitlines()):
            indent = INDENT_REGEX.search(line)
            if indent and '\t' in line:
                index = line.index('\t')
                self.errors.append('indentation contains tabs. line %d '
                                   'column %d' % (line_no + 1, index))

    def _find_blacklist_imports(self):
        for child in self.ast.body:
            names = []
            if isinstance(child, ast.Import):
                names.extend(child.names)
            elif isinstance(child, ast.TryExcept):
                bodies = child.body
                for handler in child.handlers:
                    bodies.extend(handler.body)
                for grandchild in bodies:
                    if isinstance(grandchild, ast.Import):
                        names.extend(grandchild.names)
            for name in names:
                for blacklist_import, options in BLACKLIST_IMPORTS.items():
                    if re.search(blacklist_import, name.name):
                        msg = options['msg']
                        new_only = options['new_only']
                        if self._is_new_module() and new_only:
                            self.errors.append(msg)
                        elif not new_only:
                            self.errors.append(msg)

    def _find_module_utils(self, main):
        linenos = []
        found_basic = False
        for child in self.ast.body:
            if isinstance(child, (ast.Import, ast.ImportFrom)):
                names = []
                try:
                    names.append(child.module)
                    if child.module.endswith('.basic'):
                        found_basic = True
                except AttributeError:
                    pass
                names.extend([n.name for n in child.names])

                if [n for n in names if n.startswith('ansible.module_utils')]:
                    linenos.append(child.lineno)

                    for name in child.names:
                        if (isinstance(name, ast.alias) and
                                name.name == 'basic'):
                            found_basic = True

        if not linenos:
            self.errors.append('Did not find a module_utils import')
        elif not found_basic:
            self.warnings.append('Did not find "ansible.module_utils.basic" '
                               'import')

        return linenos

    def _find_main_call(self):
        lineno = False
        if_bodies = []
        for child in self.ast.body:
            if isinstance(child, ast.If):
                try:
                    if child.test.left.id == '__name__':
                        if_bodies.extend(child.body)
                except AttributeError:
                    pass

        bodies = self.ast.body
        bodies.extend(if_bodies)

        for child in bodies:
            if isinstance(child, ast.Expr):
                if isinstance(child.value, ast.Call):
                    if (isinstance(child.value.func, ast.Name) and
                            child.value.func.id == 'main'):
                        lineno = child.lineno
                        if lineno < self.length - 1:
                            self.errors.append('Call to main() not the last '
                                               'line')

        if not lineno:
            self.errors.append('Did not find a call to main')

        return lineno or 0

    def _find_has_import(self):
        for child in self.ast.body:
            found_try_except_import = False
            found_has = False
            if isinstance(child, ast.TryExcept):
                bodies = child.body
                for handler in child.handlers:
                    bodies.extend(handler.body)
                for grandchild in bodies:
                    if isinstance(grandchild, ast.Import):
                        found_try_except_import = True
                    if isinstance(grandchild, ast.Assign):
                        for target in grandchild.targets:
                            if target.id.lower().startswith('has_'):
                                found_has = True
            if found_try_except_import and not found_has:
                self.warnings.append('Found Try/Except block without HAS_ '
                                     'assginment')

    def _find_ps_replacers(self):
        if 'WANT_JSON' not in self.text:
            self.errors.append('WANT_JSON not found in module')

        if REPLACER_WINDOWS not in self.text:
            self.errors.append('"%s" not found in module' % REPLACER_WINDOWS)

    def _find_ps_docs_py_file(self):
        if self.object_name in self.PS_DOC_BLACKLIST:
            return
        py_path = self.path.replace('.ps1', '.py')
        if not os.path.isfile(py_path):
            self.errors.append('Missing python documentation file')

    def _get_docs(self):
        docs = {
            'DOCUMENTATION': {
                'value': None,
                'lineno': 0
            },
            'EXAMPLES': {
                'value': None,
                'lineno': 0
            },
            'RETURN': {
                'value': None,
                'lineno': 0
            }
        }
        for child in self.ast.body:
            if isinstance(child, ast.Assign):
                for grandchild in child.targets:
                    if grandchild.id == 'DOCUMENTATION':
                        docs['DOCUMENTATION']['value'] = child.value.s
                        docs['DOCUMENTATION']['lineno'] = child.lineno
                    elif grandchild.id == 'EXAMPLES':
                        docs['EXAMPLES']['value'] = child.value.s[1:]
                        docs['EXAMPLES']['lineno'] = child.lineno
                    elif grandchild.id == 'RETURN':
                        docs['RETURN']['value'] = child.value.s
                        docs['RETURN']['lineno'] = child.lineno

        return docs

    def _validate_docs_schema(self, doc):
        errors = []
        try:
            doc_schema(doc)
        except Exception as e:
            errors.extend(e.errors)

        options = doc.get('options', {})
        for key, option in (options or {}).iteritems():
            try:
                option_schema(option)
            except Exception as e:
                for error in e.errors:
                    error.path[:0] = ['options', key]
                errors.extend(e.errors)

        for error in errors:
            path = [str(p) for p in error.path]
            self.errors.append('DOCUMENTATION.%s: %s' %
                               ('.'.join(path), error.error_message))

    def _validate_docs(self):
        doc_info = self._get_docs()
        try:
            doc = yaml.safe_load(doc_info['DOCUMENTATION']['value'])
        except yaml.YAMLError as e:
            doc = None
            # This offsets the error line number to where the
            # DOCUMENTATION starts so we can just go to that line in the
            # module
            e.problem_mark.line += (
                doc_info['DOCUMENTATION']['lineno'] - 1
            )
            e.problem_mark.name = '%s.DOCUMENTATION' % self.name
            self.traces.append(e)
            self.errors.append('DOCUMENTATION is not valid YAML. Line %d '
                               'column %d' %
                               (e.problem_mark.line + 1,
                                e.problem_mark.column + 1))
        except AttributeError:
            self.errors.append('No DOCUMENTATION provided')
        else:
            with CaptureStd():
                try:
                    get_docstring(self.path, verbose=True)
                except AssertionError:
                    fragment = doc['extends_documentation_fragment']
                    self.errors.append('DOCUMENTATION fragment missing: %s' %
                                       fragment)
                except Exception as e:
                    self.traces.append(e)
                    self.errors.append('Unknown DOCUMENTATION error, see '
                                       'TRACE')

            self._validate_docs_schema(doc)
            self._check_version_added(doc)
            self._check_for_new_args(doc)

        if not bool(doc_info['EXAMPLES']['value']):
            self.errors.append('No EXAMPLES provided')

        if not bool(doc_info['RETURN']['value']):
            if self._is_new_module():
                self.errors.append('No RETURN documentation provided')
            else:
                self.warnings.append('No RETURN provided')
        else:
            try:
                yaml.safe_load(doc_info['RETURN']['value'])
            except yaml.YAMLError as e:
                e.problem_mark.line += (
                    doc_info['RETURN']['lineno'] - 1
                )
                e.problem_mark.name = '%s.RETURN' % self.name
                self.errors.append('RETURN is not valid YAML. Line %d '
                                   'column %d' %
                                   (e.problem_mark.line + 1,
                                    e.problem_mark.column + 1))
                self.traces.append(e)

    def _check_version_added(self, doc):
        if not self._is_new_module():
            return

        try:
            version_added = StrictVersion(str(doc.get('version_added', '0.0')))
        except ValueError:
            version_added = doc.get('version_added', '0.0')
            self.errors.append('version_added is not a valid version '
                               'number: %r' % version_added)
            return

        should_be = '.'.join(ansible_version.split('.')[:2])
        strict_ansible_version = StrictVersion(should_be)

        if (version_added < strict_ansible_version or
                strict_ansible_version < version_added):
            self.errors.append('version_added should be %s. Currently %s' %
                               (should_be, version_added))

    def _validate_argument_spec(self):
        if not self.analyze_arg_spec:
            return
        spec = get_argument_spec(self.path)
        for arg, data in spec.items():
            if data.get('required') and data.get('default', object) != object:
                self.errors.append('"%s" is marked as required but specifies '
                                   'a default. Arguments with a default '
                                   'should not be marked as required' % arg)

    def _check_for_new_args(self, doc):
        if self._is_new_module():
            return

        with CaptureStd():
            try:
                existing = module_loader.find_plugin(self.name, mod_type='.py')
                existing_doc, _, _ = get_docstring(existing, verbose=True)
                existing_options = existing_doc.get('options', {})
            except AssertionError:
                fragment = doc['extends_documentation_fragment']
                self.warnings.append('Pre-existing DOCUMENTATION fragment '
                                     'missing: %s' % fragment)
                return
            except Exception as e:
                self.warning_traces.append(e)
                self.warnings.append('Unknown pre-existing DOCUMENTATION '
                                     'error, see TRACE. Submodule refs may '
                                     'need updated')
                return

        try:
            mod_version_added = StrictVersion(
                str(existing_doc.get('version_added', '0.0'))
            )
        except ValueError:
            mod_version_added = StrictVersion('0.0')

        options = doc.get('options', {})

        should_be = '.'.join(ansible_version.split('.')[:2])
        strict_ansible_version = StrictVersion(should_be)

        for option, details in options.iteritems():
            new = not bool(existing_options.get(option))
            if not new:
                continue

            try:
                version_added = StrictVersion(
                    str(details.get('version_added', '0.0'))
                )
            except ValueError:
                version_added = details.get('version_added', '0.0')
                self.errors.append('version_added for new option (%s) '
                                   'is not a valid version number: %r' %
                                   (option, version_added))
                continue
            except:
                # If there is any other exception it should have been caught
                # in schema validation, so we won't duplicate errors by
                # listing it again
                continue

            if (strict_ansible_version != mod_version_added and
                    (version_added < strict_ansible_version or
                     strict_ansible_version < version_added)):
                self.errors.append('version_added for new option (%s) should '
                                   'be %s. Currently %s' %
                                   (option, should_be, version_added))

    def validate(self):
        super(ModuleValidator, self).validate()

        # Blacklists -- these files are not checked
        if not frozenset((self.basename,
                          self.name)).isdisjoint(self.BLACKLIST):
            return
        for pat in self.BLACKLIST_PATTERNS:
            if fnmatch(self.basename, pat):
                return

#        if self._powershell_module():
#            self.warnings.append('Cannot check powershell modules at this '
#                                 'time.  Skipping')
#            return
        if not self._python_module() and not self._powershell_module():
            self.errors.append('Official Ansible modules must have a .py '
                               'extension for python modules or a .ps1 '
                               'for powershell modules')
            self._python_module_override = True

        if self._python_module() and self.ast is None:
            self.errors.append('Python SyntaxError while parsing module')
            try:
                compile(self.text, self.path, 'exec')
            except Exception as e:
                self.traces.append(e)
            return

        if self._python_module():
            self._validate_docs()

        if self._python_module() and not self._just_docs():
            self._validate_argument_spec()
            self._check_for_sys_exit()
            self._find_blacklist_imports()
            main = self._find_main_call()
            self._find_module_utils(main)
            self._find_has_import()
            self._check_for_tabs()

        if self._powershell_module():
            self._find_ps_replacers()
            self._find_ps_docs_py_file()

        self._check_for_gpl3_header()
        if not self._just_docs():
            self._check_interpreter(powershell=self._powershell_module())


class PythonPackageValidator(Validator):
    BLACKLIST_FILES = frozenset(('__pycache__',))

    def __init__(self, path):
        super(PythonPackageValidator, self).__init__()

        self.path = path
        self.basename = os.path.basename(path)

    @property
    def object_name(self):
        return self.basename

    @property
    def object_path(self):
        return self.path

    def validate(self):
        super(PythonPackageValidator, self).validate()

        if self.basename in self.BLACKLIST_FILES:
            return

        init_file = os.path.join(self.path, '__init__.py')
        if not os.path.exists(init_file):
            self.errors.append('Ansible module subdirectories must contain an '
                               '__init__.py')


def re_compile(value):
    """
    Argparse expects things to raise TypeError, re.compile raises an re.error
    exception

    This function is a shorthand to convert the re.error exception to a
    TypeError
    """

    try:
        return re.compile(value)
    except re.error as e:
        raise TypeError(e)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('modules', nargs='+',
                        help='Path to module or module directory')
    parser.add_argument('-w', '--warnings', help='Show warnings',
                        action='store_true')
    parser.add_argument('--exclude', help='RegEx exclusion pattern',
                        type=re_compile)
    parser.add_argument('--arg-spec', help='Analyze module argument spec',
                        action='store_true', default=False)
    args = parser.parse_args()

    args.modules[:] = [m.rstrip('/') for m in args.modules]

    exit = []

    for module in args.modules:
        if os.path.isfile(module):
            path = module
            if args.exclude and args.exclude.search(path):
                sys.exit(0)
            mv = ModuleValidator(path, analyze_arg_spec=args.arg_spec)
            mv.validate()
            exit.append(mv.report(args.warnings))

        for root, dirs, files in os.walk(module):
            basedir = root[len(module)+1:].split('/', 1)[0]
            if basedir in BLACKLIST_DIRS:
                continue
            for dirname in dirs:
                if root == module and dirname in BLACKLIST_DIRS:
                    continue
                path = os.path.join(root, dirname)
                if args.exclude and args.exclude.search(path):
                    continue
                pv = PythonPackageValidator(path)
                pv.validate()
                exit.append(pv.report(args.warnings))

            for filename in files:
                path = os.path.join(root, filename)
                if args.exclude and args.exclude.search(path):
                    continue
                mv = ModuleValidator(path, analyze_arg_spec=args.arg_spec)
                mv.validate()
                exit.append(mv.report(args.warnings))

    sys.exit(sum(exit))


if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        pass
