@@ -76,7 +76,7 @@ class SchemaParser internal constructor(
7676 val inputObjects: MutableList <GraphQLInputObjectType > = mutableListOf ()
7777 inputObjectDefinitions.forEach {
7878 if (inputObjects.none { io -> io.name == it.name }) {
79- inputObjects.add(createInputObject(it, inputObjects))
79+ inputObjects.add(createInputObject(it, inputObjects, mutableSetOf () ))
8080 }
8181 }
8282 val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
@@ -155,7 +155,8 @@ class SchemaParser internal constructor(
155155 return schemaGeneratorDirectiveHelper.onObject(objectType, directiveHelperParameters)
156156 }
157157
158- private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List <GraphQLInputObjectType >): GraphQLInputObjectType {
158+ private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List <GraphQLInputObjectType >,
159+ referencingInputObjects : MutableSet <String >): GraphQLInputObjectType {
159160 val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name }
160161
161162 val builder = GraphQLInputObjectType .newInputObject()
@@ -166,14 +167,16 @@ class SchemaParser internal constructor(
166167
167168 builder.withDirectives(* buildDirectives(definition.directives, Introspection .DirectiveLocation .INPUT_OBJECT ))
168169
170+ referencingInputObjects.add(definition.name)
171+
169172 (extensionDefinitions + definition).forEach {
170173 it.inputValueDefinitions.forEach { inputDefinition ->
171174 val fieldBuilder = GraphQLInputObjectField .newInputObjectField()
172175 .name(inputDefinition.name)
173176 .definition(inputDefinition)
174177 .description(if (inputDefinition.description != null ) inputDefinition.description.content else getDocumentation(inputDefinition))
175178 .defaultValue(buildDefaultValue(inputDefinition.defaultValue))
176- .type(determineInputType(inputDefinition.type, inputObjects))
179+ .type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects ))
177180 .withDirectives(* buildDirectives(inputDefinition.directives, Introspection .DirectiveLocation .INPUT_FIELD_DEFINITION ))
178181 builder.field(fieldBuilder.build())
179182 }
@@ -280,7 +283,7 @@ class SchemaParser internal constructor(
280283 .name(argumentDefinition.name)
281284 .definition(argumentDefinition)
282285 .description(if (argumentDefinition.description != null ) argumentDefinition.description.content else getDocumentation(argumentDefinition))
283- .type(determineInputType(argumentDefinition.type, inputObjects))
286+ .type(determineInputType(argumentDefinition.type, inputObjects, setOf () ))
284287 .apply { buildDefaultValue(argumentDefinition.defaultValue)?.let { defaultValue(it) } }
285288 .withDirectives(* buildDirectives(argumentDefinition.directives, Introspection .DirectiveLocation .ARGUMENT_DEFINITION ))
286289
@@ -380,7 +383,7 @@ class SchemaParser internal constructor(
380383 is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
381384 is InputObjectTypeDefinition -> {
382385 log.info(" Create input object" )
383- createInputObject(typeDefinition, inputObjects)
386+ createInputObject(typeDefinition, inputObjects, mutableSetOf () )
384387 }
385388 is TypeName -> {
386389 val scalarType = customScalars[typeDefinition.name]
@@ -398,16 +401,19 @@ class SchemaParser internal constructor(
398401 else -> throw SchemaError (" Unknown type: $typeDefinition " )
399402 }
400403
401- private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >) =
402- determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
404+ private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >, referencingInputObjects : Set < String > ) =
405+ determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects ) as GraphQLInputType
403406
404- private fun <T : Any > determineInputType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >, inputObjects : List <GraphQLInputObjectType >): GraphQLType =
407+ private fun <T : Any > determineInputType (expectedType : KClass <T >,
408+ typeDefinition : Type <* >, allowedTypeReferences : Set <String >,
409+ inputObjects : List <GraphQLInputObjectType >,
410+ referencingInputObjects : Set <String >): GraphQLType =
405411 when (typeDefinition) {
406412 is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
407413 is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
408414 is InputObjectTypeDefinition -> {
409415 log.info(" Create input object" )
410- createInputObject(typeDefinition, inputObjects)
416+ createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet < String > )
411417 }
412418 is TypeName -> {
413419 val scalarType = customScalars[typeDefinition.name]
@@ -425,9 +431,14 @@ class SchemaParser internal constructor(
425431 } else {
426432 val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
427433 if (filteredDefinitions.isNotEmpty()) {
428- val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects)
429- (inputObjects as MutableList ).add(inputObject)
430- inputObject
434+ val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name }
435+ if (referencingInputObject != null ) {
436+ GraphQLTypeReference (referencingInputObject)
437+ } else {
438+ val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects, referencingInputObjects as MutableSet <String >)
439+ (inputObjects as MutableList ).add(inputObject)
440+ inputObject
441+ }
431442 } else {
432443 // todo: handle enum type
433444 GraphQLTypeReference (typeDefinition.name)
0 commit comments