Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ endif::[]
* Change default for `sanitize_field_names` to sanitize `*auth*` instead of `authorization` {pull}1494[#1494]
* Add `span_stack_trace_min_duration` to replace deprecated `span_frames_min_duration` {pull}1498[#1498]
* Enable exact_match span compression by default {pull}1504[#1504]
* Allow parent celery tasks to specify the downstream `parent_span_id` in celery headers {pull}1500[#1500]

[float]
===== Bug fixes
Expand Down
28 changes: 17 additions & 11 deletions elasticapm/contrib/celery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,24 +68,30 @@ def set_celery_headers(headers=None, **kwargs):
transaction = execution_context.get_transaction()
if transaction is not None:
trace_parent = transaction.trace_parent
trace_parent_string = trace_parent.to_string()

headers.update({"elasticapm": {"trace_parent_string": trace_parent_string}})
# Customize parent span id (if provided)
apm_headers = headers.get("elasticapm", dict())
if "parent_span_id" in apm_headers:
trace_parent = trace_parent.copy_from()
trace_parent.span_id = apm_headers["parent_span_id"]

apm_headers["trace_parent_string"] = trace_parent.to_string()
headers.update(elasticapm=apm_headers)


def get_trace_parent(celery_task):
"""
Return a trace parent contained in the request headers of a Celery Task object or None
"""
trace_parent = None
with suppress(AttributeError, KeyError, TypeError):
if celery_task.request.headers is not None:
trace_parent_string = celery_task.request.headers["elasticapm"]["trace_parent_string"]
trace_parent = TraceParent.from_string(trace_parent_string)
else:
trace_parent_string = celery_task.request.elasticapm["trace_parent_string"]
trace_parent = TraceParent.from_string(trace_parent_string)
return trace_parent
read_from_inner_headers = lambda: celery_task.request.headers["elasticapm"]["trace_parent_string"]
read_from_request = lambda: celery_task.request.elasticapm["trace_parent_string"]

for read_fun in (read_from_request, read_from_inner_headers):
with suppress(AttributeError, KeyError, TypeError):
trace_parent_string = read_fun()
return TraceParent.from_string(trace_parent_string)

return None


def register_instrumentation(client):
Expand Down