# cython: boundscheck=False
# cython: nonecheck=False
# cython: cdivision=True
# cython: wraparound=False
# cython: profile=False
# cython: language_level=3

""" Leapfrog integration in Cython. """

# Third-party
import numpy as np
cimport numpy as np
np.import_array()

# Project
from ...potential.potential.cpotential cimport CPotentialWrapper
from ...potential.frame import StaticFrame

cdef extern from "frame/src/cframe.h":
    ctypedef struct CFrameType:
        pass

cdef extern from "potential/src/cpotential.h":
    ctypedef struct CPotential:
        pass

cdef extern from "potential/src/cpotential.h":
    void c_gradient(CPotential *p, double t, double *q, double *grad) nogil


cdef void c_ruth4_step(CPotential *p, int half_ndim, double t, double dt,
                       double *cs, double *ds,
                       double *w, double *grad) nogil:
    cdef:
        int j, k

    for j in range(4):
        for k in range(half_ndim):
             grad[k] = 0.
        c_gradient(p, t, w, grad)
        for k in range(half_ndim):
            w[half_ndim + k] = w[half_ndim + k] - ds[j] * grad[k] * dt
            w[k] = w[k] + cs[j] * w[half_ndim + k] * dt

cpdef ruth4_integrate_hamiltonian(hamiltonian,
                                  double[:, ::1] w0,
                                  double[::1] t):
    """
    CAUTION: Interpretation of axes is different here! We need the
    arrays to be C ordered and easy to iterate over, so here the
    axes are (norbits, ndim).
    """

    if not hamiltonian.c_enabled:
        raise TypeError("Input Hamiltonian object does not support C-level access.")

    if not isinstance(hamiltonian.frame, StaticFrame):
        raise TypeError("Leapfrog integration is currently only supported "
                        "for StaticFrame, not {}."
                        .format(hamiltonian.frame.__class__.__name__))

    cdef:
        # temporary scalars
        int i, j, k
        int n = w0.shape[0]
        int ndim = w0.shape[1]
        int half_ndim = ndim // 2

        int ntimes = len(t)
        double dt = t[1] - t[0]

        # Integrator coefficients
        double two_13 = 2 ** (1./3.)
        double[::1] cs = np.array([
            1. / (2. * (2. - two_13)),
            (1. - two_13) / (2.*(2. - two_13)),
            (1. - two_13) / (2.*(2. - two_13)),
            1. / (2.*(2. - two_13))
        ], dtype='f8')

        double[::1] ds = np.array([
            0.,
            1. / (2. - two_13),
            -two_13 / (2. - two_13),
            1. / (2. - two_13)
        ], dtype='f8')

        # temporary array containers
        double[::1] grad = np.zeros(half_ndim)

        # return arrays
        double[:, :, ::1] all_w = np.zeros((ntimes, n, ndim))

        # whoa, so many dots
        CPotential cp = (<CPotentialWrapper>(hamiltonian.potential.c_instance)).cpotential

    # save initial conditions
    all_w[0, :, :] = w0.copy()

    with nogil:

        for j in range(1, ntimes, 1):
            for i in range(n):
                # Copy previous step w to current step w
                for k in range(ndim):
                    all_w[j, i, k] = all_w[j-1, i, k]

                c_ruth4_step(&cp, half_ndim, t[j], dt,
                             &cs[0], &ds[0],
                             &all_w[j, i, 0], &grad[0])

    return np.asarray(t), np.asarray(all_w)
