diff --git a/aws_embedded_metrics/metric_scope/__init__.py b/aws_embedded_metrics/metric_scope/__init__.py index 47044bc..b185f38 100644 --- a/aws_embedded_metrics/metric_scope/__init__.py +++ b/aws_embedded_metrics/metric_scope/__init__.py @@ -18,35 +18,70 @@ def metric_scope(fn): # type: ignore + if inspect.isasyncgenfunction(fn): + @wraps(fn) + async def async_gen_wrapper(*args, **kwargs): # type: ignore + logger = create_metrics_logger() + if "metrics" in inspect.signature(fn).parameters: + kwargs["metrics"] = logger + + try: + fn_gen = fn(*args, **kwargs) + while True: + result = await fn_gen.__anext__() + await logger.flush() + yield result + except Exception as ex: + await logger.flush() + if not isinstance(ex, StopIteration): + raise + + return async_gen_wrapper + + elif inspect.isgeneratorfunction(fn): + @wraps(fn) + def gen_wrapper(*args, **kwargs): # type: ignore + logger = create_metrics_logger() + if "metrics" in inspect.signature(fn).parameters: + kwargs["metrics"] = logger + + try: + fn_gen = fn(*args, **kwargs) + while True: + result = next(fn_gen) + asyncio.run(logger.flush()) + yield result + except Exception as ex: + asyncio.run(logger.flush()) + if not isinstance(ex, StopIteration): + raise - if asyncio.iscoroutinefunction(fn): + return gen_wrapper + elif asyncio.iscoroutinefunction(fn): @wraps(fn) - async def wrapper(*args, **kwargs): # type: ignore + async def async_wrapper(*args, **kwargs): # type: ignore logger = create_metrics_logger() if "metrics" in inspect.signature(fn).parameters: kwargs["metrics"] = logger + try: return await fn(*args, **kwargs) - except Exception as e: - raise e finally: await logger.flush() - return wrapper - else: + return async_wrapper + else: @wraps(fn) def wrapper(*args, **kwargs): # type: ignore logger = create_metrics_logger() if "metrics" in inspect.signature(fn).parameters: kwargs["metrics"] = logger + try: return fn(*args, **kwargs) - except Exception as e: - raise e finally: - loop = asyncio.get_event_loop() - loop.run_until_complete(logger.flush()) + asyncio.run(logger.flush()) return wrapper diff --git a/examples/README.md b/examples/README.md index 493bf86..4527755 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,7 +8,9 @@ With Docker images, using the `awslogs` log driver will send your container logs ## ECS and Fargate -With ECS and Fargate, you can use the `awslogs` log driver to have your logs sent to CloudWatch Logs on your behalf. After configuring your task to use the `awslogs` log driver, you may write your EMF logs to STDOUT and they will be processed. +With ECS and Fargate, you can use the `awsfirelens` (recommended) or `awslogs` log driver to have your logs sent to CloudWatch Logs on your behalf. After configuring the options for your preferred log driver, you may write your EMF logs to STDOUT and they will be processed. + +[`awsfirelens` documentation](https://github.com/aws/amazon-cloudwatch-logs-for-fluent-bit) [ECS documentation on `awslogs` log driver](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/using_awslogs.html) diff --git a/setup.py b/setup.py index 27c3417..7e8cc13 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="aws-embedded-metrics", - version="3.2.0", + version="3.3.0", author="Amazon Web Services", author_email="jarnance@amazon.com", description="AWS Embedded Metrics Package", diff --git a/tests/metric_scope/test_metric_scope.py b/tests/metric_scope/test_metric_scope.py index 20bf131..9ebd1f1 100644 --- a/tests/metric_scope/test_metric_scope.py +++ b/tests/metric_scope/test_metric_scope.py @@ -168,6 +168,43 @@ def my_handler(metrics): actual_timestamp_second = int(round(logger.context.meta["Timestamp"] / 1000)) assert expected_timestamp_second == actual_timestamp_second + +def test_sync_scope_iterates_generator(mock_logger): + expected_results = [1, 2] + + @metric_scope + def my_handler(): + yield from expected_results + raise Exception("test exception") + + actual_results = [] + with pytest.raises(Exception, match="test exception"): + for result in my_handler(): + actual_results.append(result) + + assert actual_results == expected_results + assert InvocationTracker.invocations == 3 + + +@pytest.mark.asyncio +async def test_async_scope_iterates_async_generator(mock_logger): + expected_results = [1, 2] + + @metric_scope + async def my_handler(): + for item in expected_results: + yield item + await asyncio.sleep(1) + raise Exception("test exception") + + actual_results = [] + with pytest.raises(Exception, match="test exception"): + async for result in my_handler(): + actual_results.append(result) + + assert actual_results == expected_results + assert InvocationTracker.invocations == 3 + # Test helpers