-
Couldn't load subscription status.
- Fork 51
Open
Labels
rfcLet's discuss a proposalLet's discuss a proposal
Description
TL'DR
Motivation
Our current APIs for nD Parallel Training are low-level and are kind of complex for common users ... Ideally, we want a simpler API at a high level like this:
Single Device Code
dataset = ...
data_loader = torch.utils.data.DataLoader(dataset, ...)
class Net(nn.Module):
...
def optimizer_fn(model):
...
return torch.optim.Adam(model_param_groups, ...)
def lr_scheduler_fn(optimizer):
...
return torch.optim.lr_scheduler.StepLR(optimizer, ...)
model = Net(...)
optimizer = optimizer_fn(model)
scheduler = lr_scheduler_fn(optimizer)
for epoch in range(10):
for batch in data_loader:
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()
scheduler.step()
torch.save(model.state_dict(), "/path/to/checkpoint")
torch.save(optimizer.state_dict(), "/path/to/checkpoint")
torch.save(scheduler.state_dict(), "/path/to/checkpoint")veScale High-Level API for nD Parallel Training
dataset = ...
### zero code change on model
class Net(nn.Module):
...
def optimizer_fn(model):
...
return torch.optim.Adam(model_param_groups, ...)
def lr_scheduler_fn(optimizer):
...
return torch.optim.lr_scheduler.StepLR(optimizer, ...)
### create giant model without OOM
model = vescale.deferred_init(Net, ...)
### generate plan of nD parallel training under user constraints
# $ constraints = { "pipeline_parallel.split_method" : "flops",
# $ "tensor_parallel.sharding_policy" : "megatron" }
plan = vescale.generate_plan(constraints, model)
# $ print(plan)
# $ pipeline_parallel.split_points : ["layer1", "layer3", ...]
# $ tensor_parallel.sharding_plan : { "layer2.weight" : [Shard(dim=0)], ... }
### create nD parallel model and optimizer, specified by the plan
model, optimizer, scheduler, data_loader = vescale.parallelize(plan, model, optimizer_fn, lr_scheduler_fn, dataset)
### zero code change on training loop
for epoch in range(10):
for batch in data_loader:
optimizer.zero_grad()
### trains nD parallel model as if on single device
loss = model(batch)
loss.backward()
optimizer.step()
scheduler.step()
### saves nD parallel model and optimizer
vescale.save("/path/to/checkpoint", { "plan": plan, "model" : model, "optimizer" : optimizer, "lr_scheduler": scheduler })Idea
- Single Device Abstraction for nD Parallel Training
- Common users can only see this high-level API
- Common users can only write <10 LoC in training scripts
- veScale handles all complexities under the hood (e.g., all low-level APIs)
- This is a unified API for nD parallelsim
- Take an nD parallel "Plan" (a.k.a. config)
- Create nD DeviceMesh
- Create each D of nD Parallel
- Support both Eager and Compile Mode
- This is a future-proof API
- can extend to future DeviceMesh
- can extend to future parallel (e.g., EP, CP, *P)
Feedbacks are all we need : )
Metadata
Metadata
Assignees
Labels
rfcLet's discuss a proposalLet's discuss a proposal