Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 154 additions & 60 deletions graphql-apt/src/main/java/feign/graphql/apt/GraphqlSchemaProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@

import com.google.auto.service.AutoService;
import com.squareup.javapoet.TypeName;
import feign.graphql.GraphqlField;
import feign.graphql.GraphqlQuery;
import feign.graphql.GraphqlSchema;
import feign.graphql.Scalar;
import feign.graphql.Toggle;
import graphql.language.Document;
import graphql.language.Field;
import graphql.language.FieldDefinition;
import graphql.language.ListType;
import graphql.language.NonNullType;
import graphql.language.ObjectTypeDefinition;
import graphql.language.OperationDefinition;
import graphql.language.SelectionSet;
import graphql.language.Type;
import graphql.language.VariableDefinition;
import graphql.parser.Parser;
import graphql.schema.GraphQLSchema;
Expand All @@ -39,6 +37,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Supplier;
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
Expand All @@ -48,10 +48,13 @@
import javax.annotation.processing.SupportedAnnotationTypes;
import javax.annotation.processing.SupportedSourceVersion;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.Element;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.PackageElement;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.MirroredTypeException;
import javax.lang.model.type.MirroredTypesException;
import javax.lang.model.type.TypeKind;
import javax.lang.model.type.TypeMirror;
import javax.tools.Diagnostic;
Expand Down Expand Up @@ -117,6 +120,9 @@ private void processInterface(TypeElement typeElement) {
var validator = new QueryValidator(messager);
var generator = new TypeGenerator(filer, messager, registry, typeMapper, targetPackage);

var classFieldAnnotations = extractFieldAnnotations(typeElement);
var classConfig = resolveClassConfig(schemaAnnotation, classFieldAnnotations);

for (var enclosed : typeElement.getEnclosedElements()) {
if (!(enclosed instanceof ExecutableElement method)) {
continue;
Expand All @@ -126,6 +132,9 @@ private void processInterface(TypeElement typeElement) {
continue;
}

var methodConfig = resolveMethodConfig(method, queryAnnotation, classConfig);
generator.setAnnotationConfig(methodConfig);

processMethod(
method,
queryAnnotation,
Expand Down Expand Up @@ -242,9 +251,9 @@ private void processMethod(
if (rootType != null) {
var rootField = findRootField(operation.getSelectionSet());
if (rootField != null && rootField.getSelectionSet() != null) {
var rootFieldDef = findFieldDefinition(rootType, rootField.getName());
var rootFieldDef = GraphqlTypeMapper.findFieldDefinition(rootType, rootField.getName());
if (rootFieldDef != null) {
var fieldTypeName = unwrapTypeName(rootFieldDef.getType());
var fieldTypeName = GraphqlTypeMapper.unwrapTypeName(rootFieldDef.getType());
var fieldObjectType =
registry.getType(fieldTypeName, ObjectTypeDefinition.class).orElse(null);
if (fieldObjectType != null) {
Expand Down Expand Up @@ -297,77 +306,39 @@ private Field findRootField(SelectionSet selectionSet) {
return null;
}

private FieldDefinition findFieldDefinition(ObjectTypeDefinition typeDef, String fieldName) {
for (var fd : typeDef.getFieldDefinitions()) {
if (fd.getName().equals(fieldName)) {
return fd;
}
}
return null;
}

private ObjectTypeDefinition getRootType(
OperationDefinition operation, TypeDefinitionRegistry registry) {
var rootTypeName =
var operationName =
switch (operation.getOperation()) {
case MUTATION ->
registry
.schemaDefinition()
.flatMap(
sd ->
sd.getOperationTypeDefinitions().stream()
.filter(otd -> otd.getName().equals("mutation"))
.findFirst())
.map(otd -> otd.getTypeName().getName())
.orElse("Mutation");
case SUBSCRIPTION ->
registry
.schemaDefinition()
.flatMap(
sd ->
sd.getOperationTypeDefinitions().stream()
.filter(otd -> otd.getName().equals("subscription"))
.findFirst())
.map(otd -> otd.getTypeName().getName())
.orElse("Subscription");
default ->
registry
.schemaDefinition()
.flatMap(
sd ->
sd.getOperationTypeDefinitions().stream()
.filter(otd -> otd.getName().equals("query"))
.findFirst())
.map(otd -> otd.getTypeName().getName())
.orElse("Query");
case MUTATION -> "mutation";
case SUBSCRIPTION -> "subscription";
default -> "query";
};
var fallback = Character.toUpperCase(operationName.charAt(0)) + operationName.substring(1);
var rootTypeName =
registry
.schemaDefinition()
.flatMap(
sd ->
sd.getOperationTypeDefinitions().stream()
.filter(otd -> otd.getName().equals(operationName))
.findFirst())
.map(otd -> otd.getTypeName().getName())
.orElse(fallback);
return registry.getType(rootTypeName, ObjectTypeDefinition.class).orElse(null);
}

private String findGraphqlInputType(
String javaParamTypeName, List<VariableDefinition> variableDefs) {
for (var varDef : variableDefs) {
var graphqlTypeName = unwrapTypeName(varDef.getType());
var graphqlTypeName = GraphqlTypeMapper.unwrapTypeName(varDef.getType());
if (graphqlTypeName.equals(javaParamTypeName)) {
return graphqlTypeName;
}
}
return javaParamTypeName;
}

private String unwrapTypeName(Type<?> type) {
if (type instanceof NonNullType nullType) {
return unwrapTypeName(nullType.getType());
}
if (type instanceof ListType listType) {
return unwrapTypeName(listType.getType());
}
if (type instanceof graphql.language.TypeName name) {
return name.getName();
}
return "String";
}

private static final Set<String> JAVA_BUILT_INS =
Set.of(
"String",
Expand Down Expand Up @@ -434,6 +405,129 @@ private TypeMirror unwrapListTypeMirror(TypeMirror typeMirror) {
return typeMirror;
}

private TypeAnnotationConfig resolveClassConfig(
GraphqlSchema annotation,
Map<String, TypeAnnotationConfig.FieldAnnotations> classFieldAnnotations) {
var fqns = extractClassFqns(annotation::typeAnnotations);
var rawAnnotations = annotation.rawTypeAnnotations();
var usesFqns = extractClassFqns(annotation::uses);
var nonNullFqns = extractClassFqns(annotation::nonNullTypeAnnotations);
var nonNullRaw = annotation.nonNullRawTypeAnnotations();
var config =
TypeAnnotationConfig.resolve(
fqns,
rawAnnotations,
annotation.useOptional(),
classFieldAnnotations,
nonNullFqns,
nonNullRaw);
if (usesFqns.isEmpty()) {
return config;
}
var mergedImports = new TreeSet<>(config.imports());
for (var fqn : usesFqns) {
if (!fqn.startsWith("java.lang.")) {
mergedImports.add(fqn);
}
}
return new TypeAnnotationConfig(
mergedImports,
config.annotations(),
config.useOptional(),
config.fieldAnnotations(),
config.nonNullAnnotations());
}

private static List<String> extractClassFqns(Supplier<Class<?>[]> accessor) {
try {
var classes = accessor.get();
return java.util.Arrays.stream(classes).map(Class::getCanonicalName).toList();
} catch (MirroredTypesException e) {
return e.getTypeMirrors().stream().map(TypeMirror::toString).toList();
}
}

private TypeAnnotationConfig resolveMethodConfig(
ExecutableElement method, GraphqlQuery annotation, TypeAnnotationConfig classConfig) {
var methodFqns = extractClassFqns(annotation::typeAnnotations);
var methodRaw = annotation.rawTypeAnnotations();
var methodToggle = annotation.useOptional();

var useOptional =
methodToggle == Toggle.INHERIT ? classConfig.useOptional() : methodToggle == Toggle.TRUE;

var methodFieldAnnotations = extractFieldAnnotations(method);
var fieldAnnotations =
TypeAnnotationConfig.FieldAnnotations.merge(
classConfig.fieldAnnotations(), methodFieldAnnotations);

var methodNonNullFqns = extractClassFqns(annotation::nonNullTypeAnnotations);
var methodNonNullRaw = annotation.nonNullRawTypeAnnotations();
boolean hasMethodNonNull = !methodNonNullFqns.isEmpty() || methodNonNullRaw.length > 0;
var resolvedNonNull = hasMethodNonNull ? null : classConfig.nonNullAnnotations();

boolean hasMethodAnnotations = !methodFqns.isEmpty() || methodRaw.length > 0;
if (!hasMethodAnnotations && !hasMethodNonNull) {
if (useOptional == classConfig.useOptional()
&& fieldAnnotations.equals(classConfig.fieldAnnotations())) {
return classConfig;
}
var mergedImports = new TreeSet<>(classConfig.imports());
for (var fa : fieldAnnotations.values()) {
mergedImports.addAll(fa.imports());
}
return new TypeAnnotationConfig(
mergedImports,
classConfig.annotations(),
useOptional,
fieldAnnotations,
classConfig.nonNullAnnotations());
}

var nonNullFqns = hasMethodNonNull ? methodNonNullFqns : List.<String>of();
var nonNullRaw = hasMethodNonNull ? methodNonNullRaw : new String[0];
var config =
TypeAnnotationConfig.resolve(
methodFqns, methodRaw, useOptional, fieldAnnotations, nonNullFqns, nonNullRaw);

if (resolvedNonNull != null && !resolvedNonNull.isEmpty()) {
var mergedImports = new TreeSet<>(config.imports());
mergedImports.addAll(classConfig.imports());
return new TypeAnnotationConfig(
mergedImports, config.annotations(), useOptional, fieldAnnotations, resolvedNonNull);
}

return config;
}

private Map<String, TypeAnnotationConfig.FieldAnnotations> extractFieldAnnotations(
Element method) {
var fieldAnnotations = new HashMap<String, TypeAnnotationConfig.FieldAnnotations>();
var graphqlFields = method.getAnnotationsByType(GraphqlField.class);
for (var gf : graphqlFields) {
var fqns = extractClassFqns(gf::typeAnnotations);
var typeOverride = extractFieldTypeOverride(gf);
var resolved =
TypeAnnotationConfig.FieldAnnotations.resolve(
fqns, gf.rawTypeAnnotations(), typeOverride);
fieldAnnotations.put(gf.name(), resolved);
}
return fieldAnnotations;
}

private String extractFieldTypeOverride(GraphqlField annotation) {
try {
var cls = annotation.type();
if (cls == Void.class) {
return null;
}
return cls.getCanonicalName();
} catch (MirroredTypeException e) {
var fqn = e.getTypeMirror().toString();
return "java.lang.Void".equals(fqn) ? null : fqn;
}
}

private String getPackageName(TypeElement typeElement) {
var enclosing = typeElement.getEnclosingElement();
while (enclosing != null && !(enclosing instanceof PackageElement)) {
Expand Down
42 changes: 40 additions & 2 deletions graphql-apt/src/main/java/feign/graphql/apt/GraphqlTypeMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.ParameterizedTypeName;
import com.squareup.javapoet.TypeName;
import graphql.language.FieldDefinition;
import graphql.language.ListType;
import graphql.language.NonNullType;
import graphql.language.ObjectTypeDefinition;
import graphql.language.Type;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class GraphqlTypeMapper {

Expand All @@ -44,11 +47,24 @@ public GraphqlTypeMapper(String targetPackage, Map<String, TypeName> customScala
}

public TypeName map(Type<?> type) {
return map(type, false);
}

public TypeName map(Type<?> type, boolean useOptional) {
boolean nullable = !(type instanceof NonNullType);
var mapped = mapInner(type);
if (useOptional && nullable) {
return ParameterizedTypeName.get(ClassName.get(Optional.class), mapped);
}
return mapped;
}

private TypeName mapInner(Type<?> type) {
if (type instanceof NonNullType nullType) {
return map(nullType.getType());
return mapInner(nullType.getType());
}
if (type instanceof ListType listType) {
var elementType = map(listType.getType());
var elementType = mapInner(listType.getType());
return ParameterizedTypeName.get(ClassName.get(List.class), elementType);
}
if (type instanceof graphql.language.TypeName name) {
Expand All @@ -72,4 +88,26 @@ private TypeName mapScalarOrNamed(String name) {
public boolean isScalar(String name) {
return BUILT_IN_SCALARS.containsKey(name) || customScalars.containsKey(name);
}

static String unwrapTypeName(Type<?> type) {
if (type instanceof NonNullType nullType) {
return unwrapTypeName(nullType.getType());
}
if (type instanceof ListType listType) {
return unwrapTypeName(listType.getType());
}
if (type instanceof graphql.language.TypeName name) {
return name.getName();
}
return "String";
}

static FieldDefinition findFieldDefinition(ObjectTypeDefinition typeDef, String fieldName) {
for (var fd : typeDef.getFieldDefinitions()) {
if (fd.getName().equals(fieldName)) {
return fd;
}
}
return null;
}
}
Loading