-
Notifications
You must be signed in to change notification settings - Fork 363
[feat] Add AimCallback for distributed runs using the hugging face API #3284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Add AimCallback for distributed runs using the hugging face API #3284
Conversation
There is a singular aim.Run which the main worker initializes and manages. All auxiliary workers (local_rank 0 workers hosted on other nodes) collect their metrics and forward them to the main worker. The main worker records the metrics in AIM. Signed-off-by: Vassilis Vassiliadis <[email protected]>
|
I have a test that I ran on my laptop. It uses the following compute_environment: LOCAL_MACHINE
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_cpu_ram_efficient_loading: true
fsdp_sync_module_states: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
num_processes: 4And the following script: import collections
import random
import time
import typing
import os
import logging
import dataclasses
logging.basicConfig(level=20)
log = logging.getLogger()
log.info("initialize")
import aim.distributed_hugging_face
class Callback(aim.distributed_hugging_face.AimCallback):
# I use this just to avoid spinning up an AIM Server, it collects the metrics, computes some statistical data,
# and then dumps those on the disk.
def on_train_end(self, args, state, control, **kwargs):
run_metrics = []
run_hash = None
if state.is_world_process_zero:
run_hash = self._run.hash
run_metrics = [
(m.name, m.context.to_dict(), m.values.values_list())
for m in self._run.metrics()
]
super().on_train_end(args, state, control, **kwargs)
if not state.is_world_process_zero:
return
metrics = []
for name, context, values in run_metrics:
try:
len_values = 0
_sum = 0
avg = None
_min = None
_max = None
for x in values:
if x is None:
continue
len_values += 1
_sum += x
if _min is None or _min > x:
_min = x
if _max is None or _max < x:
_max = x
if len_values > 0:
avg = _sum / len_values
values = {
"avg": avg,
"max": _max,
"min": _min,
}
except ValueError:
# Don't aggregate properties that are weird
pass
metrics.append(
{
"name": name,
"values": values,
"context": context,
}
)
# Standard
import json
data = data = json.dumps(
{
"run_hash": run_hash,
"metrics": metrics,
"world_size": self.distributed_information.world_size,
},
indent=2,
)
print(data)
with open("aim_info.json", "w", encoding="utf-8") as f:
f.write(data)
cb = Callback(
main_port=8788,
repo="my_aim",
# main_host="localhost",
system_tracking_interval=1,
# This is so that we can test this on a single machine
workers_only_on_rank_0=False,
)
log.info("train_init " + str(dataclasses.asdict(cb.distributed_information)))
sleep_seconds = 10
stateType = collections.namedtuple("state", field_names=["is_world_process_zero"])
state = stateType(is_world_process_zero=cb.distributed_information.rank == 0)
cb.on_train_begin(args=None, state=state, control=None)
sleep_seconds += random.random() * 30
log.info(
f"sleep for {sleep_seconds}" + str(dataclasses.asdict(cb.distributed_information))
)
time.sleep(sleep_seconds)
log.info("train_end" + str(dataclasses.asdict(cb.distributed_information)))
cb.on_train_end(args=None, state=state, control=None)
log.info("stopped" + str(dataclasses.asdict(cb.distributed_information)))I call the above It takes about a minute to run and then prints the following: {
"run_hash": "05b71ecb672249e1871cec25",
"metrics": [
{
"name": "__system__cpu",
"values": {
"avg": 0.21071428571428577,
"max": 0.6,
"min": 0.1
},
"context": {
"node_rank": 1,
"rank": 1
}
},
{
"name": "__system__disk_percent",
"values": {
"avg": 0.8999999999999996,
"max": 0.9,
"min": 0.9
},
"context": {
"node_rank": 1,
"rank": 1
}
},
{
"name": "__system__memory_percent",
"values": {
"avg": 43.09037757142856,
"max": 43.487668,
"min": 42.708874
},
"context": {
"node_rank": 1,
"rank": 1
}
},
{
"name": "__system__p_memory_percent",
"values": {
"avg": 0.37263560714285715,
"max": 0.372696,
"min": 0.372386
},
"context": {
"node_rank": 1,
"rank": 1
}
},
{
"name": "__system__cpu",
"values": {
"avg": 0.5999999999999998,
"max": 1.0,
"min": 0.0
},
"context": {}
},
{
"name": "__system__disk_percent",
"values": {
"avg": 0.8999999999999996,
"max": 0.9,
"min": 0.9
},
"context": {}
},
{
"name": "__system__memory_percent",
"values": {
"avg": 43.09092353333333,
"max": 43.537211,
"min": 42.702723
},
"context": {}
},
{
"name": "__system__p_memory_percent",
"values": {
"avg": 0.3828104666666667,
"max": 0.383759,
"min": 0.380683
},
"context": {}
}
],
"world_size": 2
}It uses 2 workers both of which run on the same node. If you think this is a reasonable approach we could try it out on actual distributed notes and verify it works as expected. |
alberttorosyan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @VassilisVassiliadis! Thanks for your contribution. Changes look good! Approving this PR.
|
Thanks! |
There is a singular aim.Run which the main worker initializes and manages. All auxiliary workers (local_rank 0 workers hosted on other nodes) collect their metrics and forward them to the main worker. The main worker records the metrics in AIM.
Resolves #3148