/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.hadoop.cql3;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.util.*;
import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.datastax.driver.core.TypeCodec;
import org.apache.cassandra.utils.AbstractIterator;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ColumnDefinitions;
import com.datastax.driver.core.ColumnMetadata;
import com.datastax.driver.core.LocalDate;
import com.datastax.driver.core.Metadata;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Row;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.TableMetadata;
import com.datastax.driver.core.Token;
import com.datastax.driver.core.TupleValue;
import com.datastax.driver.core.UDTValue;
import com.google.common.reflect.TypeToken;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.hadoop.ColumnFamilySplit;
import org.apache.cassandra.hadoop.ConfigHelper;
import org.apache.cassandra.hadoop.HadoopCompat;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
CqlRecordReader reads the rows return from the CQL query
It uses CQL auto-paging.
Return a Long as a local CQL row key starts from 0;
Row as C* java driver CQL result set row
1) select clause must include partition key columns (to calculate the progress based on the actual CF row processed)
2) where clause must include token(partition_key1, ... , partition_keyn) > ? and
token(partition_key1, ... , partition_keyn) <= ? (in the right order)
/**
* <p>
* CqlRecordReader reads the rows return from the CQL query
* It uses CQL auto-paging.
* </p>
* <p>
* Return a Long as a local CQL row key starts from 0;
* </p>
* {@code
* Row as C* java driver CQL result set row
* 1) select clause must include partition key columns (to calculate the progress based on the actual CF row processed)
* 2) where clause must include token(partition_key1, ... , partition_keyn) > ? and
* token(partition_key1, ... , partition_keyn) <= ? (in the right order)
* }
*/
public class CqlRecordReader extends RecordReader<Long, Row>
implements org.apache.hadoop.mapred.RecordReader<Long, Row>, AutoCloseable
{
private static final Logger logger = LoggerFactory.getLogger(CqlRecordReader.class);
private ColumnFamilySplit split;
private RowIterator rowIterator;
private Pair<Long, Row> currentRow;
private int totalRowCount; // total number of rows to fetch
private String keyspace;
private String cfName;
private String cqlQuery;
private Cluster cluster;
private Session session;
private IPartitioner partitioner;
private String inputColumns;
private String userDefinedWhereClauses;
private List<String> partitionKeys = new ArrayList<>();
// partition keys -- key aliases
private LinkedHashMap<String, Boolean> partitionBoundColumns = Maps.newLinkedHashMap();
protected int nativeProtocolVersion = 1;
public CqlRecordReader()
{
super();
}
@Override
public void initialize(InputSplit split, TaskAttemptContext context) throws IOException
{
this.split = (ColumnFamilySplit) split;
Configuration conf = HadoopCompat.getConfiguration(context);
totalRowCount = (this.split.getLength() < Long.MAX_VALUE)
? (int) this.split.getLength()
: ConfigHelper.getInputSplitSize(conf);
cfName = ConfigHelper.getInputColumnFamily(conf);
keyspace = ConfigHelper.getInputKeyspace(conf);
partitioner = ConfigHelper.getInputPartitioner(conf);
inputColumns = CqlConfigHelper.getInputcolumns(conf);
userDefinedWhereClauses = CqlConfigHelper.getInputWhereClauses(conf);
try
{
if (cluster != null)
return;
// create a Cluster instance
String[] locations = split.getLocations();
cluster = CqlConfigHelper.getInputCluster(locations, conf);
}
catch (Exception e)
{
throw new RuntimeException(e);
}
if (cluster != null)
session = cluster.connect(quote(keyspace));
if (session == null)
throw new RuntimeException("Can't create connection session");
//get negotiated serialization protocol
nativeProtocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion().toInt();
// If the user provides a CQL query then we will use it without validation
// otherwise we will fall back to building a query using the:
// inputColumns
// whereClauses
cqlQuery = CqlConfigHelper.getInputCql(conf);
// validate that the user hasn't tried to give us a custom query along with input columns
// and where clauses
if (StringUtils.isNotEmpty(cqlQuery) && (StringUtils.isNotEmpty(inputColumns) ||
StringUtils.isNotEmpty(userDefinedWhereClauses)))
{
throw new AssertionError("Cannot define a custom query with input columns and / or where clauses");
}
if (StringUtils.isEmpty(cqlQuery))
cqlQuery = buildQuery();
logger.trace("cqlQuery {}", cqlQuery);
rowIterator = new RowIterator();
logger.trace("created {}", rowIterator);
}
public void close()
{
if (session != null)
session.close();
if (cluster != null)
cluster.close();
}
public Long getCurrentKey()
{
return currentRow.left;
}
public Row getCurrentValue()
{
return currentRow.right;
}
public float getProgress()
{
if (!rowIterator.hasNext())
return 1.0F;
// the progress is likely to be reported slightly off the actual but close enough
float progress = ((float) rowIterator.totalRead / totalRowCount);
return progress > 1.0F ? 1.0F : progress;
}
public boolean nextKeyValue() throws IOException
{
if (!rowIterator.hasNext())
{
logger.trace("Finished scanning {} rows (estimate was: {})", rowIterator.totalRead, totalRowCount);
return false;
}
try
{
currentRow = rowIterator.next();
}
catch (Exception e)
{
// throw it as IOException, so client can catch it and handle it at client side
IOException ioe = new IOException(e.getMessage());
ioe.initCause(ioe.getCause());
throw ioe;
}
return true;
}
// Because the old Hadoop API wants us to write to the key and value
// and the new asks for them, we need to copy the output of the new API
// to the old. Thus, expect a small performance hit.
// And obviously this wouldn't work for wide rows. But since ColumnFamilyInputFormat
// and ColumnFamilyRecordReader don't support them, it should be fine for now.
public boolean next(Long key, Row value) throws IOException
{
if (nextKeyValue())
{
((WrappedRow)value).setRow(getCurrentValue());
return true;
}
return false;
}
public long getPos() throws IOException
{
return rowIterator.totalRead;
}
public Long createKey()
{
return Long.valueOf(0L);
}
public Row createValue()
{
return new WrappedRow();
}
Return native version protocol of the cluster connection
Returns: serialization protocol version.
/**
* Return native version protocol of the cluster connection
* @return serialization protocol version.
*/
public int getNativeProtocolVersion()
{
return nativeProtocolVersion;
}
CQL row iterator
Input cql query
1) select clause must include key columns (if we use partition key based row count)
2) where clause must include token(partition_key1 ... partition_keyn) > ? and
token(partition_key1 ... partition_keyn) <= ?
/** CQL row iterator
* Input cql query
* 1) select clause must include key columns (if we use partition key based row count)
* 2) where clause must include token(partition_key1 ... partition_keyn) > ? and
* token(partition_key1 ... partition_keyn) <= ?
*/
private class RowIterator extends AbstractIterator<Pair<Long, Row>>
{
private long keyId = 0L;
protected int totalRead = 0; // total number of cf rows read
protected Iterator<Row> rows;
private Map<String, ByteBuffer> previousRowKey = new HashMap<String, ByteBuffer>(); // previous CF row key
public RowIterator()
{
AbstractType type = partitioner.getTokenValidator();
ResultSet rs = session.execute(cqlQuery, type.compose(type.fromString(split.getStartToken())), type.compose(type.fromString(split.getEndToken())) );
for (ColumnMetadata meta : cluster.getMetadata().getKeyspace(quote(keyspace)).getTable(quote(cfName)).getPartitionKey())
partitionBoundColumns.put(meta.getName(), Boolean.TRUE);
rows = rs.iterator();
}
protected Pair<Long, Row> computeNext()
{
if (rows == null || !rows.hasNext())
return endOfData();
Row row = rows.next();
Map<String, ByteBuffer> keyColumns = new HashMap<String, ByteBuffer>(partitionBoundColumns.size());
for (String column : partitionBoundColumns.keySet())
keyColumns.put(column, row.getBytesUnsafe(column));
// increase total CF row read
if (previousRowKey.isEmpty() && !keyColumns.isEmpty())
{
previousRowKey = keyColumns;
totalRead++;
}
else
{
for (String column : partitionBoundColumns.keySet())
{
// this is not correct - but we don't seem to have easy access to better type information here
if (ByteBufferUtil.compareUnsigned(keyColumns.get(column), previousRowKey.get(column)) != 0)
{
previousRowKey = keyColumns;
totalRead++;
break;
}
}
}
keyId ++;
return Pair.create(keyId, row);
}
}
private static class WrappedRow implements Row
{
private Row row;
public void setRow(Row row)
{
this.row = row;
}
@Override
public ColumnDefinitions getColumnDefinitions()
{
return row.getColumnDefinitions();
}
@Override
public boolean isNull(int i)
{
return row.isNull(i);
}
@Override
public boolean isNull(String name)
{
return row.isNull(name);
}
@Override
public Object getObject(int i)
{
return row.getObject(i);
}
@Override
public <T> T get(int i, Class<T> aClass)
{
return row.get(i, aClass);
}
@Override
public <T> T get(int i, TypeToken<T> typeToken)
{
return row.get(i, typeToken);
}
@Override
public <T> T get(int i, TypeCodec<T> typeCodec)
{
return row.get(i, typeCodec);
}
@Override
public Object getObject(String s)
{
return row.getObject(s);
}
@Override
public <T> T get(String s, Class<T> aClass)
{
return row.get(s, aClass);
}
@Override
public <T> T get(String s, TypeToken<T> typeToken)
{
return row.get(s, typeToken);
}
@Override
public <T> T get(String s, TypeCodec<T> typeCodec)
{
return row.get(s, typeCodec);
}
@Override
public boolean getBool(int i)
{
return row.getBool(i);
}
@Override
public boolean getBool(String name)
{
return row.getBool(name);
}
@Override
public short getShort(int i)
{
return row.getShort(i);
}
@Override
public short getShort(String s)
{
return row.getShort(s);
}
@Override
public byte getByte(int i)
{
return row.getByte(i);
}
@Override
public byte getByte(String s)
{
return row.getByte(s);
}
@Override
public int getInt(int i)
{
return row.getInt(i);
}
@Override
public int getInt(String name)
{
return row.getInt(name);
}
@Override
public long getLong(int i)
{
return row.getLong(i);
}
@Override
public long getLong(String name)
{
return row.getLong(name);
}
@Override
public Date getTimestamp(int i)
{
return row.getTimestamp(i);
}
@Override
public Date getTimestamp(String s)
{
return row.getTimestamp(s);
}
@Override
public LocalDate getDate(int i)
{
return row.getDate(i);
}
@Override
public LocalDate getDate(String s)
{
return row.getDate(s);
}
@Override
public long getTime(int i)
{
return row.getTime(i);
}
@Override
public long getTime(String s)
{
return row.getTime(s);
}
@Override
public float getFloat(int i)
{
return row.getFloat(i);
}
@Override
public float getFloat(String name)
{
return row.getFloat(name);
}
@Override
public double getDouble(int i)
{
return row.getDouble(i);
}
@Override
public double getDouble(String name)
{
return row.getDouble(name);
}
@Override
public ByteBuffer getBytesUnsafe(int i)
{
return row.getBytesUnsafe(i);
}
@Override
public ByteBuffer getBytesUnsafe(String name)
{
return row.getBytesUnsafe(name);
}
@Override
public ByteBuffer getBytes(int i)
{
return row.getBytes(i);
}
@Override
public ByteBuffer getBytes(String name)
{
return row.getBytes(name);
}
@Override
public String getString(int i)
{
return row.getString(i);
}
@Override
public String getString(String name)
{
return row.getString(name);
}
@Override
public BigInteger getVarint(int i)
{
return row.getVarint(i);
}
@Override
public BigInteger getVarint(String name)
{
return row.getVarint(name);
}
@Override
public BigDecimal getDecimal(int i)
{
return row.getDecimal(i);
}
@Override
public BigDecimal getDecimal(String name)
{
return row.getDecimal(name);
}
@Override
public UUID getUUID(int i)
{
return row.getUUID(i);
}
@Override
public UUID getUUID(String name)
{
return row.getUUID(name);
}
@Override
public InetAddress getInet(int i)
{
return row.getInet(i);
}
@Override
public InetAddress getInet(String name)
{
return row.getInet(name);
}
@Override
public <T> List<T> getList(int i, Class<T> elementsClass)
{
return row.getList(i, elementsClass);
}
@Override
public <T> List<T> getList(int i, TypeToken<T> typeToken)
{
return row.getList(i, typeToken);
}
@Override
public <T> List<T> getList(String name, Class<T> elementsClass)
{
return row.getList(name, elementsClass);
}
@Override
public <T> List<T> getList(String s, TypeToken<T> typeToken)
{
return row.getList(s, typeToken);
}
@Override
public <T> Set<T> getSet(int i, Class<T> elementsClass)
{
return row.getSet(i, elementsClass);
}
@Override
public <T> Set<T> getSet(int i, TypeToken<T> typeToken)
{
return row.getSet(i, typeToken);
}
@Override
public <T> Set<T> getSet(String name, Class<T> elementsClass)
{
return row.getSet(name, elementsClass);
}
@Override
public <T> Set<T> getSet(String s, TypeToken<T> typeToken)
{
return row.getSet(s, typeToken);
}
@Override
public <K, V> Map<K, V> getMap(int i, Class<K> keysClass, Class<V> valuesClass)
{
return row.getMap(i, keysClass, valuesClass);
}
@Override
public <K, V> Map<K, V> getMap(int i, TypeToken<K> typeToken, TypeToken<V> typeToken1)
{
return row.getMap(i, typeToken, typeToken1);
}
@Override
public <K, V> Map<K, V> getMap(String name, Class<K> keysClass, Class<V> valuesClass)
{
return row.getMap(name, keysClass, valuesClass);
}
@Override
public <K, V> Map<K, V> getMap(String s, TypeToken<K> typeToken, TypeToken<V> typeToken1)
{
return row.getMap(s, typeToken, typeToken1);
}
@Override
public UDTValue getUDTValue(int i)
{
return row.getUDTValue(i);
}
@Override
public UDTValue getUDTValue(String name)
{
return row.getUDTValue(name);
}
@Override
public TupleValue getTupleValue(int i)
{
return row.getTupleValue(i);
}
@Override
public TupleValue getTupleValue(String name)
{
return row.getTupleValue(name);
}
@Override
public Token getToken(int i)
{
return row.getToken(i);
}
@Override
public Token getToken(String name)
{
return row.getToken(name);
}
@Override
public Token getPartitionKeyToken()
{
return row.getPartitionKeyToken();
}
}
Build a query for the reader of the form:
SELECT * FROM ks>cf token(pk1,...pkn)>? AND token(pk1,...pkn)<=? [AND user where clauses] [ALLOW FILTERING]
/**
* Build a query for the reader of the form:
*
* SELECT * FROM ks>cf token(pk1,...pkn)>? AND token(pk1,...pkn)<=? [AND user where clauses] [ALLOW FILTERING]
*/
private String buildQuery()
{
fetchKeys();
List<String> columns = getSelectColumns();
String selectColumnList = columns.size() == 0 ? "*" : makeColumnList(columns);
String partitionKeyList = makeColumnList(partitionKeys);
return String.format("SELECT %s FROM %s.%s WHERE token(%s)>? AND token(%s)<=?" + getAdditionalWhereClauses(),
selectColumnList, quote(keyspace), quote(cfName), partitionKeyList, partitionKeyList);
}
private String getAdditionalWhereClauses()
{
String whereClause = "";
if (StringUtils.isNotEmpty(userDefinedWhereClauses))
whereClause += " AND " + userDefinedWhereClauses;
if (StringUtils.isNotEmpty(userDefinedWhereClauses))
whereClause += " ALLOW FILTERING";
return whereClause;
}
private List<String> getSelectColumns()
{
List<String> selectColumns = new ArrayList<>();
if (StringUtils.isNotEmpty(inputColumns))
{
// We must select all the partition keys plus any other columns the user wants
selectColumns.addAll(partitionKeys);
for (String column : Splitter.on(',').split(inputColumns))
{
if (!partitionKeys.contains(column))
selectColumns.add(column);
}
}
return selectColumns;
}
private String makeColumnList(Collection<String> columns)
{
return Joiner.on(',').join(Iterables.transform(columns, new Function<String, String>()
{
public String apply(String column)
{
return quote(column);
}
}));
}
private void fetchKeys()
{
// get CF meta data
TableMetadata tableMetadata = session.getCluster()
.getMetadata()
.getKeyspace(Metadata.quote(keyspace))
.getTable(Metadata.quote(cfName));
if (tableMetadata == null)
{
throw new RuntimeException("No table metadata found for " + keyspace + "." + cfName);
}
//Here we assume that tableMetadata.getPartitionKey() always
//returns the list of columns in order of component_index
for (ColumnMetadata partitionKey : tableMetadata.getPartitionKey())
{
partitionKeys.add(partitionKey.getName());
}
}
private String quote(String identifier)
{
return "\"" + identifier.replaceAll("\"", "\"\"") + "\"";
}
}