Astroport.ONE/venv/lib/python3.11/site-packages/graphql/validation/rules/possible_type_extensions.py
2024-03-01 16:15:45 +01:00

110 lines
3.6 KiB
Python

import re
from functools import partial
from typing import Any, Optional
from ...error import GraphQLError
from ...language import TypeDefinitionNode, TypeExtensionNode
from ...pyutils import did_you_mean, inspect, suggestion_list
from ...type import (
is_enum_type,
is_input_object_type,
is_interface_type,
is_object_type,
is_scalar_type,
is_union_type,
)
from . import SDLValidationContext, SDLValidationRule
__all__ = ["PossibleTypeExtensionsRule"]
class PossibleTypeExtensionsRule(SDLValidationRule):
"""Possible type extension
A type extension is only valid if the type is defined and has the same kind.
"""
def __init__(self, context: SDLValidationContext):
super().__init__(context)
self.schema = context.schema
self.defined_types = {
def_.name.value: def_
for def_ in context.document.definitions
if isinstance(def_, TypeDefinitionNode)
}
def check_extension(self, node: TypeExtensionNode, *_args: Any) -> None:
schema = self.schema
type_name = node.name.value
def_node = self.defined_types.get(type_name)
existing_type = schema.get_type(type_name) if schema else None
expected_kind: Optional[str]
if def_node:
expected_kind = def_kind_to_ext_kind(def_node.kind)
elif existing_type:
expected_kind = type_to_ext_kind(existing_type)
else:
expected_kind = None
if expected_kind:
if expected_kind != node.kind:
kind_str = extension_kind_to_type_name(node.kind)
self.report_error(
GraphQLError(
f"Cannot extend non-{kind_str} type '{type_name}'.",
[def_node, node] if def_node else node,
)
)
else:
all_type_names = list(self.defined_types)
if self.schema:
all_type_names.extend(self.schema.type_map)
suggested_types = suggestion_list(type_name, all_type_names)
self.report_error(
GraphQLError(
f"Cannot extend type '{type_name}' because it is not defined."
+ did_you_mean(suggested_types),
node.name,
)
)
enter_scalar_type_extension = enter_object_type_extension = check_extension
enter_interface_type_extension = enter_union_type_extension = check_extension
enter_enum_type_extension = enter_input_object_type_extension = check_extension
def_kind_to_ext_kind = partial(re.compile("(?<=_type_)definition$").sub, "extension")
def type_to_ext_kind(type_: Any) -> str:
if is_scalar_type(type_):
return "scalar_type_extension"
if is_object_type(type_):
return "object_type_extension"
if is_interface_type(type_):
return "interface_type_extension"
if is_union_type(type_):
return "union_type_extension"
if is_enum_type(type_):
return "enum_type_extension"
if is_input_object_type(type_):
return "input_object_type_extension"
# Not reachable. All possible types have been considered.
raise TypeError(f"Unexpected type: {inspect(type_)}.")
_type_names_for_extension_kinds = {
"scalar_type_extension": "scalar",
"object_type_extension": "object",
"interface_type_extension": "interface",
"union_type_extension": "union",
"enum_type_extension": "enum",
"input_object_type_extension": "input object",
}
def extension_kind_to_type_name(kind: str) -> str:
return _type_names_for_extension_kinds.get(kind, "unknown type")