@@ -70,33 +70,24 @@ public static int bytesToInt(byte[] bytes) {
7070 */
7171 public static long serialize (WriteChannel out , Schema schema ) throws IOException {
7272 FlatBufferBuilder builder = new FlatBufferBuilder ();
73- builder .finish (schema .getSchema (builder ));
74- ByteBuffer serializedBody = builder .dataBuffer ();
75- ByteBuffer serializedHeader =
76- serializeHeader (MessageHeader .Schema , serializedBody .remaining ());
77-
78- long size = out .writeIntLittleEndian (serializedHeader .remaining ());
79- size += out .write (serializedHeader );
80- size += out .write (serializedBody );
73+ int schemaOffset = schema .getSchema (builder );
74+ ByteBuffer serializedMessage = serializeMessage (builder , MessageHeader .Schema , schemaOffset , 0 );
75+ long size = out .writeIntLittleEndian (serializedMessage .remaining ());
76+ size += out .write (serializedMessage );
8177 return size ;
8278 }
8379
8480 /**
8581 * Deserializes a schema object. Format is from serialize().
8682 */
8783 public static Schema deserializeSchema (ReadChannel in ) throws IOException {
88- Message header = deserializeHeader (in , MessageHeader .Schema );
89- if (header == null ) {
84+ Message message = deserializeMessage (in , MessageHeader .Schema );
85+ if (message == null ) {
9086 throw new IOException ("Unexpected end of input. Missing schema." );
9187 }
9288
93- // Now read the schema.
94- ByteBuffer buffer = ByteBuffer .allocate ((int )header .bodyLength ());
95- if (in .readFully (buffer ) != header .bodyLength ()) {
96- throw new IOException ("Unexpected end of input trying to read schema." );
97- }
98- buffer .rewind ();
99- return Schema .deserialize (buffer );
89+ return Schema .convertSchema ((org .apache .arrow .flatbuf .Schema )
90+ message .header (new org .apache .arrow .flatbuf .Schema ()));
10091 }
10192
10293 /**
@@ -106,33 +97,23 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
10697 throws IOException {
10798 long start = out .getCurrentPosition ();
10899 int bodyLength = batch .computeBodyLength ();
109- ByteBuffer metadata = WriteChannel .serialize (batch );
110-
111- int messageLength = 4 + metadata .remaining () + bodyLength ;
112- ByteBuffer serializedHeader =
113- serializeHeader (MessageHeader .RecordBatch , messageLength );
114-
115- // Compute the required alignment. This is not a great way to do it. The issue is
116- // that we need to know the message size to serialize the message header but the
117- // size depends on the alignment, which depends on the message header.
118- // This will serialize the header again with the updated size alignment adjusted.
119- // TODO: We really just want sizeof(MessageHeader) from the serializeHeader() above.
120- // Is there a way to do this?
121- long bufferOffset = start + 4 + serializedHeader .remaining () + 4 + metadata .remaining ();
122- if (bufferOffset % 8 != 0 ) {
123- messageLength += 8 - bufferOffset % 8 ;
124- serializedHeader = serializeHeader (MessageHeader .RecordBatch , messageLength );
125- }
126100
127- // Write message header.
128- out .writeIntLittleEndian (serializedHeader .remaining ());
129- out .write (serializedHeader );
101+ FlatBufferBuilder builder = new FlatBufferBuilder ();
102+ int batchOffset = batch .writeTo (builder );
103+
104+ ByteBuffer serializedMessage = serializeMessage (builder , MessageHeader .RecordBatch ,
105+ batchOffset , bodyLength );
106+
107+ int metadataLength = serializedMessage .remaining ();
130108
131- // Write batch header. with the 4 byte little endian prefix
132- out .writeIntLittleEndian (metadata .remaining ());
133- int metadataSize = metadata .remaining ();
134- long batchStart = out .getCurrentPosition ();
135- out .write (metadata );
109+ // Add extra padding bytes so that length prefix + metadata is a multiple
110+ // of 8 after alignment
111+ if ((metadataLength + 4 ) % 8 != 0 ) {
112+ metadataLength += 8 - (metadataLength + 4 ) % 8 ;
113+ }
114+
115+ out .writeIntLittleEndian (metadataLength );
116+ out .write (serializedMessage );
136117
137118 // Align the output to 8 byte boundary.
138119 out .align ();
@@ -154,31 +135,32 @@ public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
154135 " != " + startPosition + layout .getSize ());
155136 }
156137 }
157- return new ArrowBlock (batchStart , metadataSize , (int )(out .getCurrentPosition () - bufferStart ));
138+ // Metadata size in the Block account for the size prefix
139+ return new ArrowBlock (start , metadataLength + 4 , out .getCurrentPosition () - bufferStart );
158140 }
159141
160142 /**
161143 * Deserializes a RecordBatch
162144 */
163145 public static ArrowRecordBatch deserializeRecordBatch (ReadChannel in ,
164146 BufferAllocator alloc ) throws IOException {
165- Message header = deserializeHeader (in , MessageHeader .RecordBatch );
166- if (header == null ) return null ;
167-
168- int messageLen = (int )header .bodyLength ();
169- // Now read the buffer. This has the metadata followed by the data.
170- ArrowBuf buffer = alloc .buffer (messageLen );
171- long readPosition = in .getCurrentPositiion ();
172- if (in .readFully (buffer , messageLen ) != messageLen ) {
173- throw new IOException ("Unexpected end of input trying to read batch." );
147+ Message message = deserializeMessage (in , MessageHeader .RecordBatch );
148+ if (message == null ) return null ;
149+
150+ if (message .bodyLength () > Integer .MAX_VALUE ) {
151+ throw new IOException ("Cannot currently deserialize record batches over 2GB" );
174152 }
175153
176- // Read the length of the metadata.
177- int metadataLen = buffer .readInt ();
178- buffer = buffer .slice (4 , messageLen - 4 );
179- readPosition += 4 ;
180- messageLen -= 4 ;
181- return deserializeRecordBatch (buffer , readPosition , metadataLen , messageLen );
154+ RecordBatch recordBatchFB = (RecordBatch ) message .header (new RecordBatch ());
155+
156+ int bodyLength = (int ) message .bodyLength ();
157+
158+ // Now read the record batch body
159+ ArrowBuf buffer = alloc .buffer (bodyLength );
160+ if (in .readFully (buffer , bodyLength ) != bodyLength ) {
161+ throw new IOException ("Unexpected end of input trying to read batch." );
162+ }
163+ return deserializeRecordBatch (recordBatchFB , buffer );
182164 }
183165
184166 /**
@@ -187,37 +169,39 @@ public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
187169 */
188170 public static ArrowRecordBatch deserializeRecordBatch (ReadChannel in , ArrowBlock block ,
189171 BufferAllocator alloc ) throws IOException {
190- long readPosition = in . getCurrentPositiion ();
191- int totalLen = block .getMetadataLength () + block .getBodyLength ();
192- if (( readPosition + block . getMetadataLength ()) % 8 != 0 ) {
193- // Compute padded size.
194- totalLen += ( 8 - ( readPosition + block . getMetadataLength ()) % 8 );
172+ // Metadata length contains integer prefix plus byte padding
173+ long totalLen = block .getMetadataLength () + block .getBodyLength ();
174+
175+ if ( totalLen > Integer . MAX_VALUE ) {
176+ throw new IOException ( "Cannot currently deserialize record batches over 2GB" );
195177 }
196178
197- ArrowBuf buffer = alloc .buffer (totalLen );
198- if (in .readFully (buffer , totalLen ) != totalLen ) {
179+ ArrowBuf buffer = alloc .buffer (( int ) totalLen );
180+ if (in .readFully (buffer , ( int ) totalLen ) != totalLen ) {
199181 throw new IOException ("Unexpected end of input trying to read batch." );
200182 }
201183
202- return deserializeRecordBatch (buffer , readPosition , block .getMetadataLength (), totalLen );
203- }
184+ ArrowBuf metadataBuffer = buffer .slice (4 , block .getMetadataLength () - 4 );
204185
205- // Deserializes a record batch. Buffer should start at the RecordBatch and include
206- // all the bytes for the metadata and then data buffers.
207- private static ArrowRecordBatch deserializeRecordBatch (
208- ArrowBuf buffer , long readPosition , int metadataLen , int bufferLen ) {
209186 // Read the metadata.
210187 RecordBatch recordBatchFB =
211- RecordBatch .getRootAsRecordBatch (buffer .nioBuffer ().asReadOnlyBuffer ());
188+ RecordBatch .getRootAsRecordBatch (metadataBuffer .nioBuffer ().asReadOnlyBuffer ());
212189
213- int bufferOffset = metadataLen ;
214- readPosition += bufferOffset ;
215- if (readPosition % 8 != 0 ) {
216- bufferOffset += (int )(8 - readPosition % 8 );
217- }
190+ // Now read the body
191+ final ArrowBuf body = buffer .slice (block .getMetadataLength (),
192+ (int ) totalLen - block .getMetadataLength ());
193+ ArrowRecordBatch result = deserializeRecordBatch (recordBatchFB , body );
218194
195+ metadataBuffer .release ();
196+ buffer .release ();
197+
198+ return result ;
199+ }
200+
201+ // Deserializes a record batch given the Flatbuffer metadata and in-memory body
202+ private static ArrowRecordBatch deserializeRecordBatch (RecordBatch recordBatchFB ,
203+ ArrowBuf body ) {
219204 // Now read the body
220- final ArrowBuf body = buffer .slice (bufferOffset , bufferLen - bufferOffset );
221205 int nodesLength = recordBatchFB .nodesLength ();
222206 List <ArrowFieldNode > nodes = new ArrayList <>();
223207 for (int i = 0 ; i < nodesLength ; ++i ) {
@@ -232,43 +216,44 @@ private static ArrowRecordBatch deserializeRecordBatch(
232216 }
233217 ArrowRecordBatch arrowRecordBatch =
234218 new ArrowRecordBatch (recordBatchFB .length (), nodes , buffers );
235- buffer .release ();
219+ body .release ();
236220 return arrowRecordBatch ;
237221 }
238222
239223 /**
240224 * Serializes a message header.
241225 */
242- private static ByteBuffer serializeHeader (byte headerType , int bodyLength ) {
243- FlatBufferBuilder headerBuilder = new FlatBufferBuilder ();
244- Message .startMessage (headerBuilder );
245- Message .addHeaderType (headerBuilder , headerType );
246- Message .addVersion (headerBuilder , MetadataVersion .V1 );
247- Message .addBodyLength (headerBuilder , bodyLength );
248- headerBuilder .finish (Message .endMessage (headerBuilder ));
249- return headerBuilder .dataBuffer ();
226+ private static ByteBuffer serializeMessage (FlatBufferBuilder builder , byte headerType ,
227+ int headerOffset , int bodyLength ) {
228+ Message .startMessage (builder );
229+ Message .addHeaderType (builder , headerType );
230+ Message .addHeader (builder , headerOffset );
231+ Message .addVersion (builder , MetadataVersion .V1 );
232+ Message .addBodyLength (builder , bodyLength );
233+ builder .finish (Message .endMessage (builder ));
234+ return builder .dataBuffer ();
250235 }
251236
252- private static Message deserializeHeader (ReadChannel in , byte headerType ) throws IOException {
253- // Read the header size. There is an i32 little endian prefix.
237+ private static Message deserializeMessage (ReadChannel in , byte headerType ) throws IOException {
238+ // Read the message size. There is an i32 little endian prefix.
254239 ByteBuffer buffer = ByteBuffer .allocate (4 );
255240 if (in .readFully (buffer ) != 4 ) {
256241 return null ;
257242 }
258243
259- int headerLength = bytesToInt (buffer .array ());
260- buffer = ByteBuffer .allocate (headerLength );
261- if (in .readFully (buffer ) != headerLength ) {
244+ int messageLength = bytesToInt (buffer .array ());
245+ buffer = ByteBuffer .allocate (messageLength );
246+ if (in .readFully (buffer ) != messageLength ) {
262247 throw new IOException (
263- "Unexpected end of stream trying to read header ." );
248+ "Unexpected end of stream trying to read message ." );
264249 }
265250 buffer .rewind ();
266251
267- Message header = Message .getRootAsMessage (buffer );
268- if (header .headerType () != headerType ) {
252+ Message message = Message .getRootAsMessage (buffer );
253+ if (message .headerType () != headerType ) {
269254 throw new IOException ("Invalid message: expecting " + headerType +
270- ". Message contained: " + header .headerType ());
255+ ". Message contained: " + message .headerType ());
271256 }
272- return header ;
257+ return message ;
273258 }
274259}
0 commit comments