"""Unit tests for all time series regressors."""

__author__ = ["mloning", "TonyBagnall", "fkiraly"]


import numpy as np
import pandas as pd

from sktime.datatypes import check_is_mtype
from sktime.tests.test_all_estimators import BaseFixtureGenerator, QuickTester

# Pairs of message and lambda function to check the tags
_tag_constraints = [
    (
        "n_dimensions should be equal to one if 'is_univariate' is True",
        lambda x: x["n_dimensions"] == 1 if x["is_univariate"] else True,
    ),
    (
        "n_panels should be equal to one if 'is_one_panel' is True",
        lambda x: x["n_panels"] == 1 if x["is_one_panel"] else True,
    ),
    (
        "hierarchical datasets should have n_panels greater than one",
        lambda x: x["n_panels"] > 1 if x["n_hierarchy_levels"] > 0 else True,
    ),
]


class ForecastingDatasetFixtureGenerator(BaseFixtureGenerator):
    """Fixture generator for classifier tests.

    Fixtures parameterized
    ----------------------
    estimator_class: estimator inheriting from BaseObject
        ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
    estimator_instance: instance of estimator inheriting from BaseObject
        ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
        instances are generated by create_test_instance class method
    scenario: instance of TestScenario
        ranges over all scenarios returned by retrieve_scenarios
    """

    # note: this should be separate from TestAllRegressors
    #   additional fixtures, parameters, etc should be added here
    #   TestAllRegressors should contain the tests only

    estimator_type_filter = "dataset_forecasting"


class TestAllForecastingDatasets(ForecastingDatasetFixtureGenerator, QuickTester):
    """Module level tests for all sktime regressors."""

    def test_tag_is_one_series(self, estimator_instance):
        expected = estimator_instance.get_tag("is_one_series")
        y = estimator_instance.load("y")

        if isinstance(y, (pd.DataFrame, pd.Series)):
            is_one_series = y.index.nlevels == 1 or y.index.droplevel(-1).nunique() == 1
        elif isinstance(y, np.ndarray):
            is_one_series = y.ndim == 1 or y.shape[1] == 1
        else:
            raise ValueError(f'Unexpected type "{type(y)}" for y')
        assert is_one_series == expected

    def test_tag_n_panels(self, estimator_instance):
        expected = estimator_instance.get_tag("n_panels")
        y = estimator_instance.load("y")

        if check_is_mtype(y, "pd-multiindex"):
            n_panels = 1

        elif check_is_mtype(y, "pd_multiindex_hier"):
            n_panels = len(y.index.droplevel(-1).drop_duplicates())

        else:
            n_panels = 1

        assert n_panels == expected

    def test_tag_n_hierarchy_levels(self, estimator_instance):
        expected = estimator_instance.get_tag("n_hierarchy_levels")
        y = estimator_instance.load("y")
        n_hierarchy_levels = y.index.nlevels - 1
        if check_is_mtype(y, "pd.Series") or check_is_mtype(y, "pd-multiindex"):
            n_hierarchy_levels = 0

        elif check_is_mtype(y, "pd_multiindex_hier"):
            n_hierarchy_levels = y.index.nlevels - 1

        else:
            n_hierarchy_levels = 0

        assert n_hierarchy_levels == expected

    def test_tag_constraints(self, estimator_instance):
        tags = estimator_instance.get_tags()
        for constraint in _tag_constraints:
            assert constraint[1](tags), constraint[0]

    def test_tag_is_univariate(self, estimator_instance):
        is_univariate = estimator_instance.get_tag("is_univariate")
        y = estimator_instance.load("y")

        n_columns = 1 if isinstance(y, pd.Series) else y.shape[1]

        assert n_columns == 1 if is_univariate else n_columns > 1
