1
1
import getpass
2
-
3
- import dill
4
2
import os
5
- import pandas as pd
6
- import pkg_resources
7
3
import platform
8
4
from base64 import b64encode , b64decode
9
5
from datetime import datetime
10
6
from itertools import product
11
7
from uuid import uuid4
12
8
9
+ import dill
10
+ import pandas as pd
11
+ import pkg_resources
12
+ import pymongo
13
+
13
14
from omegaml .backends .basemodel import BaseModelBackend
14
15
from omegaml .documents import Metadata
15
- from omegaml .util import _raise , settings
16
+ from omegaml .util import _raise , settings , ensure_index
16
17
17
18
18
19
class ExperimentBackend (BaseModelBackend ):
@@ -183,12 +184,24 @@ def active_run(self):
183
184
def status (self , run = None ):
184
185
return 'STOPPED'
185
186
186
- def start (self ):
187
+ def start (self , run = None ):
187
188
raise NotImplementedError
188
189
189
190
def stop (self ):
190
191
raise NotImplementedError
191
192
193
+ def start_runtime (self ):
194
+ # hook to signal the runtime is starting a task inside a worker
195
+ # this is unlike the .start() method which is called to start a run
196
+ # which can happen in the client or in the runtime
197
+ pass
198
+
199
+ def stop_runtime (self ):
200
+ # hook to signal the runtime has completed a task inside a worker
201
+ # this is unlike the .stop() method which is called to stop a run
202
+ # which can happen in the client or in the runtime
203
+ self .flush ()
204
+
192
205
def log_event (self , event , key , value , step = None , ** extra ):
193
206
raise NotImplementedError
194
207
@@ -211,6 +224,9 @@ def tensorflow_callback(self):
211
224
def data (self , experiment = None , run = None , event = None , step = None , key = None , raw = False ):
212
225
raise NotImplementedError
213
226
227
+ def flush (self ):
228
+ pass
229
+
214
230
@property
215
231
def _data_name (self ):
216
232
return f'.experiments/{ self ._experiment } '
@@ -219,7 +235,7 @@ def _data_name(self):
219
235
class NoTrackTracker (TrackingProvider ):
220
236
""" A default tracker that does not record anything """
221
237
222
- def start (self ):
238
+ def start (self , run = None ):
223
239
pass
224
240
225
241
def stop (self ):
@@ -261,6 +277,12 @@ class OmegaSimpleTracker(TrackingProvider):
261
277
_ensure_active = lambda self , r : r if r is not None else _raise (
262
278
ValueError ('no active run, call .start() or .use() ' ))
263
279
280
+ def __init__ (self , * args , ** kwargs ):
281
+ super ().__init__ (* args , ** kwargs )
282
+ self .log_buffer = []
283
+ self .max_buffer = 10
284
+ self ._initialize_dataset ()
285
+
264
286
def active_run (self , run = None ):
265
287
""" set the lastest run as the active run
266
288
@@ -293,8 +315,9 @@ def use(self, run=None):
293
315
294
316
@property
295
317
def _latest_run (self ):
296
- data = self .data (event = 'start' , raw = True )
297
- run = data [- 1 ]['run' ] if data is not None and len (data ) > 0 else None
318
+ cursor = self .data (event = 'start' , lazy = True )
319
+ data = list (cursor .sort ('data.run' , - 1 ).limit (1 )) if cursor else None
320
+ run = data [- 1 ].get ('data' , {}).get ('run' ) if data is not None and len (data ) > 0 else None
298
321
return run
299
322
300
323
def status (self , run = None ):
@@ -307,7 +330,7 @@ def status(self, run=None):
307
330
status in 'STARTED', 'STOPPED'
308
331
"""
309
332
self ._run = run or self ._run or self ._latest_run
310
- data = self .data (event = ( 'start' , 'stop' ) , run = self ._run , raw = True )
333
+ data = self .data (event = [ 'start' , 'stop' ] , run = self ._run , raw = True )
311
334
no_runs = data is None or len (data ) == 0
312
335
has_stop = sum (1 for row in (data or []) if row .get ('event' ) == 'stop' )
313
336
return 'PENDING' if no_runs else 'STOPPED' if has_stop else 'STARTED'
@@ -320,7 +343,7 @@ def start(self, run=None):
320
343
self ._run = run or (self ._latest_run or 0 ) + 1
321
344
self ._startdt = datetime .utcnow ()
322
345
data = self ._common_log_data ('start' , key = None , value = None , step = None , dt = self ._startdt )
323
- self ._write_log (data )
346
+ self ._write_log (data , immediate = True )
324
347
return self ._run
325
348
326
349
def stop (self ):
@@ -331,6 +354,14 @@ def stop(self):
331
354
self ._stopdt = datetime .utcnow ()
332
355
data = self ._common_log_data ('stop' , key = None , value = None , step = None , dt = self ._stopdt )
333
356
self ._write_log (data )
357
+ self .flush ()
358
+
359
+ def flush (self ):
360
+ # passing list of list forces insert_many
361
+ if self .log_buffer :
362
+ self ._store .put (self .log_buffer , self ._data_name ,
363
+ noversion = True , as_many = True )
364
+ self .log_buffer .clear ()
334
365
335
366
def _common_log_data (self , event , key , value , step = None , dt = None , ** extra ):
336
367
if isinstance (value , dict ):
@@ -358,8 +389,10 @@ def _common_log_data(self, event, key, value, step=None, dt=None, **extra):
358
389
data .update (self ._extra_log ) if self ._extra_log else None
359
390
return data
360
391
361
- def _write_log (self , data ):
362
- self ._store .put (data , self ._data_name , noversion = True )
392
+ def _write_log (self , data , immediate = False ):
393
+ self .log_buffer .append (data )
394
+ if immediate or len (self .log_buffer ) > self .max_buffer :
395
+ self .flush ()
363
396
364
397
def log_artifact (self , obj , name , step = None , ** extra ):
365
398
""" log any object to the current run
@@ -396,7 +429,7 @@ def log_artifact(self, obj, name, step=None, **extra):
396
429
meta = self ._model_store .put (obj , f'.experiments/.artefacts/{ objname } ' )
397
430
format = 'model'
398
431
rawdata = meta .name
399
- elif self ._store .get_backend_by_obj (obj ) is not None :
432
+ elif self ._store .get_backend_byobj (obj ) is not None :
400
433
objname = uuid4 ().hex
401
434
meta = self ._store .put (obj , f'.experiments/.artefacts/{ objname } ' )
402
435
format = 'dataset'
@@ -485,16 +518,17 @@ def log_extra(self, remove=False, **kwargs):
485
518
consume (deletions , maxlen = 0 )
486
519
487
520
def data (self , experiment = None , run = None , event = None , step = None , key = None , raw = False ,
488
- ** extra ):
521
+ lazy = False , ** extra ):
489
522
""" build a dataframe of all stored data
490
523
491
524
Args:
492
- experiment (str): the name of the experiment, defaults to its current value
525
+ experiment (str|list ): the name of the experiment, defaults to its current value
493
526
run (int|list): the run(s) to get data back, defaults to current run, use 'all' for all
494
527
event (str|list): the event(s) to include
495
528
step (int|list): the step(s) to include
496
529
key (str|list): the key(s) to include
497
530
raw (bool): if True returns the raw data instead of a DataFrame
531
+ lazy (bool): if True returns the Cursor instead of data, ignores raw
498
532
499
533
Returns:
500
534
* data (DataFrame) if raw == False
@@ -519,12 +553,19 @@ def data(self, experiment=None, run=None, event=None, step=None, key=None, raw=F
519
553
for k , v in extra .items ():
520
554
if valid (k ):
521
555
filter [f'data.{ k } ' ] = op (v )
522
- data = self ._store .get (self ._data_name , filter = filter )
523
- if data is not None and not raw :
556
+ data = self ._store .get (self ._data_name , filter = filter , lazy = lazy )
557
+ if data is not None and not raw and not lazy :
524
558
data = pd .DataFrame .from_records (data )
525
559
data .sort_values ('dt' , inplace = True ) if 'dt' in data .columns else None
526
560
return data
527
561
562
+ def _initialize_dataset (self , force = False ):
563
+ # create indexes when the dataset is first created
564
+ if not force and self ._store .exists (self ._data_name ):
565
+ return
566
+ coll = self ._store .collection (self ._data_name )
567
+ ensure_index (coll , {'data.run' : pymongo .ASCENDING , 'data.event' : pymongo .ASCENDING })
568
+
528
569
def restore_artifact (self , key = None , experiment = None , run = None , step = None , value = None ):
529
570
""" restore a logged artificat
530
571
@@ -596,7 +637,7 @@ class OmegaProfilingTracker(OmegaSimpleTracker):
596
637
def __init__ (self , * args , ** kwargs ):
597
638
super ().__init__ (* args , ** kwargs )
598
639
self .profile_logs = []
599
- self .max_buffer = 6
640
+ self .max_buffer = 10
600
641
601
642
def log_profile (self , data ):
602
643
""" the callback for BackgroundProfiler """
@@ -605,27 +646,32 @@ def log_profile(self, data):
605
646
self .flush ()
606
647
607
648
def flush (self ):
608
- for step , data in enumerate (self .profile_logs ):
609
- # record the actual time instead of logging time (avoid buffering delays)
610
- dt = data .get ('profile_dt' )
611
- for k , v in data .items ():
612
- self .log_event ('profile' , k , v , step = step , dt = dt )
613
- self .profile_logs = []
614
-
615
- def start (self ):
649
+ def log_items ():
650
+ for step , data in enumerate (self .profile_logs ):
651
+ # record the actual time instead of logging time (avoid buffering delays)
652
+ dt = data .get ('profile_dt' )
653
+ for k , v in data .items ():
654
+ item = self ._common_log_data ('profile' , k , v , step = step , dt = dt )
655
+ yield item
656
+ if self .profile_logs :
657
+ self ._store .put ([item for item in log_items ()], self ._data_name ,
658
+ index = ['event' ], as_many = True , noversion = True )
659
+ self .profile_logs = []
660
+
661
+ def start_runtime (self ):
616
662
self .profiler = BackgroundProfiler (callback = self .log_profile )
617
663
self .profiler .start ()
618
- super ().start ()
664
+ super ().start_runtime ()
619
665
620
- def stop (self ):
666
+ def stop_runtime (self ):
621
667
self .profiler .stop ()
622
668
self .flush ()
623
- super ().stop ()
669
+ super ().stop_runtime ()
624
670
625
671
626
672
try :
627
673
from tensorflow import keras
628
- except :
674
+ except Exception :
629
675
pass
630
676
else :
631
677
class TensorflowCallback (keras .callbacks .Callback ):
0 commit comments