From 36d311e04ea7d36711e1fff66b21bcc7f9360a7b Mon Sep 17 00:00:00 2001 From: Eitan Mosenkis Date: Wed, 20 Nov 2024 22:45:48 +0200 Subject: [PATCH 1/6] Don't style dict keys passed to runner. The runner & server can't be expected to receive snake cased keys if the schema uses camel case, etc. --- turms/plugins/funcs.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/turms/plugins/funcs.py b/turms/plugins/funcs.py index 7c62bc3..dbf57c5 100644 --- a/turms/plugins/funcs.py +++ b/turms/plugins/funcs.py @@ -466,7 +466,7 @@ def input_type_to_dict(input_type: GraphQLInputObjectType, registry: ClassRegist for value_key, value in input_type.fields.items(): field_name = registry.generate_node_name(value_key) - keys.append(ast.Constant(value=field_name)) + keys.append(ast.Constant(value=value_key)) values.append(ast.Name(id=field_name, ctx=ast.Load())) @@ -499,28 +499,16 @@ def generate_variable_dict( - keys.append( - ast.Constant( - value=registry.generate_parameter_name(v.variable.name.value) - ) - ) - values.append( - input_type_to_dict(type, registry) - ) - elif plugin_config.argument_key_is_styled: - keys.append( - ast.Constant( - value=registry.generate_parameter_name(v.variable.name.value) - ) - ) + keys.append(ast.Constant(value=v.variable.name.value)) + values.append(input_type_to_dict(type, registry)) else: keys.append(ast.Constant(value=v.variable.name.value)) - values.append( - ast.Name( - id=registry.generate_parameter_name(v.variable.name.value), - ctx=ast.Load(), + values.append( + ast.Name( + id=registry.generate_parameter_name(v.variable.name.value), + ctx=ast.Load(), + ) ) - ) return ast.Dict(keys=keys, values=values) From fdfe8e7de421d8b56ad99dbead51c968c3ea1fd1 Mon Sep 17 00:00:00 2001 From: Eitan Mosenkis Date: Thu, 12 Dec 2024 15:04:38 +0200 Subject: [PATCH 2/6] Support fragments on unions. --- turms/plugins/fragments.py | 345 +++++++++++++++++++++++++++---------- turms/registry.py | 21 ++- 2 files changed, 271 insertions(+), 95 deletions(-) diff --git a/turms/plugins/fragments.py b/turms/plugins/fragments.py index d879794..3ae4bac 100644 --- a/turms/plugins/fragments.py +++ b/turms/plugins/fragments.py @@ -1,13 +1,27 @@ import ast +import logging +from collections import defaultdict, deque from typing import List, Optional -from pydantic_settings import SettingsConfigDict -from turms.config import GeneratorConfig +from graphql import ( + FieldNode, + FragmentDefinitionNode, + FragmentSpreadNode, + GraphQLInterfaceType, + GraphQLObjectType, + GraphQLUnionType, + InlineFragmentNode, + SelectionSetNode, + language, +) from graphql.utilities.build_client_schema import GraphQLSchema -from turms.recurse import type_field_node -from turms.plugins.base import Plugin, PluginConfig +from graphql.utilities.type_info import get_field_def from pydantic import Field -from graphql.language.ast import FragmentDefinitionNode +from pydantic_settings import SettingsConfigDict + +from turms.config import GeneratorConfig, GraphQLTypes +from turms.plugins.base import Plugin, PluginConfig +from turms.recurse import type_field_node from turms.registry import ClassRegistry from turms.utils import ( generate_generic_typename_field, @@ -18,23 +32,13 @@ non_typename_fields, parse_documents, ) -from graphql import ( - FieldNode, - FragmentSpreadNode, - GraphQLInterfaceType, - GraphQLObjectType, - InlineFragmentNode, - SelectionSetNode, - language, -) -from turms.config import GraphQLTypes -from graphql import parse, print_ast -from graphql.language.ast import ( - DocumentNode, OperationDefinitionNode, FragmentDefinitionNode, FragmentSpreadNode -) -from collections import defaultdict, deque -def find_fragment_dependencies_recursive(selection_set: SelectionSetNode, fragment_definitions, visited): +logger = logging.getLogger(__name__) + + +def find_fragment_dependencies_recursive( + selection_set: SelectionSetNode, fragment_definitions, visited +): """Recursively find all fragment dependencies within a selection set.""" dependencies = set() if selection_set is None: @@ -50,43 +54,52 @@ def find_fragment_dependencies_recursive(selection_set: SelectionSetNode, fragme visited.add(spread_name) # Prevent cycles in recursion fragment = fragment_definitions[spread_name] dependencies.update( - find_fragment_dependencies_recursive(fragment.selection_set, fragment_definitions, visited) + find_fragment_dependencies_recursive( + fragment.selection_set, fragment_definitions, visited + ) ) # If it's a field with a nested selection set, dive deeper elif isinstance(selection, FieldNode) and selection.selection_set: dependencies.update( - find_fragment_dependencies_recursive(selection.selection_set, fragment_definitions, visited) + find_fragment_dependencies_recursive( + selection.selection_set, fragment_definitions, visited + ) ) - return dependencies + def build_recursive_dependency_graph(document): """Build a dependency graph for fragments, accounting for deep nested fragment spreads.""" fragment_definitions = { - definition.name.value: definition for definition in document.definitions + definition.name.value: definition + for definition in document.definitions if isinstance(definition, FragmentDefinitionNode) } dependencies = defaultdict(set) - + # Populate the dependency graph with deeply nested fragment dependencies for fragment_name, fragment in fragment_definitions.items(): visited = set() # Track visited fragments to avoid cyclic dependencies - dependencies[fragment_name] = find_fragment_dependencies_recursive(fragment.selection_set, fragment_definitions, visited) - + dependencies[fragment_name] = find_fragment_dependencies_recursive( + fragment.selection_set, fragment_definitions, visited + ) + return dependencies def topological_sort(dependency_graph): """Perform a topological sort on fragments based on recursive dependencies.""" sorted_fragments = [] - no_dependency_fragments = deque([frag for frag, deps in dependency_graph.items() if not deps]) + no_dependency_fragments = deque( + [frag for frag, deps in dependency_graph.items() if not deps] + ) resolved = set(no_dependency_fragments) - + while no_dependency_fragments: fragment = no_dependency_fragments.popleft() sorted_fragments.append(fragment) - + # Remove this fragment from other fragments' dependencies for frag, deps in dependency_graph.items(): if fragment in deps: @@ -94,18 +107,13 @@ def topological_sort(dependency_graph): if not deps and frag not in resolved: no_dependency_fragments.append(frag) resolved.add(frag) - - # Add any remaining fragments that may have been missed if they were independent - sorted_fragments.extend(frag for frag in dependency_graph if frag not in sorted_fragments) - - return sorted_fragments - - -from graphql.utilities.type_info import get_field_def -import logging + # Add any remaining fragments that may have been missed if they were independent + sorted_fragments.extend( + frag for frag in dependency_graph if frag not in sorted_fragments + ) -logger = logging.getLogger(__name__) + return sorted_fragments class FragmentsPluginConfig(PluginConfig): @@ -148,7 +156,6 @@ def get_implementing_types(type: GraphQLInterfaceType, client_schema: GraphQLSch return implementing_types - def generate_fragment( f: FragmentDefinitionNode, client_schema: GraphQLSchema, @@ -169,15 +176,12 @@ def generate_fragment( registry.register_fragment_type(f.name.value, type) - - - if isinstance(type, GraphQLInterfaceType): implementing_types = client_schema.get_implementations(type) - + mother_class_fields = [] - base_fragment_name = registry.style_fragment_class(f.name.value) + base_fragment_name = registry.style_fragment_class(f.name.value) additional_bases = get_additional_bases_for_type(type.name, config, registry) if type.description and plugin_config.add_documentation: @@ -189,30 +193,29 @@ def generate_fragment( mother_class_name = base_fragment_name + "Base" - - implementing_class_base_classes = { - } - + implementing_class_base_classes = {} inline_fragment_fields = {} - - for sub_node in sub_nodes: if isinstance(sub_node, FragmentSpreadNode): # Spread nodes are like inheritance? try: # We are dealing with a fragment that is an interface - implementation_map = registry.get_interface_fragment_implementations(sub_node.name.value) + implementation_map = ( + registry.get_interface_fragment_implementations( + sub_node.name.value + ) + ) for k, v in implementation_map.items(): implementing_class_base_classes.setdefault(k, []).append(v) except KeyError: x = registry.get_fragment_type(sub_node.name.value) - implementing_class_base_classes.setdefault(x, []).append(registry.inherit_fragment(sub_node.name.value)) - - + implementing_class_base_classes.setdefault(x, []).append( + registry.inherit_fragment(sub_node.name.value) + ) if isinstance(sub_node, FieldNode): @@ -236,10 +239,10 @@ def generate_fragment( ) ) - mother_class = ast.ClassDef( mother_class_name, - bases=additional_bases + get_interface_bases(config, registry) , # Todo: fill with base + bases=additional_bases + + get_interface_bases(config, registry), # Todo: fill with base decorator_list=[], keywords=[], body=mother_class_fields if mother_class_fields else [ast.Pass()], @@ -249,57 +252,58 @@ def generate_fragment( catch_class = ast.ClassDef( catch_class_name, - bases=[ast.Name(id=mother_class_name, ctx=ast.Load())], # Todo: fill with base + bases=[ + ast.Name(id=mother_class_name, ctx=ast.Load()) + ], # Todo: fill with base decorator_list=[], keywords=[], - body=[generate_generic_typename_field(registry, config)] + mother_class_fields, + body=[generate_generic_typename_field(registry, config)] + + mother_class_fields, ) - - tree.append(mother_class) tree.append(catch_class) - - - implementaionMap = {} + implementationMap = {} for i in implementing_types.objects: class_name = f"{base_fragment_name}{i.name}" - - ast_base_nodes = [ast.Name(id=x, ctx=ast.Load()) for x in implementing_class_base_classes.get(i, [])] - implementaionMap[i.name] = class_name + ast_base_nodes = [ + ast.Name(id=x, ctx=ast.Load()) + for x in implementing_class_base_classes.get(i, []) + ] + implementationMap[i.name] = class_name inline_fields = inline_fragment_fields.get(i, []) implementing_class = ast.ClassDef( - class_name, - bases=ast_base_nodes + [ast.Name(id=mother_class_name, ctx=ast.Load())] + get_interface_bases(config, registry), # Todo: fill with base - decorator_list=[], - keywords=[], - body=[generate_typename_field(i.name, registry, config)] + inline_fields, + class_name, + bases=ast_base_nodes + + [ast.Name(id=mother_class_name, ctx=ast.Load())] + + get_interface_bases(config, registry), # Todo: fill with base + decorator_list=[], + keywords=[], + body=[generate_typename_field(i.name, registry, config)] + + inline_fields, ) tree.append(implementing_class) - - registry.register_interface_fragment_implementations(f.name.value, implementaionMap) - + registry.register_interface_fragment_implementations( + f.name.value, implementationMap + ) return tree - elif isinstance(type, GraphQLObjectType): additional_bases = get_additional_bases_for_type( f.type_condition.name.value, config, registry ) if type.description and plugin_config.add_documentation: - fields.append( - ast.Expr(value=ast.Constant(value=type.description)) - ) + fields.append(ast.Expr(value=ast.Constant(value=type.description))) fields += [generate_typename_field(type.name, registry, config)] @@ -314,7 +318,9 @@ def generate_fragment( if isinstance(field, FragmentSpreadNode): try: - implementationMap = registry.get_interface_fragment_implementations(field.name.value) + implementationMap = registry.get_interface_fragment_implementations( + field.name.value + ) if type.name in implementationMap: additional_bases = [ ast.Name( @@ -323,7 +329,9 @@ def generate_fragment( ) ] + additional_bases else: - raise Exception(f"Could not find implementation for {type.name} in {implementationMap}") + raise Exception( + f"Could not find implementation for {type.name} in {implementationMap}" + ) except KeyError: additional_bases = [ ast.Name( @@ -359,13 +367,169 @@ def generate_fragment( ) return tree + elif isinstance(type, GraphQLUnionType): + mother_class_fields = [] + base_fragment_name = registry.style_fragment_class(f.name.value) + additional_bases = get_additional_bases_for_type(type.name, config, registry) + + if type.description and plugin_config.add_documentation: + mother_class_fields.append( + ast.Expr(value=ast.Constant(value=type.description)) + ) + + sub_nodes = non_typename_fields(f) + + mother_class_name = base_fragment_name # + "Base" + + implementing_class_base_classes = {} + + inline_fragment_fields = {} + + for sub_node in sub_nodes: + + if isinstance(sub_node, FragmentSpreadNode): + # Spread nodes are like inheritance? + try: + # We are dealing with a fragment that is a union + implementation_map = registry.get_union_fragment_implementations( + sub_node.name.value + ) + for k, v in implementation_map.items(): + implementing_class_base_classes.setdefault(k, []).append(v) + + except KeyError: + x = registry.get_fragment_type(sub_node.name.value) + implementing_class_base_classes.setdefault(x, []).append( + registry.inherit_fragment(sub_node.name.value) + ) + + elif isinstance(sub_node, FieldNode): + raise AssertionError("Union types should not have fields") + + elif isinstance(sub_node, InlineFragmentNode): + on_type_name = sub_node.type_condition.name.value + + fields = [] + for field in sub_node.selection_set.selections: + + if field.name.value == "__typename": + continue + + if isinstance(field, FragmentSpreadNode): + try: + implementationMap = ( + registry.get_interface_fragment_implementations( + field.name.value + ) + ) + if type.name in implementationMap: + additional_bases = [ + ast.Name( + id=implementationMap[type.name], + ctx=ast.Load(), + ) + ] + additional_bases + else: + raise Exception( + f"Could not find implementation for {type.name} in {implementationMap}" + ) + except KeyError: + additional_bases = [ + ast.Name( + id=registry.inherit_fragment(field.name.value), + ctx=ast.Load(), + ) + ] + additional_bases # needs to be prepended (MRO) + continue + + field_definition = get_field_def( + client_schema, client_schema.get_type(on_type_name), field + ) + assert ( + field_definition + ), f"Couldn't find field definition for {on_type_name}.{field.name.value}" + + fields += type_field_node( + field, + name, + field_definition, + client_schema, + config, + tree, + registry, + ) + + inline_fragment_fields.setdefault(on_type_name, []).extend(fields) + else: + raise AssertionError(f"Unknown node type: {type(sub_node)}") + + implementationMap = {} + + for i in type.types: + + class_name = f"{base_fragment_name}{i.name}" + + ast_base_nodes = [ + ast.Name(id=x, ctx=ast.Load()) + for x in implementing_class_base_classes.get(i, []) + ] + implementationMap[i.name] = class_name + + inline_fields = inline_fragment_fields.get(i.name, []) + + implementing_class = ast.ClassDef( + class_name, + bases=ast_base_nodes + + get_fragment_bases(config, plugin_config, registry), + decorator_list=[], + keywords=[], + body=[generate_typename_field(i.name, registry, config)] + + inline_fields, + ) + + tree.append(implementing_class) + + registry.register_union_fragment_implementations( + f.name.value, implementationMap + ) + + registry.register_import("typing.TypeAlias") + registry.register_import("typing.Union") + mother_class = ast.AnnAssign( + target=ast.Name(id=base_fragment_name, ctx=ast.Load()), + annotation=ast.Name(id="TypeAlias", ctx=ast.Load()), + value=ast.Subscript( + value=ast.Name(id="Union", ctx=ast.Load()), + slice=ast.Tuple( + elts=[ + ast.Name(id=f"{base_fragment_name}{i.name}", ctx=ast.Load()) + for i in type.types + ], + ctx=ast.Load(), + ), + ), + simple=1, + ) + tree.append(mother_class) + + return tree + + def reorder_definitions(document, sorted_fragments): """Reorder document definitions to place fragments in dependency order.""" - fragment_definitions = {defn.name.value: defn for defn in document.definitions if isinstance(defn, FragmentDefinitionNode)} - + fragment_definitions = { + defn.name.value: defn + for defn in document.definitions + if isinstance(defn, FragmentDefinitionNode) + } + # Order fragments according to the topologically sorted order - ordered_fragments = [fragment_definitions[name] for name in sorted_fragments if name in fragment_definitions] - + ordered_fragments = [ + fragment_definitions[name] + for name in sorted_fragments + if name in fragment_definitions + ] + # Combine operations and ordered fragments return ordered_fragments @@ -401,16 +565,11 @@ def generate_ast( # Find dependencies and sort fragments topologically fragment_dependencies = build_recursive_dependency_graph(documents) - - sorted_fragments = topological_sort(fragment_dependencies) - + sorted_fragments = topological_sort(fragment_dependencies) ordered_fragments = reorder_definitions(documents, sorted_fragments) - - - for fragment in ordered_fragments: plugin_tree += generate_fragment( fragment, client_schema, config, self.config, registry diff --git a/turms/registry.py b/turms/registry.py index 02d99c1..8c8d2d1 100644 --- a/turms/registry.py +++ b/turms/registry.py @@ -102,12 +102,12 @@ def __init__( self.mutation_class_map = {} self.registered_interfaces_fragments = {} - + self.registered_union_fragments = {} self.forward_references = set() self.fragment_type_map = {} - self.interfacefragments_class_map = {} self.interfacefragments_impl_map = {} + self.unionfragments_impl_map = {} self.log = log def style_inputtype_class(self, typename: str): @@ -282,6 +282,14 @@ def register_interface_fragment_implementations(self, fragmentname: str, impleme def get_interface_fragment_implementations(self, fragmentname: str): return self.interfacefragments_impl_map[fragmentname] + + + def register_union_fragment_implementations(self, fragmentname: str, implementationMap: Dict[str, str]): + self.unionfragments_impl_map[fragmentname] = implementationMap + + + def get_union_fragment_implementations(self, fragmentname: str): + return self.unionfragments_impl_map[fragmentname] def get_fragment_type(self, fragmentname: str): @@ -310,6 +318,15 @@ def reference_interface_fragment(self, typename: str, parent: str, allow_forward def register_interface_fragment(self, typename: str, ast: ast.AST): self.registered_interfaces_fragments[typename] = ast + + def is_union_fragment(self, typename: str): + return typename in self.registered_union_fragments + + def reference_union_fragment(self, typename: str, parent: str, allow_forward=True) -> ast.AST: + return self.registered_union_fragments[typename] + + def register_union_fragment(self, typename: str, ast: ast.AST): + self.registered_union_fragments[typename] = ast def inherit_fragment(self, typename: str, allow_forward=True) -> ast.AST: if typename not in self.fragment_class_map: From d0ba8306fa8a789d68c20b6aa7d9c9dc76114aba Mon Sep 17 00:00:00 2001 From: Eitan Mosenkis Date: Thu, 12 Dec 2024 15:49:35 +0200 Subject: [PATCH 3/6] Support unions in more places. --- turms/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/turms/utils.py b/turms/utils.py index 9974640..e150e87 100644 --- a/turms/utils.py +++ b/turms/utils.py @@ -16,6 +16,7 @@ GraphQLObjectType, GraphQLOutputType, GraphQLScalarType, + GraphQLUnionType, IntValueNode, ListTypeNode, NamedTypeNode, @@ -701,7 +702,7 @@ def recurse_outputtype_annotation( else: return registry.reference_scalar(type.name) - if isinstance(type, GraphQLObjectType) or isinstance(type, GraphQLInterfaceType): + if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType)): assert overwrite_final, "Needs to be set" if optional: registry.register_import("typing.Optional") @@ -713,7 +714,7 @@ def recurse_outputtype_annotation( else: return ast.Name(id=overwrite_final, ctx=ast.Load()) - raise NotImplementedError("oisnosin") # pragma: no cover + raise NotImplementedError(f"recurse over {type.__class__.__name__}") # pragma: no cover def recurse_outputtype_label( @@ -762,7 +763,7 @@ def recurse_outputtype_label( else: return registry.reference_scalar(type.name).id - if isinstance(type, GraphQLObjectType) or isinstance(type, GraphQLInterfaceType): + if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType, GraphQLUnionType)): assert overwrite_final, "Needs to be set" if optional: return "Optional[" + overwrite_final + "]" From cb5902336e152598749f3ece836d9dc9904c5522 Mon Sep 17 00:00:00 2001 From: Eitan Mosenkis Date: Fri, 13 Dec 2024 14:54:00 +0200 Subject: [PATCH 4/6] Add docstring. --- turms/plugins/fragments.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/turms/plugins/fragments.py b/turms/plugins/fragments.py index 3ae4bac..bdfb970 100644 --- a/turms/plugins/fragments.py +++ b/turms/plugins/fragments.py @@ -368,15 +368,9 @@ def generate_fragment( return tree elif isinstance(type, GraphQLUnionType): - mother_class_fields = [] base_fragment_name = registry.style_fragment_class(f.name.value) additional_bases = get_additional_bases_for_type(type.name, config, registry) - if type.description and plugin_config.add_documentation: - mother_class_fields.append( - ast.Expr(value=ast.Constant(value=type.description)) - ) - sub_nodes = non_typename_fields(f) mother_class_name = base_fragment_name # + "Base" @@ -442,8 +436,10 @@ def generate_fragment( ] + additional_bases # needs to be prepended (MRO) continue + field_type = client_schema.get_type(on_type_name) + field_definition = get_field_def( - client_schema, client_schema.get_type(on_type_name), field + client_schema, field_type, field ) assert ( field_definition @@ -510,8 +506,14 @@ def generate_fragment( ), simple=1, ) + tree.append(mother_class) + if type.description and plugin_config.add_documentation: + tree.append( + ast.Expr(value=ast.Constant(value=type.description)) + ) + return tree From bf6078ee13527a827734ac5fdb110f9790aeea4b Mon Sep 17 00:00:00 2001 From: Eitan Mosenkis Date: Fri, 13 Dec 2024 15:13:55 +0200 Subject: [PATCH 5/6] Fix reference finding for fragments on unions. --- turms/referencer.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/turms/referencer.py b/turms/referencer.py index e605465..48e418d 100644 --- a/turms/referencer.py +++ b/turms/referencer.py @@ -1,29 +1,31 @@ from typing import Dict, Set -from graphql.utilities.build_client_schema import GraphQLSchema -from graphql.language.ast import DocumentNode, FieldNode + from graphql import ( + FragmentDefinitionNode, GraphQLEnumType, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLList, GraphQLNonNull, GraphQLObjectType, - GraphQLInputObjectType, GraphQLScalarType, ListTypeNode, NamedTypeNode, NonNullTypeNode, OperationDefinitionNode, - FragmentDefinitionNode, - GraphQLInterfaceType, - GraphQLInputField, -) -from graphql.type.definition import ( - GraphQLType, - GraphQLUnionType, ) from graphql.language.ast import ( + DocumentNode, + FieldNode, FragmentSpreadNode, InlineFragmentNode, ) +from graphql.type.definition import ( + GraphQLType, + GraphQLUnionType, +) +from graphql.utilities.build_client_schema import GraphQLSchema class ReferenceRegistry: @@ -214,6 +216,19 @@ def create_reference_registry_from_documents( schema, registry, ) + elif isinstance(selection, InlineFragmentNode): + sub_type = schema.get_type(selection.type_condition.name.value) + for sub_selection in selection.selection_set.selections: + if isinstance(sub_selection, FieldNode): + if sub_selection.name.value == "__typename": + continue + sub_this_type = sub_type.fields[sub_selection.name.value] + recurse_find_references( + sub_selection, + sub_this_type.type, + schema, + registry, + ) for operation in operations.values(): type = schema.get_root_type(operation.operation) From c8c03394c10f7b9b44df3f3c002e154c41748c87 Mon Sep 17 00:00:00 2001 From: Eitan Mosenkis Date: Sun, 15 Dec 2024 15:11:28 +0200 Subject: [PATCH 6/6] Tests and bug fixes. --- tests/documents/unions/test.graphql | 45 ++++++ tests/schemas/union.graphql | 18 ++- tests/test_referencer.py | 10 ++ tests/test_unions.py | 40 +++++- tests/utils.py | 10 +- turms/plugins/fragments.py | 206 +++++++++++----------------- turms/referencer.py | 2 +- turms/registry.py | 24 ++-- turms/utils.py | 22 +-- 9 files changed, 212 insertions(+), 165 deletions(-) diff --git a/tests/documents/unions/test.graphql b/tests/documents/unions/test.graphql index aa265e7..eee588f 100644 --- a/tests/documents/unions/test.graphql +++ b/tests/documents/unions/test.graphql @@ -1,10 +1,55 @@ query Nana { hallo { + __typename ... on Bar { nana } ... on Foo { forward + blip } } } + +query Nana2 { + hallo { + ...BazFragment + } +} + +query Nana3 { + hallo { + ...VeryNestedFragment + } +} + +query Nana4 { + hallo { + ...DelegatingFragment + } +} + +fragment BazFragment on Element { + ... on Baz { + __typename + bloop + } +} + +fragment VeryNestedFragment on Element { + __typename + ...BazFragment + ... on Bar { + nana + } +} + +fragment DelegateFragment on Bar { + nana +} + +fragment DelegatingFragment on Element { + ... on Bar { + ...DelegateFragment + } +} diff --git a/tests/schemas/union.graphql b/tests/schemas/union.graphql index 281d1d0..d180d96 100644 --- a/tests/schemas/union.graphql +++ b/tests/schemas/union.graphql @@ -1,7 +1,18 @@ +enum TestEnum1 { + A + B +} + +enum TestEnum2 { + C + D +} + "This is foo" type Foo { "This is a forward ref" forward: String! + blip: TestEnum1 } type Bar { @@ -9,7 +20,12 @@ type Bar { nana: Int! } -union Element = Foo | Bar +type Baz { + bloop: TestEnum2! +} + +"This is a union" +union Element = Foo | Bar | Baz type Query { hallo: Element diff --git a/tests/test_referencer.py b/tests/test_referencer.py index 597c234..b620adf 100644 --- a/tests/test_referencer.py +++ b/tests/test_referencer.py @@ -1,5 +1,6 @@ from turms.referencer import create_reference_registry_from_documents from turms.utils import parse_documents + from .utils import build_relative_glob @@ -19,3 +20,12 @@ def test_referencer_countries(countries_schema): assert ( "StringQueryOperatorInput" not in z.inputs ), "StringQueryOperatorInput should be skipped" + + +def test_referencer_enum_in_union_fragment(union_schema): + + x = build_relative_glob("/documents/unions/*.graphql") + docs = parse_documents(union_schema, x) + z = create_reference_registry_from_documents(union_schema, docs) + assert "TestEnum1" in z.enums, "TestEnum1 should be referenced (in operation)" + assert "TestEnum2" in z.enums, "TestEnum2 should be referenced (in fragment)" diff --git a/tests/test_unions.py b/tests/test_unions.py index e8eb774..573455f 100644 --- a/tests/test_unions.py +++ b/tests/test_unions.py @@ -1,19 +1,19 @@ -from .utils import build_relative_glob, unit_test_with from turms.config import GeneratorConfig -from turms.run import generate_ast from turms.plugins.enums import EnumsPlugin -from turms.plugins.inputs import InputsPlugin from turms.plugins.fragments import FragmentsPlugin -from turms.plugins.operations import OperationsPlugin from turms.plugins.funcs import ( - FunctionDefinition, FuncsPlugin, FuncsPluginConfig, + FunctionDefinition, ) -from turms.stylers.snake_case import SnakeCaseStyler -from turms.stylers.capitalize import CapitalizeStyler +from turms.plugins.inputs import InputsPlugin +from turms.plugins.operations import OperationsPlugin from turms.run import generate_ast +from turms.stylers.capitalize import CapitalizeStyler +from turms.stylers.snake_case import SnakeCaseStyler + +from .utils import build_relative_glob, unit_test_with def test_nested_input_funcs(union_schema): @@ -52,3 +52,29 @@ def test_nested_input_funcs(union_schema): generated_ast, 'Nana(hallo={"__typename": "Foo","forward": "yes"}).hallo.forward', ) + +def test_fields_in_inline_fragment_on_union(union_schema): + config = GeneratorConfig( + documents=build_relative_glob("/documents/unions/*.graphql"), + ) + generated_ast = generate_ast( + config, + union_schema, + stylers=[CapitalizeStyler(), SnakeCaseStyler()], + plugins=[EnumsPlugin(), InputsPlugin(), FragmentsPlugin(), OperationsPlugin()], + ) + + unit_test_with( + generated_ast, + """ + assert Nana(hallo={'__typename': 'Foo', 'blip': 'A', 'forward': 'yes'}).hallo.blip == 'A' + assert Nana(hallo={'__typename': 'Bar', 'nana': 1}).hallo.nana == 1 + assert Nana2(hallo={'__typename': 'Baz', 'bloop': 'C'}).hallo.bloop == 'C' + assert Nana3(hallo={'__typename': 'Baz', 'bloop': 'C'}).hallo.bloop == 'C' + assert Nana3(hallo={'__typename': 'Bar', 'nana': 1}).hallo.nana == 1 + assert Nana4(hallo={'__typename': 'Bar', 'nana': 1}).hallo.nana == 1 + """, + ) + + + diff --git a/tests/utils.py b/tests/utils.py index 7154457..17f3776 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,10 +1,12 @@ -import subprocess -import os import ast +import os +import subprocess import sys +import tempfile +from textwrap import dedent from typing import List + from turms.run import write_code_to_file -import tempfile DIR_NAME = os.path.dirname(os.path.realpath(__file__)) @@ -65,7 +67,7 @@ def parse_to_code(tree: List[ast.AST]) -> str: def unit_test_with(generated_ast: List[ast.AST], test_string: str): - added_code = ast.parse(test_string).body + added_code = ast.parse(dedent(test_string)).body # We need to unparse before otherwise there might be complaints with missing lineno parsed_code = parse_to_code(generated_ast + added_code) diff --git a/turms/plugins/fragments.py b/turms/plugins/fragments.py index bdfb970..4ecc7a6 100644 --- a/turms/plugins/fragments.py +++ b/turms/plugins/fragments.py @@ -36,9 +36,7 @@ logger = logging.getLogger(__name__) -def find_fragment_dependencies_recursive( - selection_set: SelectionSetNode, fragment_definitions, visited -): +def find_fragment_dependencies_recursive(selection_set: SelectionSetNode, fragment_definitions, visited): """Recursively find all fragment dependencies within a selection set.""" dependencies = set() if selection_set is None: @@ -54,52 +52,43 @@ def find_fragment_dependencies_recursive( visited.add(spread_name) # Prevent cycles in recursion fragment = fragment_definitions[spread_name] dependencies.update( - find_fragment_dependencies_recursive( - fragment.selection_set, fragment_definitions, visited - ) + find_fragment_dependencies_recursive(fragment.selection_set, fragment_definitions, visited) ) # If it's a field with a nested selection set, dive deeper elif isinstance(selection, FieldNode) and selection.selection_set: dependencies.update( - find_fragment_dependencies_recursive( - selection.selection_set, fragment_definitions, visited - ) + find_fragment_dependencies_recursive(selection.selection_set, fragment_definitions, visited) ) - return dependencies + return dependencies def build_recursive_dependency_graph(document): """Build a dependency graph for fragments, accounting for deep nested fragment spreads.""" fragment_definitions = { - definition.name.value: definition - for definition in document.definitions + definition.name.value: definition for definition in document.definitions if isinstance(definition, FragmentDefinitionNode) } dependencies = defaultdict(set) - + # Populate the dependency graph with deeply nested fragment dependencies for fragment_name, fragment in fragment_definitions.items(): visited = set() # Track visited fragments to avoid cyclic dependencies - dependencies[fragment_name] = find_fragment_dependencies_recursive( - fragment.selection_set, fragment_definitions, visited - ) - + dependencies[fragment_name] = find_fragment_dependencies_recursive(fragment.selection_set, fragment_definitions, visited) + return dependencies def topological_sort(dependency_graph): """Perform a topological sort on fragments based on recursive dependencies.""" sorted_fragments = [] - no_dependency_fragments = deque( - [frag for frag, deps in dependency_graph.items() if not deps] - ) + no_dependency_fragments = deque([frag for frag, deps in dependency_graph.items() if not deps]) resolved = set(no_dependency_fragments) - + while no_dependency_fragments: fragment = no_dependency_fragments.popleft() sorted_fragments.append(fragment) - + # Remove this fragment from other fragments' dependencies for frag, deps in dependency_graph.items(): if fragment in deps: @@ -107,12 +96,10 @@ def topological_sort(dependency_graph): if not deps and frag not in resolved: no_dependency_fragments.append(frag) resolved.add(frag) - + # Add any remaining fragments that may have been missed if they were independent - sorted_fragments.extend( - frag for frag in dependency_graph if frag not in sorted_fragments - ) - + sorted_fragments.extend(frag for frag in dependency_graph if frag not in sorted_fragments) + return sorted_fragments @@ -156,6 +143,7 @@ def get_implementing_types(type: GraphQLInterfaceType, client_schema: GraphQLSch return implementing_types + def generate_fragment( f: FragmentDefinitionNode, client_schema: GraphQLSchema, @@ -176,12 +164,15 @@ def generate_fragment( registry.register_fragment_type(f.name.value, type) + + + if isinstance(type, GraphQLInterfaceType): implementing_types = client_schema.get_implementations(type) - + mother_class_fields = [] - base_fragment_name = registry.style_fragment_class(f.name.value) + base_fragment_name = registry.style_fragment_class(f.name.value) additional_bases = get_additional_bases_for_type(type.name, config, registry) if type.description and plugin_config.add_documentation: @@ -193,29 +184,30 @@ def generate_fragment( mother_class_name = base_fragment_name + "Base" - implementing_class_base_classes = {} + + implementing_class_base_classes = { + } + inline_fragment_fields = {} + + for sub_node in sub_nodes: if isinstance(sub_node, FragmentSpreadNode): # Spread nodes are like inheritance? try: # We are dealing with a fragment that is an interface - implementation_map = ( - registry.get_interface_fragment_implementations( - sub_node.name.value - ) - ) + implementation_map = registry.get_interface_fragment_implementations(sub_node.name.value) for k, v in implementation_map.items(): implementing_class_base_classes.setdefault(k, []).append(v) except KeyError: x = registry.get_fragment_type(sub_node.name.value) - implementing_class_base_classes.setdefault(x, []).append( - registry.inherit_fragment(sub_node.name.value) - ) + implementing_class_base_classes.setdefault(x, []).append(registry.inherit_fragment(sub_node.name.value)) + + if isinstance(sub_node, FieldNode): @@ -239,10 +231,10 @@ def generate_fragment( ) ) + mother_class = ast.ClassDef( mother_class_name, - bases=additional_bases - + get_interface_bases(config, registry), # Todo: fill with base + bases=additional_bases + get_interface_bases(config, registry) , # Todo: fill with base decorator_list=[], keywords=[], body=mother_class_fields if mother_class_fields else [ast.Pass()], @@ -252,58 +244,57 @@ def generate_fragment( catch_class = ast.ClassDef( catch_class_name, - bases=[ - ast.Name(id=mother_class_name, ctx=ast.Load()) - ], # Todo: fill with base + bases=[ast.Name(id=mother_class_name, ctx=ast.Load())], # Todo: fill with base decorator_list=[], keywords=[], - body=[generate_generic_typename_field(registry, config)] - + mother_class_fields, + body=[generate_generic_typename_field(registry, config)] + mother_class_fields, ) + + tree.append(mother_class) tree.append(catch_class) - implementationMap = {} + + + implementaionMap = {} for i in implementing_types.objects: class_name = f"{base_fragment_name}{i.name}" - ast_base_nodes = [ - ast.Name(id=x, ctx=ast.Load()) - for x in implementing_class_base_classes.get(i, []) - ] - implementationMap[i.name] = class_name + + ast_base_nodes = [ast.Name(id=x, ctx=ast.Load()) for x in implementing_class_base_classes.get(i, [])] + implementaionMap[i.name] = class_name inline_fields = inline_fragment_fields.get(i, []) implementing_class = ast.ClassDef( - class_name, - bases=ast_base_nodes - + [ast.Name(id=mother_class_name, ctx=ast.Load())] - + get_interface_bases(config, registry), # Todo: fill with base - decorator_list=[], - keywords=[], - body=[generate_typename_field(i.name, registry, config)] - + inline_fields, + class_name, + bases=ast_base_nodes + [ast.Name(id=mother_class_name, ctx=ast.Load())] + get_interface_bases(config, registry), # Todo: fill with base + decorator_list=[], + keywords=[], + body=[generate_typename_field(i.name, registry, config)] + inline_fields, ) tree.append(implementing_class) - registry.register_interface_fragment_implementations( - f.name.value, implementationMap - ) + + registry.register_interface_fragment_implementations(f.name.value, implementaionMap) + return tree + elif isinstance(type, GraphQLObjectType): additional_bases = get_additional_bases_for_type( f.type_condition.name.value, config, registry ) if type.description and plugin_config.add_documentation: - fields.append(ast.Expr(value=ast.Constant(value=type.description))) + fields.append( + ast.Expr(value=ast.Constant(value=type.description)) + ) fields += [generate_typename_field(type.name, registry, config)] @@ -318,9 +309,7 @@ def generate_fragment( if isinstance(field, FragmentSpreadNode): try: - implementationMap = registry.get_interface_fragment_implementations( - field.name.value - ) + implementationMap = registry.get_interface_fragment_implementations(field.name.value) if type.name in implementationMap: additional_bases = [ ast.Name( @@ -329,9 +318,7 @@ def generate_fragment( ) ] + additional_bases else: - raise Exception( - f"Could not find implementation for {type.name} in {implementationMap}" - ) + raise Exception(f"Could not find implementation for {type.name} in {implementationMap}") except KeyError: additional_bases = [ ast.Name( @@ -375,29 +362,24 @@ def generate_fragment( mother_class_name = base_fragment_name # + "Base" - implementing_class_base_classes = {} + member_class_base_classes = {} inline_fragment_fields = {} for sub_node in sub_nodes: if isinstance(sub_node, FragmentSpreadNode): - # Spread nodes are like inheritance? - try: - # We are dealing with a fragment that is a union - implementation_map = registry.get_union_fragment_implementations( - sub_node.name.value - ) - for k, v in implementation_map.items(): - implementing_class_base_classes.setdefault(k, []).append(v) - - except KeyError: - x = registry.get_fragment_type(sub_node.name.value) - implementing_class_base_classes.setdefault(x, []).append( - registry.inherit_fragment(sub_node.name.value) - ) + # Spread nodes are like inheritance + # We are dealing with a fragment that is a union + implementation_map = registry.get_union_fragment_members( + sub_node.name.value + ) + for k, v in implementation_map.items(): + member_class_base_classes.setdefault(k, []).append(v) elif isinstance(sub_node, FieldNode): + if sub_node.name.value == "__typename": + continue raise AssertionError("Union types should not have fields") elif isinstance(sub_node, InlineFragmentNode): @@ -410,30 +392,9 @@ def generate_fragment( continue if isinstance(field, FragmentSpreadNode): - try: - implementationMap = ( - registry.get_interface_fragment_implementations( - field.name.value - ) - ) - if type.name in implementationMap: - additional_bases = [ - ast.Name( - id=implementationMap[type.name], - ctx=ast.Load(), - ) - ] + additional_bases - else: - raise Exception( - f"Could not find implementation for {type.name} in {implementationMap}" - ) - except KeyError: - additional_bases = [ - ast.Name( - id=registry.inherit_fragment(field.name.value), - ctx=ast.Load(), - ) - ] + additional_bases # needs to be prepended (MRO) + member_class_base_classes.setdefault(on_type_name, []).append( + field.name.value + ) continue field_type = client_schema.get_type(on_type_name) @@ -467,7 +428,7 @@ def generate_fragment( ast_base_nodes = [ ast.Name(id=x, ctx=ast.Load()) - for x in implementing_class_base_classes.get(i, []) + for x in member_class_base_classes.get(i.name, []) ] implementationMap[i.name] = class_name @@ -485,15 +446,13 @@ def generate_fragment( tree.append(implementing_class) - registry.register_union_fragment_implementations( + registry.register_union_fragment_members( f.name.value, implementationMap ) - registry.register_import("typing.TypeAlias") registry.register_import("typing.Union") - mother_class = ast.AnnAssign( - target=ast.Name(id=base_fragment_name, ctx=ast.Load()), - annotation=ast.Name(id="TypeAlias", ctx=ast.Load()), + mother_class = ast.Assign( + targets=[ast.Name(id=base_fragment_name, ctx=ast.Load())], value=ast.Subscript( value=ast.Name(id="Union", ctx=ast.Load()), slice=ast.Tuple( @@ -519,18 +478,10 @@ def generate_fragment( def reorder_definitions(document, sorted_fragments): """Reorder document definitions to place fragments in dependency order.""" - fragment_definitions = { - defn.name.value: defn - for defn in document.definitions - if isinstance(defn, FragmentDefinitionNode) - } + fragment_definitions = {defn.name.value: defn for defn in document.definitions if isinstance(defn, FragmentDefinitionNode)} # Order fragments according to the topologically sorted order - ordered_fragments = [ - fragment_definitions[name] - for name in sorted_fragments - if name in fragment_definitions - ] + ordered_fragments = [fragment_definitions[name] for name in sorted_fragments if name in fragment_definitions] # Combine operations and ordered fragments return ordered_fragments @@ -567,11 +518,16 @@ def generate_ast( # Find dependencies and sort fragments topologically fragment_dependencies = build_recursive_dependency_graph(documents) - + sorted_fragments = topological_sort(fragment_dependencies) + + ordered_fragments = reorder_definitions(documents, sorted_fragments) + + + for fragment in ordered_fragments: plugin_tree += generate_fragment( fragment, client_schema, config, self.config, registry diff --git a/turms/referencer.py b/turms/referencer.py index 48e418d..345e2f5 100644 --- a/turms/referencer.py +++ b/turms/referencer.py @@ -81,7 +81,7 @@ def recurse_find_references( continue field_type = sub_sub_node_type.fields[sub_sub_node.name.value] - return recurse_find_references( + recurse_find_references( sub_sub_node, field_type.type, client_schema, diff --git a/turms/registry.py b/turms/registry.py index 8c8d2d1..7feee48 100644 --- a/turms/registry.py +++ b/turms/registry.py @@ -1,7 +1,8 @@ import ast +from keyword import iskeyword from typing import Dict, List + from turms.config import GeneratorConfig, LogFunction -from keyword import iskeyword from turms.errors import ( NoEnumFound, NoInputTypeFound, @@ -107,7 +108,7 @@ def __init__( self.fragment_type_map = {} self.interfacefragments_impl_map = {} - self.unionfragments_impl_map = {} + self.unionfragment_members_map = {} self.log = log def style_inputtype_class(self, typename: str): @@ -282,14 +283,14 @@ def register_interface_fragment_implementations(self, fragmentname: str, impleme def get_interface_fragment_implementations(self, fragmentname: str): return self.interfacefragments_impl_map[fragmentname] - - def register_union_fragment_implementations(self, fragmentname: str, implementationMap: Dict[str, str]): - self.unionfragments_impl_map[fragmentname] = implementationMap + def register_union_fragment_members(self, fragmentname: str, membersMap: Dict[str, str]): + self.unionfragment_members_map[fragmentname] = membersMap - def get_union_fragment_implementations(self, fragmentname: str): - return self.unionfragments_impl_map[fragmentname] + + def get_union_fragment_members(self, fragmentname: str): + return self.unionfragment_members_map[fragmentname] def get_fragment_type(self, fragmentname: str): @@ -318,15 +319,6 @@ def reference_interface_fragment(self, typename: str, parent: str, allow_forward def register_interface_fragment(self, typename: str, ast: ast.AST): self.registered_interfaces_fragments[typename] = ast - - def is_union_fragment(self, typename: str): - return typename in self.registered_union_fragments - - def reference_union_fragment(self, typename: str, parent: str, allow_forward=True) -> ast.AST: - return self.registered_union_fragments[typename] - - def register_union_fragment(self, typename: str, ast: ast.AST): - self.registered_union_fragments[typename] = ast def inherit_fragment(self, typename: str, allow_forward=True) -> ast.AST: if typename not in self.fragment_class_map: diff --git a/turms/utils.py b/turms/utils.py index e150e87..e39bec3 100644 --- a/turms/utils.py +++ b/turms/utils.py @@ -1,16 +1,14 @@ +import ast import glob import re from typing import List, Optional, Set, Union -from turms.config import GeneratorConfig -from turms.errors import GenerationError -from graphql.utilities.build_client_schema import GraphQLSchema -from graphql.language.ast import DocumentNode, FieldNode, NameNode -from graphql.error.graphql_error import GraphQLError + from graphql import ( BooleanValueNode, FloatValueNode, FragmentDefinitionNode, GraphQLEnumType, + GraphQLInterfaceType, GraphQLList, GraphQLNonNull, GraphQLObjectType, @@ -29,19 +27,21 @@ parse, print_ast, validate, - GraphQLInterfaceType, ) -import ast -from turms.registry import ClassRegistry +from graphql.error.graphql_error import GraphQLError +from graphql.language.ast import DocumentNode, FieldNode, NameNode +from graphql.utilities.build_client_schema import GraphQLSchema + +from turms.config import GeneratorConfig from turms.errors import ( GenerationError, NoEnumFound, NoInputTypeFound, NoScalarFound, ) -from .config import GraphQLTypes -import re +from turms.registry import ClassRegistry +from .config import GraphQLTypes commentline_regex = re.compile(r"^.*#(.*)") @@ -488,7 +488,7 @@ def auto_add_typename_field_to_fragment_str(fragment_str: str) -> str: for fragment in x.definitions: if isinstance(fragment, FragmentDefinitionNode): selections = list(fragment.selection_set.selections) - if not any(field.name.value == "__typename" for field in selections): + if not any(isinstance(field, FieldNode) and field.name.value == "__typename" for field in selections): selections.append( FieldNode( name=NameNode(value="__typename"),