38
38
from typing import Optional
39
39
40
40
import elasticapm
41
- from elasticapm .base import Client , get_client
41
+ from elasticapm .base import Client
42
42
from elasticapm .conf import constants
43
43
from elasticapm .utils import encoding , get_name_from_func , nested_key
44
44
from elasticapm .utils .disttracing import TraceParent
@@ -66,25 +66,25 @@ def handler(event, context):
66
66
return {"statusCode": r.status_code, "body": "Success!"}
67
67
"""
68
68
69
- def __init__ (self , name : Optional [str ] = None , ** kwargs ) -> None :
69
+ def __init__ (self , name : Optional [str ] = None , elasticapm_client : Optional [ Client ] = None , ** kwargs ) -> None :
70
70
self .name = name
71
71
self .event = {}
72
72
self .context = {}
73
73
self .response = None
74
+ self .instrumented = False
75
+ self .client = elasticapm_client # elasticapm_client is intended for testing only
74
76
75
77
# Disable all background threads except for transport
76
78
kwargs ["metrics_interval" ] = "0ms"
77
79
kwargs ["central_config" ] = False
78
80
kwargs ["cloud_provider" ] = "none"
79
81
kwargs ["framework_name" ] = "AWS Lambda"
80
- if "service_name" not in kwargs :
82
+ if "service_name" not in kwargs and "ELASTIC_APM_SERVICE_NAME" not in os . environ :
81
83
kwargs ["service_name" ] = os .environ ["AWS_LAMBDA_FUNCTION_NAME" ]
84
+ if "service_version" not in kwargs and "ELASTIC_APM_SERVICE_VERSION" not in os .environ :
85
+ kwargs ["service_version" ] = os .environ .get ("AWS_LAMBDA_FUNCTION_VERSION" )
82
86
83
- self .client = get_client ()
84
- if not self .client :
85
- self .client = Client (** kwargs )
86
- if not self .client .config .debug and self .client .config .instrument and self .client .config .enabled :
87
- elasticapm .instrument ()
87
+ self .client_kwargs = kwargs
88
88
89
89
def __call__ (self , func ):
90
90
self .name = self .name or get_name_from_func (func )
@@ -96,6 +96,21 @@ def decorated(*args, **kwds):
96
96
self .event , self .context = args
97
97
else :
98
98
self .event , self .context = {}, {}
99
+ # We delay client creation until the function is called, so that
100
+ # multiple @capture_serverless instances in the same file don't create
101
+ # multiple clients
102
+ if not self .client :
103
+ # Don't use get_client() as we may have a config mismatch due to **kwargs
104
+ self .client = Client (** self .client_kwargs )
105
+ if (
106
+ not self .instrumented
107
+ and not self .client .config .debug
108
+ and self .client .config .instrument
109
+ and self .client .config .enabled
110
+ ):
111
+ elasticapm .instrument ()
112
+ self .instrumented = True
113
+
99
114
if not self .client .config .debug and self .client .config .instrument and self .client .config .enabled :
100
115
with self :
101
116
self .response = func (* args , ** kwds )
@@ -124,10 +139,21 @@ def __enter__(self):
124
139
)
125
140
if self .httpmethod : # API Gateway
126
141
self .source = "api"
127
- if os .environ .get ("AWS_LAMBDA_FUNCTION_NAME" ):
128
- transaction_name = "{} {}" .format (self .httpmethod , os .environ ["AWS_LAMBDA_FUNCTION_NAME" ])
142
+ if nested_key (self .event , "requestContext" , "httpMethod" ):
143
+ # API v1
144
+ resource = "/{}{}" .format (
145
+ nested_key (self .event , "requestContext" , "stage" ),
146
+ nested_key (self .event , "requestContext" , "resourcePath" ),
147
+ )
129
148
else :
130
- transaction_name = self .name
149
+ # API v2
150
+ route_key = nested_key (self .event , "requestContext" , "routeKey" )
151
+ route_key = f"/{ route_key } " if route_key .startswith ("$" ) else route_key .split (" " , 1 )[- 1 ]
152
+ resource = "/{}{}" .format (
153
+ nested_key (self .event , "requestContext" , "stage" ),
154
+ route_key ,
155
+ )
156
+ transaction_name = "{} {}" .format (self .httpmethod , resource )
131
157
elif "Records" in self .event and len (self .event ["Records" ]) == 1 :
132
158
record = self .event ["Records" ][0 ]
133
159
if record .get ("eventSource" ) == "aws:s3" : # S3
@@ -203,21 +229,17 @@ def set_metadata_and_context(self, coldstart: bool) -> None:
203
229
faas ["coldstart" ] = coldstart
204
230
faas ["trigger" ] = {"type" : "other" }
205
231
faas ["execution" ] = self .context .aws_request_id
232
+ arn = self .context .invoked_function_arn
233
+ if len (arn .split (":" )) > 7 :
234
+ arn = ":" .join (arn .split (":" )[:7 ])
235
+ faas ["id" ] = arn
236
+ faas ["name" ] = os .environ .get ("AWS_LAMBDA_FUNCTION_NAME" )
237
+ faas ["version" ] = os .environ .get ("AWS_LAMBDA_FUNCTION_VERSION" )
206
238
207
239
if self .source == "api" :
208
240
faas ["trigger" ]["type" ] = "http"
209
241
faas ["trigger" ]["request_id" ] = self .event ["requestContext" ]["requestId" ]
210
- path = (
211
- self .event ["requestContext" ].get ("resourcePath" )
212
- or self .event ["requestContext" ]["http" ]["path" ].split (self .event ["requestContext" ]["stage" ])[- 1 ]
213
- )
214
- service_context ["origin" ] = {
215
- "name" : "{} {}/{}" .format (
216
- self .httpmethod ,
217
- self .event ["requestContext" ]["stage" ],
218
- path ,
219
- )
220
- }
242
+ service_context ["origin" ] = {"name" : self .event ["requestContext" ]["domainName" ]}
221
243
service_context ["origin" ]["id" ] = self .event ["requestContext" ]["apiId" ]
222
244
service_context ["origin" ]["version" ] = self .event .get ("version" , "1.0" )
223
245
cloud_context ["origin" ] = {}
@@ -236,13 +258,18 @@ def set_metadata_and_context(self, coldstart: bool) -> None:
236
258
cloud_context ["origin" ]["region" ] = record ["awsRegion" ]
237
259
cloud_context ["origin" ]["account" ] = {"id" : record ["eventSourceARN" ].split (":" )[4 ]}
238
260
cloud_context ["origin" ]["provider" ] = "aws"
239
- message_context ["queue" ] = service_context ["origin" ]["name" ]
261
+ message_context ["queue" ] = { "name" : service_context ["origin" ]["name" ]}
240
262
if "SentTimestamp" in record ["attributes" ]:
241
263
message_context ["age" ] = {"ms" : int ((time .time () * 1000 ) - int (record ["attributes" ]["SentTimestamp" ]))}
242
264
if self .client .config .capture_body in ("transactions" , "all" ) and "body" in record :
243
265
message_context ["body" ] = record ["body" ]
244
266
if self .client .config .capture_headers and record .get ("messageAttributes" ):
245
- message_context ["headers" ] = record ["messageAttributes" ]
267
+ headers = {}
268
+ for k , v in record ["messageAttributes" ].items ():
269
+ if v and v .get ("stringValue" ):
270
+ headers [k ] = v .get ("stringValue" )
271
+ if headers :
272
+ message_context ["headers" ] = headers
246
273
elif self .source == "sns" :
247
274
record = self .event ["Records" ][0 ]
248
275
faas ["trigger" ]["type" ] = "pubsub"
@@ -256,7 +283,7 @@ def set_metadata_and_context(self, coldstart: bool) -> None:
256
283
cloud_context ["origin" ]["region" ] = record ["Sns" ]["TopicArn" ].split (":" )[3 ]
257
284
cloud_context ["origin" ]["account_id" ] = record ["Sns" ]["TopicArn" ].split (":" )[4 ]
258
285
cloud_context ["origin" ]["provider" ] = "aws"
259
- message_context ["queue" ] = service_context ["origin" ]["name" ]
286
+ message_context ["queue" ] = { "name" : service_context ["origin" ]["name" ]}
260
287
if "Timestamp" in record ["Sns" ]:
261
288
message_context ["age" ] = {
262
289
"ms" : int (
@@ -270,7 +297,12 @@ def set_metadata_and_context(self, coldstart: bool) -> None:
270
297
if self .client .config .capture_body in ("transactions" , "all" ) and "Message" in record ["Sns" ]:
271
298
message_context ["body" ] = record ["Sns" ]["Message" ]
272
299
if self .client .config .capture_headers and record ["Sns" ].get ("MessageAttributes" ):
273
- message_context ["headers" ] = record ["Sns" ]["MessageAttributes" ]
300
+ headers = {}
301
+ for k , v in record ["Sns" ]["MessageAttributes" ].items ():
302
+ if v and v .get ("Type" ) == "String" :
303
+ headers [k ] = v .get ("Value" )
304
+ if headers :
305
+ message_context ["headers" ] = headers
274
306
elif self .source == "s3" :
275
307
record = self .event ["Records" ][0 ]
276
308
faas ["trigger" ]["type" ] = "datasource"
@@ -291,11 +323,7 @@ def set_metadata_and_context(self, coldstart: bool) -> None:
291
323
"name" : os .environ .get ("AWS_EXECUTION_ENV" ),
292
324
"version" : platform .python_version (),
293
325
}
294
- arn = self .context .invoked_function_arn
295
- if len (arn .split (":" )) > 7 :
296
- arn = ":" .join (arn .split (":" )[:7 ])
297
- metadata ["service" ]["id" ] = arn
298
- metadata ["service" ]["version" ] = os .environ .get ("AWS_LAMBDA_FUNCTION_VERSION" )
326
+ metadata ["service" ]["version" ] = self .client .config .service_version
299
327
metadata ["service" ]["node" ] = {"configured_name" : os .environ .get ("AWS_LAMBDA_LOG_STREAM_NAME" )}
300
328
# This is the one piece of metadata that requires deep merging. We add it manually
301
329
# here to avoid having to deep merge in _transport.add_metadata()
@@ -315,7 +343,7 @@ def set_metadata_and_context(self, coldstart: bool) -> None:
315
343
# faas doesn't actually belong in context, but we handle this in to_dict
316
344
elasticapm .set_context (faas , "faas" )
317
345
if message_context :
318
- elasticapm .set_context (service_context , "message" )
346
+ elasticapm .set_context (message_context , "message" )
319
347
self .client ._transport .add_metadata (metadata )
320
348
321
349
0 commit comments