Skip to content

Django dataloader

As a response to the N+1 query problem, the strawberry library provides a DataLoader class that can be used to batch and cache database queries.

The main DataLoader specification from graphql can be found here.

Example

Mixin Setup

Because we want to dislay, executed queries under the hood, we need to create a Mixin that will be used in our GraphQL test cases.

debug_sql_queries_mixin.py

python
from typing import Optional
from django.test.utils import CaptureQueriesContext
from django.db import connections
from django.db import DEFAULT_DB_ALIAS


class _DebugSQLQueries(CaptureQueriesContext):
    """Base class for capturing SQL queries in tests.

    This class extends CaptureQueriesContext to provide a context manager
    that captures SQL queries executed during the test.
    """

    def __init__(self, connection) -> None:
        super().__init__(connection)

    def __exit__(self, exc_type, exc_value, traceback):
        super().__exit__(exc_type, exc_value, traceback)
        if exc_type is not None:
            return

        executed = len(self)
        print(f"{executed} queries executed.")
        for i, query in enumerate(self.captured_queries, start=1):
            print(f"{i}. {query['sql']}")


class SQLTestMixins:
    def debug_sql_queries(
        self,
        using=DEFAULT_DB_ALIAS,
    ) -> _DebugSQLQueries:
        """Display SQL queries captured."""
        conn = connections[using]
        return _DebugSQLQueries(conn)

DataLoader Setup

We will use strawberry for our GraphQL implementation, and we will create a DataLoader for our Location model.

python
import strawberry
from strawberry import relay
from strawberry_django import connection
from strawberry_django.relay import DjangoListConnection
from myapp.object_types import LocationType
from myapp.dataloaders.location import location_loader


@strawberry.type
class LocationQuery:
    """Query type for locations."""

    # location: LocationType = relay.node() # basic one, we keep it to compare
    locations: DjangoListConnection[LocationType] = connection()

    @strawberry.field
    async def location(self, id: strawberry.ID) -> LocationType:
        """Fetch a single location by ID."""
        # Convert global ID to primary key

        id = relay.from_base64(id)[1]
        return await location_loader.load(id)

Unittest query

We will use this query to test our DataLoader implementation.

graphql
fragment LocationFields on LocationType {
  id
  root
  acceptPriority
  acceptRegexp
  enabled
}

query DataLoaderLocations {
  first: location(id: "TG9jYXRpb25UeXBlOjE=") {
    ...LocationFields
  }
  second: location(id: "TG9jYXRpb25UeXBlOjI=") {
    ...LocationFields
  }
}

Without dataloader, this query will result in 2 queries to the database, one for each location.

python
from django.test import TestCase

from myapp.models import Location as LocationModel
from tests.factories import LocationFactory
from tests.mixins.graphql_test import GraphQLTestMixins
from tests.mixins.sql_test import SQLTestMixins


class LocationSQLQueryTestCase(TestCase, GraphQLTestMixins, SQLTestMixins):
    """Test SQL queries for location."""

    @classmethod
    def setUpTestData(cls) -> None:

        cls.location = LocationFactory.create_batch(5)
        LocationModel.objects.bulk_create(cls.location)

    def test_location_basic_sql_query(self) -> None:
        """Test SQL queries for location."""

        # with self.assertNumQueries(1): after, we will use this as a unittest
        with self.debug_sql_queries():
            with self.with_schema("dataloader_location") as schema:
                response = schema.query("DataLoaderLocations")
                self.assertEqual(response.status_code, 200, "Query failed")

You should see:

1 queries executed.

  1. SELECT "myapp_location"."id", "myapp_location"."root", "myapp_location"."accept_priority", "myapp_location"."accept_regexp", "myapp_location"."enabled" FROM "myapp_location" WHERE "myapp_location"."id" IN (1, 2)

Instead of this:

2 queries executed.

  1. SELECT "myapp_location"."id", "myapp_location"."root", "myapp_location"."accept_priority", "myapp_location"."accept_regexp", "myapp_location"."enabled" FROM "myapp_location" WHERE "myapp_location"."id" = 1 LIMIT 21
  2. SELECT "myapp_location"."id", "myapp_location"."root", "myapp_location"."accept_priority", "myapp_location"."accept_regexp", "myapp_location"."enabled" FROM "myapp_location" WHERE "myapp_location"."id" = 2 LIMIT 21

Conclusion

So the DataLoader is batching the queries and fetching all the locations in a single query. For nested queries, to avoid refetching, a caching mechanism is used, so the DataLoader will cache the results of the first query and use them for subsequent queries.

References