2222from typing import Any
2323from typing import Optional
2424from typing import TypeVar
25+ from typing import Union
2526
2627import apache_beam as beam
2728from apache_beam .coders import DillCoder
3738from apache_beam .ml .anomaly .specifiable import Specifiable
3839from apache_beam .ml .inference .base import RunInference
3940from apache_beam .transforms .userstate import ReadModifyWriteStateSpec
41+ from apache_beam .typehints .typehints import TupleConstraint
4042
4143KeyT = TypeVar ('KeyT' )
42- TempKeyT = TypeVar ('TempKeyT' , bound = int )
43- InputT = tuple [KeyT , beam .Row ]
44- KeyedInputT = tuple [KeyT , tuple [TempKeyT , beam .Row ]]
45- KeyedOutputT = tuple [KeyT , tuple [TempKeyT , AnomalyResult ]]
46- OutputT = tuple [KeyT , AnomalyResult ]
44+ TempKeyT = TypeVar ('TempKeyT' , bound = str )
45+ InputT = beam .Row
46+ OutputT = AnomalyResult
47+ KeyedInputT = tuple [KeyT , beam .Row ]
48+ KeyedOutputT = tuple [KeyT , AnomalyResult ]
49+ NestedKeyedInputT = tuple [KeyT , tuple [TempKeyT , beam .Row ]]
50+ NestedKeyedOutputT = tuple [KeyT , tuple [TempKeyT , AnomalyResult ]]
4751
4852
4953class _ScoreAndLearnDoFn (beam .DoFn ):
@@ -86,9 +90,9 @@ def score_and_learn(self, data):
8690
8791 def process (
8892 self ,
89- element : KeyedInputT ,
93+ element : NestedKeyedInputT ,
9094 model_state = beam .DoFn .StateParam (MODEL_STATE_INDEX ),
91- ** kwargs ) -> Iterable [KeyedOutputT ]:
95+ ** kwargs ) -> Iterable [NestedKeyedOutputT ]:
9296
9397 k1 , (k2 , data ) = element
9498 self ._underlying : AnomalyDetector = model_state .read ()
@@ -107,8 +111,8 @@ def process(
107111 model_state .write (self ._underlying )
108112
109113
110- class RunScoreAndLearn (beam .PTransform [beam .PCollection [KeyedInputT ],
111- beam .PCollection [KeyedOutputT ]]):
114+ class RunScoreAndLearn (beam .PTransform [beam .PCollection [NestedKeyedInputT ],
115+ beam .PCollection [NestedKeyedOutputT ]]):
112116 """Applies the _ScoreAndLearnDoFn to a PCollection of data.
113117
114118 This PTransform scores and learns from data points using an anomaly
@@ -121,8 +125,8 @@ def __init__(self, detector: AnomalyDetector):
121125 self ._detector = detector
122126
123127 def expand (
124- self ,
125- input : beam . PCollection [ KeyedInputT ] ) -> beam .PCollection [KeyedOutputT ]:
128+ self , input : beam . PCollection [ NestedKeyedInputT ]
129+ ) -> beam .PCollection [NestedKeyedOutputT ]:
126130 return input | beam .ParDo (_ScoreAndLearnDoFn (self ._detector .to_spec ()))
127131
128132
@@ -185,7 +189,8 @@ def __init__(self, threshold_fn_spec: Spec):
185189 assert not self ._threshold_fn .is_stateful , \
186190 "This DoFn can only take stateless function as threshold_fn"
187191
188- def process (self , element : KeyedOutputT , ** kwargs ) -> Iterable [KeyedOutputT ]:
192+ def process (self , element : NestedKeyedOutputT ,
193+ ** kwargs ) -> Iterable [NestedKeyedOutputT ]:
189194 """Processes a batch of anomaly results using a stateless ThresholdFn.
190195
191196 Args:
@@ -235,9 +240,9 @@ def __init__(self, threshold_fn_spec: Spec):
235240
236241 def process (
237242 self ,
238- element : KeyedOutputT ,
243+ element : NestedKeyedOutputT ,
239244 threshold_state = beam .DoFn .StateParam (THRESHOLD_STATE_INDEX ),
240- ** kwargs ) -> Iterable [KeyedOutputT ]:
245+ ** kwargs ) -> Iterable [NestedKeyedOutputT ]:
241246 """Processes a batch of anomaly results using a stateful ThresholdFn.
242247
243248 For each input element, this DoFn retrieves the stateful `ThresholdFn` from
@@ -273,8 +278,9 @@ def process(
273278 threshold_state .write (self ._threshold_fn )
274279
275280
276- class RunThresholdCriterion (beam .PTransform [beam .PCollection [KeyedOutputT ],
277- beam .PCollection [KeyedOutputT ]]):
281+ class RunThresholdCriterion (
282+ beam .PTransform [beam .PCollection [NestedKeyedOutputT ],
283+ beam .PCollection [NestedKeyedOutputT ]]):
278284 """Applies a threshold criterion to anomaly detection results.
279285
280286 This PTransform applies a `ThresholdFn` to the anomaly scores in
@@ -288,8 +294,8 @@ def __init__(self, threshold_criterion: ThresholdFn):
288294 self ._threshold_fn = threshold_criterion
289295
290296 def expand (
291- self ,
292- input : beam . PCollection [ KeyedOutputT ] ) -> beam .PCollection [KeyedOutputT ]:
297+ self , input : beam . PCollection [ NestedKeyedOutputT ]
298+ ) -> beam .PCollection [NestedKeyedOutputT ]:
293299
294300 if self ._threshold_fn .is_stateful :
295301 return (
@@ -301,8 +307,9 @@ def expand(
301307 | beam .ParDo (_StatelessThresholdDoFn (self ._threshold_fn .to_spec ())))
302308
303309
304- class RunAggregationStrategy (beam .PTransform [beam .PCollection [KeyedOutputT ],
305- beam .PCollection [KeyedOutputT ]]):
310+ class RunAggregationStrategy (
311+ beam .PTransform [beam .PCollection [NestedKeyedOutputT ],
312+ beam .PCollection [NestedKeyedOutputT ]]):
306313 """Applies an aggregation strategy to grouped anomaly detection results.
307314
308315 This PTransform aggregates anomaly predictions from multiple models or
@@ -319,8 +326,8 @@ def __init__(
319326 self ._agg_model_id = agg_model_id
320327
321328 def expand (
322- self ,
323- input : beam . PCollection [ KeyedOutputT ] ) -> beam .PCollection [KeyedOutputT ]:
329+ self , input : beam . PCollection [ NestedKeyedOutputT ]
330+ ) -> beam .PCollection [NestedKeyedOutputT ]:
324331 post_gbk = (
325332 input | beam .MapTuple (lambda k , v : ((k , v [0 ]), v [1 ]))
326333 | beam .GroupByKey ())
@@ -367,8 +374,8 @@ def expand(
367374 return ret
368375
369376
370- class RunOneDetector (beam .PTransform [beam .PCollection [KeyedInputT ],
371- beam .PCollection [KeyedOutputT ]]):
377+ class RunOneDetector (beam .PTransform [beam .PCollection [NestedKeyedInputT ],
378+ beam .PCollection [NestedKeyedOutputT ]]):
372379 """Runs a single anomaly detector on a PCollection of data.
373380
374381 This PTransform applies a single `AnomalyDetector` to the input data,
@@ -381,8 +388,8 @@ def __init__(self, detector):
381388 self ._detector = detector
382389
383390 def expand (
384- self ,
385- input : beam . PCollection [ KeyedInputT ] ) -> beam .PCollection [KeyedOutputT ]:
391+ self , input : beam . PCollection [ NestedKeyedInputT ]
392+ ) -> beam .PCollection [NestedKeyedOutputT ]:
386393 model_id = getattr (
387394 self ._detector ,
388395 "_model_id" ,
@@ -402,8 +409,8 @@ def expand(
402409 return ret
403410
404411
405- class RunOfflineDetector (beam .PTransform [beam .PCollection [KeyedInputT ],
406- beam .PCollection [KeyedOutputT ]]):
412+ class RunOfflineDetector (beam .PTransform [beam .PCollection [NestedKeyedInputT ],
413+ beam .PCollection [NestedKeyedOutputT ]]):
407414 """Runs a offline anomaly detector on a PCollection of data.
408415
409416 This PTransform applies a `OfflineDetector` to the input data, handling
@@ -416,7 +423,7 @@ def __init__(self, offline_detector: OfflineDetector):
416423 self ._offline_detector = offline_detector
417424
418425 def _restore_and_convert (
419- self , elem : tuple [tuple [Any , Any , beam .Row ], Any ]) -> KeyedOutputT :
426+ self , elem : tuple [tuple [Any , Any , beam .Row ], Any ]) -> NestedKeyedOutputT :
420427 """Converts the model output to AnomalyResult.
421428
422429 Args:
@@ -454,8 +461,8 @@ def _select_features(self, elem: tuple[Any,
454461 for k in self ._offline_detector ._features }))
455462
456463 def expand (
457- self ,
458- input : beam . PCollection [ KeyedInputT ] ) -> beam .PCollection [KeyedOutputT ]:
464+ self , input : beam . PCollection [ NestedKeyedInputT ]
465+ ) -> beam .PCollection [NestedKeyedOutputT ]:
459466 model_uuid = f"{ self ._offline_detector ._model_id } :{ uuid .uuid4 ().hex [:6 ]} "
460467
461468 # Call RunInference Transform with the keyed model handler
@@ -488,8 +495,9 @@ def expand(
488495 return ret
489496
490497
491- class RunEnsembleDetector (beam .PTransform [beam .PCollection [KeyedInputT ],
492- beam .PCollection [KeyedOutputT ]]):
498+ class RunEnsembleDetector (beam .PTransform [beam .PCollection [NestedKeyedInputT ],
499+ beam .PCollection [NestedKeyedOutputT ]]
500+ ):
493501 """Runs an ensemble of anomaly detectors on a PCollection of data.
494502
495503 This PTransform applies an `EnsembleAnomalyDetector` to the input data,
@@ -502,8 +510,8 @@ def __init__(self, ensemble_detector: EnsembleAnomalyDetector):
502510 self ._ensemble_detector = ensemble_detector
503511
504512 def expand (
505- self ,
506- input : beam . PCollection [ KeyedInputT ] ) -> beam .PCollection [KeyedOutputT ]:
513+ self , input : beam . PCollection [ NestedKeyedInputT ]
514+ ) -> beam .PCollection [NestedKeyedOutputT ]:
507515 model_uuid = f"{ self ._ensemble_detector ._model_id } :{ uuid .uuid4 ().hex [:6 ]} "
508516
509517 assert self ._ensemble_detector ._sub_detectors is not None
@@ -548,8 +556,10 @@ def expand(
548556 return ret
549557
550558
551- class AnomalyDetection (beam .PTransform [beam .PCollection [InputT ],
552- beam .PCollection [OutputT ]]):
559+ class AnomalyDetection (beam .PTransform [beam .PCollection [Union [InputT ,
560+ KeyedInputT ]],
561+ beam .PCollection [Union [OutputT ,
562+ KeyedOutputT ]]]):
553563 """Performs anomaly detection on a PCollection of data.
554564
555565 This PTransform applies an `AnomalyDetector` or `EnsembleAnomalyDetector` to
@@ -576,8 +586,8 @@ def __init__(
576586
577587 def expand (
578588 self ,
579- input : beam .PCollection [InputT ],
580- ) -> beam .PCollection [OutputT ]:
589+ input : beam .PCollection [Union [ InputT , KeyedInputT ] ],
590+ ) -> beam .PCollection [Union [ OutputT , KeyedOutputT ] ]:
581591
582592 # Add a temporary unique key per data point to facilitate grouping the
583593 # outputs from multiple anomaly detectors for the same data point.
@@ -600,20 +610,43 @@ def expand(
600610 #
601611 # We select uuid.uuid1() for its inclusion of node information, making it
602612 # more suitable for parallel execution environments.
603- add_temp_key_fn : Callable [[InputT ], KeyedInputT ] \
604- = lambda e : (e [0 ], (str (uuid .uuid1 ()), e [1 ]))
605- keyed_input = (input | "Add temp key" >> beam .Map (add_temp_key_fn ))
606613
614+ if isinstance (input .element_type , TupleConstraint ):
615+ keyed_input = input
616+ else :
617+ # Add a default key 0 if the input is unkeyed.
618+ keyed_input = input | beam .WithKeys (0 )
619+
620+ add_temp_key_fn : Callable [[KeyedInputT ], NestedKeyedInputT ]
621+ run_detector : beam .PTransform
607622 if isinstance (self ._root_detector , EnsembleAnomalyDetector ):
608- keyed_output = (keyed_input | RunEnsembleDetector (self ._root_detector ))
609- elif isinstance (self ._root_detector , OfflineDetector ):
610- keyed_output = (keyed_input | RunOfflineDetector (self ._root_detector ))
623+ add_temp_key_fn = lambda e : (e [0 ], (str (uuid .uuid1 ()), e [1 ]))
624+ run_detector = RunEnsembleDetector (self ._root_detector )
611625 else :
612- keyed_output = (keyed_input | RunOneDetector (self ._root_detector ))
626+ # If there is only one non-ensemble detector, temp key can be the same
627+ # because we don't need it to identify each input during result
628+ # aggregation.
629+ add_temp_key_fn = lambda e : (e [0 ], ("" , e [1 ]))
630+ if isinstance (self ._root_detector , OfflineDetector ):
631+ run_detector = RunOfflineDetector (self ._root_detector )
632+ else :
633+ run_detector = RunOneDetector (self ._root_detector )
634+
635+ nested_keyed_input = (
636+ keyed_input | "Add temp key" >> beam .Map (add_temp_key_fn ))
613637
614- # remove the temporary key and simplify the output.
615- remove_temp_key_fn : Callable [[KeyedOutputT ], OutputT ] \
638+ nested_keyed_output = nested_keyed_input | run_detector
639+
640+ # Remove the temporary key and simplify the output.
641+ remove_temp_key_fn : Callable [[NestedKeyedOutputT ], KeyedOutputT ] \
616642 = lambda e : (e [0 ], e [1 ][1 ])
617- ret = keyed_output | "Remove temp key" >> beam .Map (remove_temp_key_fn )
643+ keyed_output = nested_keyed_output | "Remove temp key" >> beam .Map (
644+ remove_temp_key_fn )
645+
646+ if isinstance (input .element_type , TupleConstraint ):
647+ ret = keyed_output
648+ else :
649+ # Remove the default key if the input is unkeyed.
650+ ret = keyed_output | beam .Values ()
618651
619652 return ret
0 commit comments