Astroport.ONE/venv/lib/python3.11/site-packages/graphql/validation/rules/provided_required_arguments.py
2024-03-01 16:15:45 +01:00

120 lines
4.4 KiB
Python

from typing import cast, Any, Dict, List, Union
from ...error import GraphQLError
from ...language import (
DirectiveDefinitionNode,
DirectiveNode,
FieldNode,
InputValueDefinitionNode,
NonNullTypeNode,
TypeNode,
VisitorAction,
SKIP,
print_ast,
)
from ...type import GraphQLArgument, is_required_argument, is_type, specified_directives
from . import ASTValidationRule, SDLValidationContext, ValidationContext
__all__ = ["ProvidedRequiredArgumentsRule", "ProvidedRequiredArgumentsOnDirectivesRule"]
class ProvidedRequiredArgumentsOnDirectivesRule(ASTValidationRule):
"""Provided required arguments on directives
A directive is only valid if all required (non-null without a default value)
arguments have been provided.
For internal use only.
"""
context: Union[ValidationContext, SDLValidationContext]
def __init__(self, context: Union[ValidationContext, SDLValidationContext]):
super().__init__(context)
required_args_map: Dict[
str, Dict[str, Union[GraphQLArgument, InputValueDefinitionNode]]
] = {}
schema = context.schema
defined_directives = schema.directives if schema else specified_directives
for directive in cast(List, defined_directives):
required_args_map[directive.name] = {
name: arg
for name, arg in directive.args.items()
if is_required_argument(arg)
}
ast_definitions = context.document.definitions
for def_ in ast_definitions:
if isinstance(def_, DirectiveDefinitionNode):
required_args_map[def_.name.value] = {
arg.name.value: arg
for arg in filter(is_required_argument_node, def_.arguments or ())
}
self.required_args_map = required_args_map
def leave_directive(self, directive_node: DirectiveNode, *_args: Any) -> None:
# Validate on leave to allow for deeper errors to appear first.
directive_name = directive_node.name.value
required_args = self.required_args_map.get(directive_name)
if required_args:
arg_nodes = directive_node.arguments or ()
arg_node_set = {arg.name.value for arg in arg_nodes}
for arg_name in required_args:
if arg_name not in arg_node_set:
arg_type = required_args[arg_name].type
arg_type_str = (
str(arg_type)
if is_type(arg_type)
else print_ast(cast(TypeNode, arg_type))
)
self.report_error(
GraphQLError(
f"Directive '@{directive_name}' argument '{arg_name}'"
f" of type '{arg_type_str}' is required,"
" but it was not provided.",
directive_node,
)
)
class ProvidedRequiredArgumentsRule(ProvidedRequiredArgumentsOnDirectivesRule):
"""Provided required arguments
A field or directive is only valid if all required (non-null without a default
value) field arguments have been provided.
"""
context: ValidationContext
def __init__(self, context: ValidationContext):
super().__init__(context)
def leave_field(self, field_node: FieldNode, *_args: Any) -> VisitorAction:
# Validate on leave to allow for deeper errors to appear first.
field_def = self.context.get_field_def()
if not field_def:
return SKIP
arg_nodes = field_node.arguments or ()
arg_node_map = {arg.name.value: arg for arg in arg_nodes}
for arg_name, arg_def in field_def.args.items():
arg_node = arg_node_map.get(arg_name)
if not arg_node and is_required_argument(arg_def):
self.report_error(
GraphQLError(
f"Field '{field_node.name.value}' argument '{arg_name}'"
f" of type '{arg_def.type}' is required,"
" but it was not provided.",
field_node,
)
)
return None
def is_required_argument_node(arg: InputValueDefinitionNode) -> bool:
return isinstance(arg.type, NonNullTypeNode) and arg.default_value is None