/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.websockets.spi;

import io.undertow.UndertowLogger;
import io.undertow.io.IoCallback;
import io.undertow.io.Sender;
import io.undertow.security.api.SecurityContext;
import io.undertow.security.idm.Account;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.HttpUpgradeListener;
import io.undertow.server.session.SessionConfig;
import io.undertow.server.session.SessionManager;
import io.undertow.util.AttachmentKey;
import io.undertow.util.HeaderMap;
import io.undertow.util.HttpString;
import io.undertow.websockets.core.WebSocketChannel;
import org.xnio.ChannelListener;
import org.xnio.FinishedIoFuture;
import org.xnio.FutureResult;
import org.xnio.IoFuture;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.channels.StreamSourceChannel;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

Author:Stuart Douglas
/** * @author Stuart Douglas */
public class AsyncWebSocketHttpServerExchange implements WebSocketHttpExchange { private final HttpServerExchange exchange; private Sender sender; private final Set<WebSocketChannel> peerConnections; public AsyncWebSocketHttpServerExchange(final HttpServerExchange exchange, Set<WebSocketChannel> peerConnections) { this.exchange = exchange; this.peerConnections = peerConnections; } @Override public <T> void putAttachment(final AttachmentKey<T> key, final T value) { exchange.putAttachment(key, value); } @Override public <T> T getAttachment(final AttachmentKey<T> key) { return exchange.getAttachment(key); } @Override public String getRequestHeader(final String headerName) { return exchange.getRequestHeaders().getFirst(HttpString.tryFromString(headerName)); } @Override public Map<String, List<String>> getRequestHeaders() { Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); for (final HttpString header : exchange.getRequestHeaders().getHeaderNames()) { headers.put(header.toString(), new ArrayList<>(exchange.getRequestHeaders().get(header))); } return Collections.unmodifiableMap(headers); } @Override public String getResponseHeader(final String headerName) { return exchange.getResponseHeaders().getFirst(HttpString.tryFromString(headerName)); } @Override public Map<String, List<String>> getResponseHeaders() { Map<String, List<String>> headers = new HashMap<>(); for (final HttpString header : exchange.getResponseHeaders().getHeaderNames()) { headers.put(header.toString(), new ArrayList<>(exchange.getResponseHeaders().get(header))); } return Collections.unmodifiableMap(headers); } @Override public void setResponseHeaders(final Map<String, List<String>> headers) { HeaderMap map = exchange.getRequestHeaders(); map.clear(); for (Map.Entry<String, List<String>> header : headers.entrySet()) { map.addAll(HttpString.tryFromString(header.getKey()), header.getValue()); } } @Override public void setResponseHeader(final String headerName, final String headerValue) { exchange.getResponseHeaders().put(HttpString.tryFromString(headerName), headerValue); } @Override public void upgradeChannel(final HttpUpgradeListener upgradeCallback) { exchange.upgradeChannel(upgradeCallback); } @Override public IoFuture<Void> sendData(final ByteBuffer data) { if (sender == null) { this.sender = exchange.getResponseSender(); } final FutureResult<Void> future = new FutureResult<>(); sender.send(data, new IoCallback() { @Override public void onComplete(final HttpServerExchange exchange, final Sender sender) { future.setResult(null); } @Override public void onException(final HttpServerExchange exchange, final Sender sender, final IOException exception) { UndertowLogger.REQUEST_IO_LOGGER.ioException(exception); future.setException(exception); } }); return future.getIoFuture(); } @Override public IoFuture<byte[]> readRequestData() { final ByteArrayOutputStream data = new ByteArrayOutputStream(); final PooledByteBuffer pooled = exchange.getConnection().getByteBufferPool().allocate(); final ByteBuffer buffer = pooled.getBuffer(); final StreamSourceChannel channel = exchange.getRequestChannel(); int res; for (; ; ) { try { res = channel.read(buffer); if (res == -1) { return new FinishedIoFuture<>(data.toByteArray()); } else if (res == 0) { //callback final FutureResult<byte[]> future = new FutureResult<>(); channel.getReadSetter().set(new ChannelListener<StreamSourceChannel>() { @Override public void handleEvent(final StreamSourceChannel channel) { int res; try { res = channel.read(buffer); if (res == -1) { future.setResult(data.toByteArray()); channel.suspendReads(); return; } else if (res == 0) { return; } else { buffer.flip(); while (buffer.hasRemaining()) { data.write(buffer.get()); } buffer.clear(); } } catch (IOException e) { future.setException(e); } } }); channel.resumeReads(); return future.getIoFuture(); } else { buffer.flip(); while (buffer.hasRemaining()) { data.write(buffer.get()); } buffer.clear(); } } catch (IOException e) { final FutureResult<byte[]> future = new FutureResult<>(); future.setException(e); return future.getIoFuture(); } } } @Override public void endExchange() { exchange.endExchange(); } @Override public void close() { try { exchange.endExchange(); } finally { IoUtils.safeClose(exchange.getConnection()); } } @Override public String getRequestScheme() { return exchange.getRequestScheme(); } @Override public String getRequestURI() { String q = exchange.getQueryString(); if (q == null || q.isEmpty()) { return exchange.getRequestURI(); } else { return exchange.getRequestURI() + "?" + q; } } @Override public ByteBufferPool getBufferPool() { return exchange.getConnection().getByteBufferPool(); } @Override public String getQueryString() { return exchange.getQueryString(); } @Override public Object getSession() { SessionManager sm = exchange.getAttachment(SessionManager.ATTACHMENT_KEY); SessionConfig sessionCookieConfig = exchange.getAttachment(SessionConfig.ATTACHMENT_KEY); if(sm != null && sessionCookieConfig != null) { return sm.getSession(exchange, sessionCookieConfig); } return null; } @Override public Map<String, List<String>> getRequestParameters() { Map<String, List<String>> params = new HashMap<>(); for (Map.Entry<String, Deque<String>> param : exchange.getQueryParameters().entrySet()) { params.put(param.getKey(), new ArrayList<>(param.getValue())); } return params; } @Override public Principal getUserPrincipal() { SecurityContext sc = exchange.getSecurityContext(); if(sc == null) { return null; } Account authenticatedAccount = sc.getAuthenticatedAccount(); if(authenticatedAccount == null) { return null; } return authenticatedAccount.getPrincipal(); } @Override public boolean isUserInRole(String role) { SecurityContext sc = exchange.getSecurityContext(); if(sc == null) { return false; } Account authenticatedAccount = sc.getAuthenticatedAccount(); if(authenticatedAccount == null) { return false; } return authenticatedAccount.getRoles().contains(role); } @Override public Set<WebSocketChannel> getPeerConnections() { return peerConnections; } @Override public OptionMap getOptions() { return exchange.getConnection().getUndertowOptions(); } }