Astroport.ONE/venv/lib/python3.11/site-packages/graphql/type/validate.py
2024-03-01 16:15:45 +01:00

610 lines
23 KiB
Python

from operator import attrgetter, itemgetter
from typing import (
Any,
Collection,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
from ..error import GraphQLError
from ..pyutils import inspect
from ..language import (
DirectiveNode,
InputValueDefinitionNode,
NamedTypeNode,
Node,
OperationType,
SchemaDefinitionNode,
SchemaExtensionNode,
)
from .definition import (
GraphQLEnumType,
GraphQLInputField,
GraphQLInputObjectType,
GraphQLInterfaceType,
GraphQLObjectType,
GraphQLUnionType,
is_enum_type,
is_input_object_type,
is_input_type,
is_interface_type,
is_named_type,
is_non_null_type,
is_object_type,
is_output_type,
is_union_type,
is_required_argument,
is_required_input_field,
)
from ..utilities.type_comparators import is_equal_type, is_type_sub_type_of
from .directives import is_directive, GraphQLDeprecatedDirective
from .introspection import is_introspection_type
from .schema import GraphQLSchema, assert_schema
__all__ = ["validate_schema", "assert_valid_schema"]
def validate_schema(schema: GraphQLSchema) -> List[GraphQLError]:
"""Validate a GraphQL schema.
Implements the "Type Validation" sub-sections of the specification's "Type System"
section.
Validation runs synchronously, returning a list of encountered errors, or an empty
list if no errors were encountered and the Schema is valid.
"""
# First check to ensure the provided value is in fact a GraphQLSchema.
assert_schema(schema)
# If this Schema has already been validated, return the previous results.
# noinspection PyProtectedMember
errors = schema._validation_errors
if errors is None:
# Validate the schema, producing a list of errors.
context = SchemaValidationContext(schema)
context.validate_root_types()
context.validate_directives()
context.validate_types()
# Persist the results of validation before returning to ensure validation does
# not run multiple times for this schema.
errors = context.errors
schema._validation_errors = errors
return errors
def assert_valid_schema(schema: GraphQLSchema) -> None:
"""Utility function which asserts a schema is valid.
Throws a TypeError if the schema is invalid.
"""
errors = validate_schema(schema)
if errors:
raise TypeError("\n\n".join(error.message for error in errors))
class SchemaValidationContext:
"""Utility class providing a context for schema validation."""
errors: List[GraphQLError]
schema: GraphQLSchema
def __init__(self, schema: GraphQLSchema):
self.errors = []
self.schema = schema
def report_error(
self,
message: str,
nodes: Union[Optional[Node], Collection[Optional[Node]]] = None,
) -> None:
if nodes and not isinstance(nodes, Node):
nodes = [node for node in nodes if node]
nodes = cast(Optional[Collection[Node]], nodes)
self.errors.append(GraphQLError(message, nodes))
def validate_root_types(self) -> None:
schema = self.schema
query_type = schema.query_type
if not query_type:
self.report_error("Query root type must be provided.", schema.ast_node)
elif not is_object_type(query_type):
self.report_error(
f"Query root type must be Object type, it cannot be {query_type}.",
get_operation_type_node(schema, OperationType.QUERY)
or query_type.ast_node,
)
mutation_type = schema.mutation_type
if mutation_type and not is_object_type(mutation_type):
self.report_error(
"Mutation root type must be Object type if provided,"
f" it cannot be {mutation_type}.",
get_operation_type_node(schema, OperationType.MUTATION)
or mutation_type.ast_node,
)
subscription_type = schema.subscription_type
if subscription_type and not is_object_type(subscription_type):
self.report_error(
"Subscription root type must be Object type if provided,"
f" it cannot be {subscription_type}.",
get_operation_type_node(schema, OperationType.SUBSCRIPTION)
or subscription_type.ast_node,
)
def validate_directives(self) -> None:
directives = self.schema.directives
for directive in directives:
# Ensure all directives are in fact GraphQL directives.
if not is_directive(directive):
self.report_error(
f"Expected directive but got: {inspect(directive)}.",
getattr(directive, "ast_node", None),
)
continue
# Ensure they are named correctly.
self.validate_name(directive)
# Ensure the arguments are valid.
for arg_name, arg in directive.args.items():
# Ensure they are named correctly.
self.validate_name(arg, arg_name)
# Ensure the type is an input type.
if not is_input_type(arg.type):
self.report_error(
f"The type of @{directive.name}({arg_name}:)"
f" must be Input Type but got: {inspect(arg.type)}.",
arg.ast_node,
)
if is_required_argument(arg) and arg.deprecation_reason is not None:
self.report_error(
f"Required argument @{directive.name}({arg_name}:)"
" cannot be deprecated.",
[
get_deprecated_directive_node(arg.ast_node),
arg.ast_node and arg.ast_node.type,
],
)
def validate_name(self, node: Any, name: Optional[str] = None) -> None:
# Ensure names are valid, however introspection types opt out.
try:
if not name:
name = node.name
name = cast(str, name)
ast_node = node.ast_node
except AttributeError: # pragma: no cover
pass
else:
if name.startswith("__"):
self.report_error(
f"Name {name!r} must not begin with '__',"
" which is reserved by GraphQL introspection.",
ast_node,
)
def validate_types(self) -> None:
validate_input_object_circular_refs = InputObjectCircularRefsValidator(self)
for type_ in self.schema.type_map.values():
# Ensure all provided types are in fact GraphQL type.
if not is_named_type(type_):
self.report_error(
f"Expected GraphQL named type but got: {inspect(type_)}.",
type_.ast_node if is_named_type(type_) else None,
)
continue
# Ensure it is named correctly (excluding introspection types).
if not is_introspection_type(type_):
self.validate_name(type_)
if is_object_type(type_):
type_ = cast(GraphQLObjectType, type_)
# Ensure fields are valid
self.validate_fields(type_)
# Ensure objects implement the interfaces they claim to.
self.validate_interfaces(type_)
elif is_interface_type(type_):
type_ = cast(GraphQLInterfaceType, type_)
# Ensure fields are valid.
self.validate_fields(type_)
# Ensure interfaces implement the interfaces they claim to.
self.validate_interfaces(type_)
elif is_union_type(type_):
type_ = cast(GraphQLUnionType, type_)
# Ensure Unions include valid member types.
self.validate_union_members(type_)
elif is_enum_type(type_):
type_ = cast(GraphQLEnumType, type_)
# Ensure Enums have valid values.
self.validate_enum_values(type_)
elif is_input_object_type(type_):
type_ = cast(GraphQLInputObjectType, type_)
# Ensure Input Object fields are valid.
self.validate_input_fields(type_)
# Ensure Input Objects do not contain non-nullable circular references
validate_input_object_circular_refs(type_)
def validate_fields(
self, type_: Union[GraphQLObjectType, GraphQLInterfaceType]
) -> None:
fields = type_.fields
# Objects and Interfaces both must define one or more fields.
if not fields:
self.report_error(
f"Type {type_.name} must define one or more fields.",
[type_.ast_node, *type_.extension_ast_nodes],
)
for field_name, field in fields.items():
# Ensure they are named correctly.
self.validate_name(field, field_name)
# Ensure the type is an output type
if not is_output_type(field.type):
self.report_error(
f"The type of {type_.name}.{field_name}"
f" must be Output Type but got: {inspect(field.type)}.",
field.ast_node and field.ast_node.type,
)
# Ensure the arguments are valid.
for arg_name, arg in field.args.items():
# Ensure they are named correctly.
self.validate_name(arg, arg_name)
# Ensure the type is an input type.
if not is_input_type(arg.type):
self.report_error(
f"The type of {type_.name}.{field_name}({arg_name}:)"
f" must be Input Type but got: {inspect(arg.type)}.",
arg.ast_node and arg.ast_node.type,
)
if is_required_argument(arg) and arg.deprecation_reason is not None:
self.report_error(
f"Required argument {type_.name}.{field_name}({arg_name}:)"
" cannot be deprecated.",
[
get_deprecated_directive_node(arg.ast_node),
arg.ast_node and arg.ast_node.type,
],
)
def validate_interfaces(
self, type_: Union[GraphQLObjectType, GraphQLInterfaceType]
) -> None:
iface_type_names: Set[str] = set()
for iface in type_.interfaces:
if not is_interface_type(iface):
self.report_error(
f"Type {type_.name} must only implement Interface"
f" types, it cannot implement {inspect(iface)}.",
get_all_implements_interface_nodes(type_, iface),
)
continue
if type_ is iface:
self.report_error(
f"Type {type_.name} cannot implement itself"
" because it would create a circular reference.",
get_all_implements_interface_nodes(type_, iface),
)
if iface.name in iface_type_names:
self.report_error(
f"Type {type_.name} can only implement {iface.name} once.",
get_all_implements_interface_nodes(type_, iface),
)
continue
iface_type_names.add(iface.name)
self.validate_type_implements_ancestors(type_, iface)
self.validate_type_implements_interface(type_, iface)
def validate_type_implements_interface(
self,
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
iface: GraphQLInterfaceType,
) -> None:
type_fields, iface_fields = type_.fields, iface.fields
# Assert each interface field is implemented.
for field_name, iface_field in iface_fields.items():
type_field = type_fields.get(field_name)
# Assert interface field exists on object.
if not type_field:
self.report_error(
f"Interface field {iface.name}.{field_name}"
f" expected but {type_.name} does not provide it.",
[
iface_field.ast_node,
type_.ast_node,
*type_.extension_ast_nodes,
],
)
continue
# Assert interface field type is satisfied by type field type, by being
# a valid subtype (covariant).
if not is_type_sub_type_of(self.schema, type_field.type, iface_field.type):
self.report_error(
f"Interface field {iface.name}.{field_name}"
f" expects type {iface_field.type}"
f" but {type_.name}.{field_name}"
f" is type {type_field.type}.",
[
iface_field.ast_node and iface_field.ast_node.type,
type_field.ast_node and type_field.ast_node.type,
],
)
# Assert each interface field arg is implemented.
for arg_name, iface_arg in iface_field.args.items():
type_arg = type_field.args.get(arg_name)
# Assert interface field arg exists on object field.
if not type_arg:
self.report_error(
"Interface field argument"
f" {iface.name}.{field_name}({arg_name}:)"
f" expected but {type_.name}.{field_name}"
" does not provide it.",
[iface_arg.ast_node, type_field.ast_node],
)
continue
# Assert interface field arg type matches object field arg type
# (invariant).
if not is_equal_type(iface_arg.type, type_arg.type):
self.report_error(
"Interface field argument"
f" {iface.name}.{field_name}({arg_name}:)"
f" expects type {iface_arg.type}"
f" but {type_.name}.{field_name}({arg_name}:)"
f" is type {type_arg.type}.",
[
iface_arg.ast_node and iface_arg.ast_node.type,
type_arg.ast_node and type_arg.ast_node.type,
],
)
# Assert additional arguments must not be required.
for arg_name, type_arg in type_field.args.items():
iface_arg = iface_field.args.get(arg_name)
if not iface_arg and is_required_argument(type_arg):
self.report_error(
f"Object field {type_.name}.{field_name} includes"
f" required argument {arg_name} that is missing from"
f" the Interface field {iface.name}.{field_name}.",
[type_arg.ast_node, iface_field.ast_node],
)
def validate_type_implements_ancestors(
self,
type_: Union[GraphQLObjectType, GraphQLInterfaceType],
iface: GraphQLInterfaceType,
) -> None:
type_interfaces, iface_interfaces = type_.interfaces, iface.interfaces
for transitive in iface_interfaces:
if transitive not in type_interfaces:
self.report_error(
f"Type {type_.name} cannot implement {iface.name}"
" because it would create a circular reference."
if transitive is type_
else f"Type {type_.name} must implement {transitive.name}"
f" because it is implemented by {iface.name}.",
get_all_implements_interface_nodes(iface, transitive)
+ get_all_implements_interface_nodes(type_, iface),
)
def validate_union_members(self, union: GraphQLUnionType) -> None:
member_types = union.types
if not member_types:
self.report_error(
f"Union type {union.name} must define one or more member types.",
[union.ast_node, *union.extension_ast_nodes],
)
included_type_names: Set[str] = set()
for member_type in member_types:
if is_object_type(member_type):
if member_type.name in included_type_names:
self.report_error(
f"Union type {union.name} can only include type"
f" {member_type.name} once.",
get_union_member_type_nodes(union, member_type.name),
)
else:
included_type_names.add(member_type.name)
else:
self.report_error(
f"Union type {union.name} can only include Object types,"
f" it cannot include {inspect(member_type)}.",
get_union_member_type_nodes(union, str(member_type)),
)
def validate_enum_values(self, enum_type: GraphQLEnumType) -> None:
enum_values = enum_type.values
if not enum_values:
self.report_error(
f"Enum type {enum_type.name} must define one or more values.",
[enum_type.ast_node, *enum_type.extension_ast_nodes],
)
for value_name, enum_value in enum_values.items():
# Ensure valid name.
self.validate_name(enum_value, value_name)
def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
fields = input_obj.fields
if not fields:
self.report_error(
f"Input Object type {input_obj.name}"
" must define one or more fields.",
[input_obj.ast_node, *input_obj.extension_ast_nodes],
)
# Ensure the arguments are valid
for field_name, field in fields.items():
# Ensure they are named correctly.
self.validate_name(field, field_name)
# Ensure the type is an input type.
if not is_input_type(field.type):
self.report_error(
f"The type of {input_obj.name}.{field_name}"
f" must be Input Type but got: {inspect(field.type)}.",
field.ast_node.type if field.ast_node else None,
)
if is_required_input_field(field) and field.deprecation_reason is not None:
self.report_error(
f"Required input field {input_obj.name}.{field_name}"
" cannot be deprecated.",
[
get_deprecated_directive_node(field.ast_node),
field.ast_node and field.ast_node.type,
],
)
def get_operation_type_node(
schema: GraphQLSchema, operation: OperationType
) -> Optional[Node]:
ast_node: Optional[Union[SchemaDefinitionNode, SchemaExtensionNode]]
for ast_node in [schema.ast_node, *(schema.extension_ast_nodes or ())]:
if ast_node:
operation_types = ast_node.operation_types
if operation_types: # pragma: no cover else
for operation_type in operation_types:
if operation_type.operation == operation:
return operation_type.type
return None
class InputObjectCircularRefsValidator:
"""Modified copy of algorithm from validation.rules.NoFragmentCycles"""
def __init__(self, context: SchemaValidationContext):
self.context = context
# Tracks already visited types to maintain O(N) and to ensure that cycles
# are not redundantly reported.
self.visited_types: Set[str] = set()
# Array of input fields used to produce meaningful errors
self.field_path: List[Tuple[str, GraphQLInputField]] = []
# Position in the type path
self.field_path_index_by_type_name: Dict[str, int] = {}
def __call__(self, input_obj: GraphQLInputObjectType) -> None:
"""Detect cycles recursively."""
# This does a straight-forward DFS to find cycles.
# It does not terminate when a cycle was found but continues to explore
# the graph to find all possible cycles.
name = input_obj.name
if name in self.visited_types:
return
self.visited_types.add(name)
self.field_path_index_by_type_name[name] = len(self.field_path)
for field_name, field in input_obj.fields.items():
if is_non_null_type(field.type) and is_input_object_type(
field.type.of_type
):
field_type = cast(GraphQLInputObjectType, field.type.of_type)
cycle_index = self.field_path_index_by_type_name.get(field_type.name)
self.field_path.append((field_name, field))
if cycle_index is None:
self(field_type)
else:
cycle_path = self.field_path[cycle_index:]
field_names = map(itemgetter(0), cycle_path)
self.context.report_error(
f"Cannot reference Input Object '{field_type.name}'"
" within itself through a series of non-null fields:"
f" '{'.'.join(field_names)}'.",
cast(
Collection[Node],
map(attrgetter("ast_node"), map(itemgetter(1), cycle_path)),
),
)
self.field_path.pop()
del self.field_path_index_by_type_name[name]
def get_all_implements_interface_nodes(
type_: Union[GraphQLObjectType, GraphQLInterfaceType], iface: GraphQLInterfaceType
) -> List[NamedTypeNode]:
ast_node = type_.ast_node
nodes = type_.extension_ast_nodes
if ast_node is not None:
nodes = [ast_node, *nodes] # type: ignore
implements_nodes: List[NamedTypeNode] = []
for node in nodes:
iface_nodes = node.interfaces
if iface_nodes: # pragma: no cover else
implements_nodes.extend(
iface_node
for iface_node in iface_nodes
if iface_node.name.value == iface.name
)
return implements_nodes
def get_union_member_type_nodes(
union: GraphQLUnionType, type_name: str
) -> List[NamedTypeNode]:
ast_node = union.ast_node
nodes = union.extension_ast_nodes
if ast_node is not None:
nodes = [ast_node, *nodes] # type: ignore
member_type_nodes: List[NamedTypeNode] = []
for node in nodes:
type_nodes = node.types
if type_nodes: # pragma: no cover else
member_type_nodes.extend(
type_node
for type_node in type_nodes
if type_node.name.value == type_name
)
return member_type_nodes
def get_deprecated_directive_node(
definition_node: Optional[Union[InputValueDefinitionNode]],
) -> Optional[DirectiveNode]:
directives = definition_node and definition_node.directives
if directives:
for directive in directives:
if (
directive.name.value == GraphQLDeprecatedDirective.name
): # pragma: no cover else
return directive
return None # pragma: no cover