-
Notifications
You must be signed in to change notification settings - Fork 24.1k
Pointwise Tag, use in aot_autograd partitioner #90029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…partitioner [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90029
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 Failures, 3 PendingAs of commit fe63c54: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -346,6 +365,9 @@ def is_tensor_node(x): | |||
# Natalia said that we should allow recomputing indexing :) | |||
default_recomputable_ops += [aten.index] | |||
|
|||
# add more generally ? | |||
default_recomputable_ops += pointwise_ops() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I recalled the DTensor pointwise op set might not be a complete set of all the pointwise operators in ATen (although it covers most of them). Do you need to include all possible pointwise ops initially, or we can increase the coverage step by step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does dtensor need the narrower set here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah possibly DTensor needs a narrower set as some of the pointwise op might not make too much sense to DTensor, but I would prefer the tag in native_functions to reflect all pointwise ops, and leave the subsystems (i.e. DTensor or aot autograd) to decide and filter out the ops needed from the tagged set.
desc: | | ||
Pointwise operators are operators where each element of the output is computed only by accessing | ||
the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted | ||
shape of the inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does pointwise imply that it is implemented with TensorIterator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pointwise is more a description of how an operator is computed than its underlying implementation. I would imagine all pointwise operators should be computed with TensorIterator but I don't think that's a necessary condition. Open to input though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason I ask is because there is a bunch of subtle striding behavior, which is unlikely to be done correctly if you're not using TensorIterator under the hood. So, it matters materially for the more subtle invariants whether or not you're buying into "all of the TensorIterator semantics". We don't have to call it TI but I am interested in knowing if we are giving these guarantees.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To add some context about the DTensor pointwise op set, it was initially added from manual inspection of every op in the native_functions.yaml by referencing https://pytorch.org/docs/stable/torch.html#pointwise-ops, then added some missing pointwise ops when we actually tried on real models. I believe this set is still not completely including all possible pointwise ops yet (but it's close).
For the meaning of pointwise ops, I think the description from @eellison should be a fair enough description of the definition of a pointwise op, but I agree we might need some formal algorithmic guard on it to ensure existing and newly added pointwise op get tagged too.
ops.append(opoverloadpacket) | ||
break | ||
|
||
return ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If at all possible, we should figure out a way to do this that doesn't involve glomming all of the operators into an LRU cache. The main hazard to doing it this way is that torch.ops.aten is lazily loaded, so there isn't a guarantee that you will actually have all of the operators at the time you call this function. And this will be even worse if you want to support pointwise ops outside of the aten namespace
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm, what would you suggest ? I think it would make sense for there to exist a database that maps from tag to corresponding operators. cc @anjali411.
That could be lazily loaded as well, but the invocation here would force loading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that's a good idea. I think we can just create it in codegen when the native functions are parsed.
Currently, there's no API for users to add tags from Python Library API but in the future, we'll have to ensure that this db is up-to-date when new ops are added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned this in WC but to repeat it here: for use in the partitioner, you don't need the entire set, at the point where you need to test if an op is pointwise or not, just check the tag directly.
On my end, I need a more clarity on what exactly pointwise means, and when it is permissible to add an operator to this set. "We just used the same list from DTensor" does it cut it in the long term. |
I took from the DTensor supported list initially so that it would be easier to remove that list moving forward if there were extra invariants in that list that are separate from the definition of pointwise ops, and so that you wouldn't have to individually review every operator I'm adding here. I think :
is a concise and adequate representation of pointwise. I'm happy to hear alternative definitions. As it stands, the list in AOT Autograd was missing the following operators, which prevents efficient recomputation and subsequent speeding up of dynamo models.
|
I guess my main concern is (1) how do we know this list as is, is correct, and (2) how do we make sure people keep this list up to date as new operators are added. In the original tags proposal, every tag was supposed to be accompanied with a test that could programmatically determine if an operator was in the set or not. There isn't any such test currently, which suggests that we're pretty unlikely to keep this list up to date going forward |
To ensure that the list is correct and keep the list up to date, I recalled I had a conversation with @albanD about this. Alban suggested that we can compute the vjp of the operator, and if the gradient is a diagonal matrix, then we are certain that this is a pointwise op. Maybe we can use this approach to guard the list? |
Checking that the vjp is diagonal will certainly let you test that there aren't functions that are incorrectly tagged this way. But do you really want to test the inverse (that a new function isn't missing the tag) this way? Seems a bit questionable. |
tbh it is pretty hard to test programmatically here with the definition above. You will have to test every function, every sample for that function and every input value for these. I think there are two options here:
|
I'm also ok with white box approaches; e.g., we grep the C++ source code or something and check it users TensorIterator, for example |
@wanchaol how did you come up with this list originally ? |
…t_autograd partitioner" Takes the pointwise op list from [DTensor](https://github.com/pytorch/pytorch/blob/master/torch/distributed/_tensor/ops/pointwise_ops.py#L36) as an initially starting point for pointwise ops, and feeds them to the aot autograd partitioner. [ghstack-poisoned]
Okay I added all of the functions in UnaryOps.cpp, BinaryOps.cpp, TensorCompare.cpp, TernaryOps.cpp, and I added a test that checks broadcasted shapes. If people further thoughts on testing, I can do something more exhaustive that we agree on. One test that might make sense would be a DEBUG test that checks if when running TensorIterator that the most recently invoked dispatched operator includes the pointwise tag. This would be very expensive and we don't even have a DEBUG build right now but I think would work well. |
I agree canonical is problematic, but that's no excuse lol. (And pointwise is being used for much more substantive stuff in PyTorch than canonical right now, so let's hold it to a higher standard) |
This doesn't look like it does the regex? |
@ezyang did you mean as part of the build process ? I grepped for DEFINE_DISPATCH locally and then used that to tag all of the operators. |
I mean, have a unit test that runs whatever regex you did, and then check the tags match (both positively and negatively) |
…t_autograd partitioner" Takes the pointwise op list from [DTensor](https://github.com/pytorch/pytorch/blob/master/torch/distributed/_tensor/ops/pointwise_ops.py#L36) as an initially starting point for pointwise ops, and feeds them to the aot autograd partitioner. [ghstack-poisoned]
"aten.mode.values", | ||
) | ||
|
||
regex = re.compile(r"DEFINE_DISPATCH\(.*_stub") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hrrmmm, DEFINE_DISPATCH doesn't actually mean it uses TensorIterator, does it? lol
"aten/src/ATen/native/TensorCompare.cpp", | ||
] | ||
|
||
allowed_functions = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the name allowed here is confusing, what you actually mean is, manually denylisted for pointwise tag, don't you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good enough for now. Thank you for humoring me.
…t_autograd partitioner" Takes the pointwise op list from [DTensor](https://github.com/pytorch/pytorch/blob/master/torch/distributed/_tensor/ops/pointwise_ops.py#L36) as an initially starting point for pointwise ops, and feeds them to the aot autograd partitioner. [ghstack-poisoned]
@@ -346,6 +365,9 @@ def is_tensor_node(x): | |||
# Natalia said that we should allow recomputing indexing :) | |||
default_recomputable_ops += [aten.index] | |||
|
|||
# add more generally ? | |||
default_recomputable_ops += pointwise_ops() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably also check that inductor has a lowering for it.. will update in another pr
…t_autograd partitioner" Takes the pointwise op list from [DTensor](https://github.com/pytorch/pytorch/blob/master/torch/distributed/_tensor/ops/pointwise_ops.py#L36) as an initially starting point for pointwise ops, and feeds them to the aot autograd partitioner. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
||
out_shape = torch._refs._broadcast_shapes(*shapes) | ||
|
||
for out_elem in tree_flatten(out): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we actually have any pointwise ops that return more than 1 tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh maybe the _foreach_*
ops fall in this category
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
frexp
Merge failedReason: 2 additional jobs have failed, first few of them are: linux-binary-manywheel ,linux-binary-manywheel / manywheel-py3_7-cuda11_6-test / build Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -f "unrelated flakey failure" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…partitioner (pytorch#90029) Takes the pointwise op list from [DTensor](https://github.com/pytorch/pytorch/blob/master/torch/distributed/_tensor/ops/pointwise_ops.py#L36) as an initially starting point for pointwise ops, and feeds them to the aot autograd partitioner. Pull Request resolved: pytorch#90029 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
Takes the pointwise op list from DTensor as an initially starting point for pointwise ops, and feeds them to the aot autograd partitioner.
Edit: expanded to cover all ops in Unary, Binary, Ternary, and TensorCompare.