Thanks to visit codestin.com
Credit goes to github.com

Skip to content

mikegrudic/reaxion

Repository files navigation

reaxion

Python package Readthedocs Status codecov

reaxion is a flexible, object-oriented implementation for systems of ISM microphysics and chemistry equations, with numerical solvers implemented in JAX, and interfaces for embedding the equations and their Jacobians into other codes.

Do we really need yet another ISM code?

reaxion might be interesting because it combines two powerful concepts:

  1. Object-oriented implementation of microphysics and chemistry via the Process class, which implements methods for representing physical processes, composing them into a network in a fully-symbolic sympy representation. OOP is nice here because if you want to add a new process to reaxion, you typically only have to do it in one file. Rate expressions never have to be repeated in-code. Most processes one would want to implement follow very common patterns (e.g. 2-body processes), so class inheritance is also used to minimize new lines of code. Once you've constructed your system, reaxion can give you the symbolic equations to manipulate and analyze as you please. If you want to solve the equations numerically, Process has methods for substituting known values into numerical solvers. It can also automatically generate compilable implementations of the RHS of the system to embed in your choice of simulation code and plug into your choice of solver.
  2. Fast, differentiable implementation of nonlinear algebraic and differential-algebraic equation solvers with JAX, implemented in its functional programming paradigm (e.g. reaxion.numerics.newton_rootsolve). These can achieve excellent numerical throughput running natively on GPUs - in fact, crunching iterates in-place is essentially the best-case application of numerics on GPUs. Differentiability enables sensitivity analysis with respect to all parameters in a single pass, instead of constructing a grid of N parameter variations for N parameters. This makes it easier in principle to directly answer questions like "How sensitive is this temperature to the abundance of C or the ionization energy of H?", etc.

Roadmap

reaxion is in an early prototyping phase right now. Here are some things I would eventually like to add:

  • Flexible implementation of a reduced network suitable for RHD simulations in GIZMO and potentially other codes.
  • Dust and radiation physics: add the dust energy equation and evolution of photon number densities to the network.
  • Interfaces to convert from other existing chemistry network formats to the Process representation.
  • Solver robustness upgrades: thermochemical networks can be quite challenging numerically, due to how steeply terms switch on with increasing T. In can be hard to get a solution without good initial guesses.
  • If possible, glue interface allowing an existing compiled hydro code to call the JAX solvers on-the-fly.

Installation

Clone the repo and run pip install . from the directory, or install the latest release from pypi via pip install reaxion.

Quickstart: Collisional Ionization Equilibrium

Example of using reaxion to solve for collisional ionization equilibrium (CIE) for a hydrogen-helium mixture and plot the ionization states as a function of temperature.

%matplotlib inline
%config InlineBackend.figure_format='retina'
import numpy as np
from matplotlib import pyplot as plt
import sympy as sp

Simple processes

A simple process is defined by a single reaction, with a specified rate.

Let's inspect the structure of a single process, the gas-phase recombination of H+: H+ + e- -> H + hν

from reaxion.processes import CollisionalIonization, GasPhaseRecombination

process = GasPhaseRecombination("H+")
print(f"Name: {process.name}")
print(f"Heating rate coefficient: {process.heat_rate_coefficient}")
print(f"Heating rate per cm^-3: {process.heat}"),
print(f"Rate coefficient: {process.rate_coefficient}")
print(f"Recombination rate per cm^-3: {process.rate}")
print(f"RHS of e- number density equation: {process.network['e-']}")
Name: Gas-phase recombination of H+
Heating rate coefficient: -1.46719838641439e-26*sqrt(T)/((0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
Heating rate per cm^-3: -1.46719838641439e-26*sqrt(T)*n_H+*n_e-/((0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
Rate coefficient: 1.41621465870114e-10/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
Recombination rate per cm^-3: 1.41621465870114e-10*n_H+*n_e-/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252)
RHS of e- number density equation: Eq(Derivative(n_e-(t), t), -1.41621465870114e-10*n_H+*n_e-/(sqrt(T)*(0.00119216696847702*sqrt(T) + 1.0)**1.748*(0.563615123664978*sqrt(T) + 1.0)**0.252))

Note that all symbolic representations assume CGS units as is standard in ISM physics.

Composing processes

Now let's define our full network as a sum of simple processes

processes = [CollisionalIonization(s) for s in ("H", "He", "He+")] + [GasPhaseRecombination(i) for i in ("H+", "He+", "He++")]
system = sum(processes)

system.subprocesses
[Collisional Ionization of H,
 Collisional Ionization of He,
 Collisional Ionization of He+,
 Gas-phase recombination of H+,
 Gas-phase recombination of He+,
 Gas-phase recombination of He++]

Summed processes keep track of all subprocesses, e.g. the total net heating rate is:

system.heat

$\displaystyle - \frac{1.55 \cdot 10^{-26} n_{He+} n_{e-}}{T^{0.3647}} - \frac{1.2746917300104 \cdot 10^{-21} \sqrt{T} n_{H} n_{e-} e^{- \frac{157809.1}{T}}}{\frac{\sqrt{10} \sqrt{T}}{1000} + 1} - \frac{1.46719838641439 \cdot 10^{-26} \sqrt{T} n_{H+} n_{e-}}{\left(0.00119216696847702 \sqrt{T} + 1.0\right)^{1.748} \left(0.563615123664978 \sqrt{T} + 1.0\right)^{0.252}} - \frac{9.37661057635428 \cdot 10^{-22} \sqrt{T} n_{He} n_{e-} e^{- \frac{285335.4}{T}}}{\frac{\sqrt{10} \sqrt{T}}{1000} + 1} - \frac{4.9524176975855 \cdot 10^{-22} \sqrt{T} n_{He+} n_{e-} e^{- \frac{631515}{T}}}{\frac{\sqrt{10} \sqrt{T}}{1000} + 1} - \frac{5.86879354565754 \cdot 10^{-26} \sqrt{T} n_{He++} n_{e-}}{\left(0.00119216696847702 \sqrt{T} + 1.0\right)^{1.748} \left(0.563615123664978 \sqrt{T} + 1.0\right)^{0.252}}$

Summing processes also sums all chemical and gas/dust cooling/heating rates.

Solving ionization equilibrium

We would like to solve for ionization equilibrium given a temperature $T$, overall H number density $n_{\rm H,tot}$. We define a dictionary of those input quantities and also one for the initial guesses of the number densities of the species in the reduced network.

Tgrid = np.logspace(3,6,10**6)
ngrid = np.ones_like(Tgrid) * 100

knowns = {"T": Tgrid, "n_Htot": ngrid}

guesses = {
    "H": 0.5*np.ones_like(Tgrid),
    "He": 1e-5*np.ones_like(Tgrid),
    "He+": 1e-5*np.ones_like(Tgrid)
}

Note that by default, the solver only directly solves for $n_{\rm H}$, $n_{\rm He}$ and $n_{\rm He+}$ because $n_{\rm H+}$, $n_{\rm He++}$, and $n_{\rm e-}$ are eliminated by conservation equations. So we only need initial guesses for those 3 quantities. By default the solver takes abundances $x_i = n_i / n_{\rm H,tot}$ as inputs and outputs.

The solve method calls the JAX solver and computes the solution:

sol = system.solve(knowns, guesses,tol=1e-3)
print(sol)
{'He': Array([9.2546351e-02, 9.2546351e-02, 9.2546351e-02, ..., 2.7493625e-09,
       2.7493037e-09, 2.7492442e-09], dtype=float32), 'H': Array([9.9999994e-01, 9.9999994e-01, 9.9999994e-01, ..., 6.0612075e-07,
       6.0611501e-07, 6.0610921e-07], dtype=float32), 'He+': Array([3.1222404e-13, 3.1222396e-13, 3.1222374e-13, ..., 7.6922206e-06,
       7.6921306e-06, 7.6920396e-06], dtype=float32), 'He++': Array([0.        , 0.        , 0.        , ..., 0.09253865, 0.09253865,
       0.09253865], dtype=float32), 'H+': Array([5.9604645e-08, 5.9604645e-08, 5.9604645e-08, ..., 9.9999940e-01,
       9.9999940e-01, 9.9999940e-01], dtype=float32), 'e-': Array([5.9604957e-08, 5.9604957e-08, 5.9604957e-08, ..., 1.1850843e+00,
       1.1850843e+00, 1.1850843e+00], dtype=float32)}
for i, xi in sorted(sol.items()):
    plt.loglog(Tgrid, xi, label=i)
plt.legend(labelspacing=0)
plt.ylabel("$x_i$")
plt.xlabel("T (K)")
plt.ylim(1e-4,3)
(0.0001, 3)

png

Generating code

Suppose you just want the RHS of the system you're solving, or its Jacobian, because you have a better solver and/or want to embed these equations in some old C or Fortran code without any dependencies. You can do that too with generate_code.

print(system.generate_code(('H','He','He+'),language='c'))
# Computes the RHS function and Jacobianto solve for [x_He, x_H, x_Heplus]

# INDEX CONVENTION: (0: x_He) (1: x_H) (2: x_Heplus)

x0 = 1.0/T; 
x1 = sqrt(T); 
x2 = pow(n_Htot, 2); 
x3 = 1.0/((1.0/1000.0)*sqrt(10)*x1 + 1); 
x4 = x1*x2*x3; 
x5 = x4*exp(-285335.40000000002*x0); 
x6 = x5*x_He; 
x7 = x_H - 1; 
x8 = -x7 - 2*x_He - x_Heplus + 2*y; 
x9 = 2.3800000000000001e-11*x8; 
x10 = 1.0/x1; 
x11 = x2*(0.0019*pow(T, -1.5)*(1 + 0.29999999999999999*exp(-94000.0*x0))*exp(-470000.0*x0) + 1.9324160622805846e-10*x10*pow(0.00016493478118851054*x1 + 1.0, -1.7891999999999999)*pow(4.8416074481177231*x1 + 1.0, -0.21079999999999999)); 
x12 = x11*x_Heplus; 
x13 = -x12*x8 + x6*x9; 
x14 = exp(-157809.10000000001*x0); 
x15 = x14*x4; 
x16 = x15*x_H; 
x17 = 5.8500000000000005e-11*x16; 
x18 = -x7; 
x19 = pow(0.0011921669684770192*x1 + 1.0, -1.748); 
x20 = pow(0.56361512366497779*x1 + 1.0, -0.252); 
x21 = -x_He - x_Heplus + y; 
x22 = x10*x2; 
x23 = x22*pow(0.00059608348423850961*x1 + 1.0, -1.748)*pow(0.2818075618324889*x1 + 1.0, -0.252); 
x24 = 5.664858634804579e-10*x23; 
x25 = x24*x8; 
x26 = exp(-631515*x0); 
x27 = x26*x4; 
x28 = 4.7600000000000002e-11*x6; 
x29 = x5*x9; 
x30 = 2*x12; 
x31 = -x12 + 2.3800000000000001e-11*x6; 
x32 = x11*x8 + x31; 
x33 = x19*x20*x22; 
x34 = x18*x33; 
x35 = 1.4162146587011448e-10*x34; 
x36 = -5.68e-12*x1*x2*x26*x3*x_Heplus + x21*x24;

rhs_result[0] = -x13;
rhs_result[1] = 1.4162146587011448e-10*x10*x18*x19*x2*x20*x8 - x17*x8;
rhs_result[2] = x13 + x21*x25 - 5.68e-12*x27*x8*x_Heplus;

jac_result[0] = x28 - x29 - x30;
jac_result[1] = x31;
jac_result[2] = x32;
jac_result[3] = 1.1700000000000001e-10*x16 - 2.8324293174022895e-10*x34;
jac_result[4] = 5.8500000000000005e-11*x1*x14*x2*x3*x_H - 5.8500000000000005e-11*x15*x8 - 1.4162146587011448e-10*x33*x8 - x35;
jac_result[5] = x17 - x35;
jac_result[6] = 1.136e-11*x1*x2*x26*x3*x_Heplus - 1.1329717269609158e-9*x21*x23 - x25 - x28 + x29 + x30;
jac_result[7] = -x31 - x36;
jac_result[8] = -x25 - 5.68e-12*x27*x8 - x32 - x36;

Let's break down what happened there. First, reaxion is generating the symbolic functions needed to solve the system, as it needs to do before it solves the system with its own solver:

func, jac, _ = system.network.solver_functions(('H','He','He+'),return_jac=True)

Here func represents the set of functions $f_i$ such that $f_i = 0$ solves the system. jac encodes the Jacbian of f $J_{ij} = \frac{\partial f_i}{\partial x_j}$ of derivatives with respect to the solved variables. Note that the two have many common expressions - before being implemented, one should employ common expression elimination to simplify the code and evaluate the functions more efficiently:

cse, (cse_func, cse_jac) = sp.cse((sp.Matrix(func),sp.Matrix(jac)))

cse
[(x0, 1/T),
 (x1, sqrt(T)),
 (x2, n_Htot**2),
 (x3, 1/(sqrt(10)*x1/1000 + 1)),
 (x4, x1*x2*x3),
 (x5, x4*exp(-285335.4*x0)),
 (x6, x5*x_He),
 (x7, x_H - 1),
 (x8, -x7 - 2*x_He - x_He+ + 2*y),
 (x9, 2.38e-11*x8),
 (x10, 1/x1),
 (x11,
  x2*(0.0019*(1 + 0.3*exp(-94000.0*x0))*exp(-470000.0*x0)/T**1.5 + 1.93241606228058e-10*x10/((0.000164934781188511*x1 + 1.0)**1.7892*(4.84160744811772*x1 + 1.0)**0.2108))),
 (x12, x11*x_He+),
 (x13, -x12*x8 + x6*x9),
 (x14, exp(-157809.1*x0)),
 (x15, x14*x4),
 (x16, x15*x_H),
 (x17, 5.85e-11*x16),
 (x18, -x7),
 (x19, (0.00119216696847702*x1 + 1.0)**(-1.748)),
 (x20, (0.563615123664978*x1 + 1.0)**(-0.252)),
 (x21, -x_He - x_He+ + y),
 (x22, x10*x2),
 (x23,
  x22/((0.00059608348423851*x1 + 1.0)**1.748*(0.281807561832489*x1 + 1.0)**0.252)),
 (x24, 5.66485863480458e-10*x23),
 (x25, x24*x8),
 (x26, exp(-631515*x0)),
 (x27, x26*x4),
 (x28, 4.76e-11*x6),
 (x29, x5*x9),
 (x30, 2*x12),
 (x31, -x12 + 2.38e-11*x6),
 (x32, x11*x8 + x31),
 (x33, x19*x20*x22),
 (x34, x18*x33),
 (x35, 1.41621465870114e-10*x34),
 (x36, -5.68e-12*x1*x2*x26*x3*x_He+ + x21*x24)]
cse_func

$\displaystyle \left[\begin{matrix}- x_{13}\1.41621465870114 \cdot 10^{-10} x_{10} x_{18} x_{19} x_{2} x_{20} x_{8} - x_{17} x_{8}\x_{13} + x_{21} x_{25} - 5.68 \cdot 10^{-12} x_{27} x_{8} x_{He+}\end{matrix}\right]$

cse_jac

$\displaystyle \left[\begin{matrix}x_{28} - x_{29} - x_{30} & x_{31} & x_{32}\1.17 \cdot 10^{-10} x_{16} - 2.83242931740229 \cdot 10^{-10} x_{34} & 5.85 \cdot 10^{-11} x_{1} x_{14} x_{2} x_{3} x_{H} - 5.85 \cdot 10^{-11} x_{15} x_{8} - 1.41621465870114 \cdot 10^{-10} x_{33} x_{8} - x_{35} & x_{17} - x_{35}\1.136 \cdot 10^{-11} x_{1} x_{2} x_{26} x_{3} x_{He+} - 1.13297172696092 \cdot 10^{-9} x_{21} x_{23} - x_{25} - x_{28} + x_{29} + x_{30} & - x_{31} - x_{36} & - x_{25} - 5.68 \cdot 10^{-12} x_{27} x_{8} - x_{32} - x_{36}\end{matrix}\right]$

One can then take these expressions and convert them to the syntax of the code you wish to embed them in:

from sympy.codegen.ast import Assignment
for expr in cse:
    print(sp.ccode(Assignment(*expr),standard='c99'))

rhs_result = sp.MatrixSymbol('rhs_result', len(func), 1)
jac_result = sp.MatrixSymbol('jac_result', len(func),len(func))
print()
print(sp.ccode(Assignment(rhs_result, cse_func),standard='c99'))
print()
print(sp.ccode(Assignment(jac_result, cse_jac),standard='c99'))
x0 = 1.0/T;
x1 = sqrt(T);
x2 = pow(n_Htot, 2);
x3 = 1.0/((1.0/1000.0)*sqrt(10)*x1 + 1);
x4 = x1*x2*x3;
x5 = x4*exp(-285335.40000000002*x0);
x6 = x5*x_He;
x7 = x_H - 1;
x8 = -x7 - 2*x_He - x_He+ + 2*y;
x9 = 2.3800000000000001e-11*x8;
x10 = 1.0/x1;
x11 = x2*(0.0019*pow(T, -1.5)*(1 + 0.29999999999999999*exp(-94000.0*x0))*exp(-470000.0*x0) + 1.9324160622805846e-10*x10*pow(0.00016493478118851054*x1 + 1.0, -1.7891999999999999)*pow(4.8416074481177231*x1 + 1.0, -0.21079999999999999));
x12 = x11*x_He+;
x13 = -x12*x8 + x6*x9;
x14 = exp(-157809.10000000001*x0);
x15 = x14*x4;
x16 = x15*x_H;
x17 = 5.8500000000000005e-11*x16;
x18 = -x7;
x19 = pow(0.0011921669684770192*x1 + 1.0, -1.748);
x20 = pow(0.56361512366497779*x1 + 1.0, -0.252);
x21 = -x_He - x_He+ + y;
x22 = x10*x2;
x23 = x22*pow(0.00059608348423850961*x1 + 1.0, -1.748)*pow(0.2818075618324889*x1 + 1.0, -0.252);
x24 = 5.664858634804579e-10*x23;
x25 = x24*x8;
x26 = exp(-631515*x0);
x27 = x26*x4;
x28 = 4.7600000000000002e-11*x6;
x29 = x5*x9;
x30 = 2*x12;
x31 = -x12 + 2.3800000000000001e-11*x6;
x32 = x11*x8 + x31;
x33 = x19*x20*x22;
x34 = x18*x33;
x35 = 1.4162146587011448e-10*x34;
x36 = -5.68e-12*x1*x2*x26*x3*x_He+ + x21*x24;

rhs_result[0] = -x13;
rhs_result[1] = 1.4162146587011448e-10*x10*x18*x19*x2*x20*x8 - x17*x8;
rhs_result[2] = x13 + x21*x25 - 5.68e-12*x27*x8*x_He+;

jac_result[0] = x28 - x29 - x30;
jac_result[1] = x31;
jac_result[2] = x32;
jac_result[3] = 1.1700000000000001e-10*x16 - 2.8324293174022895e-10*x34;
jac_result[4] = 5.8500000000000005e-11*x1*x14*x2*x3*x_H - 5.8500000000000005e-11*x15*x8 - 1.4162146587011448e-10*x33*x8 - x35;
jac_result[5] = x17 - x35;
jac_result[6] = 1.136e-11*x1*x2*x26*x3*x_He+ - 1.1329717269609158e-9*x21*x23 - x25 - x28 + x29 + x30;
jac_result[7] = -x31 - x36;
jac_result[8] = -x25 - 5.68e-12*x27*x8 - x32 - x36;

About

Flexible, object-oriented implementation of ISM microphysics and chemistry with fast JAX solvers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published