/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package io.undertow.websockets.core.protocol.version07;

import io.undertow.UndertowLogger;
import io.undertow.server.protocol.framed.SendFrameHeader;
import io.undertow.util.ImmediatePooledByteBuffer;
import io.undertow.websockets.core.StreamSinkFrameChannel;
import io.undertow.websockets.core.WebSocketFrameType;
import io.undertow.websockets.core.WebSocketMessages;
import io.undertow.websockets.extensions.ExtensionFunction;
import io.undertow.websockets.extensions.NoopExtensionFunction;
import io.undertow.connector.PooledByteBuffer;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.ThreadLocalRandom;

StreamSinkFrameChannel implementation for writing WebSocket Frames on WebSocketVersion.V08 connections
Author:Norman Maurer
/** * {@link StreamSinkFrameChannel} implementation for writing WebSocket Frames on {@link io.undertow.websockets.core.WebSocketVersion#V08} connections * * @author <a href="mailto:nmaurer@redhat.com">Norman Maurer</a> */
public abstract class WebSocket07FrameSinkChannel extends StreamSinkFrameChannel { private final Masker masker; private volatile boolean dataWritten = false; protected final ExtensionFunction extensionFunction; protected WebSocket07FrameSinkChannel(WebSocket07Channel wsChannel, WebSocketFrameType type) { super(wsChannel, type); if(wsChannel.isClient()) { masker = new Masker(0); } else { masker = null; } /* Checks if there are negotiated extensions that need to modify RSV bits */ if (wsChannel.areExtensionsSupported() && (type == WebSocketFrameType.TEXT || type == WebSocketFrameType.BINARY)) { extensionFunction = wsChannel.getExtensionFunction(); setRsv(extensionFunction.writeRsv(0)); } else { extensionFunction = NoopExtensionFunction.INSTANCE; setRsv(0); } } @Override protected void handleFlushComplete(boolean finalFrame) { dataWritten = true; // TODO not sure we need to do this as the key was set when it was last used // if(masker != null) { // masker.setMaskingKey(maskingKey); // } } private byte opCode() { if(dataWritten) { return WebSocket07Channel.OPCODE_CONT; } switch (getType()) { case CONTINUATION: return WebSocket07Channel.OPCODE_CONT; case TEXT: return WebSocket07Channel.OPCODE_TEXT; case BINARY: return WebSocket07Channel.OPCODE_BINARY; case CLOSE: return WebSocket07Channel.OPCODE_CLOSE; case PING: return WebSocket07Channel.OPCODE_PING; case PONG: return WebSocket07Channel.OPCODE_PONG; default: throw WebSocketMessages.MESSAGES.unsupportedFrameType(getType()); } } @Override protected SendFrameHeader createFrameHeader() { byte b0 = 0; //if writes are shutdown this is the final fragment if (isFinalFrameQueued()) { b0 |= 1 << 7; // set FIN } /* Known extensions (i.e. compression) should not modify RSV bit on continuation bit. */ byte opCode = opCode(); int rsv = opCode == WebSocket07Channel.OPCODE_CONT ? 0 : getRsv(); b0 |= (rsv & 7) << 4; b0 |= opCode & 0xf; final ByteBuffer header = ByteBuffer.allocate(14); byte maskKey = 0; if(masker != null) { maskKey |= 1 << 7; } long payloadSize = getBuffer().remaining(); if (payloadSize > 125 && opCode == WebSocket07Channel.OPCODE_PING) { throw WebSocketMessages.MESSAGES.invalidPayloadLengthForPing(payloadSize); } if (payloadSize <= 125) { header.put(b0); header.put((byte)((payloadSize | maskKey) & 0xFF)); } else if (payloadSize <= 0xFFFF) { header.put(b0); header.put((byte) ((126 | maskKey) & 0xFF)); header.put((byte) (payloadSize >>> 8 & 0xFF)); header.put((byte) (payloadSize & 0xFF)); } else { header.put(b0); header.put((byte) ((127 | maskKey) & 0xFF)); header.putLong(payloadSize); } if(masker != null) { int maskingKey = ThreadLocalRandom.current().nextInt(); //generate a new key for this frame header.put((byte)((maskingKey >> 24) & 0xFF)); header.put((byte)((maskingKey >> 16) & 0xFF)); header.put((byte)((maskingKey >> 8) & 0xFF)); header.put((byte)((maskingKey & 0xFF))); masker.setMaskingKey(maskingKey); //do any required masking ByteBuffer buf = getBuffer(); masker.beforeWrite(buf, buf.position(), buf.remaining()); } header.flip(); return new SendFrameHeader(0, new ImmediatePooledByteBuffer(header)); } @Override protected PooledByteBuffer preWriteTransform(PooledByteBuffer body) { try { return super.preWriteTransform(extensionFunction.transformForWrite(body, this, this.isFinalFrameQueued())); } catch (IOException e) { UndertowLogger.REQUEST_IO_LOGGER.ioException(e); markBroken(); throw new RuntimeException(e); } } }