Skip to content

Commit f187539

Browse files
authored
Merge pull request #1 from wesm/file-change-cpp-impl
Use unions in Java, simplify record batch deserialization, change C++ to use Message type
2 parents 7c6f7ef + e3af434 commit f187539

File tree

5 files changed

+97
-124
lines changed

5 files changed

+97
-124
lines changed

cpp/src/arrow/ipc/adapter.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,12 @@ class RecordBatchWriter : public ArrayVisitor {
129129
num_rows_, body_length, field_nodes_, buffer_meta_, &metadata_fb));
130130

131131
// Need to write 4 bytes (metadata size), the metadata, plus padding to
132-
// fall on a 64-byte offset
133-
int64_t padded_metadata_length =
134-
BitUtil::RoundUpToMultipleOf64(metadata_fb->size() + 4);
132+
// fall on an 8-byte offset
133+
int64_t padded_metadata_length = BitUtil::CeilByte(metadata_fb->size() + 4);
135134

136135
// The returned metadata size includes the length prefix, the flatbuffer,
137136
// plus padding
138-
*metadata_length = padded_metadata_length;
137+
*metadata_length = static_cast<int32_t>(padded_metadata_length);
139138

140139
// Write the flatbuffer size prefix
141140
int32_t flatbuffer_size = metadata_fb->size();
@@ -604,7 +603,9 @@ Status ReadRecordBatchMetadata(int64_t offset, int32_t metadata_length,
604603
return Status::Invalid(ss.str());
605604
}
606605

607-
*metadata = std::make_shared<RecordBatchMetadata>(buffer, sizeof(int32_t));
606+
std::shared_ptr<Message> message;
607+
RETURN_NOT_OK(Message::Open(buffer, 4, &message));
608+
*metadata = std::make_shared<RecordBatchMetadata>(message);
608609
return Status::OK();
609610
}
610611

cpp/src/arrow/ipc/metadata-internal.cc

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -320,23 +320,10 @@ Status MessageBuilder::SetRecordBatch(int32_t length, int64_t body_length,
320320
Status WriteRecordBatchMetadata(int32_t length, int64_t body_length,
321321
const std::vector<flatbuf::FieldNode>& nodes,
322322
const std::vector<flatbuf::Buffer>& buffers, std::shared_ptr<Buffer>* out) {
323-
flatbuffers::FlatBufferBuilder fbb;
324-
325-
auto batch = flatbuf::CreateRecordBatch(
326-
fbb, length, fbb.CreateVectorOfStructs(nodes), fbb.CreateVectorOfStructs(buffers));
327-
328-
fbb.Finish(batch);
329-
330-
int32_t size = fbb.GetSize();
331-
332-
auto result = std::make_shared<PoolBuffer>();
333-
RETURN_NOT_OK(result->Resize(size));
334-
335-
uint8_t* dst = result->mutable_data();
336-
memcpy(dst, fbb.GetBufferPointer(), size);
337-
338-
*out = result;
339-
return Status::OK();
323+
MessageBuilder builder;
324+
RETURN_NOT_OK(builder.SetRecordBatch(length, body_length, nodes, buffers));
325+
RETURN_NOT_OK(builder.Finish());
326+
return builder.GetBuffer(out);
340327
}
341328

342329
Status MessageBuilder::Finish() {

format/File.fbs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct Block {
4343

4444
/// Length of the data (this is aligned so there can be a gap between this and
4545
/// the metatdata).
46-
bodyLength: int;
46+
bodyLength: long;
4747
}
4848

4949
root_type Footer;

java/vector/src/main/java/org/apache/arrow/vector/file/ArrowBlock.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ public class ArrowBlock implements FBSerializable {
2626

2727
private final long offset;
2828
private final int metadataLength;
29-
private final int bodyLength;
29+
private final long bodyLength;
3030

31-
public ArrowBlock(long offset, int metadataLength, int bodyLength) {
31+
public ArrowBlock(long offset, int metadataLength, long bodyLength) {
3232
super();
3333
this.offset = offset;
3434
this.metadataLength = metadataLength;
@@ -43,7 +43,7 @@ public int getMetadataLength() {
4343
return metadataLength;
4444
}
4545

46-
public int getBodyLength() {
46+
public long getBodyLength() {
4747
return bodyLength;
4848
}
4949

@@ -56,7 +56,7 @@ public int writeTo(FlatBufferBuilder builder) {
5656
public int hashCode() {
5757
final int prime = 31;
5858
int result = 1;
59-
result = prime * result + bodyLength;
59+
result = prime * result + (int) (bodyLength ^ (bodyLength >>> 32));
6060
result = prime * result + metadataLength;
6161
result = prime * result + (int) (offset ^ (offset >>> 32));
6262
return result;

java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java

Lines changed: 82 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -70,33 +70,24 @@ public static int bytesToInt(byte[] bytes) {
7070
*/
7171
public static long serialize(WriteChannel out, Schema schema) throws IOException {
7272
FlatBufferBuilder builder = new FlatBufferBuilder();
73-
builder.finish(schema.getSchema(builder));
74-
ByteBuffer serializedBody = builder.dataBuffer();
75-
ByteBuffer serializedHeader =
76-
serializeHeader(MessageHeader.Schema, serializedBody.remaining());
77-
78-
long size = out.writeIntLittleEndian(serializedHeader.remaining());
79-
size += out.write(serializedHeader);
80-
size += out.write(serializedBody);
73+
int schemaOffset = schema.getSchema(builder);
74+
ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.Schema, schemaOffset, 0);
75+
long size = out.writeIntLittleEndian(serializedMessage.remaining());
76+
size += out.write(serializedMessage);
8177
return size;
8278
}
8379

8480
/**
8581
* Deserializes a schema object. Format is from serialize().
8682
*/
8783
public static Schema deserializeSchema(ReadChannel in) throws IOException {
88-
Message header = deserializeHeader(in, MessageHeader.Schema);
89-
if (header == null) {
84+
Message message = deserializeMessage(in, MessageHeader.Schema);
85+
if (message == null) {
9086
throw new IOException("Unexpected end of input. Missing schema.");
9187
}
9288

93-
// Now read the schema.
94-
ByteBuffer buffer = ByteBuffer.allocate((int)header.bodyLength());
95-
if (in.readFully(buffer) != header.bodyLength()) {
96-
throw new IOException("Unexpected end of input trying to read schema.");
97-
}
98-
buffer.rewind();
99-
return Schema.deserialize(buffer);
89+
return Schema.convertSchema((org.apache.arrow.flatbuf.Schema)
90+
message.header(new org.apache.arrow.flatbuf.Schema()));
10091
}
10192

10293
/**
@@ -106,33 +97,23 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
10697
throws IOException {
10798
long start = out.getCurrentPosition();
10899
int bodyLength = batch.computeBodyLength();
109-
ByteBuffer metadata = WriteChannel.serialize(batch);
110-
111-
int messageLength = 4 + metadata.remaining() + bodyLength;
112-
ByteBuffer serializedHeader =
113-
serializeHeader(MessageHeader.RecordBatch, messageLength);
114-
115-
// Compute the required alignment. This is not a great way to do it. The issue is
116-
// that we need to know the message size to serialize the message header but the
117-
// size depends on the alignment, which depends on the message header.
118-
// This will serialize the header again with the updated size alignment adjusted.
119-
// TODO: We really just want sizeof(MessageHeader) from the serializeHeader() above.
120-
// Is there a way to do this?
121-
long bufferOffset = start + 4 + serializedHeader.remaining() + 4 + metadata.remaining();
122-
if (bufferOffset % 8 != 0) {
123-
messageLength += 8 - bufferOffset % 8;
124-
serializedHeader = serializeHeader(MessageHeader.RecordBatch, messageLength);
125-
}
126100

127-
// Write message header.
128-
out.writeIntLittleEndian(serializedHeader.remaining());
129-
out.write(serializedHeader);
101+
FlatBufferBuilder builder = new FlatBufferBuilder();
102+
int batchOffset = batch.writeTo(builder);
103+
104+
ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch,
105+
batchOffset, bodyLength);
106+
107+
int metadataLength = serializedMessage.remaining();
130108

131-
// Write batch header. with the 4 byte little endian prefix
132-
out.writeIntLittleEndian(metadata.remaining());
133-
int metadataSize = metadata.remaining();
134-
long batchStart = out.getCurrentPosition();
135-
out.write(metadata);
109+
// Add extra padding bytes so that length prefix + metadata is a multiple
110+
// of 8 after alignment
111+
if ((metadataLength + 4) % 8 != 0) {
112+
metadataLength += 8 - (metadataLength + 4) % 8;
113+
}
114+
115+
out.writeIntLittleEndian(metadataLength);
116+
out.write(serializedMessage);
136117

137118
// Align the output to 8 byte boundary.
138119
out.align();
@@ -154,31 +135,32 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
154135
" != " + startPosition + layout.getSize());
155136
}
156137
}
157-
return new ArrowBlock(batchStart, metadataSize, (int)(out.getCurrentPosition() - bufferStart));
138+
// Metadata size in the Block account for the size prefix
139+
return new ArrowBlock(start, metadataLength + 4, out.getCurrentPosition() - bufferStart);
158140
}
159141

160142
/**
161143
* Deserializes a RecordBatch
162144
*/
163145
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
164146
BufferAllocator alloc) throws IOException {
165-
Message header = deserializeHeader(in, MessageHeader.RecordBatch);
166-
if (header == null) return null;
167-
168-
int messageLen = (int)header.bodyLength();
169-
// Now read the buffer. This has the metadata followed by the data.
170-
ArrowBuf buffer = alloc.buffer(messageLen);
171-
long readPosition = in.getCurrentPositiion();
172-
if (in.readFully(buffer, messageLen) != messageLen) {
173-
throw new IOException("Unexpected end of input trying to read batch.");
147+
Message message = deserializeMessage(in, MessageHeader.RecordBatch);
148+
if (message == null) return null;
149+
150+
if (message.bodyLength() > Integer.MAX_VALUE) {
151+
throw new IOException("Cannot currently deserialize record batches over 2GB");
174152
}
175153

176-
// Read the length of the metadata.
177-
int metadataLen = buffer.readInt();
178-
buffer = buffer.slice(4, messageLen - 4);
179-
readPosition += 4;
180-
messageLen -= 4;
181-
return deserializeRecordBatch(buffer, readPosition, metadataLen, messageLen);
154+
RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch());
155+
156+
int bodyLength = (int) message.bodyLength();
157+
158+
// Now read the record batch body
159+
ArrowBuf buffer = alloc.buffer(bodyLength);
160+
if (in.readFully(buffer, bodyLength) != bodyLength) {
161+
throw new IOException("Unexpected end of input trying to read batch.");
162+
}
163+
return deserializeRecordBatch(recordBatchFB, buffer);
182164
}
183165

184166
/**
@@ -187,37 +169,39 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
187169
*/
188170
public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, ArrowBlock block,
189171
BufferAllocator alloc) throws IOException {
190-
long readPosition = in.getCurrentPositiion();
191-
int totalLen = block.getMetadataLength() + block.getBodyLength();
192-
if ((readPosition + block.getMetadataLength()) % 8 != 0) {
193-
// Compute padded size.
194-
totalLen += (8 - (readPosition + block.getMetadataLength()) % 8);
172+
// Metadata length contains integer prefix plus byte padding
173+
long totalLen = block.getMetadataLength() + block.getBodyLength();
174+
175+
if (totalLen > Integer.MAX_VALUE) {
176+
throw new IOException("Cannot currently deserialize record batches over 2GB");
195177
}
196178

197-
ArrowBuf buffer = alloc.buffer(totalLen);
198-
if (in.readFully(buffer, totalLen) != totalLen) {
179+
ArrowBuf buffer = alloc.buffer((int) totalLen);
180+
if (in.readFully(buffer, (int) totalLen) != totalLen) {
199181
throw new IOException("Unexpected end of input trying to read batch.");
200182
}
201183

202-
return deserializeRecordBatch(buffer, readPosition, block.getMetadataLength(), totalLen);
203-
}
184+
ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4);
204185

205-
// Deserializes a record batch. Buffer should start at the RecordBatch and include
206-
// all the bytes for the metadata and then data buffers.
207-
private static ArrowRecordBatch deserializeRecordBatch(
208-
ArrowBuf buffer, long readPosition, int metadataLen, int bufferLen) {
209186
// Read the metadata.
210187
RecordBatch recordBatchFB =
211-
RecordBatch.getRootAsRecordBatch(buffer.nioBuffer().asReadOnlyBuffer());
188+
RecordBatch.getRootAsRecordBatch(metadataBuffer.nioBuffer().asReadOnlyBuffer());
212189

213-
int bufferOffset = metadataLen;
214-
readPosition += bufferOffset;
215-
if (readPosition % 8 != 0) {
216-
bufferOffset += (int)(8 - readPosition % 8);
217-
}
190+
// Now read the body
191+
final ArrowBuf body = buffer.slice(block.getMetadataLength(),
192+
(int) totalLen - block.getMetadataLength());
193+
ArrowRecordBatch result = deserializeRecordBatch(recordBatchFB, body);
218194

195+
metadataBuffer.release();
196+
buffer.release();
197+
198+
return result;
199+
}
200+
201+
// Deserializes a record batch given the Flatbuffer metadata and in-memory body
202+
private static ArrowRecordBatch deserializeRecordBatch(RecordBatch recordBatchFB,
203+
ArrowBuf body) {
219204
// Now read the body
220-
final ArrowBuf body = buffer.slice(bufferOffset, bufferLen - bufferOffset);
221205
int nodesLength = recordBatchFB.nodesLength();
222206
List<ArrowFieldNode> nodes = new ArrayList<>();
223207
for (int i = 0; i < nodesLength; ++i) {
@@ -232,43 +216,44 @@ private static ArrowRecordBatch deserializeRecordBatch(
232216
}
233217
ArrowRecordBatch arrowRecordBatch =
234218
new ArrowRecordBatch(recordBatchFB.length(), nodes, buffers);
235-
buffer.release();
219+
body.release();
236220
return arrowRecordBatch;
237221
}
238222

239223
/**
240224
* Serializes a message header.
241225
*/
242-
private static ByteBuffer serializeHeader(byte headerType, int bodyLength) {
243-
FlatBufferBuilder headerBuilder = new FlatBufferBuilder();
244-
Message.startMessage(headerBuilder);
245-
Message.addHeaderType(headerBuilder, headerType);
246-
Message.addVersion(headerBuilder, MetadataVersion.V1);
247-
Message.addBodyLength(headerBuilder, bodyLength);
248-
headerBuilder.finish(Message.endMessage(headerBuilder));
249-
return headerBuilder.dataBuffer();
226+
private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType,
227+
int headerOffset, int bodyLength) {
228+
Message.startMessage(builder);
229+
Message.addHeaderType(builder, headerType);
230+
Message.addHeader(builder, headerOffset);
231+
Message.addVersion(builder, MetadataVersion.V1);
232+
Message.addBodyLength(builder, bodyLength);
233+
builder.finish(Message.endMessage(builder));
234+
return builder.dataBuffer();
250235
}
251236

252-
private static Message deserializeHeader(ReadChannel in, byte headerType) throws IOException {
253-
// Read the header size. There is an i32 little endian prefix.
237+
private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException {
238+
// Read the message size. There is an i32 little endian prefix.
254239
ByteBuffer buffer = ByteBuffer.allocate(4);
255240
if (in.readFully(buffer) != 4) {
256241
return null;
257242
}
258243

259-
int headerLength = bytesToInt(buffer.array());
260-
buffer = ByteBuffer.allocate(headerLength);
261-
if (in.readFully(buffer) != headerLength) {
244+
int messageLength = bytesToInt(buffer.array());
245+
buffer = ByteBuffer.allocate(messageLength);
246+
if (in.readFully(buffer) != messageLength) {
262247
throw new IOException(
263-
"Unexpected end of stream trying to read header.");
248+
"Unexpected end of stream trying to read message.");
264249
}
265250
buffer.rewind();
266251

267-
Message header = Message.getRootAsMessage(buffer);
268-
if (header.headerType() != headerType) {
252+
Message message = Message.getRootAsMessage(buffer);
253+
if (message.headerType() != headerType) {
269254
throw new IOException("Invalid message: expecting " + headerType +
270-
". Message contained: " + header.headerType());
255+
". Message contained: " + message.headerType());
271256
}
272-
return header;
257+
return message;
273258
}
274259
}

0 commit comments

Comments
 (0)