Skip to content

SQL Mixin Debug Django

debug_sql_queries_mixin.py

python
from typing import Optional
from django.test.utils import CaptureQueriesContext
from django.db import connections
from django.db import DEFAULT_DB_ALIAS


class _DebugSQLQueries(CaptureQueriesContext):
    """Base class for capturing SQL queries in tests.

    This class extends CaptureQueriesContext to provide a context manager
    that captures SQL queries executed during the test.
    """

    def __init__(self, connection) -> None:
        super().__init__(connection)

    def __exit__(self, exc_type, exc_value, traceback):
        super().__exit__(exc_type, exc_value, traceback)
        if exc_type is not None:
            return

        executed = len(self)
        print(f"{executed} queries executed.")
        for i, query in enumerate(self.captured_queries, start=1):
            print(f"{i}. {query['sql']}")


class _AssertNumDBConnections(CaptureQueriesContext):
    """Context manager to assert the number of database connections."""

    def __init__(self, connection, num: int) -> None:
        super().__init__(connection)
        self.num = num

    def __exit__(self, exc_type, exc_value, traceback):
        super().__exit__(exc_type, exc_value, traceback)
        current_connections = len(connections.all())
        assert (
            current_connections == self.num
        ), f"Expected {self.num} database connections, but got {current_connections}."  # noqa



class SQLTestMixins:
    def assertNumberOfQueries(
        self, using=DEFAULT_DB_ALIAS
    ) -> _DebugSQLQueries:
        """Assert the number of SQL queries captured."""
        conn = connections[using]
        return _DebugSQLQueries(conn)

    def assertNumDBConnections(
        self, num, using=DEFAULT_DB_ALIAS
    ) -> _AssertNumDBConnections:
        """Assert the number of database connections used in a test."""
        return _AssertNumDBConnections(connections[using], num)

Usage

python
from django.test import TestCase

from myapp.models import Location as LocationModel
from tests.factories import LocationFactory
from tests.mixins.graphql_test import GraphQLTestMixins
from tests.mixins.sql_test import SQLTestMixins


class LocationSQLQueryTestCase(TestCase, GraphQLTestMixins, SQLTestMixins):
    """Test SQL queries for location."""

    @classmethod
    def setUpTestData(cls) -> None:

        cls.location = LocationFactory.create_batch(5)
        LocationModel.objects.bulk_create(cls.location)

    def test_location_basic_sql_query(self) -> None:
        """Test SQL queries for location."""

        # with self.assertNumQueries(1): after, we will use this as a unittest
        # with self.debug_sql_queries(1): to show SQL queries
        with self.assertNumDBConnections(1):
            with self.with_schema("dataloader_location") as schema:
                response = schema.query("DataLoaderLocations")
                self.assertEqual(response.status_code, 200, "Query failed")

References