"""A collection of tests for tenpy.linalg.charges"""
# Copyright 2018 TeNPy Developers

import tenpy.linalg.charges as charges
import numpy as np
import numpy.testing as npt
import nose.tools as nst
import itertools as it
from random_test import gen_random_legcharge

# charges for comparison, unsorted (*_us) and sorted (*_s)
qflat_us = np.array([
    -6, -6, -6, -4, -4, -4, 4, 4, -4, -4, -4, -4, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 0, 0, -2, -2, -2, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4,
    4, 4, 4, 6, 6
]).reshape((-1, 1))
slices_us = np.array([0, 3, 6, 8, 12, 23, 34, 37, 41, 55, 59, 61])
charges_us = np.array([[-6], [-4], [4], [-4], [-2], [0], [-2], [0], [2], [4], [6]])
# sorted
qflat_s = np.sort(qflat_us, axis=0)
slices_s = np.array([0, 3, 10, 24, 39, 53, 59, 61])
charges_s = np.array([[-6], [-4], [-2], [0], [2], [4], [6]])

qdict_s = {
    (-6, ): slice(0, 3),
    (-4, ): slice(3, 10),
    (-2, ): slice(10, 24),
    (0, ): slice(24, 39),
    (2, ): slice(39, 53),
    (4, ): slice(53, 59),
    (6, ): slice(59, 61)
}

ch_1 = charges.ChargeInfo([1])


def test_ChargeInfo():
    trivial = charges.ChargeInfo()
    trivial.test_sanity()
    print("trivial: ", trivial)
    nst.eq_(trivial.qnumber, 0)
    chinfo = charges.ChargeInfo([3, 1], ['some', ''])
    print("nontrivial chinfo: ", chinfo)
    nst.eq_(chinfo.qnumber, 2)
    qs = [[0, 2], [2, 0], [5, 3], [-2, -3]]
    is_valid = [True, True, False, False]
    for q, valid in zip(qs, is_valid):
        print(q, valid)
        check = chinfo.check_valid(np.array([q]))
        print(check)
        nst.eq_(check, valid)
    qs_valid = np.array([chinfo.make_valid(q) for q in qs])
    npt.assert_equal(qs_valid, chinfo.make_valid(qs))
    chinfo2 = charges.ChargeInfo([3, 1], ['some', ''])
    assert (chinfo2 == chinfo)
    chinfo3 = charges.ChargeInfo([3, 1], ['other', ''])
    assert (chinfo3 != chinfo)


def test__find_row_differences():
    for qflat in [qflat_us, qflat_s]:
        diff = charges._find_row_differences(qflat)
        comp = [0] + [i for i in range(1, len(qflat))
                      if np.any(qflat[i - 1] != qflat[i])] + [len(qflat)]
        npt.assert_equal(diff, comp)


def test_LegCharge():
    lcs = charges.LegCharge.from_qflat(ch_1, qflat_s).bunch()[1]
    npt.assert_equal(lcs.charges, charges_s)  # check from_qflat
    npt.assert_equal(lcs.slices, slices_s)  # check from_qflat
    npt.assert_equal(lcs.to_qflat(), qflat_s)  # check to_qflat
    lcus = charges.LegCharge.from_qflat(ch_1, qflat_us).bunch()[1]
    npt.assert_equal(lcus.charges, charges_us)  # check from_qflat
    npt.assert_equal(lcus.slices, slices_us)  # check from_qflat
    npt.assert_equal(lcus.to_qflat(), qflat_us)  # check to_qflat

    lc = charges.LegCharge.from_qdict(ch_1, qdict_s)
    npt.assert_equal(lc.charges, charges_s)  # check from_qdict
    npt.assert_equal(lc.slices, slices_s)  # check from_dict
    npt.assert_equal(lc.to_qdict(), qdict_s)  # chec to_qdict
    nst.eq_(lcs.is_sorted(), True)
    nst.eq_(lcs.is_blocked(), True)
    nst.eq_(lcus.is_sorted(), False)
    nst.eq_(lcus.is_blocked(), False)

    # test sort & bunch
    lcus_charges = lcus.charges.copy()
    pqind, lcus_s = lcus.sort(bunch=False)
    lcus_s.test_sanity()
    npt.assert_equal(lcus_charges, lcus.charges)  # don't change the old instance
    npt.assert_equal(lcus_s.charges, lcus.charges[pqind])  # permutation returned by sort ok?
    nst.eq_(lcus_s.is_sorted(), True)
    nst.eq_(lcus_s.is_bunched(), False)
    nst.eq_(lcus_s.is_blocked(), False)
    nst.eq_(lcus_s.ind_len, lcus.ind_len)
    nst.eq_(lcus_s.block_number, lcus.block_number)
    idx, lcus_sb = lcus_s.bunch()
    lcus_sb.test_sanity()
    lcus_sb.sorted = False  # to ensure that is_blocked really runs the check
    nst.eq_(lcus_sb.is_sorted(), True)
    nst.eq_(lcus_sb.is_bunched(), True)
    nst.eq_(lcus_sb.is_blocked(), True)
    nst.eq_(lcus_sb.ind_len, lcus.ind_len)

    # test get_qindex
    for i in range(lcs.ind_len):
        qidx, idx_in_block = lcs.get_qindex(i)
        assert (lcs.slices[qidx] <= i < lcs.slices[qidx + 1])
        assert (lcs.slices[qidx] + idx_in_block == i)


def test_LegPipe():
    shape = (20, 10, 8)
    legs = [gen_random_legcharge(ch_1, s) for s in shape]
    for sort, bunch in it.product([True, False], repeat=2):
        pipe = charges.LegPipe(legs, sort=sort, bunch=bunch)
        pipe.test_sanity()
        assert (pipe.ind_len == np.prod(shape))
        print(pipe.q_map)
        # test pipe._map_incoming_qind
        qind_inc = pipe.q_map[:, 3:].copy()  # all possible qindices
        np.random.shuffle(qind_inc)  # different order to make the test non-trivial
        qmap_ind = pipe._map_incoming_qind(qind_inc)
        for i in range(len(qind_inc)):
            npt.assert_equal(pipe.q_map[qmap_ind[i], 3:], qind_inc[i])
            size = np.prod([l.slices[j + 1] - l.slices[j] for l, j in zip(legs, qind_inc[i])])
            nst.eq_(size, pipe.q_map[qmap_ind[i], 1] - pipe.q_map[qmap_ind[i], 0])
        # pipe.map_incoming_flat is tested by test_np_conserved.


if __name__ == "__main__":
    test_ChargeInfo()
    test__find_row_differences()
    test_LegCharge()
    test_LegPipe()
