Skip to content
Draft
227 changes: 123 additions & 104 deletions datadog_lambda/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def extract_context_from_sqs_or_sns_event_or_context(
Falls back to lambda context if no trace data is found in the SQS message attributes.
Set a DSM checkpoint if DSM is enabled and the method for context propagation is supported.
"""
source_arn = ""
event_type = "sqs" if event_source.equals(EventTypes.SQS) else "sns"

# EventBridge => SQS
Expand All @@ -248,91 +247,105 @@ def extract_context_from_sqs_or_sns_event_or_context(
except Exception:
logger.debug("Failed extracting context as EventBridge to SQS.")

try:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if context is extracted from event bridge, we don't set a checkpoint. Is that expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tracers never inject DSM context in the case of event bridge or step functions. I'm not sure this is the PR to be adding the functionality for these event types

first_record = event.get("Records")[0]
source_arn = first_record.get("eventSourceARN", "")

# logic to deal with SNS => SQS event
if "body" in first_record:
body_str = first_record.get("body")
try:
body = json.loads(body_str)
if body.get("Type", "") == "Notification" and "TopicArn" in body:
logger.debug("Found SNS message inside SQS event")
first_record = get_first_record(create_sns_event(body))
except Exception:
pass

msg_attributes = first_record.get("messageAttributes")
if msg_attributes is None:
sns_record = first_record.get("Sns") or {}
# SNS->SQS event would extract SNS arn without this check
if event_source.equals(EventTypes.SNS):
source_arn = sns_record.get("TopicArn", "")
msg_attributes = sns_record.get("MessageAttributes") or {}
dd_payload = msg_attributes.get("_datadog")
if dd_payload:
# SQS uses dataType and binaryValue/stringValue
# SNS uses Type and Value
dd_json_data = None
dd_json_data_type = dd_payload.get("Type") or dd_payload.get("dataType")
if dd_json_data_type == "Binary":
import base64

dd_json_data = dd_payload.get("binaryValue") or dd_payload.get("Value")
if dd_json_data:
dd_json_data = base64.b64decode(dd_json_data)
elif dd_json_data_type == "String":
dd_json_data = dd_payload.get("stringValue") or dd_payload.get("Value")
else:
logger.debug(
"Datadog Lambda Python only supports extracting trace"
"context from String or Binary SQS/SNS message attributes"
)

if dd_json_data:
dd_data = json.loads(dd_json_data)

if is_step_function_event(dd_data):
apm_context: Context = None
for record in event.get("Records", []):
source_arn = (
record.get("eventSourceARN")
if event_type == "sqs"
else record.get("Sns", {}).get("TopicArn")
)
dd_ctx = None
try:
dd_ctx = _extract_context_from_sqs_or_sns_record(record)
if apm_context is None:
if dd_ctx and is_step_function_event(dd_ctx):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the DSM context can't be in a step_funtion_event?

In any case, the logic to get the apm context from the dd_ctx, should be in it's own function I this.

That function can be the _extract_context_from_xray that you can rename to:
_extract_apm_context, and it can take the parameters: dd_ctx and record.

Then, code here can be:

if apm_context is None: apm_context = _extract_apm_context(dd_ctx, record) 

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make the code easier to read, but also less error prone.

Here, if the context is extracted from step function, we are not setting a checkpoint. I don't think this is what we want?

try:
return extract_context_from_step_functions(dd_data, None)
return extract_context_from_step_functions(dd_ctx, None)
except Exception:
logger.debug(
"Failed to extract Step Functions context from SQS/SNS event."
)
context = propagator.extract(dd_data)
_dsm_set_checkpoint(dd_data, event_type, source_arn)
return context
elif not dd_ctx:
apm_context = _extract_context_from_xray(record)
else:
apm_context = propagator.extract(dd_ctx)
except Exception as e:
logger.debug("The trace extractor returned with error %s", e)
if config.data_streams_enabled:
_dsm_set_checkpoint(dd_ctx, event_type, source_arn)
if not config.data_streams_enabled:
break

return (
apm_context
if apm_context
else extract_context_from_lambda_context(lambda_context)
)


def _extract_context_from_sqs_or_sns_record(record):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that function looks great to me

# logic to deal with SNS => SQS event
if "body" in record:
body_str = record.get("body")
try:
body = json.loads(body_str)
if body.get("Type", "") == "Notification" and "TopicArn" in body:
logger.debug("Found SNS message inside SQS event")
record = get_first_record(create_sns_event(body))
except Exception:
pass

msg_attributes = record.get("messageAttributes")
if msg_attributes is None:
sns_record = record.get("Sns") or {}
msg_attributes = sns_record.get("MessageAttributes") or {}
dd_payload = msg_attributes.get("_datadog")
if dd_payload:
# SQS uses dataType and binaryValue/stringValue
# SNS uses Type and Value
dd_json_data = None
dd_json_data_type = dd_payload.get("Type") or dd_payload.get("dataType")
if dd_json_data_type == "Binary":
import base64

dd_json_data = dd_payload.get("binaryValue") or dd_payload.get("Value")
if dd_json_data:
dd_json_data = base64.b64decode(dd_json_data)
elif dd_json_data_type == "String":
dd_json_data = dd_payload.get("stringValue") or dd_payload.get("Value")
else:
# Handle case where trace context is injected into attributes.AWSTraceHeader
# example: Root=1-654321ab-000000001234567890abcdef;Parent=0123456789abcdef;Sampled=1
attrs = event.get("Records")[0].get("attributes")
if attrs:
x_ray_header = attrs.get("AWSTraceHeader")
if x_ray_header:
x_ray_context = parse_xray_header(x_ray_header)
trace_id_parts = x_ray_context.get("trace_id", "").split("-")
if len(trace_id_parts) > 2 and trace_id_parts[2].startswith(
DD_TRACE_JAVA_TRACE_ID_PADDING
):
# If it starts with eight 0's padding,
# then this AWSTraceHeader contains Datadog injected trace context
logger.debug(
"Found dd-trace injected trace context from AWSTraceHeader"
)
return Context(
trace_id=int(trace_id_parts[2][8:], 16),
span_id=int(x_ray_context["parent_id"], 16),
sampling_priority=float(x_ray_context["sampled"]),
)
# Still want to set a DSM checkpoint even if DSM context not propagated
_dsm_set_checkpoint(None, event_type, source_arn)
return extract_context_from_lambda_context(lambda_context)
except Exception as e:
logger.debug("The trace extractor returned with error %s", e)
# Still want to set a DSM checkpoint even if DSM context not propagated
_dsm_set_checkpoint(None, event_type, source_arn)
return extract_context_from_lambda_context(lambda_context)
logger.debug(
"Datadog Lambda Python only supports extracting trace"
"context from String or Binary SQS/SNS message attributes"
)

if dd_json_data:
dd_data = json.loads(dd_json_data)
return dd_data
return None


def _extract_context_from_xray(record):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as mentioned above, I would change function to extract the apm context. Not just from xray.

attrs = record.get("attributes")
if attrs:
x_ray_header = attrs.get("AWSTraceHeader")
if x_ray_header:
x_ray_context = parse_xray_header(x_ray_header)
trace_id_parts = x_ray_context.get("trace_id", "").split("-")
if len(trace_id_parts) > 2 and trace_id_parts[2].startswith(
DD_TRACE_JAVA_TRACE_ID_PADDING
):
# If it starts with eight 0's padding,
# then this AWSTraceHeader contains Datadog injected trace context
logger.debug(
"Found dd-trace injected trace context from AWSTraceHeader"
)
return Context(
trace_id=int(trace_id_parts[2][8:], 16),
span_id=int(x_ray_context["parent_id"], 16),
sampling_priority=float(x_ray_context["sampled"]),
)
return None


def _extract_context_from_eventbridge_sqs_event(event):
Expand Down Expand Up @@ -392,31 +405,37 @@ def extract_context_from_kinesis_event(event, lambda_context):
Extract datadog trace context from a Kinesis Stream's base64 encoded data string
Set a DSM checkpoint if DSM is enabled and the method for context propagation is supported.
"""
source_arn = ""
try:
record = get_first_record(event)
source_arn = record.get("eventSourceARN", "")
kinesis = record.get("kinesis")
if not kinesis:
return extract_context_from_lambda_context(lambda_context)
data = kinesis.get("data")
if data:
import base64

b64_bytes = data.encode("ascii")
str_bytes = base64.b64decode(b64_bytes)
data_str = str_bytes.decode("ascii")
data_obj = json.loads(data_str)
dd_ctx = data_obj.get("_datadog")
if dd_ctx:
context = propagator.extract(dd_ctx)
_dsm_set_checkpoint(dd_ctx, "kinesis", source_arn)
return context
except Exception as e:
logger.debug("The trace extractor returned with error %s", e)
# Still want to set a DSM checkpoint even if DSM context not propagated
_dsm_set_checkpoint(None, "kinesis", source_arn)
return extract_context_from_lambda_context(lambda_context)
apm_context: Context = None
for record in event.get("Records", []):
dd_ctx = None
try:
source_arn = record.get("eventSourceARN", "")
kinesis = record.get("kinesis")
if not kinesis:
return extract_context_from_lambda_context(lambda_context)
data = kinesis.get("data")
if data:
import base64

b64_bytes = data.encode("ascii")
str_bytes = base64.b64decode(b64_bytes)
data_str = str_bytes.decode("ascii")
data_obj = json.loads(data_str)
dd_ctx = data_obj.get("_datadog")
if dd_ctx and apm_context is None:
apm_context = propagator.extract(dd_ctx)
except Exception as e:
logger.debug("The trace extractor returned with error %s", e)
if config.data_streams_enabled:
_dsm_set_checkpoint(dd_ctx, "kinesis", source_arn)
if not config.data_streams_enabled:
break
return (
apm_context
if apm_context
else extract_context_from_lambda_context(lambda_context)
)


def _deterministic_sha256_hash(s: str, part: str) -> int:
Expand Down
Loading
Loading