Skip to content

Custom GraphQL Test Case Mixins

To group query in .graphql files and request them in tests, using operation name

Create config file

yaml
api_route: /localize/api/
documents:
  - "./tests/queries/*.graphql"

mixins

tests/mixins/graphql_test.py

python
from functools import partial
from pathlib import Path
from typing import Any
from typing import Generator
from typing import List
from collections import namedtuple
from contextlib import contextmanager
from django.http.response import HttpResponseBase
from django.test import Client as PostClient
import yaml


SchemaQuery = namedtuple("SchemaQuery", ["query", "document"])
SchemaConfig = namedtuple("SchemaConfig", ["root", "settings"])
CONFIG_FILE = "graphql_config.yml"


class GraphqlSettings(dict):
    """GraphQL settings for the test configuration.
    This class allows dynamic access to its attributes for settings fields.
    The main goal of this class is to provide a real pythonic way to access
    GraphQL settings without the need to use dictionary keys directly.

    You can populate it with other fields as needed.
    ex:
      settings = GraphqlSettings(documents=["*.graphql"], api_route="http://example.com/graphql")
      print(settings.documents)
      print(settings.api_route)
    """

    documents: List[str]
    api_route: str

    def __init__(self, **kwargs):
        self.update(kwargs)

    def __getattr__(self, item):
        return self._self_recurse_arre(self, item.split("."))

    def _self_recurse_arre(self, dictionary, keywords):
        if not dictionary:
            return None

        if not keywords:
            return dictionary

        if len(keywords) == 1:
            return dictionary.get(keywords[0])
        return self._self_recurse_arre(
            dictionary.get(keywords[0]), keywords[1:]
        )


class GraphQLTestMixins:
    """GraphQL test mixins for executing GraphQL queries in tests.

    This class provides methods to parse GraphQL documents and execute queries
    using a specified GraphQL schema from a custom configuration file.
    """

    def _iter_parse_document_glob(self, root, documents: str | List[str]):
        """Iterate over GraphQL documents and parse them."""

        if not isinstance(documents, list):
            documents = [documents]

        for doc in documents:
            for path in root.glob(doc):
                with path.open() as file:
                    yield path, file.read()

    def _iter_parse_config(self, root=Path(__file__).parent, previous=None):
        """Iterate over GraphQL configuration files."""

        if previous == root:
            raise FileNotFoundError(
                f"Please create a {CONFIG_FILE} file in the root project directory"
            )

        config_file = root / CONFIG_FILE

        if config_file.is_file():
            with config_file.open() as stream:
                config_content = yaml.safe_load(stream)
                if not config_content:
                    raise ValueError(
                        f"Config file '{config_file}' is empty or invalid."
                    )

                settings_obj = GraphqlSettings(**config_content)
                return SchemaConfig(
                    root=root,
                    settings=settings_obj,
                )
        return self._iter_parse_config(root=root.parent, previous=root)

    @contextmanager
    def with_schema(self, file_name) -> Generator[SchemaQuery, Any, None]:
        """Context manager to set the schema for GraphQL tests."""

        config: SchemaConfig = self._iter_parse_config()
        for document in self._iter_parse_document_glob(
            config.root, config.settings.documents
        ):
            file, doc = document
            yield SchemaQuery(
                query=partial(
                    self.query,
                    input=None,
                    graphql_url=config.settings.api_route,
                    document=doc,
                ),
                document=doc,
            )

    def query(
        self,
        operation: str,
        input: dict | None = None,
        graphql_url: str = None,
        document: str | None = None,
        variables: dict | None = None,
    ) -> HttpResponseBase:
        """Execute a GraphQL query with the given parameters.

        :param operation: The name of the GraphQL operation to execute.
        :type operation: str

        :param input: Input data for the query, if applicable.
        :type input: dict | None

        :param graphql_url: The URL to send the GraphQL query to.
        :type graphql_url: str | None

        :param document: The GraphQL query string.
        :type document: str | None

        :param variables: Variables for the GraphQL query.
        :type variables: dict | None

        :return: The response from the DjangoClient.
        :rtype: HttpResponseBase
        """

        return self.graphql_query(
            graphql_url=graphql_url,
            query=document,
            operation_name=operation,
            input_data=input,
            variables=variables,
        )

    def graphql_query(
        self,
        graphql_url,
        query,
        operation_name=None,
        input_data=None,
        variables=None,
        headers=None,
    ) -> HttpResponseBase:
        """Send a GraphQL query to the specified URL.

        :param graphql_url: The URL to send the GraphQL query to.
        :type graphql_url: str

        :param query: The GraphQL query string.
        :type query: str

        :param operation_name: The name of the operation to execute.
        :type operation_name: str | None

        :param input_data: Input data for the query, if applicable.
        :type input_data: dict | None

        :param variables: Variables for the GraphQL query.
        :type variables: dict | None

        :param headers: Additional headers to include in the request.
        :type headers: dict | None

        :return: The response from the DjangoClient.
        :rtype: HttpResponseBase
        """

        body = {"query": query}
        headers = headers or {}

        if operation_name:
            body["operationName"] = operation_name

        if variables:
            body["variables"] = variables

        if input_data:
            if "variables" in body:
                body["variables"]["input"] = input_data
            else:
                body["variables"] = {"input": input_data}

        client = PostClient()
        return client.post(
            path=graphql_url,
            data=body,
            content_type="application/json",
            headers=headers,
        )


if __name__ == "__main__":

    class ExampleUsage(GraphQLTestMixins):
        def run_test(self) -> None:
            with self.with_schema("location") as schema:
                print(f"Running test with schema: {schema.document}")
                print(f"Query content: {schema.query("GetLocations").json()}")

                # example input for a mutation
                input = {
                    "root": "/test_root",
                    "enabled": True,
                }

                response = schema.query("CreateLocation", input=input)
                self.assertEqual(response.status_code, 200, "Query failed")
                data = response.json()["data"]["createLocation"]

    example_test = ExampleUsage()
    example_test.run_test()

.graphql example

./tests/queries/location.graphql

graphql
query GetLocations {
  locations {
    edges {
      node {
        id
        root
        enabled
      }
    }
  }
}

mutation CreateLocation($input: CreateLocationInput!) {
  createLocation(input: $input) {
    id
    root
    enabled
  }

In unittests

python
from django.test import TestCase
from localize.models import Location as LocationModel
from tests.mixins.graphql_test import GraphQLTestMixins
from tests.test_location_gql import LocationFactory


class LocationQueryTest(TestCase, GraphQLTestMixins):

    @classmethod
    def setUpTestData(cls) -> None:
        print("Setting up data for LocationTest")

        # Create 5 locations for testing using factory boy
        locations = LocationFactory.create_batch(5)
        LocationModel.objects.bulk_create(locations)

    def test_query_locations(self) -> None:
        """Test querying locations."""

        with self.with_schema("location") as schema:
            response = schema.query("GetLocations")
            self.assertEqual(response.status_code, 200, "Query failed")
            self.assertEqual(
                len(response.json()["data"]["locations"]["edges"]),
                5,
                "Expected 5 locations in the response",
            )

    def test_create_locations(self) -> None:
        """Test create locations."""

        input = {
            "root": "/test_root",
            "enabled": True,
        }
        with self.with_schema("location") as schema:
            response = schema.query("CreateLocation", input=input)
            self.assertEqual(response.status_code, 200, "Query failed")
            data = response.json()["data"]["createLocation"]
            print(data)

References