diff --git a/temporalio/contrib/opentelemetry.py b/temporalio/contrib/opentelemetry.py index 2e0f43201..84773fd43 100644 --- a/temporalio/contrib/opentelemetry.py +++ b/temporalio/contrib/opentelemetry.py @@ -74,12 +74,22 @@ class should return the workflow interceptor subclass from def __init__( self, tracer: Optional[opentelemetry.trace.Tracer] = None, + *, + always_create_workflow_spans: bool = False, ) -> None: """Initialize a OpenTelemetry tracing interceptor. Args: tracer: The tracer to use. Defaults to :py:func:`opentelemetry.trace.get_tracer`. + always_create_workflow_spans: When false, the default, spans are + only created in workflows when an overarching span from the + client is present. In cases of starting a workflow elsewhere, + e.g. CLI or schedules, a client-created span is not present and + workflow spans will not be created. Setting this to true will + create spans in workflows no matter what, but there is a risk of + them being orphans since they may not have a parent span after + replaying. """ self.tracer = tracer or opentelemetry.trace.get_tracer(__name__) # To customize any of this, users must subclass. We intentionally don't @@ -90,6 +100,7 @@ def __init__( self.text_map_propagator: opentelemetry.propagators.textmap.TextMapPropagator = default_text_map_propagator # TODO(cretz): Should I be using the configured one at the client and activity level? self.payload_converter = temporalio.converter.PayloadConverter.default + self._always_create_workflow_spans = always_create_workflow_spans def intercept_client( self, next: temporalio.client.OutboundInterceptor @@ -165,10 +176,15 @@ def _start_as_current_span( def _completed_workflow_span( self, params: _CompletedWorkflowSpanParams - ) -> _CarrierDict: + ) -> Optional[_CarrierDict]: # Carrier to context, start span, set span as current on context, # context back to carrier + # If the parent is missing and user hasn't said to always create, do not + # create + if params.parent_missing and not self._always_create_workflow_spans: + return None + # Extract the context context = self.text_map_propagator.extract(params.context) # Create link if there is a span present @@ -286,7 +302,7 @@ class _InputWithHeaders(Protocol): class _WorkflowExternFunctions(TypedDict): __temporal_opentelemetry_completed_span: Callable[ - [_CompletedWorkflowSpanParams], _CarrierDict + [_CompletedWorkflowSpanParams], Optional[_CarrierDict] ] @@ -299,6 +315,7 @@ class _CompletedWorkflowSpanParams: link_context: Optional[_CarrierDict] exception: Optional[Exception] kind: opentelemetry.trace.SpanKind + parent_missing: bool _interceptor_context_key = opentelemetry.context.create_key( @@ -529,17 +546,13 @@ def _completed_span( exception: Optional[Exception] = None, kind: opentelemetry.trace.SpanKind = opentelemetry.trace.SpanKind.INTERNAL, ) -> None: - # If there is no span on the context, we do not create a span - if opentelemetry.trace.get_current_span() is opentelemetry.trace.INVALID_SPAN: - return None - # If we are replaying and they don't want a span on replay, no span if temporalio.workflow.unsafe.is_replaying() and not new_span_even_on_replay: return None # Create the span. First serialize current context to carrier. - context_carrier: _CarrierDict = {} - self.text_map_propagator.inject(context_carrier) + new_context_carrier: _CarrierDict = {} + self.text_map_propagator.inject(new_context_carrier) # Invoke info = temporalio.workflow.info() attributes: Dict[str, opentelemetry.util.types.AttributeValue] = { @@ -548,11 +561,11 @@ def _completed_span( } if additional_attributes: attributes.update(additional_attributes) - context_carrier = self._extern_functions[ + updated_context_carrier = self._extern_functions[ "__temporal_opentelemetry_completed_span" ]( _CompletedWorkflowSpanParams( - context=context_carrier, + context=new_context_carrier, name=span_name, # Always set span attributes as workflow ID and run ID attributes=attributes, @@ -560,13 +573,15 @@ def _completed_span( link_context=link_context_carrier, exception=exception, kind=kind, + parent_missing=opentelemetry.trace.get_current_span() + is opentelemetry.trace.INVALID_SPAN, ) ) # Add to outbound if needed - if add_to_outbound: + if add_to_outbound and updated_context_carrier: add_to_outbound.headers = self._context_carrier_to_headers( - context_carrier, add_to_outbound.headers + updated_context_carrier, add_to_outbound.headers ) def _set_on_context( diff --git a/tests/contrib/test_opentelemetry.py b/tests/contrib/test_opentelemetry.py index e9969aa89..5ecf49126 100644 --- a/tests/contrib/test_opentelemetry.py +++ b/tests/contrib/test_opentelemetry.py @@ -332,6 +332,56 @@ def dump_spans( return ret +@workflow.defn +class SimpleWorkflow: + @workflow.run + async def run(self) -> str: + return "done" + + +async def test_opentelemetry_always_create_workflow_spans(client: Client): + # Create a tracer that has an in-memory exporter + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + tracer = get_tracer(__name__, tracer_provider=provider) + + # Create a worker with an interceptor without always create + async with Worker( + client, + task_queue=f"task_queue_{uuid.uuid4()}", + workflows=[SimpleWorkflow], + interceptors=[TracingInterceptor(tracer)], + ) as worker: + assert "done" == await client.execute_workflow( + SimpleWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # Confirm the spans are not there + spans = exporter.get_finished_spans() + logging.debug("Spans:\n%s", "\n".join(dump_spans(spans, with_attributes=False))) + assert len(spans) == 0 + + # Now create a worker with an interceptor with always create + async with Worker( + client, + task_queue=f"task_queue_{uuid.uuid4()}", + workflows=[SimpleWorkflow], + interceptors=[TracingInterceptor(tracer, always_create_workflow_spans=True)], + ) as worker: + assert "done" == await client.execute_workflow( + SimpleWorkflow.run, + id=f"workflow_{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + # Confirm the spans are not there + spans = exporter.get_finished_spans() + logging.debug("Spans:\n%s", "\n".join(dump_spans(spans, with_attributes=False))) + assert len(spans) > 0 + assert spans[0].name == "RunWorkflow:SimpleWorkflow" + + # TODO(cretz): Additional tests to write # * query without interceptor (no headers) # * workflow without interceptor (no headers) but query with interceptor (headers)