#!/usr/bin/env python

import warnings
from argparse import ArgumentParser
import os
import os.path as op
import toml
import datetime
import platform

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    import AFQ
    from AFQ import api
    from AFQ.utils.bin import get_default_args
    import AFQ.segmentation as seg
    import AFQ.tractography as aft

    import logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger('pyAFQ')

def process_defaults(defaults_dict):
    defaults_list = []
    for k, v in defaults_dict.items():
        if isinstance(v, str):
            defaults_list.append(f"{k} = '{v}'")
        elif isinstance(v, bool):
            if v:
                defaults_list.append(f"{k} = true")
            else:
                defaults_list.append(f"{k} = false")
        elif callable(v):
            defaults_list.append(f"{k} = '{v.__name__}'")
        else:
            defaults_list.append(f"{k} = {v}")
    return defaults_list


track_defaults = get_default_args(aft.track)
track_defaults_list = process_defaults(track_defaults)
seg_defaults = get_default_args(seg.Segmentation.__init__)
seg_defaults_list = process_defaults(seg_defaults)
clean_defaults = get_default_args(seg.clean_bundle)
clean_defaults_list = process_defaults(clean_defaults)

# For all lists that contain a "None" default value, replace it with an
# empty string to not break toml syntax. As of Feb-7-2020, that is only
# track_defaults_list and seg_defaults_list
none_val = "''"
replace_none = lambda myl, rep=none_val: [v.replace('None', rep) for v in myl]
track_defaults_list = replace_none(track_defaults_list)
seg_defaults_list = replace_none(seg_defaults_list)

usage = f"""pyAFQ /path/to/afq_config.toml

Runs full AFQ processing as specified in the configuration file.

For details about configuration, see instructions in:
https://yeatmanlab.github.io/pyAFQ/usage.html#running-the-pyafq-pipeline

The default values of the configuration are:

```
[files]
dmriprep_path = '/path/to/dmriprep/folder'

[bundles]
bundles = {api.BUNDLES}
seg_algo = 'AFQ'
scalars_model = 'DTI'
scalars = ['dti_fa', 'dti_md']

[tracking]
{chr(10).join(track_defaults_list)}
wm_labels = [250, 251, 252, 253, 254, 255, 41, 2, 16, 77]

[segmentation]
{chr(10).join(seg_defaults_list)}

[cleaning]
{chr(10).join(clean_defaults_list)}

[compute]
dask_it = false
force_recompute = false
virtual_frame_buffer = false
```
"""

parser = ArgumentParser(usage)

parser.add_argument(dest='config', action="store", type=str,
                    help="Path to config file")

parser.add_argument('--notrack', action="store_true", default=False)

opts = parser.parse_args()

if not opts.notrack:
    logger.info(
        "Your use of pyAFQ is being recorded using Google Analytics. "
        "For more information, please refer to the pyAFQ documentation: "
        "https://yeatmanlab.github.io/pyAFQ/usage.html#usage-tracking-with-google-analytics. "  #noqa
        "To turn this off, use the `--notrack` flag when using the pyAFQ CLI")
    import popylar
    popylar.track_event(AFQ._ga_id, "pyAFQ_cli", "CLI called")

config = toml.load(opts.config)
# Delete empty values that are intended to be "None" so that the parser can
# substitute default values in appropriately
for component in config.keys():
    config[component] = {k: v
                         for k, v in config[component].items()
                         if v != eval(none_val)}

files = config.get("files", {})

dmriprep_path = files.get("dmriprep_path")

if dmriprep_path is None:
    raise RuntimeError("Config file must provide dmriprep_path")

config['pyAFQ'] = {}
config['pyAFQ']['utc_time_started'] = datetime.datetime.now().isoformat('T')
config['pyAFQ']['version'] = AFQ.__version__
config['pyAFQ']['platform'] = platform.system()

afq_path = op.join(op.split(dmriprep_path)[0], 'afq')
os.makedirs(afq_path, exist_ok=True)

afq_metadata_file = op.join(afq_path, 'afq_metadata.toml')

with open(afq_metadata_file, 'w') as ff:
    toml.dump(config, ff)

sub_prefix = files.get("sub_prefix", "sub")
dwi_folder = files.get("dwi_folder", "dwi")
dwi_file = files.get("dwi_file", "*dwi")
anat_folder = files.get("anat_folder", "anat")
anat_file = files.get("anat_file", "*T1w*")
seg_file = files.get("seg_file", "*aparc+aseg*")
reg_template = files.get("reg_template", None)

tracking = config.get("tracking", {})
tracking_params = {}
for k, v in track_defaults.items():
    tracking_params[k] = tracking.get(k, v)

wm_labels = tracking.get(
    "wm_labels",
    [250, 251, 252, 253, 254, 255, 41, 2, 16, 77])

segmentation = config.get("segmentation", {})
segmentation_params = {}
for k, v in seg_defaults.items():
    segmentation_params[k] = segmentation.get(k, v)
b0_threshold = segmentation_params.get("b0_threshold", 0)

bundles = config.get("bundles", {})
bundle_names = bundles.get("bundles", api.BUNDLES)
scalars_model = bundles.get("scalars_model", "DTI")
scalars = bundles.get("scalars", ["fa", "md"])

cleaning = config.get("cleaning", {})
clean_params = {}
for k, v in clean_defaults.items():
    clean_params[k] = cleaning.get(k, v)

compute = config.get("compute", {})
dask_it = compute.get("dask_it", False)
force_recompute = compute.get("force_recompute", False)
virtual_frame_buffer = compute.get("virtual_frame_buffer", False)

myafq = api.AFQ(dmriprep_path,
                sub_prefix=sub_prefix,
                dwi_folder=dwi_folder,
                dwi_file=dwi_file,
                anat_folder=anat_folder,
                anat_file=anat_file,
                seg_file=seg_file,
                b0_threshold=float(b0_threshold),
                bundle_names=bundle_names,
                dask_it=dask_it,
                force_recompute=force_recompute,
                reg_template=reg_template,
                wm_labels=wm_labels,
                virtual_frame_buffer=virtual_frame_buffer,
                segmentation_params=segmentation_params,
                tracking_params=tracking_params,
                clean_params=clean_params)


# Do all the things:
myafq.set_dti_cfa()
myafq.set_dti_pdd()
myafq.set_template_xform()
if segmentation_params["seg_algo"].lower() == "afq":
    myafq.export_rois()
    myafq.export_ROI_gifs()
    myafq.export_bundle_gif()
myafq.export_bundles()
myafq.combine_profiles()

# If you got this far, you can report on time ended and record that:
config['pyAFQ']['utc_time_ended'] = datetime.datetime.now().isoformat('T')

with open(afq_metadata_file, 'w') as ff:
    toml.dump(config, ff)
