Skip to content

SQL Mixin Debug Django

debug_sql_queries_mixin.py

python
from django.test.utils import CaptureQueriesContext
from django.db import connections
from django.db import DEFAULT_DB_ALIAS


class _DebugSQLQueries(CaptureQueriesContext):
    """Base class for capturing SQL queries in tests.

    This class extends CaptureQueriesContext to provide a context manager
    that captures SQL queries executed during the test.
    """

    def __init__(self, connection) -> None:
        super().__init__(connection)

    def __exit__(self, exc_type, exc_value, traceback):
        super().__exit__(exc_type, exc_value, traceback)
        if exc_type is not None:
            return

        executed = len(self)
        print(f"{executed} queries executed.")
        for i, query in enumerate(self.captured_queries, start=1):
            print(f"{i}. {query['sql']}")


class _AssertNumDBConnections(CaptureQueriesContext):
    """Context manager to assert the number of database connections."""

    def __init__(self, connection, num: int) -> None:
        super().__init__(connection)
        self.num = num

    def __exit__(self, exc_type, exc_value, traceback):
        super().__exit__(exc_type, exc_value, traceback)
        current_connections = len(connections.all())
        assert (
            current_connections == self.num
        ), f"Expected {self.num} database connections, but got {current_connections}."  # noqa


class _AssertNumQueriesByType(CaptureQueriesContext):
    """Context manager to assert the number of specific SQL query types."""

    def __init__(self, connection, num: int, query_type: str) -> None:
        super().__init__(connection)
        self.num = num
        self.query_type = query_type.upper()

    def __exit__(self, exc_type, exc_value, traceback):
        super().__exit__(exc_type, exc_value, traceback)
        if exc_type is not None:
            return

        # Count queries of the specified type
        query_count = 0
        for query in self.captured_queries:
            sql = query["sql"].strip().upper()
            if sql.startswith(self.query_type):
                query_count += 1

        assert query_count == self.num, (
            f"Expected {self.num} {self.query_type} queries, "
            f"but got {query_count}."
        )


class SQLTestMixins:
    """SQL test mixins for asserting database query patterns.

    Example usage:
        class MyTestCase(TestCase, SQLTestMixins):
            def test_my_function(self):
                # Test that exactly 2 SELECT queries are executed
                with self.assertNumSelectQueries(2):
                    MyModel.objects.all().count()
                    MyModel.objects.filter(id=1).first()

                # Test that exactly 1 DELETE query is executed
                with self.assertNumDeleteQueries(1):
                    MyModel.objects.filter(id=1).delete()

                # Test multiple query types together
                with (
                    self.assertNumSelectQueries(1),
                    self.assertNumInsertQueries(1),
                    self.assertNumUpdateQueries(0)
                ):
                    obj = MyModel.objects.create(name="test")
                    MyModel.objects.filter(id=obj.id).first()
    """

    def displaySQLQueries(self, using=DEFAULT_DB_ALIAS) -> _DebugSQLQueries:
        """Assert the number of SQL queries captured."""
        conn = connections[using]
        return _DebugSQLQueries(conn)

    def assertNumDBConnections(
        self, num, using=DEFAULT_DB_ALIAS
    ) -> _AssertNumDBConnections:
        """Assert the number of database connections used in a test."""
        return _AssertNumDBConnections(connections[using], num)

    def assertNumSelectQueries(
        self, num, using=DEFAULT_DB_ALIAS
    ) -> _AssertNumQueriesByType:
        """Assert the number of SELECT queries executed."""
        return _AssertNumQueriesByType(connections[using], num, "SELECT")

    def assertNumInsertQueries(
        self, num, using=DEFAULT_DB_ALIAS
    ) -> _AssertNumQueriesByType:
        """Assert the number of INSERT queries executed."""
        return _AssertNumQueriesByType(connections[using], num, "INSERT")

    def assertNumUpdateQueries(
        self, num, using=DEFAULT_DB_ALIAS
    ) -> _AssertNumQueriesByType:
        """Assert the number of UPDATE queries executed."""
        return _AssertNumQueriesByType(connections[using], num, "UPDATE")

    def assertNumDeleteQueries(
        self, num, using=DEFAULT_DB_ALIAS
    ) -> _AssertNumQueriesByType:
        """Assert the number of DELETE queries executed."""
        return _AssertNumQueriesByType(connections[using], num, "DELETE")

Usage

python
from django.test import TestCase

from myapp.models import Location as LocationModel
from tests.factories import LocationFactory
from tests.mixins.graphql_test import GraphQLTestMixins
from tests.mixins.sql_test import SQLTestMixins


class LocationSQLQueryTestCase(TestCase, GraphQLTestMixins, SQLTestMixins):
    """Test SQL queries for location."""

    @classmethod
    def setUpTestData(cls) -> None:

        cls.location = LocationFactory.create_batch(5)
        LocationModel.objects.bulk_create(cls.location)

    def test_location_basic_sql_query(self) -> None:
        """Test SQL queries for location."""

        # with self.assertNumQueries(1): after, we will use this as a unittest
        # with self.debug_sql_queries(1): to show SQL queries
        with self.assertNumDBConnections(1):
            with self.with_schema("dataloader_location") as schema:
                response = schema.query("DataLoaderLocations")
                self.assertEqual(response.status_code, 200, "Query failed")

References