[WIP] Fallback mechanism for mx.np operators#16923
Conversation
| # try to fallback to official NumPy op | ||
| onp_op = _get_np_op(name) | ||
| new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs] | ||
| out = onp_op(*new_inputs, **kwargs) |
There was a problem hiding this comment.
It will break the computational graph, and could not compute the gradient.
There was a problem hiding this comment.
We are aware of this. More sophisticated fallback mechanism is illustrated in #16698 by leveraging CustomOp. To reach 100% NumPy op coverage within a month, this the simplest and fastest pathway though. In the future, we will gradually replace those fallback ops with native implementation in backend.
There was a problem hiding this comment.
It may be better to use mx.autograd.Function to wrap these numpy operators.
There was a problem hiding this comment.
@wkcn
mx.autograd.Function currently does not support DeepNumpy, some extra infrastructure is required.
Also, I am not sure if mx.autograd.Function could be integrated into HybridBlock, I cannot find corresponding cases covered in the unit tests.
Fix lint Fix
fe362cc to
2cd2094
Compare
* Add fallback mechanism Fix lint Fix * Add unit tests for linalg.cond and heaviside * Add spacing * Fix lint * Skip python2 for dispatching array function Co-authored-by: Hao Jin <[email protected]>
* Add fallback mechanism Fix lint Fix * Add unit tests for linalg.cond and heaviside * Add spacing * Fix lint * Skip python2 for dispatching array function Co-authored-by: Hao Jin <[email protected]>
* Add fallback mechanism Fix lint Fix * Add unit tests for linalg.cond and heaviside * Add spacing * Fix lint * Skip python2 for dispatching array function Co-authored-by: Hao Jin <[email protected]>
Description
Fallback mechanism for
mx.npoperators.@haojin2