forked from mosaicml/composer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcallback_settings.py
More file actions
328 lines (290 loc) · 10 KB
/
Copy pathcallback_settings.py
File metadata and controls
328 lines (290 loc) · 10 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import os
from typing import Any
from unittest import mock
from unittest.mock import MagicMock
import pytest
from torch.utils.data import DataLoader
import composer.callbacks
import composer.loggers
import composer.profiler
from composer import Callback
from composer.callbacks import (
EarlyStopper,
ExportForInferenceCallback,
FreeOutputs,
Generate,
ImageVisualizer,
MemoryMonitor,
MemorySnapshot,
MLPerfCallback,
OOMObserver,
SpeedMonitor,
SystemMetricsMonitor,
ThresholdStopper,
)
from composer.callbacks.load_checkpoint import LoadCheckpoint
from composer.loggers import (
CometMLLogger,
ConsoleLogger,
LoggerDestination,
MLFlowLogger,
NeptuneLogger,
ProgressBarLogger,
RemoteUploaderDownloader,
TensorboardLogger,
WandBLogger,
)
from composer.models.base import ComposerModel
from composer.utils import dist
from composer.utils.device import get_device
from tests.common import get_module_subclasses
from tests.common.datasets import RandomClassificationDataset, dummy_gpt_lm_dataloader
from tests.common.models import SimpleModel, configure_tiny_gpt2_hf_model
try:
import wandb
_WANDB_INSTALLED = True
del wandb # unused
except ImportError:
_WANDB_INSTALLED = False
try:
import tensorboard
_TENSORBOARD_INSTALLED = True
del tensorboard # unused
except ImportError:
_TENSORBOARD_INSTALLED = False
try:
import comet_ml
_COMETML_INSTALLED = True
os.environ['COMET_API_KEY']
del comet_ml # unused
except ImportError:
_COMETML_INSTALLED = False
# If COMET_API_KEY not set.
except KeyError:
_COMETML_INSTALLED = False
try:
import mlperf_logging
_MLPERF_INSTALLED = True
del mlperf_logging
except ImportError:
_MLPERF_INSTALLED = False
try:
import mlflow
_MLFLOW_INSTALLED = True
del mlflow
except ImportError:
_MLFLOW_INSTALLED = False
try:
import libcloud
_LIBCLOUD_INSTALLED = True
del libcloud # unused
except ImportError:
_LIBCLOUD_INSTALLED = False
try:
import pynmvl
_PYNMVL_INSTALLED = True
del pynmvl # unused
except ImportError:
_PYNMVL_INSTALLED = False
try:
import neptune
_NEPTUNE_INSTALLED = True
del neptune # unused
except ImportError:
_NEPTUNE_INSTALLED = False
_callback_kwargs: dict[type[Callback], dict[str, Any]] = {
Generate: {
'prompts': ['a', 'b', 'c'],
'interval': '1ba',
'batch_size': 2,
'max_new_tokens': 20,
},
RemoteUploaderDownloader: {
'bucket_uri': 'libcloud://.',
'backend_kwargs': {
'provider': 'local',
'container': '.',
'provider_kwargs': {
'key': '.',
},
},
'use_procs': False,
'num_concurrent_uploads': 1,
},
ThresholdStopper: {
'monitor': 'MulticlassAccuracy',
'dataloader_label': 'train',
'threshold': 0.99,
},
EarlyStopper: {
'monitor': 'MulticlassAccuracy',
'dataloader_label': 'train',
},
ExportForInferenceCallback: {
'save_format': 'torchscript',
'save_path': '/tmp/model.pth',
},
MLPerfCallback: {
'root_folder': '.',
'index': 0,
},
SpeedMonitor: {
'window_size': 1,
},
NeptuneLogger: {
'mode': 'debug',
},
WandBLogger: {
'init_kwargs': {
'mode': 'offline',
},
},
composer.profiler.Profiler: {
'trace_handlers': [MagicMock()],
'schedule': composer.profiler.cyclic_schedule(),
},
LoadCheckpoint: {
'load_path': 'fake-path',
},
}
_callback_marks: dict[
type[Callback],
list[pytest.MarkDecorator],
] = {
RemoteUploaderDownloader: [
pytest.mark.filterwarnings(
# post_close might not be called if being used outside of the trainer
r'ignore:Implicitly cleaning up:ResourceWarning',
),
pytest.mark.skipif(not _LIBCLOUD_INSTALLED, reason='Libcloud is optional'),
],
MemoryMonitor: [
pytest.mark.
filterwarnings(r'ignore:The memory monitor only works on CUDA devices, but the model is on cpu:UserWarning'),
],
MemorySnapshot: [
pytest.mark.
filterwarnings(r'ignore:The memory snapshot only works on CUDA devices, but the model is on cpu:UserWarning'),
],
OOMObserver: [
pytest.mark.
filterwarnings(r'ignore:The oom observer only works on CUDA devices, but the model is on cpu:UserWarning'),
],
MLPerfCallback: [pytest.mark.skipif(not _MLPERF_INSTALLED, reason='MLPerf is optional')],
WandBLogger: [
pytest.mark.filterwarnings(r'ignore:unclosed file:ResourceWarning'),
pytest.mark.skipif(not _WANDB_INSTALLED, reason='Wandb is optional'),
],
ProgressBarLogger: [
pytest.mark.
filterwarnings(r'ignore:Specifying the ProgressBarLogger via `loggers` is not recommended as.*:Warning'),
],
ConsoleLogger: [
pytest.mark.
filterwarnings(r'ignore:Specifying the ConsoleLogger via `loggers` is not recommended as.*:Warning'),
],
CometMLLogger: [pytest.mark.skipif(not _COMETML_INSTALLED, reason='comet_ml is optional')],
TensorboardLogger: [pytest.mark.skipif(not _TENSORBOARD_INSTALLED, reason='Tensorboard is optional')],
ImageVisualizer: [pytest.mark.skipif(not _WANDB_INSTALLED, reason='Wandb is optional')],
MLFlowLogger: [pytest.mark.skipif(not _MLFLOW_INSTALLED, reason='mlflow is optional')],
SystemMetricsMonitor: [pytest.mark.skipif(not _PYNMVL_INSTALLED, reason='pynmvl is optional')],
NeptuneLogger: [pytest.mark.skipif(not _NEPTUNE_INSTALLED, reason='neptune is optional')],
}
def _mlflow_patch():
try:
import mlflow.utils.file_utils
original_is_directory = mlflow.utils.file_utils.is_directory
def patched_is_directory(path):
if path.endswith('.trash'):
return True
return original_is_directory(path)
return mock.patch('mlflow.utils.file_utils.is_directory', patched_is_directory)
except ImportError:
return contextlib.nullcontext()
_callback_patches: dict[type[Callback], Any] = {
LoadCheckpoint: lambda: mock.patch('composer.callbacks.load_checkpoint.load_checkpoint'),
MLFlowLogger: lambda: _mlflow_patch(),
}
def get_cb_patches(impl: type[Callback]):
patch_context = _callback_patches.get(impl, None)
if patch_context is None:
return contextlib.nullcontext()
if callable(patch_context):
return patch_context()
return patch_context
def get_cb_kwargs(impl: type[Callback]):
return _callback_kwargs.get(impl, {})
def _to_pytest_param(impl):
if impl not in _callback_marks:
return pytest.param(impl)
else:
marks = _callback_marks[impl]
return pytest.param(impl, marks=marks)
def get_cbs_and_marks(callbacks: bool = False, loggers: bool = False, profilers: bool = False):
"""Returns a list of :class:`pytest.mark.param` objects for all :class:`.Callback`.
The callbacks are correctly annotated with ``skipif`` marks for optional dependencies
and ``filterwarning`` marks for any warnings that might be emitted and are safe to ignore
This function is meant to be used like this::
import pytest
from tests.callbacks.callback_settings import get_cbs_and_marks, get_cb_kwargs
@pytest.mark.parametrize("cb_cls",get_cbs_and_marks(callbacks=True, loggers=True, profilers=True))
def test_something(cb_cls: Type[Callback]):
cb_kwargs = get_cb_kwargs(cb_cls)
cb = cb_cls(**cb_kwargs)
assert isinstance(cb, Callback)
"""
implementations = []
if callbacks:
implementations.extend(get_module_subclasses(composer.callbacks, Callback))
if loggers:
implementations.extend(get_module_subclasses(composer.loggers, LoggerDestination))
if profilers:
implementations.extend(get_module_subclasses(composer.profiler, Callback))
ans = [_to_pytest_param(impl) for impl in implementations]
if not len(ans):
raise ValueError('callbacks, loggers, or profilers must be True')
return ans
def get_cb_hparams_and_marks():
"""Returns a list of :class:`pytest.mark.param` objects for all ``callback_registry``
and ``logger_registry``entries.
The callbacks are correctly annotated with ``skipif`` marks for optional dependencies
and ``filterwarning`` marks for any warnings that might be emitted and are safe to ignore
This function is meant to be used like this::
import pytest
from tests.common.hparams import construct_from_yaml
from tests.callbacks.callback_settings import get_cb_hparams_and_marks, get_cb_kwargs
@pytest.mark.parametrize("constructor",get_cb_hparams_and_marks())
def test_something(constructor: Callable, yaml_dict: dict[str, Any]):
yaml_dict = get_cb_kwargs(constructor)
construct_from_yaml(constructor, yaml_dict=yaml_dict)
"""
# TODO: (Hanlin) populate this
implementations = []
ans = [_to_pytest_param(impl) for impl in implementations]
return ans
def get_cb_model_and_datasets(
cb: Callback,
dl_size=100,
**default_dl_kwargs,
) -> tuple[ComposerModel, DataLoader, DataLoader]:
if isinstance(cb, Generate):
if get_device(None).name == 'cpu' and dist.get_world_size() > 1:
pytest.xfail(
'GPT2 is not currently supported with DDP. See https://github.com/huggingface/transformers/issues/22482 for more details.',
)
return (
configure_tiny_gpt2_hf_model(),
dummy_gpt_lm_dataloader(size=dl_size),
dummy_gpt_lm_dataloader(size=dl_size),
)
model = SimpleModel()
if isinstance(cb, FreeOutputs):
model.get_metrics = lambda is_train=False: {}
return (
model,
DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs),
DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs),
)