Skip to content
9 changes: 6 additions & 3 deletions datadog_lambda/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import logging
from datadog_lambda.cold_start import initialize_cold_start_tracing

initialize_cold_start_tracing()

# The minor version corresponds to the Lambda layer version.
# E.g.,, version 0.5.0 gets packaged into layer version 5.
try:
Expand All @@ -7,8 +13,5 @@

__version__ = importlib_metadata.version(__name__)

import os
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.getLevelName(os.environ.get("DD_LOG_LEVEL", "INFO").upper()))
200 changes: 200 additions & 0 deletions datadog_lambda/cold_start.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import time
import os
from typing import List, Hashable
import logging

logger = logging.getLogger(__name__)

_cold_start = True
_lambda_container_initialized = False

Expand All @@ -21,3 +28,196 @@ def is_cold_start():
def get_cold_start_tag():
"""Returns the cold start tag to be used in metrics"""
return "cold_start:{}".format(str(is_cold_start()).lower())


class ImportNode(object):
def __init__(self, module_name, full_file_path, start_time_ns, end_time_ns=None):
self.module_name = module_name
self.full_file_path = full_file_path
self.start_time_ns = start_time_ns
self.end_time_ns = end_time_ns
self.children = []


root_nodes: List[ImportNode] = []
import_stack: List[ImportNode] = []
already_wrapped_loaders = set()


def reset_node_stacks():
global root_nodes
root_nodes = []
global import_stack
import_stack = []


def push_node(module_name, file_path):
node = ImportNode(module_name, file_path, time.time_ns())
global import_stack
if import_stack:
import_stack[-1].children.append(node)
import_stack.append(node)


def pop_node(module_name):
global import_stack
if not import_stack:
return
node = import_stack.pop()
if node.module_name != module_name:
return
end_time_ns = time.time_ns()
node.end_time_ns = end_time_ns
if not import_stack: # import_stack empty, a root node has been found
global root_nodes
root_nodes.append(node)


def wrap_exec_module(original_exec_module):
def wrapped_method(module):
should_pop = False
try:
spec = module.__spec__
push_node(spec.name, spec.origin)
should_pop = True
except Exception:
pass
try:
return original_exec_module(module)
finally:
if should_pop:
pop_node(spec.name)

return wrapped_method


def wrap_find_spec(original_find_spec):
def wrapped_find_spec(*args, **kwargs):
spec = original_find_spec(*args, **kwargs)
if spec is None:
return None
loader = getattr(spec, "loader", None)
if (
loader is not None
and isinstance(loader, Hashable)
and loader not in already_wrapped_loaders
):
if hasattr(loader, "exec_module"):
try:
loader.exec_module = wrap_exec_module(loader.exec_module)
already_wrapped_loaders.add(loader)
except Exception as e:
logger.debug("Failed to wrap the loader. %s", e)
return spec

return wrapped_find_spec


def initialize_cold_start_tracing():
if (
is_cold_start()
and os.environ.get("DD_TRACE_ENABLED", "true").lower() == "true"
and os.environ.get("DD_COLD_START_TRACING", "true").lower() == "true"
):
from sys import version_info, meta_path

if version_info >= (3, 7): # current implementation only support version > 3.7
for importer in meta_path:
try:
importer.find_spec = wrap_find_spec(importer.find_spec)
except Exception:
pass


class ColdStartTracer(object):
def __init__(
self,
tracer,
function_name,
cold_start_span_finish_time_ns,
trace_ctx,
min_duration_ms: int,
ignored_libs: List[str] = [],
):
self._tracer = tracer
self.function_name = function_name
self.cold_start_span_finish_time_ns = cold_start_span_finish_time_ns
self.min_duration_ms = min_duration_ms
self.trace_ctx = trace_ctx
self.ignored_libs = ignored_libs
self.need_to_reactivate_context = True

def trace(self, root_nodes: List[ImportNode] = root_nodes):
if not root_nodes:
return
cold_start_span_start_time_ns = root_nodes[0].start_time_ns
cold_start_span = self.create_cold_start_span(cold_start_span_start_time_ns)
while root_nodes:
root_node = root_nodes.pop()
self.trace_tree(root_node, cold_start_span)
self.finish_span(cold_start_span, self.cold_start_span_finish_time_ns)

def trace_tree(self, import_node: ImportNode, parent_span):
if (
import_node.end_time_ns - import_node.start_time_ns
< self.min_duration_ms * 1e6
or import_node.module_name in self.ignored_libs
):
return

span = self.start_span(
"aws.lambda.import", import_node.module_name, import_node.start_time_ns
)
tags = {
"resource_names": import_node.module_name,
"resource.name": import_node.module_name,
"filename": import_node.full_file_path,
"operation_name": self.get_operation_name(import_node.full_file_path),
}
span.set_tags(tags)
if parent_span:
span.parent_id = parent_span.span_id
for child_node in import_node.children:
self.trace_tree(child_node, span)
self.finish_span(span, import_node.end_time_ns)

def create_cold_start_span(self, start_time_ns):
span = self.start_span("aws.lambda.load", self.function_name, start_time_ns)
tags = {
"resource_names": self.function_name,
"resource.name": self.function_name,
"operation_name": "aws.lambda.load",
}
span.set_tags(tags)
return span

def start_span(self, span_type, resource, start_time_ns):
if self.need_to_reactivate_context:
self._tracer.context_provider.activate(
self.trace_ctx
) # reactivate required after each finish() call
self.need_to_reactivate_context = False
span_kwargs = {
"service": "aws.lambda",
"resource": resource,
"span_type": span_type,
}
span = self._tracer.trace(span_type, **span_kwargs)
span.start_ns = start_time_ns
return span

def finish_span(self, span, finish_time_ns):
span.finish(finish_time_ns / 1e9)
self.need_to_reactivate_context = True

def get_operation_name(self, filename: str):
if filename is None:
return "aws.lambda.import_core_module"
if not isinstance(filename, str):
return "aws.lambda.import"
if filename.startswith("/opt/"):
return "aws.lambda.import_layer"
elif filename.startswith("/var/lang/"):
return "aws.lambda.import_runtime"
else:
return "aws.lambda.import"
45 changes: 43 additions & 2 deletions datadog_lambda/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# under the Apache License Version 2.0.
# This product includes software developed at Datadog (https://www.datadoghq.com/).
# Copyright 2019 Datadog, Inc.

import base64
import os
import logging
Expand All @@ -12,7 +11,7 @@
from time import time_ns

from datadog_lambda.extension import should_use_extension, flush_extension
from datadog_lambda.cold_start import set_cold_start, is_cold_start
from datadog_lambda.cold_start import set_cold_start, is_cold_start, ColdStartTracer
from datadog_lambda.constants import (
TraceContextSource,
XraySubsegment,
Expand All @@ -38,6 +37,7 @@
create_inferred_span,
InferredSpanInfo,
is_authorizer_response,
tracer,
)
from datadog_lambda.trigger import (
extract_trigger_tags,
Expand Down Expand Up @@ -131,6 +131,28 @@ def __init__(self, func):
self.decode_authorizer_context = (
os.environ.get("DD_DECODE_AUTHORIZER_CONTEXT", "true").lower() == "true"
)
self.cold_start_tracing = (
os.environ.get("DD_COLD_START_TRACING", "true").lower() == "true"
)
self.min_cold_start_trace_duration = 3
if "DD_MIN_COLD_START_DURATION" in os.environ:
try:
self.min_cold_start_trace_duration = int(
os.environ["DD_MIN_COLD_START_DURATION"]
)
except Exception:
logger.debug("Malformatted env DD_MIN_COLD_START_DURATION")
self.cold_start_trace_skip_lib = [
"ddtrace.internal.compat",
"ddtrace.filters",
]
if "DD_COLD_START_TRACE_SKIP_LIB" in os.environ:
try:
self.cold_start_trace_skip_lib = os.environ[
"DD_COLD_START_TRACE_SKIP_LIB"
].split(",")
except Exception:
logger.debug("Malformatted for env DD_COLD_START_TRACE_SKIP_LIB")
self.response = None
if profiling_env_var:
self.prof = profiler.Profiler(env=env_env_var, service=service_env_var)
Expand Down Expand Up @@ -257,6 +279,11 @@ def _after(self, event, context):
create_dd_dummy_metadata_subsegment(
self.trigger_tags, XraySubsegment.LAMBDA_FUNCTION_TAGS_KEY
)
should_trace_cold_start = (
dd_tracing_enabled and self.cold_start_tracing and is_cold_start()
)
if should_trace_cold_start:
trace_ctx = tracer.current_trace_context()

if self.span:
if dd_capture_lambda_payload_enabled:
Expand All @@ -276,6 +303,20 @@ def _after(self, event, context):
else:
self.inferred_span.finish()

if should_trace_cold_start:
try:
following_span = self.span or self.inferred_span
ColdStartTracer(
tracer,
self.function_name,
following_span.start_ns,
trace_ctx,
self.min_cold_start_trace_duration,
self.cold_start_trace_skip_lib,
).trace()
except Exception as e:
logger.debug("Failed to create cold start spans. %s", e)

if not self.flush_to_log or should_use_extension:
flush_stats()
if should_use_extension:
Expand Down
1 change: 1 addition & 0 deletions tests/integration/serverless.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ provider:
DD_TRACE_ENABLED: true
DD_API_KEY: ${env:DD_API_KEY}
DD_TRACE_MANAGED_SERVICES: true
DD_COLD_START_TRACING: false
timeout: 15
deploymentBucket:
name: integration-tests-serververless-deployment-bucket
Expand Down
Loading