Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Custom primitives + ordered effects + linear_solve = bug in 0.5 ? #26087

@PhilipVinc

Description

@PhilipVinc

Description

In mpi4jax we use some ordered effect to ensure that our mpi communication is ordered.

Version 0.5 of jax breaks our ordered effect logic when we use our custom primitives inside of jax.lax.custom_linear_solve.
The breakage only appears if the call is jitted, while it does not do nothing if it's not jitted.

While this might be a bug on our end, the code was working previously, and I'm at a loss at what could be wrong. Do you have any suggestion on what could be going wrong?

MWE:

import mpi4jax
from mpi4py import MPI
import jax
from jax.scipy.sparse.linalg import cg

k = jax.random.key(1)
O = jax.random.normal(k, (24, 24))
b = jax.random.normal(k, (24,))

def mat_vec(v):
    res, _ = mpi4jax.allreduce(v, op=MPI.SUM, comm=MPI.COMM_WORLD)
    return res
Aop = jax.tree_util.Partial(mat_vec)

# works
x, info = cg(Aop, b)

# crashes
x, info = jax.jit(cg)(Aop, b)

and the stack trace

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File ~/Documents/pythonenvs/netket_pro/bin/ipython:8
      7 sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
----> 8 sys.exit(start_ipython())

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/__init__.py:130, in start_ipython()
    129 from IPython.terminal.ipapp import launch_new_instance
--> 130 return launch_new_instance(argv=argv, **kwargs)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/traitlets/config/application.py:1074, in launch_instance()
   1073 app = cls.instance(**kwargs)
-> 1074 app.initialize(argv)
   1075 app.start()

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/traitlets/config/application.py:118, in inner()
    117 try:
--> 118     return method(app, *args, **kwargs)
    119 except (TraitError, ArgumentError) as e:

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/terminal/ipapp.py:284, in initialize()
    283 self.init_extensions()
--> 284 self.init_code()

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/shellapp.py:353, in init_code()
    351 # command-line execution (ipython -i script.py, ipython -m module)
    352 # should *not* be excluded from %whos
--> 353 self._run_cmd_line_code()
    354 self._run_module()

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/shellapp.py:478, in _run_cmd_line_code()
    477 try:
--> 478     self._exec_file(fname, shell_futures=True)
    479 except:

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/shellapp.py:403, in _exec_file()
    401             else:
    402                 # default to python, even without extension
--> 403                 self.shell.safe_execfile(full_filename,
    404                                          self.shell.user_ns,
    405                                          shell_futures=shell_futures,
    406                                          raise_exceptions=True)
    407 finally:

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/interactiveshell.py:2932, in safe_execfile()
   2931     glob, loc = (where + (None, ))[:2]
-> 2932     py3compat.execfile(
   2933         fname, glob, loc,
   2934         self.compile if shell_futures else None)
   2935 except SystemExit as status:
   2936     # If the call was made with 0 or None exit status (sys.exit(0)
   2937     # or sys.exit() ), don't bother showing a traceback, as both of
   (...)
   2943     # For other exit status, we show the exception unless
   2944     # explicitly silenced, but only in short form.

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/utils/py3compat.py:55, in execfile()
     54 compiler = compiler or compile
---> 55 exec(compiler(f.read(), fname, "exec"), glob, loc)

File ~/Dropbox/Ricerca/Codes/Python/netket/pp.py:19
     18 # crashes
---> 19 x, info = jax.jit(cg)(Aop, b)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/scipy/sparse/linalg.py:286, in cg()
    234 """Use Conjugate Gradient iteration to solve ``Ax = b``.
    235
    236 The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
   (...)
    284 jax.lax.custom_linear_solve
    285 """
--> 286 return _isolve(_cg_solve,
    287                A=A, b=b, x0=x0, tol=tol, atol=atol,
    288                maxiter=maxiter, M=M, check_symmetric=True)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/scipy/sparse/linalg.py:226, in _isolve()
    224 symmetric = all(map(real_valued, tree_leaves(b))) \
    225   if check_symmetric else False
--> 226 x = lax.custom_linear_solve(
    227     A, b, solve=isolve_solve, transpose_solve=isolve_solve,
    228     symmetric=symmetric)
    229 info = None

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/scipy/sparse/linalg.py:128, in _cg_solve()
    126   return x_, r_, gamma_, p_, k + 1
--> 128 r0 = _sub(b, A(x0))
    129 p0 = z0 = M(r0)

File ~/Dropbox/Ricerca/Codes/Python/netket/pp.py:11, in mat_vec()
     10 def mat_vec(v):
---> 11     res, _ = mpi4jax.allreduce(v, op=MPI.SUM, comm=MPI.COMM_WORLD)
     12     return res

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/validation.py:90, in wrapped()
     84         raise TypeError(
     85             f'{func_name} got unexpected type for argument "{arg}" '
     86             f"(expected: {readable_arg_types}, got: {type(val)})."
     87             f"{extra_message}"
     88         )
---> 90 return function(*args, **kwargs)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/collective_ops/allreduce.py:76, in allreduce()
     74     from mpi4jax._src.notoken import allreduce
---> 76     return allreduce(x, op, comm=comm), token
     78 if comm is None:

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/validation.py:90, in wrapped()
     84         raise TypeError(
     85             f'{func_name} got unexpected type for argument "{arg}" '
     86             f"(expected: {readable_arg_types}, got: {type(val)})."
     87             f"{extra_message}"
     88         )
---> 90 return function(*args, **kwargs)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/notoken/collective_ops/allreduce.py:72, in allreduce()
     71 comm = wrap_as_hashable(comm)
---> 72 return mpi_allreduce_p.bind(x, op=op, comm=comm, transpose=False)

JaxStackTraceBeforeTransformation: KeyError: <mpi4jax._src.utils.OrderedMPIEffect object at 0x117c374d0>

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
File ~/Dropbox/Ricerca/Codes/Python/netket/pp.py:19
     16 x, info = cg(Aop, b)
     18 # crashes
---> 19 x, info = jax.jit(cg)(Aop, b)

    [... skipping hidden 1 frame]

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:340, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
    335 if config.no_tracing.value:
    336   raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
    337                      "`jit`, but 'no_tracing' is set")
    339 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable,
--> 340  pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
    342 maybe_fastpath_data = _get_fastpath_data(
    343     executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
    344     jaxpr.consts, jit_info.abstracted_axes,
    345     pgle_profiler)
    347 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:198, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    196   args_flat = map(core.full_lower, args_flat)
    197   core.check_eval_args(args_flat)
--> 198   out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
    199 else:
    200   out_flat = pjit_p.bind(*args_flat, **p.params)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:1660, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, *args)
   1657 compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
   1658 # Passing mutable PGLE profile here since it should be extracted by JAXPR to
   1659 # initialize the fdo_profile compile option.
-> 1660 compiled = _resolve_and_lower(
   1661     args, jaxpr=jaxpr, in_shardings=in_shardings,
   1662     out_shardings=out_shardings, in_layouts=in_layouts,
   1663     out_layouts=out_layouts, resource_env=resource_env,
   1664     donated_invars=donated_invars, name=name, keep_unused=keep_unused,
   1665     inline=inline, lowering_platforms=None,
   1666     lowering_parameters=mlir.LoweringParameters(),
   1667     pgle_profiler=pgle_profiler,
   1668     compiler_options_kvs=compiler_options_kvs,
   1669 ).compile()
   1671 # This check is expensive so only do it if enable_checks is on.
   1672 if compiled._auto_spmd_lowering and config.enable_checks.value:

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:1627, in _resolve_and_lower(args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler, compiler_options_kvs)
   1624 in_shardings = _resolve_in_shardings(args, in_shardings)
   1625 in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
   1626                                  jaxpr.in_avals)
-> 1627 return _pjit_lower(
   1628     jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
   1629     donated_invars, name, keep_unused, inline, compiler_options_kvs,
   1630     lowering_platforms=lowering_platforms,
   1631     lowering_parameters=lowering_parameters,
   1632     pgle_profiler=pgle_profiler)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:1792, in _pjit_lower(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler)
   1789 else:
   1790   mesh, api_name = ((resource_env.physical_mesh, 'pjit')
   1791                     if resource_env is not None else (None, 'jit'))
-> 1792 return pxla.lower_sharding_computation(
   1793     jaxpr, api_name, name, in_shardings, out_shardings,
   1794     in_layouts, out_layouts, tuple(donated_invars),
   1795     keep_unused=keep_unused, context_mesh=mesh,
   1796     compiler_options_kvs=compiler_options_kvs,
   1797     lowering_platforms=lowering_platforms,
   1798     lowering_parameters=lowering_parameters,
   1799     pgle_profiler=pgle_profiler)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
    330 @wraps(func)
    331 def wrapper(*args, **kwargs):
    332   with TraceAnnotation(name, **decorator_kwargs):
--> 333     return func(*args, **kwargs)
    334   return wrapper

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py:2330, in lower_sharding_computation(closed_jaxpr, api_name, fun_name, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, keep_unused, context_mesh, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler)
   2324 semantic_in_shardings = SemanticallyEqualShardings(
   2325     in_shardings, global_in_avals)  # type: ignore
   2326 semantic_out_shardings = SemanticallyEqualShardings(
   2327     out_shardings, global_out_avals)  # type: ignore
   2329 (module, keepalive, host_callbacks, unordered_effects, ordered_effects,
-> 2330  nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
   2331      closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
   2332      semantic_out_shardings, in_layouts, out_layouts, num_devices,
   2333      tuple(da_object) if prim_requires_devices else None, donated_invars,
   2334      name_stack, all_default_mem_kind, inout_aliases,
   2335      propagated_out_mem_kinds, platforms,
   2336      lowering_parameters=lowering_parameters,
   2337      abstract_mesh=abstract_mesh)
   2339 # backend and device_assignment is passed through to MeshExecutable because
   2340 # if keep_unused=False and all in_shardings are pruned, then there is no way
   2341 # to get the device_assignment and backend. So pass it to MeshExecutable
   2342 # because we calculate the device_assignment and backend before in_shardings,
   2343 # etc are pruned.
   2344 return MeshComputation(
   2345     str(name_stack),
   2346     module,
   (...)
   2373     intermediate_shardings=unique_intermediate_shardings,
   2374     context_mesh=context_mesh)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py:1960, in _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, num_devices, device_assignment, donated_invars, name_stack, all_default_mem_kind, inout_aliases, propagated_out_mem_kinds, platforms, lowering_parameters, abstract_mesh)
   1956 ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
   1957 with dispatch.log_elapsed_time(
   1958       "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
   1959       fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
-> 1960   lowering_result = mlir.lower_jaxpr_to_module(
   1961       module_name,
   1962       closed_jaxpr,
   1963       ordered_effects=ordered_effects,
   1964       backend=backend,
   1965       platforms=platforms,
   1966       axis_context=axis_ctx,
   1967       name_stack=name_stack,
   1968       donated_args=donated_invars,
   1969       replicated_args=replicated_args,
   1970       arg_shardings=in_mlir_shardings,
   1971       result_shardings=out_mlir_shardings,
   1972       in_layouts=in_layouts,
   1973       out_layouts=out_layouts,
   1974       arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
   1975       result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
   1976       num_replicas=nreps,
   1977       num_partitions=num_partitions,
   1978       all_default_mem_kind=all_default_mem_kind,
   1979       input_output_aliases=inout_aliases,
   1980       propagated_out_mem_kinds=propagated_out_mem_kinds,
   1981       lowering_parameters=lowering_parameters)
   1982 tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
   1983 unordered_effects = list(
   1984     effects.ordered_effects.filter_not_in(closed_jaxpr.effects))

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1195, in lower_jaxpr_to_module(***failed resolving arguments***)
   1193   attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
   1194   print("ordered_effects", ordered_effects)
-> 1195   lower_jaxpr_to_fun(
   1196       ctx, "main", jaxpr, ordered_effects,
   1197       name_stack=name_stack,
   1198       public=True,
   1199       replicated_args=replicated_args,
   1200       arg_shardings=arg_shardings,
   1201       result_shardings=result_shardings,
   1202       input_output_aliases=input_output_aliases,
   1203       xla_donated_args=xla_donated_args,
   1204       arg_names=arg_names,
   1205       result_names=result_names,
   1206       arg_memory_kinds=arg_memory_kinds,
   1207       result_memory_kinds=result_memory_kinds,
   1208       arg_layouts=in_layouts,
   1209       result_layouts=out_layouts,
   1210       propagated_out_mem_kinds=propagated_out_mem_kinds)
   1212 try:
   1213   if not ctx.module.operation.verify():

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1680, in lower_jaxpr_to_fun(ctx, name, jaxpr, effects, name_stack, public, replicated_args, arg_shardings, result_shardings, use_sharding_annotations, input_output_aliases, xla_donated_args, api_name, arg_names, result_names, arg_memory_kinds, result_memory_kinds, arg_layouts, result_layouts, propagated_out_mem_kinds)
   1678   callee_name_stack = name_stack
   1679 consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
-> 1680 out_vals, tokens_out = jaxpr_subcomp(
   1681     ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
   1682     consts, *args, dim_var_values=dim_var_values)
   1683 outs: list[IrValues] = []
   1684 for eff in effects:

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1952, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   1949   rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
   1951 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
-> 1952 ans = lower_per_platform(rule_ctx, str(eqn.primitive),
   1953                          platform_rules, default_rule,
   1954                          eqn.effects,
   1955                          *in_nodes, **eqn.params)
   1957 if effects:
   1958   # If there were ordered effects in the primitive, there should be output
   1959   # tokens we need for subsequent ordered effects.
   1960   tokens_out = rule_ctx.tokens_out

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:2070, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
   2068 # If there is a single rule left just apply the rule, without conditionals.
   2069 if len(kept_rules) == 1:
-> 2070   output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
   2071   map(
   2072       lambda o: wrap_compute_type_in_place(ctx, o.owner),
   2073       filter(_is_not_block_argument, flatten_ir_values(output)),
   2074   )
   2075   map(
   2076       lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
   2077       flatten_ir_values(output),
   2078   )

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:2185, in lower_fun.<locals>.f_lowered(ctx, *args, **params)
   2183 else:
   2184   sub_context = ctx.module_context
-> 2185 out, tokens = jaxpr_subcomp(
   2186     sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
   2187     _ir_consts(consts), *args,
   2188     dim_var_values=ctx.dim_var_values)
   2189 ctx.set_tokens_out(tokens)
   2190 return out

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1937, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
   1934     default_rule = _lowerings[eqn.primitive]
   1936 effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
-> 1937 tokens_in = tokens.subset(effects)
   1938 avals_in = map(aval, eqn.invars)
   1939 rule_ctx = LoweringRuleContext(
   1940     module_context=ctx, primitive=eqn.primitive,
   1941     name_stack=source_info.name_stack,
   1942     avals_in=avals_in,
   1943     avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
   1944     tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1346, in TokenSet.subset(self, effects)
   1344 """Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
   1345 #print("tokens are:", self._tokens, "effects are:", effects)
-> 1346 return TokenSet((eff, self._tokens[eff]) for eff in effects)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1323, in TokenSet.__init__(self, *args, **kwargs)
   1322 def __init__(self, *args, **kwargs):
-> 1323   self._tokens = collections.OrderedDict(*args, **kwargs)

File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1346, in <genexpr>(.0)
   1344 """Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
   1345 #print("tokens are:", self._tokens, "effects are:", effects)
-> 1346 return TokenSet((eff, self._tokens[eff]) for eff in effects)

KeyError: <mpi4jax._src.utils.OrderedMPIEffect object at 0x117c374d0>

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  1.26.4
python: 3.13.1 (main, Dec 19 2024, 14:22:59) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mba-10834270.local', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec  6 19:01:59 PST 2024; root:xnu-11215.61.5~2/RELEASE_ARM64_T6000', machine='arm64')
In [4]: mpi4jax.__version__
Out[4]: '0.7.0'
``

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions