from copy import copy, deepcopy from typing import ( Any, Collection, Dict, List, NamedTuple, Optional, Set, Tuple, Union, cast, ) from ..error import GraphQLError from ..language import ast, OperationType from ..pyutils import inspect, is_collection, is_description from .definition import ( GraphQLAbstractType, GraphQLInterfaceType, GraphQLInputObjectType, GraphQLNamedType, GraphQLObjectType, GraphQLUnionType, GraphQLType, GraphQLWrappingType, get_named_type, is_input_object_type, is_interface_type, is_object_type, is_union_type, is_wrapping_type, ) from .directives import GraphQLDirective, specified_directives, is_directive from .introspection import introspection_types try: from typing import TypedDict except ImportError: # Python < 3.8 from typing_extensions import TypedDict __all__ = ["GraphQLSchema", "GraphQLSchemaKwargs", "is_schema", "assert_schema"] TypeMap = Dict[str, GraphQLNamedType] class InterfaceImplementations(NamedTuple): objects: List[GraphQLObjectType] interfaces: List[GraphQLInterfaceType] class GraphQLSchemaKwargs(TypedDict, total=False): query: Optional[GraphQLObjectType] mutation: Optional[GraphQLObjectType] subscription: Optional[GraphQLObjectType] types: Optional[Tuple[GraphQLNamedType, ...]] directives: Tuple[GraphQLDirective, ...] description: Optional[str] extensions: Dict[str, Any] ast_node: Optional[ast.SchemaDefinitionNode] extension_ast_nodes: Tuple[ast.SchemaExtensionNode, ...] assume_valid: bool class GraphQLSchema: """Schema Definition A Schema is created by supplying the root types of each type of operation, query and mutation (optional). A schema definition is then supplied to the validator and executor. Schemas should be considered immutable once they are created. If you want to modify a schema, modify the result of the ``to_kwargs()`` method and recreate the schema. Example:: MyAppSchema = GraphQLSchema( query=MyAppQueryRootType, mutation=MyAppMutationRootType) Note: When the schema is constructed, by default only the types that are reachable by traversing the root types are included, other types must be explicitly referenced. Example:: character_interface = GraphQLInterfaceType('Character', ...) human_type = GraphQLObjectType( 'Human', interfaces=[character_interface], ...) droid_type = GraphQLObjectType( 'Droid', interfaces: [character_interface], ...) schema = GraphQLSchema( query=GraphQLObjectType('Query', fields={'hero': GraphQLField(character_interface, ....)}), ... # Since this schema references only the `Character` interface it's # necessary to explicitly list the types that implement it if # you want them to be included in the final schema. types=[human_type, droid_type]) Note: If a list of ``directives`` is provided to GraphQLSchema, that will be the exact list of directives represented and allowed. If ``directives`` is not provided, then a default set of the specified directives (e.g. @include and @skip) will be used. If you wish to provide *additional* directives to these specified directives, you must explicitly declare them. Example:: MyAppSchema = GraphQLSchema( ... directives=specified_directives + [my_custom_directive]) """ query_type: Optional[GraphQLObjectType] mutation_type: Optional[GraphQLObjectType] subscription_type: Optional[GraphQLObjectType] type_map: TypeMap directives: Tuple[GraphQLDirective, ...] description: Optional[str] extensions: Dict[str, Any] ast_node: Optional[ast.SchemaDefinitionNode] extension_ast_nodes: Tuple[ast.SchemaExtensionNode, ...] _implementations_map: Dict[str, InterfaceImplementations] _sub_type_map: Dict[str, Set[str]] _validation_errors: Optional[List[GraphQLError]] def __init__( self, query: Optional[GraphQLObjectType] = None, mutation: Optional[GraphQLObjectType] = None, subscription: Optional[GraphQLObjectType] = None, types: Optional[Collection[GraphQLNamedType]] = None, directives: Optional[Collection[GraphQLDirective]] = None, description: Optional[str] = None, extensions: Optional[Dict[str, Any]] = None, ast_node: Optional[ast.SchemaDefinitionNode] = None, extension_ast_nodes: Optional[Collection[ast.SchemaExtensionNode]] = None, assume_valid: bool = False, ) -> None: """Initialize GraphQL schema. If this schema was built from a source known to be valid, then it may be marked with ``assume_valid`` to avoid an additional type system validation. """ self._validation_errors = [] if assume_valid else None # Check for common mistakes during construction to produce clear and early # error messages, but we leave the specific tests for the validation. if query and not isinstance(query, GraphQLType): raise TypeError("Expected query to be a GraphQL type.") if mutation and not isinstance(mutation, GraphQLType): raise TypeError("Expected mutation to be a GraphQL type.") if subscription and not isinstance(subscription, GraphQLType): raise TypeError("Expected subscription to be a GraphQL type.") if types is None: types = [] else: if not is_collection(types) or not all( isinstance(type_, GraphQLType) for type_ in types ): raise TypeError( "Schema types must be specified as a collection of GraphQL types." ) if directives is not None: # noinspection PyUnresolvedReferences if not is_collection(directives): raise TypeError("Schema directives must be a collection.") if not isinstance(directives, tuple): directives = tuple(directives) if description is not None and not is_description(description): raise TypeError("Schema description must be a string.") if extensions is None: extensions = {} elif not isinstance(extensions, dict) or not all( isinstance(key, str) for key in extensions ): raise TypeError("Schema extensions must be a dictionary with string keys.") if ast_node and not isinstance(ast_node, ast.SchemaDefinitionNode): raise TypeError("Schema AST node must be a SchemaDefinitionNode.") if extension_ast_nodes: if not is_collection(extension_ast_nodes) or not all( isinstance(node, ast.SchemaExtensionNode) for node in extension_ast_nodes ): raise TypeError( "Schema extension AST nodes must be specified" " as a collection of SchemaExtensionNode instances." ) if not isinstance(extension_ast_nodes, tuple): extension_ast_nodes = tuple(extension_ast_nodes) else: extension_ast_nodes = () self.description = description self.extensions = extensions self.ast_node = ast_node self.extension_ast_nodes = extension_ast_nodes self.query_type = query self.mutation_type = mutation self.subscription_type = subscription # Provide specified directives (e.g. @include and @skip) by default self.directives = specified_directives if directives is None else directives # To preserve order of user-provided types, we add first to add them to # the set of "collected" types, so `collect_referenced_types` ignore them. if types: all_referenced_types = TypeSet.with_initial_types(types) collect_referenced_types = all_referenced_types.collect_referenced_types for type_ in types: # When we are ready to process this type, we remove it from "collected" # types and then add it together with all dependent types in the correct # position. del all_referenced_types[type_] collect_referenced_types(type_) else: all_referenced_types = TypeSet() collect_referenced_types = all_referenced_types.collect_referenced_types if query: collect_referenced_types(query) if mutation: collect_referenced_types(mutation) if subscription: collect_referenced_types(subscription) for directive in self.directives: # Directives are not validated until validate_schema() is called. if is_directive(directive): for arg in directive.args.values(): collect_referenced_types(arg.type) collect_referenced_types(introspection_types["__Schema"]) # Storing the resulting map for reference by the schema. type_map: TypeMap = {} self.type_map = type_map self._sub_type_map = {} # Keep track of all implementations by interface name. implementations_map: Dict[str, InterfaceImplementations] = {} self._implementations_map = implementations_map for named_type in all_referenced_types: if not named_type: continue type_name = getattr(named_type, "name", None) if not type_name: raise TypeError( "One of the provided types for building the Schema" " is missing a name.", ) if type_name in type_map: raise TypeError( "Schema must contain uniquely named types" f" but contains multiple types named '{type_name}'." ) type_map[type_name] = named_type if is_interface_type(named_type): named_type = cast(GraphQLInterfaceType, named_type) # Store implementations by interface. for iface in named_type.interfaces: if is_interface_type(iface): iface = cast(GraphQLInterfaceType, iface) if iface.name in implementations_map: implementations = implementations_map[iface.name] else: implementations = implementations_map[ iface.name ] = InterfaceImplementations(objects=[], interfaces=[]) implementations.interfaces.append(named_type) elif is_object_type(named_type): named_type = cast(GraphQLObjectType, named_type) # Store implementations by objects. for iface in named_type.interfaces: if is_interface_type(iface): iface = cast(GraphQLInterfaceType, iface) if iface.name in implementations_map: implementations = implementations_map[iface.name] else: implementations = implementations_map[ iface.name ] = InterfaceImplementations(objects=[], interfaces=[]) implementations.objects.append(named_type) def to_kwargs(self) -> GraphQLSchemaKwargs: return GraphQLSchemaKwargs( query=self.query_type, mutation=self.mutation_type, subscription=self.subscription_type, types=tuple(self.type_map.values()) or None, directives=self.directives, description=self.description, extensions=self.extensions, ast_node=self.ast_node, extension_ast_nodes=self.extension_ast_nodes, assume_valid=self._validation_errors is not None, ) def __copy__(self) -> "GraphQLSchema": # pragma: no cover return self.__class__(**self.to_kwargs()) def __deepcopy__(self, memo_: Dict) -> "GraphQLSchema": from ..type import ( is_introspection_type, is_specified_scalar_type, is_specified_directive, ) type_map: TypeMap = { name: copy(type_) for name, type_ in self.type_map.items() if not is_introspection_type(type_) and not is_specified_scalar_type(type_) } types = type_map.values() for type_ in types: remap_named_type(type_, type_map) directives = [ directive if is_specified_directive(directive) else copy(directive) for directive in self.directives ] return self.__class__( self.query_type and cast(GraphQLObjectType, type_map[self.query_type.name]), self.mutation_type and cast(GraphQLObjectType, type_map[self.mutation_type.name]), self.subscription_type and cast(GraphQLObjectType, type_map[self.subscription_type.name]), types, directives, self.description, extensions=deepcopy(self.extensions), ast_node=deepcopy(self.ast_node), extension_ast_nodes=deepcopy(self.extension_ast_nodes), assume_valid=True, ) def get_root_type(self, operation: OperationType) -> Optional[GraphQLObjectType]: return getattr(self, f"{operation.value}_type") def get_type(self, name: str) -> Optional[GraphQLNamedType]: return self.type_map.get(name) def get_possible_types( self, abstract_type: GraphQLAbstractType ) -> List[GraphQLObjectType]: """Get list of all possible concrete types for given abstract type.""" return ( cast(GraphQLUnionType, abstract_type).types if is_union_type(abstract_type) else self.get_implementations( cast(GraphQLInterfaceType, abstract_type) ).objects ) def get_implementations( self, interface_type: GraphQLInterfaceType ) -> InterfaceImplementations: return self._implementations_map.get( interface_type.name, InterfaceImplementations(objects=[], interfaces=[]) ) def is_sub_type( self, abstract_type: GraphQLAbstractType, maybe_sub_type: GraphQLNamedType, ) -> bool: """Check whether a type is a subtype of a given abstract type.""" types = self._sub_type_map.get(abstract_type.name) if types is None: types = set() add = types.add if is_union_type(abstract_type): for type_ in cast(GraphQLUnionType, abstract_type).types: add(type_.name) else: implementations = self.get_implementations( cast(GraphQLInterfaceType, abstract_type) ) for type_ in implementations.objects: add(type_.name) for type_ in implementations.interfaces: add(type_.name) self._sub_type_map[abstract_type.name] = types return maybe_sub_type.name in types def get_directive(self, name: str) -> Optional[GraphQLDirective]: for directive in self.directives: if directive.name == name: return directive return None @property def validation_errors(self) -> Optional[List[GraphQLError]]: return self._validation_errors class TypeSet(Dict[GraphQLNamedType, None]): """An ordered set of types that can be collected starting from initial types.""" @classmethod def with_initial_types(cls, types: Collection[GraphQLType]) -> "TypeSet": return cast(TypeSet, super().fromkeys(types)) def collect_referenced_types(self, type_: GraphQLType) -> None: """Recursive function supplementing the type starting from an initial type.""" named_type = get_named_type(type_) if named_type in self: return self[named_type] = None collect_referenced_types = self.collect_referenced_types if is_union_type(named_type): named_type = cast(GraphQLUnionType, named_type) for member_type in named_type.types: collect_referenced_types(member_type) elif is_object_type(named_type) or is_interface_type(named_type): named_type = cast( Union[GraphQLObjectType, GraphQLInterfaceType], named_type ) for interface_type in named_type.interfaces: collect_referenced_types(interface_type) for field in named_type.fields.values(): collect_referenced_types(field.type) for arg in field.args.values(): collect_referenced_types(arg.type) elif is_input_object_type(named_type): named_type = cast(GraphQLInputObjectType, named_type) for field in named_type.fields.values(): collect_referenced_types(field.type) def is_schema(schema: Any) -> bool: """Test if the given value is a GraphQL schema.""" return isinstance(schema, GraphQLSchema) def assert_schema(schema: Any) -> GraphQLSchema: if not is_schema(schema): raise TypeError(f"Expected {inspect(schema)} to be a GraphQL schema.") return cast(GraphQLSchema, schema) def remapped_type(type_: GraphQLType, type_map: TypeMap) -> GraphQLType: """Get a copy of the given type that uses this type map.""" if is_wrapping_type(type_): type_ = cast(GraphQLWrappingType, type_) return type_.__class__(remapped_type(type_.of_type, type_map)) type_ = cast(GraphQLNamedType, type_) return type_map.get(type_.name, type_) def remap_named_type(type_: GraphQLNamedType, type_map: TypeMap) -> None: """Change all references in the given named type to use this type map.""" if is_union_type(type_): type_ = cast(GraphQLUnionType, type_) type_.types = [ type_map.get(member_type.name, member_type) for member_type in type_.types ] elif is_object_type(type_) or is_interface_type(type_): type_ = cast(Union[GraphQLObjectType, GraphQLInterfaceType], type_) type_.interfaces = [ type_map.get(interface_type.name, interface_type) for interface_type in type_.interfaces ] fields = type_.fields for field_name, field in fields.items(): field = copy(field) field.type = remapped_type(field.type, type_map) args = field.args for arg_name, arg in args.items(): arg = copy(arg) arg.type = remapped_type(arg.type, type_map) args[arg_name] = arg fields[field_name] = field elif is_input_object_type(type_): type_ = cast(GraphQLInputObjectType, type_) fields = type_.fields for field_name, field in fields.items(): field = copy(field) field.type = remapped_type(field.type, type_map) fields[field_name] = field