Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -228,64 +228,62 @@ def keys(self, carrier: Dict) -> List:

getter = AiohttpGetter()

def create_aiohttp_middleware(tracer_provider: trace.TracerProvider | None = None):
_tracer = tracer_provider.get_tracer(
__name__, __version__
) if tracer_provider else tracer

@web.middleware
async def _middleware(request, handler):
"""Middleware for aiohttp implementing tracing logic"""
if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled(
request.url.path
):
return await handler(request)

span_name, additional_attributes = get_default_span_details(request)

req_attrs = collect_request_attributes(request)
duration_attrs = _parse_duration_attrs(req_attrs)
active_requests_count_attrs = _parse_active_request_count_attrs(req_attrs)

duration_histogram = meter.create_histogram(
name=MetricInstruments.HTTP_SERVER_DURATION,
unit="ms",
description="Measures the duration of inbound HTTP requests.",
)

@web.middleware
async def middleware(request, handler):
"""Middleware for aiohttp implementing tracing logic"""
if not is_http_instrumentation_enabled() or _excluded_urls.url_disabled(
request.url.path
):
return await handler(request)

span_name, additional_attributes = get_default_span_details(request)

req_attrs = collect_request_attributes(request)
duration_attrs = _parse_duration_attrs(req_attrs)
active_requests_count_attrs = _parse_active_request_count_attrs(req_attrs)
active_requests_counter = meter.create_up_down_counter(
name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS,
unit="requests",
description="measures the number of concurrent HTTP requests those are currently in flight",
)

duration_histogram = meter.create_histogram(
name=MetricInstruments.HTTP_SERVER_DURATION,
unit="ms",
description="Measures the duration of inbound HTTP requests.",
)
with _tracer.start_as_current_span(
span_name,
context=extract(request, getter=getter),
kind=trace.SpanKind.SERVER,
) as span:
attributes = collect_request_attributes(request)
attributes.update(additional_attributes)
span.set_attributes(attributes)
start = default_timer()
active_requests_counter.add(1, active_requests_count_attrs)
try:
resp = await handler(request)
set_status_code(span, resp.status)
except web.HTTPException as ex:
set_status_code(span, ex.status_code)
raise
finally:
duration = max((default_timer() - start) * 1000, 0)
duration_histogram.record(duration, duration_attrs)
active_requests_counter.add(-1, active_requests_count_attrs)
return resp
return _middleware

middleware = create_aiohttp_middleware() # for backwards compatibility

active_requests_counter = meter.create_up_down_counter(
name=MetricInstruments.HTTP_SERVER_ACTIVE_REQUESTS,
unit="requests",
description="measures the number of concurrent HTTP requests those are currently in flight",
)

with tracer.start_as_current_span(
span_name,
context=extract(request, getter=getter),
kind=trace.SpanKind.SERVER,
) as span:
attributes = collect_request_attributes(request)
attributes.update(additional_attributes)
span.set_attributes(attributes)
start = default_timer()
active_requests_counter.add(1, active_requests_count_attrs)
try:
resp = await handler(request)
set_status_code(span, resp.status)
except web.HTTPException as ex:
set_status_code(span, ex.status_code)
raise
finally:
duration = max((default_timer() - start) * 1000, 0)
duration_histogram.record(duration, duration_attrs)
active_requests_counter.add(-1, active_requests_count_attrs)
return resp


class _InstrumentedApplication(web.Application):
"""Insert tracing middleware"""

def __init__(self, *args, **kwargs):
middlewares = kwargs.pop("middlewares", [])
middlewares.insert(0, middleware)
kwargs["middlewares"] = middlewares
super().__init__(*args, **kwargs)


class AioHttpServerInstrumentor(BaseInstrumentor):
Expand All @@ -296,7 +294,22 @@ class AioHttpServerInstrumentor(BaseInstrumentor):
"""

def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider", None)
assert tracer_provider is None or isinstance(
tracer_provider, trace.TracerProvider
)
self._original_app = web.Application

_middleware = create_aiohttp_middleware(tracer_provider=tracer_provider)
class _InstrumentedApplication(web.Application):
"""Insert tracing middleware"""

def __init__(self, *args, **kwargs):
middlewares = kwargs.pop("middlewares", [])
middlewares.insert(0, _middleware)
kwargs["middlewares"] = middlewares
super().__init__(*args, **kwargs)

setattr(web, "Application", _InstrumentedApplication)

def _uninstrument(self, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from opentelemetry.test.globals_test import reset_trace_globals
from opentelemetry.test.test_base import TestBase
from opentelemetry.util._importlib_metadata import entry_points

from opentelemetry.sdk.trace.sampling import ParentBased, TraceIdRatioBased

class HTTPMethod(Enum):
"""HTTP methods and descriptions"""
Expand Down Expand Up @@ -76,9 +76,9 @@ def fixture_suppress():

@pytest_asyncio.fixture(name="server_fixture")
async def fixture_server_fixture(tracer, aiohttp_server, suppress):
_, memory_exporter = tracer
tracer_provider, memory_exporter = tracer

AioHttpServerInstrumentor().instrument()
AioHttpServerInstrumentor().instrument(tracer_provider=tracer_provider)

app = aiohttp.web.Application()
app.add_routes([aiohttp.web.get("/test-path", default_handler)])
Expand Down Expand Up @@ -195,3 +195,34 @@ async def handler(request):
# Clean up
AioHttpServerInstrumentor().uninstrument()
memory_exporter.clear()



@pytest.mark.asyncio
@pytest.mark.parametrize(
"tracer", [TestBase().create_tracer_provider(sampler=ParentBased(TraceIdRatioBased(0.05)))]
)
async def test_non_global_tracer_provider(
tracer,
server_fixture,
aiohttp_client,
):
n_requests = 1000
collection_ratio = 0.05
n_expected_trace_ids = n_requests * collection_ratio

_, memory_exporter = tracer
server, _ = server_fixture

assert len(memory_exporter.get_finished_spans()) == 0

client = await aiohttp_client(server)
for _ in range(n_requests):
await client.get("/test-path")

trace_ids = {
span.context.trace_id
for span in memory_exporter.get_finished_spans()
if span.context is not None
}
assert 0.5 * n_expected_trace_ids <= len(trace_ids) <= 1.5 * n_expected_trace_ids