SQL Mixin Debug Django
debug_sql_queries_mixin.py
python
from typing import Optional
from django.test.utils import CaptureQueriesContext
from django.db import connections
from django.db import DEFAULT_DB_ALIAS
class _DebugSQLQueries(CaptureQueriesContext):
"""Base class for capturing SQL queries in tests.
This class extends CaptureQueriesContext to provide a context manager
that captures SQL queries executed during the test.
"""
def __init__(self, connection) -> None:
super().__init__(connection)
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if exc_type is not None:
return
executed = len(self)
print(f"{executed} queries executed.")
for i, query in enumerate(self.captured_queries, start=1):
print(f"{i}. {query['sql']}")
class _AssertNumDBConnections(CaptureQueriesContext):
"""Context manager to assert the number of database connections."""
def __init__(self, connection, num: int) -> None:
super().__init__(connection)
self.num = num
def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
current_connections = len(connections.all())
assert (
current_connections == self.num
), f"Expected {self.num} database connections, but got {current_connections}." # noqa
class SQLTestMixins:
def assertNumberOfQueries(
self, using=DEFAULT_DB_ALIAS
) -> _DebugSQLQueries:
"""Assert the number of SQL queries captured."""
conn = connections[using]
return _DebugSQLQueries(conn)
def assertNumDBConnections(
self, num, using=DEFAULT_DB_ALIAS
) -> _AssertNumDBConnections:
"""Assert the number of database connections used in a test."""
return _AssertNumDBConnections(connections[using], num)
Usage
python
from django.test import TestCase
from myapp.models import Location as LocationModel
from tests.factories import LocationFactory
from tests.mixins.graphql_test import GraphQLTestMixins
from tests.mixins.sql_test import SQLTestMixins
class LocationSQLQueryTestCase(TestCase, GraphQLTestMixins, SQLTestMixins):
"""Test SQL queries for location."""
@classmethod
def setUpTestData(cls) -> None:
cls.location = LocationFactory.create_batch(5)
LocationModel.objects.bulk_create(cls.location)
def test_location_basic_sql_query(self) -> None:
"""Test SQL queries for location."""
# with self.assertNumQueries(1): after, we will use this as a unittest
# with self.debug_sql_queries(1): to show SQL queries
with self.assertNumDBConnections(1):
with self.with_schema("dataloader_location") as schema:
response = schema.query("DataLoaderLocations")
self.assertEqual(response.status_code, 200, "Query failed")