Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions elasticapm/instrumentation/packages/dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,30 +170,46 @@ def extract_signature(sql):
keyword = "INTO" if sql_type == "INSERT" else "FROM"
sql_type = sql_type + " " + keyword

table_name = look_for_table(sql, keyword)
object_name = look_for_table(sql, keyword)
elif sql_type in ["CREATE", "DROP"]:
# 2nd word is part of SQL type
sql_type = sql_type + sql[first_space:second_space]
table_name = ""
object_name = ""
elif sql_type == "UPDATE":
table_name = look_for_table(sql, "UPDATE")
object_name = look_for_table(sql, "UPDATE")
elif sql_type == "SELECT":
# Name is first table
try:
sql_type = "SELECT FROM"
table_name = look_for_table(sql, "FROM")
object_name = look_for_table(sql, "FROM")
except Exception:
table_name = ""
object_name = ""
elif sql_type in ["EXEC", "EXECUTE"]:
sql_type = "EXECUTE"
end = second_space if second_space > first_space else len(sql)
object_name = sql[first_space + 1 : end]
elif sql_type == "CALL":
first_paren = sql.find("(", first_space)
end = first_paren if first_paren > first_space else len(sql)
procedure_name = sql[first_space + 1 : end].rstrip(";")
object_name = procedure_name + "()"
else:
# No name
table_name = ""
object_name = ""

signature = " ".join(filter(bool, [sql_type, table_name]))
signature = " ".join(filter(bool, [sql_type, object_name]))
return signature


QUERY_ACTION = "query"
EXEC_ACTION = "exec"
PROCEDURE_STATEMENTS = ["EXEC", "EXECUTE", "CALL"]


def extract_action_from_signature(signature, default):
if signature.split(" ")[0] in PROCEDURE_STATEMENTS:
return EXEC_ACTION
return default


class CursorProxy(wrapt.ObjectProxy):
Expand Down Expand Up @@ -226,6 +242,7 @@ def _trace_sql(self, method, sql, params, action=QUERY_ACTION):
signature = sql_string + "()"
else:
signature = self.extract_signature(sql_string)
action = extract_action_from_signature(signature, action)

# Truncate sql_string to 10000 characters to prevent large queries from
# causing an error to APM server.
Expand Down
42 changes: 41 additions & 1 deletion tests/instrumentation/dbapi2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import pytest

from elasticapm.instrumentation.packages.dbapi2 import Literal, extract_signature, scan, tokenize
from elasticapm.instrumentation.packages.dbapi2 import (
Literal,
extract_action_from_signature,
extract_signature,
scan,
tokenize,
)


def test_scan_simple():
Expand Down Expand Up @@ -114,3 +120,37 @@ def test_extract_signature_bytes():
actual = extract_signature(sql)
expected = "HELLO"
assert actual == expected


@pytest.mark.parametrize(
["sql", "expected"],
[
(
"EXEC AdventureWorks2022.dbo.uspGetEmployeeManagers 50;",
"EXECUTE AdventureWorks2022.dbo.uspGetEmployeeManagers",
),
("EXECUTE sp_who2", "EXECUTE sp_who2"),
("EXEC sp_updatestats @@all_schemas = 'true'", "EXECUTE sp_updatestats"),
("CALL get_car_stats_by_year(2017, @number, @min, @avg, @max);", "CALL get_car_stats_by_year()"),
("CALL get_car_stats_by_year", "CALL get_car_stats_by_year()"),
("CALL get_car_stats_by_year;", "CALL get_car_stats_by_year()"),
("CALL get_car_stats_by_year();", "CALL get_car_stats_by_year()"),
],
)
def test_extract_signature_for_procedure_call(sql, expected):
actual = extract_signature(sql)
assert actual == expected


@pytest.mark.parametrize(
["sql", "expected"],
[
("SELECT FROM table", "query"),
("EXEC sp_who", "exec"),
("EXECUTE sp_updatestats", "exec"),
("CALL me_maybe", "exec"),
],
)
def test_extract_action_from_signature(sql, expected):
actual = extract_action_from_signature(sql, "query")
assert actual == expected