Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions datadog_lambda/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging

from wrapt import wrap_function_wrapper as wrap
from wrapt.importer import when_imported

from datadog_lambda.tracing import get_dd_trace_context

Expand All @@ -29,7 +30,7 @@ def patch_all():
Datadog trace context.
"""
_patch_httplib()
_patch_requests()
_ensure_patch_requests()


def _patch_httplib():
Expand All @@ -45,7 +46,20 @@ def _patch_httplib():
logger.debug("Patched %s", httplib_module)


def _patch_requests():
def _ensure_patch_requests():
"""
`requests` is third-party, may not be installed or used,
but ensure it gets patched if installed and used.
"""
if "requests" in sys.modules:
# already imported, patch now
_patch_requests(sys.modules["requests"])
else:
# patch when imported
when_imported("requests")(_patch_requests)


def _patch_requests(module):
"""
Patch the high-level HTTP client module `requests`
if it's installed.
Expand All @@ -66,9 +80,9 @@ def _wrap_requests_request(func, instance, args, kwargs):
into the outgoing requests.
"""
context = get_dd_trace_context()
if "headers" in kwargs:
if "headers" in kwargs and isinstance(kwargs["headers"], dict):
kwargs["headers"].update(context)
elif len(args) >= 5:
elif len(args) >= 5 and isinstance(args[4], dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

goodness

args[4].update(context)
else:
kwargs["headers"] = context
Expand All @@ -86,9 +100,9 @@ def _wrap_httplib_request(func, instance, args, kwargs):
the Datadog trace headers into the outgoing requests.
"""
context = get_dd_trace_context()
if "headers" in kwargs:
if "headers" in kwargs and isinstance(kwargs["headers"], dict):
kwargs["headers"].update(context)
elif len(args) >= 4:
elif len(args) >= 4 and isinstance(args[3], dict):
args[3].update(context)
else:
kwargs["headers"] = context
Expand Down
56 changes: 33 additions & 23 deletions datadog_lambda/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,45 @@ class _LambdaDecorator(object):
and extracts/injects trace context.
"""

_force_new = False
_force_wrap = False

def __new__(cls, func):
"""
If the decorator is accidentally applied to the same function multiple times,
only the first one takes effect.
wrap only once.

If _force_new, always return a real decorator, useful for unit tests.
If _force_wrap, always return a real decorator, useful for unit tests.
"""
if cls._force_new or not getattr(func, "_dd_wrapped", False):
wrapped = super(_LambdaDecorator, cls).__new__(cls)
wrapped._dd_wrapped = True
return wrapped
else:
return _NoopDecorator(func)
try:
if cls._force_wrap or not isinstance(func, _LambdaDecorator):
wrapped = super(_LambdaDecorator, cls).__new__(cls)
logger.debug("datadog_lambda_wrapper wrapped")
return wrapped
else:
logger.debug("datadog_lambda_wrapper already wrapped")
return _NoopDecorator(func)
except Exception:
traceback.print_exc()
return func

def __init__(self, func):
"""Executes when the wrapped function gets wrapped"""
self.func = func
self.flush_to_log = os.environ.get("DD_FLUSH_TO_LOG", "").lower() == "true"
self.logs_injection = (
os.environ.get("DD_LOGS_INJECTION", "true").lower() == "true"
)

# Inject trace correlation ids to logs
if self.logs_injection:
inject_correlation_ids()

# Patch HTTP clients to propagate Datadog trace context
patch_all()
logger.debug("datadog_lambda_wrapper initialized")
try:
self.func = func
self.flush_to_log = os.environ.get("DD_FLUSH_TO_LOG", "").lower() == "true"
self.logs_injection = (
os.environ.get("DD_LOGS_INJECTION", "true").lower() == "true"
)

# Inject trace correlation ids to logs
if self.logs_injection:
inject_correlation_ids()

# Patch HTTP clients to propagate Datadog trace context
patch_all()
logger.debug("datadog_lambda_wrapper initialized")
except Exception:
traceback.print_exc()

def __call__(self, event, context, **kwargs):
"""Executes when the wrapped function gets called"""
Expand All @@ -97,21 +105,23 @@ def __call__(self, event, context, **kwargs):
self._after(event, context)

def _before(self, event, context):
set_cold_start()
try:
set_cold_start()
submit_invocations_metric(context)
# Extract Datadog trace context from incoming requests
extract_dd_trace_context(event)

# Set log correlation ids using extracted trace context
set_correlation_ids()
logger.debug("datadog_lambda_wrapper _before() done")
except Exception:
traceback.print_exc()

def _after(self, event, context):
try:
if not self.flush_to_log:
lambda_stats.flush(float("inf"))
logger.debug("datadog_lambda_wrapper _after() done")
except Exception:
traceback.print_exc()

Expand Down
2 changes: 0 additions & 2 deletions tests/integration/handle.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from datadog_lambda.metric import lambda_metric
from datadog_lambda.wrapper import datadog_lambda_wrapper

Expand Down
5 changes: 2 additions & 3 deletions tests/integration/http_requests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import requests

from datadog_lambda.metric import lambda_metric
Expand All @@ -12,7 +11,7 @@ def handle(event, context):
"tests.integration.count", 21, tags=["test:integration", "role:hello"]
)

us_response = requests.get("https://ip-ranges.datadoghq.com/")
eu_response = requests.get("https://ip-ranges.datadoghq.eu/")
requests.get("https://ip-ranges.datadoghq.com/")
requests.get("https://ip-ranges.datadoghq.eu/")

return {"statusCode": 200, "body": {"message": "hello, dog!"}}
37 changes: 28 additions & 9 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

from datadog_lambda.patch import (
_patch_httplib,
_patch_requests,
_ensure_patch_requests,
)
from datadog_lambda.constants import TraceHeader


class TestPatchHTTPClients(unittest.TestCase):

def setUp(self):
patcher = patch('datadog_lambda.patch.get_dd_trace_context')
patcher = patch("datadog_lambda.patch.get_dd_trace_context")
self.mock_get_dd_trace_context = patcher.start()
self.mock_get_dd_trace_context.return_value = {
TraceHeader.TRACE_ID: '123',
TraceHeader.PARENT_ID: '321',
TraceHeader.SAMPLING_PRIORITY: '2',
TraceHeader.TRACE_ID: "123",
TraceHeader.PARENT_ID: "321",
TraceHeader.SAMPLING_PRIORITY: "2",
}
self.addCleanup(patcher.stop)

Expand All @@ -34,10 +34,29 @@ def test_patch_httplib(self):
self.mock_get_dd_trace_context.assert_called()

def test_patch_requests(self):
_patch_requests()
_ensure_patch_requests()
import requests
r = requests.get("https://www.datadoghq.com/")
self.mock_get_dd_trace_context.assert_called()
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], '123')
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], '321')
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], '2')
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], "123")
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], "321")
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], "2")

def test_patch_requests_with_headers(self):
_ensure_patch_requests()
import requests
r = requests.get("https://www.datadoghq.com/", headers={"key": "value"})
self.mock_get_dd_trace_context.assert_called()
self.assertEqual(r.request.headers["key"], "value")
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], "123")
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], "321")
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], "2")

def test_patch_requests_with_headers_none(self):
_ensure_patch_requests()
import requests
r = requests.get("https://www.datadoghq.com/", headers=None)
self.mock_get_dd_trace_context.assert_called()
self.assertEqual(r.request.headers[TraceHeader.TRACE_ID], "123")
self.assertEqual(r.request.headers[TraceHeader.PARENT_ID], "321")
self.assertEqual(r.request.headers[TraceHeader.SAMPLING_PRIORITY], "2")
6 changes: 3 additions & 3 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TestDatadogLambdaWrapper(unittest.TestCase):
def setUp(self):
# Force @datadog_lambda_wrapper to always create a real
# (not no-op) wrapper.
datadog_lambda_wrapper._force_new = True
datadog_lambda_wrapper._force_wrap = True

patcher = patch("datadog_lambda.metric.lambda_stats")
self.mock_metric_lambda_stats = patcher.start()
Expand Down Expand Up @@ -265,9 +265,9 @@ def test_only_one_wrapper_in_use(self):
def lambda_handler(event, context):
lambda_metric("test.metric", 100)

# Turn off _force_new to emulate the nested wrapper scenario,
# Turn off _force_wrap to emulate the nested wrapper scenario,
# the second @datadog_lambda_wrapper should actually be no-op.
datadog_lambda_wrapper._force_new = False
datadog_lambda_wrapper._force_wrap = False

lambda_handler_double_wrapped = datadog_lambda_wrapper(lambda_handler)

Expand Down