-
Notifications
You must be signed in to change notification settings - Fork 3
feat: refactor dataset state #87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
koritsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to understand what exactly resolved the issue. What i understood from the blog post:
- we want to keep dataset metadata in shared memory to be accessed by all processes and not copy it
- every time we access a piece of dataset data we increase it's refcount
- increasing refcount turns objects into copy-to-read -> this piece of data is no longer shared but copied to each process -> this gradually increases unique memory for each process
- the reason is python objects and to avoid this we use
torch.Tensorwhenever we can bc the way torch serializes objects for multiprocessing is saving it to a file with shared access by all child processes - to implement this we moved to TensorDicts all the tensor-like objects since polars doesnt do this serialization properly.
Is it so?
| from tqdm import tqdm | ||
|
|
||
|
|
||
| @hydra.main(version_base=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
|
@koritsky i tried to fix the copy-on-read issue ~2 yrs ago by switching from python containers to iirc that at least fixed the macro issue of training runs getting OOMd not sure if that attempt was flawed in some way or something has changed about polars/etc since either way based on the blog post findings it seems the correct approach it to keep as much state as possible in |
update dataset state from a pair of
pl.DataFrames to a combo ofdata: TensorDictfor most sample keys (see blog)meta: pl.DataFramefor non tensor friendly dtypesstreams: dictfor storing metadata about training-time readersadd dataset saving/loading
add a
torchdatanode-based dataloader (https://meta-pytorch.org/data/beta/migrate_to_nodes_from_utils.html) for thread workersminor config updates
tests refactoring