reaxion is a flexible, object-oriented implementation for systems of ISM microphysics and chemistry equations, with numerical solvers implemented in JAX, and interfaces for embedding the equations and their Jacobians into other codes.
reaxion might be interesting because it combines two powerful concepts:
- Object-oriented implementation of microphysics and chemistry via the
Processclass, which implements methods for representing physical processes, composing them into a network in a fully-symbolicsympyrepresentation. OOP is nice here because if you want to add a new process toreaxion, you typically only have to do it in one file. Rate expressions never have to be repeated in-code. Most processes one would want to implement follow very common patterns (e.g. 2-body processes), so class inheritance is also used to minimize new lines of code. Once you've constructed your system,reaxioncan give you the symbolic equations to manipulate and analyze as you please. If you want to solve the equations numerically,Processhas methods for substituting known values into numerical solvers. It can also automatically generate compilable implementations of the RHS of the system to embed in your choice of simulation code and plug into your choice of solver. - Fast, differentiable implementation of nonlinear algebraic and differential-algebraic equation solvers with JAX, implemented in its functional programming paradigm (e.g.
reaxion.numerics.newton_rootsolve). These can achieve excellent numerical throughput running natively on GPUs - in fact, crunching iterates in-place is essentially the best-case application of numerics on GPUs. Differentiability enables sensitivity analysis with respect to all parameters in a single pass, instead of constructing a grid ofNparameter variations forNparameters. This makes it easier in principle to directly answer questions like "How sensitive is this temperature to the abundance of C or the ionization energy of H?", etc.
reaxion is in an early prototyping phase right now. Here are some things I would eventually like to add:
- Flexible implementation of a reduced network suitable for RHD simulations in GIZMO and potentially other codes.
- Dust and radiation physics: add the dust energy equation and evolution of photon number densities to the network.
- Interfaces to convert from other existing chemistry network formats to the
Processrepresentation. - Solver robustness upgrades: thermochemical networks can be quite challenging numerically, due to how steeply terms switch on with increasing
T. In can be hard to get a solution without good initial guesses. - If possible, glue interface allowing an existing compiled hydro code to call the JAX solvers on-the-fly.
Clone the repo and run pip install . from the directory, or install the latest release from pypi via pip install reaxion.
Example of using reaxion to solve for collisional ionization equilibrium (CIE) for a hydrogen-helium mixture and plot the ionization states as a function of temperature.
%matplotlib inline
%config InlineBackend.figure_format='retina'
import numpy as np
from matplotlib import pyplot as plt
import sympy as spA simple process is defined by a single reaction, with a specified rate.
Let's inspect the structure of a single process, the gas-phase recombination of H+: H+ + e- -> H + hν
from reaxion.processes import CollisionalIonization, GasPhaseRecombination
process = GasPhaseRecombination("H+")
print(f"Name: {process.name}")
print(f"Heating rate coefficient: {process.heat_rate_coefficient}")
print(f"Heating rate per cm^-3: {process.heat}"),
print(f"Rate coefficient: {process.rate_coefficient}")
print(f"Recombination rate per cm^-3: {process.rate}")
print(f"RHS of e- number density equation: {process.network['e-']}")Name: Gas-phase recombination of H+
Heating rate coefficient: -1.46719838641439e-26*sqrt(T)/((0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
Heating rate per cm^-3: -1.46719838641439e-26*sqrt(T)*n_H+*n_e-/((0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
Rate coefficient: 1.41621465870114e-10/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
Recombination rate per cm^-3: 1.41621465870114e-10*n_H+*n_e-/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
RHS of e- number density equation: Eq(Derivative(n_e-(t), t), -1.41621465870114e-10*n_H+*n_e-/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252))
Note that all symbolic representations assume CGS units as is standard in ISM physics.
Now let's define our full network as a sum of simple processes
processes = [CollisionalIonization(s) for s in ("H", "He", "He+")] + [GasPhaseRecombination(i) for i in ("H+", "He+", "He++")]
system = sum(processes)
system.subprocesses[Collisional Ionization of H,
Collisional Ionization of He,
Collisional Ionization of He+,
Gas-phase recombination of H+,
Gas-phase recombination of He+,
Gas-phase recombination of He++]
Summed processes keep track of all subprocesses, e.g. the total net heating rate is:
system.heatSumming processes also sums all chemical and gas/dust cooling/heating rates.
We would like to solve for ionization equilibrium given a temperature
Tgrid = np.logspace(3,6,10**6)
ngrid = np.ones_like(Tgrid) * 100
knowns = {"T": Tgrid, "n_Htot": ngrid}
guesses = {
"H": 0.5*np.ones_like(Tgrid),
"He": 1e-5*np.ones_like(Tgrid),
"He+": 1e-5*np.ones_like(Tgrid)
}Note that by default, the solver only directly solves for
The solve method calls the JAX solver and computes the solution:
sol = system.solve(knowns, guesses,tol=1e-3)
print(sol){'He': Array([9.2546351e-02, 9.2546351e-02, 9.2546351e-02, ..., 2.7493625e-09,
2.7493037e-09, 2.7492442e-09], dtype=float32), 'H': Array([9.9999994e-01, 9.9999994e-01, 9.9999994e-01, ..., 6.0612075e-07,
6.0611501e-07, 6.0610921e-07], dtype=float32), 'He+': Array([3.1222404e-13, 3.1222396e-13, 3.1222374e-13, ..., 7.6922206e-06,
7.6921306e-06, 7.6920396e-06], dtype=float32), 'He++': Array([0. , 0. , 0. , ..., 0.09253865, 0.09253865,
0.09253865], dtype=float32), 'H+': Array([5.9604645e-08, 5.9604645e-08, 5.9604645e-08, ..., 9.9999940e-01,
9.9999940e-01, 9.9999940e-01], dtype=float32), 'e-': Array([5.9604957e-08, 5.9604957e-08, 5.9604957e-08, ..., 1.1850843e+00,
1.1850843e+00, 1.1850843e+00], dtype=float32)}
for i, xi in sorted(sol.items()):
plt.loglog(Tgrid, xi, label=i)
plt.legend(labelspacing=0)
plt.ylabel("$x_i$")
plt.xlabel("T (K)")
plt.ylim(1e-4,3)(0.0001, 3)
Suppose you just want the RHS of the system you're solving, or its Jacobian, because you have a better solver and/or want to embed these equations in some old C or Fortran code without any dependencies. You can do that too with generate_code.
print(system.generate_code(('H','He','He+'),language='c'))# Computes the RHS function and Jacobianto solve for [x_He, x_H, x_Heplus]
# INDEX CONVENTION: (0: x_He) (1: x_H) (2: x_Heplus)
x0 = 1.0/T;
x1 = sqrt(T);
x2 = pow(n_Htot, 2);
x3 = 1.0/((1.0/1000.0)*sqrt(10)*x1 + 1);
x4 = x1*x2*x3;
x5 = x4*exp(-285335.40000000002*x0);
x6 = x5*x_He;
x7 = x_H - 1;
x8 = -x7 - 2*x_He - x_Heplus + 2*y;
x9 = 2.3800000000000001e-11*x8;
x10 = 1.0/x1;
x11 = x2*(0.0019*pow(T, -1.5)*(1 + 0.29999999999999999*exp(-94000.0*x0))*exp(-470000.0*x0) + 1.9324160622805846e-10*x10*pow(0.00016493478118851054*x1 + 1.0, -1.7891999999999999)*pow(4.8416074481177231*x1 + 1.0, -0.21079999999999999));
x12 = x11*x_Heplus;
x13 = -x12*x8 + x6*x9;
x14 = exp(-157809.10000000001*x0);
x15 = x14*x4;
x16 = x15*x_H;
x17 = 5.8500000000000005e-11*x16;
x18 = -x7;
x19 = pow(0.0011921669684770192*x1 + 1.0, -1.748);
x20 = pow(0.56361512366497779*x1 + 1.0, -0.252);
x21 = -x_He - x_Heplus + y;
x22 = x10*x2;
x23 = x22*pow(0.00059608348423850961*x1 + 1.0, -1.748)*pow(0.2818075618324889*x1 + 1.0, -0.252);
x24 = 5.664858634804579e-10*x23;
x25 = x24*x8;
x26 = exp(-631515*x0);
x27 = x26*x4;
x28 = 4.7600000000000002e-11*x6;
x29 = x5*x9;
x30 = 2*x12;
x31 = -x12 + 2.3800000000000001e-11*x6;
x32 = x11*x8 + x31;
x33 = x19*x20*x22;
x34 = x18*x33;
x35 = 1.4162146587011448e-10*x34;
x36 = -5.68e-12*x1*x2*x26*x3*x_Heplus + x21*x24;
rhs_result[0] = -x13;
rhs_result[1] = 1.4162146587011448e-10*x10*x18*x19*x2*x20*x8 - x17*x8;
rhs_result[2] = x13 + x21*x25 - 5.68e-12*x27*x8*x_Heplus;
jac_result[0] = x28 - x29 - x30;
jac_result[1] = x31;
jac_result[2] = x32;
jac_result[3] = 1.1700000000000001e-10*x16 - 2.8324293174022895e-10*x34;
jac_result[4] = 5.8500000000000005e-11*x1*x14*x2*x3*x_H - 5.8500000000000005e-11*x15*x8 - 1.4162146587011448e-10*x33*x8 - x35;
jac_result[5] = x17 - x35;
jac_result[6] = 1.136e-11*x1*x2*x26*x3*x_Heplus - 1.1329717269609158e-9*x21*x23 - x25 - x28 + x29 + x30;
jac_result[7] = -x31 - x36;
jac_result[8] = -x25 - 5.68e-12*x27*x8 - x32 - x36;
Let's break down what happened there. First, reaxion is generating the symbolic functions needed to solve the system, as it needs to do before it solves the system with its own solver:
func, jac, _ = system.network.solver_functions(('H','He','He+'),return_jac=True)Here func represents the set of functions jac encodes the Jacbian of f
cse, (cse_func, cse_jac) = sp.cse((sp.Matrix(func),sp.Matrix(jac)))
cse[(x0, 1/T),
(x1, sqrt(T)),
(x2, n_Htot**2),
(x3, 1/(sqrt(10)*x1/1000 + 1)),
(x4, x1*x2*x3),
(x5, x4*exp(-285335.4*x0)),
(x6, x5*x_He),
(x7, x_H - 1),
(x8, -x7 - 2*x_He - x_He+ + 2*y),
(x9, 2.38e-11*x8),
(x10, 1/x1),
(x11,
x2*(0.0019*(1 + 0.3*exp(-94000.0*x0))*exp(-470000.0*x0)/T**1.5 + 1.93241606228058e-10*x10/((0.000164934781188511*x1 + 1.0)**1.7892*(4.84160744811772*x1 + 1.0)**0.2108))),
(x12, x11*x_He+),
(x13, -x12*x8 + x6*x9),
(x14, exp(-157809.1*x0)),
(x15, x14*x4),
(x16, x15*x_H),
(x17, 5.85e-11*x16),
(x18, -x7),
(x19, (0.00119216696847702*x1 + 1.0)**(-1.748)),
(x20, (0.563615123664978*x1 + 1.0)**(-0.252)),
(x21, -x_He - x_He+ + y),
(x22, x10*x2),
(x23,
x22/((0.00059608348423851*x1 + 1.0)**1.748*(0.281807561832489*x1 + 1.0)**0.252)),
(x24, 5.66485863480458e-10*x23),
(x25, x24*x8),
(x26, exp(-631515*x0)),
(x27, x26*x4),
(x28, 4.76e-11*x6),
(x29, x5*x9),
(x30, 2*x12),
(x31, -x12 + 2.38e-11*x6),
(x32, x11*x8 + x31),
(x33, x19*x20*x22),
(x34, x18*x33),
(x35, 1.41621465870114e-10*x34),
(x36, -5.68e-12*x1*x2*x26*x3*x_He+ + x21*x24)]
cse_func$\displaystyle \left[\begin{matrix}- x_{13}\1.41621465870114 \cdot 10^{-10} x_{10} x_{18} x_{19} x_{2} x_{20} x_{8} - x_{17} x_{8}\x_{13} + x_{21} x_{25} - 5.68 \cdot 10^{-12} x_{27} x_{8} x_{He+}\end{matrix}\right]$
cse_jac$\displaystyle \left[\begin{matrix}x_{28} - x_{29} - x_{30} & x_{31} & x_{32}\1.17 \cdot 10^{-10} x_{16} - 2.83242931740229 \cdot 10^{-10} x_{34} & 5.85 \cdot 10^{-11} x_{1} x_{14} x_{2} x_{3} x_{H} - 5.85 \cdot 10^{-11} x_{15} x_{8} - 1.41621465870114 \cdot 10^{-10} x_{33} x_{8} - x_{35} & x_{17} - x_{35}\1.136 \cdot 10^{-11} x_{1} x_{2} x_{26} x_{3} x_{He+} - 1.13297172696092 \cdot 10^{-9} x_{21} x_{23} - x_{25} - x_{28} + x_{29} + x_{30} & - x_{31} - x_{36} & - x_{25} - 5.68 \cdot 10^{-12} x_{27} x_{8} - x_{32} - x_{36}\end{matrix}\right]$
One can then take these expressions and convert them to the syntax of the code you wish to embed them in:
from sympy.codegen.ast import Assignment
for expr in cse:
print(sp.ccode(Assignment(*expr),standard='c99'))
rhs_result = sp.MatrixSymbol('rhs_result', len(func), 1)
jac_result = sp.MatrixSymbol('jac_result', len(func),len(func))
print()
print(sp.ccode(Assignment(rhs_result, cse_func),standard='c99'))
print()
print(sp.ccode(Assignment(jac_result, cse_jac),standard='c99'))x0 = 1.0/T;
x1 = sqrt(T);
x2 = pow(n_Htot, 2);
x3 = 1.0/((1.0/1000.0)*sqrt(10)*x1 + 1);
x4 = x1*x2*x3;
x5 = x4*exp(-285335.40000000002*x0);
x6 = x5*x_He;
x7 = x_H - 1;
x8 = -x7 - 2*x_He - x_He+ + 2*y;
x9 = 2.3800000000000001e-11*x8;
x10 = 1.0/x1;
x11 = x2*(0.0019*pow(T, -1.5)*(1 + 0.29999999999999999*exp(-94000.0*x0))*exp(-470000.0*x0) + 1.9324160622805846e-10*x10*pow(0.00016493478118851054*x1 + 1.0, -1.7891999999999999)*pow(4.8416074481177231*x1 + 1.0, -0.21079999999999999));
x12 = x11*x_He+;
x13 = -x12*x8 + x6*x9;
x14 = exp(-157809.10000000001*x0);
x15 = x14*x4;
x16 = x15*x_H;
x17 = 5.8500000000000005e-11*x16;
x18 = -x7;
x19 = pow(0.0011921669684770192*x1 + 1.0, -1.748);
x20 = pow(0.56361512366497779*x1 + 1.0, -0.252);
x21 = -x_He - x_He+ + y;
x22 = x10*x2;
x23 = x22*pow(0.00059608348423850961*x1 + 1.0, -1.748)*pow(0.2818075618324889*x1 + 1.0, -0.252);
x24 = 5.664858634804579e-10*x23;
x25 = x24*x8;
x26 = exp(-631515*x0);
x27 = x26*x4;
x28 = 4.7600000000000002e-11*x6;
x29 = x5*x9;
x30 = 2*x12;
x31 = -x12 + 2.3800000000000001e-11*x6;
x32 = x11*x8 + x31;
x33 = x19*x20*x22;
x34 = x18*x33;
x35 = 1.4162146587011448e-10*x34;
x36 = -5.68e-12*x1*x2*x26*x3*x_He+ + x21*x24;
rhs_result[0] = -x13;
rhs_result[1] = 1.4162146587011448e-10*x10*x18*x19*x2*x20*x8 - x17*x8;
rhs_result[2] = x13 + x21*x25 - 5.68e-12*x27*x8*x_He+;
jac_result[0] = x28 - x29 - x30;
jac_result[1] = x31;
jac_result[2] = x32;
jac_result[3] = 1.1700000000000001e-10*x16 - 2.8324293174022895e-10*x34;
jac_result[4] = 5.8500000000000005e-11*x1*x14*x2*x3*x_H - 5.8500000000000005e-11*x15*x8 - 1.4162146587011448e-10*x33*x8 - x35;
jac_result[5] = x17 - x35;
jac_result[6] = 1.136e-11*x1*x2*x26*x3*x_He+ - 1.1329717269609158e-9*x21*x23 - x25 - x28 + x29 + x30;
jac_result[7] = -x31 - x36;
jac_result[8] = -x25 - 5.68e-12*x27*x8 - x32 - x36;