#!/usr/bin/env python
import sys

from res.enkf import EnKFMain, ResConfig, ESUpdate, ErtRunContext
from res.enkf.enums import RealizationStateEnum, HookRuntime
from ecl.util.util import BoolVector
import res
from res.util import ResLog

import argparse


def setup_fs(ert, target="default"):
    fs_manager = ert.getEnkfFsManager()

    src_fs = fs_manager.getCurrentFileSystem()
    tar_fs = fs_manager.getFileSystem(target)

    return src_fs, tar_fs


def resconfig(config_file):
    return ResConfig(user_config_file=config_file)


def _test_run(ert):
    source_fs, _ = setup_fs(ert)

    model_config = ert.getModelConfig()
    subst_list = ert.getDataKW()

    single_mask = BoolVector(default_value=False)
    single_mask[0] = True

    run_context = ErtRunContext.ensemble_experiment(
                                sim_fs=source_fs,
                                mask=single_mask,
                                path_fmt=model_config.getRunpathFormat(),
                                jobname_fmt=model_config.getJobnameFormat(),
                                subst_list=subst_list,
                                itr=0)

    sim_runner = ert.getEnkfSimulationRunner()
    _run_ensemble_experiment(ert, run_context, sim_runner)


def _experiment_run(ert):
    source_fs, _ = setup_fs(ert)

    model_config = ert.getModelConfig()
    subst_list = ert.getDataKW()

    mask = BoolVector(default_value=True, initial_size=ert.getEnsembleSize())

    run_context = ErtRunContext.ensemble_experiment(
                                sim_fs=source_fs,
                                mask=mask,
                                path_fmt=model_config.getRunpathFormat(),
                                jobname_fmt=model_config.getJobnameFormat(),
                                subst_list=subst_list,
                                itr=0)

    sim_runner = ert.getEnkfSimulationRunner()
    _run_ensemble_experiment(ert, run_context, sim_runner)


def _ensemble_smoother_run(ert, target_case):
    source_fs, target_fs = setup_fs(ert, target_case)

    model_config = ert.getModelConfig()
    subst_list = ert.getDataKW()

    mask = BoolVector(default_value=True, initial_size=ert.getEnsembleSize())

    prior_context = ErtRunContext.ensemble_smoother(
                                sim_fs=source_fs,
                                target_fs=target_fs,
                                mask=mask,
                                path_fmt=model_config.getRunpathFormat(),
                                jobname_fmt=model_config.getJobnameFormat(),
                                subst_list=subst_list,
                                itr=0)

    sim_runner = ert.getEnkfSimulationRunner()
    _run_ensemble_experiment(ert, prior_context, sim_runner)
    sim_runner.runWorkflows( HookRuntime.PRE_UPDATE )

    es_update = ESUpdate(ert)
    success = es_update.smootherUpdate(prior_context)
    if not success:
        raise AssertionError("Analysis of simulation failed!")

    sim_runner.runWorkflows( HookRuntime.POST_UPDATE )

    ert.getEnkfFsManager().switchFileSystem(prior_context.get_target_fs())

    sim_fs = prior_context.get_target_fs( )
    state = RealizationStateEnum.STATE_HAS_DATA | RealizationStateEnum.STATE_INITIALIZED
    mask = sim_fs.getStateMap().createMask(state)

    rerun_context = ErtRunContext.ensemble_smoother(
                                sim_fs=sim_fs,
                                target_fs=None,
                                mask=mask,
                                path_fmt=model_config.getRunpathFormat(),
                                jobname_fmt=model_config.getJobnameFormat(),
                                subst_list=subst_list,
                                itr=1)

    _run_ensemble_experiment(ert, rerun_context, sim_runner)


def _run_ensemble_experiment(ert, run_context, sim_runner):
    sim_runner.createRunPath(run_context)
    sim_runner.runWorkflows(HookRuntime.PRE_SIMULATION)

    job_queue = ert.get_queue_config().create_job_queue()
    num_successful_realizations = sim_runner.runEnsembleExperiment(job_queue, run_context)
    _assert_minium_realizations_success(ert, num_successful_realizations)

    print("{} of the realizations were successful".format(num_successful_realizations))
    sim_runner.runWorkflows( HookRuntime.POST_SIMULATION )


def _assert_minium_realizations_success(ert, num_successful_realizations):
    if num_successful_realizations == 0:
        raise AssertionError("Simulation failed! All realizations failed!")
    elif not ert.analysisConfig().haveEnoughRealisations(num_successful_realizations, ert.getEnsembleSize()):
        raise AssertionError("Too many simulations have failed! You can add/adjust MIN_REALIZATIONS to allow failures in your simulations.\n\n"
                            "Check ERT log file '%s' or simulation folder for details." % ResLog.getFilename())


def main():
    # The ert_cli script should be called from ert.in. The arguments are parsed and verified in ert.in
    if len(sys.argv) < 3:
        raise AssertionError("Required arguments are missing, the config-file, "
                             "mode and target case must be provided")
    config_file, mode, target_case = sys.argv[1:]

    config = resconfig(config_file)
    ert = EnKFMain(config)

    if mode == "test_run":
        _test_run(ert)
    if mode == "ensemble_experiment":
        _experiment_run(ert)
    if mode == "ensemble_smoother":
        _ensemble_smoother_run(ert, target_case)


if __name__ == '__main__':
    main()
