SQL Mixin Debug Django
debug_sql_queries_mixin.py
python
from django.test.utils import CaptureQueriesContext
from django.db import connections
from django.db import DEFAULT_DB_ALIAS
class _DebugSQLQueries(CaptureQueriesContext):
"""Base class for capturing SQL queries in tests.
This class extends CaptureQueriesContext to provide a context manager
that captures SQL queries executed during the test.
"""
def __init__(self, connection) -> None:
super().__init__(connection)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if exc_type is not None:
return
executed = len(self)
print(f"{executed} queries executed.")
for i, query in enumerate(self.captured_queries, start=1):
print(f"{i}. {query['sql']}")
class _AssertNumDBConnections(CaptureQueriesContext):
"""Context manager to assert the number of database connections."""
def __init__(self, connection, num: int) -> None:
super().__init__(connection)
self.num = num
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
current_connections = len(connections.all())
assert (
current_connections == self.num
), f"Expected {self.num} database connections, but got {current_connections}." # noqa
class _AssertNumQueriesByType(CaptureQueriesContext):
"""Context manager to assert the number of specific SQL query types."""
def __init__(self, connection, num: int, query_type: str) -> None:
super().__init__(connection)
self.num = num
self.query_type = query_type.upper()
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if exc_type is not None:
return
# Count queries of the specified type
query_count = 0
for query in self.captured_queries:
sql = query["sql"].strip().upper()
if sql.startswith(self.query_type):
query_count += 1
assert query_count == self.num, (
f"Expected {self.num} {self.query_type} queries, "
f"but got {query_count}."
)
class SQLTestMixins:
"""SQL test mixins for asserting database query patterns.
Example usage:
class MyTestCase(TestCase, SQLTestMixins):
def test_my_function(self):
# Test that exactly 2 SELECT queries are executed
with self.assertNumSelectQueries(2):
MyModel.objects.all().count()
MyModel.objects.filter(id=1).first()
# Test that exactly 1 DELETE query is executed
with self.assertNumDeleteQueries(1):
MyModel.objects.filter(id=1).delete()
# Test multiple query types together
with (
self.assertNumSelectQueries(1),
self.assertNumInsertQueries(1),
self.assertNumUpdateQueries(0)
):
obj = MyModel.objects.create(name="test")
MyModel.objects.filter(id=obj.id).first()
"""
def displaySQLQueries(self, using=DEFAULT_DB_ALIAS) -> _DebugSQLQueries:
"""Assert the number of SQL queries captured."""
conn = connections[using]
return _DebugSQLQueries(conn)
def assertNumDBConnections(
self, num, using=DEFAULT_DB_ALIAS
) -> _AssertNumDBConnections:
"""Assert the number of database connections used in a test."""
return _AssertNumDBConnections(connections[using], num)
def assertNumSelectQueries(
self, num, using=DEFAULT_DB_ALIAS
) -> _AssertNumQueriesByType:
"""Assert the number of SELECT queries executed."""
return _AssertNumQueriesByType(connections[using], num, "SELECT")
def assertNumInsertQueries(
self, num, using=DEFAULT_DB_ALIAS
) -> _AssertNumQueriesByType:
"""Assert the number of INSERT queries executed."""
return _AssertNumQueriesByType(connections[using], num, "INSERT")
def assertNumUpdateQueries(
self, num, using=DEFAULT_DB_ALIAS
) -> _AssertNumQueriesByType:
"""Assert the number of UPDATE queries executed."""
return _AssertNumQueriesByType(connections[using], num, "UPDATE")
def assertNumDeleteQueries(
self, num, using=DEFAULT_DB_ALIAS
) -> _AssertNumQueriesByType:
"""Assert the number of DELETE queries executed."""
return _AssertNumQueriesByType(connections[using], num, "DELETE")Usage
python
from django.test import TestCase
from myapp.models import Location as LocationModel
from tests.factories import LocationFactory
from tests.mixins.graphql_test import GraphQLTestMixins
from tests.mixins.sql_test import SQLTestMixins
class LocationSQLQueryTestCase(TestCase, GraphQLTestMixins, SQLTestMixins):
"""Test SQL queries for location."""
@classmethod
def setUpTestData(cls) -> None:
cls.location = LocationFactory.create_batch(5)
LocationModel.objects.bulk_create(cls.location)
def test_location_basic_sql_query(self) -> None:
"""Test SQL queries for location."""
# with self.assertNumQueries(1): after, we will use this as a unittest
# with self.debug_sql_queries(1): to show SQL queries
with self.assertNumDBConnections(1):
with self.with_schema("dataloader_location") as schema:
response = schema.query("DataLoaderLocations")
self.assertEqual(response.status_code, 200, "Query failed")