Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 28df7cb

Browse files
authored
[AnomalyDetection] Support unkeyed input for the transform. (#34897)
* Support keyed and unkeyed input for AnomalyDetection. - Restructure test code and prepare for adding tests for unkeyed input. - Add tests for unkeyed input. - Optimize the case when running one non-ensemble detector. * Correct TempKeyT's bound. * Minor changes on comments. * Fix lints. * Fix lints. * Change the default key from None to 0 when adding it to unkeyed input.
1 parent 23dc447 commit 28df7cb

2 files changed

Lines changed: 241 additions & 192 deletions

File tree

sdks/python/apache_beam/ml/anomaly/transforms.py

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any
2323
from typing import Optional
2424
from typing import TypeVar
25+
from typing import Union
2526

2627
import apache_beam as beam
2728
from apache_beam.coders import DillCoder
@@ -37,13 +38,16 @@
3738
from apache_beam.ml.anomaly.specifiable import Specifiable
3839
from apache_beam.ml.inference.base import RunInference
3940
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
41+
from apache_beam.typehints.typehints import TupleConstraint
4042

4143
KeyT = TypeVar('KeyT')
42-
TempKeyT = TypeVar('TempKeyT', bound=int)
43-
InputT = tuple[KeyT, beam.Row]
44-
KeyedInputT = tuple[KeyT, tuple[TempKeyT, beam.Row]]
45-
KeyedOutputT = tuple[KeyT, tuple[TempKeyT, AnomalyResult]]
46-
OutputT = tuple[KeyT, AnomalyResult]
44+
TempKeyT = TypeVar('TempKeyT', bound=str)
45+
InputT = beam.Row
46+
OutputT = AnomalyResult
47+
KeyedInputT = tuple[KeyT, beam.Row]
48+
KeyedOutputT = tuple[KeyT, AnomalyResult]
49+
NestedKeyedInputT = tuple[KeyT, tuple[TempKeyT, beam.Row]]
50+
NestedKeyedOutputT = tuple[KeyT, tuple[TempKeyT, AnomalyResult]]
4751

4852

4953
class _ScoreAndLearnDoFn(beam.DoFn):
@@ -86,9 +90,9 @@ def score_and_learn(self, data):
8690

8791
def process(
8892
self,
89-
element: KeyedInputT,
93+
element: NestedKeyedInputT,
9094
model_state=beam.DoFn.StateParam(MODEL_STATE_INDEX),
91-
**kwargs) -> Iterable[KeyedOutputT]:
95+
**kwargs) -> Iterable[NestedKeyedOutputT]:
9296

9397
k1, (k2, data) = element
9498
self._underlying: AnomalyDetector = model_state.read()
@@ -107,8 +111,8 @@ def process(
107111
model_state.write(self._underlying)
108112

109113

110-
class RunScoreAndLearn(beam.PTransform[beam.PCollection[KeyedInputT],
111-
beam.PCollection[KeyedOutputT]]):
114+
class RunScoreAndLearn(beam.PTransform[beam.PCollection[NestedKeyedInputT],
115+
beam.PCollection[NestedKeyedOutputT]]):
112116
"""Applies the _ScoreAndLearnDoFn to a PCollection of data.
113117
114118
This PTransform scores and learns from data points using an anomaly
@@ -121,8 +125,8 @@ def __init__(self, detector: AnomalyDetector):
121125
self._detector = detector
122126

123127
def expand(
124-
self,
125-
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
128+
self, input: beam.PCollection[NestedKeyedInputT]
129+
) -> beam.PCollection[NestedKeyedOutputT]:
126130
return input | beam.ParDo(_ScoreAndLearnDoFn(self._detector.to_spec()))
127131

128132

@@ -185,7 +189,8 @@ def __init__(self, threshold_fn_spec: Spec):
185189
assert not self._threshold_fn.is_stateful, \
186190
"This DoFn can only take stateless function as threshold_fn"
187191

188-
def process(self, element: KeyedOutputT, **kwargs) -> Iterable[KeyedOutputT]:
192+
def process(self, element: NestedKeyedOutputT,
193+
**kwargs) -> Iterable[NestedKeyedOutputT]:
189194
"""Processes a batch of anomaly results using a stateless ThresholdFn.
190195
191196
Args:
@@ -235,9 +240,9 @@ def __init__(self, threshold_fn_spec: Spec):
235240

236241
def process(
237242
self,
238-
element: KeyedOutputT,
243+
element: NestedKeyedOutputT,
239244
threshold_state=beam.DoFn.StateParam(THRESHOLD_STATE_INDEX),
240-
**kwargs) -> Iterable[KeyedOutputT]:
245+
**kwargs) -> Iterable[NestedKeyedOutputT]:
241246
"""Processes a batch of anomaly results using a stateful ThresholdFn.
242247
243248
For each input element, this DoFn retrieves the stateful `ThresholdFn` from
@@ -273,8 +278,9 @@ def process(
273278
threshold_state.write(self._threshold_fn)
274279

275280

276-
class RunThresholdCriterion(beam.PTransform[beam.PCollection[KeyedOutputT],
277-
beam.PCollection[KeyedOutputT]]):
281+
class RunThresholdCriterion(
282+
beam.PTransform[beam.PCollection[NestedKeyedOutputT],
283+
beam.PCollection[NestedKeyedOutputT]]):
278284
"""Applies a threshold criterion to anomaly detection results.
279285
280286
This PTransform applies a `ThresholdFn` to the anomaly scores in
@@ -288,8 +294,8 @@ def __init__(self, threshold_criterion: ThresholdFn):
288294
self._threshold_fn = threshold_criterion
289295

290296
def expand(
291-
self,
292-
input: beam.PCollection[KeyedOutputT]) -> beam.PCollection[KeyedOutputT]:
297+
self, input: beam.PCollection[NestedKeyedOutputT]
298+
) -> beam.PCollection[NestedKeyedOutputT]:
293299

294300
if self._threshold_fn.is_stateful:
295301
return (
@@ -301,8 +307,9 @@ def expand(
301307
| beam.ParDo(_StatelessThresholdDoFn(self._threshold_fn.to_spec())))
302308

303309

304-
class RunAggregationStrategy(beam.PTransform[beam.PCollection[KeyedOutputT],
305-
beam.PCollection[KeyedOutputT]]):
310+
class RunAggregationStrategy(
311+
beam.PTransform[beam.PCollection[NestedKeyedOutputT],
312+
beam.PCollection[NestedKeyedOutputT]]):
306313
"""Applies an aggregation strategy to grouped anomaly detection results.
307314
308315
This PTransform aggregates anomaly predictions from multiple models or
@@ -319,8 +326,8 @@ def __init__(
319326
self._agg_model_id = agg_model_id
320327

321328
def expand(
322-
self,
323-
input: beam.PCollection[KeyedOutputT]) -> beam.PCollection[KeyedOutputT]:
329+
self, input: beam.PCollection[NestedKeyedOutputT]
330+
) -> beam.PCollection[NestedKeyedOutputT]:
324331
post_gbk = (
325332
input | beam.MapTuple(lambda k, v: ((k, v[0]), v[1]))
326333
| beam.GroupByKey())
@@ -367,8 +374,8 @@ def expand(
367374
return ret
368375

369376

370-
class RunOneDetector(beam.PTransform[beam.PCollection[KeyedInputT],
371-
beam.PCollection[KeyedOutputT]]):
377+
class RunOneDetector(beam.PTransform[beam.PCollection[NestedKeyedInputT],
378+
beam.PCollection[NestedKeyedOutputT]]):
372379
"""Runs a single anomaly detector on a PCollection of data.
373380
374381
This PTransform applies a single `AnomalyDetector` to the input data,
@@ -381,8 +388,8 @@ def __init__(self, detector):
381388
self._detector = detector
382389

383390
def expand(
384-
self,
385-
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
391+
self, input: beam.PCollection[NestedKeyedInputT]
392+
) -> beam.PCollection[NestedKeyedOutputT]:
386393
model_id = getattr(
387394
self._detector,
388395
"_model_id",
@@ -402,8 +409,8 @@ def expand(
402409
return ret
403410

404411

405-
class RunOfflineDetector(beam.PTransform[beam.PCollection[KeyedInputT],
406-
beam.PCollection[KeyedOutputT]]):
412+
class RunOfflineDetector(beam.PTransform[beam.PCollection[NestedKeyedInputT],
413+
beam.PCollection[NestedKeyedOutputT]]):
407414
"""Runs a offline anomaly detector on a PCollection of data.
408415
409416
This PTransform applies a `OfflineDetector` to the input data, handling
@@ -416,7 +423,7 @@ def __init__(self, offline_detector: OfflineDetector):
416423
self._offline_detector = offline_detector
417424

418425
def _restore_and_convert(
419-
self, elem: tuple[tuple[Any, Any, beam.Row], Any]) -> KeyedOutputT:
426+
self, elem: tuple[tuple[Any, Any, beam.Row], Any]) -> NestedKeyedOutputT:
420427
"""Converts the model output to AnomalyResult.
421428
422429
Args:
@@ -454,8 +461,8 @@ def _select_features(self, elem: tuple[Any,
454461
for k in self._offline_detector._features}))
455462

456463
def expand(
457-
self,
458-
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
464+
self, input: beam.PCollection[NestedKeyedInputT]
465+
) -> beam.PCollection[NestedKeyedOutputT]:
459466
model_uuid = f"{self._offline_detector._model_id}:{uuid.uuid4().hex[:6]}"
460467

461468
# Call RunInference Transform with the keyed model handler
@@ -488,8 +495,9 @@ def expand(
488495
return ret
489496

490497

491-
class RunEnsembleDetector(beam.PTransform[beam.PCollection[KeyedInputT],
492-
beam.PCollection[KeyedOutputT]]):
498+
class RunEnsembleDetector(beam.PTransform[beam.PCollection[NestedKeyedInputT],
499+
beam.PCollection[NestedKeyedOutputT]]
500+
):
493501
"""Runs an ensemble of anomaly detectors on a PCollection of data.
494502
495503
This PTransform applies an `EnsembleAnomalyDetector` to the input data,
@@ -502,8 +510,8 @@ def __init__(self, ensemble_detector: EnsembleAnomalyDetector):
502510
self._ensemble_detector = ensemble_detector
503511

504512
def expand(
505-
self,
506-
input: beam.PCollection[KeyedInputT]) -> beam.PCollection[KeyedOutputT]:
513+
self, input: beam.PCollection[NestedKeyedInputT]
514+
) -> beam.PCollection[NestedKeyedOutputT]:
507515
model_uuid = f"{self._ensemble_detector._model_id}:{uuid.uuid4().hex[:6]}"
508516

509517
assert self._ensemble_detector._sub_detectors is not None
@@ -548,8 +556,10 @@ def expand(
548556
return ret
549557

550558

551-
class AnomalyDetection(beam.PTransform[beam.PCollection[InputT],
552-
beam.PCollection[OutputT]]):
559+
class AnomalyDetection(beam.PTransform[beam.PCollection[Union[InputT,
560+
KeyedInputT]],
561+
beam.PCollection[Union[OutputT,
562+
KeyedOutputT]]]):
553563
"""Performs anomaly detection on a PCollection of data.
554564
555565
This PTransform applies an `AnomalyDetector` or `EnsembleAnomalyDetector` to
@@ -576,8 +586,8 @@ def __init__(
576586

577587
def expand(
578588
self,
579-
input: beam.PCollection[InputT],
580-
) -> beam.PCollection[OutputT]:
589+
input: beam.PCollection[Union[InputT, KeyedInputT]],
590+
) -> beam.PCollection[Union[OutputT, KeyedOutputT]]:
581591

582592
# Add a temporary unique key per data point to facilitate grouping the
583593
# outputs from multiple anomaly detectors for the same data point.
@@ -600,20 +610,43 @@ def expand(
600610
#
601611
# We select uuid.uuid1() for its inclusion of node information, making it
602612
# more suitable for parallel execution environments.
603-
add_temp_key_fn: Callable[[InputT], KeyedInputT] \
604-
= lambda e: (e[0], (str(uuid.uuid1()), e[1]))
605-
keyed_input = (input | "Add temp key" >> beam.Map(add_temp_key_fn))
606613

614+
if isinstance(input.element_type, TupleConstraint):
615+
keyed_input = input
616+
else:
617+
# Add a default key 0 if the input is unkeyed.
618+
keyed_input = input | beam.WithKeys(0)
619+
620+
add_temp_key_fn: Callable[[KeyedInputT], NestedKeyedInputT]
621+
run_detector: beam.PTransform
607622
if isinstance(self._root_detector, EnsembleAnomalyDetector):
608-
keyed_output = (keyed_input | RunEnsembleDetector(self._root_detector))
609-
elif isinstance(self._root_detector, OfflineDetector):
610-
keyed_output = (keyed_input | RunOfflineDetector(self._root_detector))
623+
add_temp_key_fn = lambda e: (e[0], (str(uuid.uuid1()), e[1]))
624+
run_detector = RunEnsembleDetector(self._root_detector)
611625
else:
612-
keyed_output = (keyed_input | RunOneDetector(self._root_detector))
626+
# If there is only one non-ensemble detector, temp key can be the same
627+
# because we don't need it to identify each input during result
628+
# aggregation.
629+
add_temp_key_fn = lambda e: (e[0], ("", e[1]))
630+
if isinstance(self._root_detector, OfflineDetector):
631+
run_detector = RunOfflineDetector(self._root_detector)
632+
else:
633+
run_detector = RunOneDetector(self._root_detector)
634+
635+
nested_keyed_input = (
636+
keyed_input | "Add temp key" >> beam.Map(add_temp_key_fn))
613637

614-
# remove the temporary key and simplify the output.
615-
remove_temp_key_fn: Callable[[KeyedOutputT], OutputT] \
638+
nested_keyed_output = nested_keyed_input | run_detector
639+
640+
# Remove the temporary key and simplify the output.
641+
remove_temp_key_fn: Callable[[NestedKeyedOutputT], KeyedOutputT] \
616642
= lambda e: (e[0], e[1][1])
617-
ret = keyed_output | "Remove temp key" >> beam.Map(remove_temp_key_fn)
643+
keyed_output = nested_keyed_output | "Remove temp key" >> beam.Map(
644+
remove_temp_key_fn)
645+
646+
if isinstance(input.element_type, TupleConstraint):
647+
ret = keyed_output
648+
else:
649+
# Remove the default key if the input is unkeyed.
650+
ret = keyed_output | beam.Values()
618651

619652
return ret

0 commit comments

Comments
 (0)