diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java
index 5a1594e59..5423b360a 100644
--- a/core/src/main/java/io/substrait/type/Type.java
+++ b/core/src/main/java/io/substrait/type/Type.java
@@ -11,6 +11,14 @@
@Value.Enclosing
public interface Type extends TypeExpression, ParameterizedType, NullableType, FunctionArg {
+ /**
+ * Returns a copy of the {@link Type} with the specified nullability.
+ *
+ *
This method is implemented by all concrete {@link Type} classes via Immutables code
+ * generation.
+ */
+ Type withNullable(boolean nullable);
+
static TypeCreator withNullability(boolean nullable) {
return nullable ? TypeCreator.NULLABLE : TypeCreator.REQUIRED;
}
diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java
index 43358e505..999769cd9 100644
--- a/core/src/main/java/io/substrait/type/TypeCreator.java
+++ b/core/src/main/java/io/substrait/type/TypeCreator.java
@@ -25,11 +25,6 @@ public class TypeCreator {
public final Type INTERVAL_YEAR;
public final Type UUID;
- private static NullableSettingTypeVisitor NULLABLE_TRUE_VISITOR =
- new NullableSettingTypeVisitor(true);
- private static NullableSettingTypeVisitor NULLABLE_FALSE_VISITOR =
- new NullableSettingTypeVisitor(false);
-
protected TypeCreator(boolean nullable) {
this.nullable = nullable;
BOOLEAN = Type.Bool.builder().nullable(nullable).build();
@@ -116,161 +111,13 @@ public static TypeCreator of(boolean nullability) {
return nullability ? NULLABLE : REQUIRED;
}
+ /** Make the given type NULLABLE */
public static Type asNullable(Type type) {
- return type.nullable() ? type : type.accept(NULLABLE_TRUE_VISITOR);
+ return type.withNullable(true);
}
+ /** Make the given type NOT NULLABLE */
public static Type asNotNullable(Type type) {
- return type.nullable() ? type.accept(NULLABLE_FALSE_VISITOR) : type;
- }
-
- private static final class NullableSettingTypeVisitor
- implements TypeVisitor {
-
- private final boolean nullability;
-
- NullableSettingTypeVisitor(boolean nullability) {
- this.nullability = nullability;
- }
-
- @Override
- public Type visit(Type.Bool type) throws RuntimeException {
- return Type.Bool.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.I8 type) throws RuntimeException {
- return Type.I8.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.I16 type) throws RuntimeException {
- return Type.I16.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.I32 type) throws RuntimeException {
- return Type.I32.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.I64 type) throws RuntimeException {
- return Type.I64.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.FP32 type) throws RuntimeException {
- return Type.FP32.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.FP64 type) throws RuntimeException {
- return Type.FP64.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Str type) throws RuntimeException {
- return Type.Str.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Binary type) throws RuntimeException {
- return Type.Binary.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Date type) throws RuntimeException {
- return Type.Date.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Time type) throws RuntimeException {
- return Type.Time.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.TimestampTZ type) throws RuntimeException {
- return Type.TimestampTZ.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Timestamp type) throws RuntimeException {
- return Type.Timestamp.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.IntervalYear type) throws RuntimeException {
- return Type.IntervalYear.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.IntervalDay type) throws RuntimeException {
- return Type.IntervalDay.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.IntervalCompound type) throws RuntimeException {
- return Type.IntervalCompound.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.UUID type) throws RuntimeException {
- return Type.UUID.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.FixedChar type) throws RuntimeException {
- return Type.FixedChar.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.VarChar type) throws RuntimeException {
- return Type.VarChar.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.FixedBinary type) throws RuntimeException {
- return Type.FixedBinary.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Decimal type) throws RuntimeException {
- return Type.Decimal.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.PrecisionTime type) throws RuntimeException {
- return Type.PrecisionTime.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.PrecisionTimestamp type) throws RuntimeException {
- return Type.PrecisionTimestamp.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.PrecisionTimestampTZ type) throws RuntimeException {
- return Type.PrecisionTimestampTZ.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Struct type) throws RuntimeException {
- return Type.Struct.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.ListType type) throws RuntimeException {
- return Type.ListType.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.Map type) throws RuntimeException {
- return Type.Map.builder().from(type).nullable(nullability).build();
- }
-
- @Override
- public Type visit(Type.UserDefined type) throws RuntimeException {
- return Type.UserDefined.builder().from(type).nullable(nullability).build();
- }
+ return type.withNullable(false);
}
}