Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from debug_toolbar import settings as dt_settings
from debug_toolbar.utils import get_stack, get_template_info, tidy_stacktrace

try:
from psycopg2._json import Json as PostgresJson
except ImportError:
PostgresJson = None


class SQLQueryTriggered(Exception):
"""Thrown when template panel triggers a query"""
Expand Down Expand Up @@ -105,6 +110,8 @@ def _quote_params(self, params):
return [self._quote_expr(p) for p in params]

def _decode(self, param):
if PostgresJson and isinstance(param, PostgresJson):
return param.dumps(param.adapted)
# If a sequence type, decode each element separately
if isinstance(param, (tuple, list)):
return [self._decode(element) for element in param]
Expand Down Expand Up @@ -136,7 +143,6 @@ def _record(self, method, sql, params):
_params = json.dumps(self._decode(params))
except TypeError:
pass # object not JSON serializable

template_info = get_template_info()

alias = getattr(self.db, "alias", "default")
Expand Down
11 changes: 11 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,14 @@ def __repr__(self):

class Binary(models.Model):
field = models.BinaryField()


try:
from django.contrib.postgres.fields import JSONField

class PostgresJSON(models.Model):
field = JSONField()


except ImportError:
pass
30 changes: 30 additions & 0 deletions tests/panels/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@

from ..base import BaseTestCase

try:
from psycopg2._json import Json as PostgresJson
except ImportError:
PostgresJson = None

if connection.vendor == "postgresql":
from ..models import PostgresJSON as PostgresJSONModel
else:
PostgresJSONModel = None


class SQLPanelTestCase(BaseTestCase):
panel_id = "SQLPanel"
Expand Down Expand Up @@ -120,6 +130,26 @@ def test_param_conversion(self):
('["Foo", true, false]', "[10, 1]", '["2017-12-22 16:07:01"]'),
)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
def test_json_param_conversion(self):
self.assertEqual(len(self.panel._queries), 0)

list(PostgresJSONModel.objects.filter(field__contains={"foo": "bar"}))

response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)

# ensure query was logged
self.assertEqual(len(self.panel._queries), 1)
self.assertEqual(
self.panel._queries[0][1]["params"], '["{\\"foo\\": \\"bar\\"}"]',
)
self.assertIsInstance(
self.panel._queries[0][1]["raw_params"][0], PostgresJson,
)

def test_binary_param_force_text(self):
self.assertEqual(len(self.panel._queries), 0)

Expand Down
30 changes: 30 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
from django.core import signing
from django.core.checks import Warning, run_checks
from django.db import connection
from django.http import HttpResponse
from django.template.loader import get_template
from django.test import RequestFactory, SimpleTestCase, TestCase
Expand Down Expand Up @@ -206,6 +207,35 @@ def test_sql_explain_checks_show_toolbar(self):
)
self.assertEqual(response.status_code, 404)

@unittest.skipUnless(
connection.vendor == "postgresql", "Test valid only on PostgreSQL"
)
def test_sql_explain_postgres_json_field(self):
url = "/__debug__/sql_explain/"
base_query = (
'SELECT * FROM "tests_postgresjson" WHERE "tests_postgresjson"."field" @>'
)
query = base_query + """ '{"foo": "bar"}'"""
data = {
"sql": query,
"raw_sql": base_query + " %s",
"params": '["{\\"foo\\": \\"bar\\"}"]',
"alias": "default",
"duration": "0",
"hash": "2b7172eb2ac8e2a8d6f742f8a28342046e0d00ba",
}
response = self.client.post(url, data)
self.assertEqual(response.status_code, 200)
response = self.client.post(url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest")
self.assertEqual(response.status_code, 200)
with self.settings(INTERNAL_IPS=[]):
response = self.client.post(url, data)
self.assertEqual(response.status_code, 404)
response = self.client.post(
url, data, HTTP_X_REQUESTED_WITH="XMLHttpRequest"
)
self.assertEqual(response.status_code, 404)

def test_sql_profile_checks_show_toolbar(self):
url = "/__debug__/sql_profile/"
data = {
Expand Down