From 25f72df87779a4a567eae7b2e5d90af4cf94bcbb Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Tue, 7 Oct 2025 14:27:53 +0300 Subject: [PATCH 1/2] GH-810: added support of ExtensionType for UnionVector --- .../src/main/codegen/includes/vv_imports.ftl | 2 + .../templates/AbstractFieldReader.java | 2 +- .../templates/AbstractFieldWriter.java | 4 +- .../main/codegen/templates/BaseReader.java | 6 +-- .../main/codegen/templates/BaseWriter.java | 4 +- .../main/codegen/templates/ComplexCopier.java | 10 +++-- .../main/codegen/templates/NullReader.java | 2 +- .../codegen/templates/PromotableWriter.java | 15 +++---- .../codegen/templates/UnionListWriter.java | 10 +++-- .../main/codegen/templates/UnionReader.java | 30 +++++++++++++- .../main/codegen/templates/UnionVector.java | 37 +++++++++++++++++ .../main/codegen/templates/UnionWriter.java | 37 +++++++++++++++-- .../apache/arrow/vector/BaseValueVector.java | 6 +-- .../arrow/vector/ExtensionTypeVector.java | 3 ++ .../org/apache/arrow/vector/NullVector.java | 6 +-- .../org/apache/arrow/vector/ValueVector.java | 7 ++-- .../complex/AbstractContainerVector.java | 6 +-- .../BaseLargeRepeatedValueViewVector.java | 1 + .../arrow/vector/complex/LargeListVector.java | 6 +-- .../vector/complex/LargeListViewVector.java | 6 +-- .../arrow/vector/complex/ListVector.java | 6 +-- .../arrow/vector/complex/ListViewVector.java | 6 +-- .../complex/impl/AbstractBaseReader.java | 4 +- ...Factory.java => ExtensionTypeFactory.java} | 19 +++++---- .../complex/impl/UnionExtensionWriter.java | 8 ++-- .../complex/impl/UnionLargeListReader.java | 2 +- .../arrow/vector/extension/OpaqueVector.java | 6 +++ .../apache/arrow/vector/TestListVector.java | 35 ++++++++-------- .../apache/arrow/vector/TestMapVector.java | 27 ++++++------ .../org/apache/arrow/vector/UuidVector.java | 7 ++++ .../complex/impl/TestComplexCopier.java | 41 ++++++++++--------- .../complex/impl/TestPromotableWriter.java | 18 ++++---- ...uidWriterFactory.java => UuidFactory.java} | 30 +++++++++++++- .../vector/complex/impl/UuidWriterImpl.java | 3 +- .../complex/writer/TestComplexWriter.java | 37 ++++++++++++----- .../vector/types/pojo/TestExtensionType.java | 6 +++ 36 files changed, 317 insertions(+), 138 deletions(-) rename vector/src/main/java/org/apache/arrow/vector/complex/impl/{ExtensionTypeWriterFactory.java => ExtensionTypeFactory.java} (72%) rename vector/src/test/java/org/apache/arrow/vector/complex/impl/{UuidWriterFactory.java => UuidFactory.java} (54%) diff --git a/vector/src/main/codegen/includes/vv_imports.ftl b/vector/src/main/codegen/includes/vv_imports.ftl index 2bbcecc85..25cfa5802 100644 --- a/vector/src/main/codegen/includes/vv_imports.ftl +++ b/vector/src/main/codegen/includes/vv_imports.ftl @@ -43,6 +43,8 @@ import org.apache.arrow.vector.util.JsonStringArrayList; import java.util.Arrays; import java.util.Random; import java.util.List; +import java.util.HashMap; +import java.util.Map; import java.io.Closeable; import java.io.InputStream; diff --git a/vector/src/main/codegen/templates/AbstractFieldReader.java b/vector/src/main/codegen/templates/AbstractFieldReader.java index c7c5b4d78..5f03a3ddd 100644 --- a/vector/src/main/codegen/templates/AbstractFieldReader.java +++ b/vector/src/main/codegen/templates/AbstractFieldReader.java @@ -109,7 +109,7 @@ public void copyAsField(String name, ${name}Writer writer) { - public void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory) { + public void copyAsValue(StructWriter writer, ExtensionTypeFactory writerFactory) { fail("CopyAsValue StructWriter"); } diff --git a/vector/src/main/codegen/templates/AbstractFieldWriter.java b/vector/src/main/codegen/templates/AbstractFieldWriter.java index ae5b97fae..6af10cba3 100644 --- a/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -110,10 +110,10 @@ public void endEntry() { public void write(ExtensionHolder var1) { this.fail("ExtensionType"); } - public void writeExtension(Object var1) { + public void writeExtension(Object var1, ExtensionType var2) { this.fail("ExtensionType"); } - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory var1) { + public void addExtensionTypeWriterFactory(ExtensionTypeFactory var1, ExtensionType var2) { this.fail("ExtensionType"); } diff --git a/vector/src/main/codegen/templates/BaseReader.java b/vector/src/main/codegen/templates/BaseReader.java index 4c6f49ab9..4c2b032d7 100644 --- a/vector/src/main/codegen/templates/BaseReader.java +++ b/vector/src/main/codegen/templates/BaseReader.java @@ -49,7 +49,7 @@ public interface RepeatedStructReader extends StructReader{ boolean next(); int size(); void copyAsValue(StructWriter writer); - void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory); + void copyAsValue(StructWriter writer, ExtensionTypeFactory writerFactory); } public interface ListReader extends BaseReader{ @@ -60,7 +60,7 @@ public interface RepeatedListReader extends ListReader{ boolean next(); int size(); void copyAsValue(ListWriter writer); - void copyAsValue(ListWriter writer, ExtensionTypeWriterFactory writerFactory); + void copyAsValue(ListWriter writer, ExtensionTypeFactory writerFactory); } public interface MapReader extends BaseReader{ @@ -71,7 +71,7 @@ public interface RepeatedMapReader extends MapReader{ boolean next(); int size(); void copyAsValue(MapWriter writer); - void copyAsValue(MapWriter writer, ExtensionTypeWriterFactory writerFactory); + void copyAsValue(MapWriter writer, ExtensionTypeFactory writerFactory); } public interface ScalarReader extends diff --git a/vector/src/main/codegen/templates/BaseWriter.java b/vector/src/main/codegen/templates/BaseWriter.java index 78da7fddc..cb39973be 100644 --- a/vector/src/main/codegen/templates/BaseWriter.java +++ b/vector/src/main/codegen/templates/BaseWriter.java @@ -122,14 +122,14 @@ public interface ExtensionWriter extends BaseWriter { * * @param value the extension type value to write */ - void writeExtension(Object value); + void writeExtension(Object value, ExtensionType extensionType); /** * Adds the given extension type factory. This factory allows configuring writer implementations for specific ExtensionTypeVector. * * @param factory the extension type factory to add */ - void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory); + void addExtensionTypeWriterFactory(ExtensionTypeFactory factory, ExtensionType extensionType); } public interface ScalarWriter extends diff --git a/vector/src/main/codegen/templates/ComplexCopier.java b/vector/src/main/codegen/templates/ComplexCopier.java index 4df5478f4..dd0b6a1df 100644 --- a/vector/src/main/codegen/templates/ComplexCopier.java +++ b/vector/src/main/codegen/templates/ComplexCopier.java @@ -19,6 +19,7 @@ import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; <@pp.dropOutputFile /> <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/impl/ComplexCopier.java" /> @@ -45,11 +46,11 @@ public static void copy(FieldReader input, FieldWriter output) { writeValue(input, output, null); } - public static void copy(FieldReader input, FieldWriter output, ExtensionTypeWriterFactory extensionTypeWriterFactory) { + public static void copy(FieldReader input, FieldWriter output, ExtensionTypeFactory extensionTypeWriterFactory) { writeValue(input, output, extensionTypeWriterFactory); } - private static void writeValue(FieldReader reader, FieldWriter writer, ExtensionTypeWriterFactory extensionTypeWriterFactory) { + private static void writeValue(FieldReader reader, FieldWriter writer, ExtensionTypeFactory extensionTypeWriterFactory) { final MinorType mt = reader.getMinorType(); switch (mt) { @@ -120,9 +121,10 @@ private static void writeValue(FieldReader reader, FieldWriter writer, Extension } if (reader.isSet()) { Object value = reader.readObject(); + ExtensionType extensionType = (ExtensionType) reader.getField().getType(); if (value != null) { - writer.addExtensionTypeWriterFactory(extensionTypeWriterFactory); - writer.writeExtension(value); + writer.addExtensionTypeWriterFactory(extensionTypeWriterFactory, extensionType); + writer.writeExtension(value, extensionType); } } else { writer.writeNull(); diff --git a/vector/src/main/codegen/templates/NullReader.java b/vector/src/main/codegen/templates/NullReader.java index 052963347..2b55993ee 100644 --- a/vector/src/main/codegen/templates/NullReader.java +++ b/vector/src/main/codegen/templates/NullReader.java @@ -86,7 +86,7 @@ public void read(int arrayIndex, Nullable${name}Holder holder){ } - public void copyAsValue(StructWriter writer, ExtensionTypeWriterFactory writerFactory){} + public void copyAsValue(StructWriter writer, ExtensionTypeFactory writerFactory){} public void read(ExtensionHolder holder) { holder.isSet = 0; } diff --git a/vector/src/main/codegen/templates/PromotableWriter.java b/vector/src/main/codegen/templates/PromotableWriter.java index d22eb00b2..3c422951d 100644 --- a/vector/src/main/codegen/templates/PromotableWriter.java +++ b/vector/src/main/codegen/templates/PromotableWriter.java @@ -15,7 +15,8 @@ * limitations under the License. */ -<@pp.dropOutputFile /> +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;<@pp.dropOutputFile /> <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/impl/PromotableWriter.java" /> <#include "/@includes/license.ftl" /> @@ -541,17 +542,13 @@ public void writeLargeVarChar(String value) { } @Override - public void writeExtension(Object value) { - getWriter(MinorType.EXTENSIONTYPE).writeExtension(value); + public void writeExtension(Object value, ExtensionType extensionType) { + getWriter(MinorType.EXTENSIONTYPE, extensionType).writeExtension(value, extensionType); } @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory) { - getWriter(MinorType.EXTENSIONTYPE).addExtensionTypeWriterFactory(factory); - } - - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory, ArrowType arrowType) { - getWriter(MinorType.EXTENSIONTYPE, arrowType).addExtensionTypeWriterFactory(factory); + public void addExtensionTypeWriterFactory(ExtensionTypeFactory var1, ExtensionType var2) { + getWriter(MinorType.EXTENSIONTYPE, var2).addExtensionTypeWriterFactory(var1, var2); } @Override diff --git a/vector/src/main/codegen/templates/UnionListWriter.java b/vector/src/main/codegen/templates/UnionListWriter.java index 3c41ac72b..88af1059b 100644 --- a/vector/src/main/codegen/templates/UnionListWriter.java +++ b/vector/src/main/codegen/templates/UnionListWriter.java @@ -16,6 +16,7 @@ */ import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.writer.Decimal256Writer; import org.apache.arrow.vector.complex.writer.DecimalWriter; import org.apache.arrow.vector.holders.Decimal256Holder; @@ -24,6 +25,7 @@ import java.lang.UnsupportedOperationException; import java.math.BigDecimal; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; <@pp.dropOutputFile /> <#list ["List", "ListView", "LargeList", "LargeListView"] as listName> @@ -336,14 +338,14 @@ public void writeNull() { } @Override - public void writeExtension(Object value) { - writer.writeExtension(value); + public void writeExtension(Object value, ExtensionType extensionType) { + writer.writeExtension(value, extensionType); writer.setPosition(writer.idx() + 1); } @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory var1) { - writer.addExtensionTypeWriterFactory(var1, extensionType); + public void addExtensionTypeWriterFactory(ExtensionTypeFactory var1, ExtensionType var2) { + writer.addExtensionTypeWriterFactory(var1, var2); } public void write(ExtensionHolder var1) { diff --git a/vector/src/main/codegen/templates/UnionReader.java b/vector/src/main/codegen/templates/UnionReader.java index 96ad3e1b9..bc7e06ce4 100644 --- a/vector/src/main/codegen/templates/UnionReader.java +++ b/vector/src/main/codegen/templates/UnionReader.java @@ -16,6 +16,8 @@ */ +import java.util.HashMap; +import java.util.Map; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -43,9 +45,12 @@ public class UnionReader extends AbstractFieldReader { private BaseReader[] readers = new BaseReader[NUM_SUPPORTED_TYPES]; public UnionVector data; - + private ExtensionTypeFactory extensionTypeFactory; + private Map extensionTypeReaders = new HashMap<>(); + public UnionReader(UnionVector data) { this.data = data; + this.extensionTypeFactory = data.getExtensionTypeFactory(); } public MinorType getMinorType() { @@ -79,6 +84,10 @@ public void read(int index, UnionHolder holder) { } private FieldReader getReaderForIndex(int index) { + return getReaderForIndex(index, null); + } + + private FieldReader getReaderForIndex(int index, ExtensionType extensionType) { int typeValue = data.getTypeValue(index); FieldReader reader = (FieldReader) readers[typeValue]; if (reader != null) { @@ -95,6 +104,11 @@ private FieldReader getReaderForIndex(int index) { return (FieldReader) getListView(); case MAP: return (FieldReader) getMap(); + case EXTENSIONTYPE: + if(extensionType == null) { + throw new IllegalStateException("Cannot read extension type without extensionType"); + } + return (FieldReader) getExtension(extensionType); <#list vv.types as type> <#list type.minor as minor> <#assign name = minor.class?cap_first /> @@ -214,6 +228,20 @@ public void copyAsValue(${name}Writer writer){ + public void read(ExtensionHolder holder){ + getReaderForIndex(idx(), extensionTypeFactory.getExtensionTypeByHolder(holder)).read(holder); + } + + private ExtensionReader getExtension(ExtensionType arrowType) { + ExtensionReader extensionReader = extensionTypeReaders.get(arrowType); + if (extensionReader == null) { + extensionReader = extensionTypeFactory.getReaderImpl(data.getExtensionTypeVector(arrowType)); + extensionReader.setPosition(idx()); + extensionTypeReaders.put(arrowType, extensionReader); + } + return extensionReader; + } + @Override public void copyAsValue(ListWriter writer) { ComplexCopier.copy(this, (FieldWriter) writer); diff --git a/vector/src/main/codegen/templates/UnionVector.java b/vector/src/main/codegen/templates/UnionVector.java index 67efdf60f..906bfa5ee 100644 --- a/vector/src/main/codegen/templates/UnionVector.java +++ b/vector/src/main/codegen/templates/UnionVector.java @@ -104,6 +104,8 @@ public class UnionVector extends AbstractContainerVector implements FieldVector private ValueVector singleVector; private int typeBufferAllocationSizeInBytes; + + private ExtensionTypeFactory extensionTypeFactory; private final FieldType fieldType; private final Field[] typeIds = new Field[Byte.MAX_VALUE + 1]; @@ -325,6 +327,26 @@ public StructVector getStruct() { + + private ExtensionTypeVector extensionTypeVector; + + public T getExtensionTypeVector(ExtensionType type) { + return getExtensionTypeVector(null, type); + } + + public T getExtensionTypeVector(String name, ExtensionType type) { + if (extensionTypeVector == null) { + int vectorCount = internalStruct.size(); + extensionTypeVector = addOrGet(name, MinorType.EXTENSIONTYPE, type, extensionTypeFactory.getVectorClass(type)); + if (internalStruct.size() > vectorCount) { + extensionTypeVector.allocateNew(); + if (callBack != null) { + callBack.doWork(); + } + } + } + return (T) extensionTypeVector; + } public ListVector getList() { if (listVector == null) { @@ -725,6 +747,8 @@ public ValueVector getVectorByType(int typeId, ArrowType arrowType) { return getListView(); case MAP: return getMap(name, arrowType); + case EXTENSIONTYPE: + return getExtensionTypeVector(name, (ExtensionType) arrowType); default: throw new UnsupportedOperationException("Cannot support type: " + MinorType.values()[typeId]); } @@ -847,6 +871,11 @@ public void setSafe(int index, Nullable${name}Holder holder) { + public void setSafe(int index, ExtensionHolder holder) { + setType(index, MinorType.EXTENSIONTYPE); + getExtensionTypeVector(null).setSafe(index, holder); + } + public void setType(int index, MinorType type) { while (index >= getTypeBufferValueCapacity()) { reallocTypeBuffer(); @@ -929,4 +958,12 @@ public void setInitialCapacity(int valueCount, double density) { public void setNull(int index) { throw new UnsupportedOperationException("The method setNull() is not supported on UnionVector."); } + + public void setExtensionTypeFactory(ExtensionTypeFactory extensionTypeFactory) { + this.extensionTypeFactory = extensionTypeFactory; + } + + public ExtensionTypeFactory getExtensionTypeFactory() { + return extensionTypeFactory; + } } diff --git a/vector/src/main/codegen/templates/UnionWriter.java b/vector/src/main/codegen/templates/UnionWriter.java index 272edab17..5a8e7e704 100644 --- a/vector/src/main/codegen/templates/UnionWriter.java +++ b/vector/src/main/codegen/templates/UnionWriter.java @@ -16,8 +16,11 @@ */ import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.impl.NullableStructWriterFactory; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; <@pp.dropOutputFile /> <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/impl/UnionWriter.java" /> @@ -47,8 +50,12 @@ public class UnionWriter extends AbstractFieldWriter implements FieldWriter { protected UnionListWriter listWriter; protected UnionListViewWriter listViewWriter; protected UnionMapWriter mapWriter; + protected UnionExtensionWriter extensionWriter; protected List writers = new java.util.ArrayList<>(); + private Map extensionTypeWriters = new HashMap<>(); + protected final NullableStructWriterFactory nullableStructWriterFactory; + private ExtensionTypeFactory extensionTypeFactory; public UnionWriter(UnionVector vector) { this(vector, NullableStructWriterFactory.getNullableStructWriterFactoryInstance()); @@ -57,6 +64,7 @@ public UnionWriter(UnionVector vector) { public UnionWriter(UnionVector vector, NullableStructWriterFactory nullableStructWriterFactory) { data = vector; this.nullableStructWriterFactory = nullableStructWriterFactory; + this.extensionTypeFactory = data.getExtensionTypeFactory(); } /** @@ -213,8 +221,24 @@ public MapWriter asMap(ArrowType arrowType) { return getMapWriter(arrowType); } - private ExtensionWriter getExtensionWriter(ArrowType arrowType) { - throw new UnsupportedOperationException("ExtensionTypes are not supported yet."); + private ExtensionWriter getExtensionWriter(ExtensionType arrowType) { + if (extensionWriter == null) { + extensionWriter = new UnionExtensionWriter(data.getExtensionTypeVector(arrowType)); + extensionWriter.addExtensionTypeWriterFactory(extensionTypeFactory, arrowType); + extensionWriter.setPosition(idx()); + } + return extensionWriter; + } + + public ExtensionWriter asExtension(ExtensionType arrowType) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + return getExtensionWriter(arrowType); + } + + public void writeExtension(Object value, ExtensionType extensionType) { + data.setType(idx(), MinorType.EXTENSIONTYPE); + getExtensionWriter(extensionType).setPosition(idx()); + getExtensionWriter(extensionType).writeExtension(value, extensionType); } BaseWriter getWriter(MinorType minorType) { @@ -232,7 +256,7 @@ BaseWriter getWriter(MinorType minorType, ArrowType arrowType) { case MAP: return getMapWriter(arrowType); case EXTENSIONTYPE: - return getExtensionWriter(arrowType); + return getExtensionWriter((ExtensionType) arrowType); <#list vv.types as type> <#list type.minor as minor> <#assign name = minor.class?cap_first /> @@ -392,7 +416,7 @@ public void write(${name}Holder holder) { - + public void writeNull() { } @@ -480,6 +504,11 @@ public ExtensionWriter extension(String name, ArrowType arrowType) { return getStructWriter().extension(name, arrowType); } + @Override + public void addExtensionTypeWriterFactory(ExtensionTypeFactory var1, ExtensionType var2) { + this.extensionTypeFactory = var1; + } + <#list vv.types as type><#list type.minor as minor> <#assign lowerName = minor.class?uncap_first /> <#if lowerName == "int" ><#assign lowerName = "integer" /> diff --git a/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java b/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java index cc57cde29..54de53ebb 100644 --- a/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/BaseValueVector.java @@ -22,7 +22,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.ReferenceManager; import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.util.DataSizeRoundingUtil; import org.apache.arrow.vector.util.TransferPair; @@ -263,13 +263,13 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { @Override public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException(); } @Override public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException(); } diff --git a/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java b/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java index 3762fecd0..ac99eeeb6 100644 --- a/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/ExtensionTypeVector.java @@ -24,6 +24,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -287,4 +288,6 @@ public BufferAllocator getAllocator() { public OUT accept(VectorVisitor visitor, IN value) { return visitor.visit(this, value); } + + public abstract void setSafe(int index, ExtensionHolder holder); } diff --git a/vector/src/main/java/org/apache/arrow/vector/NullVector.java b/vector/src/main/java/org/apache/arrow/vector/NullVector.java index 0d6dab283..3b4a6a0e5 100644 --- a/vector/src/main/java/org/apache/arrow/vector/NullVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/NullVector.java @@ -27,7 +27,7 @@ import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.impl.NullReader; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -332,13 +332,13 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { @Override public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException(); } @Override public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException(); } diff --git a/vector/src/main/java/org/apache/arrow/vector/ValueVector.java b/vector/src/main/java/org/apache/arrow/vector/ValueVector.java index e0628c2ee..e4f8d289d 100644 --- a/vector/src/main/java/org/apache/arrow/vector/ValueVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/ValueVector.java @@ -22,7 +22,7 @@ import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.memory.util.hash.ArrowBufHasher; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; @@ -319,8 +319,7 @@ public interface ValueVector extends Closeable, Iterable { * @param from source vector * @param writerFactory the extension type writer factory to use for copying extension type values */ - void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory); + void copyFrom(int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory); /** * Same as {@link #copyFrom(int, int, ValueVector)} except that it handles the case when the @@ -332,7 +331,7 @@ void copyFrom( * @param writerFactory the extension type writer factory to use for copying extension type values */ void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory); + int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory); /** * Accept a generic {@link VectorVisitor} and return the result. diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java index 429f9884b..034fffc5b 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java @@ -21,7 +21,7 @@ import org.apache.arrow.vector.DensityAwareVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.FixedSizeList; @@ -154,13 +154,13 @@ public void copyFromSafe(int fromIndex, int thisIndex, ValueVector from) { @Override public void copyFrom( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException(); } @Override public void copyFromSafe( - int fromIndex, int thisIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int fromIndex, int thisIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException(); } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java index fac3f86bb..ac644862c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java @@ -347,6 +347,7 @@ protected int getMaxViewEndChildVectorByIndex(int index) { * Initialize the data vector (and execute callback) if it hasn't already been done, returns the * data vector. */ + @SuppressWarnings("unchecked") public AddOrGetResult addOrGetVector(FieldType fieldType) { boolean created = false; if (vector instanceof NullVector) { diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java index 48c8127e2..d31e95d0c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java @@ -49,7 +49,7 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.impl.UnionLargeListReader; import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -497,7 +497,7 @@ public void copyFrom(int inIndex, int outIndex, ValueVector from) { */ @Override public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); @@ -517,7 +517,7 @@ public void copyFrom( */ @Override public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { copyFrom(inIndex, outIndex, from, writerFactory); } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 992a66444..7fa3c2252 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -41,7 +41,7 @@ import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.impl.UnionLargeListViewReader; import org.apache.arrow.vector.complex.impl.UnionLargeListViewWriter; import org.apache.arrow.vector.complex.impl.UnionListReader; @@ -349,14 +349,14 @@ public void copyFrom(int inIndex, int outIndex, ValueVector from) { @Override public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException( "LargeListViewVector does not support copyFromSafe operation yet."); } @Override public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { throw new UnsupportedOperationException( "LargeListViewVector does not support copyFrom operation yet."); } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 89549257c..5b41d26a4 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -42,7 +42,7 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -415,7 +415,7 @@ public void copyFrom(int inIndex, int outIndex, ValueVector from) { */ @Override public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { copyFrom(inIndex, outIndex, from, writerFactory); } @@ -430,7 +430,7 @@ public void copyFromSafe( */ @Override public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java b/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java index 278424042..4d8a1b53e 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java @@ -42,7 +42,7 @@ import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.ComplexCopier; -import org.apache.arrow.vector.complex.impl.ExtensionTypeWriterFactory; +import org.apache.arrow.vector.complex.impl.ExtensionTypeFactory; import org.apache.arrow.vector.complex.impl.UnionListViewReader; import org.apache.arrow.vector.complex.impl.UnionListViewWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -341,7 +341,7 @@ public void copyFromSafe(int inIndex, int outIndex, ValueVector from) { @Override public void copyFromSafe( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { copyFrom(inIndex, outIndex, from, writerFactory); } @@ -357,7 +357,7 @@ public void copyFrom(int inIndex, int outIndex, ValueVector from) { @Override public void copyFrom( - int inIndex, int outIndex, ValueVector from, ExtensionTypeWriterFactory writerFactory) { + int inIndex, int outIndex, ValueVector from, ExtensionTypeFactory writerFactory) { Preconditions.checkArgument(this.getMinorType() == from.getMinorType()); FieldReader in = from.getReader(); in.setPosition(inIndex); diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java index bf074ecb9..af7f6489c 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/AbstractBaseReader.java @@ -117,12 +117,12 @@ public void copyAsValue(MapWriter writer) { } @Override - public void copyAsValue(ListWriter writer, ExtensionTypeWriterFactory writerFactory) { + public void copyAsValue(ListWriter writer, ExtensionTypeFactory writerFactory) { ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); } @Override - public void copyAsValue(MapWriter writer, ExtensionTypeWriterFactory writerFactory) { + public void copyAsValue(MapWriter writer, ExtensionTypeFactory writerFactory) { ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); } } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeWriterFactory.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeFactory.java similarity index 72% rename from vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeWriterFactory.java rename to vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeFactory.java index 09f0314c5..5e6051371 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeWriterFactory.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeFactory.java @@ -17,15 +17,12 @@ package org.apache.arrow.vector.complex.impl; import org.apache.arrow.vector.ExtensionTypeVector; +import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.FieldWriter; +import org.apache.arrow.vector.holders.ExtensionHolder; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; -/** - * A factory interface for creating instances of {@link ExtensionTypeWriter}. This factory allows - * configuring writer implementations for specific {@link ExtensionTypeVector}. - * - * @param the type of writer implementation for a specific {@link ExtensionTypeVector}. - */ -public interface ExtensionTypeWriterFactory { +public interface ExtensionTypeFactory { /** * Returns an instance of the writer implementation for the given {@link ExtensionTypeVector}. @@ -34,5 +31,11 @@ public interface ExtensionTypeWriterFactory { * returned. * @return an instance of the writer implementation for the given {@link ExtensionTypeVector}. */ - T getWriterImpl(ExtensionTypeVector vector); + FieldWriter getWriterImpl(ExtensionTypeVector vector); + + Class getVectorClass(ExtensionType extensionType); + + FieldReader getReaderImpl(ExtensionTypeVector vector); + + ExtensionType getExtensionTypeByHolder(ExtensionHolder holder); } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java index 4219069cb..4f5373371 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionExtensionWriter.java @@ -19,6 +19,7 @@ import org.apache.arrow.vector.ExtensionTypeVector; import org.apache.arrow.vector.complex.writer.FieldWriter; import org.apache.arrow.vector.holders.ExtensionHolder; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; import org.apache.arrow.vector.types.pojo.Field; public class UnionExtensionWriter extends AbstractFieldWriter { @@ -55,12 +56,13 @@ public void close() throws Exception { } @Override - public void writeExtension(Object var1) { - this.writer.writeExtension(var1); + public void writeExtension(Object var1, ExtensionType var2) { + this.writer.writeExtension(var1, var2); } @Override - public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory) { + public void addExtensionTypeWriterFactory( + ExtensionTypeFactory factory, ExtensionType extensionType) { this.writer = factory.getWriterImpl(vector); this.writer.setPosition(idx()); } diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java index a9104cb0d..1d94bfa84 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/UnionLargeListReader.java @@ -106,7 +106,7 @@ public void copyAsValue(UnionLargeListWriter writer) { ComplexCopier.copy(this, (FieldWriter) writer); } - public void copyAsValue(UnionLargeListWriter writer, ExtensionTypeWriterFactory writerFactory) { + public void copyAsValue(UnionLargeListWriter writer, ExtensionTypeFactory writerFactory) { ComplexCopier.copy(this, (FieldWriter) writer, writerFactory); } } diff --git a/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueVector.java b/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueVector.java index 00eb9a984..7c663c0d7 100644 --- a/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueVector.java +++ b/vector/src/main/java/org/apache/arrow/vector/extension/OpaqueVector.java @@ -21,6 +21,7 @@ import org.apache.arrow.vector.ExtensionTypeVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueIterableVector; +import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.types.pojo.Field; /** @@ -46,6 +47,11 @@ public Object getObject(int index) { return getUnderlyingVector().getObject(index); } + @Override + public void setSafe(int index, ExtensionHolder holder) { + throw new UnsupportedOperationException(); + } + @Override public int hashCode(int index) { return hashCode(index, null); diff --git a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index c6c7c5c86..281ad180b 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -36,7 +36,7 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; +import org.apache.arrow.vector.complex.impl.UuidFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.holder.UuidHolder; @@ -1208,7 +1208,8 @@ public void testGetTransferPairWithField() { @Test public void testListVectorWithExtensionType() throws Exception { - final FieldType type = FieldType.nullable(new UuidType()); + UuidType uuidType = new UuidType(); + final FieldType type = FieldType.nullable(uuidType); try (final ListVector inVector = new ListVector("list", allocator, type, null)) { UnionListWriter writer = inVector.getWriter(); writer.allocate(); @@ -1216,10 +1217,10 @@ public void testListVectorWithExtensionType() throws Exception { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); writer.startList(); - ExtensionWriter extensionWriter = writer.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); - extensionWriter.writeExtension(u2); + ExtensionWriter extensionWriter = writer.extension(uuidType); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter.writeExtension(u1, uuidType); + extensionWriter.writeExtension(u2, uuidType); writer.endList(); writer.setValueCount(1); @@ -1236,7 +1237,8 @@ public void testListVectorWithExtensionType() throws Exception { @Test public void testListVectorReaderForExtensionType() throws Exception { - final FieldType type = FieldType.nullable(new UuidType()); + UuidType uuidType = new UuidType(); + final FieldType type = FieldType.nullable(uuidType); try (final ListVector inVector = new ListVector("list", allocator, type, null)) { UnionListWriter writer = inVector.getWriter(); writer.allocate(); @@ -1244,10 +1246,10 @@ public void testListVectorReaderForExtensionType() throws Exception { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); writer.startList(); - ExtensionWriter extensionWriter = writer.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); - extensionWriter.writeExtension(u2); + ExtensionWriter extensionWriter = writer.extension(uuidType); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter.writeExtension(u1, uuidType); + extensionWriter.writeExtension(u2, uuidType); writer.endList(); writer.setValueCount(1); @@ -1273,6 +1275,7 @@ public void testListVectorReaderForExtensionType() throws Exception { @Test public void testCopyFromForExtensionType() throws Exception { + UuidType uuidType = new UuidType(); try (ListVector inVector = ListVector.empty("input", allocator); ListVector outVector = ListVector.empty("output", allocator)) { UnionListWriter writer = inVector.getWriter(); @@ -1281,10 +1284,10 @@ public void testCopyFromForExtensionType() throws Exception { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); writer.startList(); - ExtensionWriter extensionWriter = writer.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); - extensionWriter.writeExtension(u2); + ExtensionWriter extensionWriter = writer.extension(uuidType); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter.writeExtension(u1, uuidType); + extensionWriter.writeExtension(u2, uuidType); extensionWriter.writeNull(); writer.endList(); @@ -1292,7 +1295,7 @@ public void testCopyFromForExtensionType() throws Exception { // copy values from input to output outVector.allocateNew(); - outVector.copyFrom(0, 0, inVector, new UuidWriterFactory()); + outVector.copyFrom(0, 0, inVector, new UuidFactory()); outVector.setValueCount(1); UnionListReader reader = outVector.getReader(); diff --git a/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java b/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java index 1a1810d0f..8d3fa118a 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java @@ -34,7 +34,7 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionMapWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; +import org.apache.arrow.vector.complex.impl.UuidFactory; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.ListWriter; @@ -1272,6 +1272,7 @@ public void testMapTypeReturnsSupportedMapWriter() { @Test public void testMapVectorWithExtensionType() throws Exception { + UuidType uuidType = new UuidType(); try (final MapVector inVector = MapVector.empty("map", allocator, false)) { inVector.allocateNew(); UnionMapWriter writer = inVector.getWriter(); @@ -1281,15 +1282,15 @@ public void testMapVectorWithExtensionType() throws Exception { writer.startMap(); writer.startEntry(); writer.key().bigInt().writeBigInt(0); - ExtensionWriter extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); + ExtensionWriter extensionWriter = writer.value().extension(uuidType); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter.writeExtension(u1, uuidType); writer.endEntry(); writer.startEntry(); writer.key().bigInt().writeBigInt(1); extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u2); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), new UuidType()); + extensionWriter.writeExtension(u2, uuidType); writer.endEntry(); writer.endMap(); @@ -1315,6 +1316,8 @@ public void testMapVectorWithExtensionType() throws Exception { @Test public void testCopyFromForExtensionType() throws Exception { + UuidType uuidType = new UuidType(); + try (final MapVector inVector = MapVector.empty("in", allocator, false); final MapVector outVector = MapVector.empty("out", allocator, false)) { inVector.allocateNew(); @@ -1325,21 +1328,21 @@ public void testCopyFromForExtensionType() throws Exception { writer.startMap(); writer.startEntry(); writer.key().bigInt().writeBigInt(0); - ExtensionWriter extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); + ExtensionWriter extensionWriter = writer.value().extension(uuidType); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter.writeExtension(u1, uuidType); writer.endEntry(); writer.startEntry(); writer.key().bigInt().writeBigInt(1); extensionWriter = writer.value().extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u2); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), new UuidType()); + extensionWriter.writeExtension(u2, uuidType); writer.endEntry(); writer.endMap(); writer.setValueCount(1); outVector.allocateNew(); - outVector.copyFrom(0, 0, inVector, new UuidWriterFactory()); + outVector.copyFrom(0, 0, inVector, new UuidFactory()); outVector.setValueCount(1); UnionMapReader mapReader = outVector.getReader(); diff --git a/vector/src/test/java/org/apache/arrow/vector/UuidVector.java b/vector/src/test/java/org/apache/arrow/vector/UuidVector.java index 72ba4aa55..3f03fc027 100644 --- a/vector/src/test/java/org/apache/arrow/vector/UuidVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/UuidVector.java @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex.impl.UuidReaderImpl; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.holder.UuidHolder; +import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.UuidType; @@ -49,6 +50,12 @@ public UUID getObject(int index) { return new UUID(bb.getLong(), bb.getLong()); } + @Override + public void setSafe(int index, ExtensionHolder holder) { + UuidHolder uuidHolder = (UuidHolder) holder; + setSafe(index, uuidHolder.value); + } + @Override public int hashCode(int index) { return hashCode(index, null); diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java index 738e8905e..1ac89a31c 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestComplexCopier.java @@ -851,6 +851,7 @@ public void testCopyMapVectorWithMapValue() { @Test public void testCopyListVectorWithExtensionType() { + UuidType uuidType = new UuidType(); try (ListVector from = ListVector.empty("v", allocator); ListVector to = ListVector.empty("v", allocator)) { @@ -860,10 +861,10 @@ public void testCopyListVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { listWriter.setPosition(i); listWriter.startList(); - ExtensionWriter extensionWriter = listWriter.extension(new UuidType()); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(UUID.randomUUID()); - extensionWriter.writeExtension(UUID.randomUUID()); + ExtensionWriter extensionWriter = listWriter.extension(uuidType); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter.writeExtension(UUID.randomUUID(), uuidType); + extensionWriter.writeExtension(UUID.randomUUID(), uuidType); listWriter.endList(); } from.setValueCount(COUNT); @@ -874,7 +875,7 @@ public void testCopyListVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out, new UuidFactory()); } to.setValueCount(COUNT); @@ -886,6 +887,7 @@ public void testCopyListVectorWithExtensionType() { @Test public void testCopyMapVectorWithExtensionType() { + UuidType uuidType = new UuidType(); try (final MapVector from = MapVector.empty("v", allocator, false); final MapVector to = MapVector.empty("v", allocator, false)) { @@ -896,12 +898,12 @@ public void testCopyMapVectorWithExtensionType() { mapWriter.setPosition(i); mapWriter.startMap(); mapWriter.startEntry(); - ExtensionWriter extensionKeyWriter = mapWriter.key().extension(new UuidType()); - extensionKeyWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionKeyWriter.writeExtension(UUID.randomUUID()); - ExtensionWriter extensionValueWriter = mapWriter.value().extension(new UuidType()); - extensionValueWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionValueWriter.writeExtension(UUID.randomUUID()); + ExtensionWriter extensionKeyWriter = mapWriter.key().extension(uuidType); + extensionKeyWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionKeyWriter.writeExtension(UUID.randomUUID(), uuidType); + ExtensionWriter extensionValueWriter = mapWriter.value().extension(uuidType); + extensionValueWriter.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionValueWriter.writeExtension(UUID.randomUUID(), uuidType); mapWriter.endEntry(); mapWriter.endMap(); } @@ -914,7 +916,7 @@ public void testCopyMapVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out, new UuidFactory()); } to.setValueCount(COUNT); @@ -925,6 +927,7 @@ public void testCopyMapVectorWithExtensionType() { @Test public void testCopyStructVectorWithExtensionType() { + UuidType uuidType = new UuidType(); try (final StructVector from = StructVector.empty("v", allocator); final StructVector to = StructVector.empty("v", allocator)) { @@ -934,12 +937,12 @@ public void testCopyStructVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { structWriter.setPosition(i); structWriter.start(); - ExtensionWriter extensionWriter1 = structWriter.extension("timestamp1", new UuidType()); - extensionWriter1.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter1.writeExtension(UUID.randomUUID()); - ExtensionWriter extensionWriter2 = structWriter.extension("timestamp2", new UuidType()); - extensionWriter2.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter2.writeExtension(UUID.randomUUID()); + ExtensionWriter extensionWriter1 = structWriter.extension("timestamp1", uuidType); + extensionWriter1.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter1.writeExtension(UUID.randomUUID(), uuidType); + ExtensionWriter extensionWriter2 = structWriter.extension("timestamp2", uuidType); + extensionWriter2.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); + extensionWriter2.writeExtension(UUID.randomUUID(), uuidType); structWriter.end(); } @@ -951,7 +954,7 @@ public void testCopyStructVectorWithExtensionType() { for (int i = 0; i < COUNT; i++) { in.setPosition(i); out.setPosition(i); - ComplexCopier.copy(in, out, new UuidWriterFactory()); + ComplexCopier.copy(in, out, new UuidFactory()); } to.setValueCount(COUNT); diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 7b8b1f9ef..31f82a389 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -782,21 +782,22 @@ public void testPromoteToUnionFromDecimal() throws Exception { @Test public void testExtensionType() throws Exception { + UuidType uuidType = new UuidType(); try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); final UuidVector v = - container.addOrGet("uuid", FieldType.nullable(new UuidType()), UuidVector.class); + container.addOrGet("uuid", FieldType.nullable(uuidType), UuidVector.class); final PromotableWriter writer = new PromotableWriter(v, container)) { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); container.allocateNew(); container.setValueCount(1); - writer.addExtensionTypeWriterFactory(new UuidWriterFactory()); + writer.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); writer.setPosition(0); - writer.writeExtension(u1); + writer.writeExtension(u1, uuidType); writer.setPosition(1); - writer.writeExtension(u2); + writer.writeExtension(u2, uuidType); container.setValueCount(2); @@ -808,20 +809,21 @@ public void testExtensionType() throws Exception { @Test public void testExtensionTypeForList() throws Exception { + UuidType uuidType = new UuidType(); try (final ListVector container = ListVector.empty(EMPTY_SCHEMA_PATH, allocator); final UuidVector v = - (UuidVector) container.addOrGetVector(FieldType.nullable(new UuidType())).getVector(); + (UuidVector) container.addOrGetVector(FieldType.nullable(uuidType)).getVector(); final PromotableWriter writer = new PromotableWriter(v, container)) { UUID u1 = UUID.randomUUID(); UUID u2 = UUID.randomUUID(); container.allocateNew(); container.setValueCount(1); - writer.addExtensionTypeWriterFactory(new UuidWriterFactory()); + writer.addExtensionTypeWriterFactory(new UuidFactory(), uuidType); writer.setPosition(0); - writer.writeExtension(u1); + writer.writeExtension(u1, uuidType); writer.setPosition(1); - writer.writeExtension(u2); + writer.writeExtension(u2, uuidType); container.setValueCount(2); diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidWriterFactory.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidFactory.java similarity index 54% rename from vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidWriterFactory.java rename to vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidFactory.java index 1b1bf4e6e..8f4df7d80 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidWriterFactory.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidFactory.java @@ -18,8 +18,12 @@ import org.apache.arrow.vector.ExtensionTypeVector; import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.holder.UuidHolder; +import org.apache.arrow.vector.holders.ExtensionHolder; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; +import org.apache.arrow.vector.types.pojo.UuidType; -public class UuidWriterFactory implements ExtensionTypeWriterFactory { +public class UuidFactory implements ExtensionTypeFactory { @Override public AbstractFieldWriter getWriterImpl(ExtensionTypeVector extensionTypeVector) { @@ -28,4 +32,28 @@ public AbstractFieldWriter getWriterImpl(ExtensionTypeVector extensionTypeVector } return null; } + + @Override + public Class getVectorClass(ExtensionType extensionType) { + if (extensionType instanceof UuidType) { + return UuidVector.class; + } + throw new UnsupportedOperationException("Unsupported extension type " + extensionType); + } + + @Override + public ExtensionType getExtensionTypeByHolder(ExtensionHolder holder) { + if (holder instanceof UuidHolder) { + return new UuidType(); + } + return null; + } + + @Override + public AbstractFieldReader getReaderImpl(ExtensionTypeVector extensionTypeVector) { + if (extensionTypeVector instanceof UuidVector) { + return new UuidReaderImpl((UuidVector) extensionTypeVector); + } + return null; + } } diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java index 68029b1df..35396286f 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidWriterImpl.java @@ -21,6 +21,7 @@ import org.apache.arrow.vector.UuidVector; import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.holders.ExtensionHolder; +import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; public class UuidWriterImpl extends AbstractExtensionTypeWriter { @@ -29,7 +30,7 @@ public UuidWriterImpl(UuidVector vector) { } @Override - public void writeExtension(Object value) { + public void writeExtension(Object value, ExtensionType extensionType) { UUID uuid = (UUID) value; ByteBuffer bb = ByteBuffer.allocate(16); bb.putLong(uuid.getMostSignificantBits()); diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index f374eb41e..f9109487a 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -66,7 +66,7 @@ import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionReader; import org.apache.arrow.vector.complex.impl.UnionWriter; -import org.apache.arrow.vector.complex.impl.UuidWriterFactory; +import org.apache.arrow.vector.complex.impl.UuidFactory; import org.apache.arrow.vector.complex.reader.BaseReader.StructReader; import org.apache.arrow.vector.complex.reader.BigIntReader; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -1101,25 +1101,27 @@ public void testListViewMapType() { @Test public void simpleUnion() throws Exception { List bufs = new ArrayList(); + UUID uuid = UUID.randomUUID(); UnionVector vector = new UnionVector("union", allocator, /* field type */ null, /* call-back */ null); + vector.setExtensionTypeFactory(new UuidFactory()); UnionWriter unionWriter = new UnionWriter(vector); unionWriter.allocate(); for (int i = 0; i < COUNT; i++) { unionWriter.setPosition(i); - if (i % 5 == 0) { + if (i % 6 == 0) { unionWriter.writeInt(i); - } else if (i % 5 == 1) { + } else if (i % 6 == 1) { TimeStampMilliTZHolder holder = new TimeStampMilliTZHolder(); holder.value = (long) i; holder.timezone = "AsdfTimeZone"; unionWriter.write(holder); - } else if (i % 5 == 2) { + } else if (i % 6 == 2) { DurationHolder holder = new DurationHolder(); holder.value = (long) i; holder.unit = TimeUnit.NANOSECOND; unionWriter.write(holder); - } else if (i % 5 == 3) { + } else if (i % 6 == 3) { FixedSizeBinaryHolder holder = new FixedSizeBinaryHolder(); ArrowBuf buf = allocator.buffer(4); buf.setInt(0, i); @@ -1127,6 +1129,13 @@ public void simpleUnion() throws Exception { holder.buffer = buf; unionWriter.write(holder); bufs.add(buf); + } else if (i % 6 == 4) { + UuidHolder holder = new UuidHolder(); + ByteBuffer bb = ByteBuffer.allocate(16); + bb.putLong(uuid.getMostSignificantBits()); + bb.putLong(uuid.getLeastSignificantBits()); + holder.value = bb.array(); + unionWriter.asExtension(new UuidType()).writeExtension(uuid, new UuidType()); } else { unionWriter.writeFloat4((float) i); } @@ -1135,23 +1144,29 @@ public void simpleUnion() throws Exception { UnionReader unionReader = new UnionReader(vector); for (int i = 0; i < COUNT; i++) { unionReader.setPosition(i); - if (i % 5 == 0) { + if (i % 6 == 0) { assertEquals(i, unionReader.readInteger().intValue()); - } else if (i % 5 == 1) { + } else if (i % 6 == 1) { NullableTimeStampMilliTZHolder holder = new NullableTimeStampMilliTZHolder(); unionReader.read(holder); assertEquals(i, holder.value); assertEquals("AsdfTimeZone", holder.timezone); - } else if (i % 5 == 2) { + } else if (i % 6 == 2) { NullableDurationHolder holder = new NullableDurationHolder(); unionReader.read(holder); assertEquals(i, holder.value); assertEquals(TimeUnit.NANOSECOND, holder.unit); - } else if (i % 5 == 3) { + } else if (i % 6 == 3) { NullableFixedSizeBinaryHolder holder = new NullableFixedSizeBinaryHolder(); unionReader.read(holder); assertEquals(i, holder.buffer.getInt(0)); assertEquals(4, holder.byteWidth); + } else if (i % 6 == 4) { + UuidHolder holder = new UuidHolder(); + unionReader.read(holder); + ByteBuffer bb = ByteBuffer.wrap(holder.value); + UUID actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(uuid, actualUuid); } else { assertEquals((float) i, unionReader.readFloat(), 1e-12); } @@ -2511,8 +2526,8 @@ public void extensionWriterReader() throws Exception { { ExtensionWriter extensionWriter = rootWriter.extension("uuid1", new UuidType()); extensionWriter.setPosition(0); - extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); - extensionWriter.writeExtension(u1); + extensionWriter.addExtensionTypeWriterFactory(new UuidFactory(), new UuidType()); + extensionWriter.writeExtension(u1, new UuidType()); } // read StructReader rootReader = new SingleStructReaderImpl(parent).reader("root"); diff --git a/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java index d24708d66..65496cb36 100644 --- a/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java +++ b/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -47,6 +47,7 @@ import org.apache.arrow.vector.compare.Range; import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.types.FloatingPointPrecision; @@ -369,6 +370,11 @@ public int hashCode(int index, ArrowBufHasher hasher) { return getUnderlyingVector().getObject(index); } + @Override + public void setSafe(int index, ExtensionHolder holder) { + throw new UnsupportedOperationException(); + } + public void set(int index, float latitude, float longitude) { getUnderlyingVector().getChild("Latitude", Float4Vector.class).set(index, latitude); getUnderlyingVector().getChild("Longitude", Float4Vector.class).set(index, longitude); From 8f4f795aa0e1d1856d783da268c6033991675b73 Mon Sep 17 00:00:00 2001 From: Ivan Chesnov Date: Tue, 7 Oct 2025 14:36:24 +0300 Subject: [PATCH 2/2] GH-810: updated descriptions --- .../complex/impl/ExtensionTypeFactory.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeFactory.java b/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeFactory.java index 5e6051371..d79d5c986 100644 --- a/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeFactory.java +++ b/vector/src/main/java/org/apache/arrow/vector/complex/impl/ExtensionTypeFactory.java @@ -22,6 +22,11 @@ import org.apache.arrow.vector.holders.ExtensionHolder; import org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; +/** + * A factory interface that allows configuring writer implementations for specific {@link + * ExtensionTypeVector}, get the vector class for a given {@link ExtensionType}, and get the reader + * implementation for a given {@link ExtensionTypeVector}. + */ public interface ExtensionTypeFactory { /** @@ -33,9 +38,29 @@ public interface ExtensionTypeFactory { */ FieldWriter getWriterImpl(ExtensionTypeVector vector); + /** + * Returns the vector class for the given {@link ExtensionType}. + * + * @param extensionType the {@link ExtensionType} for which the vector class is to be returned. + * @return the vector class for the given {@link ExtensionType}. + */ Class getVectorClass(ExtensionType extensionType); + /** + * Returns an instance of the reader implementation for the given {@link ExtensionTypeVector}. + * + * @param vector the {@link ExtensionTypeVector} for which the reader implementation is to be + * returned. + * @return an instance of the reader implementation for the given {@link ExtensionTypeVector}. + */ FieldReader getReaderImpl(ExtensionTypeVector vector); + /** + * Returns the {@link ExtensionType} for the given {@link ExtensionHolder}. + * + * @param holder the {@link ExtensionHolder} for which the {@link ExtensionType} is to be + * returned. + * @return the {@link ExtensionType} for the given {@link ExtensionHolder}. + */ ExtensionType getExtensionTypeByHolder(ExtensionHolder holder); }