Thanks to visit codestin.com
Credit goes to github.com

Skip to content

acoh64/mosaix-pde

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

65 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mosaix-pde

mosaix-pde is a package for optimizing pattern forming PDEs that appear in different areas of physics, written in JAX. It has code for PDE optimization and control with gradient-based methods and reinforcement learning. We use diffrax for time stepping and implement system-specific solvers, such as semi-implicit Fourier methods and Strang splitting.

You can find the full documentation on read the docs.

Installation

To install the package, we recommend cloning the github repo and then installing locally:

git clone https://github.com/acoh64/mosaix-pde.git
cd mosaix-pde
conda create -y -n mosaix-pde-env python=3.12
conda activate mosaix-pde-env
pip install -e .

By default, it will install the CPU version of JAX. To use with GPU, run:

pip install -U "jax[cuda12]"

Usage

Here is an example of solving the Cahn-Hilliard equation in 2D with periodic boundary conditions using a semi-implicit Fourier method:

import jax
import jax.numpy as jnp

from mosaix_pde import PDEModel
from mosaix_pde import CahnHilliard2DPeriodic
from mosaix_pde import SemiImplicitFourierSpectral
from mosaix_pde import Domain
from mosaix_pde import PeriodicCNN

Nx = Ny = 128
Lx = Ly = 0.01 * Nx

domain = Domain((Nx, Ny), ((-Lx / 2, Lx / 2), (-Ly / 2, Ly / 2)), "dimensionless")

opt_model = PDEModel(equation_type=CahnHilliard2DPeriodic, domain=domain, solver_type=SemiImplicitFourierSpectral)

params = {"kappa": 0.002, "mu": lambda c: jnp.log(c / (1.0 - c)) + 3.0 * (1.0 - 2.0 * c), "D": lambda c: c * (1. - c)}

solver_params = {"A": 0.5}

key = jax.random.PRNGKey(0)
y0 = jnp.clip(0.01 * jax.random.normal(key, (Nx, Ny)) + 0.5, 0.0, 1.0)
ts = jnp.linspace(0.0, 0.02, 100)

sol = opt_model.solve(params, y0, ts, solver_params, dt0=0.000001, max_steps=1000000)

Next, here is an example of using the previous solution as a dataset to fit a neural network for the chemical potential term:

data = {}
data['ys'] = sol
data['ts'] = ts

model = PeriodicCNN(
    in_channels=1,
    hidden_channels=(32, 64, 64),
    out_channels=1,
    kernel_size=3,
    key=jax.random.PRNGKey(0),
)

init_params = {"mu": model}
static_params = {"kappa": 0.002, "D": lambda c: c * (1. - c)}
solver_parameters = {"A": 0.5}
weights = {"mu": None}
lambda_reg = 0.0

inds = [[30,40,50], [50,60,70], [70,80,90]]

res = opt_model.train(data, inds, init_params, static_params, solver_parameters, weights, lambda_reg, method="mse", max_steps=100)

Current Model Implementations

This package is designed to support pattern-forming PDEs across a wide-range of physical systems. We have currently implemented variants of the following equations:

  • Cahn-Hilliard equation
    • 2D with periodic boundary conditions
    • 3D with periodic boundary conditions
    • 2D with smoothed boundary method
  • Allen-Cahn equation
    • 2D with periodic boundary conditions
    • 2D with constant current conditions + Butler-Volmer kinetics (for battery applications)
    • 2D with smoothed boundary
    • 2D with smoothed boundar and constant current conditions + Butler-Volmer kinetics (for battery applications)
  • Gross-Pitaevskii
    • Reduced 2D with periodic boundary conditions
    • Rotating reduced with 2D periodic boundary conditions

Running Tests

To run the tests in the tests/ directory, run

pytest tests/

TODO

  • Arbitrary boundary conditions
  • Implicit time stepping
  • Multi-GPU support
  • Extend to non-Cartesian domains
  • WandB logging and checkpointing

License

This code has been published under the MIT licence.

Acknowledgments

This project builds on the excellent JAX ecosystem for scientific computing. We gratefully acknowledge the following open-source libraries:

We especially thank Patrick Kidger for developing amazing JAX software.

About

Library for PDE learning, optimization, and control with gradient-based methods and reinforcement learning

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •