"""This module contains a base class for a Matrix Product State (MPS).

An MPS looks roughly like this::

    |   -- B[0] -- B[1] -- B[2] -- ...
    |       |       |      |

We use the following label convention for the `B` (where arrows indicate `qconj`)::

    |  vL ->- B ->- vR
    |         |
    |         ^
    |         p

We store one 3-leg tensor `_B[i]` with labels ``'vL', 'vR', 'p'`` for each of the `L` sites
``0 <= i < L``.
Additionally, we store ``L+1`` singular value arrays `_S[ib]` on each bond ``0 <= ib <= L``,
independent of the boundary conditions.
``_S[ib]`` gives the singlur values on the bond ``i-1, i``.
However, be aware that e.g. :attr:`MPS.chi` returns only the dimensions of the
:attr:`MPS.nontrivial_bonds` depending on the boundary conditions.

We restrict ourselves to normalized states (i.e. ``np.linalg.norm(psi._S[ib]) == 1`` up to
roundoff errors).

For efficient simulations, it is crucial that the MPS is in a 'canonical form'.
The different forms and boundary conditions are easiest described in Vidal's
:math:`\Gamma, \Lambda` notation [1]_.

Valid MPS boundary conditions (not to confuse with `bc_coupling` of
:class:`tenpy.models.model.CouplingModel`)  are the following:

==========  ===================================================================================
`bc`        description
==========  ===================================================================================
'finite'    Finite MPS, ``G0 s1 G1 ... s{L-1} G{l-1}``. This is acchieved
            by using a trivial left and right bond ``s[0] = s[-1] = np.array([1.])``.
'segment'   Generalization of 'finite', describes an MPS embedded in left and right
            environments. The left environment is described by ``chi[0]`` *orthonormal* states
            which are weighted by the singular values ``s[0]``. Similar, ``s[L]`` weight some
            right orthonormal states. You can think of the left and right states to be
            generated by additional MPS, such that the overall structure is something like
            ``... s L s L [s0 G0 s1 G1 ... s{L-1} G{L-1} s{L}] R s R s R ... ``
            (where we save the part in the brackets ``[ ... ]``).
'infinite'  infinite MPS (iMPS): we save a 'MPS unit cell' ``[s0 G0 s1 G1 ... s{L-1} G{L-1}]``
            which is repeated periodically, identifying all indices modulo ``self.L``.
            In particular, the last bond ``L`` is identified with ``0``.
            (The MPS unit cell can differ from a lattice unit cell).
            bond is identified with the first one.
==========  ===================================================================================

An MPS can be in different 'canonical forms' (see [Vidal2004]_, [Schollwoeck2011]_).
To take care of the different canonical forms, algorithms should use functions like
:meth:`get_theta`, :meth:`get_B` and :meth:`set_B` instead of accessing them directly,
as they return the `B` in the desired form (which can be chosed as an argument).

======== ========== =======================================================================
`form`   tuple      description
======== ========== =======================================================================
``'B'``  (0, 1)     right canonical: ``_B[i] = -- Gamma[i] -- s[i+1]--``
                    The default form, which algorithms asssume.
``'C'``  (0.5, 0.5) symmetric form: ``_B[i] = -- s[i]**0.5 -- Gamma[i] -- s[i+1]**0.5--``
``'A'``  (1, 0)     left canonical: ``_B[i] = -- s[i] -- Gamma[i] --``.
                    For stability reasons, we recommend to *not* use this form.
``'G'``  (0, 0)     Save only ``_B[i] = -- Gamma[i] --``.
                    For stability reasons, we recommend to *not* use this form.
``None`` ``None``   General non-canoncial form.
                    Valid form for initialization, but you need to call
                    :meth:`canonicalize` (or sub-functions) before using algorithms.
======== ========== =======================================================================

.. todo ::

    - expectaion values
    - canonicalize()
    - much much more ....
    - proper documentation
"""

from __future__ import division
import numpy as np
import itertools
import warnings
import scipy.sparse as sparse
import scipy.sparse.linalg.eigen.arpack

from ..linalg import np_conserved as npc
from ..tools.misc import to_iterable, argsort
from ..tools.math import lcm, speigs


class MPS(object):
    r"""A Matrix Product State, finite (MPS) or infinite (iMPS).

    Parameters
    ----------
    sites : list of :class:`~tenpy.networks.site.Site`
        Defines the local Hilbert space for each site.
    Bs : list of :class:`~tenpy.linalg.np_conserved.Array`
        The 'matrices' of the MPS. Labels are ``vL, vR, p`` (in any order).
    SVs : list of 1D array
        The singular values on *each* bond. Should always have length `L+1`.
        Entries out of :attr:`nontrivial_bonds` are ignored.
    bc : ``'finite' | 'segment' | 'infinite'``
        Boundary conditions as described in the tabel of the module doc-string.
    form : (list of) {``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)}
        The form the stored 'matrices'. The table in module doc-string.
        A single choice holds for all of the entries.

    Attributes
    ----------
    L
    chi
    finite
    nontrivial_bonds
    sites : list of :class:`~tenpy.networks.site.Site`
        Defines the local Hilbert space for each site.
    bc : {'finite', 'segment', 'infinite'}
        Boundary conditions as described in above table.
    form : list of {``None`` | tuple(float, float)}
        Describes the canonical form on each site.
        ``None`` means non-canonical form.
        For ``form = (nuL, nuR)``, the stored ``_B[i]`` are
        ``s**form[0] -- Gamma -- s**form[1]`` (in Vidal's notation).
    chinfo : :class:`~tenpy.linalg.np_conserved.ChargeInfo`
        The nature of the charge.
    dtype : type
        The data type of the `_B`.
    _B : list of :class:`npc.Array`
        The 'matrices' of the MPS. Labels are ``vL, vR, p`` (in any order).
        We recommend using :meth:`get_B` and :meth:`set_B`, which will take care of the different
        canonical forms.
    _S : None | list of 1D arrays
        The singular values on each virtual bond, length ``L+1``.
        May be ``None`` if the MPS is not in canonical form.
        Otherwise, ``_S[i]`` is to the left of ``_B[i]``.
        We recommend using :meth:`get_SL`, :meth:`get_SR`, :meth:`set_SL`, :meth:`set_SR`, which
        take proper care of the boundary conditions.
    _valid_forms : dict
        Mapping for canonical forms to a tuple ``(nuL, nuR)`` indicating that
        ``self._Bs[i] = s[i]**nuL -- Gamma[i] -- s[i]**nuR`` is saved.
    _valid_bc : tuple of str
        Valid boundary conditions.
    """

    # Canonical form conventions: the saved B = s**nu[0]--Gamma--s**nu[1].
    # For the canonical forms, ``nu[0] + nu[1] = 1``
    _valid_forms = {
        'A': (1., 0.),
        'C': (0.5, 0.5),
        'B': (0., 1.),
        'G': (0., 0.),  # like Vidal's `Gamma`.
        None: None,  # means 'not in any canonical form'
    }

    # valid boundary conditions. Don't overwrite this!
    _valid_bc = ('finite', 'segment', 'infinite')

    def __init__(self, sites, Bs, SVs, bc='finite', form='B'):
        self.sites = list(sites)
        self.chinfo = self.sites[0].leg.chinfo
        self.dtype = dtype = np.find_common_type([B.dtype for B in Bs], [])
        self.form = self._parse_form(form)
        self.bc = bc  # one of ``'finite', 'periodic', 'segment'``.

        # make copies of Bs and SVs
        self._B = [B.astype(dtype, copy=True) for B in Bs]
        self._S = [None] * (self.L + 1)
        for i in range(self.L + 1)[self.nontrivial_bonds]:
            self._S[i] = np.array(SVs[i], dtype=np.float)
        if self.bc == 'infinite':
            self._S[-1] = self._S[0]
        elif self.bc == 'finite':
            self._S[0] = self._S[-1] = np.ones([1])
        self.test_sanity()

    def test_sanity(self):
        """Sanity check. Raises Errors if something is wrong."""
        if self.bc not in self._valid_bc:
            raise ValueError("invalid boundary condition: " + repr(self.bc))
        if len(self._B) != self.L:
            raise ValueError("wrong len of self._B")
        if len(self._S) != self.L + 1:
            raise ValueError("wrong len of self._S")
        for i, B in enumerate(self._B):
            if not set(['vL', 'vR', 'p']) <= set(B.get_leg_labels()):
                raise ValueError("B has wrong labels " + repr(B.get_leg_labels()))
            B.test_sanity()  # recursive...
            if self._S[i].shape[-1] != B.get_leg('vL').ind_len or \
                    self._S[i+1].shape[0] != B.get_leg('vR').ind_len:
                raise ValueError("shape of B incompatible with len of singular values")
            if not self.finite or i + 1 < self.L:
                B2 = self._B[(i + 1) % self.L]
                B.get_leg('vR').test_contractible(B2.get_leg('vL'))
        if self.bc == 'finite':
            if len(self._S[0]) != 1 or len(self._S[-1]) != 1:
                raise ValueError("non-trivial outer bonds for finite MPS")
        elif self.bc == 'infinite':
            if np.any(self._S[self.L] != self._S[0]):
                raise ValueError("iMPS with S[0] != S[L]")
        assert len(self.form) == self.L
        for f in self.form:
            if f is not None:
                assert isinstance(f, tuple)
                assert len(f) == 2

    @classmethod
    def from_product_state(cls,
                           sites,
                           p_state,
                           bc='finite',
                           dtype=np.float,
                           form='B',
                           chargeL=None):
        """ Construct a matrix product state from a given product state.

        Parameters
        ----------
        sites : list of :class:`~tenpy.networks.site.Site`
            The sites defining the local Hilbert space.
        p_state : iterable of {int | 1D array}
            Defines the product state.
            If ``p_state[i]`` is int, then site ``i`` is in state ``p_state[i]``
            If ``p_state[i]`` is an array, then site ``i`` wavefunction is ``p_state[i]``
        bc : {'infinite', 'finite', 'segmemt'}
            MPS boundary conditions. See docstring of :class:`MPS`.
        dtype : type or string
            The data type of the array entries.
        form : (list of) {``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)}
            Defines the canonical form. See module doc-string.
            A single choice holds for all of the entries.
        chargeL : charges
            Bond charges at bond 0, which are purely conventional.

        """
        sites = list(sites)
        L = len(sites)
        p_state = list(p_state)
        if len(p_state) != L:
            raise ValueError("Length of p_state does not match number of sites.")
        ci = sites[0].leg.chinfo
        Bs = []
        chargeL = ci.make_valid(chargeL)  # sets to zero if `None`
        legL = npc.LegCharge.from_qflat(ci, chargeL)

        for i, site in enumerate(sites):
            try:
                iter(p_state[i])
                if len(p_state[i]) != site.dim:
                    raise ValueError("p_state incompatible with local dim:" + repr(p_state[i]))
                B = np.array(p_state[i], dtype).reshape((site.dim, 1, 1))
            except TypeError:
                B = np.zeros((site.dim, 1, 1), dtype)
                B[p_state[i], 0, 0] = 1.0
            # calculate the LegCharge of the right leg
            legs = [site.leg, legL, None]  # other legs are known
            legs = npc.detect_legcharge(B, ci, legs, None, qconj=-1)
            B = npc.Array.from_ndarray(B, legs, dtype)
            B.set_leg_labels(['p', 'vL', 'vR'])
            Bs.append(B)
            legL = legs[-1].conj()  # prepare for next `i`
        if bc == 'infinite':
            # for an iMPS, the last leg has to match the first one.
            # so we need to gauge `qtotal` of the last `B` such that the right leg matches.
            chdiff = Bs[-1].get_leg('vR').charges[0] - Bs[0].get_leg('vL').charges[0]
            Bs[-1] = Bs[-1].gauge_total_charge('vR', ci.make_valid(chdiff))
        SVs = [[1.]] * (L + 1)
        return cls(sites, Bs, SVs, form=form, bc=bc)

    @classmethod
    def from_full(cls, sites, psi, form='B', cutoff=1.e-16):
        """Construct an MPS from a single tensor `psi` with one leg per physical site.

        Performs a sequence of SVDs of psi to split off the `B` matrices and obtain the singular
        values, the result will be in canonical form.
        Obviously, this is only well-defined for `finite` boundary conditions.

        Parameters
        ----------
        sites : list of :class:`~tenpy.networks.site.Site`
            The sites defining the local Hilbert space.
        psi : :class:`~tenpy.linalg.np_conserved.Array`
            The full wave function to be represented as an MPS.
            Should have labels ``'p0', 'p1', ...,  'p{L-1}'``.
        form  : ``'B' | 'A' | 'C' | 'G'``
            The canonical form of the resulting MPS, see module doc-string.
        cutoff : float
            Cutoff of singular values used in the SVDs.

        Returns
        -------
        psi_mps : :class:`MPS`
            MPS representation of `psi`, normalized and in canonical form.
        """
        if form not in ['B', 'A', 'C', 'G']:
            raise ValueError("Invalid form: " + repr(form))
        # perform SVDs to bring it into 'B' form, afterwards change the form.
        L = len(sites)
        assert (L >= 2)
        B_list = [None] * L
        S_list = [1] * (L + 1)
        labels = ['p' + str(i) for i in range(L)]
        psi.itranspose(labels)
        # combine legs from left
        psi = psi.add_trivial_leg(0, label='vL', qconj=+1)
        for i in range(0, L - 1):
            psi = psi.combine_legs([0, 1])  # combines the legs until `i`
        psi = psi.add_trivial_leg(2, label='vR', qconj=-1)
        # now psi has only three legs: ``'(((vL.p0).p1)...p{L-2})', 'p{L-1}', 'vR'``
        for i in range(L - 1, 0, -1):
            # split off B[i]
            psi = psi.combine_legs([labels[i], 'vR'])
            psi, S, B = npc.svd(psi, inner_labels=['vR', 'vL'], cutoff=cutoff)
            S /= np.linalg.norm(S)  # normalize
            psi.iscale_axis(S, 1)
            B_list[i] = B.split_legs(1).replace_label(labels[i], 'p')
            S_list[i] = S
            psi = psi.split_legs(0)
        psi = psi.combine_legs([labels[0], 'vR'])
        psi, S, B = npc.svd(
            psi, qtotal_LR=[None, psi.qtotal], inner_labels=['vR', 'vL'], cutoff=cutoff)
        assert (psi.shape == (1, 1))
        S_list[0] = np.ones([1], dtype=np.float)
        B_list[0] = B.split_legs(1).replace_label(labels[0], 'p')
        res = cls(sites, B_list, S_list, bc='finite', form='B')
        if form != 'B':
            res.convert_form(form)
        return res

    def copy(self):
        """Returns a copy of `self`.

        The copy still shares the sites, chinfo, and LegCharges of the _B,
        but the values of B and S are deeply copied.
        """
        # __init__ makes deep copies of B, S
        return MPS(self.sites, self._B, self._S, self.bc, self.form)

    @property
    def L(self):
        """Number of physical sites. For an iMPS the len of the MPS unit cell."""
        return len(self.sites)

    @property
    def dim(self):
        """List of local physical dimensions."""
        return [site.dim for site in self.sites]

    @property
    def finite(self):
        "Distinguish MPS (``True; bc='finite', 'segment'`` ) vs. iMPS (``False; bc='infinite'``)"
        assert (self.bc in self._valid_bc)
        return self.bc != 'infinite'

    @property
    def chi(self):
        """Dimensions of the (nontrivial) virtual bonds."""
        # s.shape[0] == len(s) for 1D numpy array, but works also for a 2D npc Array.
        return [s.shape[0] for s in self._S[self.nontrivial_bonds]]

    @property
    def nontrivial_bonds(self):
        """Slice of the non-trivial bond indices, depending on ``self.bc``."""
        if self.bc == 'finite':
            return slice(1, self.L)
        elif self.bc == 'segment':
            return slice(0, self.L + 1)
        elif self.bc == 'infinite':
            return slice(0, self.L)

    def get_B(self, i, form='B', copy=False, cutoff=1.e-16):
        """return (view of) `B` at site `i` in canonical form.

        Parameters
        ----------
        i : int
            Index choosing the site.
        form : ``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)
            The (canonical) form of the returned B.
            For ``None``, return the matrix in whatever form it is.
        copy : bool
            Whether to return a copy even if `form` matches the current form.
        cutoff : float
            During DMRG with a mixer, `S` may be a matrix for which we need the inverse.
            This is calculated as the Penrose pseudo-inverse, which uses a cutoff for the
            singular values.

        Returns
        -------
        B : :class:`~tenpy.linalg.np_conserved.Array`
            The MPS 'matrix' `B` at site `i` with leg labels ``vL, vR, p`` (in undefined order).
            May be a view of the matrix (if ``copy=False``),
            or a copy (if the form changed or ``copy=True``)

        Raises
        ------
        ValueError : if self is not in canoncial form and ``form != None``.
        """
        i = self._to_valid_index(i)
        form = self._to_valid_form(form)
        return self._convert_form_i(self._B[i], i, self.form[i], form, copy, cutoff)

    def set_B(self, i, B, form='B'):
        """set `B` at site `i`.

        Parameters
        ----------
        i : int
            Index choosing the site.
        B : :class:`~tenpy.linalg.np_conserved.Array`
            The 'matrix' at site `i`. Should have leg labels ``vL, vR, p`` (in any order).
        form : ``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)
            The (canonical) form of the `B` to set.
            ``None`` stands for non-canonical form.
        """
        i = self._to_valid_index(i)
        self.form[i] = self._to_valid_form(form)
        self._B[i] = B

    def get_SL(self, i):
        """return singular values on the left of site `i`"""
        i = self._to_valid_index(i)
        return self._S[i]

    def get_SR(self, i):
        """return singular values on the right of site `i`"""
        i = self._to_valid_index(i)
        return self._S[i + 1]

    def set_SL(self, i, S):
        """set singular values on the left of site `i`"""
        i = self._to_valid_index(i)
        self._S[i] = S
        if not self.finite and i == 0:
            self._S[self.L] = S

    def set_SR(self, i, S):
        """set singular values on the right of site `i`"""
        i = self._to_valid_index(i)
        self._S[i + 1] = S
        if not self.finite and i == self.L - 1:
            self._S[0] = S

    def get_op(self, op_list, i):
        """Given a list of operators, select the one corresponding to site `i`.

        Parameters
        ----------
        op_list : (list of) {str | npc.array}
            List of operators from which we choose. We assume that ``op_list[j]`` acts on site
            ``j``. If the length is shorter than `L`, we repeat it periodically.
            Strings are translated using :meth:`~tenpy.networks.site.Site.get_op` of site `i`.
        i : int
            Index of the site on which the operator acts.

        Returns
        -------
        op : npc.array
            One of the entries in `op_list`, not copied.
        """
        i = self._to_valid_index(i)
        op = op_list[i % len(op_list)]
        if (type(op) == str):
            op = self.sites[i].get_op(op)
        return op

    def get_theta(self, i, n=2, cutoff=1.e-16, formL=1., formR=1.):
        """Calculates the `n`-site wavefunction on ``sites[i:i+n]``.

        Parameters
        ----------
        i : int
            Site index.
        n : int
            Number of sites. The result lives on ``sites[i:i+n]``.
        cutoff : float
            During DMRG with a mixer, `S` may be a matrix for which we need the inverse.
            This is calculated as the Penrose pseudo-inverse, which uses a cutoff for the
            singular values.
        formL : float
            Exponent for the singular values to the left.
        formR : float
            Exponent for the singular values to the right.

        Returns
        -------
        theta : :class:`~tenpy.linalg.np_conserved.Array`
            The n-site wave function with leg labels ``vL, vR, p0, p1, .... p{n-1}``
            (in undefined order).
            In Vidal's notation (with s=lambda, G=Gamma):
            ``theta = s**form_L G_i s G_{i+1} s ... G_{i+n-1} s**form_R``.
        """
        i = self._to_valid_index(i)
        if self.finite:
            if (i < 0 or i + n > self.L):
                raise ValueError("i = {0:d} out of bounds".format(i))

        if self.form[i] is None or self.form[(i + n - 1) % self.L] is None:
            # we allow intermediate `form`=None except at the very left and right.
            raise ValueError("can't calculate theta for non-canonical form")

        # the following code is an equivalent to::
        #
        #   theta = self.get_B(i, form='B').replace_label('p', 'p0')
        #   theta.iscale_axis(self.get_SL(i)** formL, 'vL')
        #   for k in range(1, n):
        #       j = (i + n) % self.L
        #       B = self.get_B(j, form='B').replace_label('p', 'p'+str(k))
        #       theta = npc.tensordot(theta, B, ['vR', 'vL'])
        #   theta.iscale_axis(self.get_SR(i + n - 1)** (formR-1.), 'vR')
        #   return theta
        #
        # However, the following code is nummerically more stable if ``self.form`` is not `B`
        # (since it avoids unnecessary `scale_axis`) and works also for intermediate sites with
        # ``self.form[j] = None``.

        fL, fR = self.form[i]  # left / right form exponent
        copy = (fL == 0 and fR == 0)  # otherwise, a copy is performed later by `scale_axis`.
        theta = self.get_B(i, form=None, copy=copy)  # in the current form
        theta = self._replace_p_label(theta, 0)
        theta = self._scale_axis_B(theta, self.get_SL(i), formL - fL, 'vL', cutoff)
        for k in range(1, n):  # nothing if n=1.
            j = (i + k) % self.L
            B = self._replace_p_label(self.get_B(j, None, False), k)
            if self.form[j] is not None:
                fL_j, fR_j = self.form[j]
                if fR is not None:
                    B = self._scale_axis_B(B, self.get_SL(j), 1. - fL_j - fR, 'vL', cutoff)
                # otherwise we can just hope it's fine.
                fR = fR_j
            else:
                fR = None
            theta = npc.tensordot(theta, B, axes=('vR', 'vL'))
        # here, fR = self.form[i+n-1][1]
        theta = self._scale_axis_B(theta, self.get_SR(i + n - 1), formR - fR, 'vR', cutoff)
        return theta

    def convert_form(self, new_form='B'):
        """tranform self into different canonical form (by scaling the legs with singular values).

        Parameters
        ----------
        new_form : (list of) {``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)}
            The form the stored 'matrices'. The table in module doc-string.
            A single choice holds for all of the entries.

        Raises
        ------
        ValueError : if trying to convert from a ``None`` form. Use :meth:`canonicalize` instead!
        """
        new_forms = self._parse_form(new_form)
        for i, form in enumerate(new_forms):
            new_B = self.get_B(i, form=form, copy=False)  # calculates the desired form.
            self.set_B(i, new_B, form=form)

    def entanglement_entropy(self, n=1, bonds=None, for_matrix_S=False):
        r"""Calculate the (half-chain) entanglement entropy for all nontrivial bonds.

        Consider a cut of the sytem in to :math:`A = \{ j: j < i \}` and :math:`B = \{ j: j > i\}`.
        This defines the von-Neumann entanglement entropy as
        :math:`S(A, n=1) = -tr(\rho_A \log(\rho_A)) = S(B, n=1)`.
        The generalization for ``n != 1, n>0`` are the Renyi entropies:
        :math:`S(A, n) = \frac{1}{1-n} \log(tr(\rho_A^2)) = S(B, n=1)`

        This function calculates the entropy for a cut at different bonds `i`, for which the
        the eigenvalues of the reduced density matrix :math:`\rho_A` and :math:`\rho_B` is given
        by the squared schmidt values `S` of the bond.

        Parameters
        ----------
        n : int/float
            Selects which entropy to calculate;
            `n=1` (default) is the ususal von-Neumann entanglement entropy.
        bonds : ``None`` | (iterable of) int
            Selects the bonds at which the entropy should be calculated.
            ``None`` defaults to ``range(0, L+1)[self.nontrivial_bonds]``.
        for_matrix_S : bool
            Switch calculate the entanglement entropy even if the `_S` are matrices.
            Since :math:`O(\chi^3)` is expensive compared to the ususal :math:`O(\chi)`,
            we raise an error by default.

        Returns
        -------
        entropies : 1D ndarray
            Entanglement entropies for half-cuts.
            `entropies[j]` contains the entropy for a cut at bond ``bonds[j]``
            (i.e. left to site ``bonds[j]``).
        """
        if bonds is None:
            nt = self.nontrivial_bonds
            bonds = range(nt.start, nt.stop)
        res = []
        for ib in bonds:
            s = self._S[ib]
            if len(s.shape) > 1:
                if for_matrix_S:
                    # explicitly calculate Schmidt values by diagonalizing (s^dagger s)
                    s = npc.eigvals(npc.tensordot(s.conj(), s, axes=[0, 0]))
                    s = np.sqrt(s[s > 1.e-40])
                else:
                    raise ValueError("entropy with non-diagonal schmidt values")
            s = s[s > 1.e-20]  # just for stability reasons / to avoid NaN in log
            if n == 1:
                res.append(-np.inner(np.log(s), s))
            elif n == np.inf:
                res.append(-2. * np.log(np.max(s)))
            else:  # general n != 1, inf
                res.append(np.log(np.sum(s**(2 * n))) / (1. - n))
        return np.array(res)

    def overlap(self, other):
        """Compute overlap :math:`<self|other>`.

        Parameters
        ----------
        other : :class:`MPS`
            An MPS with the same physical sites.

        Returns
        -------
        overlap : dtype
            The contraction <self|other>.
        env : MPSEnvironment
            The environment (storing the LP and RP) used to calculate the overlap.
        """
        if not self.finite:
            # requires MPSTransferMatrix for finding dominant left/right parts
            raise NotImplementedError("TODO")
        env = MPSEnvironment(self, other)
        return env.full_contraction(0), env

    def expectation_value(self, ops, sites=None, axes=None):
        """Expectation value ``<psi|ops|psi>`` of (n-site) operator(s).

        Given the MPS in canonical form, it calculates n-site expectation values.
        For example the contraction for a two-site (`n`=2) operator on site `i` would look like::

            |          .--S--B[i]--B[i+1]--.
            |          |     |     |       |
            |          |     |-----|       |
            |          |     | op  |       |
            |          |     |-----|       |
            |          |     |     |       |
            |          .--S--B*[i]-B*[i+1]-.

        Parameters
        ----------
        ops : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            The operators, for wich the expectation value should be taken,
            All operators should all have the same number of legs (namely `2 n`).
            If less than ``len(sites)`` operators are given, we repeat them periodically.
            Strings (like ``'Id', 'Sz'``) are translated into single-site operators defined by
            `self.sites`.
        sites : None | list of int
            List of site indices. Expectation values are evaluated there.
            If ``None`` (default), the entire chain is taken (clipping for finite b.c.)
        axes : None | (list of str, list of str)
            Two lists of each `n` leg labels giving the physical legs of the operator used for
            contraction. The first `n` legs are contracted with conjugated B`s,
            the second `n` legs with the non-conjugated `B`.
            ``None`` defaults to ``(['p'], ['p*'])`` for single site operators (`n` = 1), or
            ``(['p0', 'p1', ... 'p{n-1}'], ['p0*', 'p1*', .... 'p{n-1}*'])`` for `n` > 1.

        Returns
        -------
        exp_vals : 1D ndarray
            Expectation values, ``exp_vals[i] = <psi|ops[i]|psi>``, where ``ops[i]`` acts on
            site(s) ``j, j+1, ..., j+{n-1}`` with ``j=sites[i]``.

        Examples
        --------
        One site examples (n=1):

        >>> psi.expectation_value('Sz')
        [Sz0, Sz1, ..., Sz{L-1}]
        >>> psi.expectation_value(['Sz', 'Sx'])
        [Sz0, Sx1, Sz2, Sx3, ... ]
        >>> psi.expectation_value('Sz', sites=[0, 3, 4])
        [Sz0, Sz3, Sz4]

        Two site example (n=2), assuming homogeneous sites:

        >>> SzSx = npc.outer(psi.sites[0].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                             psi.sites[1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
        >>> psi.expectation_value(SzSx)
        [Sz0Sx1, Sz1Sx2, Sz2Sx3, ... ]   # with len ``L-1`` for finite bc, or ``L`` for infinite

        Example measuring <psi|SzSx|psi2> on each second site, for inhomogeneous sites:

        >>> SzSx_list = [npc.outer(psi.sites[i].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                                   psi.sites[i+1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
                         for i in range(0, psi.L-1, 2)]
        >>> psi.expectation_value(SzSx_list, range(0, psi.L-1, 2))
        [Sz0Sx1, Sz2Sx3, Sz4Sx5, ...]

        """
        ops, sites, n, th_labels, (axes_p, axes_pstar) = self._expectation_value_args(ops, sites,
                                                                                      axes)
        vLvR_axes_p = ('vL', 'vR') + tuple(axes_p)
        E = []
        for i in sites:
            op = self.get_op(ops, i)
            theta = self.get_theta(i, n)
            C = npc.tensordot(op, theta, axes=[axes_pstar, th_labels[2:]])
            E.append(npc.inner(theta, C, axes=[th_labels, vLvR_axes_p], do_conj=True))
        return np.real_if_close(np.array(E))

    def correlation_function(self,
                             ops1,
                             ops2,
                             sites1=None,
                             sites2=None,
                             opstr=None,
                             str_on_first=False,
                             hermitian=False):
        """Correlation function  ``<psi|op1_i op2_j|psi>`` of single site operators `op1`, `op2`.

        Given the MPS in canonical form, it calculates n-site expectation values.
        For examples the contraction for a two-site operator on site `i` would look like::

            |          .--S--B[i]--B[i+1]--...--B[j]---.
            |          |     |     |            |      |
            |          |     |     |            op2    |
            |          |     op1   |            |      |
            |          |     |     |            |      |
            |          .--S--B*[i]-B*[i+1]-...--B*[j]--.

        Onsite terms are taken in the order ``<psi | op1 op2 | psi>``.

        If `opstr` is given and ``str_on_first=True``, it calculates::

            |           for i < j                               for i > j
            |
            |          .--S--B[i]---B[i+1]--...- B[j]---.     .--S--B[j]---B[j+1]--...- B[i]---.
            |          |     |      |            |      |     |     |      |            |      |
            |          |     opstr  opstr        op2    |     |     op2    |            |      |
            |          |     |      |            |      |     |     |      |            |      |
            |          |     op1    |            |      |     |     opstr  opstr        op1    |
            |          |     |      |            |      |     |     |      |            |      |
            |          .--S--B*[i]--B*[i+1]-...- B*[j]--.     .--S--B*[j]--B*[j+1]-...- B*[i]--.

        For ``i==j``, no `opstr` is included.
        For ``str_on_first=False`` (default), the `opstr` on site ``min(i, j)`` is always left out.

        Strings (like ``'Id', 'Sz'``) in the operator lists are translated into single-site
        operators defined by the :class:`~tenpy.networks.site.Site` on which they act.
        Each operator should have the two legs ``'p', 'p*'``.

        Parameters
        ----------
        ops1 : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            First operator of the correlation function (acting after ops2).
            ``ops1[x]`` acts on site ``sites1[x]``.
            If less than ``len(sites1)`` operators are given, we repeat them periodically.
        ops2 : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            Second operator of the correlation function (acting before ops1).
            ``ops2[y]`` acts on site ``sites2[y]``.
            If less than ``len(sites2)`` operators are given, we repeat them periodically.
        sites1 : None | int | list of int
            List of site indices; a single `int` is translated to ``range(0, sites1)``.
            ``None`` defaults to all sites ``range(0, L)``.
            Is sorted before use, i.e. the order is ignored.
        sites2 : None | int | list of int
            List of site indices; a single `int` is translated to ``range(0, sites2)``.
            ``None`` defaults to all sites ``range(0, L)``.
            Is sorted before use, i.e. the order is ignored.
        opstr : None | (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            Ignored by default (``None``).
            Operator(s) to be inserted between ``ops1`` and ``ops2``.
            If given as a list, ``opstr[r]`` is inserted at site `r` (independent of `sites1` and
            `sites2`).
        str_on_first : bool
            Whether the `opstr` is included on the site ``min(i, j)``.
            Note the order, which is chosen that way to handle fermionic Jordan-Wigner strings
            correctly. (In other words: choose ``str_on_first=True`` for fermions!)
        hermitian : bool
            Optimization flag: if ``sites1 == sites2`` and ``Ops1[i]^\dagger == Ops2[i]``
            (which is not checked explicitly!), the resulting ``C[x, y]`` will be hermitian.
            We can use that to avoid calculations, so ``hermitian=True`` will run faster.

        Returns
        -------
        C : 2D ndarray
            The correlation function ``C[x, y] = <psi|ops1[i] ops2[j]|psi>``,
            where ``ops1[i]`` acts on site ``i=sites1[x]`` and ``ops2[j]`` on site ``j=sites2[y]``.
            If opstr is given, it gives (for ``opstr_on_first=True``):

            - For ``i < j``: ``C[x, y] = <psi|ops1[i] prod_{i <= r < j} opstr[r] ops2[j]|psi>``.
            - For ``i > j``: ``C[x, y] = <psi|prod_{j <= r < i} opstr[r] ops1[i] ops2[j]|psi>``.
            - For ``i = j``: ``C[x, y] = <psi|ops1[i] ops2[j]|psi>``.

            The condition ``<= r`` is replaced by a strict ``< r``, if ``opstr_on_first=False``.
        """
        ops1, ops2, sites1, sites2, opstr = self._correlation_function_args(ops1, ops2, sites1,
                                                                            sites2, opstr)
        if hermitian and sites1 != sites2:
            warnings.warn("MPS correlation function can't use the hermitian flag")
            hermitian = False
        C = np.empty((len(sites1), len(sites2)), dtype=np.complex)
        for x, i in enumerate(sites1):
            # j > i
            j_gtr = sites2[sites2 > i]
            if len(j_gtr) > 0:
                C_gtr = self._corr_up_diag(ops1, ops2, i, j_gtr, opstr, str_on_first, True)
                C[x, (sites2 > i)] = C_gtr
                if hermitian:
                    C[x + 1:, x] = np.conj(C_gtr)
            # j == i
            j_eq = sites2[sites2 == i]
            if len(j_eq) > 0:
                # on-site correlation function
                op12 = npc.tensordot(self.get_op(ops1, i), self.get_op(ops2, i), axes=['p*', 'p'])
                C[x, (sites2 == i)] = self.expectation_value(op12, i, [['p'], ['p*']])
        if not hermitian:
            #  j < i
            for y, j in enumerate(sites2):
                i_gtr = sites1[sites1 > j]
                if len(i_gtr) > 0:
                    C[(sites1 > j), y] = self._corr_up_diag(ops2, ops1, j, i_gtr, opstr,
                                                            str_on_first, False)
                    # exchange ops1 and ops2 : they commute on different sites,
                    # but we apply opstr after op1 (using the last argument = False)
        return np.real_if_close(C)

    def __str__(self):
        """Some status information about the MPS"""
        res = ["MPS, L={L:d}, bc={bc!r}.".format(L=self.L, bc=self.bc)]
        res.append("chi: " + str(self.chi))
        if self.L > 10:
            res.append("first two sites: " + repr(self.sites[0]) + " " + repr(self.sites[1]))
            res.append("first two forms:" + " ".join([repr(f) for f in self.form[:2]]))
        else:
            res.append("sites: " + " ".join([repr(s) for s in self.sites]))
            res.append("forms: " + " ".join([repr(f) for f in self.form]))
        return "\n".join(res)

    def _to_valid_index(self, i):
        """make sure `i` is a valid index (depending on `self.bc`)."""
        if not self.finite:
            return i % self.L
        if i < 0:
            i += self.L
        if i >= self.L or i < 0:
            raise ValueError("i = {0:d} out of bounds for finite MPS".format(i))
        return i

    def _parse_form(self, form):
        """parse `form` = (list of) {tuple | key of _valid_forms} to list of tuples"""
        if isinstance(form, tuple):
            return [form] * self.L
        form = to_iterable(form)
        if len(form) == 1:
            form = [form[0]] * self.L
        if len(form) != self.L:
            raise ValueError("Wrong len of `form`: " + repr(form))
        return [self._to_valid_form(f) for f in form]

    def _to_valid_form(self, form):
        """parse `form` = {tuple | key of _valid_forms} to a tuple"""
        if isinstance(form, tuple):
            return form
        return self._valid_forms[form]

    def _convert_form_i(self, B, i, form, new_form, copy=True, cutoff=1.e-16):
        """transform `B[i]` from canonical form `form` into canonical form `new_form`.

        ======== ======== ================================================
        form     new_form action
        ======== ======== ================================================
        *        ``None`` return (copy of) B
        tuple    tuple    scale the legs 'vL' and 'vR' of B appropriately
                          with ``self.get_SL(i)`` and ``self.get_SR(i)``.
        ``None`` tuple    raise ValueError
        ======== ======== ================================================
        """
        if new_form is None or form == new_form:
            if copy:
                return B.copy()
            return B  # nothing to do
        if form is None:
            raise ValueError("can't convert form of non-canonical state!")
        old_L, old_R = form
        new_L, new_R = new_form
        B = self._scale_axis_B(B, self.get_SL(i), new_L - old_L, 'vL', cutoff)
        B = self._scale_axis_B(B, self.get_SR(i), new_R - old_R, 'vR', cutoff)
        return B

    def _scale_axis_B(self, B, S, form_diff, axis_B, cutoff):
        """Scale an axis of B with S to bring it in desired form.

        If S is just 1D (as usual, e.g. during TEBD), this function just performs
        ``B.scale_axis(S**form_diff, axis_B)``.

        However, during the DMRG with mixer, S might acutally be a 2D matrix.
        For ``form_diff = -1``, we need to calculate the inverse of S, more precisely the
        (Moore-Penrose) pseudo inverse, see :func:`~tenpy.linalg.np_conserved.pinv`.
        The cutoff is only used in that case.

        Returns scaled B."""
        if form_diff == 0:
            return B  # nothing to do
        if isinstance(S, npc.Array):
            if S.rank != 2:
                raise ValueError("Expect 2D npc.Array or 1D numpy ndarray")
            if form_diff == -1:
                S = npc.pinv(S, cutoff)
            elif form_diff != 1.:
                raise ValueError("Can't scale/tensordot a 2D `S` for non-integer `form_diff`")

            # Hack: mpo.MPOEnvironment.full_contraction uses ``axis_B == 'vL*'``
            if axis_B == 'vL' or axis_B == 'vL*':
                B = npc.tensordot(S, B, axes=[1, axis_B]).replace_label(0, axis_B)
            elif axis_B == 'vR' or axis_B == 'vR*':
                B = npc.tensordot(B, S, axes=[axis_B, 0]).replace_label(-1, axis_B)
            else:
                raise ValueError("This should never happen: unexpected leg for scaling with S")
            return B
        else:
            if form_diff != 1.:
                S = S**form_diff
            return B.scale_axis(S, axis_B)

    def _replace_p_label(self, A, k):
        """Return npc Array `A` with replaced label, ``'p' -> 'p'+str(k)``.

        Instead of re-implementing `get_theta`, the derived `PurificationMPS` needs only to
        implement this function."""
        return A.replace_label('p', 'p' + str(k))

    def _expectation_value_args(self, ops, sites, axes):
        """parse the arguments of self.expectation_value()"""
        ops = npc.to_iterable_arrays(ops)
        n = self.get_op(ops, 0).rank // 2  # same as int(rank/2)
        L = self.L
        if sites is None:
            if self.finite:
                sites = range(L - (n - 1))
            else:
                sites = range(L)
        sites = to_iterable(sites)
        th_labels = ['vL', 'vR'] + ['p' + str(j) for j in range(n)]
        if axes is None:
            if n == 1:
                axes = (['p'], ['p*'])
            else:
                axes = (th_labels[2:], [lbl + '*' for lbl in th_labels[2:]])
        axes_p, axes_pstar = axes
        if len(axes_p) != n or len(axes_pstar) != n:
            raise ValueError("Len of axes does not match operator n=" + len(n))
        return ops, sites, n, th_labels, axes

    def _correlation_function_args(self, ops1, ops2, sites1, sites2, opstr):
        """get default arguments of self.correlation_function()"""
        if sites1 is None:
            sites1 = range(0, self.L)
        elif type(sites1) == int:
            sites1 = range(0, sites1)
        if sites2 is None:
            sites2 = range(0, self.L)
        elif type(sites2) == int:
            sites2 = range(0, sites2)
        ops1 = npc.to_iterable_arrays(ops1)
        ops2 = npc.to_iterable_arrays(ops2)
        opstr = npc.to_iterable_arrays(opstr)
        sites1 = np.sort(sites1)
        sites2 = np.sort(sites2)
        return ops1, ops2, sites1, sites2, opstr

    def _corr_up_diag(self, ops1, ops2, i, j_gtr, opstr, str_on_first, apply_opstr_first):
        """correlation function above the diagonal: for fixed i and all j in j_gtr, j > i."""
        op1 = self.get_op(ops1, i)
        opstr1 = self.get_op(opstr, i)
        if opstr1 is not None:
            axes = ['p*', 'p'] if apply_opstr_first else ['p', 'p*']
            op1 = npc.tensordot(op1, opstr1, axes=axes)
        theta = self.get_theta(i, n=1)
        C = npc.tensordot(op1, theta, axes=['p*', 'p0'])
        C = npc.tensordot(theta.conj(), C, axes=[['p0*', 'vL*'], ['p', 'vL']])
        # C has legs 'vR*', 'vR'
        js = list(j_gtr[::-1])  # stack of j, sorted *descending*
        res = []
        for r in range(i + 1, js[0] + 1):  # js[0] is the maximum
            B = self.get_B(r, form='B')
            C = npc.tensordot(C, B, axes=['vR', 'vL'])
            if r == js[-1]:
                Cij = npc.tensordot(self.get_op(ops2, r), C, axes=['p*', 'p'])
                Cij = npc.inner(B.conj(), Cij, axes=[['vL*', 'p*', 'vR*'], ['vR*', 'p', 'vR']])
                res.append(Cij)
                js.pop()
            if len(js) > 0:
                op = self.get_op(opstr, r)
                if op is not None:
                    C = npc.tensordot(op, C, axes=['p*', 'p'])
                C = npc.tensordot(B.conj(), C, axes=[['vL*', 'p*'], ['vR*', 'p']])
        return res


class MPSEnvironment(object):
    """Stores partial contractions of :math:`<bra|Op|ket>` for local operators `Op`.

    The network for a contraction :math:`<bra|Op|ket>` of a local operator `Op`, say exemplary
    at sites `i, i+1` looks like::

        |     .-----M[0]--- ... --M[1]---M[2]--- ... ->--.
        |     |     |             |      |               |
        |     |     |             |------|               |
        |     LP[0] |             |  Op  |               RP[-1]
        |     |     |             |------|               |
        |     |     |             |      |               |
        |     .-----N[0]*-- ... --N[1]*--N[2]*-- ... -<--.

    Of course, we can also calculate the overlap `<bra|ket>` by using the special case ``Op = Id``.

    We use the following label convention (where arrows indicate `qconj`)::

        |    .-->- vR           vL ->-.
        |    |                        |
        |    LP                       RP
        |    |                        |
        |    .--<- vR*         vL* -<-.

    To avoid recalculations of the whole network e.g. in the DMRG sweeps,
    we store the contractions up to some site index in this class.
    For ``bc='finite','segment'``, the very left and right part ``LP[0]`` and
    ``RP[-1]`` are trivial and don't change,
    but for ``bc='infinite'`` they are might be updated
    (by inserting another unit cell to the left/right).

    The MPS `bra` and `ket` have to be in canonical form.
    All the environments are constructed without the singular values on the open bond.
    In other words, we contract left-canonical `A` to the left parts `LP`
    and right-canonical `B` to the right parts `RP`.
    Thus, the special case `psi`=`ket` should yield identity matrices for `LP` and `RP`.

    .. todo :
        Functionality to find the dominant LP/RP in the TD limit -> requires MPSTransferMatrix

    Parameters
    ----------
    bra : :class:`~tenpy.networks.mps.MPS`
        The MPS to project on. Should be given in usual 'ket' form;
        we call `conj()` on the matrices directly.
    ket : :class:`~tenpy.networks.mpo.MPO`
        The MPS on which the local operator acts.
    firstLP : ``None`` | :class:`~tenpy.linalg.np_conserved.Array`
        Initial very left part. If ``None``, build trivial one.
    rightRP : ``None`` | :class:`~tenpy.linalg.np_conserved.Array`
        Initial very right part. If ``None``, build trivial one.
    age_LP : int
        The number of physical sites involved into the contraction yielding `firstLP`.
    age_RP : int
        The number of physical sites involved into the contraction yielding `lastRP`.

    Attributes
    ----------
    L : int
        Number of physical sites. For iMPS the len of the MPS unit cell.
    dtype : type | string
        The data type of the Array entries.
    bra, ket: :class:`~tenpy.networks.mps.MPS`
        The two MPS for the contraction.
    _LP : list of {``None`` | :class:`~tenpy.linalg.np_conserved.Array`}
        Left parts of the environment, len `L`.
        ``LP[i]`` contains the contraction strictly left of site `i`
        (or ``None``, if we don't have it calculated).
    _RP : list of {``None`` | :class:`~tenpy.linalg.np_conserved.Array`}
        Right parts of the environment, len `L`.
        ``RP[i]`` contains the contraction strictly right of site `i`
        (or ``None``, if we don't have it calculated).
    _LP_age : list of int|``None``
        Used for book-keeping, how large the DMRG system grew:
        ``_LP_age[i]`` stores the number of physical sites invovled into the contraction
        network which yields ``self._LP[i]``.
    _RP_age : list of int|``None``
        Used for book-keeping, how large the DMRG system grew:
        ``_RP_age[i]`` stores the number of physical sites invovled into the contraction
        network which yields ``self._RP[i]``.
    """

    def __init__(self, bra, ket, firstLP=None, lastRP=None, age_LP=0, age_RP=0):
        if ket is None:
            ket = bra
        self.bra = bra
        self.ket = ket
        self.L = L = bra.L
        self.finite = bra.finite
        self.dtype = np.find_common_type([bra.dtype, ket.dtype], [])
        self._LP = [None] * L
        self._RP = [None] * L
        self._LP_age = [None] * L
        self._RP_age = [None] * L
        if firstLP is None:
            # Build trivial verly first LP
            leg_bra = bra.get_B(0).get_leg('vL')
            leg_ket = ket.get_B(0).get_leg('vL').conj()
            leg_ket.test_contractible(leg_bra)
            # should work for both finite and segment bc
            firstLP = npc.diag(1., leg_bra, dtype=self.dtype)
            firstLP.set_leg_labels(['vR*', 'vR'])
        self.set_LP(0, firstLP, age=age_LP)
        if lastRP is None:
            # Build trivial verly last RP
            leg_bra = bra.get_B(L - 1).get_leg('vR')
            leg_ket = ket.get_B(L - 1).get_leg('vR').conj()
            leg_ket.test_contractible(leg_bra)
            lastRP = npc.diag(1., leg_bra, dtype=self.dtype)  # (leg_bra, leg_ket)
            lastRP.set_leg_labels(['vL*', 'vL'])
        self.set_RP(L - 1, lastRP, age=age_RP)
        self.test_sanity()

    def test_sanity(self):
        assert (self.bra.L == self.ket.L)
        assert (self.bra.finite == self.ket.finite)
        # check that the network is contractable
        for b_s, k_s in itertools.izip(self.bra.sites, self.ket.sites):
            b_s.leg.test_equal(k_s.leg)
        assert any([LP is not None for LP in self._LP])
        assert any([RP is not None for RP in self._RP])

    def get_LP(self, i, store=True):
        """Calculate LP at given site from nearest available one (including `i`).

        Parameters
        ----------
        i : int
            The returned `LP` will contain the contraction *strictly* left of site `i`.
        store : bool
            Wheter to store the calculated `LP` in `self` (``True``) or discard them (``False``).

        Returns
        -------
        LP_i : :class:`~tenpy.linalg.np_conserved.Array`
            Contraction of everything left of site `i`,
            with labels ``'vR*', 'wR', 'vR'`` for `bra`, `H`, `ket`.
        """
        # find nearest available LP to the left.
        for i0 in range(i, i - self.L, -1):
            LP = self._LP[self._to_valid_index(i0)]
            if LP is not None:
                break
            # (for finite, LP[0] should always be set, so we should abort at latest with i0=0)
        else:  # no break called
            raise ValueError("No left part in the system???")
        age_i0 = self.get_LP_age(i0)
        for j in range(i0, i):
            LP = self._contract_LP(j, LP)
            if store:
                self.set_LP(j + 1, LP, age=age_i0 + j - i0 + 1)
        return LP

    def get_RP(self, i, store=True):
        """Calculate RP at given site from nearest available one (including `i`).

        Parameters
        ----------
        i : int
            The returned `RP` will contain the contraction *strictly* rigth of site `i`.
        store : bool
            Wheter to store the calculated `RP` in `self` (``True``) or discard them (``False``).

        Returns
        -------
        RP_i : :class:`~tenpy.linalg.np_conserved.Array`
            Contraction of everything left of site `i`,
            with labels ``'vL*', 'wL', 'vL'`` for `bra`, `H`, `ket`.
        """
        # find nearest available RP to the right.
        for i0 in range(i, i + self.L):
            RP = self._RP[self._to_valid_index(i0)]
            if RP is not None:
                break
            # (for finite, RP[-1] should always be set, so we should abort at latest with i0=L-1)
        age_i0 = self.get_RP_age(i0)
        for j in range(i0, i, -1):
            RP = self._contract_RP(j, RP)
            if store:
                self.set_RP(j - 1, RP, age=age_i0 + i0 - j + 1)
        return RP

    def get_LP_age(self, i):
        """Return number of physical sites in the contractions of get_LP(i). Might be ``None``."""
        return self._LP_age[self._to_valid_index(i)]

    def get_RP_age(self, i):
        """Return number of physical sites in the contractions of get_LP(i). Might be ``None``."""
        return self._RP_age[self._to_valid_index(i)]

    def set_LP(self, i, LP, age):
        """Store part to the left of site `i`."""
        i = self._to_valid_index(i)
        self._LP[i] = LP
        self._LP_age[i] = age

    def set_RP(self, i, RP, age):
        """Store part to the right of site 1i1."""
        i = self._to_valid_index(i)
        self._RP[i] = RP
        self._RP_age[i] = age

    def del_LP(self, i):
        """Delete stored part strictly to the left of site `i`."""
        i = self._to_valid_index(i)
        self._LP[i] = None
        self._LP_age[i] = None

    def del_RP(self, i):
        """Delete storde part scrictly to the right of site `i`."""
        i = self._to_valid_index(i)
        self._RP[i] = None
        self._RP_age[i] = None

    def full_contraction(self, i0):
        """Calculate the energy by a full contraction of the network.

        The full contraction of the environments gives the value ``<bra|H|ket>``,
        i.e. if `bra` is `ket`, the total energy. For this purpose, this function contracts
        ``get_LP(i0+1, store=False)`` and ``get_RP(i0, store=False)``.

        Parameters
        ----------
        i0 : int
            Site index.
        """
        if self.ket.finite and i0 + 1 == self.L:
            # special case to handle `_to_valid_index` correctly:
            # get_LP(L) is not valid for finite b.c, so we use need to calculate it explicitly.
            LP = self.get_LP(i0, store=False)
            LP = self._contract_LP(i0, LP)
        else:
            LP = self.get_LP(i0 + 1, store=False)
        # multiply with `S`: a bit of a hack: use 'private' MPS._scale_axis_B
        S_bra = self.bra.get_SR(i0).conj()
        LP = self.bra._scale_axis_B(LP, S_bra, form_diff=1., axis_B='vR*', cutoff=0.)
        # cutoff is not used for form_diff = 1
        S_ket = self.ket.get_SR(i0)
        LP = self.bra._scale_axis_B(LP, S_ket, form_diff=1., axis_B='vR', cutoff=0.)
        RP = self.get_RP(i0, store=False)
        return npc.inner(LP, RP, axes=[['vR*', 'vR'], ['vL*', 'vL']], do_conj=False)

    def expectation_value(self, ops, sites=None, axes=None):
        """Expectation value ``<bra|ops|ket>`` of (n-site) operator(s).

        Calculates n-site expectation values of operators sandwiched between bra and ket.
        For examples the contraction for a two-site operator on site `i` would look like::

            |          .--S--B[i]--B[i+1]--.
            |          |     |     |       |
            |          |     |-----|       |
            |          LP[i] | op  |       RP[i+1]
            |          |     |-----|       |
            |          |     |     |       |
            |          .--S--B*[i]-B*[i+1]-.

        Here, the `B` are taken from `ket`, the `B*` from `bra`.
        The call structure is the same as for :meth:`MPS.expectation_value`.

        Parameters
        ----------
        ops : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            The operators, for wich the expectation value should be taken,
            All operators should all have the same number of legs (namely `2 n`).
            If less than ``len(sites)`` operators are given, we repeat them periodically.
            Strings (like ``'Id', 'Sz'``) are translated into single-site operators defined by
            `self.sites`.
        sites : list
            List of site indices. Expectation values are evaluated there.
            If ``None`` (default), the entire chain is taken (clipping for finite b.c.)
        axes : None | (list of str, list of str)
            Two lists of each `n` leg labels giving the physical legs of the operator used for
            contraction. The first `n` legs are contracted with conjugated B`s,
            the second `n` legs with the non-conjugated `B`.
            ``None`` defaults to ``(['p'], ['p*'])`` for single site (n=1), or
            ``(['p0', 'p1', ... 'p{n-1}'], ['p0*', 'p1*', .... 'p{n-1}*'])`` for `n` > 1.

        Returns
        -------
        exp_vals : 1D ndarray
            Expectation values, ``exp_vals[i] = <bra|ops[i]|ket>``, where ``ops[i]`` acts on
            site(s) ``j, j+1, ..., j+{n-1}`` with ``j=sites[i]``.

        Examples
        --------
        One site examples (n=1):

        >>> psi.expectation_value('Sz')
        [Sz0, Sz1, ..., Sz{L-1}]
        >>> psi.expectation_value(['Sz', 'Sx'])
        [Sz0, Sx1, Sz2, Sx3, ... ]
        >>> psi.expectation_value('Sz', sites=[0, 3, 4])
        [Sz0, Sz3, Sz4]

        Two site example (n=2), assuming homogeneous sites:

        >>> SzSx = npc.outer(psi.sites[0].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                             psi.sites[1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
        >>> psi.expectation_value(SzSx)
        [Sz0Sx1, Sz1Sx2, Sz2Sx3, ... ]   # with len ``L-1`` for finite bc, or ``L`` for infinite

        Example measuring <psi|SzSx|psi2> on each second site, for inhomogeneous sites:

        >>> SzSx_list = [npc.outer(psi.sites[i].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                                   psi.sites[i+1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
                         for i in range(0, psi.L-1, 2)]
        >>> psi.expectation_value(SzSx_list, range(0, psi.L-1, 2))
        [Sz0Sx1, Sz2Sx3, Sz4Sx5, ...]
        """
        ops, sites, n, th_labels, (axes_p, axes_pstar) = self.bra._expectation_value_args(
            ops, sites, axes)
        vLvR_axes_p = ('vR*', 'vL*') + tuple(axes_p)
        E = []
        for i in sites:
            LP = self.get_LP(i, store=True)
            RP = self.get_RP(i, store=True)
            op = ops[i % len(ops)]
            if type(op) == str:
                op = self.ket.sites[i].get_op(op)
            C = self.bra.get_theta(i, n)  # vL, vR, p0, p1, ...
            C = npc.tensordot(op, C, axes=[axes_pstar, th_labels[2:]])  # axes_p + (vL, vR)
            C = npc.tensordot(LP, C, axes=['vR', 'vL'])  # axes_p + (vR*, vR)
            C = npc.tensordot(C, RP, axes=['vR', 'vL'])  # axes_p + (vR*, vL*)
            theta_bra = self.bra.get_theta(i, n)  # th_labels == (vL, vR, p0, p1, ...
            E.append(npc.inner(theta_bra, C, axes=[th_labels, vLvR_axes_p], do_conj=True))
        return np.real_if_close(np.array(E))

    def _contract_LP(self, i, LP):
        """Contract LP with the tensors on site `i` to form ``self._LP[i+1]``"""
        LP = npc.tensordot(LP, self.ket.get_B(i, form='A'), axes=('vR', 'vL'))
        LP = npc.tensordot(
            self.bra.get_B(i, form='A').conj(), LP, axes=(['p*', 'vL*'], ['p', 'vR*']))
        return LP  # labels 'vR*', 'vR'

    def _contract_RP(self, i, RP):
        """Contract RP with the tensors on site `i` to form ``self._RP[i-1]``"""
        RP = npc.tensordot(self.ket.get_B(i, form='B'), RP, axes=('vR', 'vL'))
        RP = npc.tensordot(
            self.bra.get_B(i, form='B').conj(), RP, axes=(['p*', 'vR*'], ['p', 'vL*']))
        return RP  # labels 'vL', 'vL*'

    def _to_valid_index(self, i):
        """Make sure `i` is a valid index (depending on `ket.bc`)."""
        return self.ket._to_valid_index(i)


class TransferMatrix(sparse.linalg.LinearOperator):
    r"""Transfer matrix of two MPS (bra & ket).

    For an iMPS in the thermodynamic limit, we need to find the 'dominant `LP`' (and `RP`).
    This mean nothing else than to take the transfer matrix of the unit cell and find the
    (left/right) eigenvector with the largest (magnitude) eigenvalue, since it will dominate
    :math:` LP (TM)^n` (or :math:`(TM)^n RP`) in the limit :math:`n \rightarrow \infty` - whatever
    the initial `LP` is. This class provides exactly that functionality with :meth:`dominant_vec`

    Given two MPS, we define the transfer matrix as::

        |    ---M[i]---M[i+1]- ... --M[i+L]---
        |       |      |             |
        |    ---N[j]*--N[j+1]* ... --N[j+L]*--

    Here the `M` denotes the `B` of the bra (which are not necessarily in canonical form)
    and `N` the ones of the ket, respectively.

    To view it as a `matrix`, we combine the left and right indices to pipes::

        |  (vL.vL*) ->-TM->- (vR.vR*)

    In general, the transfer matrix is not Hermitian; thus you shouldn't use lanczos.
    Instead, we support to select a charge sector and act on numpy arrays, which allows to use the
    scipy.sparse routines.

    Parameters
    ----------
    bra : MPS
        The MPS which is to be (complex) conjugated
    ket : MPS
        The MPS which is not (complex) conjugated.
    chinfo : :class:`~tenpy.linalg.np_conserved.ChargeInfo`
        The nature of the charge.
    shift_bra : int
        We start the `N` of the bra at site `shift_bra`.
    shift_ket : int | None
        We start the `M` of the ket at site `shift_ket`. ``None`` defaults to `shift_bra`.
    transpose : bool
        Wheter `self.matvec` acts on `RP` (``False``) or `LP` (``True``).
    charge_sector : None | charges | ``0``
        Selects the charge sector of `RP` (for `transpose`=False) which is used for the `matvec`
        with a dense ndarray.
        ``None`` (default) stands for *all* sectors, ``0`` stands for the zero-charge sector.
        Defaults to ``0``, i.e. assumes the dominant eigenvector is in charge sector 0.


    Attributes
    ----------
    L : int
        Number of physical sites involved in the transfer matrix, i.e. the least common multiple
        of `bra.L` and `ket.L`.
    shape : (int, int)
        The dimensions for the selected charge sector.
    transpose : bool
        Wheter `self.matvec` acts on `RP` (``True``) or `LP` (``False``).
    qtotal : charges
        Total charge of the transfer matrix (which is gauged away in matvec).
    matvec_count : int
        The number of `matvec` operations performed.
    _bra_N : list of npc.Array
        The matrices of the ket, transposed for fast `matvec`.
    _ket_M : list of npc.Array
        Complex conjugated matrices of the bra, transposed for fast `matvec`.
    _pipe : :class:`~tenpy.linalg.charges.LegPipe`
        Pipe corresponding to ``'(vL.vL*)'`` for ``transpose=False``
        or to ``'(vR.vR*)'`` for ``transpose=True``.
    _charge_sector : None | charges
        The charge sector of `RP` (`LP` for `transpose`) which is used for the :meth:`matvec`
        with a dense ndarray. ``None`` stands for *all* sectors.
    _mask : bool ndarray
        Selects the indices of the pipe which are used in `matvec`, mapping from

    .. todo :
        One could create a separate `LinearOperator` class working for both numpy and npc arrays,
        checking the type of `vec`. Should implement `matvec` exactly as here.
        Problem: At the time of writing this, the npc.Lanczos.LinalgOperator takes also
        rank-n Arrays as inputs for matvec! How to infer `shape` and the necessary pipe?
        Still, might also be useful in algorithms.exact_diag
    """

    def __init__(self, bra, ket, shift_bra=0, shift_ket=None, transpose=False, charge_sector=0):
        L = self.L = lcm(bra.L, ket.L)
        if shift_ket is None:
            shift_ket = shift_bra
        self.shift_bra = shift_bra
        self.shift_ket = shift_ket
        self.transpose = transpose
        if not transpose:
            M = self._ket_M = [
                ket.get_B(i, form=None).itranspose(['vL', 'p', 'vR'])
                for i in range(shift_ket, shift_ket + L)
            ]
            N = self._bra_N = [
                bra.get_B(i, form=None).conj().itranspose(['p*', 'vR*', 'vL*'])
                for i in range(shift_bra, shift_bra + L)
            ]
        else:
            M = self._ket_M = [
                ket.get_B(i, form=None).itranspose(['vR', 'p', 'vL'])
                for i in range(shift_ket, shift_ket + L)
            ]
            N = self._bra_N = [
                bra.get_B(i, form=None).conj().itranspose(['p*', 'vL*', 'vR*'])
                for i in range(shift_bra, shift_bra + L)
            ]
        self.chinfo = bra.chinfo
        if ket.chinfo != bra.chinfo:
            raise ValueError("incompatible charges")
        self.qtotal = self.chinfo.make_valid(np.sum([B.qtotal for B in M + N]))
        self._pipe = npc.LegPipe([M[0].get_leg('vL'), N[0].get_leg('vL*')], qconj=+1)
        if transpose:
            self._pipe = self._pipe.conj()
        self.shape = (self._pipe_L.ind_len, self._pipe_R.ind_len)
        self.dtype = np.promote_types(bra.dtype, ket.dtype)
        self.matvec_count = 0
        self._charge_sector = None
        self._mask = None
        self.charge_sector = charge_sector  # uses the setter

    @property
    def charge_sector(self, value):
        """Charge sector of `RP` (`LP` for `transpose) which is used for `matvec` with ndarray."""
        return self._charge_sector

    @charge_sector.setter
    def charge_sector(self, value):
        if type(value) == int and value == 0:
            value = self.chinfo.make_valid()  # zero charges
        elif value is not None:
            value = self.chinfo.make_valid(value)
        self._charge_sector = value
        if value is not None:
            self._mask = np.all(self._pipe.to_qflat() == value[np.newaxis, :], axis=1)
            self.shape = tuple([np.sum(self._mask)] * 2)
        else:
            chi2 = self._pipe.ind_len
            self.shape = (chi2, chi2)
            self._mask = np.ones([chi2], dtype=np.bool)

    def matvec(self, vec):
        """Apply the transfer matrix to `vec`.

        Parameters
        ----------
        vec : ndarray | :class:`~tenpy.linalg.np_conserved.Array`


        Returns
        -------
        TM_vec : ndarray | :class:`~tenpy.linalg.np_conserved.Array`
            The transfer matrix applied to `vec`, in the same type/form as `vec` was given.
        """
        self.matvec_count += 1
        # handle both npc Arrays and ndarray....
        if isinstance(vec, npc.Array):
            return self._matvec_npc(vec)
        return self._matvec_flat(vec)

    def _matvec_flat(self, vec):
        """matvec operation for a numpy ndarray corresponding to the selected charge sector."""
        # convert into npc Array
        npc_vec = self._flat_to_npc(vec)
        # apply the transfer matrix
        npc_vec = self._matvec_npc(npc_vec)
        # convert back into numpy ndarray.
        return self._npc_to_flat(npc_vec)

    def _matvec_npc(self, vec):
        """Given `vec` as an npc.Array, apply the transfer matrix.

        The bra/ket legs of `vec` may be in a pipe, but don't have to be.
        We return it the same way as we got it (with the same legs and charges)."""
        pipe = None
        if vec.rank == 1:
            vec.split_legs(0)
            pipe = self._pipe
        qtotal = vec.qtotal
        legs = vec.legs
        # the actual work
        if not self.transpose:
            for N, M in itertools.izip(self._bra_N, self._ket_M):
                vec = npc.tensordot(M, vec, axes=['vR', 'vL'])
                vec = npc.tensordot(vec, N, axes=[['p', 'vL*'], ['p*', 'vR*']])
        else:
            for N, M in itertools.izip(reversed(self._bra_N), reversed(self._ket_M)):
                vec = npc.tensordot(M, vec, axes=['vL', 'vR'])
                vec = npc.tensordot(vec, N, axes=[['p', 'vR*'], ['p*', 'vL*']])
        if np.any(self.qtotal != 0):
            # Hack: replace leg charges and qtotal -> effectively gauge `self.qtotal` away.
            vec.qtotal = qtotal
            vec.legs = legs
            vec.test_sanity()  # Should be fine, but who knows...
        if pipe is not None:
            vec = vec.combine_legs([0, 1], pipes=pipe)
        return vec

    def _flat_to_npc(self, vec):
        """Convert flat vector of selected charge sector into npc Array.

        Parameters
        ----------
        vec : 1D ndarray
            Numpy vector to be converted. Should have the entries according to self.charge_sector.

        Returns
        -------
        npc_vec : :class:`~tenpy.linalg.np_conserved.Array`
            Same as `vec`, but converted into a flat array.
        """
        full_vec = np.zeros(self._pipe.ind_len)
        full_vec[self._mask] = vec
        return npc.Array.from_ndarray(full_vec, [self._pipe])

    def _npc_to_flat(self, npc_vec):
        """Convert npc Array with qtotal = self.charge_sector into ndarray.

        Parameters
        ----------
        npc_vec : :class:`~tenpy.linalg.np_conserved.Array`
            Npc Array to be converted. Should have the entries according to self.charge_sector.

        Returns
        -------
        vec : 1D ndarray
            Same as `npc_vec`, but converted into a flat array.
        """
        if self._charge_sector is not None and np.any(npc_vec.qtotal != self._charge_sector):
            raise ValueError("npc_vec.qtotal and charge sector don't match!")
        return npc_vec.to_ndarray()[self._mask]

    def eigenvectors(self, num_ev=1, max_num_ev=None, max_tol=1.e-12, which='LM', **kwargs):
        """Find (dominant) eigenvector(s) of self using scipy.sparse.

        If no charge_sector was selected, we look in *all* charge sectors.

        Parameters
        ----------
        num_ev : int
            Number of eigenvalues/vectors to look for.
        max_num_ev : int
            :func:`scipy.sparse.linalg.speigs` somtimes raises a NoConvergenceError for small
            `num_ev`, which might be avoided by increasing `num_ev`. As a work-around,
            we try it again in the case of an error, just with larger `num_ev` up to `max_num_ev`.
            ``None`` defaults to ``num_ev + 2``.
        max_tol : float
            After the first `NoConvergenceError` we increase the `tol` argument to that value.
        which : str
            Which eigenvalues to look for, see `scipy.sparse.linalg.speigs`.
        **kwargs :
            Further keyword arguments are given to :func:`~tenpy.tools.math.speigs`.

        Returns
        -------
        eta : 1D ndarray
            The eigenvalues, sorted according to `which`.
        w : list of :class:`~tenpy.linalg.np_conserved.Array`
            The eigenvectors corresponding to `eta`, as npc.Array with LegPipe.
        """
        if max_num_ev is None:
            max_num_ev = num_ev + 2
        if self.charge_sector is None:
            # Try for all charge sectors
            eta = []
            A = []
            for chsect in self._pipe.charge_sectors():
                self.charge_sector = chsect
                eta_cs, A_cs = self.eigenvectors(num_ev, max_num_ev, max_tol, which, **kwargs)
                eta.extend(eta_cs)
                A.extend(A_cs)
            self.charge_sector = None
        else:
            # for given charge sector
            for k in xrange(num_ev, max_num_ev + 1):
                if k > num_ev:
                    warnings.warn("increased `num_ev` to " + str(k + 1))
                try:
                    eta, A = speigs(self, k=k, which='LM', **kwargs)
                    A = np.real_if_close(A)
                    A = [self._flat_to_npc(A[:, j]) for j in range(A.shape[1])]
                    break
                except scipy.sparse.linalg.eigen.arpack.ArpackNoConvergence:
                    if k == max_num_ev:
                        raise
                    # just retry with larger k and 'tol'
                    kwargs['tol'] = max(max_tol, kwargs.get('tol', 0))
        # sort
        perm = argsort(eta, which)
        return np.array(eta)[perm], [A[j] for j in perm]
