"""Unit tests for all time series regressors."""

__author__ = ["mloning", "TonyBagnall", "fkiraly"]


import types
from contextlib import nullcontext

import numpy as np
import pandas as pd
import pytest

from sktime.datasets.base._base import InvalidSetError
from sktime.datatypes._check import check_is_mtype
from sktime.tests.test_all_estimators import BaseFixtureGenerator, QuickTester


def is_generator(obj):
    return isinstance(obj, types.GeneratorType)


class DatasetFixtureGenerator(BaseFixtureGenerator):
    """Fixture generator for classifier tests.

    Fixtures parameterized
    ----------------------
    estimator_class: estimator inheriting from BaseObject
        ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
    estimator_instance: instance of estimator inheriting from BaseObject
        ranges over estimator classes not excluded by EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
        instances are generated by create_test_instance class method
    scenario: instance of TestScenario
        ranges over all scenarios returned by retrieve_scenarios
    """

    # note: this should be separate from TestAllRegressors
    #   additional fixtures, parameters, etc should be added here
    #   TestAllRegressors should contain the tests only

    estimator_type_filter = "dataset"


class TestAllDatasets(DatasetFixtureGenerator, QuickTester):
    """Module level tests for all sktime datasets."""

    def test_load_output_type(self, estimator_instance):
        """Verify if the output of load is of the expected type."""

        available_keys = estimator_instance.keys()
        assert len(available_keys) > 0

        outputs = estimator_instance.load(*available_keys)
        for i, key in enumerate(available_keys):
            output = outputs[i]
            if key != "cv":
                assert isinstance(
                    output, (pd.DataFrame, pd.Series, np.ndarray, type(None))
                )
            else:
                # Check if the output is a generator
                assert is_generator(output)

    def test_tag_n_instances(self, estimator_instance):
        """Check the number of instances."""
        n_instances = estimator_instance.get_tag("n_instances")
        y = estimator_instance.load("y")
        if check_is_mtype(y, "pd-multiindex"):
            assert len(y.index.levels[0]) == n_instances
        elif check_is_mtype(y, "pd.Series") or check_is_mtype(y, "pd.DataFrame"):
            assert 1 == n_instances
        elif check_is_mtype(y, "pd_multiindex_hier"):
            assert y.index.droplevel(-1).nunique() == n_instances
        else:
            assert len(y) == n_instances

    def test_tag_n_instances_train(self, estimator_instance):
        """Check the number of instances in the training set.

        If the dataset has no training set, the tests for InvalidSetError
        """
        n_instances_train = estimator_instance.get_tag("n_instances_train")

        contextwrapper = (
            nullcontext() if n_instances_train > 0 else pytest.raises(InvalidSetError)
        )
        with contextwrapper:
            y_train = estimator_instance.load("y_train")
            if check_is_mtype(y_train, "pd-multiindex"):
                assert len(y_train.index.levels[0]) == n_instances_train
            elif check_is_mtype(y_train, "pd.Series") or check_is_mtype(
                y_train, "pd.DataFrame"
            ):
                assert 1 == n_instances_train
            elif check_is_mtype(y_train, "pd_multiindex_hier"):
                assert y_train.index.droplevel(-1).nunique() == n_instances_train
            else:
                assert len(y_train) == n_instances_train

    def test_tag_n_instances_test(self, estimator_instance):
        """Check the number of instances in the test set.


        If the dataset has no training set, the tests for InvalidSetError
        """
        n_instances_test = estimator_instance.get_tag("n_instances_test")

        contextwrapper = (
            nullcontext() if n_instances_test > 0 else pytest.raises(InvalidSetError)
        )
        with contextwrapper:
            y_test = estimator_instance.load("y_test")
            if check_is_mtype(y_test, "pd-multiindex"):
                assert len(y_test.index.levels[0]) == n_instances_test
            elif check_is_mtype(y_test, "pd.Series") or check_is_mtype(
                y_test, "pd.DataFrame"
            ):
                assert 1 == n_instances_test
            elif check_is_mtype(y_test, "pd_multiindex_hier"):
                assert y_test.index.droplevel(-1).nunique() == n_instances_test
            else:
                assert len(y_test) == n_instances_test

    def test_tag_n_splits(self, estimator_instance):
        """Check the number of splits."""
        n_splits = estimator_instance.get_tag("n_splits")
        if n_splits > 1:
            cv = estimator_instance.load("cv")
            # assert length of cv is equal to the number of splits
            assert len(list(cv)) == n_splits
        else:
            with pytest.raises(InvalidSetError):
                estimator_instance.load("cv")
