r"""This module contains a base class for a Matrix Product State (MPS).

An MPS looks roughly like this::

    |   -- B[0] -- B[1] -- B[2] -- ...
    |       |       |      |

We use the following label convention for the `B` (where arrows indicate `qconj`)::

    |  vL ->- B ->- vR
    |         |
    |         ^
    |         p

We store one 3-leg tensor `_B[i]` with labels ``'vL', 'vR', 'p'`` for each of the `L` sites
``0 <= i < L``.
Additionally, we store ``L+1`` singular value arrays `_S[ib]` on each bond ``0 <= ib <= L``,
independent of the boundary conditions.
``_S[ib]`` gives the singlur values on the bond ``i-1, i``.
However, be aware that e.g. :attr:`~tenpy.networks.mps.MPS.chi` returns only the dimensions of the
:attr:`~tenpy.networks.mps.MPS.nontrivial_bonds` depending on the boundary conditions.

The matrices and singular values always represent a normalized state
(i.e. ``np.linalg.norm(psi._S[ib]) == 1`` up to roundoff errors),
but we keep track of the norm in :attr:`~tenpy.networks.mps.MPS.norm`
(which is respected by :meth:`~tenpy.networks.mps.MPS.overlap`, ...).

For efficient simulations, it is crucial that the MPS is in a 'canonical form'.
The different forms and boundary conditions are easiest described in Vidal's
:math:`\Gamma, \Lambda` notation [Vidal2004]_.

Valid MPS boundary conditions (not to confuse with `bc_coupling` of
:class:`tenpy.models.model.CouplingModel`)  are the following:

==========  ===================================================================================
`bc`        description
==========  ===================================================================================
'finite'    Finite MPS, ``G0 s1 G1 ... s{L-1} G{l-1}``. This is acchieved
            by using a trivial left and right bond ``s[0] = s[-1] = np.array([1.])``.
'segment'   Generalization of 'finite', describes an MPS embedded in left and right
            environments. The left environment is described by ``chi[0]`` *orthonormal* states
            which are weighted by the singular values ``s[0]``. Similar, ``s[L]`` weight some
            right orthonormal states. You can think of the left and right states to be
            generated by additional MPS, such that the overall structure is something like
            ``... s L s L [s0 G0 s1 G1 ... s{L-1} G{L-1} s{L}] R s R s R ...``
            (where we save the part in the brackets ``[ ... ]`` ).
'infinite'  infinite MPS (iMPS): we save a 'MPS unit cell' ``[s0 G0 s1 G1 ... s{L-1} G{L-1}]``
            which is repeated periodically, identifying all indices modulo ``self.L``.
            In particular, the last bond ``L`` is identified with ``0``.
            (The MPS unit cell can differ from a lattice unit cell).
            bond is identified with the first one.
==========  ===================================================================================

An MPS can be in different 'canonical forms' (see [Vidal2004]_, [Schollwoeck2011]_).
To take care of the different canonical forms, algorithms should use functions like
:meth:`~tenpy.networks.mps.MPS.get_theta`, :meth:`~tenpy.networks.mps.MPS.get_B`
and :meth:`~tenpy.networks.mps.MPS.set_B` instead of accessing them directly,
as they return the `B` in the desired form (which can be chosed as an argument).

======== ========== =======================================================================
`form`   tuple      description
======== ========== =======================================================================
``'B'``  (0, 1)     right canonical: ``_B[i] = -- Gamma[i] -- s[i+1]--``
                    The default form, which algorithms asssume.
``'C'``  (0.5, 0.5) symmetric form: ``_B[i] = -- s[i]**0.5 -- Gamma[i] -- s[i+1]**0.5--``
``'A'``  (1, 0)     left canonical: ``_B[i] = -- s[i] -- Gamma[i] --``.
                    For stability reasons, we recommend to *not* use this form.
``'G'``  (0, 0)     Save only ``_B[i] = -- Gamma[i] --``.
                    For stability reasons, we recommend to *not* use this form.
``None`` ``None``   General non-canoncial form.
                    Valid form for initialization, but you need to call
                    :meth:`~tenpy.networks.mps.MPS.canonical_form` (or similar)
                    before using algorithms.
======== ========== =======================================================================
"""
# Copyright 2018 TeNPy Developers

import numpy as np
import warnings
import scipy.sparse.linalg.eigen.arpack

from ..linalg import np_conserved as npc
from ..linalg import sparse
from ..tools.misc import to_iterable, argsort
from ..tools.math import lcm, speigs, entropy
from functools import reduce
from ..algorithms.truncation import TruncationError, svd_theta

__all__ = ['MPS', 'MPSEnvironment', 'TransferMatrix']


class MPS(object):
    r"""A Matrix Product State, finite (MPS) or infinite (iMPS).

    Parameters
    ----------
    sites : list of :class:`~tenpy.networks.site.Site`
        Defines the local Hilbert space for each site.
    Bs : list of :class:`~tenpy.linalg.np_conserved.Array`
        The 'matrices' of the MPS. Labels are ``vL, vR, p`` (in any order).
    SVs : list of 1D array
        The singular values on *each* bond. Should always have length `L+1`.
        Entries out of :attr:`nontrivial_bonds` are ignored.
    bc : ``'finite' | 'segment' | 'infinite'``
        Boundary conditions as described in the tabel of the module doc-string.
    form : (list of) {``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)}
        The form of the stored 'matrices', see table in module doc-string.
        A single choice holds for all of the entries.

    Attributes
    ----------
    L
    chi
    finite
    nontrivial_bonds
    sites : list of :class:`~tenpy.networks.site.Site`
        Defines the local Hilbert space for each site.
    bc : {'finite', 'segment', 'infinite'}
        Boundary conditions as described in above table.
    form : list of {``None`` | tuple(float, float)}
        Describes the canonical form on each site.
        ``None`` means non-canonical form.
        For ``form = (nuL, nuR)``, the stored ``_B[i]`` are
        ``s**form[0] -- Gamma -- s**form[1]`` (in Vidal's notation).
    chinfo : :class:`~tenpy.linalg.np_conserved.ChargeInfo`
        The nature of the charge.
    dtype : type
        The data type of the ``_B``.
    norm : float
        The norm of the state, i.e. ``sqrt(<psi|psi>)``.
    _B : list of :class:`npc.Array`
        The 'matrices' of the MPS. Labels are ``vL, vR, p`` (in any order).
        We recommend using :meth:`get_B` and :meth:`set_B`, which will take care of the different
        canonical forms.
    _S : None | list of 1D arrays
        The singular values on each virtual bond, length ``L+1``.
        May be ``None`` if the MPS is not in canonical form.
        Otherwise, ``_S[i]`` is to the left of ``_B[i]``.
        We recommend using :meth:`get_SL`, :meth:`get_SR`, :meth:`set_SL`, :meth:`set_SR`, which
        take proper care of the boundary conditions.
    _valid_forms : dict
        Mapping for canonical forms to a tuple ``(nuL, nuR)`` indicating that
        ``self._Bs[i] = s[i]**nuL -- Gamma[i] -- s[i]**nuR`` is saved.
    _valid_bc : tuple of str
        Valid boundary conditions.
    """

    # Canonical form conventions: the saved B = s**nu[0]--Gamma--s**nu[1].
    # For the canonical forms, ``nu[0] + nu[1] = 1``
    _valid_forms = {
        'A': (1., 0.),
        'C': (0.5, 0.5),
        'B': (0., 1.),
        'G': (0., 0.),  # like Vidal's `Gamma`.
        None: None,  # means 'not in any canonical form'
    }

    # valid boundary conditions. Don't overwrite this!
    _valid_bc = ('finite', 'segment', 'infinite')
    _p_label = ['p']

    def __init__(self, sites, Bs, SVs, bc='finite', form='B', norm=1.):
        self.sites = list(sites)
        self.chinfo = self.sites[0].leg.chinfo
        self.dtype = dtype = np.find_common_type([B.dtype for B in Bs], [])
        self.form = self._parse_form(form)
        self.bc = bc  # one of ``'finite', 'periodic', 'segment'``.
        self.norm = norm

        # make copies of Bs and SVs
        self._B = [B.astype(dtype, copy=True) for B in Bs]
        self._S = [None] * (self.L + 1)
        for i in range(self.L + 1)[self.nontrivial_bonds]:
            self._S[i] = np.array(SVs[i], dtype=np.float)
        if self.bc == 'infinite':
            self._S[-1] = self._S[0]
        elif self.bc == 'finite':
            self._S[0] = self._S[-1] = np.ones([1])
        self.test_sanity()

    def test_sanity(self):
        """Sanity check. Raises Errors if something is wrong."""
        if self.bc not in self._valid_bc:
            raise ValueError("invalid boundary condition: " + repr(self.bc))
        if len(self._B) != self.L:
            raise ValueError("wrong len of self._B")
        if len(self._S) != self.L + 1:
            raise ValueError("wrong len of self._S")
        for i, B in enumerate(self._B):
            if not set(['vL', 'vR', 'p']) <= set(B.get_leg_labels()):
                raise ValueError("B has wrong labels " + repr(B.get_leg_labels()))
            B.test_sanity()  # recursive...
            if self._S[i].shape[-1] != B.get_leg('vL').ind_len or \
                    self._S[i+1].shape[0] != B.get_leg('vR').ind_len:
                raise ValueError("shape of B incompatible with len of singular values")
            if not self.finite or i + 1 < self.L:
                B2 = self._B[(i + 1) % self.L]
                B.get_leg('vR').test_contractible(B2.get_leg('vL'))
        if self.bc == 'finite':
            if len(self._S[0]) != 1 or len(self._S[-1]) != 1:
                raise ValueError("non-trivial outer bonds for finite MPS")
        elif self.bc == 'infinite':
            if np.any(self._S[self.L] != self._S[0]):
                raise ValueError("iMPS with S[0] != S[L]")
        assert len(self.form) == self.L
        for f in self.form:
            if f is not None:
                assert isinstance(f, tuple)
                assert len(f) == 2

    @classmethod
    def from_product_state(cls,
                           sites,
                           p_state,
                           bc='finite',
                           dtype=np.float,
                           form='B',
                           chargeL=None):
        """Construct a matrix product state from a given product state.

        Parameters
        ----------
        sites : list of :class:`~tenpy.networks.site.Site`
            The sites defining the local Hilbert space.
        p_state : iterable of {int | str | 1D array}
            Defines the product state.
            If ``p_state[i]`` is int, then site ``i`` is in state ``p_state[i]``.
            If ``p_state[i]`` is str, then site ``i`` is in state
            ``self.sites[i].state_label(p_state[i])``.
            If ``p_state[i]`` is an array, then site ``i`` wavefunction is ``p_state[i]``.
            Note that what an int means can change depending in the charges;
            see the warning in the doc-string of :class:`~tenpy.networks.site.Site`.
        bc : {'infinite', 'finite', 'segmemt'}
            MPS boundary conditions. See docstring of :class:`MPS`.
        dtype : type or string
            The data type of the array entries.
        form : (list of) {``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)}
            Defines the canonical form. See module doc-string.
            A single choice holds for all of the entries.
        chargeL : charges
            Leg charge at bond 0, which are purely conventional.

        Returns
        -------
        product_mps : :class:`MPS`
            An MPS representing the specified product state.

        Examples
        --------
        To get a Neel state:

        >>> M = SpinChain(model_params)
        >>> p_state = ["up", "down"]*(L//2)  # repeats entries L/2 times
        >>> psi = MPS.from_product_state(M.lat.mps_sites(), p_state, bc=M.lat.bc_MPS)

        For Spin S=1/2, you could get a state with all sites pointing in negative x-direction with:

        >>> neg_x_state = np.array([1., -1.])
        >>> p_state = [neg_x_state/np.linalg.norm(neg_x_state)]*L  # other parameters as above

        """
        sites = list(sites)
        L = len(sites)
        p_state = list(p_state)
        if len(p_state) != L:
            raise ValueError("Length of p_state does not match number of sites.")
        ci = sites[0].leg.chinfo
        Bs = []
        chargeL = ci.make_valid(chargeL)  # sets to zero if `None`
        legL = npc.LegCharge.from_qflat(ci, [chargeL])
        for p_st, site in zip(p_state, sites):
            try:
                iter(p_st)
                if len(p_st) != site.dim:
                    raise ValueError("p_state incompatible with local dim:" + repr(p_st))
                B = np.array(p_st, dtype).reshape((site.dim, 1, 1))
            except TypeError:
                B = np.zeros((site.dim, 1, 1), dtype)
                B[p_st, 0, 0] = 1.0
            Bs.append(B)
        SVs = [[1.]] * (L + 1)
        return cls.from_Bflat(sites, Bs, SVs, bc=bc, dtype=dtype, form=form, legL=legL)

    @classmethod
    def from_Bflat(cls, sites, Bflat, SVs=None, bc='finite', dtype=np.float, form='B', legL=None):
        """Construct a matrix product state from a given product state.

        Parameters
        ----------
        sites : list of :class:`~tenpy.networks.site.Site`
            The sites defining the local Hilbert space.
        Bflat : iterable of numpy ndarrays
            The matrix defining the MPS on each site, with legs ``'p', 'vL', 'vR'``
            (physical, virtual left/right).
        SVs : list of 1D array | ``None``
            The singular values on *each* bond. Should always have length `L+1`.
            By default (``None``), set all singular values to the same value.
            Entries out of :attr:`nontrivial_bonds` are ignored.
        bc : {'infinite', 'finite', 'segmemt'}
            MPS boundary conditions. See docstring of :class:`MPS`.
        dtype : type or string
            The data type of the array entries.
        form : (list of) {``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)}
            Defines the canonical form of `Bflat`. See module doc-string.
            A single choice holds for all of the entries.
        leg_L : LegCharge | ``None``
            Leg charges at bond 0, which are purely conventional.
            If ``None``, use trivial charges.

        Returns
        -------
        mps : :class:`MPS`
            An MPS with the matrices `Bflat` converted to npc arrays.
        """
        sites = list(sites)
        L = len(sites)
        Bflat = list(Bflat)
        if len(Bflat) != L:
            raise ValueError("Length of Bflat does not match number of sites.")
        ci = sites[0].leg.chinfo
        if legL is None:
            legL = npc.LegCharge.from_qflat(ci, [ci.make_valid(None)] * Bflat[0].shape[1])
        if SVs is None:
            SVs = [np.ones(B.shape[1]) / np.sqrt(B.shape[1]) for B in Bflat]
            SVs.append(np.ones(Bflat[-1].shape[2]) / np.sqrt(Bflat[-1].shape[2]))
        Bs = []
        for i, site in enumerate(sites):
            B = np.array(Bflat[i], dtype)
            # calculate the LegCharge of the right leg
            legs = [site.leg, legL, None]  # other legs are known
            legs = npc.detect_legcharge(B, ci, legs, None, qconj=-1)
            B = npc.Array.from_ndarray(B, legs, dtype)
            B.iset_leg_labels(['p', 'vL', 'vR'])
            Bs.append(B)
            legL = legs[-1].conj()  # prepare for next `i`
        if bc == 'infinite':
            # for an iMPS, the last leg has to match the first one.
            # so we need to gauge `qtotal` of the last `B` such that the right leg matches.
            chdiff = Bs[-1].get_leg('vR').charges[0] - Bs[0].get_leg('vL').charges[0]
            Bs[-1] = Bs[-1].gauge_total_charge('vR', ci.make_valid(chdiff))
        return cls(sites, Bs, SVs, form=form, bc=bc)

    @classmethod
    def from_full(cls, sites, psi, form='B', cutoff=1.e-16, normalize=True):
        """Construct an MPS from a single tensor `psi` with one leg per physical site.

        Performs a sequence of SVDs of psi to split off the `B` matrices and obtain the singular
        values, the result will be in canonical form.
        Obviously, this is only well-defined for `finite` boundary conditions.

        Parameters
        ----------
        sites : list of :class:`~tenpy.networks.site.Site`
            The sites defining the local Hilbert space.
        psi : :class:`~tenpy.linalg.np_conserved.Array`
            The full wave function to be represented as an MPS.
            Should have labels ``'p0', 'p1', ...,  'p{L-1}'``.
        form  : ``'B' | 'A' | 'C' | 'G'``
            The canonical form of the resulting MPS, see module doc-string.
        cutoff : float
            Cutoff of singular values used in the SVDs.
        normalize : bool
            Whether the resulting MPS should have 'norm' 1.

        Returns
        -------
        psi_mps : :class:`MPS`
            MPS representation of `psi`, in canonical form and possibly normalized.
        """
        if form not in ['B', 'A', 'C', 'G']:
            raise ValueError("Invalid form: " + repr(form))
        # perform SVDs to bring it into 'B' form, afterwards change the form.
        L = len(sites)
        assert (L >= 2)
        B_list = [None] * L
        S_list = [1] * (L + 1)
        norm = 1. if normalize else npc.norm(psi)
        labels = ['p' + str(i) for i in range(L)]
        psi.itranspose(labels)
        # combine legs from left
        psi = psi.add_trivial_leg(0, label='vL', qconj=+1)
        for i in range(0, L - 1):
            psi = psi.combine_legs([0, 1])  # combines the legs until `i`
        psi = psi.add_trivial_leg(2, label='vR', qconj=-1)
        # now psi has only three legs: ``'(((vL.p0).p1)...p{L-2})', 'p{L-1}', 'vR'``
        for i in range(L - 1, 0, -1):
            # split off B[i]
            psi = psi.combine_legs([labels[i], 'vR'])
            psi, S, B = npc.svd(psi, inner_labels=['vR', 'vL'], cutoff=cutoff)
            S /= np.linalg.norm(S)  # normalize
            psi.iscale_axis(S, 1)
            B_list[i] = B.split_legs(1).replace_label(labels[i], 'p')
            S_list[i] = S
            psi = psi.split_legs(0)
        psi = psi.combine_legs([labels[0], 'vR'])
        psi, S, B = npc.svd(
            psi, qtotal_LR=[None, psi.qtotal], inner_labels=['vR', 'vL'], cutoff=cutoff)
        assert (psi.shape == (1, 1))
        S_list[0] = np.ones([1], dtype=np.float)
        B_list[0] = B.split_legs(1).replace_label(labels[0], 'p')
        res = cls(sites, B_list, S_list, bc='finite', form='B', norm=norm)
        if form != 'B':
            res.convert_form(form)
        return res

    @classmethod
    def from_singlets(cls,
                      site,
                      L,
                      pairs,
                      up='up',
                      down='down',
                      lonely=[],
                      lonely_state=0,
                      bc='finite'):
        """Create an MPS of entangled singlets.

        Parameters
        ----------
        site : :class:`~tenpy.networks.site.Site`
            The `site` defining the local Hilbert space, taken uniformly for all sites.
        L : int
            The number of sites.
        pairs : list of (int, int)
            Pairs of sites to be entangled; the returned MPS will have a singlet
            for each pair in `pairs`.
        up, down : int | str
            A singlet is defined as ``(|up down> - |down up>)/2**0.5``,
            ``up`` and ``down`` give state indices or labels defined on the corresponding site.
        lonely : list of int
            Sites which are not included into a singlet pair.
        lonely_state : int | str
            The state for the lonely sites.
        bc : {'infinite', 'finite', 'segmemt'}
            MPS boundary conditions. See docstring of :class:`MPS`.

        Returns
        -------
        singlet_mps : :class:`MPS`
            An MPS representing singlets on the specified bonds.
        """
        # sort each pair s.t. i < j
        pairs = [((i, j) if i < j else (j, i)) for (i, j) in pairs]
        # sort by smaller site of the pair
        pairs.sort(key=lambda x: x[0])
        pairs.append((L, L))
        lonely = sorted(lonely) + [L]
        # generate building block tensors
        up = site.state_index(up)
        down = site.state_index(down)
        mask = np.zeros(site.dim, dtype=np.bool_)
        mask[up] = mask[down] = True
        Open = npc.diag(1., site.leg)[:, mask]
        Close = np.zeros([site.dim, site.dim], dtype=np.float_)
        Close[up, down] = -1.
        Close[down, up] = 1.
        Close = npc.Array.from_ndarray(Close, [site.leg, site.leg])  # no conj() !
        Close = Close[mask, :]
        Id = npc.eye_like(Close, 0)
        Lonely = np.zeros(site.dim, dtype=np.float_)
        Lonely[lonely_state] = 1
        Lonely = npc.Array.from_ndarray(Lonely, [site.leg])
        Bs = []
        Ss = [np.ones(1)]
        forms = []
        open_singlets = []  # the k-th open singlet should be closed at site open_singlets[k]
        Ts = []  # the tensors on the current site
        labels_L = []
        for i in range(L):
            labels_R = labels_L[:]
            next_Ts = Ts[:]
            if i == pairs[0][0]:  # open a new singlet
                j = pairs[0][1]
                lbl = 's{0:d}-{1:d}'.format(i, j)
                pairs.pop(0)
                open_singlets.append(j)
                next_Ts.append(Id.copy().iset_leg_labels([lbl + 'L', lbl]))
                Open.iset_leg_labels(['p', lbl])
                Ts.append(Open.copy(deep=False))
                labels_R.append(lbl)
                forms.append('A')
            elif i == lonely[0]:  # just a lonely state
                Ts.append(Lonely)
                lonely.pop(0)
                forms.append('B')
            else:  # close a singlet
                k = open_singlets.index(i)
                Close.iset_leg_labels([labels_L[k] + 'L', 'p'])
                Ts[k] = Close
                next_Ts.pop(k)
                open_singlets.pop(k)
                labels_R.pop(k)
                forms.append('B')
            # generate `B` from `Ts`
            B = reduce(npc.outer, Ts)
            labels_L = [lbl_ + 'L' for lbl_ in labels_L]
            if len(labels_L) > 0 and len(labels_R) > 0:
                B = B.combine_legs([labels_L, labels_R], new_axes=[0, 2], qconj=[+1, -1])
                B.iset_leg_labels(['vL', 'p', 'vR'])
            elif len(labels_L) == 0 and len(labels_R) == 0:
                B = B.add_trivial_leg(0, label='vL', qconj=+1)
                B = B.add_trivial_leg(2, label='vR', qconj=+1)
                B.iset_leg_labels(['vL', 'p', 'vR'])
            elif len(labels_L) == 0:
                B = B.combine_legs([labels_R], new_axes=[1], qconj=[-1])
                B.iset_leg_labels(['p', 'vR'])
                B = B.add_trivial_leg(0, label='vL', qconj=+1)
            else:  # len(labels_R) == 0
                B = B.combine_legs([labels_L], new_axes=[0], qconj=[+1])
                B.iset_leg_labels(['vL', 'p'])
                B = B.add_trivial_leg(2, label='vR', qconj=+1)
            Bs.append(B)
            N = 2**len(labels_R)
            Ss.append(np.ones(N) / (N**0.5))
            Ts = next_Ts
            labels_L = labels_R
        return cls([site] * L, Bs, Ss, bc=bc, form=forms)

    def copy(self):
        """Returns a copy of `self`.

        The copy still shares the sites, chinfo, and LegCharges of the _B,
        but the values of B and S are deeply copied.
        """
        # __init__ makes deep copies of B, S
        return MPS(self.sites, self._B, self._S, self.bc, self.form, self.norm)

    @property
    def L(self):
        """Number of physical sites. For an iMPS the len of the MPS unit cell."""
        return len(self.sites)

    @property
    def dim(self):
        """List of local physical dimensions."""
        return [site.dim for site in self.sites]

    @property
    def finite(self):
        "Distinguish MPS (``True; bc='finite', 'segment'`` ) vs. iMPS (``False; bc='infinite'``)"
        assert (self.bc in self._valid_bc)
        return self.bc != 'infinite'

    @property
    def chi(self):
        """Dimensions of the (nontrivial) virtual bonds."""
        # s.shape[0] == len(s) for 1D numpy array, but works also for a 2D npc Array.
        return [s.shape[0] for s in self._S[self.nontrivial_bonds]]

    @property
    def nontrivial_bonds(self):
        """Slice of the non-trivial bond indices, depending on ``self.bc``."""
        if self.bc == 'finite':
            return slice(1, self.L)
        elif self.bc == 'segment':
            return slice(0, self.L + 1)
        elif self.bc == 'infinite':
            return slice(0, self.L)

    def get_B(self, i, form='B', copy=False, cutoff=1.e-16):
        """Return (view of) `B` at site `i` in canonical form.

        Parameters
        ----------
        i : int
            Index choosing the site.
        form : ``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)
            The (canonical) form of the returned B.
            For ``None``, return the matrix in whatever form it is.
        copy : bool
            Whether to return a copy even if `form` matches the current form.
        cutoff : float
            During DMRG with a mixer, `S` may be a matrix for which we need the inverse.
            This is calculated as the Penrose pseudo-inverse, which uses a cutoff for the
            singular values.

        Returns
        -------
        B : :class:`~tenpy.linalg.np_conserved.Array`
            The MPS 'matrix' `B` at site `i` with leg labels ``vL, vR, p`` (in undefined order).
            May be a view of the matrix (if ``copy=False``),
            or a copy (if the form changed or ``copy=True``)

        Raises
        ------
        ValueError : if self is not in canoncial form and ``form != None``.
        """
        i = self._to_valid_index(i)
        form = self._to_valid_form(form)
        return self._convert_form_i(self._B[i], i, self.form[i], form, copy, cutoff)

    def set_B(self, i, B, form='B'):
        """Set `B` at site `i`.

        Parameters
        ----------
        i : int
            Index choosing the site.
        B : :class:`~tenpy.linalg.np_conserved.Array`
            The 'matrix' at site `i`. Should have leg labels ``vL, vR, p`` (in any order).
        form : ``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)
            The (canonical) form of the `B` to set.
            ``None`` stands for non-canonical form.
        """
        i = self._to_valid_index(i)
        self.form[i] = self._to_valid_form(form)
        self.dtype = np.find_common_type([self.dtype, B.dtype], [])
        self._B[i] = B

    def get_SL(self, i):
        """Return singular values on the left of site `i`"""
        i = self._to_valid_index(i)
        return self._S[i]

    def get_SR(self, i):
        """Return singular values on the right of site `i`"""
        i = self._to_valid_index(i)
        return self._S[i + 1]

    def set_SL(self, i, S):
        """Set singular values on the left of site `i`"""
        i = self._to_valid_index(i)
        self._S[i] = S
        if not self.finite and i == 0:
            self._S[self.L] = S

    def set_SR(self, i, S):
        """Set singular values on the right of site `i`"""
        i = self._to_valid_index(i)
        self._S[i + 1] = S
        if not self.finite and i == self.L - 1:
            self._S[0] = S

    def get_op(self, op_list, i):
        """Given a list of operators, select the one corresponding to site `i`.

        Parameters
        ----------
        op_list : (list of) {str | npc.array}
            List of operators from which we choose. We assume that ``op_list[j]`` acts on site
            ``j``. If the length is shorter than `L`, we repeat it periodically.
            Strings are translated using :meth:`~tenpy.networks.site.Site.get_op` of site `i`.
        i : int
            Index of the site on which the operator acts.

        Returns
        -------
        op : npc.array
            One of the entries in `op_list`, not copied.
        """
        i = self._to_valid_index(i)
        op = op_list[i % len(op_list)]
        if (isinstance(op, str)):
            op = self.sites[i].get_op(op)
        return op

    def get_theta(self, i, n=2, cutoff=1.e-16, formL=1., formR=1.):
        """Calculates the `n`-site wavefunction on ``sites[i:i+n]``.

        Parameters
        ----------
        i : int
            Site index.
        n : int
            Number of sites. The result lives on ``sites[i:i+n]``.
        cutoff : float
            During DMRG with a mixer, `S` may be a matrix for which we need the inverse.
            This is calculated as the Penrose pseudo-inverse, which uses a cutoff for the
            singular values.
        formL : float
            Exponent for the singular values to the left.
        formR : float
            Exponent for the singular values to the right.

        Returns
        -------
        theta : :class:`~tenpy.linalg.np_conserved.Array`
            The n-site wave function with leg labels ``vL, vR, p0, p1, .... p{n-1}``
            (in undefined order).
            In Vidal's notation (with s=lambda, G=Gamma):
            ``theta = s**form_L G_i s G_{i+1} s ... G_{i+n-1} s**form_R``.
        """
        i = self._to_valid_index(i)
        if self.finite:
            if (i < 0 or i + n > self.L):
                raise ValueError("i = {0:d} out of bounds".format(i))

        if self.form[i] is None or self.form[(i + n - 1) % self.L] is None:
            # we allow intermediate `form`=None except at the very left and right.
            raise ValueError("can't calculate theta for non-canonical form")

        # the following code is an equivalent to::
        #
        #   theta = self.get_B(i, form='B').replace_label('p', 'p0')
        #   theta.iscale_axis(self.get_SL(i)** formL, 'vL')
        #   for k in range(1, n):
        #       j = (i + n) % self.L
        #       B = self.get_B(j, form='B').replace_label('p', 'p'+str(k))
        #       theta = npc.tensordot(theta, B, ['vR', 'vL'])
        #   theta.iscale_axis(self.get_SR(i + n - 1)** (formR-1.), 'vR')
        #   return theta
        #
        # However, the following code is nummerically more stable if ``self.form`` is not `B`
        # (since it avoids unnecessary `scale_axis`) and works also for intermediate sites with
        # ``self.form[j] = None``.

        fL, fR = self.form[i]  # left / right form exponent
        copy = (fL == 0 and fR == 0)  # otherwise, a copy is performed later by `scale_axis`.
        theta = self.get_B(i, form=None, copy=copy)  # in the current form
        theta = self._replace_p_label(theta, 0)
        theta = self._scale_axis_B(theta, self.get_SL(i), formL - fL, 'vL', cutoff)
        for k in range(1, n):  # nothing if n=1.
            j = (i + k) % self.L
            B = self._replace_p_label(self.get_B(j, None, False), k)
            if self.form[j] is not None:
                fL_j, fR_j = self.form[j]
                if fR is not None:
                    B = self._scale_axis_B(B, self.get_SL(j), 1. - fL_j - fR, 'vL', cutoff)
                # otherwise we can just hope it's fine.
                fR = fR_j
            else:
                fR = None
            theta = npc.tensordot(theta, B, axes=('vR', 'vL'))
        # here, fR = self.form[i+n-1][1]
        theta = self._scale_axis_B(theta, self.get_SR(i + n - 1), formR - fR, 'vR', cutoff)
        return theta

    def convert_form(self, new_form='B'):
        """Tranform self into different canonical form (by scaling the legs with singular values).

        Parameters
        ----------
        new_form : (list of) {``'B' | 'A' | 'C' | 'G' | None`` | tuple(float, float)}
            The form the stored 'matrices'. The table in module doc-string.
            A single choice holds for all of the entries.

        Raises
        ------
        ValueError : if trying to convert from a ``None`` form. Use :meth:`canonicalize` instead!
        """
        new_forms = self._parse_form(new_form)
        for i, form in enumerate(new_forms):
            new_B = self.get_B(i, form=form, copy=False)  # calculates the desired form.
            self.set_B(i, new_B, form=form)

    def entanglement_entropy(self, n=1, bonds=None, for_matrix_S=False):
        r"""Calculate the (half-chain) entanglement entropy for all nontrivial bonds.

        Consider a bipartition of the sytem into :math:`A = \{ j: j <= i_b \}` and
        :math:`B = \{ j: j > i_b\}` and the reduced density matrix :math:`\rho_A = tr_B(\rho)`.
        The von-Neumann entanglement entropy is defined as
        :math:`S(A, n=1) = -tr(\rho_A \log(\rho_A)) = S(B, n=1)`.
        The generalization for ``n != 1, n>0`` are the Renyi entropies:
        :math:`S(A, n) = \frac{1}{1-n} \log(tr(\rho_A^2)) = S(B, n=1)`

        This function calculates the entropy for a cut at different bonds `i`, for which the
        the eigenvalues of the reduced density matrix :math:`\rho_A` and :math:`\rho_B` is given
        by the squared schmidt values `S` of the bond.

        Parameters
        ----------
        n : int/float
            Selects which entropy to calculate;
            `n=1` (default) is the ususal von-Neumann entanglement entropy.
        bonds : ``None`` | (iterable of) int
            Selects the bonds at which the entropy should be calculated.
            ``None`` defaults to ``range(0, L+1)[self.nontrivial_bonds]``.
        for_matrix_S : bool
            Switch calculate the entanglement entropy even if the `_S` are matrices.
            Since :math:`O(\chi^3)` is expensive compared to the ususal :math:`O(\chi)`,
            we raise an error by default.

        Returns
        -------
        entropies : 1D ndarray
            Entanglement entropies for half-cuts.
            `entropies[j]` contains the entropy for a cut at bond ``bonds[j]``
            (i.e. left to site ``bonds[j]``).
        """
        if bonds is None:
            nt = self.nontrivial_bonds
            bonds = range(nt.start, nt.stop)
        res = []
        for ib in bonds:
            s = self._S[ib]
            if len(s.shape) > 1:
                if for_matrix_S:
                    # explicitly calculate Schmidt values by diagonalizing (s^dagger s)
                    s = npc.eigvals(npc.tensordot(s.conj(), s, axes=[0, 0]))
                    res.append(entropy(s, n))
                else:
                    raise ValueError("entropy with non-diagonal schmidt values")
            else:
                res.append(entropy(s**2, n))
        return np.array(res)

    def entanglement_entropy_segment(self, segment=[0], first_site=None, n=1):
        r"""Calculate entanglement entropy for general geometry of the bipartition.

        This function is similar as :meth:`entanglement_entropy`,
        but for more general geometry of the region `A` to be a segment of a *few* sites.

        This is acchieved by explicitly calculating the reduced density matrix of `A`
        and thus works only for small segments.

        Parameters
        ----------
        segment : list of int
            Given a first site `i`, the region ``A_i`` is defined to be ``[i+j for j in segment]``.
        first_site : ``None`` | (iterable of) int
            Calculate the entropy for segments starting at these sites.
            ``None`` defaults to ``range(L-segment[-1])`` for finite
            or `range(L)` for infinite boundary conditions.
        n : int | float
            Selects which entropy to calculate;
            `n=1` (default) is the ususal von-Neumann entanglement entropy,
            otherwise the `n`-th Renyi entropy.

        Returns
        -------
        entropies : 1D ndarray
            ``entropies[i]`` contains the entropy for the the region ``A_i`` defined above.
        """
        # Side-Remark: there is a trick to calculate the entanglement for large regions `A_i`
        # of consecutive sites (in our notation, ``segment = range(La)``)
        # To get the entanglement entropy, diagonalize:
        #     --theta---
        #       | | |
        #     --theta*--
        #  Diagonalization is O(chi^6), compared to O(d^{3*La})
        segment = np.sort(segment)
        if first_site is None:
            if self.finite:
                first_site = range(0, self.L - segment[-1])
            else:
                first_site = range(self.L)
        comb_legs = [
            self._get_p_labels(len(segment), False),
            self._get_p_labels(len(segment), True)
        ]
        res = []
        for i0 in first_site:
            rho = self.get_rho_segment(segment + i0)
            rho = rho.combine_legs(comb_legs, qconj=[+1, -1])
            p = npc.eigvalsh(rho)
            res.append(entropy(p, n))
        return np.array(res)

    def entanglement_spectrum(self, by_charge=False):
        r"""return entanglement energy spectrum.

        Parameters
        ----------
        by_charge : bool
            Wheter we should sort the spectrum on each bond by the possible charges.

        Returns
        -------
        ent_spectrum : list
            For each (non-trivial) bond the entanglement spectrum.
            If `by_charge` is ``False``, return (for each bond) a sorted 1D ndarray
            with the convetion :math:`S_i^2 = e^{-\xi_i}`, where :math:`S_i` labels a Schmidt value
            and :math:`\xi_i` labels the entanglement 'energy' in the returned spectrum.
            If `by_charge` is True, return a a list of tuples ``(charge, sub_spectrum)``
            for each possible charge on that bond.
        """
        if by_charge:
            res = []
            for i in range(self.L + 1)[self.nontrivial_bonds]:
                ss = -2.*np.log(self._S[i])
                if i < self.L:
                    leg = self._B[i].get_leg('vL')
                else:  # i == L: segment b.c.
                    leg = self._B[i - 1].get_leg('vR')
                spectrum = [(leg.get_charge(qi), np.sort(ss[leg.get_slice(qi)]))
                            for qi in range(leg.block_number)]
                res.append(spectrum)
            return res
        else:
            return [np.sort(-2.*np.log(ss)) for ss in self._S[self.nontrivial_bonds]]

    def get_rho_segment(self, segment):
        """Return reduced density matrix for a segment.

        Note that the dimension of rho_A scales exponentially in the length of the segment.

        Parameters
        ----------
        segment : iterable of int
            Sites for which the reduced density matrix is to be calculated.
            Assumed to be sorted.

        Returns
        -------
        rho : :class:`~tenpy.linalg.np_conserved.Array`
            Reduced density matrix of the segment sites.
            Labels ``'p0', 'p1', ..., 'pk', 'p0*', 'p1*', ..., 'pk*'`` with ``k=len(segment)``.
        """
        segment = np.asarray(segment)
        if np.all(segment[1:] == segment[:-1] + 1):  # consecutive
            theta = self.get_theta(segment[0], segment[-1] - segment[0] + 1)
            rho = npc.tensordot(theta, theta.conj(), axes=(['vL', 'vR'], ['vL*', 'vR*']))
            return rho
        rho = self.get_theta(segment[0], 1)
        rho = npc.tensordot(rho, rho.conj(), axes=('vL', 'vL*'))
        k = 1
        for i in range(segment[0] + 1, segment[-1]):
            B = self.get_B(i)
            if i == segment[k]:
                B = B._replace_p_label(B, k)
                k += 1
                rho = npc.tensordot(rho, B, axes=('vR', 'vL'))
                rho = npc.tensordot(rho, B.conj(), axes=('vR*', 'vL*'))
            else:
                rho = npc.tensordot(rho, B, axes=('vR', 'vL'))
                rho = npc.tensordot(rho, B.conj(), axes=(['vR*', 'p'], ['vL*', 'p*']))
        B = self.get_B(segment[-1])
        rho = npc.tensordot(rho, B, axes=('vR', 'vL'))
        rho = npc.tensordot(rho, B.conj(), axes=(['vR*', 'vR'], ['vL*', 'vR*']))
        return rho

    def mutinf_two_site(self, max_range=None, n=1):
        """Calculate the two-site mutual information :math:`I(i:j)`.

        Calculates :math:`I(i:j) = S(i) + S(j) - S(i,j)`,
        where :math:`S(i)` is the single site entropy on site :math:`i`
        and :math:`S(i,j)` the two-site entropy on sites :math:`i,j`.

        Parameters
        ----------
        max_range : int
            Maximal distance ``|i-j|`` for which the mutual information should be calculated.
            ``None`` defaults to `L-1`.
        n : float
            Selects the entropy to use, see :func:`~tenpy.tools.math.entropy`.

        Returns
        -------
        coords : 2D array
            Coordinates for the mutinf array.
        mutinf : 1D array
            ``mutinf[k]`` is the mutual information :math:`I(i:j)` between the
            sites ``i, j = coords[k]``.
        """
        #  Basically the code of get_rho_segment and entanglement_entropy,
        #  but optimized to run in O(L^2)
        if max_range is None:
            max_range = self.L
        S_i = self.entanglement_entropy_segment(n=n)  # single-site entropy
        legs_ij = self._get_p_labels(2, False), self._get_p_labels(2, True)
        # = (['p0', 'p1'], ['p0*', 'p1*'])
        contr_legs = (
            ['vR*'] + self._get_p_label(1, False),  # 'vL', 'p1'
            ['vL*'] + self._get_p_label(1, True))  # 'vL*', 'p1*'
        mutinf = []
        coord = []
        for i in range(self.L):
            rho = self.get_theta(i, 1)
            rho = npc.tensordot(rho, rho.conj(), axes=('vL', 'vL*'))
            jmax = i + max_range + 1
            if self.finite:
                jmax = min(jmax, self.L)
            for j in range(i + 1, jmax):
                B = self._replace_p_label(self.get_B(j, form='B'), 1)  # 'vL', 'vR', 'p1'
                rho = npc.tensordot(rho, B, axes=['vR', 'vL'])
                rho_ij = npc.tensordot(rho, B.conj(), axes=(['vR*', 'vR'], ['vL*', 'vR*']))
                rho_ij = rho_ij.combine_legs(legs_ij, qconj=[+1, -1])
                S_ij = entropy(npc.eigvalsh(rho_ij), n)
                mutinf.append(S_i[i] + S_i[j % self.L] - S_ij)
                coord.append((i, j))
                if j + 1 < jmax:
                    rho = npc.tensordot(rho, B.conj(), axes=contr_legs)
        return np.array(coord), np.array(mutinf)

    def overlap(self, other):
        """Compute overlap :math:`<self|other>`.

        Parameters
        ----------
        other : :class:`MPS`
            An MPS with the same physical sites.

        Returns
        -------
        overlap : dtype
            The contraction <self|other>, taking into account the :attr:`norm` of both MPS.
        env : MPSEnvironment
            The environment (storing the LP and RP) used to calculate the overlap.
        """
        if not self.finite:
            # requires TransferMatrix for finding dominant left/right parts
            raise NotImplementedError("TODO")
        env = MPSEnvironment(self, other)
        return env.full_contraction(0), env

    def expectation_value(self, ops, sites=None, axes=None):
        """Expectation value ``<psi|ops|psi>/<psi|psi>`` of (n-site) operator(s).

        Given the MPS in canonical form, it calculates n-site expectation values.
        For example the contraction for a two-site (`n` = 2) operator on site `i` would look like::

            |          .--S--B[i]--B[i+1]--.
            |          |     |     |       |
            |          |     |-----|       |
            |          |     | op  |       |
            |          |     |-----|       |
            |          |     |     |       |
            |          .--S--B*[i]-B*[i+1]-.

        Parameters
        ----------
        ops : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            The operators, for wich the expectation value should be taken,
            All operators should all have the same number of legs (namely `2 n`).
            If less than `self.L` operators are given, we repeat them periodically.
            Strings (like ``'Id', 'Sz'``) are translated into single-site operators defined by
            :attr:`sites`.
        sites : None | list of int
            List of site indices. Expectation values are evaluated there.
            If ``None`` (default), the entire chain is taken (clipping for finite b.c.)
        axes : None | (list of str, list of str)
            Two lists of each `n` leg labels giving the physical legs of the operator used for
            contraction. The first `n` legs are contracted with conjugated `B`,
            the second `n` legs with the non-conjugated `B`.
            ``None`` defaults to ``(['p'], ['p*'])`` for single site operators (`n` = 1), or
            ``(['p0', 'p1', ... 'p{n-1}'], ['p0*', 'p1*', .... 'p{n-1}*'])`` for `n` > 1.

        Returns
        -------
        exp_vals : 1D ndarray
            Expectation values, ``exp_vals[i] = <psi|ops[i]|psi>``, where ``ops[i]`` acts on
            site(s) ``j, j+1, ..., j+{n-1}`` with ``j=sites[i]``.

        Examples
        --------
        One site examples (n=1):

        >>> psi.expectation_value('Sz')
        [Sz0, Sz1, ..., Sz{L-1}]
        >>> psi.expectation_value(['Sz', 'Sx'])
        [Sz0, Sx1, Sz2, Sx3, ... ]
        >>> psi.expectation_value('Sz', sites=[0, 3, 4])
        [Sz0, Sz3, Sz4]

        Two site example (n=2), assuming homogeneous sites:

        >>> SzSx = npc.outer(psi.sites[0].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                             psi.sites[1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
        >>> psi.expectation_value(SzSx)
        [Sz0Sx1, Sz1Sx2, Sz2Sx3, ... ]   # with len L-1 for finite bc, or L for infinite

        Example measuring <psi|SzSx|psi2> on each second site, for inhomogeneous sites:

        >>> SzSx_list = [npc.outer(psi.sites[i].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                                   psi.sites[i+1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
                         for i in range(0, psi.L-1, 2)]
        >>> psi.expectation_value(SzSx_list, range(0, psi.L-1, 2))
        [Sz0Sx1, Sz2Sx3, Sz4Sx5, ...]

        """
        ops, sites, n, (op_ax_p, op_ax_pstar) = self._expectation_value_args(ops, sites, axes)
        ax_p = ['p' + str(k) for k in range(n)]
        ax_pstar = ['p' + str(k) + '*' for k in range(n)]
        E = []
        for i in sites:
            op = self.get_op(ops, i)
            op = op.replace_labels(op_ax_p + op_ax_pstar, ax_p + ax_pstar)
            theta = self.get_theta(i, n)
            C = npc.tensordot(op, theta, axes=[ax_pstar, ax_p])  # C has same labels as theta
            E.append(npc.inner(theta, C, axes=[theta.get_leg_labels()] * 2, do_conj=True))
        return np.real_if_close(np.array(E))

    def correlation_function(self,
                             ops1,
                             ops2,
                             sites1=None,
                             sites2=None,
                             opstr=None,
                             str_on_first=True,
                             hermitian=False):
        r"""Correlation function  ``<psi|op1_i op2_j|psi>/<psi|psi>`` of single site operators.

        Given the MPS in canonical form, it calculates n-site expectation values.
        For examples the contraction for a two-site operator on site `i` would look like::

            |          .--S--B[i]--B[i+1]--...--B[j]---.
            |          |     |     |            |      |
            |          |     |     |            op2    |
            |          |     op1   |            |      |
            |          |     |     |            |      |
            |          .--S--B*[i]-B*[i+1]-...--B*[j]--.

        Onsite terms are taken in the order ``<psi | op1 op2 | psi>``.

        If `opstr` is given and ``str_on_first=True``, it calculates::

            |           for i < j                               for i > j
            |
            |          .--S--B[i]---B[i+1]--...- B[j]---.     .--S--B[j]---B[j+1]--...- B[i]---.
            |          |     |      |            |      |     |     |      |            |      |
            |          |     opstr  opstr        op2    |     |     op2    |            |      |
            |          |     |      |            |      |     |     |      |            |      |
            |          |     op1    |            |      |     |     opstr  opstr        op1    |
            |          |     |      |            |      |     |     |      |            |      |
            |          .--S--B*[i]--B*[i+1]-...- B*[j]--.     .--S--B*[j]--B*[j+1]-...- B*[i]--.

        For ``i==j``, no `opstr` is included.
        For ``str_on_first=False``, the `opstr` on site ``min(i, j)`` is always left out.

        Strings (like ``'Id', 'Sz'``) in the operator lists are translated into single-site
        operators defined by the :class:`~tenpy.networks.site.Site` on which they act.
        Each operator should have the two legs ``'p', 'p*'``.

        Parameters
        ----------
        ops1 : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            First operator of the correlation function (acting after ops2).
            ``ops1[x]`` acts on site ``sites1[x]``.
            If less than ``len(sites1)`` operators are given, we repeat them periodically.
        ops2 : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            Second operator of the correlation function (acting before ops1).
            ``ops2[y]`` acts on site ``sites2[y]``.
            If less than ``len(sites2)`` operators are given, we repeat them periodically.
        sites1 : None | int | list of int
            List of site indices; a single `int` is translated to ``range(0, sites1)``.
            ``None`` defaults to all sites ``range(0, L)``.
            Is sorted before use, i.e. the order is ignored.
        sites2 : None | int | list of int
            List of site indices; a single `int` is translated to ``range(0, sites2)``.
            ``None`` defaults to all sites ``range(0, L)``.
            Is sorted before use, i.e. the order is ignored.
        opstr : None | (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            Ignored by default (``None``).
            Operator(s) to be inserted between ``ops1`` and ``ops2``.
            If given as a list, ``opstr[r]`` is inserted at site `r` (independent of `sites1` and
            `sites2`).
        str_on_first : bool
            Whether the `opstr` is included on the site ``min(i, j)``.
            Note the order, which is chosen that way to handle fermionic Jordan-Wigner strings
            correctly. (In other words: choose ``str_on_first=True`` for fermions!)
        hermitian : bool
            Optimization flag: if ``sites1 == sites2`` and ``Ops1[i]^\dagger == Ops2[i]``
            (which is not checked explicitly!), the resulting ``C[x, y]`` will be hermitian.
            We can use that to avoid calculations, so ``hermitian=True`` will run faster.

        Returns
        -------
        C : 2D ndarray
            The correlation function ``C[x, y] = <psi|ops1[i] ops2[j]|psi>``,
            where ``ops1[i]`` acts on site ``i=sites1[x]`` and ``ops2[j]`` on site ``j=sites2[y]``.
            If `opstr` is given, it gives (for ``str_on_first=True``):

            - For ``i < j``: ``C[x, y] = <psi|ops1[i] prod_{i <= r < j} opstr[r] ops2[j]|psi>``.
            - For ``i > j``: ``C[x, y] = <psi|prod_{j <= r < i} opstr[r] ops1[i] ops2[j]|psi>``.
            - For ``i = j``: ``C[x, y] = <psi|ops1[i] ops2[j]|psi>``.

            The condition ``<= r`` is replaced by a strict ``< r``, if ``str_on_first=False``.
        """
        ops1, ops2, sites1, sites2, opstr = self._correlation_function_args(
            ops1, ops2, sites1, sites2, opstr)
        if hermitian and sites1 != sites2:
            warnings.warn("MPS correlation function can't use the hermitian flag")
            hermitian = False
        C = np.empty((len(sites1), len(sites2)), dtype=np.complex)
        for x, i in enumerate(sites1):
            # j > i
            j_gtr = sites2[sites2 > i]
            if len(j_gtr) > 0:
                C_gtr = self._corr_up_diag(ops1, ops2, i, j_gtr, opstr, str_on_first, True)
                C[x, (sites2 > i)] = C_gtr
                if hermitian:
                    C[x + 1:, x] = np.conj(C_gtr)
            # j == i
            j_eq = sites2[sites2 == i]
            if len(j_eq) > 0:
                # on-site correlation function
                op12 = npc.tensordot(self.get_op(ops1, i), self.get_op(ops2, i), axes=['p*', 'p'])
                C[x, (sites2 == i)] = self.expectation_value(op12, i, [['p'], ['p*']])
        if not hermitian:
            #  j < i
            for y, j in enumerate(sites2):
                i_gtr = sites1[sites1 > j]
                if len(i_gtr) > 0:
                    C[(sites1 > j), y] = self._corr_up_diag(ops2, ops1, j, i_gtr, opstr,
                                                            str_on_first, False)
                    # exchange ops1 and ops2 : they commute on different sites,
                    # but we apply opstr after op1 (using the last argument = False)
        return np.real_if_close(C)

    def norm_test(self):
        """Check that self is in canonical form.

        Returns
        -------
        norm_error: array, shape (L, 2)
            For each site the norm error to the left and right.
            The error ``norm_error[i, 0]`` is defined as the norm-difference between
            the following networks::

                |   --theta[i]---.       --s[i]--.
                |       |        |    vs         |
                |   --theta*[i]--.       --s[i]--.

            Similarly, ``norm_errror[i, 1]`` is the norm-difference of::

                |   .--theta[i]---         .--s[i+1]--
                |   |    |          vs     |
                |   .--theta*[i]--         .--s[i+1]--

        """
        err = np.empty((self.L, 2), dtype=np.float)
        lbl_R = (self._get_p_label(0, star=False) + ['vR'],
                 self._get_p_label(0, star=True) + ['vR*'])
        lbl_L = (self._get_p_label(0, star=False) + ['vL'],
                 self._get_p_label(0, star=True) + ['vL*'])
        for i in range(self.L):
            th = self.get_theta(i, 1)
            rho_L = npc.tensordot(th, th.conj(), axes=lbl_R)
            S = self.get_SL(i)
            if isinstance(S, npc.Array):  # during DMRG with mixer, S may be a 2D npc.Array
                if S.rank != 2:
                    raise ValueError("Expect 2D npc.Array or 1D numpy ndarray")
                rho_L2 = npc.tensordot(S, S.conj(), axes=['vR', 'vR*'])
            else:
                rho_L2 = npc.diag(S**2, rho_L.get_leg('vL'), dtype=rho_L.dtype)
            err[i, 0] = npc.norm(rho_L - rho_L2)
            rho_R = npc.tensordot(th, th.conj(), axes=lbl_L)
            S = self.get_SR(i)
            if isinstance(S, npc.Array):
                if S.rank != 2:
                    raise ValueError("Expect 2D npc.Array or 1D numpy ndarray")
                rho_R2 = npc.tensordot(S, S.conj(), axes=['vL', 'vL*'])
            else:
                rho_R2 = npc.diag(S**2, rho_R.get_leg('vR'), dtype=rho_L.dtype)
            err[i, 1] = npc.norm(rho_R - rho_R2)
        return err

    def canonical_form(self, renormalize=True):
        """Bring self into canonical 'B' form, calculate singular values.

        Works only for finite/segment boundary conditions.
        If any `B` is in `form` ``None``, it does *not* use any of the singular values `S`
        (for 'finite' boundary conditions, or only the very left `S` for 'segment' b.c.).
        If all sites have a `form` label (like ``'A','B'``), it respects the `form` to ensure
        that one `S` is included per bond.

        .. todo ::
            Should we try to avoid carrying around the total charge of the B matrices?
            Also, implement 'canonical_form_infinite' by diagonalizing the transfer matrix...

        Parameters
        ----------
        renormalize: bool
            Whether a change in the norm should be discarded or used to update :attr:`norm`.

        Returns
        -------
        U_L, V_R : :class:`~tenpy.linalg.np_conserved.Array`
            Only returned for ``'segment'`` boundary conditions.
            The unitaries defining the new left and right Schmidt states in terms of the old ones,
            with legs ``'vL', 'vR'``.
        """
        if self.finite:
            return self._canonical_form_finite(renormalize)
        else:
            self._canonical_form_infinite()

    def correlation_length(self, num_ev=1, charge_sector=None, tol_ev0=1.e-8):
        r"""Calculate the correlation length by diagonalizing the transfer matrix.

        Works only for infinite MPS, where the transfer matrix is a useful concept.
        For an MPS, any correlation function splits into :math:`C(A_i, B_j) = A'_i T^{j-i-1} B'_j`
        with some parts left and right and the :math:`j-i-1`-th power of the transfer matrix in
        between. The largest eigenvalue is 1 and gives the dominant contribution of
        :math:`A'_i E_1 * 1^{j-i-1} * E_1^T B'_j = <A> <B>`, and the second largest one
        gives a contribution :math:`\propto \lambda_2^{j-i-1}`.
        Thus :math:`\lambda_2 = \exp(-\frac{1}{\xi})`.
        Assumes that `self` is in canonical form.

        .. todo ::
            might want to insert OpString.

        Parameters
        ----------
        num_ev : int
            We look for the `num_ev` + 1 largest eigenvalues.
        charge_sector : ``None`` | 0 | charges
            The charge sector of the transfer matrix,
            in which we look for the most dominant eigenvalues.

        Returns
        -------
        xi : float | 1D array
            for num_ev =1, return just the correlation length,
            otherwise an array of the `num_ev` largest correlation legths.
        """
        self.convert_form('B')  # ensure uniform canonical form.
        T = TransferMatrix(self, self, charge_sector=charge_sector)
        E, V = T.eigenvectors(num_ev + 1, which='LM')
        E = E[np.argsort(-np.abs(E))]  # sort descending by magnitude
        if abs(E[0] - 1.) > tol_ev0:
            raise ValueError("largest eigenvalue not one: was not in canonical form!")
        if len(E) < 2:
            return 0.  # only a single eigenvector: zero correlation length
        if num_ev == 1:
            return -1. / np.log(abs(E[1] / E[0])) * self.L
        return -1. / np.log(np.abs(E[1:num_ev + 1] / E[0])) * self.L

    def add(self, other, alpha, beta):
        """Return an MPS which represents ``alpha|self> + beta |others>``.

        Works only for 'finite' boundary conditions.
        Takes into account :attr:`norm`.

        Parameters
        ----------
        other : :class:`MPS`
            Another MPS of the same length to be added with self.
        alpha, beta : complex float
            Prefactors for self and other. We calculate
            ``alpha * |self> + beta * |other>``

        Returns
        -------
        sum : :class:`MPS`
            An MPS representing ``alpha|self> + beta |others>``.
        """
        L = self.L
        assert (other.L == L and L >= 2)  # (one could generalize this function...)
        assert (self.bc == 'finite')  # not clear for segment: are left states orthogonal?
        # TODO: should gauge qtotal to zero.
        legs = ['vL', 'vR'] + self._p_label
        # alpha and beta appear only on the first site
        alpha = alpha * self.norm
        beta = beta * other.norm
        Bs = [
            npc.grid_concat(
                [[alpha * self.get_B(0).transpose(legs), beta * other.get_B(0).transpose(legs)]],
                axes=[0, 1])
        ]
        for i in range(1, L - 1):
            B1 = self.get_B(i).transpose(legs)
            B2 = other.get_B(i).transpose(legs)
            grid = [[B1, npc.zeros([B1.get_leg('vL'), B2.get_leg('vR')] + B1.legs[2:])],
                    [npc.zeros([B2.get_leg('vL'), B1.get_leg('vR')] + B1.legs[2:]), B2]]
            Bs.append(npc.grid_concat(grid, [0, 1]))
        Bs.append(
            npc.grid_concat(
                [[self.get_B(L - 1).transpose(legs)], [other.get_B(L - 1).transpose(legs)]],
                axes=[0, 1]))

        Ss = [np.ones(1)] + [np.ones(B.shape[1]) for B in Bs]
        psi = MPS(self.sites, Bs, Ss, 'finite', None)
        # bring to canonical form, calculate Ss
        psi._canonical_form_finite(renormalize=False)
        return psi

    def apply_local_op(self, i, op, unitary=None, renormalize=False):
        """Apply a local operator to `self`.


        Note that this destroys the canonical form if the local operator is non-unitary.
        Therefore, this function calls :meth:`canonical_form` if necessary.

        Parameters
        ----------
        i : int
            Index of the site on which the operator should act.
        op : str | npc.Array
            A physical operator acting on site `i`, with legs ``'p', 'p*'``.
            Strings (like ``'Id', 'Sz'``) are translated into single-site operators defined by
            :attr:`sites`.
        unitary : None | bool
            Whether `op` is unitary, i.e., whether the canonical form is preserved (``True``)
            or whether we should call :meth:`canonical_form` (``False``).
            ``None`` checks whether ``norm(op dagger(op) - identity)`` is small.
        renormalize : bool
            Whether the final state should keep track of the norm (False, default) or be
            renormalized to have norm 1 (True).

        """
        i = self._to_valid_index(i)
        if isinstance(op, str):
            op = self.sites[i].get_op(op)
        if unitary is None:
            op_op_dagger = npc.tensordot(op, op.conj(), axes=['p*', 'p'])
            unitary = npc.norm(op_op_dagger - npc.eye_like(op_op_dagger.legs[0])) < 1.e-14
        opB = npc.tensordot(op, self._B[i], axes=['p*', 'p'])
        self._B[i] = opB
        self.dtype = np.find_common_type([self.dtype, opB.dtype], [])
        if not unitary:
            self.canonical_form(renormalize)

    def swap_sites(self, i, swap_op='auto', trunc_par={}):
        """Swap the two neighboring sites `i` and `i+1` (inplace).

        Exchange two neighboring sites: form theta, 'swap' the physical legs and split
        with an svd. While the 'swap' is just a transposition/relabeling for bosons, one needs to
        be careful about the sign for fermions.

        Parameters
        ----------
        i : int
            Swap the two sites at positions `i` and `i+1`.
        swap_op : ``None`` | ``'auto'`` | :class:`~tenpy.linalg.np_conserved.Array`
            The operator used to swap the phyiscal legs of the two-site wave function `theta`.
            For ``None``, just transpose/relabel the legs, for ``'auto'`` also take care of
            fermionic signs. Alternative give an npc :class:`~tenpy.linalg.np_conserved.Array`
            which represents the full operator used for the swap.
            Should have legs ``['p0', 'p1', 'p0*', 'p1*']`` whith ``'p0', 'p1*'`` contractible.
        trunc_par : dict
            Parameters for truncation, see :func:`~tenpy.algorithms.truncation.truncate`.

        Returns
        -------
        trunc_err : :class:`~tenpy.algorithms.truncation.TruncationError`
            The error of the represented state introduced by the truncation after the swap.
        """
        siteL, siteR = self.sites[self._to_valid_index(i)], self.sites[self._to_valid_index(i+1)]
        if swap_op == 'auto':
            # get sign for Fermions.
            # If we write the wave function as
            # psi = sum_{ [n_i]} psi_[n_i] prod_i (c^dagger_i)^{n_i}  |vac>
            # we see that switching i <-> i+1 the phase to be introduced is by commuting
            # (c^dagger_i)^{n_i} with (c^dagger_{i+1})^{n_{i+1}}
            # This gives a sign (-1)^{n_i * n_{i+1}}.
            # site.JW_exponent is the `n_i` in the above equations, for each physical index.
            sign = siteL.JW_exponent[:, np.newaxis] * siteR.JW_exponent[np.newaxis, :]
            if np.any(sign):
                dL, dR = siteL.dim, siteR.dim
                sign = np.real_if_close(np.exp(1.j*np.pi*sign.reshape([dL*dR])))
                swap_op = np.diag(sign).reshape([dL, dR, dL, dR])
                legs = [siteL.leg, siteR.leg, siteL.leg.conj(), siteR.leg.conj()]
                swap_op = npc.Array.from_ndarray(swap_op, legs)
                swap_op.iset_leg_labels(['p1', 'p0', 'p0*', 'p1*'])
            else:  # no sign necessary
                swap_op = None  # continue with transposition as for Bosons
        theta = self.get_theta(i, n=2)
        C = self.get_theta(i, n=2, formL=0.)  # inversion free, see also TEBDEngine.update_bond()
        if swap_op is None:
            # just replace the labels, effectively this is a transposition.
            theta.ireplace_labels(['p0', 'p1'], ['p1', 'p0'])
            C.ireplace_labels(['p0', 'p1'], ['p1', 'p0'])
        elif isinstance(swap_op, npc.Array):
            theta = npc.tensordot(swap_op, theta, axes=[['p0*', 'p1*'], ['p0', 'p1']])
            C = npc.tensordot(swap_op, C, axes=(['p0*', 'p1*'], ['p0', 'p1']))
        else:
            raise ValueError("Invalid swap_op: got " + repr(swap_op))
        theta = theta.combine_legs([('vL', 'p0'), ('vR', 'p1')], qconj=[+1, -1])
        U, S, V, err, renormalize = svd_theta(theta, trunc_par, inner_labels=['vR', 'vL'])
        B_R = V.split_legs(1).ireplace_label('p1', 'p')
        B_L = npc.tensordot(C.combine_legs(('vR', 'p1'), pipes=theta.legs[1]),
                            V.conj(),
                            axes=['(vR.p1)', '(vR*.p1*)'])
        B_L.ireplace_labels(['vL*', 'p0'], ['vR', 'p'])
        B_L /= renormalize  # re-normalize to <psi|psi> = 1
        self.set_SR(i, S)
        self.set_B(i, B_L, 'B')
        self.set_B(i+1, B_R, 'B')
        self.sites[self._to_valid_index(i)] = siteR  # swap 'sites' as well
        self.sites[self._to_valid_index(i+1)] = siteL
        return err

    def permute_sites(self, perm, swap_op='auto', trunc_par={}, verbose=0):
        """Applies the permutation perm to the state (inplace).

        Parameters
        ----------
        perm : ndarray[ndim=1, int]
            The applied permutation, such that ``psi.permute_sites(perm)[i] = psi[perm[i]]``
            (where ``[i]`` indicates the `i`-th site).
        swap_op : ``None`` | ``'auto'`` | :class:`~tenpy.linalg.np_conserved.Array`
            The operator used to swap the phyiscal legs of a two-site wave function `theta`,
            see :meth:`swap_sites`.
        trunc_par : dict
            Parameters for truncation, see :func:`~tenpy.algorithms.truncation.truncate`.
        verbose : float
            Level of verbosity, print status messages if verbose > 0.

        Returns
        -------
        trunc_err : :class:`~tenpy.algorithms.truncation.TruncationError`
            The error of the represented state introduced by the truncation after the swaps.
        """
        # In order to keep sites close together, we always scan from the left,
        # keeping everything up to `i` in strictly ascending order.
        # => more or less an 'insertion' sort algorithm.
        # Works nicely for permutations like [1,2,3,0,6,7,8,5] (swapping the 0 and 5 around).
        # For [ 2 3 4 5 6 7 0 1], it splits 0 and 1 apart (first swapping the 0 down, then the 1)
        trunc_err = TruncationError()
        num_swaps = 0
        i = 0
        while i < self.L-1:
            if perm[i] > perm[i+1]:
                if verbose > 1:
                    print(i, ": chi = ", self._S[i+1].shape[0], end='')
                trunc = self.swap_sites(i, swap_op, trunc_par)
                if verbose > 1:
                    print("->", self._S[i+1].shape[0], ". eps = ", trunc.eps)
                num_swaps += 1
                x, y = perm[i], perm[i+1]
                perm[i+1], perm[i] = x, y
                # restart from very left; but we know it's already sorted up to i-1
                if i > 0:
                    i -= 1
                trunc_err += trunc
            else:
                i += 1
        if verbose > 0:
            print("Total swaps in permute_sites:", num_swaps, repr(trunc_err))
        return trunc_err

    def compute_K(self, perm, swap_op='auto', trunc_par={}, canonicalize=1.e-6, verbose=0):
        r"""Compute the momentum quantum numbers of the entanglement spectrum for 2D states.

        Works for an infinite MPS living on a cylinder, infinitely long in `x` direction and with
        periodic boundary conditions in `y` directions.
        If the state is invariant under 'rotations' around the cylinder axis, one can find the
        momentum quantum numbers of it. (The rotation is nothing more than a translation in `y`.)
        This function permutes some sites (on a copy of `self`) to enact the rotation, and then
        finds the dominant eigenvector of the mixed transfer matrix to get the quantum numbers,
        along the lines of [PollmannTurner2012]_.


        Parameters
        ----------
        perm : 1D ndarray | :class:`~tenpy.models.lattice.Lattice`
            Permuation to be applied to the physical indices, see :meth:`permute_sites`.
            If a lattice is given, we use it to read out the lattice structure and shift
            each site by one lattice-vector in y-direction (assuming periodic boundary conditions).
            (If you have a :class:`~tenpy.models.model.CouplingModel`,
            give its `lat` attribute for this argument)
        swap_op : ``None`` | ``'auto'`` | :class:`~tenpy.linalg.np_conserved.Array`
            The operator used to swap the phyiscal legs of a two-site wave function `theta`,
            see :meth:`swap_sites`.
        trunc_par : dict
            Parameters for truncation, see :func:`~tenpy.algorithms.truncation.truncate`.
        canonicalize : float
            Check that `self` is in canonical form; call :meth:`canonical_form`
            if :meth:`norm_test` yields ``np.linalg.norm(self.norm_test()) > canonicalize``.
        verbose : float
            Level of verbosity, print status messages if verbose > 0.

        Returns
        -------
        U : :class:`~tenpy.linalg.np_conserved.Array`
            Unitary representation of the applied permutation on left Schmidt states.
        W : ndarray
            1D array of the form ``S**2 exp(i K)``, where `S` are the singular values
            on the left bond.
        q : :class:`~tenpy.linalg.charges.LegCharge`
            LegCharge corresponding to `W`.
        ov : complex
            The eigenvalue of the mixed transfer matrix `<psi|T|psi>` per :attr:`L` sites.
        trunc_err : :class:`~tenpy.algorithms.truncation.TruncationError`
            The error of the represented state introduced by the truncation after swaps.
        """
        from ..models.lattice import Lattice  # dynamical import to avoid import loops
        if self.finite:
            raise ValueError("Works only for infinite b.c.")

        if isinstance(perm, Lattice):
            lat = perm
            assert lat.dim >= 2  # ensure that the lattice is at least 2D
            assert lat.N_sites == self.L
            shifted_lat_order = lat.order.copy()
            shifted_lat_order[:, 1] = np.mod(shifted_lat_order[:, 1] + 1, lat.Ls[1])
            perm = lat.lat2mps_idx(shifted_lat_order)
            if verbose > 1:
                print("permutation: ", perm)
        # preliminary: check canonical form
        self.convert_form('B')
        norm_err = np.linalg.norm(self.norm_test())
        if norm_err > canonicalize:
            warnings.warn("self.norm_test() =", norm_err, "==> canonicalize")
            self.canonical_form()
        # get copy of self
        psi_t = self.copy()
        # apply permutation
        perm = np.asarray(perm)
        trunc_err = psi_t.permute_sites(perm, swap_op, trunc_par, verbose/10.)
        # re-check canonical form
        norm_err = np.linalg.norm(psi_t.norm_test())
        if norm_err > canonicalize:
            warnings.warn("psi_t.norm_test() =", norm_err, "==> canonicalize")
            psi_t.canonical_form()
        psi_t.convert_form('B')
        TM = TransferMatrix(self, psi_t, transpose=True, charge_sector=0)
        # Find left dominant eigenvector of this mixed transfer matrix.
        # Because we are in B form and get the left eigenvector,
        # the resulting vector should be sUs up to a scaling.
        ov, sUs = TM.eigenvectors()
        if verbose > 0:
            print("compute_K: overlap ", ov[0], ", |o| = 1. -", 1. - np.abs(ov[0]))
            # (should be 1 if state is invariant under translations)
            print("compute_K: truncation error ", trunc_err.eps)
        sUs = sUs[0].split_legs(0)
        _, sUs_blocked = sUs.as_completely_blocked()
        W = npc.eigvals(sUs_blocked, sort='m>')
        # W = s^2 exp(i K ) up to overall scaling
        # Strip S's from U
        inv_S = 1./self.get_SL(0)
        U = sUs.scale_axis(inv_S, 0).iscale_axis(inv_S, 1)
        # U should be unitary - scale it
        U *= (np.sqrt(U.shape[0])/npc.norm(U))
        return U, W/np.sum(np.abs(W)), sUs_blocked.legs[0], ov[0], trunc_err

    def __str__(self):
        """Some status information about the MPS."""
        res = ["MPS, L={L:d}, bc={bc!r}.".format(L=self.L, bc=self.bc)]
        res.append("chi: " + str(self.chi))
        if self.L > 10:
            res.append("first two sites: " + repr(self.sites[0]) + " " + repr(self.sites[1]))
            res.append("first two forms:" + " ".join([repr(f) for f in self.form[:2]]))
        else:
            res.append("sites: " + " ".join([repr(s) for s in self.sites]))
            res.append("forms: " + " ".join([repr(f) for f in self.form]))
        return "\n".join(res)

    def _to_valid_index(self, i):
        """Make sure `i` is a valid index (depending on `self.bc`)."""
        if not self.finite:
            return i % self.L
        if i < 0:
            i += self.L
        if i >= self.L or i < 0:
            raise ValueError("i = {0:d} out of bounds for finite MPS".format(i))
        return i

    def _parse_form(self, form):
        """Parse `form` = (list of) {tuple | key of _valid_forms} to list of tuples"""
        if isinstance(form, tuple):
            return [form] * self.L
        form = to_iterable(form)
        if len(form) == 1:
            form = [form[0]] * self.L
        if len(form) != self.L:
            raise ValueError("Wrong len of `form`: " + repr(form))
        return [self._to_valid_form(f) for f in form]

    def _to_valid_form(self, form):
        """Parse `form` = {tuple | key of _valid_forms} to a tuple"""
        if isinstance(form, tuple):
            return form
        return self._valid_forms[form]

    def _convert_form_i(self, B, i, form, new_form, copy=True, cutoff=1.e-16):
        """Transform `B[i]` from canonical form `form` into canonical form `new_form`.

        ======== ======== ================================================
        form     new_form action
        ======== ======== ================================================
        *        ``None`` return (copy of) B
        tuple    tuple    scale the legs 'vL' and 'vR' of B appropriately
                          with ``self.get_SL(i)`` and ``self.get_SR(i)``.
        ``None`` tuple    raise ValueError
        ======== ======== ================================================
        """
        if new_form is None or form == new_form:
            if copy:
                return B.copy()
            return B  # nothing to do
        if form is None:
            raise ValueError("can't convert form of non-canonical state!")
        old_L, old_R = form
        new_L, new_R = new_form
        B = self._scale_axis_B(B, self.get_SL(i), new_L - old_L, 'vL', cutoff)
        B = self._scale_axis_B(B, self.get_SR(i), new_R - old_R, 'vR', cutoff)
        return B

    def _scale_axis_B(self, B, S, form_diff, axis_B, cutoff):
        """Scale an axis of B with S to bring it in desired form.

        If S is just 1D (as usual, e.g. during TEBD), this function just performs
        ``B.scale_axis(S**form_diff, axis_B)``.

        However, during the DMRG with mixer, S might acutally be a 2D matrix.
        For ``form_diff = -1``, we need to calculate the inverse of S, more precisely the
        (Moore-Penrose) pseudo inverse, see :func:`~tenpy.linalg.np_conserved.pinv`.
        The cutoff is only used in that case.

        Returns scaled B."""
        if form_diff == 0:
            return B  # nothing to do
        if isinstance(S, npc.Array):
            if S.rank != 2:
                raise ValueError("Expect 2D npc.Array or 1D numpy ndarray")
            if form_diff == -1:
                S = npc.pinv(S, cutoff)
            elif form_diff != 1.:
                raise ValueError("Can't scale/tensordot a 2D `S` for non-integer `form_diff`")

            # Hack: mpo.MPOEnvironment.full_contraction uses ``axis_B == 'vL*'``
            if axis_B == 'vL' or axis_B == 'vL*':
                B = npc.tensordot(S, B, axes=[1, axis_B]).replace_label(0, axis_B)
            elif axis_B == 'vR' or axis_B == 'vR*':
                B = npc.tensordot(B, S, axes=[axis_B, 0]).replace_label(-1, axis_B)
            else:
                raise ValueError("This should never happen: unexpected leg for scaling with S")
            return B
        else:
            if form_diff != 1.:
                S = S**form_diff
            return B.scale_axis(S, axis_B)

    def _replace_p_label(self, A, k):
        """Return npc Array `A` with replaced label, ``'p' -> 'p'+str(k)``.

        This is done for each of the 'physical labels' in :attr:`_p_label`.
        With a clever use of this function, the re-implementation of various functions
        (like get_theta) in derived classes with multiple legs per site can be avoided.
        """
        return A.replace_labels(self._p_label, self._get_p_label(k, False))

    def _get_p_label(self, k, star=False):
        """return list of physical label(s) with additional str(k) and possibly a '*'."""
        if star == 'both':
            return [lbl + str(k) for lbl in self._p_label] + \
                   [lbl + str(k)+'*' for lbl in self._p_label]
        elif star:
            return [lbl + str(k) + '*' for lbl in self._p_label]
        else:
            return [lbl + str(k) for lbl in self._p_label]

    def _get_p_labels(self, ks, star=False):
        """join ``self._get_p_label(k) for k in range(ks)`` to a single list."""
        res = []
        for k in range(ks):
            res.extend(self._get_p_label(k, star))
        return res

    def _expectation_value_args(self, ops, sites, axes):
        """parse the arguments of self.expectation_value()"""
        ops = npc.to_iterable_arrays(ops)
        n = self.get_op(ops, 0).rank // 2  # same as int(rank/2)
        L = self.L
        if sites is None:
            if self.finite:
                sites = range(L - (n - 1))
            else:
                sites = range(L)
        sites = to_iterable(sites)
        if axes is None:
            if n == 1:
                axes = (['p'], ['p*'])
            else:
                axes = (self._get_p_labels(n), self._get_p_labels(n, True))
        # check number of axes
        ax_p, ax_pstar = axes
        if len(ax_p) != n or len(ax_pstar) != n:
            raise ValueError("Len of axes does not match to n-site operator with n=" + str(n))
        return ops, sites, n, axes

    def _correlation_function_args(self, ops1, ops2, sites1, sites2, opstr):
        """get default arguments of self.correlation_function()"""
        if sites1 is None:
            sites1 = range(0, self.L)
        elif isinstance(sites1, int):
            sites1 = range(0, sites1)
        if sites2 is None:
            sites2 = range(0, self.L)
        elif isinstance(sites2, int):
            sites2 = range(0, sites2)
        ops1 = npc.to_iterable_arrays(ops1)
        ops2 = npc.to_iterable_arrays(ops2)
        opstr = npc.to_iterable_arrays(opstr)
        sites1 = np.sort(sites1)
        sites2 = np.sort(sites2)
        return ops1, ops2, sites1, sites2, opstr

    def _corr_up_diag(self, ops1, ops2, i, j_gtr, opstr, str_on_first, apply_opstr_first):
        """correlation function above the diagonal: for fixed i and all j in j_gtr, j > i."""
        op1 = self.get_op(ops1, i)
        opstr1 = self.get_op(opstr, i)
        if opstr1 is not None:
            axes = ['p*', 'p'] if apply_opstr_first else ['p', 'p*']
            op1 = npc.tensordot(op1, opstr1, axes=axes)
        theta = self.get_theta(i, n=1)
        C = npc.tensordot(op1, theta, axes=['p*', 'p0'])
        C = npc.tensordot(theta.conj(), C, axes=[['p0*', 'vL*'], ['p', 'vL']])
        # C has legs 'vR*', 'vR'
        js = list(j_gtr[::-1])  # stack of j, sorted *descending*
        res = []
        for r in range(i + 1, js[0] + 1):  # js[0] is the maximum
            B = self.get_B(r, form='B')
            C = npc.tensordot(C, B, axes=['vR', 'vL'])
            if r == js[-1]:
                Cij = npc.tensordot(self.get_op(ops2, r), C, axes=['p*', 'p'])
                Cij = npc.inner(B.conj(), Cij, axes=[['vL*', 'p*', 'vR*'], ['vR*', 'p', 'vR']])
                res.append(Cij)
                js.pop()
            if len(js) > 0:
                op = self.get_op(opstr, r)
                if op is not None:
                    C = npc.tensordot(op, C, axes=['p*', 'p'])
                C = npc.tensordot(B.conj(), C, axes=[['vL*', 'p*'], ['vR*', 'p']])
        return res

    def _canonical_form_finite(self, renormalize):
        assert (self.finite)
        L = self.L
        assert (L > 2)  # otherwise implement yourself...
        # normalize very left singular values
        S = self.get_SL(0)
        if self.bc == 'segment':
            if S is None:
                raise ValueError("Need S[0] for segment boundary conditions.")
            self.set_SL(0,
                        S / np.linalg.norm(S))  # must have correct singular values to the left...
            S = self.get_SR(L - 1)
            self.set_SR(L - 1, S / np.linalg.norm(S))
        else:  # bc == 'finite':
            self.set_SL(0, np.array([1.]))  # trivial singular value on very left/right
            self.set_SR(L - 1, np.array([1.]))
        # sweep from left to right to bring it into left canonical form.
        if any([(f is None) for f in self.form]):
            # ignore any 'S' and canonical form
            M = self.get_B(0, None)
            form = None
        else:
            # we actually had a canonical form before, so we should *not* ignore the 'S'
            M = self.get_theta(0, n=1).replace_labels(self._get_p_label(0), self._p_label)
            form = 'B'  # for other 'M'
        if self.bc == 'segment':
            M.iscale_axis(self.get_SL(0), axis='vL')
        Q, R = npc.qr(M.combine_legs(['vL'] + self._p_label), inner_labels=['vR', 'vL'])
        # Q = unitary, R has to be multiplied to the right
        self.set_B(0, Q.split_legs(0), form='A')
        for i in range(1, L - 1):
            M = self.get_B(i, form)
            M = npc.tensordot(R, M, axes=['vR', 'vL'])
            Q, R = npc.qr(M.combine_legs(['vL'] + self._p_label), inner_labels=['vR', 'vL'])
            # Q is unitary, i.e. left canonical, R has to be multiplied to the right
            self.set_B(i, Q.split_legs(0), form='A')
        M = self.get_B(L - 1, None)
        M = npc.tensordot(R, M, axes=['vR', 'vL'])
        if self.bc == 'segment':
            # also neet to calculate new singular values on the very right
            U, S, VR_segment = npc.svd(M.combine_legs(['vL'] + self._p_label),
                                       cutoff=0., inner_labels=['vR', 'vL'])
            S /= np.linalg.norm(S)
            self.set_SR(L - 1, S)
            M = U.scale_axis(S, 1).split_legs(0)
        # sweep from right to left, calculating all the singular values
        U, S, V = npc.svd(
            M.combine_legs(['vR'] + self._p_label, qconj=-1), cutoff=0., inner_labels=['vR', 'vL'])
        if not renormalize:
            self.norm = self.norm * np.linalg.norm(S)
        S = S / np.linalg.norm(S)  # normalize
        self.set_SL(L - 1, S)
        self.set_B(L - 1, V.split_legs(1), form='B')
        for i in range(L - 2, -1, -1):
            M = self.get_B(i, 'A')
            M = npc.tensordot(M, U.scale_axis(S, 'vR'), axes=['vR', 'vL'])
            U, S, V = npc.svd(M.combine_legs(['vR'] + self._p_label, qconj=-1),
                              cutoff=0.,
                              inner_labels=['vR', 'vL'])
            S = S / np.linalg.norm(S)  # normalize
            self.set_SL(i, S)
            self.set_B(i, V.split_legs(1), form='B')
        if self.bc == 'finite':
            assert len(S) == 1
            self._B[0] *= U[0, 0]  # just a trivial phase factor, but better keep it
        # done. Discard the U for segment bc, although it might be a non-trivial unitary.
        # and just re-shuffling of the states left for 'segment' bc)
        if self.bc == 'segment':
            return U, VR_segment

    def _canonical_form_infinite(self):
        raise NotImplementedError("TODO")


class MPSEnvironment(object):
    """Stores partial contractions of :math:`<bra|Op|ket>` for local operators `Op`.

    The network for a contraction :math:`<bra|Op|ket>` of a local operator `Op`, say exemplary
    at sites `i, i+1` looks like::

        |     .-----M[0]--- ... --M[1]---M[2]--- ... ->--.
        |     |     |             |      |               |
        |     |     |             |------|               |
        |     LP[0] |             |  Op  |               RP[-1]
        |     |     |             |------|               |
        |     |     |             |      |               |
        |     .-----N[0]*-- ... --N[1]*--N[2]*-- ... -<--.

    Of course, we can also calculate the overlap `<bra|ket>` by using the special case ``Op = Id``.

    We use the following label convention (where arrows indicate `qconj`)::

        |    .-->- vR           vL ->-.
        |    |                        |
        |    LP                       RP
        |    |                        |
        |    .--<- vR*         vL* -<-.

    To avoid recalculations of the whole network e.g. in the DMRG sweeps,
    we store the contractions up to some site index in this class.
    For ``bc='finite','segment'``, the very left and right part ``LP[0]`` and
    ``RP[-1]`` are trivial and don't change,
    but for ``bc='infinite'`` they are might be updated
    (by inserting another unit cell to the left/right).

    The MPS `bra` and `ket` have to be in canonical form.
    All the environments are constructed without the singular values on the open bond.
    In other words, we contract left-canonical `A` to the left parts `LP`
    and right-canonical `B` to the right parts `RP`.
    Thus, the special case ``ket=bra`` should yield identity matrices for `LP` and `RP`.

    .. todo ::
        Functionality to find the dominant LP/RP in the TD limit -> requires TransferMatrix

    .. todo ::
        Doesn't work for different qtotal in ket._B / bra._B -> Need MPS.gauge_qtotal()

    Parameters
    ----------
    bra : :class:`~tenpy.networks.mps.MPS`
        The MPS to project on. Should be given in usual 'ket' form;
        we call `conj()` on the matrices directly.
    ket : :class:`~tenpy.networks.mpo.MPO`
        The MPS on which the local operator acts.
    firstLP : ``None`` | :class:`~tenpy.linalg.np_conserved.Array`
        Initial very left part. If ``None``, build trivial one.
    rightRP : ``None`` | :class:`~tenpy.linalg.np_conserved.Array`
        Initial very right part. If ``None``, build trivial one.
    age_LP : int
        The number of physical sites involved into the contraction yielding `firstLP`.
    age_RP : int
        The number of physical sites involved into the contraction yielding `lastRP`.

    Attributes
    ----------
    L : int
        Number of physical sites. For iMPS the len of the MPS unit cell.
    dtype : type | string
        The data type of the Array entries.
    bra, ket : :class:`~tenpy.networks.mps.MPS`
        The two MPS for the contraction.
    _LP : list of {``None`` | :class:`~tenpy.linalg.np_conserved.Array`}
        Left parts of the environment, len `L`.
        ``LP[i]`` contains the contraction strictly left of site `i`
        (or ``None``, if we don't have it calculated).
    _RP : list of {``None`` | :class:`~tenpy.linalg.np_conserved.Array`}
        Right parts of the environment, len `L`.
        ``RP[i]`` contains the contraction strictly right of site `i`
        (or ``None``, if we don't have it calculated).
    _LP_age : list of int | ``None``
        Used for book-keeping, how large the DMRG system grew:
        ``_LP_age[i]`` stores the number of physical sites invovled into the contraction
        network which yields ``self._LP[i]``.
    _RP_age : list of int | ``None``
        Used for book-keeping, how large the DMRG system grew:
        ``_RP_age[i]`` stores the number of physical sites invovled into the contraction
        network which yields ``self._RP[i]``.
    """

    def __init__(self, bra, ket, firstLP=None, lastRP=None, age_LP=0, age_RP=0):
        if ket is None:
            ket = bra
        self.bra = bra
        self.ket = ket
        self.L = L = bra.L
        self.finite = bra.finite
        self.dtype = np.find_common_type([bra.dtype, ket.dtype], [])
        self._LP = [None] * L
        self._RP = [None] * L
        self._LP_age = [None] * L
        self._RP_age = [None] * L
        if firstLP is None:
            # Build trivial verly first LP
            leg_bra = bra.get_B(0).get_leg('vL')
            leg_ket = ket.get_B(0).get_leg('vL').conj()
            leg_ket.test_contractible(leg_bra)
            # should work for both finite and segment bc
            firstLP = npc.diag(1., leg_bra, dtype=self.dtype)
            firstLP.iset_leg_labels(['vR*', 'vR'])
        self.set_LP(0, firstLP, age=age_LP)
        if lastRP is None:
            # Build trivial verly last RP
            leg_bra = bra.get_B(L - 1).get_leg('vR')
            leg_ket = ket.get_B(L - 1).get_leg('vR').conj()
            leg_ket.test_contractible(leg_bra)
            lastRP = npc.diag(1., leg_bra, dtype=self.dtype)  # (leg_bra, leg_ket)
            lastRP.iset_leg_labels(['vL*', 'vL'])
        self.set_RP(L - 1, lastRP, age=age_RP)
        self.test_sanity()

    def test_sanity(self):
        assert (self.bra.L == self.ket.L)
        assert (self.bra.finite == self.ket.finite)
        # check that the network is contractable
        for b_s, k_s in zip(self.bra.sites, self.ket.sites):
            b_s.leg.test_equal(k_s.leg)
        assert any([LP is not None for LP in self._LP])
        assert any([RP is not None for RP in self._RP])

    def get_LP(self, i, store=True):
        """Calculate LP at given site from nearest available one (including `i`).

        Parameters
        ----------
        i : int
            The returned `LP` will contain the contraction *strictly* left of site `i`.
        store : bool
            Wheter to store the calculated `LP` in `self` (``True``) or discard them (``False``).

        Returns
        -------
        LP_i : :class:`~tenpy.linalg.np_conserved.Array`
            Contraction of everything left of site `i`,
            with labels ``'vR*', 'vR'`` for `bra`, `ket`.
        """
        # find nearest available LP to the left.
        for i0 in range(i, i - self.L, -1):
            LP = self._LP[self._to_valid_index(i0)]
            if LP is not None:
                break
            # (for finite, LP[0] should always be set, so we should abort at latest with i0=0)
        else:  # no break called
            raise ValueError("No left part in the system???")
        age_i0 = self.get_LP_age(i0)
        for j in range(i0, i):
            LP = self._contract_LP(j, LP)
            if store:
                self.set_LP(j + 1, LP, age=age_i0 + j - i0 + 1)
        return LP

    def get_RP(self, i, store=True):
        """Calculate RP at given site from nearest available one (including `i`).

        Parameters
        ----------
        i : int
            The returned `RP` will contain the contraction *strictly* rigth of site `i`.
        store : bool
            Wheter to store the calculated `RP` in `self` (``True``) or discard them (``False``).

        Returns
        -------
        RP_i : :class:`~tenpy.linalg.np_conserved.Array`
            Contraction of everything left of site `i`,
            with labels ``'vL*', 'vL'`` for `bra`, `ket`.
        """
        # find nearest available RP to the right.
        for i0 in range(i, i + self.L):
            RP = self._RP[self._to_valid_index(i0)]
            if RP is not None:
                break
            # (for finite, RP[-1] should always be set, so we should abort at latest with i0=L-1)
        age_i0 = self.get_RP_age(i0)
        for j in range(i0, i, -1):
            RP = self._contract_RP(j, RP)
            if store:
                self.set_RP(j - 1, RP, age=age_i0 + i0 - j + 1)
        return RP

    def get_LP_age(self, i):
        """Return number of physical sites in the contractions of get_LP(i). Might be ``None``."""
        return self._LP_age[self._to_valid_index(i)]

    def get_RP_age(self, i):
        """Return number of physical sites in the contractions of get_LP(i). Might be ``None``."""
        return self._RP_age[self._to_valid_index(i)]

    def set_LP(self, i, LP, age):
        """Store part to the left of site `i`."""
        i = self._to_valid_index(i)
        self._LP[i] = LP
        self._LP_age[i] = age

    def set_RP(self, i, RP, age):
        """Store part to the right of site 1i1."""
        i = self._to_valid_index(i)
        self._RP[i] = RP
        self._RP_age[i] = age

    def del_LP(self, i):
        """Delete stored part strictly to the left of site `i`."""
        i = self._to_valid_index(i)
        self._LP[i] = None
        self._LP_age[i] = None

    def del_RP(self, i):
        """Delete storde part scrictly to the right of site `i`."""
        i = self._to_valid_index(i)
        self._RP[i] = None
        self._RP_age[i] = None

    def full_contraction(self, i0):
        """Calculate the overlap by a full contraction of the network.

        The full contraction of the environments gives the overlap ``<bra|ket>``,
        taking into account :attr:`MPS.norm` of both `bra` and `ket`.
        For this purpose, this function contracts
        ``get_LP(i0+1, store=False)`` and ``get_RP(i0, store=False)``.

        Parameters
        ----------
        i0 : int
            Site index.
        """
        if self.ket.finite and i0 + 1 == self.L:
            # special case to handle `_to_valid_index` correctly:
            # get_LP(L) is not valid for finite b.c, so we use need to calculate it explicitly.
            LP = self.get_LP(i0, store=False)
            LP = self._contract_LP(i0, LP)
        else:
            LP = self.get_LP(i0 + 1, store=False)
        # multiply with `S`: a bit of a hack: use 'private' MPS._scale_axis_B
        S_bra = self.bra.get_SR(i0).conj()
        LP = self.bra._scale_axis_B(LP, S_bra, form_diff=1., axis_B='vR*', cutoff=0.)
        # cutoff is not used for form_diff = 1
        S_ket = self.ket.get_SR(i0)
        LP = self.bra._scale_axis_B(LP, S_ket, form_diff=1., axis_B='vR', cutoff=0.)
        RP = self.get_RP(i0, store=False)
        contr = npc.inner(LP, RP, axes=[['vR*', 'vR'], ['vL*', 'vL']], do_conj=False)
        return contr * self.bra.norm * self.ket.norm

    def expectation_value(self, ops, sites=None, axes=None):
        """Expectation value ``<bra|ops|ket>`` of (n-site) operator(s).

        Calculates n-site expectation values of operators sandwiched between bra and ket.
        For examples the contraction for a two-site operator on site `i` would look like::

            |          .--S--B[i]--B[i+1]--.
            |          |     |     |       |
            |          |     |-----|       |
            |          LP[i] | op  |       RP[i+1]
            |          |     |-----|       |
            |          |     |     |       |
            |          .--S--B*[i]-B*[i+1]-.

        Here, the `B` are taken from `ket`, the `B*` from `bra`.
        The call structure is the same as for :meth:`MPS.expectation_value`.

        .. warning :
            In contrast to :meth:`MPS.expectation_value`, this funciton does not normalize,
            thus it also takes into account :attr:`MPS.norm` of both `bra` and `ket`.

        Parameters
        ----------
        ops : (list of) { :class:`~tenpy.linalg.np_conserved.Array` | str }
            The operators, for wich the expectation value should be taken,
            All operators should all have the same number of legs (namely `2 n`).
            If less than ``len(sites)`` operators are given, we repeat them periodically.
            Strings (like ``'Id', 'Sz'``) are translated into single-site operators defined by
            :attr:`sites`.
        sites : list
            List of site indices. Expectation values are evaluated there.
            If ``None`` (default), the entire chain is taken (clipping for finite b.c.)
        axes : None | (list of str, list of str)
            Two lists of each `n` leg labels giving the physical legs of the operator used for
            contraction. The first `n` legs are contracted with conjugated `B`,
            the second `n` legs with the non-conjugated `B`.
            ``None`` defaults to ``(['p'], ['p*'])`` for single site (n=1), or
            ``(['p0', 'p1', ... 'p{n-1}'], ['p0*', 'p1*', .... 'p{n-1}*'])`` for `n` > 1.

        Returns
        -------
        exp_vals : 1D ndarray
            Expectation values, ``exp_vals[i] = <bra|ops[i]|ket>``, where ``ops[i]`` acts on
            site(s) ``j, j+1, ..., j+{n-1}`` with ``j=sites[i]``.

        Examples
        --------
        One site examples (n=1):

        >>> env.expectation_value('Sz')
        [Sz0, Sz1, ..., Sz{L-1}]
        >>> env.expectation_value(['Sz', 'Sx'])
        [Sz0, Sx1, Sz2, Sx3, ... ]
        >>> env.expectation_value('Sz', sites=[0, 3, 4])
        [Sz0, Sz3, Sz4]

        Two site example (n=2), assuming homogeneous sites:

        >>> SzSx = npc.outer(psi.sites[0].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                             psi.sites[1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
        >>> env.expectation_value(SzSx)
        [Sz0Sx1, Sz1Sx2, Sz2Sx3, ... ]   # with len L-1 for finite bc, or L for infinite

        Example measuring <bra|SzSx|ket> on each second site, for inhomogeneous sites:

        >>> SzSx_list = [npc.outer(psi.sites[i].Sz.replace_labels(['p', 'p*'], ['p0', 'p0*']),
                                   psi.sites[i+1].Sx.replace_labels(['p', 'p*'], ['p1', 'p1*']))
                         for i in range(0, psi.L-1, 2)]
        >>> env.expectation_value(SzSx_list, range(0, psi.L-1, 2))
        [Sz0Sx1, Sz2Sx3, Sz4Sx5, ...]
        """
        ops, sites, n, (op_ax_p, op_ax_pstar) = self.ket._expectation_value_args(ops, sites, axes)
        ax_p = ['p' + str(k) for k in range(n)]
        ax_pstar = ['p' + str(k) + '*' for k in range(n)]
        E = []
        for i in sites:
            LP = self.get_LP(i, store=True)
            RP = self.get_RP(i, store=True)
            op = self.ket.get_op(ops, i)
            op = op.replace_labels(op_ax_p + op_ax_pstar, ax_p + ax_pstar)
            C = self.ket.get_theta(i, n)
            th_labels = C.get_leg_labels()  # vL, vR, p0, p1, ...
            C = npc.tensordot(op, C, axes=[ax_pstar, ax_p])  # same labels
            C = npc.tensordot(LP, C, axes=['vR', 'vL'])  # axes_p + (vR*, vR)
            C = npc.tensordot(C, RP, axes=['vR', 'vL'])  # axes_p + (vR*, vL*)
            C.ireplace_labels(['vR*', 'vL*'], ['vL', 'vR'])  # back to original theta labels
            theta_bra = self.bra.get_theta(i, n)
            E.append(npc.inner(theta_bra, C, axes=[th_labels] * 2, do_conj=True))
        return np.real_if_close(np.array(E)) * self.bra.norm * self.ket.norm

    def _contract_LP(self, i, LP):
        """Contract LP with the tensors on site `i` to form ``self._LP[i+1]``"""
        LP = npc.tensordot(LP, self.ket.get_B(i, form='A'), axes=('vR', 'vL'))
        LP = npc.tensordot(
            self.bra.get_B(i, form='A').conj(), LP, axes=(['p*', 'vL*'], ['p', 'vR*']))
        return LP  # labels 'vR*', 'vR'

    def _contract_RP(self, i, RP):
        """Contract RP with the tensors on site `i` to form ``self._RP[i-1]``"""
        RP = npc.tensordot(self.ket.get_B(i, form='B'), RP, axes=('vR', 'vL'))
        RP = npc.tensordot(
            self.bra.get_B(i, form='B').conj(), RP, axes=(['p*', 'vR*'], ['p', 'vL*']))
        return RP  # labels 'vL', 'vL*'

    def _to_valid_index(self, i):
        """Make sure `i` is a valid index (depending on `ket.bc`)."""
        return self.ket._to_valid_index(i)


class TransferMatrix(sparse.NpcLinearOperator):
    r"""Transfer matrix of two MPS (bra & ket).

    For an iMPS in the thermodynamic limit, we often need to find the 'dominant `LP`' (and `RP`).
    This mean nothing else than to take the transfer matrix of the unit cell and find the
    (left/right) eigenvector with the largest (magnitude) eigenvalue, since it will dominate
    :math:`LP (TM)^n` (or :math:`(TM)^n RP`) in the limit :math:`n \rightarrow \infty` - whatever
    the initial `LP` is. This class provides exactly that functionality with :meth:`eigenvectors`

    Given two MPS, we define the transfer matrix as::

        |    ---M[i]---M[i+1]- ... --M[i+L]---
        |       |      |             |
        |    ---N[j]*--N[j+1]* ... --N[j+L]*--

    Here the `M` denotes the `B` of the bra and `N` the ones of the ket, respectively.
    To view it as a `matrix`, we combine the left and right indices to pipes::

        |  (vL.vL*) ->-TM->- (vR.vR*)   acting on  (vL.vL*) ->-LP

    .. warning ::
        We don't use any canonical form of these `M` and `N` and no singular values,
        only the matrices as stored. If your state is in canonical form,
        use :meth:`MPS.convert_form` beforehand to ensure a uniform canonical form.

    .. todo ::
        tests give warnings....

    Parameters
    ----------
    bra : MPS
        The MPS which is to be (complex) conjugated
    ket : MPS
        The MPS which is not (complex) conjugated.
    shift_bra : int
        We start the `N` of the bra at site `shift_bra`.
    shift_ket : int | None
        We start the `M` of the ket at site `shift_ket`. ``None`` defaults to `shift_bra`.
    transpose : bool
        Wheter `self.matvec` acts on `RP` (``False``) or `LP` (``True``).
    charge_sector : None | charges | ``0``
        Selects the charge sector of the vector onto which the Linear operator acts.
        ``None`` stands for *all* sectors, ``0`` stands for the zero-charge sector.
        Defaults to ``0``, i.e., *assumes* the dominant eigenvector is in charge sector 0.


    Attributes
    ----------
    L : int
        Number of physical sites involved in the transfer matrix, i.e. the least common multiple
        of `bra.L` and `ket.L`.
    shift_bra : int
        We start the `N` of the bra at site `shift_bra`.
    shift_ket : int | None
        We start the `M` of the ket at site `shift_ket`. ``None`` defaults to `shift_bra`.
    transpose : bool
        Wheter `self.matvec` acts on `RP` (``True``) or `LP` (``False``).
    qtotal : charges
        Total charge of the transfer matrix (which is gauged away in matvec).
    _bra_N : list of npc.Array
        The matrices of the ket, transposed for fast `matvec`.
    _ket_M : list of npc.Array
        Complex conjugated matrices of the bra, transposed for fast `matvec`.
    _pipe : :class:`~tenpy.linalg.charges.LegPipe`
        Pipe corresponding to ``'(vL.vL*)'`` for ``transpose=False``
        or to ``'(vR.vR*)'`` for ``transpose=True``.
    charge_sector : None | charges
        The charge sector of `RP` (`LP` for `transpose`) which is used for the :meth:`matvec`
        with a dense ndarray. ``None`` stands for *all* sectors.
    _mask : bool ndarray
        Selects the indices of the pipe which are used in `matvec`, mapping from
    """

    def __init__(self, bra, ket, shift_bra=0, shift_ket=None, transpose=False, charge_sector=0):
        L = self.L = lcm(bra.L, ket.L)
        if shift_ket is None:
            shift_ket = shift_bra
        self.shift_bra = shift_bra
        self.shift_ket = shift_ket
        self.transpose = transpose
        if ket.chinfo != bra.chinfo:
            raise ValueError("incompatible charges")
        if not transpose:  # right to left
            label = '(vL.vL*)'  # what we act on
            M = self._ket_M = [
                ket.get_B(i, form=None).itranspose(['vL', 'p', 'vR'])
                for i in reversed(range(shift_ket, shift_ket + L))
            ]
            N = self._bra_N = [
                bra.get_B(i, form=None).conj().itranspose(['p*', 'vR*', 'vL*'])
                for i in reversed(range(shift_bra, shift_bra + L))
            ]
            pipe = npc.LegPipe([M[0].get_leg('vR'), N[0].get_leg('vR*')], qconj=-1).conj()
        else:  # right to left
            label = '(vR.vR*)'
            M = self._ket_M = [
                ket.get_B(i, form=None).itranspose(['vR', 'p', 'vL'])
                for i in range(shift_ket, shift_ket + L)
            ]
            N = self._bra_N = [
                bra.get_B(i, form=None).conj().itranspose(['p*', 'vL*', 'vR*'])
                for i in range(shift_bra, shift_bra + L)
            ]
            pipe = npc.LegPipe([M[0].get_leg('vL'), N[0].get_leg('vL*')], qconj=-1).conj()
        dtype = np.promote_types(bra.dtype, ket.dtype)
        self.flat_linop = sparse.FlatLinearOperator(self.matvec, pipe, dtype, charge_sector, label)
        self.qtotal = bra.chinfo.make_valid(np.sum([B.qtotal for B in M + N], axis=0))

    def matvec(self, vec):
        """Given `vec` as an npc.Array, apply the transfer matrix.

        Parameters
        ----------
        vec : :class:`~tenpy.linalg.np_conserved.Array`
            Vector to act on with the transfermatrix.
            If not `transposed`, `vec` is the right part `RP` of an environment,
            with legs ``'(vL.vL*)'`` in a pipe or separately.
            If `transposed`, the left part `LP` of an environment with legs ``'(vR.vR*)'``.

        Returns
        -------
        mat_vec : :class:`~tenpy.linalg.np_conserved.Array`
            The tranfer matrix acted on `vec`, in the same form as given.
            or as separate legs.
        """
        pipe = None
        if vec.rank == 1:
            vec = vec.split_legs(0)
            pipe = self.flat_linop.leg
        qtotal = vec.qtotal
        legs = vec.legs
        # the actual work
        if not self.transpose:  # right to left
            for N, M in zip(self._bra_N, self._ket_M):
                vec = npc.tensordot(M, vec, axes=['vR', 'vL'])
                vec = npc.tensordot(vec, N, axes=[['p', 'vL*'], ['p*', 'vR*']])
        else:  # left to right
            for N, M in zip(self._bra_N, self._ket_M):
                vec = npc.tensordot(M, vec, axes=['vL', 'vR'])
                vec = npc.tensordot(vec, N, axes=[['p', 'vR*'], ['p*', 'vL*']])
        if np.any(self.qtotal != 0):
            # Hack: replace leg charges and qtotal -> effectively gauge `self.qtotal` away.
            vec.qtotal = qtotal
            vec.legs = legs
            vec.test_sanity()  # Should be fine, but who knows...
        if pipe is not None:
            vec = vec.combine_legs([0, 1], pipes=pipe)
        return vec

    def eigenvectors(self, num_ev=1, max_num_ev=None, max_tol=1.e-12, which='LM', **kwargs):
        """Find (dominant) eigenvector(s) of self using scipy.sparse.

        If no charge_sector was selected, we look in *all* charge sectors.

        Parameters
        ----------
        num_ev : int
            Number of eigenvalues/vectors to look for.
        max_num_ev : int
            :func:`scipy.sparse.linalg.speigs` somtimes raises a NoConvergenceError for small
            `num_ev`, which might be avoided by increasing `num_ev`. As a work-around,
            we try it again in the case of an error, just with larger `num_ev` up to `max_num_ev`.
            ``None`` defaults to ``num_ev + 2``.
        max_tol : float
            After the first `NoConvergenceError` we increase the `tol` argument to that value.
        which : str
            Which eigenvalues to look for, see `scipy.sparse.linalg.speigs`.
        **kwargs :
            Further keyword arguments are given to :func:`~tenpy.tools.math.speigs`.

        Returns
        -------
        eta : 1D ndarray
            The eigenvalues, sorted according to `which`.
        w : list of :class:`~tenpy.linalg.np_conserved.Array`
            The eigenvectors corresponding to `eta`, as npc.Array with LegPipe.
        """
        if max_num_ev is None:
            max_num_ev = num_ev + 2
        flat_linop = self.flat_linop
        if flat_linop.charge_sector is None:
            # Try for all charge sectors
            eta = []
            A = []
            for chsect in flat_linop.possible_charge_sectors:
                flat_linop.charge_sector = chsect
                eta_cs, A_cs = self.eigenvectors(num_ev, max_num_ev, max_tol, which, **kwargs)
                eta.extend(eta_cs)
                A.extend(A_cs)
            flat_linop.charge_sector = None
        else:
            # for given charge sector
            for k in range(num_ev, max_num_ev + 1):
                if k > num_ev:
                    warnings.warn("increased `num_ev` to " + str(k + 1))
                try:
                    eta, A = speigs(flat_linop, k=k, which='LM', **kwargs)
                    A = np.real_if_close(A)
                    A = [flat_linop.flat_to_npc(A[:, j]) for j in range(A.shape[1])]
                    break
                except scipy.sparse.linalg.eigen.arpack.ArpackNoConvergence:
                    if k == max_num_ev:
                        raise
                    # just retry with larger k and 'tol'
                    kwargs['tol'] = max(max_tol, kwargs.get('tol', 0))
        # sort
        perm = argsort(eta, which)
        return np.array(eta)[perm], [A[j] for j in perm]
