Skip to content
1 change: 1 addition & 0 deletions CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ endif::[]
* Fix an issue where compressed spans would count against `transaction_max_spans` {pull}1377[#1377]
* Make sure HTTP connections are not re-used after a process fork {pull}1374[#1374]
* Update the `User-Agent` header to the new https://github.com/elastic/apm/pull/514[spec] {pull}1378[#1378]
* Improve status_code handling in AWS Lambda integration {pull}1382[#1382]


[[release-notes-6.x]]
Expand Down
61 changes: 32 additions & 29 deletions elasticapm/contrib/serverless/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import os
import platform
import time
from typing import Optional

import elasticapm
from elasticapm.base import Client, get_client
Expand Down Expand Up @@ -65,7 +66,7 @@ def handler(event, context):
return {"statusCode": r.status_code, "body": "Success!"}
"""

def __init__(self, name=None, **kwargs):
def __init__(self, name: Optional[str] = None, **kwargs) -> None:
self.name = name
self.event = {}
self.context = {}
Expand Down Expand Up @@ -163,18 +164,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
lambda: get_data_from_response(self.response, capture_headers=self.client.config.capture_headers),
"response",
)
status_code = None
try:
for k, v in self.response.items():
if k.lower() == "statuscode":
status_code = v
break
except AttributeError:
pass
if status_code:
result = "HTTP {}xx".format(int(status_code) // 100)
elasticapm.set_transaction_result(result, override=False)

if "statusCode" in self.response:
try:
result = "HTTP {}xx".format(int(self.response["statusCode"]) // 100)
elasticapm.set_transaction_result(result, override=False)
except ValueError:
logger.warning("Lambda function's statusCode was not formed as an int. Assuming 5xx result.")
elasticapm.set_transaction_result("HTTP 5xx", override=False)
if exc_val:
self.client.capture_exception(exc_info=(exc_type, exc_val, exc_tb), handled=False)
if self.source == "api":
Expand All @@ -194,7 +190,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
except ValueError:
logger.warning("flush timed out")

def set_metadata_and_context(self, coldstart):
def set_metadata_and_context(self, coldstart: bool) -> None:
"""
Process the metadata and context fields for this request
"""
Expand All @@ -218,8 +214,8 @@ def set_metadata_and_context(self, coldstart):
service_context["origin"] = {
"name": "{} {}/{}".format(
self.httpmethod,
path,
self.event["requestContext"]["stage"],
path,
)
}
service_context["origin"]["id"] = self.event["requestContext"]["apiId"]
Expand All @@ -240,9 +236,9 @@ def set_metadata_and_context(self, coldstart):
cloud_context["origin"]["region"] = record["awsRegion"]
cloud_context["origin"]["account"] = {"id": record["eventSourceARN"].split(":")[4]}
cloud_context["origin"]["provider"] = "aws"
message_context["queue"] = record["eventSourceARN"]
message_context["queue"] = service_context["origin"]["name"]
if "SentTimestamp" in record["attributes"]:
message_context["age"] = int((time.time() * 1000) - int(record["attributes"]["SentTimestamp"]))
message_context["age"] = {"ms": int((time.time() * 1000) - int(record["attributes"]["SentTimestamp"]))}
if self.client.config.capture_body in ("transactions", "all") and "body" in record:
message_context["body"] = record["body"]
if self.client.config.capture_headers and record.get("messageAttributes"):
Expand All @@ -260,15 +256,17 @@ def set_metadata_and_context(self, coldstart):
cloud_context["origin"]["region"] = record["Sns"]["TopicArn"].split(":")[3]
cloud_context["origin"]["account_id"] = record["Sns"]["TopicArn"].split(":")[4]
cloud_context["origin"]["provider"] = "aws"
message_context["queue"] = record["Sns"]["TopicArn"]
message_context["queue"] = service_context["origin"]["name"]
if "Timestamp" in record["Sns"]:
message_context["age"] = int(
(
datetime.datetime.now()
- datetime.datetime.strptime(record["Sns"]["Timestamp"], r"%Y-%m-%dT%H:%M:%S.%fZ")
).total_seconds()
* 1000
)
message_context["age"] = {
"ms": int(
(
datetime.datetime.now()
- datetime.datetime.strptime(record["Sns"]["Timestamp"], r"%Y-%m-%dT%H:%M:%S.%fZ")
).total_seconds()
* 1000
)
}
if self.client.config.capture_body in ("transactions", "all") and "Message" in record["Sns"]:
message_context["body"] = record["Sns"]["Message"]
if self.client.config.capture_headers and record["Sns"].get("MessageAttributes"):
Expand Down Expand Up @@ -321,7 +319,7 @@ def set_metadata_and_context(self, coldstart):
self.client._transport.add_metadata(metadata)


def get_data_from_request(event, capture_body=False, capture_headers=True):
def get_data_from_request(event: dict, capture_body: bool = False, capture_headers: bool = True) -> dict:
"""
Capture context data from API gateway event
"""
Expand Down Expand Up @@ -353,21 +351,26 @@ def get_data_from_request(event, capture_body=False, capture_headers=True):
return result


def get_data_from_response(response, capture_headers=True):
def get_data_from_response(response: dict, capture_headers: bool = True) -> dict:
"""
Capture response data from lambda return
"""
result = {}

if "statusCode" in response:
result["status_code"] = response["statusCode"]
try:
result["status_code"] = int(response["statusCode"])
except ValueError:
# statusCode wasn't formed as an int
# we don't log here, as we will have already logged at transaction.result handling
result["status_code"] = 500

if capture_headers and "headers" in response:
result["headers"] = response["headers"]
return result


def get_url_dict(event):
def get_url_dict(event: dict) -> dict:
"""
Reconstruct URL from API Gateway
"""
Expand Down
13 changes: 7 additions & 6 deletions tests/contrib/serverless/aws_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,22 +117,23 @@ def test_request_data(event_api, event_api2):


def test_response_data():
response = {"statusCode": 200, "headers": {"foo": "bar"}}

response = {"statusCode": "200", "headers": {"foo": "bar"}}
data = get_data_from_response(response, capture_headers=True)

assert data["status_code"] == 200
assert data["headers"]["foo"] == "bar"

response["statusCode"] = 400
data = get_data_from_response(response, capture_headers=False)

assert data["status_code"] == 200
assert data["status_code"] == 400
assert "headers" not in data

data = get_data_from_response({}, capture_headers=False)

assert not data

response["statusCode"] = "2xx"
data = get_data_from_response(response, capture_headers=True)
assert data["status_code"] == 500


def test_capture_serverless_api_gateway(event_api, context, elasticapm_client):

Expand Down