- added size_splits to functional#3837
Conversation
|
Yeah, I think a single function would be nicer. Both |
|
Thanks for the feedback, I will merge it in |
|
I merged both functions now. Before that, I timed all functions on my machine (
What are your thoughts? Any suggestion on the code / naming / documentation? |
|
@pytorchbot test this please |
|
@ptrblck great initiative, I've been using my own wrapper for a while. Would be nice to have it in the code base. You could simplify the code quite a bit though by using plain python operations instead of invoking torch overheads, e.g. something like def split(tensor, sizes, dim=0):
if dim < 0:
dim += tensor.dim()
if isinstance(sizes, int):
# original code ...
return chunks
if tensor.size(dim) != sum(sizes):
raise ValueError("Sizes do not match tensor size in dim")
nsizes = len(sizes)
sizes = [0] + sizes
return tuple(tensor.narrow(dim, sizes[i], sizes[i + 1])
for i in range(nsizes))Should be slightly faster too. |
|
@ptrblck as soon as you add unit tests for the list of splits case, i can merge this in. |
- added tests in test_split for variable sections splits (pytorch#3837)
|
@flennerhag Thanks for the suggestions! I tried to change some Pytorch code to plain python operations. |
|
@pytorchbot test this please |
|
thanks a lot @ptrblck ! |
Pull request addresses issue #3223
The
splitfunction splits tensors into equally sized chunks.split_sizeslet the user define a list with sizes for each chunk.tf.split combines both functionalities in one function. Maybe this is also desired for Pytorch?
split_sizesseems to be a bit slower (6.991s vs 6.704s)