Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 74512d9

Browse files
authored
[YAML] Add a spec provider for transforms taking specifiable arguments (#35187)
* Add a test provider for specifiable and try it on AnomalyDetection. Also add support on callable in spec. * Minor renaming * Fix lints.
1 parent f7fc608 commit 74512d9

3 files changed

Lines changed: 176 additions & 0 deletions

File tree

sdks/python/apache_beam/yaml/yaml_provider.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1475,6 +1475,7 @@ def standard_providers():
14751475
from apache_beam.yaml.yaml_mapping import create_mapping_providers
14761476
from apache_beam.yaml.yaml_join import create_join_providers
14771477
from apache_beam.yaml.yaml_io import io_providers
1478+
from apache_beam.yaml.yaml_specifiable import create_spec_providers
14781479

14791480
return merge_providers(
14801481
YamlProviders.create_builtin_provider(),
@@ -1483,6 +1484,7 @@ def standard_providers():
14831484
create_combine_providers(),
14841485
create_join_providers(),
14851486
io_providers(),
1487+
create_spec_providers(),
14861488
load_providers(yaml_utils.locate_data_file('standard_providers.yaml')))
14871489

14881490

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from apache_beam.io.filesystems import FileSystems
19+
from apache_beam.ml.anomaly.specifiable import Spec
20+
from apache_beam.ml.anomaly.transforms import AnomalyDetection
21+
from apache_beam.ml.anomaly.transforms import Specifiable
22+
from apache_beam.utils import python_callable
23+
from apache_beam.yaml.yaml_provider import InlineProvider
24+
25+
26+
def maybe_make_specifiable(v):
27+
if isinstance(v, dict):
28+
if "type" in v and "config" in v:
29+
return Specifiable.from_spec(
30+
Spec(type=v["type"], config=maybe_make_specifiable(v["config"])))
31+
32+
if "callable" in v:
33+
if "path" in v or "name" in v:
34+
raise ValueError(
35+
"Cannot specify 'callable' with 'path' and 'name' for function.")
36+
else:
37+
return python_callable.PythonCallableWithSource(v["callable"])
38+
39+
if "path" in v and "name" in v:
40+
return python_callable.PythonCallableWithSource.load_from_script(
41+
FileSystems.open(v["path"]).read().decode(), v["name"])
42+
43+
ret = {k: maybe_make_specifiable(v[k]) for k in v}
44+
return ret
45+
else:
46+
return v
47+
48+
49+
class SpecProvider(InlineProvider):
50+
def create_transform(self, type, args, yaml_create_transform):
51+
return self._transform_factories[type](
52+
**{
53+
k: maybe_make_specifiable(v)
54+
for k, v in args.items()
55+
})
56+
57+
58+
def create_spec_providers():
59+
return SpecProvider({"AnomalyDetection": AnomalyDetection})
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import logging
19+
import unittest
20+
from typing import Callable
21+
22+
import apache_beam as beam
23+
from apache_beam.ml.anomaly.base import AnomalyDetector
24+
from apache_beam.ml.anomaly.specifiable import specifiable
25+
from apache_beam.testing.util import assert_that
26+
from apache_beam.testing.util import equal_to
27+
from apache_beam.utils import python_callable
28+
from apache_beam.yaml.yaml_transform import YamlTransform
29+
30+
TEST_PROVIDERS = {
31+
'PyMap': lambda fn: beam.Map(python_callable.PythonCallableWithSource(fn)),
32+
}
33+
34+
35+
@specifiable
36+
class FakeDetector(AnomalyDetector): # pylint: disable=unused-variable
37+
def __init__(self, fn: Callable):
38+
super().__init__()
39+
self._fn = fn
40+
41+
def learn_one(self, x: beam.Row) -> None:
42+
pass
43+
44+
def score_one(self, x: beam.Row) -> float:
45+
v = next(iter(x))
46+
return self._fn(v)
47+
48+
49+
class YamlSpecifiableTransformTest(unittest.TestCase):
50+
def test_specifiable_transform(self):
51+
TRAIN_DATA = [
52+
(0, beam.Row(x=1)),
53+
(0, beam.Row(x=2)),
54+
(0, beam.Row(x=2)),
55+
(0, beam.Row(x=4)),
56+
(0, beam.Row(x=9)),
57+
]
58+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
59+
pickle_library='cloudpickle')) as p:
60+
result = p | beam.Create(TRAIN_DATA) | YamlTransform(
61+
'''
62+
type: chain
63+
transforms:
64+
- type: AnomalyDetection
65+
config:
66+
detector:
67+
type: 'ZScore'
68+
config:
69+
sub_stat_tracker:
70+
type: 'IncSlidingMeanTracker'
71+
config:
72+
window_size: 5
73+
stdev_tracker:
74+
type: 'IncSlidingStdevTracker'
75+
config:
76+
window_size: 5
77+
- type: PyMap
78+
config:
79+
fn: "lambda x: (x[1].predictions[0].label)"
80+
''',
81+
providers=TEST_PROVIDERS)
82+
assert_that(result, equal_to([-2, -2, 0, 1, 1]))
83+
84+
def test_specifiable_transform_with_callable(self):
85+
TRAIN_DATA = [
86+
(0, beam.Row(x=1)),
87+
(0, beam.Row(x=2)),
88+
(0, beam.Row(x=2)),
89+
(0, beam.Row(x=4)),
90+
(0, beam.Row(x=9)),
91+
]
92+
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
93+
pickle_library='cloudpickle')) as p:
94+
result = p | beam.Create(TRAIN_DATA) | YamlTransform(
95+
'''
96+
type: chain
97+
transforms:
98+
- type: AnomalyDetection
99+
config:
100+
detector:
101+
type: 'FakeDetector'
102+
config:
103+
fn:
104+
callable: "lambda x: x * 10.0"
105+
- type: PyMap
106+
config:
107+
fn: "lambda x: (x[1].predictions[0].score)"
108+
''',
109+
providers=TEST_PROVIDERS)
110+
assert_that(result, equal_to([10.0, 20.0, 20.0, 40.0, 90.0]))
111+
112+
113+
if __name__ == '__main__':
114+
logging.getLogger().setLevel(logging.INFO)
115+
unittest.main()

0 commit comments

Comments
 (0)