From 4c754b1d843aca16356d7702b4868eb02718b83e Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sat, 18 Apr 2026 09:53:55 +0800 Subject: [PATCH] fix --- .../compression/ZstdCompressionCodec.java | 2 +- .../compression/TestCompressionCodec.java | 252 ++++++++++++++++++ .../compression/AbstractCompressionCodec.java | 10 +- 3 files changed, 261 insertions(+), 3 deletions(-) diff --git a/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java b/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java index 290723608d..ed46fe81b4 100644 --- a/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java +++ b/compression/src/main/java/org/apache/arrow/compression/ZstdCompressionCodec.java @@ -44,7 +44,7 @@ protected ArrowBuf doCompress(BufferAllocator allocator, ArrowBuf uncompressedBu long bytesWritten = Zstd.compressUnsafe( compressedBuffer.memoryAddress() + CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH, - dstSize, + maxSize, /*src*/ uncompressedBuffer.memoryAddress(), /* srcSize= */ uncompressedBuffer.writerIndex(), /* level= */ this.compressionLevel); diff --git a/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java b/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java index b8fb4e28b9..804f518ece 100644 --- a/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java +++ b/compression/src/test/java/org/apache/arrow/compression/TestCompressionCodec.java @@ -38,7 +38,9 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.TimeStampMilliVector; import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -53,12 +55,15 @@ import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -347,6 +352,253 @@ void testUnloadCompressed(CompressionUtil.CodecType codec) { }); } + /** + * Test multi-batch streaming with ZSTD compression, wide schema, VectorSchemaRoot reuse, and + * all-null columns. This reproduces the scenario from GH-1116 where the 8-byte + * uncompressed-length prefix of a compressed buffer could be incorrectly written as 0. + */ + @Test + void testMultiBatchZstdStreamWithWideSchemaAndAllNulls() throws Exception { + final int fieldCount = 100; + final int batchCount = 10; + final int rowsPerBatch = 500; + + // Build a wide schema: mix of int, timestamp, and varchar fields + List fields = new ArrayList<>(); + for (int i = 0; i < fieldCount; i++) { + switch (i % 3) { + case 0: + fields.add(Field.nullable("int_" + i, new ArrowType.Int(32, true))); + break; + case 1: + fields.add( + Field.nullable("ts_" + i, new ArrowType.Timestamp(TimeUnit.MILLISECOND, null))); + break; + case 2: + fields.add(Field.nullable("str_" + i, ArrowType.Utf8.INSTANCE)); + break; + default: + break; + } + } + Schema schema = new Schema(fields); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + ArrowStreamWriter writer = + new ArrowStreamWriter( + root, + new DictionaryProvider.MapDictionaryProvider(), + Channels.newChannel(out), + IpcOption.DEFAULT, + CommonsCompressionFactory.INSTANCE, + CompressionUtil.CodecType.ZSTD)) { + writer.start(); + + for (int batch = 0; batch < batchCount; batch++) { + // Clear and reallocate — mimics the reporter's reuse pattern + root.clear(); + for (FieldVector vector : root.getFieldVectors()) { + vector.allocateNew(); + } + root.setRowCount(rowsPerBatch); + + for (int col = 0; col < fieldCount; col++) { + FieldVector vector = root.getVector(col); + // Make some batches have all-null columns for certain fields + boolean allNull = (batch % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch + switch (col % 3) { + case 0: + { + IntVector iv = (IntVector) vector; + for (int row = 0; row < rowsPerBatch; row++) { + if (allNull || row % 7 == 0) { + iv.setNull(row); + } else { + iv.setSafe(row, batch * rowsPerBatch + row); + } + } + break; + } + case 1: + { + TimeStampMilliVector tv = (TimeStampMilliVector) vector; + for (int row = 0; row < rowsPerBatch; row++) { + if (allNull || row % 5 == 0) { + tv.setNull(row); + } else { + tv.setSafe(row, 1_700_000_000_000L + (long) batch * rowsPerBatch + row); + } + } + break; + } + case 2: + { + VarCharVector sv = (VarCharVector) vector; + for (int row = 0; row < rowsPerBatch; row++) { + if (allNull || row % 9 == 0) { + sv.setNull(row); + } else { + sv.setSafe(row, ("val_" + batch + "_" + row).getBytes(StandardCharsets.UTF_8)); + } + } + break; + } + default: + break; + } + vector.setValueCount(rowsPerBatch); + } + + writer.writeBatch(); + } + writer.end(); + } + + // Read back and verify all batches round-trip correctly + try (ArrowStreamReader reader = + new ArrowStreamReader( + new ByteArrayReadableSeekableByteChannel(out.toByteArray()), + allocator, + CommonsCompressionFactory.INSTANCE)) { + int batchesRead = 0; + while (reader.loadNextBatch()) { + VectorSchemaRoot readRoot = reader.getVectorSchemaRoot(); + assertEquals(rowsPerBatch, readRoot.getRowCount()); + assertEquals(fieldCount, readRoot.getFieldVectors().size()); + + // Verify data values, null patterns, and all-null columns + for (int col = 0; col < fieldCount; col++) { + FieldVector vector = readRoot.getVector(col); + boolean allNull = + (batchesRead % 3 == 0) && (col % 3 == 1); // timestamps in every 3rd batch + if (allNull) { + // The key scenario: all-null columns must survive compression round-trip + assertEquals( + rowsPerBatch, + vector.getNullCount(), + "All-null column col=" + col + " batch=" + batchesRead); + } + for (int row = 0; row < rowsPerBatch; row++) { + switch (col % 3) { + case 0: + { + IntVector iv = (IntVector) vector; + if (allNull || row % 7 == 0) { + assertTrue( + iv.isNull(row), + "Expected null at col=" + col + " row=" + row + " batch=" + batchesRead); + } else { + assertEquals( + batchesRead * rowsPerBatch + row, + iv.get(row), + "Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead); + } + break; + } + case 1: + { + TimeStampMilliVector tv = (TimeStampMilliVector) vector; + if (allNull || row % 5 == 0) { + assertTrue( + tv.isNull(row), + "Expected null at col=" + col + " row=" + row + " batch=" + batchesRead); + } else { + assertEquals( + 1_700_000_000_000L + (long) batchesRead * rowsPerBatch + row, + tv.get(row), + "Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead); + } + break; + } + case 2: + { + VarCharVector sv = (VarCharVector) vector; + if (allNull || row % 9 == 0) { + assertTrue( + sv.isNull(row), + "Expected null at col=" + col + " row=" + row + " batch=" + batchesRead); + } else { + assertArrayEquals( + ("val_" + batchesRead + "_" + row).getBytes(StandardCharsets.UTF_8), + sv.get(row), + "Value mismatch at col=" + col + " row=" + row + " batch=" + batchesRead); + } + break; + } + default: + break; + } + } + } + batchesRead++; + } + assertEquals(batchCount, batchesRead); + } + } + + /** + * Test that an all-null fixed-width vector compresses and decompresses correctly. The data buffer + * for such a vector contains all zeros but has a non-zero writerIndex (valueCount * typeWidth). + * The compressed buffer's uncompressed-length prefix must reflect this non-zero size. + */ + @Test + void testAllNullFixedWidthVectorZstdRoundTrip() throws Exception { + final int rowCount = 3469; // same count as the reported issue + final CompressionCodec codec = new ZstdCompressionCodec(); + + try (TimeStampMilliVector origVec = + new TimeStampMilliVector( + "ts", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), + allocator)) { + origVec.allocateNew(rowCount); + // Set all values to null + for (int i = 0; i < rowCount; i++) { + origVec.setNull(i); + } + origVec.setValueCount(rowCount); + + assertEquals(rowCount, origVec.getNullCount()); + + // Compress and decompress each buffer + List origBuffers = origVec.getFieldBuffers(); + assertEquals(2, origBuffers.size()); + + // The data buffer (index 1) should have non-zero writerIndex even though all values are null + ArrowBuf dataBuffer = origBuffers.get(1); + long expectedDataSize = (long) rowCount * 8; // TimestampMilli = 8 bytes per value + assertEquals(expectedDataSize, dataBuffer.writerIndex()); + + // Retain buffers before compressing since compress() closes the input buffer. + // This mirrors what VectorUnloader.appendNodes() does. + for (ArrowBuf buf : origBuffers) { + buf.getReferenceManager().retain(); + } + List compressedBuffers = compressBuffers(codec, origBuffers); + List decompressedBuffers = deCompressBuffers(codec, compressedBuffers); + + assertEquals(2, decompressedBuffers.size()); + + // The decompressed data buffer should have the same writerIndex as the original + assertEquals(expectedDataSize, decompressedBuffers.get(1).writerIndex()); + + // Load into a new vector and verify + try (TimeStampMilliVector newVec = + new TimeStampMilliVector( + "ts_new", + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), + allocator)) { + newVec.loadFieldBuffers(new ArrowFieldNode(rowCount, rowCount), decompressedBuffers); + assertEquals(rowCount, newVec.getValueCount()); + for (int i = 0; i < rowCount; i++) { + assertTrue(newVec.isNull(i)); + } + } + AutoCloseables.close(decompressedBuffers); + } + } + void withRoot( CompressionUtil.CodecType codec, BiConsumer testBody) { diff --git a/vector/src/main/java/org/apache/arrow/vector/compression/AbstractCompressionCodec.java b/vector/src/main/java/org/apache/arrow/vector/compression/AbstractCompressionCodec.java index 58d9e4db9b..c051d06a2c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/compression/AbstractCompressionCodec.java +++ b/vector/src/main/java/org/apache/arrow/vector/compression/AbstractCompressionCodec.java @@ -29,7 +29,14 @@ public abstract class AbstractCompressionCodec implements CompressionCodec { @Override public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) { - if (uncompressedBuffer.writerIndex() == 0L) { + // Capture the uncompressed length once upfront to avoid any inconsistency from + // re-reading writerIndex() at different points. Since the uncompressedBuffer may be + // a shared reference to a vector's internal buffer, reading writerIndex() only once + // ensures the same value is used for the empty-buffer check, compression, size + // comparison, and the 8-byte uncompressed-length prefix. + long uncompressedLength = uncompressedBuffer.writerIndex(); + + if (uncompressedLength == 0L) { // shortcut for empty buffer ArrowBuf compressedBuffer = allocator.buffer(CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH); compressedBuffer.setLong(0, 0); @@ -41,7 +48,6 @@ public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) ArrowBuf compressedBuffer = doCompress(allocator, uncompressedBuffer); long compressedLength = compressedBuffer.writerIndex() - CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH; - long uncompressedLength = uncompressedBuffer.writerIndex(); if (compressedLength > uncompressedLength) { // compressed buffer is larger, send the raw buffer