Skip to content
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ to create `TypedColumn`s and with those a new Dataset from pieces of another usi
```kotlin
val dataset: Dataset<YourClass> = ...
val newDataset: Dataset<Pair<TypeA, TypeB>> = dataset.selectTyped(col(YourClass::colA), col(YourClass::colB))

// Alternatively, for instance when working with a Dataset<Row>
val typedDataset: Dataset<Pair<String, Int>> = otherDataset.selectTyped(col("a").`as`<String>(), col("b").`as`<Int>())
```

### Overload resolution ambiguity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ object KSparkExtensions {

def collectAsList[T](ds: Dataset[T]): util.List[T] = JavaConverters.seqAsJavaList(ds.collect())

def tailAsList[T](ds: Dataset[T], n: Int): util.List[T] = util.Arrays.asList(ds.tail(n) : _*)

def debugCodegen(df: Dataset[_]): Unit = {
import org.apache.spark.sql.execution.debug._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,12 +647,19 @@ operator fun Column.get(key: Any): Column = getItem(key)
fun lit(a: Any) = functions.lit(a)

/**
* Provides a type hint about the expected return value of this column. This information can
* Provides a type hint about the expected return value of this column. This information can
* be used by operations such as `select` on a [Dataset] to automatically convert the
* results into the correct JVM types.
*
* ```
* val df: Dataset<Row> = ...
* val typedColumn: Dataset<Int> = df.selectTyped( col("a").`as`<Int>() )
* ```
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T> Column.`as`(): TypedColumn<Any, T> = `as`(encoder<T>())


/**
* Alias for [Dataset.joinWith] which passes "left" argument
* and respects the fact that in result of left join right relation is nullable
Expand Down Expand Up @@ -809,45 +816,74 @@ fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply {
/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1> Dataset<T>.selectTyped(
c1: TypedColumn<out Any, U1>,
): Dataset<U1> = select(c1 as TypedColumn<T, U1>)

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
): Dataset<Pair<U1, U2>> =
select(c1, c2).map { Pair(it._1(), it._2()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
).map { Pair(it._1(), it._2()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
): Dataset<Triple<U1, U2, U3>> =
select(c1, c2, c3).map { Triple(it._1(), it._2(), it._3()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
).map { Triple(it._1(), it._2(), it._3()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3, reified U4> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c4: TypedColumn<T, U4>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
c4: TypedColumn<out Any, U4>,
): Dataset<Arity4<U1, U2, U3, U4>> =
select(c1, c2, c3, c4).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
c4 as TypedColumn<T, U4>,
).map { Arity4(it._1(), it._2(), it._3(), it._4()) }

/**
* Returns a new Dataset by computing the given [Column] expressions for each element.
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U5> Dataset<T>.selectTyped(
c1: TypedColumn<T, U1>,
c2: TypedColumn<T, U2>,
c3: TypedColumn<T, U3>,
c4: TypedColumn<T, U4>,
c5: TypedColumn<T, U5>,
c1: TypedColumn<out Any, U1>,
c2: TypedColumn<out Any, U2>,
c3: TypedColumn<out Any, U3>,
c4: TypedColumn<out Any, U4>,
c5: TypedColumn<out Any, U5>,
): Dataset<Arity5<U1, U2, U3, U4, U5>> =
select(c1, c2, c3, c4, c5).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }

select(
c1 as TypedColumn<T, U1>,
c2 as TypedColumn<T, U2>,
c3 as TypedColumn<T, U3>,
c4 as TypedColumn<T, U4>,
c5 as TypedColumn<T, U5>,
).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }

@OptIn(ExperimentalStdlibApi::class)
inline fun <reified T> schema(map: Map<String, KType> = mapOf()) = schema(typeOf<T>(), map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,31 +339,34 @@ class ApiTest : ShouldSpec({
SomeClass(intArrayOf(1, 2, 4), 5),
)

val typedColumnA: TypedColumn<Any, IntArray> = dataset.col("a").`as`(encoder())
val newDS1WithAs: Dataset<Int> = dataset.selectTyped(
col("b").`as`<Int>(),
)
newDS1WithAs.show()

val newDS2 = dataset.selectTyped(
val newDS2: Dataset<Pair<Int, Int>> = dataset.selectTyped(
// col(SomeClass::a), NOTE that this doesn't work on 2.4, returnting a data class with an array in it
col(SomeClass::b),
col(SomeClass::b),
)
newDS2.show()

val newDS3 = dataset.selectTyped(
val newDS3: Dataset<Triple<Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
)
newDS3.show()

val newDS4 = dataset.selectTyped(
val newDS4: Dataset<Arity4<Int, Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
)
newDS4.show()

val newDS5 = dataset.selectTyped(
val newDS5: Dataset<Arity5<Int, Int, Int, Int, Int>> = dataset.selectTyped(
col(SomeClass::b),
col(SomeClass::b),
col(SomeClass::b),
Expand Down
Loading