-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
In mpi4jax we use some ordered effect to ensure that our mpi communication is ordered.
Version 0.5 of jax breaks our ordered effect logic when we use our custom primitives inside of jax.lax.custom_linear_solve.
The breakage only appears if the call is jitted, while it does not do nothing if it's not jitted.
While this might be a bug on our end, the code was working previously, and I'm at a loss at what could be wrong. Do you have any suggestion on what could be going wrong?
MWE:
import mpi4jax
from mpi4py import MPI
import jax
from jax.scipy.sparse.linalg import cg
k = jax.random.key(1)
O = jax.random.normal(k, (24, 24))
b = jax.random.normal(k, (24,))
def mat_vec(v):
res, _ = mpi4jax.allreduce(v, op=MPI.SUM, comm=MPI.COMM_WORLD)
return res
Aop = jax.tree_util.Partial(mat_vec)
# works
x, info = cg(Aop, b)
# crashes
x, info = jax.jit(cg)(Aop, b)and the stack trace
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File ~/Documents/pythonenvs/netket_pro/bin/ipython:8
7 sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
----> 8 sys.exit(start_ipython())
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/__init__.py:130, in start_ipython()
129 from IPython.terminal.ipapp import launch_new_instance
--> 130 return launch_new_instance(argv=argv, **kwargs)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/traitlets/config/application.py:1074, in launch_instance()
1073 app = cls.instance(**kwargs)
-> 1074 app.initialize(argv)
1075 app.start()
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/traitlets/config/application.py:118, in inner()
117 try:
--> 118 return method(app, *args, **kwargs)
119 except (TraitError, ArgumentError) as e:
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/terminal/ipapp.py:284, in initialize()
283 self.init_extensions()
--> 284 self.init_code()
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/shellapp.py:353, in init_code()
351 # command-line execution (ipython -i script.py, ipython -m module)
352 # should *not* be excluded from %whos
--> 353 self._run_cmd_line_code()
354 self._run_module()
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/shellapp.py:478, in _run_cmd_line_code()
477 try:
--> 478 self._exec_file(fname, shell_futures=True)
479 except:
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/shellapp.py:403, in _exec_file()
401 else:
402 # default to python, even without extension
--> 403 self.shell.safe_execfile(full_filename,
404 self.shell.user_ns,
405 shell_futures=shell_futures,
406 raise_exceptions=True)
407 finally:
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/core/interactiveshell.py:2932, in safe_execfile()
2931 glob, loc = (where + (None, ))[:2]
-> 2932 py3compat.execfile(
2933 fname, glob, loc,
2934 self.compile if shell_futures else None)
2935 except SystemExit as status:
2936 # If the call was made with 0 or None exit status (sys.exit(0)
2937 # or sys.exit() ), don't bother showing a traceback, as both of
(...)
2943 # For other exit status, we show the exception unless
2944 # explicitly silenced, but only in short form.
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/IPython/utils/py3compat.py:55, in execfile()
54 compiler = compiler or compile
---> 55 exec(compiler(f.read(), fname, "exec"), glob, loc)
File ~/Dropbox/Ricerca/Codes/Python/netket/pp.py:19
18 # crashes
---> 19 x, info = jax.jit(cg)(Aop, b)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/scipy/sparse/linalg.py:286, in cg()
234 """Use Conjugate Gradient iteration to solve ``Ax = b``.
235
236 The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
(...)
284 jax.lax.custom_linear_solve
285 """
--> 286 return _isolve(_cg_solve,
287 A=A, b=b, x0=x0, tol=tol, atol=atol,
288 maxiter=maxiter, M=M, check_symmetric=True)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/scipy/sparse/linalg.py:226, in _isolve()
224 symmetric = all(map(real_valued, tree_leaves(b))) \
225 if check_symmetric else False
--> 226 x = lax.custom_linear_solve(
227 A, b, solve=isolve_solve, transpose_solve=isolve_solve,
228 symmetric=symmetric)
229 info = None
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/scipy/sparse/linalg.py:128, in _cg_solve()
126 return x_, r_, gamma_, p_, k + 1
--> 128 r0 = _sub(b, A(x0))
129 p0 = z0 = M(r0)
File ~/Dropbox/Ricerca/Codes/Python/netket/pp.py:11, in mat_vec()
10 def mat_vec(v):
---> 11 res, _ = mpi4jax.allreduce(v, op=MPI.SUM, comm=MPI.COMM_WORLD)
12 return res
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/validation.py:90, in wrapped()
84 raise TypeError(
85 f'{func_name} got unexpected type for argument "{arg}" '
86 f"(expected: {readable_arg_types}, got: {type(val)})."
87 f"{extra_message}"
88 )
---> 90 return function(*args, **kwargs)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/collective_ops/allreduce.py:76, in allreduce()
74 from mpi4jax._src.notoken import allreduce
---> 76 return allreduce(x, op, comm=comm), token
78 if comm is None:
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/validation.py:90, in wrapped()
84 raise TypeError(
85 f'{func_name} got unexpected type for argument "{arg}" '
86 f"(expected: {readable_arg_types}, got: {type(val)})."
87 f"{extra_message}"
88 )
---> 90 return function(*args, **kwargs)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/mpi4jax/_src/notoken/collective_ops/allreduce.py:72, in allreduce()
71 comm = wrap_as_hashable(comm)
---> 72 return mpi_allreduce_p.bind(x, op=op, comm=comm, transpose=False)
JaxStackTraceBeforeTransformation: KeyError: <mpi4jax._src.utils.OrderedMPIEffect object at 0x117c374d0>
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
File ~/Dropbox/Ricerca/Codes/Python/netket/pp.py:19
16 x, info = cg(Aop, b)
18 # crashes
---> 19 x, info = jax.jit(cg)(Aop, b)
[... skipping hidden 1 frame]
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:340, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
335 if config.no_tracing.value:
336 raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
337 "`jit`, but 'no_tracing' is set")
339 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, executable,
--> 340 pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
342 maybe_fastpath_data = _get_fastpath_data(
343 executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
344 jaxpr.consts, jit_info.abstracted_axes,
345 pgle_profiler)
347 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:198, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
196 args_flat = map(core.full_lower, args_flat)
197 core.check_eval_args(args_flat)
--> 198 out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
199 else:
200 out_flat = pjit_p.bind(*args_flat, **p.params)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:1660, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, *args)
1657 compiler_options_kvs = compiler_options_kvs + tuple(pgle_compile_options.items())
1658 # Passing mutable PGLE profile here since it should be extracted by JAXPR to
1659 # initialize the fdo_profile compile option.
-> 1660 compiled = _resolve_and_lower(
1661 args, jaxpr=jaxpr, in_shardings=in_shardings,
1662 out_shardings=out_shardings, in_layouts=in_layouts,
1663 out_layouts=out_layouts, resource_env=resource_env,
1664 donated_invars=donated_invars, name=name, keep_unused=keep_unused,
1665 inline=inline, lowering_platforms=None,
1666 lowering_parameters=mlir.LoweringParameters(),
1667 pgle_profiler=pgle_profiler,
1668 compiler_options_kvs=compiler_options_kvs,
1669 ).compile()
1671 # This check is expensive so only do it if enable_checks is on.
1672 if compiled._auto_spmd_lowering and config.enable_checks.value:
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:1627, in _resolve_and_lower(args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler, compiler_options_kvs)
1624 in_shardings = _resolve_in_shardings(args, in_shardings)
1625 in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings,
1626 jaxpr.in_avals)
-> 1627 return _pjit_lower(
1628 jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env,
1629 donated_invars, name, keep_unused, inline, compiler_options_kvs,
1630 lowering_platforms=lowering_platforms,
1631 lowering_parameters=lowering_parameters,
1632 pgle_profiler=pgle_profiler)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/pjit.py:1792, in _pjit_lower(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler)
1789 else:
1790 mesh, api_name = ((resource_env.physical_mesh, 'pjit')
1791 if resource_env is not None else (None, 'jit'))
-> 1792 return pxla.lower_sharding_computation(
1793 jaxpr, api_name, name, in_shardings, out_shardings,
1794 in_layouts, out_layouts, tuple(donated_invars),
1795 keep_unused=keep_unused, context_mesh=mesh,
1796 compiler_options_kvs=compiler_options_kvs,
1797 lowering_platforms=lowering_platforms,
1798 lowering_parameters=lowering_parameters,
1799 pgle_profiler=pgle_profiler)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
330 @wraps(func)
331 def wrapper(*args, **kwargs):
332 with TraceAnnotation(name, **decorator_kwargs):
--> 333 return func(*args, **kwargs)
334 return wrapper
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py:2330, in lower_sharding_computation(closed_jaxpr, api_name, fun_name, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, keep_unused, context_mesh, compiler_options_kvs, lowering_platforms, lowering_parameters, pgle_profiler)
2324 semantic_in_shardings = SemanticallyEqualShardings(
2325 in_shardings, global_in_avals) # type: ignore
2326 semantic_out_shardings = SemanticallyEqualShardings(
2327 out_shardings, global_out_avals) # type: ignore
2329 (module, keepalive, host_callbacks, unordered_effects, ordered_effects,
-> 2330 nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
2331 closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
2332 semantic_out_shardings, in_layouts, out_layouts, num_devices,
2333 tuple(da_object) if prim_requires_devices else None, donated_invars,
2334 name_stack, all_default_mem_kind, inout_aliases,
2335 propagated_out_mem_kinds, platforms,
2336 lowering_parameters=lowering_parameters,
2337 abstract_mesh=abstract_mesh)
2339 # backend and device_assignment is passed through to MeshExecutable because
2340 # if keep_unused=False and all in_shardings are pruned, then there is no way
2341 # to get the device_assignment and backend. So pass it to MeshExecutable
2342 # because we calculate the device_assignment and backend before in_shardings,
2343 # etc are pruned.
2344 return MeshComputation(
2345 str(name_stack),
2346 module,
(...)
2373 intermediate_shardings=unique_intermediate_shardings,
2374 context_mesh=context_mesh)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/pxla.py:1960, in _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, in_layouts, out_layouts, num_devices, device_assignment, donated_invars, name_stack, all_default_mem_kind, inout_aliases, propagated_out_mem_kinds, platforms, lowering_parameters, abstract_mesh)
1956 ordered_effects = list(effects.ordered_effects.filter_in(closed_jaxpr.effects))
1957 with dispatch.log_elapsed_time(
1958 "Finished jaxpr to MLIR module conversion {fun_name} in {elapsed_time:.9f} sec",
1959 fun_name=str(name_stack), event=dispatch.JAXPR_TO_MLIR_MODULE_EVENT):
-> 1960 lowering_result = mlir.lower_jaxpr_to_module(
1961 module_name,
1962 closed_jaxpr,
1963 ordered_effects=ordered_effects,
1964 backend=backend,
1965 platforms=platforms,
1966 axis_context=axis_ctx,
1967 name_stack=name_stack,
1968 donated_args=donated_invars,
1969 replicated_args=replicated_args,
1970 arg_shardings=in_mlir_shardings,
1971 result_shardings=out_mlir_shardings,
1972 in_layouts=in_layouts,
1973 out_layouts=out_layouts,
1974 arg_names=jaxpr._debug_info and jaxpr._debug_info.arg_names,
1975 result_names=jaxpr._debug_info and jaxpr._debug_info.result_paths,
1976 num_replicas=nreps,
1977 num_partitions=num_partitions,
1978 all_default_mem_kind=all_default_mem_kind,
1979 input_output_aliases=inout_aliases,
1980 propagated_out_mem_kinds=propagated_out_mem_kinds,
1981 lowering_parameters=lowering_parameters)
1982 tuple_args = dispatch.should_tuple_args(len(global_in_avals), backend.platform)
1983 unordered_effects = list(
1984 effects.ordered_effects.filter_not_in(closed_jaxpr.effects))
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1195, in lower_jaxpr_to_module(***failed resolving arguments***)
1193 attrs["mhlo.num_partitions"] = i32_attr(num_partitions)
1194 print("ordered_effects", ordered_effects)
-> 1195 lower_jaxpr_to_fun(
1196 ctx, "main", jaxpr, ordered_effects,
1197 name_stack=name_stack,
1198 public=True,
1199 replicated_args=replicated_args,
1200 arg_shardings=arg_shardings,
1201 result_shardings=result_shardings,
1202 input_output_aliases=input_output_aliases,
1203 xla_donated_args=xla_donated_args,
1204 arg_names=arg_names,
1205 result_names=result_names,
1206 arg_memory_kinds=arg_memory_kinds,
1207 result_memory_kinds=result_memory_kinds,
1208 arg_layouts=in_layouts,
1209 result_layouts=out_layouts,
1210 propagated_out_mem_kinds=propagated_out_mem_kinds)
1212 try:
1213 if not ctx.module.operation.verify():
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1680, in lower_jaxpr_to_fun(ctx, name, jaxpr, effects, name_stack, public, replicated_args, arg_shardings, result_shardings, use_sharding_annotations, input_output_aliases, xla_donated_args, api_name, arg_names, result_names, arg_memory_kinds, result_memory_kinds, arg_layouts, result_layouts, propagated_out_mem_kinds)
1678 callee_name_stack = name_stack
1679 consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
-> 1680 out_vals, tokens_out = jaxpr_subcomp(
1681 ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
1682 consts, *args, dim_var_values=dim_var_values)
1683 outs: list[IrValues] = []
1684 for eff in effects:
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1952, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
1949 rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
1951 assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
-> 1952 ans = lower_per_platform(rule_ctx, str(eqn.primitive),
1953 platform_rules, default_rule,
1954 eqn.effects,
1955 *in_nodes, **eqn.params)
1957 if effects:
1958 # If there were ordered effects in the primitive, there should be output
1959 # tokens we need for subsequent ordered effects.
1960 tokens_out = rule_ctx.tokens_out
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:2070, in lower_per_platform(ctx, description, platform_rules, default_rule, effects, *rule_args, **rule_kwargs)
2068 # If there is a single rule left just apply the rule, without conditionals.
2069 if len(kept_rules) == 1:
-> 2070 output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
2071 map(
2072 lambda o: wrap_compute_type_in_place(ctx, o.owner),
2073 filter(_is_not_block_argument, flatten_ir_values(output)),
2074 )
2075 map(
2076 lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
2077 flatten_ir_values(output),
2078 )
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:2185, in lower_fun.<locals>.f_lowered(ctx, *args, **params)
2183 else:
2184 sub_context = ctx.module_context
-> 2185 out, tokens = jaxpr_subcomp(
2186 sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
2187 _ir_consts(consts), *args,
2188 dim_var_values=ctx.dim_var_values)
2189 ctx.set_tokens_out(tokens)
2190 return out
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1937, in jaxpr_subcomp(ctx, jaxpr, name_stack, tokens, consts, dim_var_values, *args)
1934 default_rule = _lowerings[eqn.primitive]
1936 effects = list(effects_lib.ordered_effects.filter_in(eqn.effects))
-> 1937 tokens_in = tokens.subset(effects)
1938 avals_in = map(aval, eqn.invars)
1939 rule_ctx = LoweringRuleContext(
1940 module_context=ctx, primitive=eqn.primitive,
1941 name_stack=source_info.name_stack,
1942 avals_in=avals_in,
1943 avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
1944 tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1346, in TokenSet.subset(self, effects)
1344 """Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
1345 #print("tokens are:", self._tokens, "effects are:", effects)
-> 1346 return TokenSet((eff, self._tokens[eff]) for eff in effects)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1323, in TokenSet.__init__(self, *args, **kwargs)
1322 def __init__(self, *args, **kwargs):
-> 1323 self._tokens = collections.OrderedDict(*args, **kwargs)
File ~/Documents/pythonenvs/netket_pro/lib/python3.13/site-packages/jax/_src/interpreters/mlir.py:1346, in <genexpr>(.0)
1344 """Return a subset of the `TokenSet` restricted to a set of `core.Effect`s."""
1345 #print("tokens are:", self._tokens, "effects are:", effects)
-> 1346 return TokenSet((eff, self._tokens[eff]) for eff in effects)
KeyError: <mpi4jax._src.utils.OrderedMPIEffect object at 0x117c374d0>System info (python version, jaxlib version, accelerator, etc.)
jax: 0.5.0
jaxlib: 0.5.0
numpy: 1.26.4
python: 3.13.1 (main, Dec 19 2024, 14:22:59) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='mba-10834270.local', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec 6 19:01:59 PST 2024; root:xnu-11215.61.5~2/RELEASE_ARM64_T6000', machine='arm64')In [4]: mpi4jax.__version__
Out[4]: '0.7.0'
``Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working