package org.jboss.resteasy.plugins.providers.jaxb;
import org.jboss.resteasy.annotations.providers.jaxb.DoNotUseJAXBProvider;
import org.jboss.resteasy.plugins.providers.jaxb.i18n.LogMessages;
import org.jboss.resteasy.plugins.providers.jaxb.i18n.Messages;
import org.jboss.resteasy.spi.util.FindAnnotation;
import org.xml.sax.InputSource;
import javax.ws.rs.Consumes;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.ext.ContextResolver;
import javax.ws.rs.ext.Provider;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBElement;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlRegistry;
import javax.xml.bind.annotation.XmlRootElement;
import javax.xml.bind.annotation.XmlType;
import javax.xml.transform.sax.SAXSource;
import javax.xml.transform.stream.StreamSource;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
@Provider
@Produces({"application/xml", "application/*+xml", "text/xml", "text/*+xml"})
@Consumes({"application/xml", "application/*+xml", "text/xml", "text/*+xml"})
public class JAXBXmlTypeProvider extends AbstractJAXBProvider<Object>
{
protected static final String OBJECT_FACTORY_NAME = ".ObjectFactory";
@Override
public void writeTo(Object t,
Class<?> type,
Type genericType,
Annotation[] annotations,
MediaType mediaType,
MultivaluedMap<String, Object> httpHeaders,
OutputStream entityStream) throws IOException
{
LogMessages.LOGGER.debugf("Provider : %s, Method : writeTo", getClass().getName());
JAXBElement<?> result = wrapInJAXBElement(t, type);
super.writeTo(result, type, genericType, annotations, mediaType, httpHeaders, entityStream);
}
@Override
public Object readFrom(Class<Object> type, Type genericType, Annotation[] annotations, MediaType mediaType, MultivaluedMap<String, String> httpHeaders, InputStream entityStream) throws IOException
{
try
{
LogMessages.LOGGER.debugf("Provider : %s, Method : readFrom", getClass().getName());
JAXBContext jaxb = getJAXBContext(type, mediaType);
if (jaxb == null) {
jaxb = getJAXBContextFinder(type, annotations, mediaType);
}
Unmarshaller unmarshaller = jaxb.createUnmarshaller();
unmarshaller = decorateUnmarshaller(type, annotations, mediaType, unmarshaller);
Object obj = null;
if (needsSecurity())
{
SAXSource source = null;
if (getCharset(mediaType) == null)
{
source = new SAXSource(new InputSource(new InputStreamReader(entityStream, StandardCharsets.UTF_8)));
}
else
{
source = new SAXSource(new InputSource(entityStream));
}
unmarshaller = new SecureUnmarshaller(unmarshaller, isDisableExternalEntities(), isEnableSecureProcessingFeature(), isDisableDTDs());
obj = unmarshaller.unmarshal(source);
}
else
{
if (getCharset(mediaType) == null)
{
InputSource is = new InputSource(entityStream);
is.setEncoding(StandardCharsets.UTF_8.name());
StreamSource source = new StreamSource(new InputStreamReader(entityStream, StandardCharsets.UTF_8));
source.setInputStream(entityStream);
obj = unmarshaller.unmarshal(source);
}
else
{
obj = unmarshaller.unmarshal(new StreamSource(entityStream));
}
}
if (obj instanceof JAXBElement)
{
JAXBElement element = (JAXBElement) obj;
return element.getValue();
}
else
{
return obj;
}
}
catch (JAXBException e)
{
throw new JAXBUnmarshalException(e);
}
}
private javax.xml.bind.JAXBContext getJAXBContext(Class<Object> type, MediaType mediaType) throws IOException {
LogMessages.LOGGER.debugf("Provider : %s, Method : getJAXBContext", getClass().getName());
ContextResolver<javax.xml.bind.JAXBContext> resolver = providers.getContextResolver(
javax.xml.bind.JAXBContext.class, mediaType);
javax.xml.bind.JAXBContext finder = null;
if (resolver != null)
{
finder = resolver.getContext(type);
if (finder == null)
{
throw new JAXBUnmarshalException(Messages.MESSAGES.couldNotFindUsersJAXBContext(mediaType));
}
}
return finder;
}
private JAXBContext getJAXBContextFinder(Class<Object> type, Annotation[] annotations, MediaType mediaType) throws IOException
{
try
{
LogMessages.LOGGER.debugf("Provider : %s, Method : getJAXBContextFinder", getClass().getName());
ContextResolver<JAXBContextFinder> resolver = providers.getContextResolver(JAXBContextFinder.class, mediaType);
JAXBContextFinder finder = resolver.getContext(type);
if (finder == null)
{
throw new JAXBUnmarshalException(Messages.MESSAGES.couldNotFindJAXBContextFinder(mediaType));
}
JAXBContext jaxb = finder.findCacheXmlTypeContext(mediaType, annotations, type);
return jaxb;
}
catch (JAXBException e)
{
throw new JAXBUnmarshalException(e);
}
}
@Override
protected boolean isReadWritable(Class<?> type,
Type genericType,
Annotation[] annotations,
MediaType mediaType)
{
return (type.isAnnotationPresent(XmlType.class) && !type.isAnnotationPresent(XmlRootElement.class)) && (FindAnnotation.findAnnotation(type, annotations, DoNotUseJAXBProvider.class) == null) && !IgnoredMediaTypes.ignored(type, annotations, mediaType);
}
public static Object findObjectFactory(Class<?> type)
{
try
{
Class<?> factoryClass = AbstractJAXBContextFinder.findDefaultObjectFactoryClass(type);
if (factoryClass != null && factoryClass.isAnnotationPresent(XmlRegistry.class))
{
return factoryClass.newInstance();
}
else
{
throw new JAXBMarshalException(Messages.MESSAGES.validXmlRegistryCouldNotBeLocated());
}
}
catch (InstantiationException | IllegalAccessException e)
{
throw new JAXBMarshalException(e);
}
}
public static JAXBElement<?> wrapInJAXBElement(Object t, Class<?> type)
{
try
{
final Object factory = findObjectFactory(type);
Method[] method;
if (System.getSecurityManager() == null)
{
method = factory.getClass().getDeclaredMethods();
}
else
{
method = AccessController.doPrivileged(new PrivilegedExceptionAction<Method[]>()
{
@Override
public Method[] run() throws Exception
{
return factory.getClass().getDeclaredMethods();
}
});
}
for (Method current : method)
{
if (current.getParameterTypes().length == 1 && current.getParameterTypes()[0].equals(type)
&& current.getName().startsWith("create"))
{
Object result = current.invoke(factory, t);
return JAXBElement.class.cast(result);
}
}
throw new JAXBMarshalException(Messages.MESSAGES.createMethodNotFound(type));
}
catch (IllegalArgumentException | IllegalAccessException | PrivilegedActionException e)
{
throw new JAXBMarshalException(e);
}
catch (InvocationTargetException e)
{
throw new JAXBMarshalException(e.getCause());
}
}
}