44import graphql .PublicApi ;
55import graphql .schema .GraphQLEnumType ;
66import graphql .schema .GraphQLFieldDefinition ;
7+ import graphql .schema .GraphQLFieldsContainer ;
78import graphql .schema .GraphQLImplementingType ;
89import graphql .schema .GraphQLInputObjectField ;
910import graphql .schema .GraphQLInputObjectType ;
1011import graphql .schema .GraphQLInterfaceType ;
11- import graphql .schema .GraphQLNamedSchemaElement ;
1212import graphql .schema .GraphQLNamedType ;
1313import graphql .schema .GraphQLObjectType ;
1414import graphql .schema .GraphQLSchema ;
2424
2525import java .util .ArrayList ;
2626import java .util .HashSet ;
27+ import java .util .LinkedHashSet ;
2728import java .util .List ;
2829import java .util .Map ;
2930import java .util .Objects ;
@@ -45,7 +46,9 @@ public class FieldVisibilitySchemaTransformation {
4546 private final Runnable afterTransformationHook ;
4647
4748 public FieldVisibilitySchemaTransformation (VisibleFieldPredicate visibleFieldPredicate ) {
48- this (visibleFieldPredicate , () -> {}, () -> {});
49+ this (visibleFieldPredicate , () -> {
50+ }, () -> {
51+ });
4952 }
5053
5154 public FieldVisibilitySchemaTransformation (VisibleFieldPredicate visibleFieldPredicate ,
@@ -155,40 +158,85 @@ private static class FieldRemovalVisitor extends GraphQLTypeVisitorStub {
155158 private final VisibleFieldPredicate visibilityPredicate ;
156159 private final Set <GraphQLType > removedTypes ;
157160
161+ private final Set <GraphQLFieldDefinition > fieldDefinitionsToActuallyRemove = new LinkedHashSet <>();
162+ private final Set <GraphQLInputObjectField > inputObjectFieldsToDelete = new LinkedHashSet <>();
163+
158164 private FieldRemovalVisitor (VisibleFieldPredicate visibilityPredicate ,
159165 Set <GraphQLType > removedTypes ) {
160166 this .visibilityPredicate = visibilityPredicate ;
161167 this .removedTypes = removedTypes ;
162168 }
163169
164170 @ Override
165- public TraversalControl visitGraphQLFieldDefinition (GraphQLFieldDefinition definition ,
166- TraverserContext <GraphQLSchemaElement > context ) {
167- return visitField (definition , context );
171+ public TraversalControl visitGraphQLObjectType (GraphQLObjectType objectType , TraverserContext <GraphQLSchemaElement > context ) {
172+ return visitFieldsContainer (objectType , context );
168173 }
169174
170175 @ Override
171- public TraversalControl visitGraphQLInputObjectField (GraphQLInputObjectField definition ,
172- TraverserContext <GraphQLSchemaElement > context ) {
173- return visitField (definition , context );
176+ public TraversalControl visitGraphQLInterfaceType (GraphQLInterfaceType objectType , TraverserContext <GraphQLSchemaElement > context ) {
177+ return visitFieldsContainer (objectType , context );
174178 }
175179
176- private TraversalControl visitField (GraphQLNamedSchemaElement element ,
177- TraverserContext <GraphQLSchemaElement > context ) {
178-
179- VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl (
180- element , context .getParentNode ());
181- if (!visibilityPredicate .isVisible (environment )) {
182- deleteNode (context );
180+ private TraversalControl visitFieldsContainer (GraphQLFieldsContainer fieldsContainer , TraverserContext <GraphQLSchemaElement > context ) {
181+ boolean allFieldsDeleted = true ;
182+ for (GraphQLFieldDefinition fieldDefinition : fieldsContainer .getFieldDefinitions ()) {
183+ VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl (
184+ fieldDefinition , fieldsContainer );
185+ if (!visibilityPredicate .isVisible (environment )) {
186+ fieldDefinitionsToActuallyRemove .add (fieldDefinition );
187+ removedTypes .add (fieldDefinition .getType ());
188+ } else {
189+ allFieldsDeleted = false ;
190+ }
191+ }
192+ if (allFieldsDeleted ) {
193+ // we are deleting the whole interface type because all fields are supposed to be deleted
194+ return deleteNode (context );
195+ } else {
196+ return TraversalControl .CONTINUE ;
197+ }
198+ }
183199
184- if (element instanceof GraphQLFieldDefinition ) {
185- removedTypes .add (((GraphQLFieldDefinition ) element ).getType ());
186- } else if (element instanceof GraphQLInputObjectField ) {
187- removedTypes .add (((GraphQLInputObjectField ) element ).getType ());
200+ @ Override
201+ public TraversalControl visitGraphQLInputObjectType (GraphQLInputObjectType inputObjectType , TraverserContext <GraphQLSchemaElement > context ) {
202+ boolean allFieldsDeleted = true ;
203+ for (GraphQLInputObjectField inputField : inputObjectType .getFieldDefinitions ()) {
204+ VisibleFieldPredicateEnvironment environment = new VisibleFieldPredicateEnvironmentImpl (
205+ inputField , inputObjectType );
206+ if (!visibilityPredicate .isVisible (environment )) {
207+ inputObjectFieldsToDelete .add (inputField );
208+ removedTypes .add (inputField .getType ());
209+ } else {
210+ allFieldsDeleted = false ;
188211 }
189212 }
213+ if (allFieldsDeleted ) {
214+ // we are deleting the whole input object type because all fields are supposed to be deleted
215+ return deleteNode (context );
216+ } else {
217+ return TraversalControl .CONTINUE ;
218+ }
190219
191- return TraversalControl .CONTINUE ;
220+ }
221+
222+ @ Override
223+ public TraversalControl visitGraphQLFieldDefinition (GraphQLFieldDefinition definition ,
224+ TraverserContext <GraphQLSchemaElement > context ) {
225+ if (fieldDefinitionsToActuallyRemove .contains (definition )) {
226+ return deleteNode (context );
227+ } else {
228+ return TraversalControl .CONTINUE ;
229+ }
230+ }
231+
232+ @ Override
233+ public TraversalControl visitGraphQLInputObjectField (GraphQLInputObjectField definition ,
234+ TraverserContext <GraphQLSchemaElement > context ) {
235+ if (inputObjectFieldsToDelete .contains (definition )) {
236+ return deleteNode (context );
237+ } else {
238+ return TraversalControl .CONTINUE ;
239+ }
192240 }
193241 }
194242
@@ -216,12 +264,12 @@ public TraversalControl visitGraphQLInterfaceType(GraphQLInterfaceType node,
216264 public TraversalControl visitGraphQLType (GraphQLSchemaElement node ,
217265 TraverserContext <GraphQLSchemaElement > context ) {
218266 if (observedBeforeTransform .contains (node ) &&
219- !observedAfterTransform .contains (node ) &&
220- (node instanceof GraphQLObjectType ||
221- node instanceof GraphQLEnumType ||
222- node instanceof GraphQLInputObjectType ||
223- node instanceof GraphQLInterfaceType ||
224- node instanceof GraphQLUnionType )) {
267+ !observedAfterTransform .contains (node ) &&
268+ (node instanceof GraphQLObjectType ||
269+ node instanceof GraphQLEnumType ||
270+ node instanceof GraphQLInputObjectType ||
271+ node instanceof GraphQLInterfaceType ||
272+ node instanceof GraphQLUnionType )) {
225273
226274 return deleteNode (context );
227275 }
0 commit comments