[jit] Respect order of Parameters in rnn.py#18198
Conversation
eellison
left a comment
There was a problem hiding this comment.
Sorry what exactly was the issue previously and what is the new behavior?
| ConstantParameterList(std::shared_ptr<Module> module) | ||
| : module_(std::move(module)) {} | ||
| ConstantParameterList(std::shared_ptr<Module> module, py::list params) | ||
| : module_(std::move(module)), params_(std::move(params)) {} |
There was a problem hiding this comment.
Maybe convert & store the params to std::string here ? i imagine that's less overhead
There was a problem hiding this comment.
It'd be more since it would duplicate the strings, the py::list is just a thin wrapper around a Python list with the same reference semantics
| return self._flat_weights | ||
|
|
||
| def _get_flat_weights_names(self): | ||
| all_weights = [[weight for weight in weights] for weights in self._all_weights] |
There was a problem hiding this comment.
imo the 4 weight(s) on one line is kidn of confusing
|
|
||
| def _get_flat_weights_names(self): | ||
| all_weights = [[weight for weight in weights] for weights in self._all_weights] | ||
| return [p for layerparams in all_weights for p in layerparams] |
There was a problem hiding this comment.
ditto with the two layerparams / ps
|
@eellison clarified reasons in PR message |
eellison
left a comment
There was a problem hiding this comment.
Looks good, a little hard to follow.
| @_parameter_list | ||
| def get_flat_weights(self): | ||
| def _get_flat_weights_names(self): | ||
| return [weight for weights in self._all_weights for weight in weights] |
There was a problem hiding this comment.
do you mean rewriting? This is hard to read
facebook-github-bot
left a comment
There was a problem hiding this comment.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
@driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Previously to get a list of parameters this code was just putting them in the reverse order in which they were defined, which is not always right. This PR allows parameter lists to define the order themselves. To do this parameter lists need to have a corresponding function that provides the names of the parameters. Pull Request resolved: pytorch#18198 Differential Revision: D14966270 Pulled By: driazati fbshipit-source-id: 59331aa59408660069785906304b2088c19534b2
Summary: Previously to get a list of parameters this code was just putting them in the reverse order in which they were defined, which is not always right. This PR allows parameter lists to define the order themselves. To do this parameter lists need to have a corresponding function that provides the names of the parameters. Pull Request resolved: pytorch#18198 Differential Revision: D14966270 Pulled By: driazati fbshipit-source-id: 59331aa59408660069785906304b2088c19534b2
Previously to get a list of parameters this code was just putting them in the reverse order in which they were defined, which is not always right. This PR allows parameter lists to define the order themselves. To do this parameter lists need to have a corresponding function that provides the names of the parameters.