Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit daafee4

Browse files
authored
Make Multi Node sampler cycle forever (pytorch#1424)
* multi node sampler cycle forever * test for test flakiness * test for test flakiness
1 parent 4ec4548 commit daafee4

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

test/nodes/test_multi_node_weighted_sampler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def setUp(self) -> None:
3030
for i in range(self._num_datasets)
3131
}
3232
self.weights = {f"ds{i}": self._weights_fn(i) for i in range(self._num_datasets)}
33+
self.equal_weights = {f"ds{i}": 1.0 for i in range(self._num_datasets)}
3334

3435
def test_torchdata_nodes_imports(self) -> None:
3536
try:
@@ -149,6 +150,23 @@ def test_multi_node_weighted_sampler_cycle_until_all_exhausted(self) -> None:
149150
self.assertEqual(sorted(datasets_in_results), ["ds0", "ds1", "ds2", "ds3"])
150151
mixer.reset()
151152

153+
def test_multi_node_weighted_sampler_cycle_forever(self) -> None:
154+
"""Test MultiNodeWeightedSampler with stop criteria CYCLE_FOREVER"""
155+
mixer = MultiNodeWeightedSampler(
156+
self.datasets,
157+
self.equal_weights,
158+
stop_criteria=StopCriteria.CYCLE_FOREVER,
159+
)
160+
161+
num_yielded = 0
162+
_it = iter(mixer)
163+
while num_yielded < 256: # total number of samples is 4 * 10 = 40, 256 is an arbitrary larger number
164+
next(_it)
165+
num_yielded += 1
166+
167+
mixer_num_yielded = mixer.get_state()[MultiNodeWeightedSampler.NUM_YIELDED_KEY]
168+
self.assertEqual(mixer_num_yielded, num_yielded)
169+
152170
@parameterized.expand([(1, 8), (8, 32)])
153171
def test_multi_node_weighted_batch_sampler_set_rank_world_size(self, rank, world_size):
154172
"""Test MultiNodeWeightedSampler with different rank and world size"""

torchdata/nodes/samplers/multi_node_weighted_sampler.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def _validate(self) -> None:
9191
StopCriteria.CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
9292
StopCriteria.ALL_DATASETS_EXHAUSTED,
9393
StopCriteria.FIRST_DATASET_EXHAUSTED,
94+
StopCriteria.CYCLE_FOREVER,
9495
]:
9596
raise ValueError(
9697
f"Invalid {self.stop_criteria=}. stop_criteria must be one of: CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED, FIRST_DATASET_EXHAUSTED, ALL_DATASETS_EXHAUSTED"
@@ -144,6 +145,10 @@ def _get_new_weighted_sampler(self, initial_state=None):
144145
)
145146

146147
def _check_for_stop_iteration(self) -> None:
148+
if self.stop_criteria == StopCriteria.CYCLE_FOREVER:
149+
# If StopCriteria is CYCLE_FOREVER, we should never raise StopIteration
150+
return
151+
147152
if all(self._datasets_exhausted.values()):
148153
# Raise StopIteration if all datasets are exhausted,
149154
# this covers for both ALL_DATASETS_EXHAUSTED and CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED
@@ -174,14 +179,14 @@ def next(self) -> T:
174179
# Mark the dataset as exhausted
175180
self._datasets_exhausted[key] = True
176181

177-
# Based on updated _check_for_stop_iteration, check if we should raise StopIteration
182+
# Based on updated _datasets_exhausted, use _check_for_stop_iteration to check if we should raise StopIteration
178183
self._check_for_stop_iteration()
179184

180185
# If StopCriteria is ALL_DATASETS_EXHAUSTED, move to next key
181186
if self.stop_criteria == StopCriteria.ALL_DATASETS_EXHAUSTED:
182187
continue
183188

184-
# If StopCriteria is CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED,
189+
# If StopCriteria is CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED or CYCLE_FOREVER,
185190
# reset the iterator and try again
186191
self.source_nodes[key].reset()
187192
item = next(self.source_nodes[key])

torchdata/nodes/samplers/stop_criteria.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ class StopCriteria:
1717
dataset is seen exactly once. No wraparound or restart will be performed.
1818
1919
3) FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted.
20+
21+
4) CYCLE_FOREVER: Cycle through the datasets by reinitializing each exhausted source nodes.
22+
This is useful when trainer want control over certain number of steps instead of epochs.
2023
"""
2124

2225
CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED = "CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED"
2326
ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED"
2427
FIRST_DATASET_EXHAUSTED = "FIRST_DATASET_EXHAUSTED"
28+
CYCLE_FOREVER = "CYCLE_FOREVER"

0 commit comments

Comments
 (0)