Gradient multiplier (contrib) operator#13632
Conversation
Missing test for backwards pass
|
@mxnet-label-bot add[Operator, pr-awaiting-review] |
|
Shouldn't we have a more generic gradient multiplier operator? What d you think? |
|
That is certainly possible, shall I rewrite it? |
szha
left a comment
There was a problem hiding this comment.
Thanks for contributing the op. The forward and backward logic can utilize existing kernels such as those in identity and broadcast_scalar_mul.
|
@szha Thanks for the feedback, good points. However, I have a hard time finding those kernels, to me they seem to be deeply integrated into other operators. Could you please point me to the right functions? |
|
@szha Dumped the header file and used forward and backward from identity / scalar_mul. |
| .set_attr_parser([](NodeAttrs* attrs) { | ||
| attrs->parsed = std::stod(attrs->dict["scalar"]); | ||
| }) | ||
| .set_attr<FInferStorageType>("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) |
There was a problem hiding this comment.
Do you also plan to support sparse inputs/outputs? If not, you don't have to register FInferStorageType and FComputeEx (by default it infers dense storage and uses FCompute).
There was a problem hiding this comment.
Since the operator is very simple I thought it would be easy to support sparse data as well. What do I need to change to have full support?
|
Thinking to rename the operator to gradient multiplier. Any thoughts? |
| DispatchMode* dispatch_mode, | ||
| std::vector<int> *in_attrs, | ||
| std::vector<int> *out_attrs) { | ||
| CHECK_EQ(in_attrs->size(), 1); |
There was a problem hiding this comment.
This method has no indentation. Is this expected?
There was a problem hiding this comment.
It does, not sure why github shows it wrong
Retrigger flaky test
| [](const NodeAttrs& attrs){ | ||
| return std::vector<bool>{true}; | ||
| }) | ||
| .add_argument("scalar", "float", "scalar input"); |
There was a problem hiding this comment.
consider making this description more informative (e.g. X multiplier)
There was a problem hiding this comment.
Good point, updated.
Improved the description of the scalar multiplier
|
@szha @ThomasDelteil merge? |
* Added the gradient reversal contrib operator Missing test for backwards pass * Fixed linting errors * Fixed forward test * Added random forward / backward test for gradient reversal * Update test_contrib_operator.py * Fixed typo in gradient reversal op description * Replace forward code with the identitiy implementation * Fixed typos in function docs * Changed default behavior to identity * Replaced backward code with scalar_mul * Fixed backward operator and unit test * Renamed operator to gradient multiplier * Update test_contrib_operator.py Retrigger flaky test * Update gradient_multiplier_op.cc Improved the description of the scalar multiplier
* Added the gradient reversal contrib operator Missing test for backwards pass * Fixed linting errors * Fixed forward test * Added random forward / backward test for gradient reversal * Update test_contrib_operator.py * Fixed typo in gradient reversal op description * Replace forward code with the identitiy implementation * Fixed typos in function docs * Changed default behavior to identity * Replaced backward code with scalar_mul * Fixed backward operator and unit test * Renamed operator to gradient multiplier * Update test_contrib_operator.py Retrigger flaky test * Update gradient_multiplier_op.cc Improved the description of the scalar multiplier
Description
Adds the gradient multiplier operator that is mostly used in unsupervised adversarial domain adaptation.
In short: on forward pass it acts as identity transform; on backwards it multiplies the gradients with a scalar constant (lambda).
See full description here: http://proceedings.mlr.press/v37/ganin15.pdf
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.