package org.springframework.data.web.querydsl;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map.Entry;
import java.util.Optional;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.data.querydsl.binding.QuerydslBinderCustomizer;
import org.springframework.data.querydsl.binding.QuerydslBindings;
import org.springframework.data.querydsl.binding.QuerydslBindingsFactory;
import org.springframework.data.querydsl.binding.QuerydslPredicate;
import org.springframework.data.querydsl.binding.QuerydslPredicateBuilder;
import org.springframework.data.util.CastUtils;
import org.springframework.data.util.ClassTypeInformation;
import org.springframework.data.util.TypeInformation;
import org.springframework.lang.Nullable;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;
import com.querydsl.core.BooleanBuilder;
import com.querydsl.core.types.Predicate;
public class QuerydslPredicateArgumentResolver implements HandlerMethodArgumentResolver {
private static final ResolvableType PREDICATE = ResolvableType.forClass(Predicate.class);
private static final ResolvableType OPTIONAL_OF_PREDICATE = ResolvableType.forClassWithGenerics(Optional.class,
PREDICATE);
private final QuerydslBindingsFactory bindingsFactory;
private final QuerydslPredicateBuilder predicateBuilder;
public QuerydslPredicateArgumentResolver(QuerydslBindingsFactory factory,
Optional<ConversionService> conversionService) {
this.bindingsFactory = factory;
this.predicateBuilder = new QuerydslPredicateBuilder(conversionService.orElseGet(DefaultConversionService::new),
factory.getEntityPathResolver());
}
@Override
public boolean supportsParameter(MethodParameter parameter) {
ResolvableType type = ResolvableType.forMethodParameter(parameter);
if (PREDICATE.isAssignableFrom(type) || OPTIONAL_OF_PREDICATE.isAssignableFrom(type)) {
return true;
}
if (parameter.hasParameterAnnotation(QuerydslPredicate.class)) {
throw new IllegalArgumentException(String.format("Parameter at position %s must be of type Predicate but was %s.",
parameter.getParameterIndex(), parameter.getParameterType()));
}
return false;
}
@Nullable
@Override
public Object resolveArgument(MethodParameter parameter, @Nullable ModelAndViewContainer mavContainer,
NativeWebRequest webRequest, @Nullable WebDataBinderFactory binderFactory) throws Exception {
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
for (Entry<String, String[]> entry : webRequest.getParameterMap().entrySet()) {
parameters.put(entry.getKey(), Arrays.asList(entry.getValue()));
}
Optional<QuerydslPredicate> annotation = Optional
.ofNullable(parameter.getParameterAnnotation(QuerydslPredicate.class));
TypeInformation<?> domainType = extractTypeInfo(parameter).getRequiredActualType();
Optional<Class<? extends QuerydslBinderCustomizer<?>>> bindingsAnnotation = annotation
.map(QuerydslPredicate::bindings)
.map(CastUtils::cast);
QuerydslBindings bindings = bindingsAnnotation
.map(it -> bindingsFactory.createBindingsFor(domainType, it))
.orElseGet(() -> bindingsFactory.createBindingsFor(domainType));
Predicate result = predicateBuilder.getPredicate(domainType, parameters, bindings);
if (!parameter.isOptional() && result == null) {
return new BooleanBuilder();
}
return OPTIONAL_OF_PREDICATE.isAssignableFrom(ResolvableType.forMethodParameter(parameter))
? Optional.ofNullable(result)
: result;
}
static TypeInformation<?> (MethodParameter parameter) {
Optional<QuerydslPredicate> annotation = Optional
.ofNullable(parameter.getParameterAnnotation(QuerydslPredicate.class));
return annotation.filter(it -> !Object.class.equals(it.root()))
.<TypeInformation<?>> map(it -> ClassTypeInformation.from(it.root()))
.orElseGet(() -> detectDomainType(parameter));
}
private static TypeInformation<?> detectDomainType(MethodParameter parameter) {
Method method = parameter.getMethod();
if (method == null) {
throw new IllegalArgumentException("Method parameter is not backed by a method!");
}
return detectDomainType(ClassTypeInformation.fromReturnTypeOf(method));
}
private static TypeInformation<?> detectDomainType(TypeInformation<?> source) {
if (source.getTypeArguments().isEmpty()) {
return source;
}
TypeInformation<?> actualType = source.getActualType();
if (actualType == null) {
throw new IllegalArgumentException(String.format("Could not determine domain type from %s!", source));
}
if (source != actualType) {
return detectDomainType(actualType);
}
if (source instanceof Iterable) {
return source;
}
return detectDomainType(source.getRequiredComponentType());
}
}