@@ -23,11 +23,12 @@ use arrow::array::{
2323} ;
2424use arrow:: compute:: { can_cast_types, cast} ;
2525use arrow:: datatypes:: DataType :: { Int64 , Utf8 } ;
26- use arrow:: datatypes:: { DataType , Int64Type } ;
26+ use arrow:: datatypes:: { DataType , Field , FieldRef , Int64Type } ;
2727use datafusion_common:: cast:: as_string_array;
28- use datafusion_common:: { plan_datafusion_err, DataFusionError , Result } ;
28+ use datafusion_common:: { internal_err , plan_datafusion_err, DataFusionError , Result } ;
2929use datafusion_expr:: {
30- ColumnarValue , ScalarFunctionArgs , ScalarUDFImpl , Signature , Volatility ,
30+ ColumnarValue , ReturnFieldArgs , ScalarFunctionArgs , ScalarUDFImpl , Signature ,
31+ Volatility ,
3132} ;
3233use datafusion_functions:: utils:: make_scalar_function;
3334
@@ -64,7 +65,12 @@ impl ScalarUDFImpl for SparkElt {
6465 }
6566
6667 fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
67- Ok ( Utf8 )
68+ internal_err ! ( "return_field_from_args should be used instead" )
69+ }
70+
71+ fn return_field_from_args ( & self , args : ReturnFieldArgs ) -> Result < FieldRef > {
72+ let nullable = args. arg_fields . iter ( ) . any ( |f| f. is_nullable ( ) ) ;
73+ Ok ( Arc :: new ( Field :: new ( self . name ( ) , Utf8 , nullable) ) )
6874 }
6975
7076 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
@@ -248,4 +254,57 @@ mod tests {
248254 assert_eq ! ( out. data_type( ) , & Utf8 ) ;
249255 Ok ( ( ) )
250256 }
257+
258+ #[ test]
259+ fn test_elt_nullability ( ) -> Result < ( ) > {
260+ use datafusion_expr:: ReturnFieldArgs ;
261+
262+ let elt_func = SparkElt :: new ( ) ;
263+
264+ // Test with all non-nullable args - result should be non-nullable
265+ let non_nullable_idx: FieldRef = Arc :: new ( Field :: new ( "idx" , Int64 , false ) ) ;
266+ let non_nullable_v1: FieldRef = Arc :: new ( Field :: new ( "v1" , Utf8 , false ) ) ;
267+ let non_nullable_v2: FieldRef = Arc :: new ( Field :: new ( "v2" , Utf8 , false ) ) ;
268+
269+ let result = elt_func. return_field_from_args ( ReturnFieldArgs {
270+ arg_fields : & [
271+ Arc :: clone ( & non_nullable_idx) ,
272+ Arc :: clone ( & non_nullable_v1) ,
273+ Arc :: clone ( & non_nullable_v2) ,
274+ ] ,
275+ scalar_arguments : & [ None , None , None ] ,
276+ } ) ?;
277+ assert ! (
278+ !result. is_nullable( ) ,
279+ "elt should NOT be nullable when all args are non-nullable"
280+ ) ;
281+
282+ // Test with nullable index - result should be nullable
283+ let nullable_idx: FieldRef = Arc :: new ( Field :: new ( "idx" , Int64 , true ) ) ;
284+ let result = elt_func. return_field_from_args ( ReturnFieldArgs {
285+ arg_fields : & [
286+ nullable_idx,
287+ Arc :: clone ( & non_nullable_v1) ,
288+ Arc :: clone ( & non_nullable_v2) ,
289+ ] ,
290+ scalar_arguments : & [ None , None , None ] ,
291+ } ) ?;
292+ assert ! (
293+ result. is_nullable( ) ,
294+ "elt should be nullable when index is nullable"
295+ ) ;
296+
297+ // Test with nullable value - result should be nullable
298+ let nullable_v1: FieldRef = Arc :: new ( Field :: new ( "v1" , Utf8 , true ) ) ;
299+ let result = elt_func. return_field_from_args ( ReturnFieldArgs {
300+ arg_fields : & [ non_nullable_idx, nullable_v1, non_nullable_v2] ,
301+ scalar_arguments : & [ None , None , None ] ,
302+ } ) ?;
303+ assert ! (
304+ result. is_nullable( ) ,
305+ "elt should be nullable when any value is nullable"
306+ ) ;
307+
308+ Ok ( ( ) )
309+ }
251310}
0 commit comments