Skip to content

Commit 040d4d6

Browse files
authored
[avro] add avro support for vector type (#7449)
1 parent 6a53de8 commit 040d4d6

6 files changed

Lines changed: 83 additions & 4 deletions

File tree

paimon-core/src/test/java/org/apache/paimon/append/AppendOnlyWriterTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.apache.paimon.table.FileStoreTableFactory;
5353
import org.apache.paimon.types.BlobType;
5454
import org.apache.paimon.types.DataType;
55+
import org.apache.paimon.types.DataTypes;
5556
import org.apache.paimon.types.IntType;
5657
import org.apache.paimon.types.RowType;
5758
import org.apache.paimon.types.VarCharType;
@@ -578,7 +579,7 @@ public void testNonSpillable() throws Exception {
578579
writer.close();
579580
}
580581

581-
/* @Test // TODO this can be enabled after avro supports vector
582+
@Test
582583
public void testVectorStoreSameFormatUsesRowDataWriter() throws Exception {
583584
RowType vectorStoreSchema =
584585
RowType.builder()
@@ -599,7 +600,7 @@ public void testVectorStoreSameFormatUsesRowDataWriter() throws Exception {
599600
assertThat(increment.newFilesIncrement().newFiles()).hasSize(1);
600601
DataFileMeta meta = increment.newFilesIncrement().newFiles().get(0);
601602
assertThat(meta.fileName()).doesNotContain(".vector");
602-
} */
603+
}
603604

604605
private SimpleColStats initStats(Integer min, Integer max, long nullCount) {
605606
return new SimpleColStats(min, max, nullCount);

paimon-format/src/main/java/org/apache/paimon/format/avro/AvroSchemaConverter.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.apache.paimon.types.RowType;
3131
import org.apache.paimon.types.TimeType;
3232
import org.apache.paimon.types.TimestampType;
33+
import org.apache.paimon.types.VectorType;
3334

3435
import org.apache.avro.LogicalTypes;
3536
import org.apache.avro.Schema;
@@ -260,8 +261,11 @@ public static Schema convertToSchema(
260261
}
261262
return nullable ? nullableSchema(map) : map;
262263
case ARRAY:
263-
ArrayType arrayType = (ArrayType) dataType;
264-
DataType elementType = arrayType.getElementType();
264+
case VECTOR:
265+
DataType elementType =
266+
dataType.getTypeRoot() == DataTypeRoot.ARRAY
267+
? ((ArrayType) dataType).getElementType()
268+
: ((VectorType) dataType).getElementType();
265269

266270
ArrayBuilder<Schema> arrayBuilder = SchemaBuilder.builder().array();
267271

paimon-format/src/main/java/org/apache/paimon/format/avro/AvroSchemaVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.paimon.types.DataTypes;
2525
import org.apache.paimon.types.MapType;
2626
import org.apache.paimon.types.RowType;
27+
import org.apache.paimon.types.VectorType;
2728

2829
import org.apache.avro.LogicalType;
2930
import org.apache.avro.LogicalTypes;
@@ -53,6 +54,8 @@ default T visit(Schema schema, DataType type) {
5354
if (type instanceof MapType) {
5455
MapType mapType = (MapType) type;
5556
return visitArrayMap(schema, mapType.getKeyType(), mapType.getValueType());
57+
} else if (type instanceof VectorType) {
58+
return visitArrayVector(schema, ((VectorType) type).getElementType());
5659
} else {
5760
return visitArray(
5861
schema, type == null ? null : ((ArrayType) type).getElementType());
@@ -155,6 +158,8 @@ default T primitive(Schema primitive, DataType type) {
155158

156159
T visitArray(Schema schema, DataType elementType);
157160

161+
T visitArrayVector(Schema schema, DataType elementType);
162+
158163
T visitArrayMap(Schema schema, DataType keyType, DataType valueType);
159164

160165
T visitMap(Schema schema, DataType valueType);

paimon-format/src/main/java/org/apache/paimon/format/avro/FieldReaderFactory.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.apache.paimon.format.avro;
2020

2121
import org.apache.paimon.data.BinaryString;
22+
import org.apache.paimon.data.BinaryVector;
2223
import org.apache.paimon.data.Blob;
2324
import org.apache.paimon.data.BlobDescriptor;
2425
import org.apache.paimon.data.Decimal;
@@ -167,6 +168,12 @@ public FieldReader visitArray(Schema schema, @Nullable DataType elementType) {
167168
return new ArrayReader(elementReader);
168169
}
169170

171+
@Override
172+
public FieldReader visitArrayVector(Schema schema, @Nullable DataType elementType) {
173+
FieldReader elementReader = visit(schema.getElementType(), elementType);
174+
return new ArrayVectorReader(elementReader, elementType);
175+
}
176+
170177
@Override
171178
public FieldReader visitArrayMap(Schema schema, DataType keyType, DataType valueType) {
172179
RowReader entryReader =
@@ -461,6 +468,22 @@ public void skip(Decoder decoder) throws IOException {
461468
}
462469
}
463470

471+
private static class ArrayVectorReader extends ArrayReader {
472+
473+
private final DataType elementType;
474+
475+
private ArrayVectorReader(FieldReader elementReader, DataType elementType) {
476+
super(elementReader);
477+
this.elementType = elementType;
478+
}
479+
480+
@Override
481+
public Object read(Decoder decoder, Object reuse) throws IOException {
482+
GenericArray array = (GenericArray) super.read(decoder, reuse);
483+
return BinaryVector.fromInternalArray(array, elementType);
484+
}
485+
}
486+
464487
private static class ArrayMapReader implements FieldReader {
465488

466489
private final RowReader entryReader;

paimon-format/src/main/java/org/apache/paimon/format/avro/FieldWriterFactory.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.paimon.data.InternalArray;
2828
import org.apache.paimon.data.InternalMap;
2929
import org.apache.paimon.data.InternalRow;
30+
import org.apache.paimon.data.InternalVector;
3031
import org.apache.paimon.types.DataField;
3132
import org.apache.paimon.types.DataType;
3233
import org.apache.paimon.types.DataTypeRoot;
@@ -197,6 +198,22 @@ public FieldWriter visitArray(Schema schema, DataType elementType) {
197198
};
198199
}
199200

201+
@Override
202+
public FieldWriter visitArrayVector(Schema schema, DataType elementType) {
203+
FieldWriter elementWriter = visit(schema.getElementType(), elementType);
204+
return (container, index, encoder) -> {
205+
InternalVector vector = container.getVector(index);
206+
encoder.writeArrayStart();
207+
int numElements = vector.size();
208+
encoder.setItemCount(numElements);
209+
for (int i = 0; i < numElements; i += 1) {
210+
encoder.startItem();
211+
elementWriter.write(vector, i, encoder);
212+
}
213+
encoder.writeArrayEnd();
214+
};
215+
}
216+
200217
@Override
201218
public FieldWriter visitArrayMap(Schema schema, DataType keyType, DataType valueType) {
202219
RowWriter entryWriter =

paimon-format/src/test/java/org/apache/paimon/format/avro/AvroFormatReadWriteTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,18 @@
1818

1919
package org.apache.paimon.format.avro;
2020

21+
import org.apache.paimon.data.BinaryVector;
22+
import org.apache.paimon.data.GenericRow;
2123
import org.apache.paimon.format.FileFormat;
2224
import org.apache.paimon.format.FileFormatFactory;
2325
import org.apache.paimon.format.FormatReadWriteTest;
2426
import org.apache.paimon.options.Options;
27+
import org.apache.paimon.types.DataField;
28+
import org.apache.paimon.types.DataTypes;
29+
import org.apache.paimon.types.RowType;
30+
31+
import java.util.ArrayList;
32+
import java.util.List;
2533

2634
/** An avro {@link FormatReadWriteTest}. */
2735
public class AvroFormatReadWriteTest extends FormatReadWriteTest {
@@ -34,4 +42,25 @@ protected AvroFormatReadWriteTest() {
3442
protected FileFormat fileFormat() {
3543
return new AvroFileFormat(new FileFormatFactory.FormatContext(new Options(), 1024, 1024));
3644
}
45+
46+
@Override
47+
protected RowType rowTypeForFullTypesTest() {
48+
RowType rowWithoutVector = super.rowTypeForFullTypesTest();
49+
List<DataField> fields = new ArrayList<>(rowWithoutVector.getFields());
50+
int vectorFieldId = fields.stream().map(DataField::id).max(Integer::compare).get() + 1;
51+
fields.add(new DataField(vectorFieldId, "embed", DataTypes.VECTOR(3, DataTypes.FLOAT())));
52+
return new RowType(rowWithoutVector.isNullable(), fields);
53+
}
54+
55+
@Override
56+
protected GenericRow expectedRowForFullTypesTest() {
57+
float[] vector = new float[] {1.0f, 2.0f, 3.0f};
58+
GenericRow rowWithoutVector = super.expectedRowForFullTypesTest();
59+
GenericRow row = new GenericRow(rowWithoutVector.getFieldCount() + 1);
60+
for (int i = 0; i < rowWithoutVector.getFieldCount(); ++i) {
61+
row.setField(i, rowWithoutVector.getField(i));
62+
}
63+
row.setField(rowWithoutVector.getFieldCount(), BinaryVector.fromPrimitiveArray(vector));
64+
return row;
65+
}
3766
}

0 commit comments

Comments
 (0)