Astroport.ONE/venv/lib/python3.11/site-packages/graphql/validation/rules/unique_directives_per_location.py
2024-03-01 16:15:45 +01:00

86 lines
3.2 KiB
Python

from collections import defaultdict
from typing import Any, Dict, List, Union, cast
from ...error import GraphQLError
from ...language import (
DirectiveDefinitionNode,
DirectiveNode,
Node,
SchemaDefinitionNode,
SchemaExtensionNode,
TypeDefinitionNode,
TypeExtensionNode,
is_type_definition_node,
is_type_extension_node,
)
from ...type import specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["UniqueDirectivesPerLocationRule"]
class UniqueDirectivesPerLocationRule(ASTValidationRule):
"""Unique directive names per location
A GraphQL document is only valid if all non-repeatable directives at a given
location are uniquely named.
See https://spec.graphql.org/draft/#sec-Directives-Are-Unique-Per-Location
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
unique_directive_map: Dict[str, bool] = {}
schema = context.schema
defined_directives = (
schema.directives if schema else cast(List, specified_directives)
)
for directive in defined_directives:
unique_directive_map[directive.name] = not directive.is_repeatable
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
unique_directive_map[def_.name.value] = not def_.repeatable
self.unique_directive_map = unique_directive_map
self.schema_directives: Dict[str, DirectiveNode] = {}
self.type_directives_map: Dict[str, Dict[str, DirectiveNode]] = defaultdict(
dict
)
# Many different AST nodes may contain directives. Rather than listing them all,
# just listen for entering any node, and check to see if it defines any directives.
def enter(self, node: Node, *_args: Any) -> None:
directives = getattr(node, "directives", None)
if not directives:
return
directives = cast(List[DirectiveNode], directives)
if isinstance(node, (SchemaDefinitionNode, SchemaExtensionNode)):
seen_directives = self.schema_directives
elif is_type_definition_node(node) or is_type_extension_node(node):
node = cast(Union[TypeDefinitionNode, TypeExtensionNode], node)
type_name = node.name.value
seen_directives = self.type_directives_map[type_name]
else:
seen_directives = {}
for directive in directives:
directive_name = directive.name.value
if self.unique_directive_map.get(directive_name):
if directive_name in seen_directives:
self.report_error(
GraphQLError(
f"The directive '@{directive_name}'"
" can only be used once at this location.",
[seen_directives[directive_name], directive],
)
)
else:
seen_directives[directive_name] = directive