Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ New Features
- Added a new Milky Way potential model: ``MilkyWayPotential2022``, which is based on
updated measurements of the disk structure and circular velocity curve of the disk.

- Added the ability to use leapfrog integration within the ``DirectNBody`` integrator.


Bug fixes
---------
Expand Down
155 changes: 108 additions & 47 deletions gala/dynamics/nbody/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,28 @@
# Third-party
import numpy as np

from ...integrate.cyintegrators.leapfrog import leapfrog_integrate_nbody
from ...integrate.cyintegrators.ruth4 import ruth4_integrate_nbody
from ...integrate.timespec import parse_time_specification
from ...potential import Hamiltonian, NullPotential, StaticFrame
from ...units import UnitSystem
from ...util import atleast_2d
from ...integrate.timespec import parse_time_specification
from .. import Orbit, PhaseSpacePosition

from .nbody import direct_nbody_dop853, nbody_acceleration

__all__ = ['DirectNBody']
__all__ = ["DirectNBody"]


class DirectNBody:

def __init__(self, w0, particle_potentials, external_potential=None,
frame=None, units=None, save_all=True):
def __init__(
self,
w0,
particle_potentials,
external_potential=None,
frame=None,
units=None,
save_all=True,
):
"""Perform orbit integration using direct N-body forces between
particles, optionally in an external background potential.

Expand Down Expand Up @@ -50,21 +57,27 @@ def __init__(self, w0, particle_potentials, external_potential=None,

"""
if not isinstance(w0, PhaseSpacePosition):
raise TypeError("Initial conditions `w0` must be a "
"gala.dynamics.PhaseSpacePosition object, "
"not '{}'".format(w0.__class__.__name__))
raise TypeError(
"Initial conditions `w0` must be a "
"gala.dynamics.PhaseSpacePosition object, "
"not '{}'".format(w0.__class__.__name__)
)

if len(w0.shape) > 0:
if w0.shape[0] != len(particle_potentials):
raise ValueError("The number of initial conditions in `w0` must"
" match the number of particle potentials "
"passed in with `particle_potentials`.")
raise ValueError(
"The number of initial conditions in `w0` must"
" match the number of particle potentials "
"passed in with `particle_potentials`."
)

# TODO: this is a MAJOR HACK
if w0.shape[0] > 524288: # see MAX_NBODY in _nbody.pyx
raise NotImplementedError("We currently only support direct "
"N-body integration for <= 524288 "
"particles.")
raise NotImplementedError(
"We currently only support direct "
"N-body integration for <= 524288 "
"particles."
)

# First, figure out how to get units - first place to check is the arg
if units is None:
Expand All @@ -80,11 +93,13 @@ def __init__(self, w0, particle_potentials, external_potential=None,

# Now, if units are still None, raise an error!
if units is None:
raise ValueError("Could not determine units from input! You must "
"either (1) pass in the unit system with `units`,"
"(2) set the units on one of the "
"particle_potentials, OR (3) pass in an "
"`external_potential` with valid units.")
raise ValueError(
"Could not determine units from input! You must "
"either (1) pass in the unit system with `units`,"
"(2) set the units on one of the "
"particle_potentials, OR (3) pass in an "
"`external_potential` with valid units."
)
if not isinstance(units, UnitSystem):
units = UnitSystem(units)

Expand Down Expand Up @@ -112,12 +127,13 @@ def __init__(self, w0, particle_potentials, external_potential=None,
self.particle_potentials = _particle_potentials
self.save_all = save_all

self.H = Hamiltonian(self.external_potential,
frame=self.frame)
self.H = Hamiltonian(self.external_potential, frame=self.frame)
if not self.H.c_enabled:
raise ValueError("Input potential must be C-enabled: one or more "
"components in the input external potential are "
"Python-only.")
raise ValueError(
"Input potential must be C-enabled: one or more "
"components in the input external potential are "
"Python-only."
)

self.w0 = w0

Expand All @@ -132,36 +148,33 @@ def w0(self, value):

def _cache_w0(self):
# cache the position and velocity / prepare the initial conditions
self._pos = atleast_2d(self.w0.xyz.decompose(self.units).value,
insert_axis=1)
self._vel = atleast_2d(self.w0.v_xyz.decompose(self.units).value,
insert_axis=1)
self._pos = atleast_2d(self.w0.xyz.decompose(self.units).value, insert_axis=1)
self._vel = atleast_2d(self.w0.v_xyz.decompose(self.units).value, insert_axis=1)
self._c_w0 = np.ascontiguousarray(np.vstack((self._pos, self._vel)).T)

def __repr__(self):
if self.w0.shape:
return "<{} bodies={}>".format(self.__class__.__name__,
self.w0.shape[0])
return "<{} bodies={}>".format(self.__class__.__name__, self.w0.shape[0])
else:
return "<{} bodies=1>".format(self.__class__.__name__)

def _nbody_acceleration(self, t=0.):
def _nbody_acceleration(self, t=0.0):
"""
Compute the N-body acceleration at the location of each body
"""
nbody_acc = nbody_acceleration(self._c_w0, t, self.particle_potentials)
return nbody_acc.T

def acceleration(self, t=0.):
def acceleration(self, t=0.0):
"""
Compute the acceleration at the location of each N body, including the
external potential.
"""
nbody_acc = self._nbody_acceleration(t=t) * self.units['acceleration']
nbody_acc = self._nbody_acceleration(t=t) * self.units["acceleration"]
ext_acc = self.external_potential.acceleration(self.w0, t=t)
return nbody_acc + ext_acc

def integrate_orbit(self, **time_spec):
def integrate_orbit(self, Integrator=None, Integrator_kwargs=dict(), **time_spec):
"""
Integrate the initial conditions in the combined external potential
plus N-body forces.
Expand All @@ -180,31 +193,79 @@ def integrate_orbit(self, **time_spec):
The orbits of the particles.

"""
from gala.integrate import (
DOPRI853Integrator,
LeapfrogIntegrator,
Ruth4Integrator,
)

if Integrator is None:
Integrator = DOPRI853Integrator

# Prepare the time-stepping array
t = parse_time_specification(self.units, **time_spec)

ws = direct_nbody_dop853(self._c_w0, t, self.H,
self.particle_potentials,
save_all=self.save_all)
# Reorganize orbits so that massive bodies are first:
front_idx = []
front_pp = []
end_idx = []
end_pp = []
for i, pp in enumerate(self.particle_potentials):
if not isinstance(pp, NullPotential):
front_idx.append(i)
front_pp.append(pp)
else:
end_idx.append(i)
end_pp.append(pp)
idx = np.array(front_idx + end_idx)
pps = front_pp + end_pp

reorg_w0 = np.ascontiguousarray(self._c_w0[idx])

if Integrator == LeapfrogIntegrator:
_, ws = leapfrog_integrate_nbody(
self.H, reorg_w0, t, pps, store_all=int(self.save_all)
)
elif Integrator == Ruth4Integrator:
_, ws = ruth4_integrate_nbody(
self.H, reorg_w0, t, pps, store_all=int(self.save_all)
)
elif Integrator == DOPRI853Integrator:
ws = direct_nbody_dop853(reorg_w0, t, self.H, pps, save_all=self.save_all)
else:
raise NotImplementedError(
"N-body integration is currently not supported with the {Integrator} "
"integrator class"
)

if self.save_all:
pos = np.rollaxis(np.array(ws[..., :3]), axis=2)
pos = np.rollaxis(np.array(ws[..., :3]), axis=2) # should this be axis=-1?
vel = np.rollaxis(np.array(ws[..., 3:]), axis=2)

orbits = Orbit(
pos=pos * self.units['length'],
vel=vel * self.units['length'] / self.units['time'],
t=t * self.units['time'],
hamiltonian=self.H)
pos=pos * self.units["length"],
vel=vel * self.units["length"] / self.units["time"],
t=t * self.units["time"],
hamiltonian=self.H,
)

else:
pos = np.array(ws[..., :3]).T
vel = np.array(ws[..., 3:]).T

orbits = PhaseSpacePosition(
pos=pos * self.units['length'],
vel=vel * self.units['length'] / self.units['time'],
frame=self.frame)
pos=pos * self.units["length"],
vel=vel * self.units["length"] / self.units["time"],
frame=self.frame,
)

# Reorder orbits:
remap_idx = np.zeros((orbits.shape[-1], orbits.shape[-1]), dtype=int)
remap_idx[idx, np.arange(orbits.shape[-1])] = 1
_, undo_idx = np.where(remap_idx == 1)

return orbits[..., undo_idx]
remap_idx[idx, np.arange(orbits.shape[-1])] = 1
_, undo_idx = np.where(remap_idx == 1)

return orbits
return orbits[..., undo_idx]
15 changes: 11 additions & 4 deletions gala/dynamics/nbody/nbody.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ np.import_array()
from libc.math cimport sqrt
from cpython.exc cimport PyErr_CheckSignals

from ...potential import Hamiltonian
from ...potential import Hamiltonian, NullPotential
from ...potential.potential.cpotential cimport (CPotentialWrapper,
MAX_N_COMPONENTS, CPotential)
from ...potential.frame.cframe cimport CFrameWrapper
Expand Down Expand Up @@ -58,9 +58,13 @@ cpdef direct_nbody_dop853(double [:, ::1] w0, double[::1] t,
By default, this integration procedure stores the full time series of all
orbits, but this may use a lot of memory. If you just want to store the
final state of the orbits, pass ``save_all=False``.

NOTE: This assumes that all massive bodies are organized at the start of w0 and
particle_potentials, and all test particles are *after* the massive bodies.
"""
cdef:
unsigned nparticles = w0.shape[0]
unsigned nbody = 0
unsigned ndim = w0.shape[1]
unsigned ntimes = len(t)

Expand All @@ -85,25 +89,28 @@ cpdef direct_nbody_dop853(double [:, ::1] w0, double[::1] t,
f"of particle potentials passed in ({nparticles} vs. "
f"{len(particle_potentials)}).")

for pot in particle_potentials:
if not isinstance(pot, NullPotential):
nbody += 1

# Extract the CPotential objects from the particle potentials.
for i in range(nparticles):
c_particle_potentials[i] = &(<CPotentialWrapper>(particle_potentials[i].c_instance)).cpotential

# We need a void pointer for any other arguments
args = <void *>(&c_particle_potentials[0])

# TODONOW: fix below - need to pass in how many massive particles, total number of orbits and
if save_all:
all_w = dop853_helper_save_all(&cp, &cf,
<FcnEqDiff> Fwrapper_direct_nbody,
w0, t,
ndim, nparticles, nparticles, args,
ndim, nparticles, nbody, args,
ntimes, atol, rtol, nmax, 0)
else:
all_w = dop853_helper(&cp, &cf,
<FcnEqDiff> Fwrapper_direct_nbody,
w0, t,
ndim, nparticles, nparticles, args, ntimes,
ndim, nparticles, nbody, args, ntimes,
atol, rtol, nmax, 0)
all_w = np.array(all_w).reshape(nparticles, ndim)

Expand Down
Loading