diff --git a/CMakeLists.txt b/CMakeLists.txt index 5c1ca8b70..88062e7be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,6 +239,7 @@ set(DUCKDB_SRC_FILES src/duckdb/ub_src_transaction.cpp src/duckdb/src/verification/copied_statement_verifier.cpp src/duckdb/src/verification/deserialized_statement_verifier.cpp + src/duckdb/src/verification/explain_statement_verifier.cpp src/duckdb/src/verification/external_statement_verifier.cpp src/duckdb/src/verification/fetch_row_verifier.cpp src/duckdb/src/verification/no_operator_caching_verifier.cpp @@ -345,46 +346,46 @@ set(DUCKDB_SRC_FILES src/duckdb/third_party/zstd/dict/fastcover.cpp src/duckdb/third_party/zstd/dict/zdict.cpp src/duckdb/extension/core_functions/core_functions_extension.cpp - src/duckdb/extension/core_functions/function_list.cpp src/duckdb/extension/core_functions/lambda_functions.cpp - src/duckdb/ub_extension_core_functions_scalar_map.cpp + src/duckdb/extension/core_functions/function_list.cpp + src/duckdb/ub_extension_core_functions_scalar_date.cpp src/duckdb/ub_extension_core_functions_scalar_operators.cpp - src/duckdb/ub_extension_core_functions_scalar_struct.cpp + src/duckdb/ub_extension_core_functions_scalar_list.cpp + src/duckdb/ub_extension_core_functions_scalar_array.cpp src/duckdb/ub_extension_core_functions_scalar_random.cpp - src/duckdb/ub_extension_core_functions_scalar_string.cpp - src/duckdb/ub_extension_core_functions_scalar_blob.cpp - src/duckdb/ub_extension_core_functions_scalar_union.cpp src/duckdb/ub_extension_core_functions_scalar_bit.cpp - src/duckdb/ub_extension_core_functions_scalar_array.cpp + src/duckdb/ub_extension_core_functions_scalar_enum.cpp src/duckdb/ub_extension_core_functions_scalar_math.cpp + src/duckdb/ub_extension_core_functions_scalar_string.cpp + src/duckdb/ub_extension_core_functions_scalar_union.cpp + src/duckdb/ub_extension_core_functions_scalar_map.cpp src/duckdb/ub_extension_core_functions_scalar_debug.cpp - src/duckdb/ub_extension_core_functions_scalar_enum.cpp + src/duckdb/ub_extension_core_functions_scalar_blob.cpp + src/duckdb/ub_extension_core_functions_scalar_struct.cpp src/duckdb/ub_extension_core_functions_scalar_generic.cpp - src/duckdb/ub_extension_core_functions_scalar_date.cpp - src/duckdb/ub_extension_core_functions_scalar_list.cpp src/duckdb/ub_extension_core_functions_aggregate_nested.cpp src/duckdb/ub_extension_core_functions_aggregate_regression.cpp src/duckdb/ub_extension_core_functions_aggregate_algebraic.cpp - src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp src/duckdb/ub_extension_core_functions_aggregate_holistic.cpp - src/duckdb/extension/parquet/parquet_writer.cpp + src/duckdb/ub_extension_core_functions_aggregate_distributive.cpp + src/duckdb/extension/parquet/parquet_metadata.cpp + src/duckdb/extension/parquet/serialize_parquet.cpp src/duckdb/extension/parquet/geo_parquet.cpp - src/duckdb/extension/parquet/column_reader.cpp src/duckdb/extension/parquet/parquet_float16.cpp - src/duckdb/extension/parquet/parquet_metadata.cpp - src/duckdb/extension/parquet/parquet_reader.cpp - src/duckdb/extension/parquet/zstd_file_system.cpp - src/duckdb/extension/parquet/parquet_timestamp.cpp + src/duckdb/extension/parquet/parquet_statistics.cpp + src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp + src/duckdb/extension/parquet/parquet_crypto.cpp + src/duckdb/extension/parquet/column_reader.cpp + src/duckdb/extension/parquet/parquet_writer.cpp src/duckdb/extension/parquet/parquet_multi_file_info.cpp + src/duckdb/extension/parquet/parquet_reader.cpp src/duckdb/extension/parquet/column_writer.cpp + src/duckdb/extension/parquet/parquet_timestamp.cpp + src/duckdb/extension/parquet/zstd_file_system.cpp src/duckdb/extension/parquet/parquet_extension.cpp - src/duckdb/extension/parquet/parquet_crypto.cpp - src/duckdb/extension/parquet/parquet_file_metadata_cache.cpp - src/duckdb/extension/parquet/serialize_parquet.cpp - src/duckdb/extension/parquet/parquet_statistics.cpp - src/duckdb/ub_extension_parquet_decoder.cpp src/duckdb/ub_extension_parquet_reader.cpp src/duckdb/ub_extension_parquet_reader_variant.cpp + src/duckdb/ub_extension_parquet_decoder.cpp src/duckdb/ub_extension_parquet_writer.cpp src/duckdb/third_party/parquet/parquet_types.cpp src/duckdb/third_party/thrift/thrift/protocol/TProtocol.cpp @@ -424,31 +425,31 @@ set(DUCKDB_SRC_FILES src/duckdb/third_party/brotli/enc/metablock.cpp src/duckdb/third_party/brotli/enc/static_dict.cpp src/duckdb/third_party/brotli/enc/utf8_util.cpp - src/duckdb/extension/icu/./icu-list-range.cpp + src/duckdb/extension/icu/./icu-table-range.cpp src/duckdb/extension/icu/./icu_extension.cpp - src/duckdb/extension/icu/./icu-strptime.cpp - src/duckdb/extension/icu/./icu-current.cpp - src/duckdb/extension/icu/./icu-makedate.cpp + src/duckdb/extension/icu/./icu-datetrunc.cpp src/duckdb/extension/icu/./icu-datesub.cpp - src/duckdb/extension/icu/./icu-timezone.cpp - src/duckdb/extension/icu/./icu-timebucket.cpp - src/duckdb/extension/icu/./icu-table-range.cpp src/duckdb/extension/icu/./icu-datefunc.cpp - src/duckdb/extension/icu/./icu-datepart.cpp - src/duckdb/extension/icu/./icu-datetrunc.cpp src/duckdb/extension/icu/./icu-dateadd.cpp + src/duckdb/extension/icu/./icu-datepart.cpp + src/duckdb/extension/icu/./icu-timezone.cpp + src/duckdb/extension/icu/./icu-current.cpp + src/duckdb/extension/icu/./icu-strptime.cpp + src/duckdb/extension/icu/./icu-makedate.cpp + src/duckdb/extension/icu/./icu-timebucket.cpp + src/duckdb/extension/icu/./icu-list-range.cpp src/duckdb/ub_extension_icu_third_party_icu_common.cpp src/duckdb/ub_extension_icu_third_party_icu_i18n.cpp src/duckdb/extension/icu/third_party/icu/stubdata/stubdata.cpp - src/duckdb/extension/json/json_common.cpp - src/duckdb/extension/json/json_extension.cpp src/duckdb/extension/json/json_multi_file_info.cpp - src/duckdb/extension/json/json_scan.cpp + src/duckdb/extension/json/json_reader.cpp + src/duckdb/extension/json/serialize_json.cpp + src/duckdb/extension/json/json_extension.cpp src/duckdb/extension/json/json_enums.cpp + src/duckdb/extension/json/json_scan.cpp src/duckdb/extension/json/json_functions.cpp - src/duckdb/extension/json/json_reader.cpp + src/duckdb/extension/json/json_common.cpp src/duckdb/extension/json/json_deserializer.cpp - src/duckdb/extension/json/serialize_json.cpp src/duckdb/extension/json/json_serializer.cpp src/duckdb/ub_extension_json_json_functions.cpp) diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp index 1a774759d..d2bdfbe54 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp @@ -455,7 +455,7 @@ unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateF break; } - auto cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(by_type, by_types[i]); + auto cast_cost = CastFunctionSet::ImplicitCastCost(context, by_type, by_types[i]); if (cast_cost < 0) { continue; } diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp index d642e7a9f..bfea19644 100644 --- a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp +++ b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp @@ -1,6 +1,7 @@ #include "core_functions/aggregate/distributive_functions.hpp" #include "core_functions/aggregate/sum_helpers.hpp" #include "duckdb/common/exception.hpp" +#include "duckdb/common/bignum.hpp" #include "duckdb/common/types/decimal.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -211,6 +212,63 @@ unique_ptr BindDecimalSum(ClientContext &context, AggregateFunctio return nullptr; } +struct BignumState { + bool is_set; + BignumIntermediate value; +}; + +struct BignumOperation { + template + static void Initialize(STATE &state) { + state.is_set = false; + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, + idx_t count) { + for (idx_t i = 0; i < count; i++) { + Operation(state, input, unary_input); + } + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { + if (!state.is_set) { + state.is_set = true; + state.value.Initialize(unary_input.input.allocator); + } + BignumIntermediate rhs(input); + state.value.AddInPlace(unary_input.input.allocator, rhs); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &input) { + if (!source.is_set) { + return; + } + if (!target.is_set) { + target.value = source.value; + target.is_set = true; + return; + } + target.value.AddInPlace(input.allocator, source.value); + target.is_set = true; + } + + template + static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { + if (!state.is_set) { + finalize_data.ReturnNull(); + } else { + target = state.value.ToBignum(finalize_data.input.allocator); + } + } + + static bool IgnoreNull() { + return true; + } +}; + } // namespace AggregateFunctionSet SumFun::GetFunctions() { @@ -226,6 +284,8 @@ AggregateFunctionSet SumFun::GetFunctions() { sum.AddFunction(GetSumAggregate(PhysicalType::INT128)); sum.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericSumOperation>( LogicalType::DOUBLE, LogicalType::DOUBLE)); + sum.AddFunction(AggregateFunction::UnaryAggregate( + LogicalType::BIGNUM, LogicalType::BIGNUM)); return sum; } diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp index 0b2218792..5771e14eb 100644 --- a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp +++ b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp @@ -182,14 +182,6 @@ void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInput unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, vector> &arguments) { - D_ASSERT(arguments.size() == 1); - D_ASSERT(function.arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - function.arguments[0] = LogicalTypeId::UNKNOWN; - function.return_type = LogicalType::SQLNULL; - return nullptr; - } function.return_type = LogicalType::LIST(arguments[0]->return_type); return make_uniq(function.return_type); @@ -198,10 +190,10 @@ unique_ptr ListBindFunction(ClientContext &context, AggregateFunct } // namespace AggregateFunction ListFun::GetFunction() { - auto func = - AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, ListUpdateFunction, - ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, nullptr); + auto func = AggregateFunction( + {LogicalType::TEMPLATE("T")}, LogicalType::LIST(LogicalType::TEMPLATE("T")), + AggregateFunction::StateSize, AggregateFunction::StateInitialize, + ListUpdateFunction, ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, nullptr); return func; } diff --git a/src/duckdb/extension/core_functions/function_list.cpp b/src/duckdb/extension/core_functions/function_list.cpp index 6a53a0317..f34b59188 100644 --- a/src/duckdb/extension/core_functions/function_list.cpp +++ b/src/duckdb/extension/core_functions/function_list.cpp @@ -243,14 +243,14 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_SCALAR_FUNCTION_SET(ListInnerProductFun), DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListNegativeDotProductFun), DUCKDB_SCALAR_FUNCTION_SET(ListNegativeInnerProductFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListPackFun), + DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListPackFun), DUCKDB_SCALAR_FUNCTION_SET(ListReduceFun), DUCKDB_SCALAR_FUNCTION_SET(ListReverseSortFun), DUCKDB_SCALAR_FUNCTION_SET(ListSliceFun), DUCKDB_SCALAR_FUNCTION_SET(ListSortFun), DUCKDB_SCALAR_FUNCTION(ListTransformFun), DUCKDB_SCALAR_FUNCTION(ListUniqueFun), - DUCKDB_SCALAR_FUNCTION(ListValueFun), + DUCKDB_SCALAR_FUNCTION_SET(ListValueFun), DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ListaggFun), DUCKDB_SCALAR_FUNCTION(LnFun), DUCKDB_SCALAR_FUNCTION_SET(LogFun), @@ -264,7 +264,7 @@ static const StaticFunctionDefinition core_functions[] = { DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampFun), DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampMsFun), DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampNsFun), - DUCKDB_SCALAR_FUNCTION(MapFun), + DUCKDB_SCALAR_FUNCTION_SET(MapFun), DUCKDB_SCALAR_FUNCTION(MapConcatFun), DUCKDB_SCALAR_FUNCTION(MapEntriesFun), DUCKDB_SCALAR_FUNCTION(MapExtractFun), diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp index 120a14abb..1762ed90f 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp @@ -320,8 +320,8 @@ struct QuantileSortTree { vector sort_idx(1, 0); const auto count = partition.count; - index_tree = make_uniq(partition.context, order_bys, sort_idx, count); - auto index_state = index_tree->GetLocalState(); + index_tree = make_uniq(partition.context.client, order_bys, sort_idx, count); + auto index_state = index_tree->GetLocalState(partition.context); auto &local_state = index_state->Cast(); // Build the indirection array by scanning the valid indices @@ -338,12 +338,12 @@ struct QuantileSortTree { filter_sel[filtered++] = i; } } - local_state.SinkChunk(sort, row_idx, filter_sel, filtered); + local_state.Sink(partition.context, sort, row_idx, filter_sel, filtered); } else { - local_state.SinkChunk(sort, row_idx, nullptr, 0); + local_state.Sink(partition.context, sort, row_idx, nullptr, 0); } } - local_state.Sort(); + local_state.Finalize(partition.context); } inline idx_t SelectNth(const SubFrames &frames, size_t n) const { diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp index 189f63301..fadf94e47 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp @@ -30,6 +30,7 @@ struct SumState { void Initialize() { this->isset = false; + this->value = 0; } void Combine(const SumState &other) { @@ -182,7 +183,6 @@ struct BaseSumOperation { STATEOP::template AddValues(state, count); ADDOP::template AddConstant(state, input, count); } - static bool IgnoreNull() { return true; } diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp index 2d52ab0e7..96ffd39c2 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp @@ -98,7 +98,7 @@ struct ListValueFun { static constexpr const char *Example = "list_value(4, 5, 6)"; static constexpr const char *Categories = "list"; - static ScalarFunction GetFunction(); + static ScalarFunctionSet GetFunctions(); }; struct ListPackFun { diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp index 462804635..c96055e0a 100644 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp +++ b/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp @@ -32,7 +32,7 @@ struct MapFun { static constexpr const char *Example = "map(['key1', 'key2'], ['val1', 'val2'])"; static constexpr const char *Categories = ""; - static ScalarFunction GetFunction(); + static ScalarFunctionSet GetFunctions(); }; struct MapEntriesFun { diff --git a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp index 1484155c8..7ced59dcb 100644 --- a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp +++ b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp @@ -1438,26 +1438,6 @@ double DatePart::JulianDayOperator::Operation(date_t input) { return double(Date::ExtractJulianDay(input)); } -template <> -double DatePart::JulianDayOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"julian\" not recognized"); -} - -template <> -double DatePart::JulianDayOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"julian\" not recognized"); -} - -template <> -double DatePart::JulianDayOperator::Operation(dtime_ns_t input) { - return JulianDayOperator::Operation(input.time()); -} - -template <> -double DatePart::JulianDayOperator::Operation(dtime_tz_t input) { - return JulianDayOperator::Operation(input.time()); -} - template <> void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const dtime_t &input, const idx_t idx, const part_mask_t mask) { diff --git a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp index de8efadb3..1f28c8da8 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp @@ -9,7 +9,7 @@ namespace duckdb { namespace { bool CanCastImplicitly(ClientContext &context, const LogicalType &source, const LogicalType &target) { - return CastFunctionSet::Get(context).ImplicitCastCost(source, target) >= 0; + return CastFunctionSet::ImplicitCastCost(context, source, target) >= 0; } void CanCastImplicitlyFunction(DataChunk &args, ExpressionState &state, Vector &result) { diff --git a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp index aabc780b3..7670c39c6 100644 --- a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp +++ b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp @@ -51,9 +51,11 @@ unique_ptr CurrentSettingBind(ClientContext &context, ScalarFuncti auto key = StringUtil::Lower(StringValue::Get(key_val)); Value val; if (!context.TryGetCurrentSetting(key, val)) { - Catalog::AutoloadExtensionByConfigName(context, key); + auto extension_name = Catalog::AutoloadExtensionByConfigName(context, key); // If autoloader didn't throw, the config is now available - context.TryGetCurrentSetting(key, val); + if (!context.TryGetCurrentSetting(key, val)) { + throw InternalException("Extension %s did not provide the '%s' config setting", extension_name, key); + } } bound_function.return_type = val.type(); diff --git a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp index b66a629b2..97b3d625f 100644 --- a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp @@ -142,51 +142,6 @@ void ListFlattenFunction(DataChunk &args, ExpressionState &, Vector &result) { } } -unique_ptr ListFlattenBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::ARRAY) { - auto child_type = ArrayType::GetChildType(arguments[0]->return_type); - if (child_type.id() == LogicalTypeId::ARRAY) { - child_type = LogicalType::LIST(ArrayType::GetChildType(child_type)); - } - arguments[0] = - BoundCastExpression::AddCastToType(context, std::move(arguments[0]), LogicalType::LIST(child_type)); - } else if (arguments[0]->return_type.id() == LogicalTypeId::LIST) { - auto child_type = ListType::GetChildType(arguments[0]->return_type); - if (child_type.id() == LogicalTypeId::ARRAY) { - child_type = LogicalType::LIST(ArrayType::GetChildType(child_type)); - arguments[0] = - BoundCastExpression::AddCastToType(context, std::move(arguments[0]), LogicalType::LIST(child_type)); - } - } - - auto &input_type = arguments[0]->return_type; - bound_function.arguments[0] = input_type; - if (input_type.IsUnknown()) { - bound_function.arguments[0] = LogicalType::UNKNOWN; - bound_function.return_type = LogicalType::UNKNOWN; - return nullptr; - } - D_ASSERT(input_type.id() == LogicalTypeId::LIST); - - auto child_type = ListType::GetChildType(input_type); - if (child_type.id() == LogicalType::SQLNULL) { - bound_function.return_type = input_type; - return make_uniq(bound_function.return_type); - } - if (child_type.IsUnknown()) { - bound_function.arguments[0] = LogicalType::UNKNOWN; - bound_function.return_type = LogicalType::UNKNOWN; - return nullptr; - } - D_ASSERT(child_type.id() == LogicalTypeId::LIST); - - bound_function.return_type = child_type; - return make_uniq(bound_function.return_type); -} - unique_ptr ListFlattenStats(ClientContext &context, FunctionStatisticsInput &input) { auto &child_stats = input.child_stats; auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); @@ -198,8 +153,9 @@ unique_ptr ListFlattenStats(ClientContext &context, FunctionStat } // namespace ScalarFunction ListFlattenFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::LIST(LogicalType::ANY))}, LogicalType::LIST(LogicalType::ANY), - ListFlattenFunction, ListFlattenBind, nullptr, ListFlattenStats); + return ScalarFunction({LogicalType::LIST(LogicalType::LIST(LogicalType::TEMPLATE("T")))}, + LogicalType::LIST(LogicalType::TEMPLATE("T")), ListFlattenFunction, nullptr, nullptr, + ListFlattenStats); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp index 0d218d90f..8be7134ab 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp @@ -501,28 +501,6 @@ unique_ptr ListAggregateBind(ClientContext &context, ScalarFunctio return ListAggregatesBind(context, bound_function, arguments); } -unique_ptr ListDistinctBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - D_ASSERT(bound_function.arguments.size() == 1); - D_ASSERT(arguments.size() == 1); - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - bound_function.return_type = arguments[0]->return_type; - - return ListAggregatesBind<>(context, bound_function, arguments); -} - -unique_ptr ListUniqueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - D_ASSERT(bound_function.arguments.size() == 1); - D_ASSERT(arguments.size() == 1); - bound_function.return_type = LogicalType::UBIGINT; - - return ListAggregatesBind<>(context, bound_function, arguments); -} - } // namespace ScalarFunction ListAggregateFun::GetFunction() { @@ -538,13 +516,14 @@ ScalarFunction ListAggregateFun::GetFunction() { } ScalarFunction ListDistinctFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), - ListDistinctFunction, ListDistinctBind, nullptr, nullptr, ListAggregatesInitLocalState); + return ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T"))}, + LogicalType::LIST(LogicalType::TEMPLATE("T")), ListDistinctFunction, + ListAggregatesBind, nullptr, nullptr, ListAggregatesInitLocalState); } ScalarFunction ListUniqueFun::GetFunction() { return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::UBIGINT, ListUniqueFunction, - ListUniqueBind, nullptr, nullptr, ListAggregatesInitLocalState); + ListAggregatesBind, nullptr, nullptr, ListAggregatesInitLocalState); } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp index c2ce6f02d..51b4980cd 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp @@ -6,52 +6,6 @@ namespace duckdb { -static unique_ptr ListHasAnyOrAllBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - arguments[1] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[1])); - - const auto &lhs_type = arguments[0]->return_type; - const auto &rhs_type = arguments[1]->return_type; - - if (lhs_type.IsUnknown() && rhs_type.IsUnknown()) { - bound_function.arguments[0] = rhs_type; - bound_function.arguments[1] = lhs_type; - bound_function.return_type = LogicalType::UNKNOWN; - return nullptr; - } - - // LHS and RHS must have the same input type. - // Thus, we can proceed binding, even if we only know the type of one of the arguments. - - if (lhs_type.IsUnknown() || rhs_type.IsUnknown()) { - bound_function.arguments[0] = lhs_type.IsUnknown() ? rhs_type : lhs_type; - bound_function.arguments[1] = rhs_type.IsUnknown() ? lhs_type : rhs_type; - return nullptr; - } - - // Ensure the lists have the same child type, else throw. - - bound_function.arguments[0] = lhs_type; - bound_function.arguments[1] = rhs_type; - - const auto &lhs_child = ListType::GetChildType(bound_function.arguments[0]); - const auto &rhs_child = ListType::GetChildType(bound_function.arguments[1]); - - if (lhs_child != LogicalType::SQLNULL && rhs_child != LogicalType::SQLNULL && lhs_child != rhs_child) { - LogicalType common_child; - if (!LogicalType::TryGetMaxLogicalType(context, lhs_child, rhs_child, common_child)) { - throw BinderException("'%s' cannot compare lists of different types: '%s' and '%s'", bound_function.name, - lhs_child.ToString(), rhs_child.ToString()); - } - bound_function.arguments[0] = LogicalType::LIST(common_child); - bound_function.arguments[1] = LogicalType::LIST(common_child); - } - - return nullptr; -} - static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &result) { auto &l_vec = args.data[0]; @@ -213,14 +167,14 @@ static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector & } ScalarFunction ListHasAnyFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, - ListHasAnyFunction, ListHasAnyOrAllBind); + ScalarFunction fun({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::LIST(LogicalType::TEMPLATE("T"))}, + LogicalType::BOOLEAN, ListHasAnyFunction); return fun; } ScalarFunction ListHasAllFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, - ListHasAllFunction, ListHasAnyOrAllBind); + ScalarFunction fun({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::LIST(LogicalType::TEMPLATE("T"))}, + LogicalType::BOOLEAN, ListHasAllFunction); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp index a12cb7ca6..1263500c9 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp @@ -300,8 +300,8 @@ static unique_ptr ListGradeUpBind(ClientContext &context, ScalarFu null_order = GetOrder(context, *arguments[2]); } auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); - null_order = config.ResolveNullOrder(order, null_order); + order = config.ResolveOrder(context, order); + null_order = config.ResolveNullOrder(context, order, null_order); arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); @@ -326,8 +326,8 @@ static unique_ptr ListNormalSortBind(ClientContext &context, Scala null_order = GetOrder(context, *arguments[2]); } auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); - null_order = config.ResolveNullOrder(order, null_order); + order = config.ResolveOrder(context, order); + null_order = config.ResolveNullOrder(context, order, null_order); return ListSortBind(context, bound_function, arguments, order, null_order); } @@ -340,7 +340,7 @@ static unique_ptr ListReverseSortBind(ClientContext &context, Scal null_order = GetOrder(context, *arguments[1]); } auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); + order = config.ResolveOrder(context, order); switch (order) { case OrderType::ASCENDING: order = OrderType::DESCENDING; @@ -351,7 +351,7 @@ static unique_ptr ListReverseSortBind(ClientContext &context, Scal default: throw InternalException("Unexpected order type in list reverse sort"); } - null_order = config.ResolveNullOrder(order, null_order); + null_order = config.ResolveNullOrder(context, order, null_order); return ListSortBind(context, bound_function, arguments, order, null_order); } diff --git a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp index ff67a6a09..cec76fe89 100644 --- a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp +++ b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp @@ -262,37 +262,29 @@ void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &result) ListVector::SetListSize(result, column_count * args.size()); } -template -unique_ptr ListValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { +unique_ptr UnpivotBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { // collect names and deconflict, construct return type LogicalType child_type = arguments.empty() ? LogicalType::SQLNULL : ExpressionBinder::GetExpressionReturnType(*arguments[0]); for (idx_t i = 1; i < arguments.size(); i++) { auto arg_type = ExpressionBinder::GetExpressionReturnType(*arguments[i]); if (!LogicalType::TryGetMaxLogicalType(context, child_type, arg_type, child_type)) { - if (IS_UNPIVOT) { - string list_arguments = "Full list: "; - idx_t error_index = list_arguments.size(); - for (idx_t k = 0; k < arguments.size(); k++) { - if (k > 0) { - list_arguments += ", "; - } - if (k == i) { - error_index = list_arguments.size(); - } - list_arguments += arguments[k]->ToString() + " " + arguments[k]->return_type.ToString(); + string list_arguments = "Full list: "; + idx_t error_index = list_arguments.size(); + for (idx_t k = 0; k < arguments.size(); k++) { + if (k > 0) { + list_arguments += ", "; } - auto error = - StringUtil::Format("Cannot unpivot columns of types %s and %s - an explicit cast is required", - child_type.ToString(), arg_type.ToString()); - throw BinderException(arguments[i]->GetQueryLocation(), - QueryErrorContext::Format(list_arguments, error, error_index, false)); - } else { - throw BinderException(arguments[i]->GetQueryLocation(), - "Cannot create a list of types %s and %s - an explicit cast is required", - child_type.ToString(), arg_type.ToString()); + if (k == i) { + error_index = list_arguments.size(); + } + list_arguments += arguments[k]->ToString() + " " + arguments[k]->return_type.ToString(); } + auto error = StringUtil::Format("Cannot unpivot columns of types %s and %s - an explicit cast is required", + child_type.ToString(), arg_type.ToString()); + throw BinderException(arguments[i]->GetQueryLocation(), + QueryErrorContext::Format(list_arguments, error, error_index, false)); } } child_type = LogicalType::NormalizeType(child_type); @@ -316,19 +308,31 @@ unique_ptr ListValueStats(ClientContext &context, FunctionStatis } // namespace -ScalarFunction ListValueFun::GetFunction() { - // the arguments and return types are actually set in the binder function - ScalarFunction fun("list_value", {}, LogicalTypeId::LIST, ListValueFunction, ListValueBind, nullptr, - ListValueStats); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; +ScalarFunctionSet ListValueFun::GetFunctions() { + + ScalarFunctionSet set("list_value"); + + // Overload for 0 arguments, which returns an empty list. + ScalarFunction empty_fun({}, LogicalType::LIST(LogicalType::SQLNULL), ListValueFunction, nullptr, nullptr, + ListValueStats); + set.AddFunction(empty_fun); + + // Overload for 1 + N arguments, which returns a list of the arguments. + auto element_type = LogicalType::TEMPLATE("T"); + ScalarFunction value_fun({element_type}, LogicalType::LIST(element_type), ListValueFunction, nullptr, nullptr, + ListValueStats); + value_fun.varargs = element_type; + value_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + set.AddFunction(value_fun); + + return set; } ScalarFunction UnpivotListFun::GetFunction() { - auto fun = ListValueFun::GetFunction(); - fun.name = "unpivot_list"; - fun.bind = ListValueBind; + ScalarFunction fun("unpivot_list", {}, LogicalTypeId::LIST, ListValueFunction, UnpivotBind, nullptr, + ListValueStats); + fun.varargs = LogicalTypeId::ANY; + fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map.cpp b/src/duckdb/extension/core_functions/scalar/map/map.cpp index b83a4a081..ab9bea1bb 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map.cpp @@ -172,52 +172,24 @@ static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { result.Verify(row_count); } -static unique_ptr MapBind(ClientContext &, ScalarFunction &bound_function, - vector> &arguments) { +ScalarFunctionSet MapFun::GetFunctions() { - if (arguments.size() != 2 && !arguments.empty()) { - MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); - } - - bool is_null = false; - if (arguments.empty()) { - is_null = true; - } - if (!is_null) { - auto key_id = arguments[0]->return_type.id(); - auto value_id = arguments[1]->return_type.id(); - if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) { - is_null = true; - } - } + ScalarFunction empty_func({}, LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL), MapFunction); + BaseScalarFunction::SetReturnsError(empty_func); - if (is_null) { - bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); - return make_uniq(bound_function.return_type); - } + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + ScalarFunction value_func({LogicalType::LIST(key_type), LogicalType::LIST(val_type)}, + LogicalType::MAP(key_type, val_type), MapFunction); + BaseScalarFunction::SetReturnsError(value_func); + value_func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - // bind a MAP with key-value pairs - D_ASSERT(arguments.size() == 2); - if (arguments[0]->return_type.id() != LogicalTypeId::LIST) { - MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); - } - if (arguments[1]->return_type.id() != LogicalTypeId::LIST) { - MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); - } + ScalarFunctionSet set; - auto key_type = ListType::GetChildType(arguments[0]->return_type); - auto value_type = ListType::GetChildType(arguments[1]->return_type); - - bound_function.return_type = LogicalType::MAP(key_type, value_type); - return make_uniq(bound_function.return_type); -} + set.AddFunction(empty_func); + set.AddFunction(value_func); -ScalarFunction MapFun::GetFunction() { - ScalarFunction fun({}, LogicalTypeId::MAP, MapFunction, MapBind); - fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(fun); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; + return set; } } // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp index 487fd75fa..06af34e66 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp @@ -28,51 +28,15 @@ static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector & result.Verify(count); } -static LogicalType CreateReturnType(const LogicalType &map) { - auto &key_type = MapType::KeyType(map); - auto &value_type = MapType::ValueType(map); - - child_list_t child_types; - child_types.push_back(make_pair("key", key_type)); - child_types.push_back(make_pair("value", value_type)); - - auto row_type = LogicalType::STRUCT(child_types); - return LogicalType::LIST(row_type); -} - -static unique_ptr MapEntriesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 1) { - throw InvalidInputException("Too many arguments provided, only expecting a single map"); - } - auto &map = arguments[0]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (map.id() == LogicalTypeId::SQLNULL) { - // Input is NULL, output is STRUCT(NULL, NULL)[] - auto map_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); - bound_function.return_type = CreateReturnType(map_type); - return make_uniq(bound_function.return_type); - } +ScalarFunction MapEntriesFun::GetFunction() { - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("The provided argument is not a map"); - } - bound_function.return_type = CreateReturnType(map); - return make_uniq(bound_function.return_type); -} + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + auto map_type = LogicalType::MAP(key_type, val_type); + auto row_type = LogicalType::STRUCT({{"key", key_type}, {"value", val_type}}); -ScalarFunction MapEntriesFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::LIST, MapEntriesFunction, MapEntriesBind); + ScalarFunction fun({map_type}, LogicalType::LIST(row_type), MapEntriesFunction); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.varargs = LogicalType::ANY; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp index 5fd61fab4..fcea0b133 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp @@ -6,52 +6,12 @@ namespace duckdb { -template -static unique_ptr MapExtractBind(ClientContext &, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 2) { - throw BinderException("MAP_EXTRACT must have exactly two arguments"); - } - - const auto &map_type = arguments[0]->return_type; - const auto &input_type = arguments[1]->return_type; - - if (map_type.id() == LogicalTypeId::SQLNULL) { - bound_function.return_type = EXTRACT_VALUE ? LogicalTypeId::SQLNULL : LogicalType::LIST(LogicalTypeId::SQLNULL); - return make_uniq(bound_function.return_type); - } - - if (map_type.id() != LogicalTypeId::MAP) { - throw BinderException("'%s' can only operate on MAPs", bound_function.name); - } - auto &value_type = MapType::ValueType(map_type); - - //! Here we have to construct the List Type that will be returned - bound_function.return_type = EXTRACT_VALUE ? value_type : LogicalType::LIST(value_type); - const auto &key_type = MapType::KeyType(map_type); - if (key_type.id() != LogicalTypeId::SQLNULL && input_type.id() != LogicalTypeId::SQLNULL) { - bound_function.arguments[1] = MapType::KeyType(map_type); - } - return make_uniq(bound_function.return_type); -} - static void MapExtractValueFunc(DataChunk &args, ExpressionState &state, Vector &result) { const auto count = args.size(); auto &map_vec = args.data[0]; auto &arg_vec = args.data[1]; - const auto map_is_null = map_vec.GetType().id() == LogicalTypeId::SQLNULL; - const auto arg_is_null = arg_vec.GetType().id() == LogicalTypeId::SQLNULL; - - if (map_is_null || arg_is_null) { - // Short-circuit if either the map or the arg is NULL - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - result.Verify(count); - return; - } - auto &key_vec = MapVector::GetKeys(map_vec); auto &val_vec = MapVector::GetValues(map_vec); @@ -100,18 +60,6 @@ static void MapExtractListFunc(DataChunk &args, ExpressionState &state, Vector & auto &map_vec = args.data[0]; auto &arg_vec = args.data[1]; - const auto map_is_null = map_vec.GetType().id() == LogicalTypeId::SQLNULL; - const auto arg_is_null = arg_vec.GetType().id() == LogicalTypeId::SQLNULL; - - if (map_is_null || arg_is_null) { - // Short-circuit if either the map or the arg is NULL - ListVector::SetListSize(result, 0); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(result)[0] = {0, 0}; - result.Verify(count); - return; - } - auto &key_vec = MapVector::GetKeys(map_vec); auto &val_vec = MapVector::GetValues(map_vec); @@ -144,7 +92,7 @@ static void MapExtractListFunc(DataChunk &args, ExpressionState &state, Vector & const auto pos_idx = pos_format.sel->get_index(row_idx); if (!pos_format.validity.RowIsValid(pos_idx)) { - // We didnt find the key in the map, so return emptyl ist + // We didnt find the key in the map, so return empty list out_list.offset = offset; out_list.length = 0; continue; @@ -166,17 +114,20 @@ static void MapExtractListFunc(DataChunk &args, ExpressionState &state, Vector & } ScalarFunction MapExtractValueFun::GetFunction() { - ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractValueFunc, - MapExtractBind); - fun.varargs = LogicalType::ANY; + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + + ScalarFunction fun({LogicalType::MAP(key_type, val_type), key_type}, val_type, MapExtractValueFunc); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return fun; } ScalarFunction MapExtractFun::GetFunction() { - ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractListFunc, - MapExtractBind); - fun.varargs = LogicalType::ANY; + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + + ScalarFunction fun({LogicalType::MAP(key_type, val_type), key_type}, LogicalType::LIST(val_type), + MapExtractListFunc); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp index edbe1d4fb..2344b9a6e 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp @@ -19,40 +19,15 @@ static void MapFromEntriesFunction(DataChunk &args, ExpressionState &state, Vect } } -static unique_ptr MapFromEntriesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 1) { - throw InvalidInputException("The input argument must be a list of structs."); - } - auto &list = arguments[0]->return_type; - - if (list.id() == LogicalTypeId::UNKNOWN) { - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (list.id() != LogicalTypeId::LIST) { - throw InvalidInputException("The provided argument is not a list of structs"); - } - auto &elem_type = ListType::GetChildType(list); - if (elem_type.id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("The elements of the list must be structs"); - } - auto &children = StructType::GetChildTypes(elem_type); - if (children.size() != 2) { - throw InvalidInputException("The provided struct type should only contain 2 fields, a key and a value"); - } - - bound_function.return_type = LogicalType::MAP(elem_type); - return make_uniq(bound_function.return_type); -} - ScalarFunction MapFromEntriesFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::MAP, MapFromEntriesFunction, MapFromEntriesBind); + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + auto map_type = LogicalType::MAP(key_type, val_type); + auto row_type = LogicalType::STRUCT({{"", key_type}, {"", val_type}}); + + ScalarFunction fun({LogicalType::LIST(row_type)}, map_type, MapFromEntriesFunction); fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - fun.varargs = LogicalType::ANY; + BaseScalarFunction::SetReturnsError(fun); return fun; } diff --git a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp index 6d99a353e..eec32a0a6 100644 --- a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp +++ b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp @@ -51,61 +51,26 @@ static void MapValuesFunction(DataChunk &args, ExpressionState &state, Vector &r MapKeyValueFunction(args, state, result, MapVector::GetValues); } -static unique_ptr MapKeyValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, - const LogicalType &(*type_func)(const LogicalType &)) { - if (arguments.size() != 1) { - throw InvalidInputException("Too many arguments provided, only expecting a single map"); - } - auto &map = arguments[0]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (map.id() == LogicalTypeId::SQLNULL) { - // Input is NULL, output is NULL[] - bound_function.return_type = LogicalType::LIST(LogicalTypeId::SQLNULL); - return make_uniq(bound_function.return_type); - } - - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("The provided argument is not a map"); - } - - auto &type = type_func(map); - - bound_function.return_type = LogicalType::LIST(type); - return make_uniq(bound_function.return_type); -} - -static unique_ptr MapKeysBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return MapKeyValueBind(context, bound_function, arguments, MapType::KeyType); -} - -static unique_ptr MapValuesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return MapKeyValueBind(context, bound_function, arguments, MapType::ValueType); -} - ScalarFunction MapKeysFun::GetFunction() { //! the arguments and return types are actually set in the binder function - ScalarFunction function({}, LogicalTypeId::LIST, MapKeysFunction, MapKeysBind); + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + + ScalarFunction function({LogicalType::MAP(key_type, val_type)}, LogicalType::LIST(key_type), MapKeysFunction); function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + BaseScalarFunction::SetReturnsError(function); - function.varargs = LogicalType::ANY; return function; } ScalarFunction MapValuesFun::GetFunction() { - ScalarFunction function({}, LogicalTypeId::LIST, MapValuesFunction, MapValuesBind); + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + + ScalarFunction function({LogicalType::MAP(key_type, val_type)}, LogicalType::LIST(val_type), MapValuesFunction); function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; + BaseScalarFunction::SetReturnsError(function); - function.varargs = LogicalType::ANY; return function; } diff --git a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp index dd8a1a9ab..d6ae71bc0 100644 --- a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp +++ b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp @@ -1,16 +1,10 @@ #include "duckdb/common/operator/decimal_cast_operators.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/algorithm.hpp" #include "duckdb/common/likely.hpp" #include "duckdb/common/operator/abs.hpp" #include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/operator/numeric_binary_operators.hpp" #include "duckdb/common/types/bit.hpp" #include "duckdb/common/types/cast_helpers.hpp" #include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/common/types/validity_mask.hpp" -#include "duckdb/common/types/vector.hpp" #include "duckdb/common/vector_operations/unary_executor.hpp" #include "core_functions/scalar/math_functions.hpp" #include "duckdb/execution/expression_executor.hpp" @@ -18,9 +12,6 @@ #include #include -#include -#include -#include namespace duckdb { @@ -474,6 +465,116 @@ ScalarFunctionSet FloorFun::GetFunctions() { //===--------------------------------------------------------------------===// namespace { +struct RoundPrecisionFunctionData : public FunctionData { + explicit RoundPrecisionFunctionData(int32_t target_scale) : target_scale(target_scale) { + } + + int32_t target_scale; + + unique_ptr Copy() const override { + return make_uniq(target_scale); + } + + bool Equals(const FunctionData &other_p) const override { + auto &other = other_p.Cast(); + return target_scale == other.target_scale; + } +}; + +template +static void GenericRoundPrecisionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { + OP::template Operation(input, state, result); +} + +template +static unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments) { + auto &decimal_type = arguments[0]->return_type; + if (arguments[1]->HasParameter()) { + throw ParameterNotResolvedException(); + } + auto fname = StringUtil::Upper(bound_function.name); + if (!arguments[1]->IsFoldable()) { + throw NotImplementedException("%s(DECIMAL, INTEGER) with non-constant precision is not supported", fname); + } + Value val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]).DefaultCastAs(LogicalType::INTEGER); + if (val.IsNull()) { + throw NotImplementedException("%s(DECIMAL, INTEGER) with non-constant precision is not supported", fname); + } + // our new precision becomes the round value + // e.g. ROUND(DECIMAL(18,3), 1) -> DECIMAL(18,1) + // but ONLY if the round value is positive + // if it is negative the scale becomes zero + // i.e. ROUND(DECIMAL(18,3), -1) -> DECIMAL(18,0) + int32_t round_value = IntegerValue::Get(val); + uint8_t target_scale; + auto width = DecimalType::GetWidth(decimal_type); + auto scale = DecimalType::GetScale(decimal_type); + if (round_value < 0) { + target_scale = 0; + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = GenericRoundPrecisionDecimal; + break; + case PhysicalType::INT32: + bound_function.function = GenericRoundPrecisionDecimal; + break; + case PhysicalType::INT64: + bound_function.function = GenericRoundPrecisionDecimal; + break; + default: + bound_function.function = GenericRoundPrecisionDecimal; + break; + } + } else { + if (round_value >= (int32_t)scale) { + // if round_value is bigger than or equal to scale we do nothing + bound_function.function = ScalarFunction::NopFunction; + target_scale = scale; + } else { + target_scale = NumericCast(round_value); + switch (decimal_type.InternalType()) { + case PhysicalType::INT16: + bound_function.function = GenericRoundPrecisionDecimal; + break; + case PhysicalType::INT32: + bound_function.function = GenericRoundPrecisionDecimal; + break; + case PhysicalType::INT64: + bound_function.function = GenericRoundPrecisionDecimal; + break; + default: + bound_function.function = GenericRoundPrecisionDecimal; + break; + } + } + } + bound_function.arguments[0] = decimal_type; + bound_function.return_type = LogicalType::DECIMAL(width, target_scale); + return make_uniq(round_value); +} + +struct TruncOperatorPrecision { + template + static inline TR Operation(TA input, TB precision) { + double trunc_value; + if (precision < 0) { + double modifier = std::pow(10, -TA(precision)); + trunc_value = (std::trunc(input / modifier)) * modifier; + if (std::isinf(trunc_value) || std::isnan(trunc_value)) { + return input; + } + } else { + double modifier = std::pow(10, TA(precision)); + trunc_value = (std::trunc(input * modifier)) / modifier; + if (std::isinf(trunc_value) || std::isnan(trunc_value)) { + return input; + } + } + return LossyNumericCast(trunc_value); + } +}; + struct TruncOperator { // Integer truncation is a NOP template @@ -493,40 +594,134 @@ struct TruncDecimalOperator { } }; +struct TruncDecimalNegativePrecisionOperator { + template + static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + auto width = DecimalType::GetWidth(func_expr.children[0]->return_type); + if (info.target_scale <= -int32_t(width - source_scale)) { + // scale too big for width + result.SetVectorType(VectorType::CONSTANT_VECTOR); + result.SetValue(0, Value::INTEGER(0)); + return; + } + T divide_power_of_ten = + UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]); + T multiply_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]); + + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + return UnsafeNumericCast(input / divide_power_of_ten * multiply_power_of_ten); + }); + } +}; + +struct TruncDecimalPositivePrecisionOperator { + template + static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]); + UnaryExecutor::Execute(input.data[0], result, input.size(), + [&](T input) { return UnsafeNumericCast(input / power_of_ten); }); + } +}; + +struct TruncIntegerOperator { + template + static inline TR Operation(TA input, TB precision) { + if (precision < 0) { + // Do all the arithmetic at higher precision + using POWERS_OF_TEN_CLASS = typename DecimalCastTraits::POWERS_OF_TEN_CLASS; + if (precision <= -POWERS_OF_TEN_CLASS::CACHED_POWERS_OF_TEN) { + return 0; + } + const auto power_of_ten = POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-precision]; + auto result = input; + result /= power_of_ten; + if (result) { + return UnsafeNumericCast(result * power_of_ten); + } else { + return 0; + } + } else { + // Truncating integers to higher precision is a NOP + return input; + } + } +}; + } // namespace ScalarFunctionSet TruncFun::GetFunctions() { ScalarFunctionSet trunc; for (auto &type : LogicalType::Numeric()) { - scalar_function_t func = nullptr; + scalar_function_t trunc_func = nullptr; + scalar_function_t trunc_prec_func = nullptr; bind_scalar_function_t bind_func = nullptr; + bind_scalar_function_t bind_prec_func = nullptr; // Truncation of integers gets generated by some tools (e.g., Tableau/JDBC:Postgres) switch (type.id()) { case LogicalTypeId::FLOAT: - func = ScalarFunction::UnaryFunction; + trunc_func = ScalarFunction::UnaryFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; break; case LogicalTypeId::DOUBLE: - func = ScalarFunction::UnaryFunction; + trunc_func = ScalarFunction::UnaryFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; break; case LogicalTypeId::DECIMAL: bind_func = BindGenericRoundFunctionDecimal; + bind_prec_func = + BindDecimalRoundPrecision; break; case LogicalTypeId::TINYINT: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::SMALLINT: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::INTEGER: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::BIGINT: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::HUGEINT: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::UTINYINT: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::USMALLINT: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::UINTEGER: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::UBIGINT: + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; + break; case LogicalTypeId::UHUGEINT: - func = ScalarFunction::NopFunction; + trunc_func = ScalarFunction::NopFunction; + trunc_prec_func = ScalarFunction::BinaryFunction; break; default: throw InternalException("Unimplemented numeric type for function \"trunc\""); } - trunc.AddFunction(ScalarFunction({type}, type, func, bind_func)); + trunc.AddFunction(ScalarFunction({type}, type, trunc_func, bind_func)); + trunc.AddFunction(ScalarFunction({type, LogicalType::INTEGER}, type, trunc_prec_func, bind_prec_func)); } return trunc; } @@ -618,132 +813,55 @@ struct RoundIntegerOperator { } }; -struct RoundPrecisionFunctionData : public FunctionData { - explicit RoundPrecisionFunctionData(int32_t target_scale) : target_scale(target_scale) { - } - - int32_t target_scale; - - unique_ptr Copy() const override { - return make_uniq(target_scale); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return target_scale == other.target_scale; - } -}; - } // namespace -template -static void DecimalRoundNegativePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); - auto width = DecimalType::GetWidth(func_expr.children[0]->return_type); - if (info.target_scale <= -int32_t(width - source_scale)) { - // scale too big for width - result.SetVectorType(VectorType::CONSTANT_VECTOR); - result.SetValue(0, Value::INTEGER(0)); - return; - } - T divide_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]); - T multiply_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]); - T addition = divide_power_of_ten / 2; - - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; - } - return UnsafeNumericCast(input / divide_power_of_ten * multiply_power_of_ten); - }); -} - -template -static void DecimalRoundPositivePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); - T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]); - T addition = power_of_ten / 2; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; +struct DecimalRoundNegativePrecisionOperator { + template + static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + auto width = DecimalType::GetWidth(func_expr.children[0]->return_type); + if (info.target_scale <= -int32_t(width - source_scale)) { + // scale too big for width + result.SetVectorType(VectorType::CONSTANT_VECTOR); + result.SetValue(0, Value::INTEGER(0)); + return; } - return UnsafeNumericCast(input / power_of_ten); - }); -} + T divide_power_of_ten = + UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]); + T multiply_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]); + T addition = divide_power_of_ten / 2; -static unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto &decimal_type = arguments[0]->return_type; - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); - } - Value val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]).DefaultCastAs(LogicalType::INTEGER); - if (val.IsNull()) { - throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; + } + return UnsafeNumericCast(input / divide_power_of_ten * multiply_power_of_ten); + }); } - // our new precision becomes the round value - // e.g. ROUND(DECIMAL(18,3), 1) -> DECIMAL(18,1) - // but ONLY if the round value is positive - // if it is negative the scale becomes zero - // i.e. ROUND(DECIMAL(18,3), -1) -> DECIMAL(18,0) - int32_t round_value = IntegerValue::Get(val); - uint8_t target_scale; - auto width = DecimalType::GetWidth(decimal_type); - auto scale = DecimalType::GetScale(decimal_type); - if (round_value < 0) { - target_scale = 0; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - case PhysicalType::INT32: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - case PhysicalType::INT64: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - default: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - } - } else { - if (round_value >= (int32_t)scale) { - // if round_value is bigger than or equal to scale we do nothing - bound_function.function = ScalarFunction::NopFunction; - target_scale = scale; - } else { - target_scale = NumericCast(round_value); - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - case PhysicalType::INT32: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - case PhysicalType::INT64: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - default: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; +}; + +struct DecimalRoundPositivePrecisionOperator { + template + static void Operation(DataChunk &input, ExpressionState &state, Vector &result) { + auto &func_expr = state.expr.Cast(); + auto &info = func_expr.bind_info->Cast(); + auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); + T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]); + T addition = power_of_ten / 2; + UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { + if (input < 0) { + input -= addition; + } else { + input += addition; } - } + return UnsafeNumericCast(input / power_of_ten); + }); } - bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, target_scale); - return make_uniq(round_value); -} +}; ScalarFunctionSet RoundFun::GetFunctions() { ScalarFunctionSet round; @@ -763,7 +881,8 @@ ScalarFunctionSet RoundFun::GetFunctions() { break; case LogicalTypeId::DECIMAL: bind_func = BindGenericRoundFunctionDecimal; - bind_prec_func = BindDecimalRoundPrecision; + bind_prec_func = + BindDecimalRoundPrecision; break; case LogicalTypeId::TINYINT: round_func = ScalarFunction::NopFunction; @@ -790,7 +909,7 @@ ScalarFunctionSet RoundFun::GetFunctions() { // no round for integral numbers continue; } - throw InternalException("Unimplemented numeric type for function \"floor\""); + throw InternalException("Unimplemented numeric type for function \"round\""); } round.AddFunction(ScalarFunction({type}, type, round_func, bind_func)); round.AddFunction(ScalarFunction({type, LogicalType::INTEGER}, type, round_prec_func, bind_prec_func)); diff --git a/src/duckdb/extension/core_functions/scalar/string/hex.cpp b/src/duckdb/extension/core_functions/scalar/string/hex.cpp index cbf541e1b..d3d6eee7b 100644 --- a/src/duckdb/extension/core_functions/scalar/string/hex.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/hex.cpp @@ -393,7 +393,7 @@ ScalarFunctionSet HexFun::GetFunctions() { to_hex.AddFunction( ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToHexFunction)); to_hex.AddFunction( - ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToHexFunction)); + ScalarFunction({LogicalType::BIGNUM}, LogicalType::VARCHAR, ToHexFunction)); to_hex.AddFunction( ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, ToHexFunction)); to_hex.AddFunction( @@ -419,7 +419,7 @@ ScalarFunctionSet BinFun::GetFunctions() { to_binary.AddFunction( ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToBinaryFunction)); to_binary.AddFunction( - ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToBinaryFunction)); + ScalarFunction({LogicalType::BIGNUM}, LogicalType::VARCHAR, ToBinaryFunction)); to_binary.AddFunction(ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, ToBinaryFunction)); to_binary.AddFunction( diff --git a/src/duckdb/extension/core_functions/scalar/string/printf.cpp b/src/duckdb/extension/core_functions/scalar/string/printf.cpp index 1db25b0df..1ec8ae2cd 100644 --- a/src/duckdb/extension/core_functions/scalar/string/printf.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/printf.cpp @@ -41,6 +41,12 @@ unique_ptr BindPrintfFunction(ClientContext &context, ScalarFuncti case LogicalTypeId::UBIGINT: bound_function.arguments.emplace_back(LogicalType::UBIGINT); break; + case LogicalTypeId::HUGEINT: + bound_function.arguments.emplace_back(LogicalType::HUGEINT); + break; + case LogicalTypeId::UHUGEINT: + bound_function.arguments.emplace_back(LogicalType::UHUGEINT); + break; case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: bound_function.arguments.emplace_back(LogicalType::DOUBLE); @@ -146,6 +152,16 @@ static void PrintfFunction(DataChunk &args, ExpressionState &state, Vector &resu format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); break; } + case LogicalTypeId::HUGEINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } + case LogicalTypeId::UHUGEINT: { + auto arg_data = FlatVector::GetData(col); + format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); + break; + } case LogicalTypeId::DOUBLE: { auto arg_data = FlatVector::GetData(col); format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); diff --git a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp index 154634f94..2bfceae03 100644 --- a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp +++ b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp @@ -31,21 +31,6 @@ static void RepeatFunction(DataChunk &args, ExpressionState &, Vector &result) { }); } -unique_ptr RepeatBindFunction(ClientContext &, ScalarFunction &bound_function, - vector> &arguments) { - switch (arguments[0]->return_type.id()) { - case LogicalTypeId::UNKNOWN: - throw ParameterNotResolvedException(); - case LogicalTypeId::LIST: - break; - default: - throw NotImplementedException("repeat(list, count) requires a list as parameter"); - } - bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = arguments[0]->return_type; - return nullptr; -} - static void RepeatListFunction(DataChunk &args, ExpressionState &, Vector &result) { auto &list_vector = args.data[0]; auto &cnt_vector = args.data[1]; @@ -79,8 +64,8 @@ ScalarFunctionSet RepeatFun::GetFunctions() { for (const auto &type : {LogicalType::VARCHAR, LogicalType::BLOB}) { repeat.AddFunction(ScalarFunction({type, LogicalType::BIGINT}, type, RepeatFunction)); } - repeat.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::ANY), RepeatListFunction, RepeatBindFunction)); + repeat.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::BIGINT}, + LogicalType::LIST(LogicalType::TEMPLATE("T")), RepeatListFunction)); for (auto &func : repeat.functions) { BaseScalarFunction::SetReturnsError(func); } diff --git a/src/duckdb/extension/icu/icu-timezone.cpp b/src/duckdb/extension/icu/icu-timezone.cpp index 3daed28b0..86b8b6033 100644 --- a/src/duckdb/extension/icu/icu-timezone.cpp +++ b/src/duckdb/extension/icu/icu-timezone.cpp @@ -9,6 +9,7 @@ #include "include/icu-datefunc.hpp" #include "duckdb/transaction/meta_transaction.hpp" #include "duckdb/common/operator/cast_operators.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -163,7 +164,7 @@ struct ICUFromNaiveTimestamp : public ICUDateFunc { if (!input.context) { throw InternalException("Missing context for TIMESTAMP to TIMESTAMPTZ cast."); } - if (input.context->config.disable_timestamptz_casts) { + if (DBConfig::GetSetting(*input.context)) { throw BinderException("Casting from TIMESTAMP to TIMESTAMP WITH TIME ZONE without an explicit time zone " "has been disabled - use \"AT TIME ZONE ...\""); } @@ -250,7 +251,7 @@ struct ICUToNaiveTimestamp : public ICUDateFunc { if (!input.context) { throw InternalException("Missing context for TIMESTAMPTZ to TIMESTAMP cast."); } - if (input.context->config.disable_timestamptz_casts) { + if (DBConfig::GetSetting(*input.context)) { throw BinderException("Casting from TIMESTAMP WITH TIME ZONE to TIMESTAMP without an explicit time zone " "has been disabled - use \"AT TIME ZONE ...\""); } diff --git a/src/duckdb/extension/json/include/json_functions.hpp b/src/duckdb/extension/json/include/json_functions.hpp index 0b98ac30f..dafc2c635 100644 --- a/src/duckdb/extension/json/include/json_functions.hpp +++ b/src/duckdb/extension/json/include/json_functions.hpp @@ -75,9 +75,9 @@ class JSONFunctions { optional_ptr data); static TableFunction GetReadJSONTableFunction(shared_ptr function_info); static CopyFunction GetJSONCopyFunction(); - static void RegisterSimpleCastFunctions(CastFunctionSet &casts); - static void RegisterJSONCreateCastFunctions(CastFunctionSet &casts); - static void RegisterJSONTransformCastFunctions(CastFunctionSet &casts); + static void RegisterSimpleCastFunctions(ExtensionLoader &loader); + static void RegisterJSONCreateCastFunctions(ExtensionLoader &loader); + static void RegisterJSONTransformCastFunctions(ExtensionLoader &loader); private: // Scalar functions diff --git a/src/duckdb/extension/json/include/json_reader.hpp b/src/duckdb/extension/json/include/json_reader.hpp index 9435ec5ca..de75af996 100644 --- a/src/duckdb/extension/json/include/json_reader.hpp +++ b/src/duckdb/extension/json/include/json_reader.hpp @@ -46,7 +46,7 @@ struct JSONBufferHandle { struct JSONFileHandle { public: - JSONFileHandle(unique_ptr file_handle, Allocator &allocator); + JSONFileHandle(QueryContext context, unique_ptr file_handle, Allocator &allocator); bool IsOpen() const; void Close(); @@ -74,6 +74,8 @@ struct JSONFileHandle { idx_t ReadFromCache(char *&pointer, idx_t &size, atomic &position); private: + QueryContext context; + //! The JSON file handle unique_ptr file_handle; Allocator &allocator; diff --git a/src/duckdb/extension/json/json_extension.cpp b/src/duckdb/extension/json/json_extension.cpp index d2dcdbffa..e4ca49e13 100644 --- a/src/duckdb/extension/json/json_extension.cpp +++ b/src/duckdb/extension/json/json_extension.cpp @@ -1,22 +1,17 @@ #include "json_extension.hpp" -#include "include/json_extension.hpp" + +#include "json_common.hpp" +#include "json_functions.hpp" #include "duckdb/catalog/catalog_entry/macro_catalog_entry.hpp" #include "duckdb/catalog/default/default_functions.hpp" -#include "duckdb/common/string_util.hpp" #include "duckdb/function/copy_function.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" -#include "duckdb/parser/tableref/table_function_ref.hpp" -#include "json_common.hpp" -#include "json_functions.hpp" #include "duckdb/main/extension/extension_loader.hpp" +#include "duckdb/parser/expression/function_expression.hpp" namespace duckdb { -static DefaultMacro json_macros[] = { +static const DefaultMacro JSON_MACROS[] = { {DEFAULT_SCHEMA, "json_group_array", {"x", nullptr}, @@ -37,18 +32,14 @@ static DefaultMacro json_macros[] = { {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr}}; static void LoadInternal(ExtensionLoader &loader) { - // auto &db_instance = *db.instance; - // JSON type auto json_type = LogicalType::JSON(); loader.RegisterType(LogicalType::JSON_TYPE_NAME, std::move(json_type)); // JSON casts - // TODO: Register these properly using the extension loader - auto &db_instance = loader.GetDatabaseInstance(); - JSONFunctions::RegisterSimpleCastFunctions(DBConfig::GetConfig(db_instance).GetCastFunctions()); - JSONFunctions::RegisterJSONCreateCastFunctions(DBConfig::GetConfig(db_instance).GetCastFunctions()); - JSONFunctions::RegisterJSONTransformCastFunctions(DBConfig::GetConfig(db_instance).GetCastFunctions()); + JSONFunctions::RegisterSimpleCastFunctions(loader); + JSONFunctions::RegisterJSONCreateCastFunctions(loader); + JSONFunctions::RegisterJSONTransformCastFunctions(loader); // JSON scalar functions for (auto &fun : JSONFunctions::GetScalarFunctions()) { @@ -66,8 +57,8 @@ static void LoadInternal(ExtensionLoader &loader) { } // JSON replacement scan - auto &config = DBConfig::GetConfig(db_instance); - config.replacement_scans.emplace_back(JSONFunctions::ReadJSONReplacement); + DBConfig::GetConfig(loader.GetDatabaseInstance()) + .replacement_scans.emplace_back(JSONFunctions::ReadJSONReplacement); // JSON copy function auto copy_fun = JSONFunctions::GetJSONCopyFunction(); @@ -80,8 +71,8 @@ static void LoadInternal(ExtensionLoader &loader) { loader.RegisterFunction(copy_fun); // JSON macro's - for (idx_t index = 0; json_macros[index].name != nullptr; index++) { - auto info = DefaultFunctionGenerator::CreateInternalMacroInfo(json_macros[index]); + for (idx_t index = 0; JSON_MACROS[index].name != nullptr; index++) { + auto info = DefaultFunctionGenerator::CreateInternalMacroInfo(JSON_MACROS[index]); loader.RegisterFunction(*info); } } diff --git a/src/duckdb/extension/json/json_functions.cpp b/src/duckdb/extension/json/json_functions.cpp index 0bed1f802..2d09828c3 100644 --- a/src/duckdb/extension/json/json_functions.cpp +++ b/src/duckdb/extension/json/json_functions.cpp @@ -259,20 +259,147 @@ static bool CastVarcharToJSON(Vector &source, Vector &result, idx_t count, CastP return success; } -void JSONFunctions::RegisterSimpleCastFunctions(CastFunctionSet &casts) { +static bool CastJSONListToVarchar(Vector &source, Vector &result, idx_t count, CastParameters &) { + UnifiedVectorFormat child_format; + ListVector::GetEntry(source).ToUnifiedFormat(ListVector::GetListSize(source), child_format); + const auto input_jsons = UnifiedVectorFormat::GetData(child_format); + + static constexpr char const *NULL_STRING = "NULL"; + static constexpr idx_t NULL_STRING_LENGTH = 4; + + UnaryExecutor::Execute( + source, result, count, + [&](const list_entry_t &input) { + // Compute len (start with [] and ,) + idx_t len = 2; + len += input.length == 0 ? 0 : (input.length - 1) * 2; + for (idx_t json_idx = input.offset; json_idx < input.offset + input.length; json_idx++) { + const auto sel_json_idx = child_format.sel->get_index(json_idx); + if (child_format.validity.RowIsValid(sel_json_idx)) { + len += input_jsons[sel_json_idx].GetSize(); + } else { + len += NULL_STRING_LENGTH; + } + } + + // Allocate string + auto res = StringVector::EmptyString(result, len); + auto ptr = res.GetDataWriteable(); + + // Populate string + *ptr++ = '['; + for (idx_t json_idx = input.offset; json_idx < input.offset + input.length; json_idx++) { + const auto sel_json_idx = child_format.sel->get_index(json_idx); + if (child_format.validity.RowIsValid(sel_json_idx)) { + auto &input_json = input_jsons[sel_json_idx]; + memcpy(ptr, input_json.GetData(), input_json.GetSize()); + ptr += input_json.GetSize(); + } else { + memcpy(ptr, NULL_STRING, NULL_STRING_LENGTH); + ptr += NULL_STRING_LENGTH; + } + if (json_idx != input.offset + input.length - 1) { + *ptr++ = ','; + *ptr++ = ' '; + } + } + *ptr = ']'; + + res.Finalize(); + return res; + }, + FunctionErrors::CANNOT_ERROR); + return true; +} + +static bool CastVarcharToJSONList(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { + auto &lstate = parameters.local_state->Cast(); + lstate.json_allocator->Reset(); + auto alc = lstate.json_allocator->GetYYAlc(); + + bool success = true; + UnaryExecutor::ExecuteWithNulls( + source, result, count, [&](const string_t &input, ValidityMask &mask, idx_t idx) -> list_entry_t { + // Figure out if the cast can succeed + yyjson_read_err error; + const auto doc = JSONCommon::ReadDocumentUnsafe(input.GetDataWriteable(), input.GetSize(), + JSONCommon::READ_FLAG, alc, &error); + if (!doc || !unsafe_yyjson_is_arr(doc->root)) { + mask.SetInvalid(idx); + if (success) { + if (!doc) { + HandleCastError::AssignError( + JSONCommon::FormatParseError(input.GetDataWriteable(), input.GetSize(), error), parameters); + } else if (!unsafe_yyjson_is_arr(doc->root)) { + auto truncated_input = + input.GetSize() > 50 ? string(input.GetData(), 47) + "..." : input.GetString(); + HandleCastError::AssignError( + StringUtil::Format("Cannot cast to list of JSON. Input \"%s\"", truncated_input), + parameters); + } + success = false; + } + return {}; + } + + auto current_size = ListVector::GetListSize(result); + const auto arr_len = unsafe_yyjson_get_len(doc->root); + const auto new_size = current_size + arr_len; + + // Grow list if needed + if (ListVector::GetListCapacity(result) < new_size) { + ListVector::Reserve(result, new_size); + } + + // Populate list + const auto result_jsons = FlatVector::GetData(ListVector::GetEntry(result)); + size_t arr_idx, max; + yyjson_val *val; + yyjson_arr_foreach(doc->root, arr_idx, max, val) { + result_jsons[current_size + arr_idx] = JSONCommon::WriteVal(val, alc); + } + + // Update size + ListVector::SetListSize(result, current_size + arr_len); + + return {current_size, arr_len}; + }); + + JSONAllocator::AddBuffer(ListVector::GetEntry(result), alc); + return success; +} + +void JSONFunctions::RegisterSimpleCastFunctions(ExtensionLoader &loader) { + auto &db = loader.GetDatabaseInstance(); + // JSON to VARCHAR is basically free - casts.RegisterCastFunction(LogicalType::JSON(), LogicalType::VARCHAR, DefaultCasts::ReinterpretCast, 1); + loader.RegisterCastFunction(LogicalType::JSON(), LogicalType::VARCHAR, DefaultCasts::ReinterpretCast, 1); // VARCHAR to JSON requires a parse so it's not free. Let's make it 1 more than a cast to STRUCT - auto varchar_to_json_cost = casts.ImplicitCastCost(LogicalType::SQLNULL, LogicalTypeId::STRUCT) + 1; + const auto varchar_to_json_cost = + CastFunctionSet::ImplicitCastCost(db, LogicalType::SQLNULL, LogicalTypeId::STRUCT) + 1; BoundCastInfo varchar_to_json_info(CastVarcharToJSON, nullptr, JSONFunctionLocalState::InitCastLocalState); - casts.RegisterCastFunction(LogicalType::VARCHAR, LogicalType::JSON(), std::move(varchar_to_json_info), - varchar_to_json_cost); + loader.RegisterCastFunction(LogicalType::VARCHAR, LogicalType::JSON(), std::move(varchar_to_json_info), + varchar_to_json_cost); // Register NULL to JSON with a different cost than NULL to VARCHAR so the binder can disambiguate functions - auto null_to_json_cost = casts.ImplicitCastCost(LogicalType::SQLNULL, LogicalTypeId::VARCHAR) + 1; - casts.RegisterCastFunction(LogicalType::SQLNULL, LogicalType::JSON(), DefaultCasts::TryVectorNullCast, - null_to_json_cost); + const auto null_to_json_cost = + CastFunctionSet::ImplicitCastCost(db, LogicalType::SQLNULL, LogicalTypeId::VARCHAR) + 1; + loader.RegisterCastFunction(LogicalType::SQLNULL, LogicalType::JSON(), DefaultCasts::TryVectorNullCast, + null_to_json_cost); + + // JSON[] to VARCHAR (this needs a special case otherwise the cast will escape quotes) + const auto json_list_to_varchar_cost = + CastFunctionSet::ImplicitCastCost(db, LogicalType::LIST(LogicalType::JSON()), LogicalTypeId::VARCHAR) - 1; + loader.RegisterCastFunction(LogicalType::LIST(LogicalType::JSON()), LogicalTypeId::VARCHAR, CastJSONListToVarchar, + json_list_to_varchar_cost); + + // VARCHAR to JSON[] (also needs a special case otherwise get a VARCHAR -> VARCHAR[] cast first) + const auto varchar_to_json_list_cost = + CastFunctionSet::ImplicitCastCost(db, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::JSON())) - 1; + BoundCastInfo varchar_to_json_list_info(CastVarcharToJSONList, nullptr, JSONFunctionLocalState::InitCastLocalState); + loader.RegisterCastFunction(LogicalType::VARCHAR, LogicalType::LIST(LogicalType::JSON()), + std::move(varchar_to_json_list_info), varchar_to_json_list_cost); } } // namespace duckdb diff --git a/src/duckdb/extension/json/json_functions/json_create.cpp b/src/duckdb/extension/json/json_functions/json_create.cpp index 560265a08..1271a09fb 100644 --- a/src/duckdb/extension/json/json_functions/json_create.cpp +++ b/src/duckdb/extension/json/json_functions/json_create.cpp @@ -67,7 +67,7 @@ static LogicalType GetJSONType(StructNames &const_struct_names, const LogicalTyp case LogicalTypeId::TIMESTAMP_MS: case LogicalTypeId::TIMESTAMP_SEC: case LogicalTypeId::UUID: - case LogicalTypeId::VARINT: + case LogicalTypeId::BIGNUM: case LogicalTypeId::DECIMAL: return type; case LogicalTypeId::LIST: @@ -585,7 +585,7 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m TemplatedCreateValues(doc, vals, string_vector, count); break; } - case LogicalTypeId::VARINT: { + case LogicalTypeId::BIGNUM: { Vector string_vector(LogicalTypeId::VARCHAR, count); VectorOperations::DefaultCast(value_v, string_vector, count); CreateRawValues(doc, vals, string_vector, count); @@ -607,6 +607,7 @@ static void CreateValues(const StructNames &names, yyjson_mut_doc *doc, yyjson_m case LogicalTypeId::UNKNOWN: case LogicalTypeId::ANY: case LogicalTypeId::USER: + case LogicalTypeId::TEMPLATE: case LogicalTypeId::CHAR: case LogicalTypeId::STRING_LITERAL: case LogicalTypeId::INTEGER_LITERAL: @@ -792,7 +793,7 @@ BoundCastInfo AnyToJSONCastBind(BindCastInput &input, const LogicalType &source, return BoundCastInfo(AnyToJSONCast, std::move(cast_data), JSONFunctionLocalState::InitCastLocalState); } -void JSONFunctions::RegisterJSONCreateCastFunctions(CastFunctionSet &casts) { +void JSONFunctions::RegisterJSONCreateCastFunctions(ExtensionLoader &loader) { // Anything can be cast to JSON for (const auto &type : LogicalType::AllTypes()) { LogicalType source_type; @@ -819,9 +820,9 @@ void JSONFunctions::RegisterJSONCreateCastFunctions(CastFunctionSet &casts) { source_type = type; } // We prefer going to JSON over going to VARCHAR if a function can do either - const auto source_to_json_cost = - MaxValue(casts.ImplicitCastCost(source_type, LogicalType::VARCHAR) - 1, 0); - casts.RegisterCastFunction(source_type, LogicalType::JSON(), AnyToJSONCastBind, source_to_json_cost); + const auto source_to_json_cost = MaxValue( + CastFunctionSet::ImplicitCastCost(loader.GetDatabaseInstance(), source_type, LogicalType::VARCHAR) - 1, 0); + loader.RegisterCastFunction(source_type, LogicalType::JSON(), AnyToJSONCastBind, source_to_json_cost); } } diff --git a/src/duckdb/extension/json/json_functions/json_transform.cpp b/src/duckdb/extension/json/json_functions/json_transform.cpp index e40d226e1..1da02936f 100644 --- a/src/duckdb/extension/json/json_functions/json_transform.cpp +++ b/src/duckdb/extension/json/json_functions/json_transform.cpp @@ -1028,7 +1028,7 @@ BoundCastInfo JSONToAnyCastBind(BindCastInput &input, const LogicalType &source, return BoundCastInfo(JSONToAnyCast, nullptr, JSONFunctionLocalState::InitCastLocalState); } -void JSONFunctions::RegisterJSONTransformCastFunctions(CastFunctionSet &casts) { +void JSONFunctions::RegisterJSONTransformCastFunctions(ExtensionLoader &loader) { // JSON can be cast to anything for (const auto &type : LogicalType::AllTypes()) { LogicalType target_type; @@ -1055,8 +1055,9 @@ void JSONFunctions::RegisterJSONTransformCastFunctions(CastFunctionSet &casts) { target_type = type; } // Going from JSON to another type has the same cost as going from VARCHAR to that type - const auto json_to_target_cost = casts.ImplicitCastCost(LogicalType::VARCHAR, target_type); - casts.RegisterCastFunction(LogicalType::JSON(), target_type, JSONToAnyCastBind, json_to_target_cost); + const auto json_to_target_cost = + CastFunctionSet::ImplicitCastCost(loader.GetDatabaseInstance(), LogicalType::VARCHAR, target_type); + loader.RegisterCastFunction(LogicalType::JSON(), target_type, JSONToAnyCastBind, json_to_target_cost); } } diff --git a/src/duckdb/extension/json/json_functions/read_json.cpp b/src/duckdb/extension/json/json_functions/read_json.cpp index ed41d3a81..f5e9bd405 100644 --- a/src/duckdb/extension/json/json_functions/read_json.cpp +++ b/src/duckdb/extension/json/json_functions/read_json.cpp @@ -162,15 +162,16 @@ void JSONScan::AutoDetect(ClientContext &context, MultiFileBindData &bind_data, AutoDetectState auto_detect_state(context, bind_data, files, date_format_map); const auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); const auto files_per_task = (file_count + num_threads - 1) / num_threads; - const auto num_tasks = file_count / files_per_task; + const auto num_tasks = (file_count + files_per_task - 1) / files_per_task; vector task_nodes(num_tasks); // Same idea as in union_by_name.hpp TaskExecutor executor(context); for (idx_t task_idx = 0; task_idx < num_tasks; task_idx++) { const auto file_idx_start = task_idx * files_per_task; - auto task = make_uniq(executor, auto_detect_state, task_nodes[task_idx], file_idx_start, - file_idx_start + files_per_task); + const auto file_idx_end = MinValue(file_idx_start + files_per_task, file_count); + auto task = + make_uniq(executor, auto_detect_state, task_nodes[task_idx], file_idx_start, file_idx_end); executor.ScheduleTask(std::move(task)); } executor.WorkOnTasks(); diff --git a/src/duckdb/extension/json/json_reader.cpp b/src/duckdb/extension/json/json_reader.cpp index 1eb11dbc1..b52026a4e 100644 --- a/src/duckdb/extension/json/json_reader.cpp +++ b/src/duckdb/extension/json/json_reader.cpp @@ -14,10 +14,10 @@ JSONBufferHandle::JSONBufferHandle(JSONReader &reader, idx_t buffer_index_p, idx buffer_size(buffer_size_p), buffer_start(buffer_start_p) { } -JSONFileHandle::JSONFileHandle(unique_ptr file_handle_p, Allocator &allocator_p) - : file_handle(std::move(file_handle_p)), allocator(allocator_p), can_seek(file_handle->CanSeek()), - file_size(file_handle->GetFileSize()), read_position(0), requested_reads(0), actual_reads(0), - last_read_requested(false), cached_size(0) { +JSONFileHandle::JSONFileHandle(QueryContext context_p, unique_ptr file_handle_p, Allocator &allocator_p) + : context(context_p), file_handle(std::move(file_handle_p)), allocator(allocator_p), + can_seek(file_handle->CanSeek()), file_size(file_handle->GetFileSize()), read_position(0), requested_reads(0), + actual_reads(0), last_read_requested(false), cached_size(0) { } bool JSONFileHandle::IsOpen() const { @@ -95,7 +95,7 @@ void JSONFileHandle::ReadAtPosition(char *pointer, idx_t size, idx_t position, } if (size != 0) { auto &handle = override_handle ? *override_handle.get() : *file_handle.get(); - handle.Read(pointer, size, position); + handle.Read(context, pointer, size, position); } const auto incremented_actual_reads = ++actual_reads; @@ -184,7 +184,8 @@ void JSONReader::OpenJSONFile() { if (!IsOpen()) { auto &fs = FileSystem::GetFileSystem(context); auto regular_file_handle = fs.OpenFile(file, FileFlags::FILE_FLAGS_READ | options.compression); - file_handle = make_uniq(std::move(regular_file_handle), BufferAllocator::Get(context)); + file_handle = make_uniq(QueryContext(context), std::move(regular_file_handle), + BufferAllocator::Get(context)); } Reset(); } diff --git a/src/duckdb/extension/parquet/include/column_reader.hpp b/src/duckdb/extension/parquet/include/column_reader.hpp index ed223c4ba..79259875b 100644 --- a/src/duckdb/extension/parquet/include/column_reader.hpp +++ b/src/duckdb/extension/parquet/include/column_reader.hpp @@ -21,13 +21,11 @@ #include "decoder/delta_length_byte_array_decoder.hpp" #include "decoder/delta_byte_array_decoder.hpp" #include "parquet_column_schema.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/vector.hpp" #include "duckdb/common/types/vector_cache.hpp" -#endif namespace duckdb { class ParquetReader; diff --git a/src/duckdb/extension/parquet/include/parquet_crypto.hpp b/src/duckdb/extension/parquet/include/parquet_crypto.hpp index bf6848065..7261a3bc8 100644 --- a/src/duckdb/extension/parquet/include/parquet_crypto.hpp +++ b/src/duckdb/extension/parquet/include/parquet_crypto.hpp @@ -10,10 +10,7 @@ #include "parquet_types.h" #include "duckdb/common/encryption_state.hpp" - -#ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/object_cache.hpp" -#endif namespace duckdb { diff --git a/src/duckdb/extension/parquet/include/parquet_statistics.hpp b/src/duckdb/extension/parquet/include/parquet_statistics.hpp index fc53fa328..cb05dae3b 100644 --- a/src/duckdb/extension/parquet/include/parquet_statistics.hpp +++ b/src/duckdb/extension/parquet/include/parquet_statistics.hpp @@ -9,9 +9,7 @@ #pragma once #include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/statistics/base_statistics.hpp" -#endif #include "parquet_types.h" #include "resizable_buffer.hpp" diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp index b5292aad1..a2bfc3a80 100644 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ b/src/duckdb/extension/parquet/include/parquet_writer.hpp @@ -10,6 +10,7 @@ #include "duckdb.hpp" #include "duckdb/common/common.hpp" +#include "duckdb/common/optional_idx.hpp" #include "duckdb/common/encryption_state.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/mutex.hpp" @@ -81,7 +82,7 @@ class ParquetWriter { ParquetWriter(ClientContext &context, FileSystem &fs, string file_name, vector types, vector names, duckdb_parquet::CompressionCodec::type codec, ChildFieldIDs field_ids, const vector> &kv_metadata, - shared_ptr encryption_config, idx_t dictionary_size_limit, + shared_ptr encryption_config, optional_idx dictionary_size_limit, idx_t string_dictionary_page_size_limit, bool enable_bloom_filters, double bloom_filter_false_positive_ratio, int64_t compression_level, bool debug_use_openssl, ParquetVersion parquet_version); @@ -117,7 +118,7 @@ class ParquetWriter { idx_t FileSize() { return total_written; } - idx_t DictionarySizeLimit() const { + optional_idx DictionarySizeLimit() const { return dictionary_size_limit; } idx_t StringDictionaryPageSizeLimit() const { @@ -166,7 +167,7 @@ class ParquetWriter { duckdb_parquet::CompressionCodec::type codec; ChildFieldIDs field_ids; shared_ptr encryption_config; - idx_t dictionary_size_limit; + optional_idx dictionary_size_limit; idx_t string_dictionary_page_size_limit; bool enable_bloom_filters; double bloom_filter_false_positive_ratio; diff --git a/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp index 3057eaf39..fcbe4bdf4 100644 --- a/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/row_number_column_reader.hpp @@ -8,9 +8,7 @@ #pragma once -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/limits.hpp" -#endif #include "column_reader.hpp" #include "reader/templated_column_reader.hpp" diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp index 796d7696c..a7c717709 100644 --- a/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_binary_decoder.hpp @@ -1,7 +1,8 @@ #pragma once #include "duckdb/common/types/string_type.hpp" -#include "yyjson.hpp" +#include "duckdb/common/types/value.hpp" +#include "reader/variant/variant_value.hpp" using namespace duckdb_yyjson; @@ -130,23 +131,20 @@ struct VariantDecodeResult { class VariantBinaryDecoder { public: - explicit VariantBinaryDecoder(ClientContext &context); + VariantBinaryDecoder() = delete; public: - yyjson_mut_val *Decode(yyjson_mut_doc *doc, const VariantMetadata &metadata, const_data_ptr_t data); + static VariantValue Decode(const VariantMetadata &metadata, const_data_ptr_t data); public: - yyjson_mut_val *PrimitiveTypeDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data); - yyjson_mut_val *ShortStringDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data); - yyjson_mut_val *ObjectDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data); - yyjson_mut_val *ArrayDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data); - -public: - ClientContext &context; + static VariantValue PrimitiveTypeDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, + const_data_ptr_t data); + static VariantValue ShortStringDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, + const_data_ptr_t data); + static VariantValue ObjectDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, + const_data_ptr_t data); + static VariantValue ArrayDecode(const VariantMetadata &metadata, const VariantValueMetadata &value_metadata, + const_data_ptr_t data); }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp new file mode 100644 index 000000000..27ece7d70 --- /dev/null +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_shredded_conversion.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "reader/variant/variant_value.hpp" +#include "reader/variant/variant_binary_decoder.hpp" + +namespace duckdb { + +class VariantShreddedConversion { +public: + VariantShreddedConversion() = delete; + +public: + static vector Convert(Vector &metadata, Vector &group, idx_t offset, idx_t length, idx_t total_size, + bool is_field = false); + static vector ConvertShreddedLeaf(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, + idx_t length, idx_t total_size); + static vector ConvertShreddedArray(Vector &metadata, Vector &value, Vector &typed_value, idx_t offset, + idx_t length, idx_t total_size); + static vector ConvertShreddedObject(Vector &metadata, Vector &value, Vector &typed_value, + idx_t offset, idx_t length, idx_t total_size); +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp b/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp new file mode 100644 index 000000000..a4c38ede7 --- /dev/null +++ b/src/duckdb/extension/parquet/include/reader/variant/variant_value.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "duckdb/common/map.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/types/value.hpp" + +#include "yyjson.hpp" + +using namespace duckdb_yyjson; + +namespace duckdb { + +enum class VariantValueType : uint8_t { PRIMITIVE, OBJECT, ARRAY, MISSING }; + +struct VariantValue { +public: + VariantValue() : value_type(VariantValueType::MISSING) { + } + explicit VariantValue(VariantValueType type) : value_type(type) { + } + explicit VariantValue(Value &&val) : value_type(VariantValueType::PRIMITIVE), primitive_value(std::move(val)) { + } + // Delete copy constructor and copy assignment operator + VariantValue(const VariantValue &) = delete; + VariantValue &operator=(const VariantValue &) = delete; + + // Default move constructor and move assignment operator + VariantValue(VariantValue &&) noexcept = default; + VariantValue &operator=(VariantValue &&) noexcept = default; + +public: + bool IsNull() const { + return value_type == VariantValueType::PRIMITIVE && primitive_value.IsNull(); + } + bool IsMissing() const { + return value_type == VariantValueType::MISSING; + } + +public: + void AddChild(const string &key, VariantValue &&val); + void AddItem(VariantValue &&val); + +public: + yyjson_mut_val *ToJSON(ClientContext &context, yyjson_mut_doc *doc) const; + +public: + VariantValueType value_type; + //! FIXME: how can we get a deterministic child order for a partially shredded object? + map object_children; + vector array_items; + Value primitive_value; +}; + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp b/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp index 107a84b41..78670b14a 100644 --- a/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp +++ b/src/duckdb/extension/parquet/include/reader/variant_column_reader.hpp @@ -35,6 +35,10 @@ class VariantColumnReader : public ColumnReader { idx_t GroupRowsAvailable() override; uint64_t TotalCompressedSize() override; void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override; + +protected: + idx_t metadata_reader_idx; + idx_t value_reader_idx; }; } // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/resizable_buffer.hpp b/src/duckdb/extension/parquet/include/resizable_buffer.hpp index be00a0a7a..7452ddd8f 100644 --- a/src/duckdb/extension/parquet/include/resizable_buffer.hpp +++ b/src/duckdb/extension/parquet/include/resizable_buffer.hpp @@ -9,9 +9,7 @@ #pragma once #include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/allocator.hpp" -#endif #include diff --git a/src/duckdb/extension/parquet/include/thrift_tools.hpp b/src/duckdb/extension/parquet/include/thrift_tools.hpp index ce3170345..d58422991 100644 --- a/src/duckdb/extension/parquet/include/thrift_tools.hpp +++ b/src/duckdb/extension/parquet/include/thrift_tools.hpp @@ -13,11 +13,9 @@ #include "thrift/transport/TBufferTransports.h" #include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/caching_file_system.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/allocator.hpp" -#endif namespace duckdb { @@ -154,7 +152,7 @@ class ThriftFileTransport : public duckdb_apache::thrift::transport::TVirtualTra memcpy(buf, prefetch_buffer_fallback->buffer_ptr + location - prefetch_buffer_fallback->location, len); } else { // No prefetch, do a regular (non-caching) read - file_handle.GetFileHandle().Read(buf, len, location); + file_handle.GetFileHandle().Read(context, buf, len, location); } location += len; @@ -213,6 +211,8 @@ class ThriftFileTransport : public duckdb_apache::thrift::transport::TVirtualTra } private: + QueryContext context; + CachingFileHandle &file_handle; idx_t location; idx_t size; diff --git a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp index 392c2d815..c035bba43 100644 --- a/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp +++ b/src/duckdb/extension/parquet/include/writer/templated_column_writer.hpp @@ -20,8 +20,9 @@ namespace duckdb { template static void TemplatedWritePlain(Vector &col, ColumnWriterStatistics *stats, const idx_t chunk_start, const idx_t chunk_end, const ValidityMask &mask, WriteStream &ser) { - static constexpr bool COPY_DIRECTLY_FROM_VECTOR = - ALL_VALID && std::is_same::value && std::is_arithmetic::value; + static constexpr bool COPY_DIRECTLY_FROM_VECTOR = ALL_VALID && std::is_same::value && + std::is_arithmetic::value && + std::is_same::value; const auto *const ptr = FlatVector::GetData(col); @@ -67,7 +68,9 @@ class StandardColumnWriterState : public PrimitiveColumnWriterState { public: StandardColumnWriterState(ParquetWriter &writer, duckdb_parquet::RowGroup &row_group, idx_t col_idx) : PrimitiveColumnWriterState(writer, row_group, col_idx), - dictionary(BufferAllocator::Get(writer.GetContext()), writer.DictionarySizeLimit(), + dictionary(BufferAllocator::Get(writer.GetContext()), + writer.DictionarySizeLimit().IsValid() ? writer.DictionarySizeLimit().GetIndex() + : NumericCast(row_group.num_rows) / 5, writer.StringDictionaryPageSizeLimit()), encoding(duckdb_parquet::Encoding::PLAIN) { } diff --git a/src/duckdb/extension/parquet/include/zstd_file_system.hpp b/src/duckdb/extension/parquet/include/zstd_file_system.hpp index 5b132bc8a..15a2e5887 100644 --- a/src/duckdb/extension/parquet/include/zstd_file_system.hpp +++ b/src/duckdb/extension/parquet/include/zstd_file_system.hpp @@ -9,9 +9,7 @@ #pragma once #include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/compressed_file_system.hpp" -#endif namespace duckdb { diff --git a/src/duckdb/extension/parquet/parquet_crypto.cpp b/src/duckdb/extension/parquet/parquet_crypto.cpp index d68570941..07321e6ac 100644 --- a/src/duckdb/extension/parquet/parquet_crypto.cpp +++ b/src/duckdb/extension/parquet/parquet_crypto.cpp @@ -3,12 +3,10 @@ #include "mbedtls_wrapper.hpp" #include "thrift_tools.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/helper.hpp" #include "duckdb/common/types/blob.hpp" #include "duckdb/storage/arena_allocator.hpp" -#endif namespace duckdb { diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp index fff277c47..8744e8c89 100644 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ b/src/duckdb/extension/parquet/parquet_extension.cpp @@ -48,6 +48,8 @@ #include "duckdb/storage/table/row_group.hpp" #include "duckdb/common/multi_file/multi_file_function.hpp" #include "duckdb/common/primitive_dictionary.hpp" +#include "duckdb/logging/log_manager.hpp" +#include "duckdb/main/settings.hpp" #include "parquet_multi_file_info.hpp" namespace duckdb { @@ -215,12 +217,10 @@ struct ParquetWriteBindData : public TableFunctionData { bool debug_use_openssl = true; //! After how many distinct values should we abandon dictionary compression and bloom filters? - idx_t dictionary_size_limit = row_group_size / 20; - - void SetToDefaultDictionarySizeLimit() { - // This depends on row group size so we should "reset" if the row group size is changed - dictionary_size_limit = row_group_size / 20; - } + //! Defaults to 1/5th of the row group size if unset (in templated_column_writer.hpp) + //! This needs to be set dynamically because row groups can be much smaller than "row_group_size" set here, + //! e.g., due to less data or row_group_size_bytes + optional_idx dictionary_size_limit; //! This is huge but we grow it starting from 1 MB idx_t string_dictionary_page_size_limit = PrimitiveColumnWriter::MAX_UNCOMPRESSED_DICT_PAGE_SIZE; @@ -274,7 +274,6 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun D_ASSERT(names.size() == sql_types.size()); bool row_group_size_bytes_set = false; bool compression_level_set = false; - bool dictionary_size_limit_set = false; auto bind_data = make_uniq(); for (auto &option : input.info.options) { const auto loption = StringUtil::Lower(option.first); @@ -284,9 +283,6 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } if (loption == "row_group_size" || loption == "chunk_size") { bind_data->row_group_size = option.second[0].GetValue(); - if (!dictionary_size_limit_set) { - bind_data->SetToDefaultDictionarySizeLimit(); - } } else if (loption == "row_group_size_bytes") { auto roption = option.second[0]; if (roption.GetTypeMutable().id() == LogicalTypeId::VARCHAR) { @@ -363,7 +359,6 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun throw BinderException("dictionary_size_limit must be greater than 0 or 0 to disable"); } bind_data->dictionary_size_limit = val; - dictionary_size_limit_set = true; } else if (loption == "string_dictionary_page_size_limit") { auto val = option.second[0].GetValue(); if (val > PrimitiveColumnWriter::MAX_UNCOMPRESSED_DICT_PAGE_SIZE || val == 0) { @@ -412,7 +407,7 @@ static unique_ptr ParquetWriteBind(ClientContext &context, CopyFun } } if (row_group_size_bytes_set) { - if (DBConfig::GetConfig(context).options.preserve_insertion_order) { + if (DBConfig::GetSetting(context)) { throw BinderException("ROW_GROUP_SIZE_BYTES does not work while preserving insertion order. Use \"SET " "preserve_insertion_order=false;\" to disable preserving insertion order."); } @@ -684,8 +679,8 @@ static unique_ptr ParquetCopyDeserialize(Deserializer &deserialize 110, "row_groups_per_file", default_value.row_groups_per_file); data->debug_use_openssl = deserializer.ReadPropertyWithExplicitDefault(111, "debug_use_openssl", default_value.debug_use_openssl); - data->dictionary_size_limit = deserializer.ReadPropertyWithExplicitDefault( - 112, "dictionary_size_limit", default_value.dictionary_size_limit); + data->dictionary_size_limit = + deserializer.ReadPropertyWithExplicitDefault(112, "dictionary_size_limit", optional_idx()); data->bloom_filter_false_positive_ratio = deserializer.ReadPropertyWithExplicitDefault( 113, "bloom_filter_false_positive_ratio", default_value.bloom_filter_false_positive_ratio); data->parquet_version = @@ -937,7 +932,7 @@ static void LoadInternal(ExtensionLoader &loader) { auto &config = DBConfig::GetConfig(db_instance); config.replacement_scans.emplace_back(ParquetScanReplacement); config.AddExtensionOption("binary_as_string", "In Parquet files, interpret binary data as a string.", - LogicalType::BOOLEAN); + LogicalType::BOOLEAN, Value(false)); config.AddExtensionOption("disable_parquet_prefetching", "Disable the prefetching mechanism in Parquet", LogicalType::BOOLEAN, Value(false)); config.AddExtensionOption("prefetch_all_parquet_files", @@ -950,6 +945,9 @@ static void LoadInternal(ExtensionLoader &loader) { "enable_geoparquet_conversion", "Attempt to decode/encode geometry data in/as GeoParquet files if the spatial extension is present.", LogicalType::BOOLEAN, Value::BOOLEAN(true)); + config.AddExtensionOption("variant_legacy_encoding", + "Enables the Parquet reader to identify a Variant structurally.", LogicalType::BOOLEAN, + Value::BOOLEAN(false)); } void ParquetExtension::Load(ExtensionLoader &loader) { diff --git a/src/duckdb/extension/parquet/parquet_float16.cpp b/src/duckdb/extension/parquet/parquet_float16.cpp index 968b6533b..8a07d7c6c 100644 --- a/src/duckdb/extension/parquet/parquet_float16.cpp +++ b/src/duckdb/extension/parquet/parquet_float16.cpp @@ -1,9 +1,6 @@ #include "parquet_float16.hpp" #include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION - -#endif namespace duckdb { diff --git a/src/duckdb/extension/parquet/parquet_metadata.cpp b/src/duckdb/extension/parquet/parquet_metadata.cpp index 0e4dd5ab7..d160d7197 100644 --- a/src/duckdb/extension/parquet/parquet_metadata.cpp +++ b/src/duckdb/extension/parquet/parquet_metadata.cpp @@ -4,13 +4,11 @@ #include -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/multi_file/multi_file_reader.hpp" #include "duckdb/common/types/blob.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/planner/filter/constant_filter.hpp" #include "duckdb/main/config.hpp" -#endif namespace duckdb { @@ -200,6 +198,9 @@ void ParquetMetaDataOperatorData::BindMetaData(vector &return_types names.emplace_back("max_is_exact"); return_types.emplace_back(LogicalType::BOOLEAN); + + names.emplace_back("row_group_compressed_bytes"); + return_types.emplace_back(LogicalType::BIGINT); } static Value ConvertParquetStats(const LogicalType &type, const ParquetColumnSchema &schema_ele, bool stats_is_set, @@ -350,6 +351,11 @@ void ParquetMetaDataOperatorData::LoadRowGroupMetadata(ClientContext &context, c current_chunk.SetValue(27, count, ParquetElementBoolean(stats.is_max_value_exact, stats.__isset.is_max_value_exact)); + // row_group_compressed_bytes + current_chunk.SetValue( + 28, count, + ParquetElementBigint(row_group.__isset.total_compressed_size, row_group.__isset.total_compressed_size)); + count++; if (count >= STANDARD_VECTOR_SIZE) { current_chunk.SetCardinality(count); diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp index ca7f88588..16010ced3 100644 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ b/src/duckdb/extension/parquet/parquet_reader.cpp @@ -29,6 +29,7 @@ #include "duckdb/optimizer/statistics_propagator.hpp" #include "duckdb/planner/table_filter_state.hpp" #include "duckdb/common/multi_file/multi_file_reader.hpp" +#include "duckdb/logging/log_manager.hpp" #include #include @@ -47,8 +48,8 @@ using duckdb_parquet::SchemaElement; using duckdb_parquet::Statistics; using duckdb_parquet::Type; -static unique_ptr CreateThriftFileProtocol(CachingFileHandle &file_handle, - bool prefetch_mode) { +static unique_ptr +CreateThriftFileProtocol(QueryContext context, CachingFileHandle &file_handle, bool prefetch_mode) { auto transport = duckdb_base_std::make_shared(file_handle, prefetch_mode); return make_uniq>(std::move(transport)); } @@ -91,7 +92,7 @@ static shared_ptr LoadMetadata(ClientContext &context, Allocator &allocator, CachingFileHandle &file_handle, const shared_ptr &encryption_config, const EncryptionUtil &encryption_util, optional_idx footer_size) { - auto file_proto = CreateThriftFileProtocol(file_handle, false); + auto file_proto = CreateThriftFileProtocol(QueryContext(context), file_handle, false); auto &transport = reinterpret_cast(*file_proto->getTransport()); auto file_size = transport.GetSize(); if (file_size < 12) { @@ -517,22 +518,35 @@ static bool IsVariantType(const SchemaElement &root, const vectorname != "metadata") { return false; } - if (value.name != "value") { + if (value->name != "value") { return false; } //! Verify types - if (metadata.parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { + if (metadata->parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { return false; } - if (value.parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { + if (value->parquet_type != duckdb_parquet::Type::BYTE_ARRAY) { return false; } if (children.size() == 3) { @@ -540,7 +554,6 @@ static bool IsVariantType(const SchemaElement &root, const vector 0) { // check if the parent node of this is a map auto &p_ele = file_meta_data->schema[this_idx - 1]; @@ -776,7 +793,8 @@ ParquetOptions::ParquetOptions(ClientContext &context) { Value lookup_value; if (context.TryGetCurrentSetting("binary_as_string", lookup_value)) { binary_as_string = lookup_value.GetValue(); - } else if (context.TryGetCurrentSetting("variant_legacy_encoding", lookup_value)) { + } + if (context.TryGetCurrentSetting("variant_legacy_encoding", lookup_value)) { variant_legacy_encoding = lookup_value.GetValue(); } } @@ -805,7 +823,7 @@ ParquetReader::ParquetReader(ClientContext &context_p, OpenFileInfo file_p, Parq shared_ptr metadata_p) : BaseFileReader(std::move(file_p)), fs(CachingFileSystem::Get(context_p)), allocator(BufferAllocator::Get(context_p)), parquet_options(std::move(parquet_options_p)) { - file_handle = fs.OpenFile(file, FileFlags::FILE_FLAGS_READ); + file_handle = fs.OpenFile(QueryContext(context_p), file, FileFlags::FILE_FLAGS_READ); if (!file_handle->CanSeek()) { throw NotImplementedException( "Reading parquet files from a FIFO stream is not supported and cannot be efficiently supported since " @@ -986,8 +1004,8 @@ const ParquetRowGroup &ParquetReader::GetGroup(ParquetReaderScanState &state) { } uint64_t ParquetReader::GetGroupCompressedSize(ParquetReaderScanState &state) { - auto &group = GetGroup(state); - auto total_compressed_size = group.total_compressed_size; + const auto &group = GetGroup(state); + int64_t total_compressed_size = group.__isset.total_compressed_size ? group.total_compressed_size : 0; idx_t calc_compressed_size = 0; @@ -1202,7 +1220,7 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat state.prefetch_mode = false; } - state.file_handle = fs.OpenFile(file, flags); + state.file_handle = fs.OpenFile(QueryContext(context), file, flags); } state.adaptive_filter.reset(); state.scan_filters.clear(); @@ -1213,7 +1231,7 @@ void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanStat } } - state.thrift_file_proto = CreateThriftFileProtocol(*state.file_handle, state.prefetch_mode); + state.thrift_file_proto = CreateThriftFileProtocol(QueryContext(context), *state.file_handle, state.prefetch_mode); state.root_reader = CreateReader(context); state.define_buf.resize(allocator, STANDARD_VECTOR_SIZE); state.repeat_buf.resize(allocator, STANDARD_VECTOR_SIZE); diff --git a/src/duckdb/extension/parquet/parquet_statistics.cpp b/src/duckdb/extension/parquet/parquet_statistics.cpp index 3f4043e5a..5f7d93718 100644 --- a/src/duckdb/extension/parquet/parquet_statistics.cpp +++ b/src/duckdb/extension/parquet/parquet_statistics.cpp @@ -8,15 +8,12 @@ #include "reader/string_column_reader.hpp" #include "reader/struct_column_reader.hpp" #include "zstd/common/xxhash.hpp" - -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/types/blob.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/storage/statistics/struct_stats.hpp" #include "duckdb/planner/filter/constant_filter.hpp" #include "reader/uuid_column_reader.hpp" -#endif namespace duckdb { diff --git a/src/duckdb/extension/parquet/parquet_timestamp.cpp b/src/duckdb/extension/parquet/parquet_timestamp.cpp index 60953a6b6..892eb3f5a 100644 --- a/src/duckdb/extension/parquet/parquet_timestamp.cpp +++ b/src/duckdb/extension/parquet/parquet_timestamp.cpp @@ -1,11 +1,9 @@ #include "parquet_timestamp.hpp" #include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/types/date.hpp" #include "duckdb/common/types/time.hpp" #include "duckdb/common/types/timestamp.hpp" -#endif namespace duckdb { diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp index 3771e2c59..205a7d05c 100644 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ b/src/duckdb/extension/parquet/parquet_writer.cpp @@ -5,8 +5,6 @@ #include "parquet_crypto.hpp" #include "parquet_timestamp.hpp" #include "resizable_buffer.hpp" - -#ifndef DUCKDB_AMALGAMATION #include "duckdb/common/file_system.hpp" #include "duckdb/common/serializer/buffered_file_writer.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -19,7 +17,6 @@ #include "duckdb/parser/parsed_data/create_copy_function_info.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" #include "duckdb/common/types/blob.hpp" -#endif namespace duckdb { @@ -344,10 +341,10 @@ class ParquetStatsAccumulator { ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file_name_p, vector types_p, vector names_p, CompressionCodec::type codec, ChildFieldIDs field_ids_p, const vector> &kv_metadata, - shared_ptr encryption_config_p, idx_t dictionary_size_limit_p, - idx_t string_dictionary_page_size_limit_p, bool enable_bloom_filters_p, - double bloom_filter_false_positive_ratio_p, int64_t compression_level_p, - bool debug_use_openssl_p, ParquetVersion parquet_version) + shared_ptr encryption_config_p, + optional_idx dictionary_size_limit_p, idx_t string_dictionary_page_size_limit_p, + bool enable_bloom_filters_p, double bloom_filter_false_positive_ratio_p, + int64_t compression_level_p, bool debug_use_openssl_p, ParquetVersion parquet_version) : context(context), file_name(std::move(file_name_p)), sql_types(std::move(types_p)), column_names(std::move(names_p)), codec(codec), field_ids(std::move(field_ids_p)), encryption_config(std::move(encryption_config_p)), dictionary_size_limit(dictionary_size_limit_p), @@ -543,12 +540,24 @@ void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { // let's make sure all offsets are ay-okay ValidateColumnOffsets(file_name, writer->GetTotalWritten(), row_group); - // append the row group to the file meta data + row_group.total_compressed_size = NumericCast(writer->GetTotalWritten()) - row_group.file_offset; + row_group.__isset.total_compressed_size = true; + + if (encryption_config) { + auto row_group_ordinal = num_row_groups.load(); + if (row_group_ordinal > std::numeric_limits::max()) { + throw InvalidInputException("RowGroup ordinal exceeds 32767 when encryption enabled"); + } + row_group.ordinal = NumericCast(row_group_ordinal); + row_group.__isset.ordinal = true; + } + + // append the row group to the file metadata file_meta_data.row_groups.push_back(row_group); file_meta_data.num_rows += row_group.num_rows; total_written = writer->GetTotalWritten(); - num_row_groups++; + ++num_row_groups; } void ParquetWriter::Flush(ColumnDataCollection &buffer) { diff --git a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp index 74b6057aa..eacff5501 100644 --- a/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp +++ b/src/duckdb/extension/parquet/reader/variant/variant_binary_decoder.cpp @@ -114,9 +114,6 @@ VariantValueMetadata VariantValueMetadata::FromHeaderByte(uint8_t byte) { return result; } -VariantBinaryDecoder::VariantBinaryDecoder(ClientContext &context) : context(context) { -} - template static T DecodeDecimal(const_data_ptr_t data, uint8_t &scale, uint8_t &width) { scale = Load(data); @@ -143,39 +140,42 @@ hugeint_t DecodeDecimal(const_data_ptr_t data, uint8_t &scale, uint8_t &width) { return result; } -yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, - const_data_ptr_t data) { +VariantValue VariantBinaryDecoder::PrimitiveTypeDecode(const VariantMetadata &metadata, + const VariantValueMetadata &value_metadata, + const_data_ptr_t data) { switch (value_metadata.primitive_type) { case VariantPrimitiveType::NULL_TYPE: { - return yyjson_mut_null(doc); + return VariantValue(Value()); } case VariantPrimitiveType::BOOLEAN_TRUE: { - return yyjson_mut_true(doc); + return VariantValue(Value::BOOLEAN(true)); } case VariantPrimitiveType::BOOLEAN_FALSE: { - return yyjson_mut_false(doc); + return VariantValue(Value::BOOLEAN(false)); } case VariantPrimitiveType::INT8: { auto value = Load(data); - return yyjson_mut_int(doc, value); + return VariantValue(Value::TINYINT(value)); } case VariantPrimitiveType::INT16: { auto value = Load(data); - return yyjson_mut_int(doc, value); + return VariantValue(Value::SMALLINT(value)); } case VariantPrimitiveType::INT32: { auto value = Load(data); - return yyjson_mut_int(doc, value); + return VariantValue(Value::INTEGER(value)); } case VariantPrimitiveType::INT64: { auto value = Load(data); - return yyjson_mut_int(doc, value); + return VariantValue(Value::BIGINT(value)); } case VariantPrimitiveType::DOUBLE: { - double value; - memcpy(&value, data, sizeof(double)); - return yyjson_mut_real(doc, value); + double value = Load(data); + return VariantValue(Value::DOUBLE(value)); + } + case VariantPrimitiveType::FLOAT: { + float value = Load(data); + return VariantValue(Value::FLOAT(value)); } case VariantPrimitiveType::DECIMAL4: { uint8_t scale; @@ -183,7 +183,7 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c auto value = DecodeDecimal(data, scale, width); auto value_str = Decimal::ToString(value, width, scale); - return yyjson_mut_strcpy(doc, value_str.c_str()); + return VariantValue(Value(value_str)); } case VariantPrimitiveType::DECIMAL8: { uint8_t scale; @@ -191,7 +191,7 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c auto value = DecodeDecimal(data, scale, width); auto value_str = Decimal::ToString(value, width, scale); - return yyjson_mut_strcpy(doc, value_str.c_str()); + return VariantValue(Value(value_str)); } case VariantPrimitiveType::DECIMAL16: { uint8_t scale; @@ -199,21 +199,17 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c auto value = DecodeDecimal(data, scale, width); auto value_str = Decimal::ToString(value, width, scale); - return yyjson_mut_strcpy(doc, value_str.c_str()); + return VariantValue(Value(value_str)); } case VariantPrimitiveType::DATE: { date_t value; value.days = Load(data); - auto value_str = Date::ToString(value); - return yyjson_mut_strcpy(doc, value_str.c_str()); + return VariantValue(Value::DATE(value)); } case VariantPrimitiveType::TIMESTAMP_MICROS: { - timestamp_tz_t micros_tz_ts; - micros_tz_ts.value = Load(data); - - auto value = Value::TIMESTAMPTZ(micros_tz_ts); - auto value_str = value.CastAs(context, LogicalType::VARCHAR).GetValue(); - return yyjson_mut_strcpy(doc, value_str.c_str()); + timestamp_tz_t micros_ts_tz; + micros_ts_tz.value = Load(data); + return VariantValue(Value::TIMESTAMPTZ(micros_ts_tz)); } case VariantPrimitiveType::TIMESTAMP_NTZ_MICROS: { timestamp_t micros_ts; @@ -221,12 +217,7 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c auto value = Value::TIMESTAMP(micros_ts); auto value_str = value.ToString(); - return yyjson_mut_strcpy(doc, value_str.c_str()); - } - case VariantPrimitiveType::FLOAT: { - float value; - memcpy(&value, data, sizeof(float)); - return yyjson_mut_real(doc, value); + return VariantValue(Value(value_str)); } case VariantPrimitiveType::BINARY: { //! Follow the JSON serialization guide by converting BINARY to Base64: @@ -234,7 +225,7 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c auto size = Load(data); auto string_data = reinterpret_cast(data + sizeof(uint32_t)); auto base64_string = Blob::ToBase64(string_t(string_data, size)); - return yyjson_mut_strncpy(doc, base64_string.c_str(), base64_string.size()); + return VariantValue(Value(base64_string)); } case VariantPrimitiveType::STRING: { auto size = Load(data); @@ -242,36 +233,20 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c if (!Utf8Proc::IsValid(string_data, size)) { throw InternalException("Can't decode Variant short-string, string isn't valid UTF8"); } - return yyjson_mut_strncpy(doc, string_data, size); + return VariantValue(Value(string(string_data, size))); } case VariantPrimitiveType::TIME_NTZ_MICROS: { dtime_t micros_time; micros_time.micros = Load(data); - auto value_str = Time::ToString(micros_time); - return yyjson_mut_strcpy(doc, value_str.c_str()); + return VariantValue(Value::TIME(micros_time)); } case VariantPrimitiveType::TIMESTAMP_NANOS: { timestamp_ns_t nanos_ts; nanos_ts.value = Load(data); - //! Convert the nanos timestamp to a micros timestamp - date_t out_date; - dtime_t out_time; - int32_t out_nanos; - Timestamp::Convert(nanos_ts, out_date, out_time, out_nanos); - auto micros_ts = Timestamp::FromDatetime(out_date, out_time); - - //! Turn the micros timestamp into a micros_tz timestamp and serialize it - timestamp_tz_t micros_tz_ts(micros_ts.value); - auto value = Value::TIMESTAMPTZ(micros_tz_ts); - auto value_str = value.CastAs(context, LogicalType::VARCHAR).GetValue(); - - if (StringUtil::Contains(value_str, "+")) { - //! Don't attempt this for NaN/Inf timestamps - auto parts = StringUtil::Split(value_str, '+'); - value_str = StringUtil::Format("%s%s+%s", parts[0], to_string(out_nanos), parts[1]); - } - return yyjson_mut_strcpy(doc, value_str.c_str()); + //! Convert the nanos timestamp to a micros timestamp (not lossless) + auto micros_ts = Timestamp::FromEpochNanoSeconds(nanos_ts.value); + return VariantValue(Value::TIMESTAMPTZ(timestamp_tz_t(micros_ts))); } case VariantPrimitiveType::TIMESTAMP_NTZ_NANOS: { timestamp_ns_t nanos_ts; @@ -279,12 +254,12 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c auto value = Value::TIMESTAMPNS(nanos_ts); auto value_str = value.ToString(); - return yyjson_mut_strcpy(doc, value_str.c_str()); + return VariantValue(Value(value_str)); } case VariantPrimitiveType::UUID: { auto uuid_value = UUIDValueConversion::ReadParquetUUID(data); auto value_str = UUID::ToString(uuid_value); - return yyjson_mut_strcpy(doc, value_str.c_str()); + return VariantValue(Value(value_str)); } default: throw NotImplementedException("Variant PrimitiveTypeDecode not implemented for type (%d)", @@ -292,20 +267,20 @@ yyjson_mut_val *VariantBinaryDecoder::PrimitiveTypeDecode(yyjson_mut_doc *doc, c } } -yyjson_mut_val *VariantBinaryDecoder::ShortStringDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, - const_data_ptr_t data) { +VariantValue VariantBinaryDecoder::ShortStringDecode(const VariantMetadata &metadata, + const VariantValueMetadata &value_metadata, + const_data_ptr_t data) { D_ASSERT(value_metadata.string_size < 64); auto string_data = reinterpret_cast(data); if (!Utf8Proc::IsValid(string_data, value_metadata.string_size)) { throw InternalException("Can't decode Variant short-string, string isn't valid UTF8"); } - return yyjson_mut_strncpy(doc, string_data, value_metadata.string_size); + return VariantValue(Value(string(string_data, value_metadata.string_size))); } -yyjson_mut_val *VariantBinaryDecoder::ObjectDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data) { - auto obj = yyjson_mut_obj(doc); +VariantValue VariantBinaryDecoder::ObjectDecode(const VariantMetadata &metadata, + const VariantValueMetadata &value_metadata, const_data_ptr_t data) { + VariantValue ret(VariantValueType::OBJECT); auto field_offset_size = value_metadata.field_offset_size; auto field_id_size = value_metadata.field_id_size; @@ -329,17 +304,18 @@ yyjson_mut_val *VariantBinaryDecoder::ObjectDecode(yyjson_mut_doc *doc, const Va auto field_id = ReadVariableLengthLittleEndian(field_id_size, field_ids); auto next_offset = ReadVariableLengthLittleEndian(field_offset_size, field_offsets); - auto value = Decode(doc, metadata, values + last_offset); + auto value = Decode(metadata, values + last_offset); auto &key = metadata.strings[field_id]; - yyjson_mut_obj_add_val(doc, obj, key.c_str(), value); + + ret.AddChild(key, std::move(value)); last_offset = next_offset; } - return obj; + return ret; } -yyjson_mut_val *VariantBinaryDecoder::ArrayDecode(yyjson_mut_doc *doc, const VariantMetadata &metadata, - const VariantValueMetadata &value_metadata, const_data_ptr_t data) { - auto arr = yyjson_mut_arr(doc); +VariantValue VariantBinaryDecoder::ArrayDecode(const VariantMetadata &metadata, + const VariantValueMetadata &value_metadata, const_data_ptr_t data) { + VariantValue ret(VariantValueType::ARRAY); auto field_offset_size = value_metadata.field_offset_size; auto is_large = value_metadata.is_large; @@ -360,30 +336,28 @@ yyjson_mut_val *VariantBinaryDecoder::ArrayDecode(yyjson_mut_doc *doc, const Var for (idx_t i = 0; i < num_elements; i++) { auto next_offset = ReadVariableLengthLittleEndian(field_offset_size, field_offsets); - auto value = Decode(doc, metadata, values + last_offset); - yyjson_mut_arr_add_val(arr, value); + ret.AddItem(Decode(metadata, values + last_offset)); last_offset = next_offset; } - return arr; + return ret; } -yyjson_mut_val *VariantBinaryDecoder::Decode(yyjson_mut_doc *doc, const VariantMetadata &variant_metadata, - const_data_ptr_t data) { +VariantValue VariantBinaryDecoder::Decode(const VariantMetadata &variant_metadata, const_data_ptr_t data) { auto value_metadata = VariantValueMetadata::FromHeaderByte(data[0]); data++; switch (value_metadata.basic_type) { case VariantBasicType::PRIMITIVE: { - return PrimitiveTypeDecode(doc, variant_metadata, value_metadata, data); + return PrimitiveTypeDecode(variant_metadata, value_metadata, data); } case VariantBasicType::SHORT_STRING: { - return ShortStringDecode(doc, variant_metadata, value_metadata, data); + return ShortStringDecode(variant_metadata, value_metadata, data); } case VariantBasicType::OBJECT: { - return ObjectDecode(doc, variant_metadata, value_metadata, data); + return ObjectDecode(variant_metadata, value_metadata, data); } case VariantBasicType::ARRAY: { - return ArrayDecode(doc, variant_metadata, value_metadata, data); + return ArrayDecode(variant_metadata, value_metadata, data); } default: throw InternalException("Unexpected value for VariantBasicType"); diff --git a/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp new file mode 100644 index 000000000..8278eb740 --- /dev/null +++ b/src/duckdb/extension/parquet/reader/variant/variant_shredded_conversion.cpp @@ -0,0 +1,565 @@ +#include "reader/variant/variant_shredded_conversion.hpp" +#include "column_reader.hpp" +#include "utf8proc_wrapper.hpp" + +#include "duckdb/common/types/timestamp.hpp" +#include "duckdb/common/types/decimal.hpp" +#include "duckdb/common/types/uuid.hpp" +#include "duckdb/common/types/time.hpp" +#include "duckdb/common/types/date.hpp" +#include "duckdb/common/types/blob.hpp" + +namespace duckdb { + +template +struct ConvertShreddedValue { + static VariantValue Convert(T val); + static VariantValue ConvertDecimal(T val, uint8_t width, uint8_t scale) { + throw InternalException("ConvertShreddedValue::ConvertDecimal not implemented for type"); + } + static VariantValue ConvertBlob(T val) { + throw InternalException("ConvertShreddedValue::ConvertBlob not implemented for type"); + } +}; + +//! boolean +template <> +VariantValue ConvertShreddedValue::Convert(bool val) { + return VariantValue(Value::BOOLEAN(val)); +} +//! int8 +template <> +VariantValue ConvertShreddedValue::Convert(int8_t val) { + return VariantValue(Value::TINYINT(val)); +} +//! int16 +template <> +VariantValue ConvertShreddedValue::Convert(int16_t val) { + return VariantValue(Value::SMALLINT(val)); +} +//! int32 +template <> +VariantValue ConvertShreddedValue::Convert(int32_t val) { + return VariantValue(Value::INTEGER(val)); +} +//! int64 +template <> +VariantValue ConvertShreddedValue::Convert(int64_t val) { + return VariantValue(Value::BIGINT(val)); +} +//! float +template <> +VariantValue ConvertShreddedValue::Convert(float val) { + return VariantValue(Value::FLOAT(val)); +} +//! double +template <> +VariantValue ConvertShreddedValue::Convert(double val) { + return VariantValue(Value::DOUBLE(val)); +} +//! decimal4/decimal8/decimal16 +template <> +VariantValue ConvertShreddedValue::ConvertDecimal(int32_t val, uint8_t width, uint8_t scale) { + auto value_str = Decimal::ToString(val, width, scale); + return VariantValue(Value(value_str)); +} +template <> +VariantValue ConvertShreddedValue::ConvertDecimal(int64_t val, uint8_t width, uint8_t scale) { + auto value_str = Decimal::ToString(val, width, scale); + return VariantValue(Value(value_str)); +} +template <> +VariantValue ConvertShreddedValue::ConvertDecimal(hugeint_t val, uint8_t width, uint8_t scale) { + auto value_str = Decimal::ToString(val, width, scale); + return VariantValue(Value(value_str)); +} +//! date +template <> +VariantValue ConvertShreddedValue::Convert(date_t val) { + return VariantValue(Value::DATE(val)); +} +//! time +template <> +VariantValue ConvertShreddedValue::Convert(dtime_t val) { + return VariantValue(Value::TIME(val)); +} +//! timestamptz(6) +template <> +VariantValue ConvertShreddedValue::Convert(timestamp_tz_t val) { + return VariantValue(Value::TIMESTAMPTZ(val)); +} +////! timestamptz(9) +// template <> +// VariantValue ConvertShreddedValue::Convert(timestamp_ns_tz_t val) { +// return VariantValue(Value::TIMESTAMPNS_TZ(val)); +//} +//! timestampntz(6) +template <> +VariantValue ConvertShreddedValue::Convert(timestamp_t val) { + return VariantValue(Value::TIMESTAMP(val)); +} +//! timestampntz(9) +template <> +VariantValue ConvertShreddedValue::Convert(timestamp_ns_t val) { + return VariantValue(Value::TIMESTAMPNS(val)); +} +//! binary +template <> +VariantValue ConvertShreddedValue::ConvertBlob(string_t val) { + return VariantValue(Value(Blob::ToBase64(val))); +} +//! string +template <> +VariantValue ConvertShreddedValue::Convert(string_t val) { + if (!Utf8Proc::IsValid(val.GetData(), val.GetSize())) { + throw InternalException("Can't decode Variant string, it isn't valid UTF8"); + } + return VariantValue(Value(val.GetString())); +} +//! uuid +template <> +VariantValue ConvertShreddedValue::Convert(hugeint_t val) { + return VariantValue(Value(UUID::ToString(val))); +} + +template +vector ConvertTypedValues(Vector &vec, Vector &metadata, Vector &blob, idx_t offset, idx_t length, + idx_t total_size) { + UnifiedVectorFormat metadata_format; + metadata.ToUnifiedFormat(length, metadata_format); + auto metadata_data = metadata_format.GetData(metadata_format); + + UnifiedVectorFormat typed_format; + vec.ToUnifiedFormat(total_size, typed_format); + auto data = typed_format.GetData(typed_format); + + UnifiedVectorFormat value_format; + blob.ToUnifiedFormat(total_size, value_format); + auto value_data = value_format.GetData(value_format); + + auto &validity = typed_format.validity; + auto &value_validity = value_format.validity; + auto &type = vec.GetType(); + + //! Values only used for Decimal conversion + uint8_t width; + uint8_t scale; + if (TYPE_ID == LogicalTypeId::DECIMAL) { + type.GetDecimalProperties(width, scale); + } + + vector ret(length); + if (validity.AllValid()) { + for (idx_t i = 0; i < length; i++) { + auto index = typed_format.sel->get_index(i + offset); + if (TYPE_ID == LogicalTypeId::DECIMAL) { + ret[i] = OP::ConvertDecimal(data[index], width, scale); + } else if (TYPE_ID == LogicalTypeId::BLOB) { + ret[i] = OP::ConvertBlob(data[index]); + } else { + ret[i] = OP::Convert(data[index]); + } + } + } else { + for (idx_t i = 0; i < length; i++) { + auto typed_index = typed_format.sel->get_index(i + offset); + auto value_index = value_format.sel->get_index(i + offset); + if (validity.RowIsValid(typed_index)) { + //! This is a leaf, partially shredded values aren't possible here + D_ASSERT(!value_validity.RowIsValid(value_index)); + if (TYPE_ID == LogicalTypeId::DECIMAL) { + ret[i] = OP::ConvertDecimal(data[typed_index], width, scale); + } else if (TYPE_ID == LogicalTypeId::BLOB) { + ret[i] = OP::ConvertBlob(data[typed_index]); + } else { + ret[i] = OP::Convert(data[typed_index]); + } + } else if (value_validity.RowIsValid(value_index)) { + auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; + VariantMetadata variant_metadata(metadata_value); + ret[i] = VariantBinaryDecoder::Decode(variant_metadata, + const_data_ptr_cast(value_data[value_index].GetData())); + } + } + } + return ret; +} + +vector VariantShreddedConversion::ConvertShreddedLeaf(Vector &metadata, Vector &value, + Vector &typed_value, idx_t offset, idx_t length, + idx_t total_size) { + D_ASSERT(!typed_value.GetType().IsNested()); + vector result; + + auto &type = typed_value.GetType(); + switch (type.id()) { + //! boolean + case LogicalTypeId::BOOLEAN: { + return ConvertTypedValues, LogicalTypeId::BOOLEAN>( + typed_value, metadata, value, offset, length, total_size); + } + //! int8 + case LogicalTypeId::TINYINT: { + return ConvertTypedValues, LogicalTypeId::TINYINT>( + typed_value, metadata, value, offset, length, total_size); + } + //! int16 + case LogicalTypeId::SMALLINT: { + return ConvertTypedValues, LogicalTypeId::SMALLINT>( + typed_value, metadata, value, offset, length, total_size); + } + //! int32 + case LogicalTypeId::INTEGER: { + return ConvertTypedValues, LogicalTypeId::INTEGER>( + typed_value, metadata, value, offset, length, total_size); + } + //! int64 + case LogicalTypeId::BIGINT: { + return ConvertTypedValues, LogicalTypeId::BIGINT>( + typed_value, metadata, value, offset, length, total_size); + } + //! float + case LogicalTypeId::FLOAT: { + return ConvertTypedValues, LogicalTypeId::FLOAT>( + typed_value, metadata, value, offset, length, total_size); + } + //! double + case LogicalTypeId::DOUBLE: { + return ConvertTypedValues, LogicalTypeId::DOUBLE>( + typed_value, metadata, value, offset, length, total_size); + } + //! decimal4/decimal8/decimal16 + case LogicalTypeId::DECIMAL: { + auto physical_type = type.InternalType(); + switch (physical_type) { + case PhysicalType::INT32: { + return ConvertTypedValues, LogicalTypeId::DECIMAL>( + typed_value, metadata, value, offset, length, total_size); + } + case PhysicalType::INT64: { + return ConvertTypedValues, LogicalTypeId::DECIMAL>( + typed_value, metadata, value, offset, length, total_size); + } + case PhysicalType::INT128: { + return ConvertTypedValues, LogicalTypeId::DECIMAL>( + typed_value, metadata, value, offset, length, total_size); + } + default: + throw NotImplementedException("Decimal with PhysicalType (%s) not implemented for shredded Variant", + EnumUtil::ToString(physical_type)); + } + } + //! date + case LogicalTypeId::DATE: { + return ConvertTypedValues, LogicalTypeId::DATE>( + typed_value, metadata, value, offset, length, total_size); + } + //! time + case LogicalTypeId::TIME: { + return ConvertTypedValues, LogicalTypeId::TIME>( + typed_value, metadata, value, offset, length, total_size); + } + //! timestamptz(6) (timestamptz(9) not implemented in DuckDB) + case LogicalTypeId::TIMESTAMP_TZ: { + return ConvertTypedValues, LogicalTypeId::TIMESTAMP_TZ>( + typed_value, metadata, value, offset, length, total_size); + } + //! timestampntz(6) + case LogicalTypeId::TIMESTAMP: { + return ConvertTypedValues, LogicalTypeId::TIMESTAMP>( + typed_value, metadata, value, offset, length, total_size); + } + //! timestampntz(9) + case LogicalTypeId::TIMESTAMP_NS: { + return ConvertTypedValues, LogicalTypeId::TIMESTAMP_NS>( + typed_value, metadata, value, offset, length, total_size); + } + //! binary + case LogicalTypeId::BLOB: { + return ConvertTypedValues, LogicalTypeId::BLOB>( + typed_value, metadata, value, offset, length, total_size); + } + //! string + case LogicalTypeId::VARCHAR: { + return ConvertTypedValues, LogicalTypeId::VARCHAR>( + typed_value, metadata, value, offset, length, total_size); + } + //! uuid + case LogicalTypeId::UUID: { + return ConvertTypedValues, LogicalTypeId::UUID>( + typed_value, metadata, value, offset, length, total_size); + } + default: + throw NotImplementedException("Variant shredding on type: '%s' is not implemented", type.ToString()); + } +} + +namespace { + +struct ShreddedVariantField { +public: + explicit ShreddedVariantField(const string &field_name) : field_name(field_name) { + } + +public: + string field_name; + //! Values for the field, for all rows + vector values; +}; + +} // namespace + +template +static vector ConvertBinaryEncoding(Vector &metadata, Vector &value, idx_t offset, idx_t length, + idx_t total_size) { + UnifiedVectorFormat value_format; + value.ToUnifiedFormat(total_size, value_format); + auto value_data = value_format.GetData(value_format); + auto &validity = value_format.validity; + + UnifiedVectorFormat metadata_format; + metadata.ToUnifiedFormat(length, metadata_format); + auto metadata_data = metadata_format.GetData(metadata_format); + auto metadata_validity = metadata_format.validity; + + vector ret(length); + if (IS_REQUIRED) { + for (idx_t i = 0; i < length; i++) { + auto index = value_format.sel->get_index(i + offset); + + // Variant itself is NULL + if (!validity.RowIsValid(index) && !metadata_validity.RowIsValid(metadata_format.sel->get_index(i))) { + ret[i] = VariantValue(Value()); + continue; + } + + D_ASSERT(validity.RowIsValid(index)); + auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; + VariantMetadata variant_metadata(metadata_value); + auto binary_value = value_data[index].GetData(); + ret[i] = VariantBinaryDecoder::Decode(variant_metadata, const_data_ptr_cast(binary_value)); + } + } else { + //! Even though 'typed_value' is not present, 'value' is allowed to contain NULLs because we're scanning an + //! Object's shredded field. + //! When 'value' is null for a row, that means the Object does not contain this field + //! for that row. + for (idx_t i = 0; i < length; i++) { + auto index = value_format.sel->get_index(i + offset); + if (validity.RowIsValid(index)) { + auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; + VariantMetadata variant_metadata(metadata_value); + auto binary_value = value_data[index].GetData(); + ret[i] = VariantBinaryDecoder::Decode(variant_metadata, const_data_ptr_cast(binary_value)); + } + } + } + return ret; +} + +static VariantValue ConvertPartiallyShreddedObject(vector &shredded_fields, + const UnifiedVectorFormat &metadata_format, + const UnifiedVectorFormat &value_format, idx_t i, idx_t offset) { + auto ret = VariantValue(VariantValueType::OBJECT); + auto index = value_format.sel->get_index(i + offset); + auto value_data = value_format.GetData(value_format); + auto metadata_data = metadata_format.GetData(metadata_format); + auto &value_validity = value_format.validity; + + for (idx_t field_index = 0; field_index < shredded_fields.size(); field_index++) { + auto &shredded_field = shredded_fields[field_index]; + auto &field_value = shredded_field.values[i]; + + if (field_value.IsMissing()) { + //! This field is missing from the value, skip it + continue; + } + ret.AddChild(shredded_field.field_name, std::move(field_value)); + } + + if (value_validity.RowIsValid(index)) { + //! Object is partially shredded, decode the object and merge the values + auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; + VariantMetadata variant_metadata(metadata_value); + auto binary_value = value_data[index].GetData(); + auto unshredded = VariantBinaryDecoder::Decode(variant_metadata, const_data_ptr_cast(binary_value)); + if (unshredded.value_type != VariantValueType::OBJECT) { + throw InvalidInputException("Partially shredded objects have to encode Object Variants in the 'value'"); + } + for (auto &item : unshredded.object_children) { + ret.AddChild(item.first, std::move(item.second)); + } + } + return ret; +} + +vector VariantShreddedConversion::ConvertShreddedObject(Vector &metadata, Vector &value, + Vector &typed_value, idx_t offset, idx_t length, + idx_t total_size) { + auto &type = typed_value.GetType(); + D_ASSERT(type.id() == LogicalTypeId::STRUCT); + auto &fields = StructType::GetChildTypes(type); + auto &entries = StructVector::GetEntries(typed_value); + D_ASSERT(entries.size() == fields.size()); + + //! 'value' + UnifiedVectorFormat value_format; + value.ToUnifiedFormat(total_size, value_format); + auto value_data = value_format.GetData(value_format); + auto &validity = value_format.validity; + + //! 'metadata' + UnifiedVectorFormat metadata_format; + metadata.ToUnifiedFormat(length, metadata_format); + auto metadata_data = metadata_format.GetData(metadata_format); + + //! 'typed_value' + UnifiedVectorFormat typed_format; + typed_value.ToUnifiedFormat(total_size, typed_format); + auto &typed_validity = typed_format.validity; + + //! Process all fields to get the shredded field values + vector shredded_fields; + shredded_fields.reserve(fields.size()); + for (idx_t i = 0; i < fields.size(); i++) { + auto &field = fields[i]; + auto &field_name = field.first; + auto &field_vec = *entries[i]; + + shredded_fields.emplace_back(field_name); + auto &shredded_field = shredded_fields.back(); + shredded_field.values = Convert(metadata, field_vec, offset, length, total_size, true); + } + + vector ret(length); + if (typed_validity.AllValid()) { + for (idx_t i = 0; i < length; i++) { + ret[i] = ConvertPartiallyShreddedObject(shredded_fields, metadata_format, value_format, i, offset); + } + } else { + //! For some of the rows, the value is not an object + for (idx_t i = 0; i < length; i++) { + auto typed_index = typed_format.sel->get_index(i + offset); + auto value_index = value_format.sel->get_index(i + offset); + if (typed_validity.RowIsValid(typed_index)) { + ret[i] = ConvertPartiallyShreddedObject(shredded_fields, metadata_format, value_format, i, offset); + } else { + //! The value on this row is not an object, and guaranteed to be present + D_ASSERT(validity.RowIsValid(value_index)); + auto &metadata_value = metadata_data[metadata_format.sel->get_index(i)]; + VariantMetadata variant_metadata(metadata_value); + auto binary_value = value_data[value_index].GetData(); + ret[i] = VariantBinaryDecoder::Decode(variant_metadata, const_data_ptr_cast(binary_value)); + if (ret[i].value_type == VariantValueType::OBJECT) { + throw InvalidInputException( + "When 'typed_value' for a shredded Object is NULL, 'value' can not contain an Object value"); + } + } + } + } + return ret; +} + +vector VariantShreddedConversion::ConvertShreddedArray(Vector &metadata, Vector &value, + Vector &typed_value, idx_t offset, idx_t length, + idx_t total_size) { + auto &child = ListVector::GetEntry(typed_value); + auto list_size = ListVector::GetListSize(typed_value); + + //! 'value' + UnifiedVectorFormat value_format; + value.ToUnifiedFormat(total_size, value_format); + auto value_data = value_format.GetData(value_format); + + //! 'metadata' + UnifiedVectorFormat metadata_format; + metadata.ToUnifiedFormat(length, metadata_format); + auto metadata_data = metadata_format.GetData(metadata_format); + + //! 'typed_value' + UnifiedVectorFormat list_format; + typed_value.ToUnifiedFormat(total_size, list_format); + auto list_data = list_format.GetData(list_format); + auto &validity = list_format.validity; + auto &value_validity = value_format.validity; + + vector ret(length); + if (validity.AllValid()) { + //! We can be sure that none of the values are binary encoded + for (idx_t i = 0; i < length; i++) { + auto typed_index = list_format.sel->get_index(i + offset); + //! FIXME: next 4 lines duplicated below + auto entry = list_data[typed_index]; + Vector child_metadata(metadata.GetValue(i)); + ret[i] = VariantValue(VariantValueType::ARRAY); + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); + } + } else { + for (idx_t i = 0; i < length; i++) { + auto typed_index = list_format.sel->get_index(i + offset); + auto value_index = value_format.sel->get_index(i + offset); + if (validity.RowIsValid(typed_index)) { + //! FIXME: next 4 lines duplicate + auto entry = list_data[typed_index]; + Vector child_metadata(metadata.GetValue(i)); + ret[i] = VariantValue(VariantValueType::ARRAY); + ret[i].array_items = Convert(child_metadata, child, entry.offset, entry.length, list_size); + } else if (value_validity.RowIsValid(value_index)) { + auto metadata_value = metadata_data[metadata_format.sel->get_index(i)]; + VariantMetadata variant_metadata(metadata_value); + ret[i] = VariantBinaryDecoder::Decode(variant_metadata, + const_data_ptr_cast(value_data[value_index].GetData())); + } + } + } + return ret; +} + +vector VariantShreddedConversion::Convert(Vector &metadata, Vector &group, idx_t offset, idx_t length, + idx_t total_size, bool is_field) { + D_ASSERT(group.GetType().id() == LogicalTypeId::STRUCT); + + auto &group_entries = StructVector::GetEntries(group); + auto &group_type_children = StructType::GetChildTypes(group.GetType()); + D_ASSERT(group_type_children.size() == group_entries.size()); + + //! From the spec: + //! The Parquet columns used to store variant metadata and values must be accessed by name, not by position. + optional_ptr value; + optional_ptr typed_value; + for (idx_t i = 0; i < group_entries.size(); i++) { + auto &name = group_type_children[i].first; + auto &vec = group_entries[i]; + if (name == "value") { + value = vec.get(); + } else if (name == "typed_value") { + typed_value = vec.get(); + } else { + throw InvalidInputException("Variant group can only contain 'value'/'typed_value', not: %s", name); + } + } + if (!value) { + throw InvalidInputException("Required column 'value' not found in Variant group"); + } + + if (typed_value) { + auto &type = typed_value->GetType(); + vector ret; + if (type.id() == LogicalTypeId::STRUCT) { + return ConvertShreddedObject(metadata, *value, *typed_value, offset, length, total_size); + } else if (type.id() == LogicalTypeId::LIST) { + return ConvertShreddedArray(metadata, *value, *typed_value, offset, length, total_size); + } else { + return ConvertShreddedLeaf(metadata, *value, *typed_value, offset, length, total_size); + } + } else { + if (is_field) { + return ConvertBinaryEncoding(metadata, *value, offset, length, total_size); + } else { + //! Only 'value' is present, we can assume this to be 'required', so it can't contain NULLs + return ConvertBinaryEncoding(metadata, *value, offset, length, total_size); + } + } +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/variant/variant_value.cpp b/src/duckdb/extension/parquet/reader/variant/variant_value.cpp new file mode 100644 index 000000000..0ac213469 --- /dev/null +++ b/src/duckdb/extension/parquet/reader/variant/variant_value.cpp @@ -0,0 +1,85 @@ +#include "reader/variant/variant_value.hpp" + +namespace duckdb { + +void VariantValue::AddChild(const string &key, VariantValue &&val) { + D_ASSERT(value_type == VariantValueType::OBJECT); + object_children.emplace(key, std::move(val)); +} + +void VariantValue::AddItem(VariantValue &&val) { + D_ASSERT(value_type == VariantValueType::ARRAY); + array_items.push_back(std::move(val)); +} + +yyjson_mut_val *VariantValue::ToJSON(ClientContext &context, yyjson_mut_doc *doc) const { + switch (value_type) { + case VariantValueType::PRIMITIVE: { + if (primitive_value.IsNull()) { + return yyjson_mut_null(doc); + } + switch (primitive_value.type().id()) { + case LogicalTypeId::BOOLEAN: { + if (primitive_value.GetValue()) { + return yyjson_mut_true(doc); + } else { + return yyjson_mut_false(doc); + } + } + case LogicalTypeId::TINYINT: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::SMALLINT: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::INTEGER: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::BIGINT: + return yyjson_mut_int(doc, primitive_value.GetValue()); + case LogicalTypeId::FLOAT: + return yyjson_mut_real(doc, primitive_value.GetValue()); + case LogicalTypeId::DOUBLE: + return yyjson_mut_real(doc, primitive_value.GetValue()); + case LogicalTypeId::DATE: + case LogicalTypeId::TIME: + case LogicalTypeId::VARCHAR: { + auto value_str = primitive_value.ToString(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + case LogicalTypeId::TIMESTAMP: { + auto value_str = primitive_value.ToString(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + case LogicalTypeId::TIMESTAMP_TZ: { + auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + case LogicalTypeId::TIMESTAMP_NS: { + auto value_str = primitive_value.CastAs(context, LogicalType::VARCHAR).GetValue(); + return yyjson_mut_strncpy(doc, value_str.c_str(), value_str.size()); + } + default: + throw InternalException("Unexpected primitive type: %s", primitive_value.type().ToString()); + } + } + case VariantValueType::OBJECT: { + auto obj = yyjson_mut_obj(doc); + for (const auto &it : object_children) { + auto &key = it.first; + auto value = it.second.ToJSON(context, doc); + yyjson_mut_obj_add_val(doc, obj, key.c_str(), value); + } + return obj; + } + case VariantValueType::ARRAY: { + auto arr = yyjson_mut_arr(doc); + for (auto &item : array_items) { + auto value = item.ToJSON(context, doc); + yyjson_mut_arr_add_val(arr, value); + } + return arr; + } + default: + throw InternalException("Can't serialize this VariantValue type to JSON"); + } +} + +} // namespace duckdb diff --git a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp index 80f9dfedd..402bcbb07 100644 --- a/src/duckdb/extension/parquet/reader/variant_column_reader.cpp +++ b/src/duckdb/extension/parquet/reader/variant_column_reader.cpp @@ -1,5 +1,6 @@ #include "reader/variant_column_reader.hpp" #include "reader/variant/variant_binary_decoder.hpp" +#include "reader/variant/variant_shredded_conversion.hpp" namespace duckdb { @@ -11,6 +12,16 @@ VariantColumnReader::VariantColumnReader(ClientContext &context, ParquetReader & vector> child_readers_p) : ColumnReader(reader, schema), context(context), child_readers(std::move(child_readers_p)) { D_ASSERT(Type().InternalType() == PhysicalType::VARCHAR); + + if (child_readers[0]->Schema().name == "metadata" && child_readers[1]->Schema().name == "value") { + metadata_reader_idx = 0; + value_reader_idx = 1; + } else if (child_readers[1]->Schema().name == "metadata" && child_readers[0]->Schema().name == "value") { + metadata_reader_idx = 1; + value_reader_idx = 0; + } else { + throw InternalException("The Variant column must have 'metadata' and 'value' as the first two columns"); + } } ColumnReader &VariantColumnReader::GetChildReader(idx_t child_idx) { @@ -31,52 +42,73 @@ void VariantColumnReader::InitializeRead(idx_t row_group_idx_p, const vector typed_value) { + child_list_t children; + children.emplace_back("value", LogicalType::BLOB); + if (typed_value) { + children.emplace_back("typed_value", typed_value->Type()); } + return LogicalType::STRUCT(std::move(children)); +} +idx_t VariantColumnReader::Read(uint64_t num_values, data_ptr_t define_out, data_ptr_t repeat_out, Vector &result) { if (pending_skips > 0) { throw InternalException("VariantColumnReader cannot have pending skips"); } + optional_ptr typed_value_reader = child_readers.size() == 3 ? child_readers[2].get() : nullptr; // If the child reader values are all valid, "define_out" may not be initialized at all // So, we just initialize them to all be valid beforehand std::fill_n(define_out, num_values, MaxDefine()); optional_idx read_count; - Vector value_intermediate(LogicalType::BLOB, num_values); + Vector metadata_intermediate(LogicalType::BLOB, num_values); - auto metadata_values = child_readers[0]->Read(num_values, define_out, repeat_out, metadata_intermediate); - auto value_values = child_readers[1]->Read(num_values, define_out, repeat_out, value_intermediate); + Vector intermediate_group(GetIntermediateGroupType(typed_value_reader), num_values); + auto &group_entries = StructVector::GetEntries(intermediate_group); + auto &value_intermediate = *group_entries[0]; + + auto metadata_values = + child_readers[metadata_reader_idx]->Read(num_values, define_out, repeat_out, metadata_intermediate); + auto value_values = child_readers[value_reader_idx]->Read(num_values, define_out, repeat_out, value_intermediate); + + D_ASSERT(child_readers[metadata_reader_idx]->Schema().name == "metadata"); + D_ASSERT(child_readers[value_reader_idx]->Schema().name == "value"); + if (metadata_values != value_values) { throw InvalidInputException( - "The unshredded Variant column did not contain the same amount of values for 'metadata' and 'value'"); + "The Variant column did not contain the same amount of values for 'metadata' and 'value'"); } - VariantBinaryDecoder decoder(context); - auto result_data = FlatVector::GetData(result); - auto metadata_intermediate_data = FlatVector::GetData(metadata_intermediate); - auto value_intermediate_data = FlatVector::GetData(value_intermediate); - - auto metadata_validity = FlatVector::Validity(metadata_intermediate); - auto value_validity = FlatVector::Validity(value_intermediate); - for (idx_t i = 0; i < num_values; i++) { - if (!metadata_validity.RowIsValid(i) || !value_validity.RowIsValid(i)) { - throw InvalidInputException("The Variant 'metadata' and 'value' columns can not produce NULL values"); + auto &result_validity = FlatVector::Validity(result); + + vector conversion_result; + if (typed_value_reader) { + auto typed_values = typed_value_reader->Read(num_values, define_out, repeat_out, *group_entries[1]); + if (typed_values != value_values) { + throw InvalidInputException( + "The shredded Variant column did not contain the same amount of values for 'typed_value' and 'value'"); + } + } + conversion_result = + VariantShreddedConversion::Convert(metadata_intermediate, intermediate_group, 0, num_values, num_values); + + for (idx_t i = 0; i < conversion_result.size(); i++) { + auto &variant = conversion_result[i]; + if (variant.IsNull()) { + result_validity.SetInvalid(i); + continue; } - VariantMetadata variant_metadata(metadata_intermediate_data[i]); - auto value_data = reinterpret_cast(value_intermediate_data[i].GetData()); + //! Write the result to a string VariantDecodeResult decode_result; decode_result.doc = yyjson_mut_doc_new(nullptr); + auto json_val = variant.ToJSON(context, decode_result.doc); - auto val = decoder.Decode(decode_result.doc, variant_metadata, value_data); - - //! Write the result to a string size_t len; - decode_result.data = yyjson_mut_val_write_opts(val, YYJSON_WRITE_ALLOW_INF_AND_NAN, nullptr, &len, nullptr); + decode_result.data = + yyjson_mut_val_write_opts(json_val, YYJSON_WRITE_ALLOW_INF_AND_NAN, nullptr, &len, nullptr); if (!decode_result.data) { throw InvalidInputException("Could not serialize the JSON to string, yyjson failed"); } diff --git a/src/duckdb/src/catalog/catalog.cpp b/src/duckdb/src/catalog/catalog.cpp index bb895d4d7..eaf0a61e0 100644 --- a/src/duckdb/src/catalog/catalog.cpp +++ b/src/duckdb/src/catalog/catalog.cpp @@ -39,6 +39,7 @@ #include "duckdb/function/built_in_functions.hpp" #include "duckdb/catalog/similar_catalog_entry.hpp" #include "duckdb/storage/database_size.hpp" +#include "duckdb/main/settings.hpp" #include namespace duckdb { @@ -534,14 +535,14 @@ bool Catalog::TryAutoLoad(ClientContext &context, const string &original_name) n return false; } -void Catalog::AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name) { +string Catalog::AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name) { #ifndef DUCKDB_DISABLE_EXTENSION_LOAD auto &dbconfig = DBConfig::GetConfig(context); if (dbconfig.options.autoload_known_extensions) { auto extension_name = ExtensionHelper::FindExtensionInEntries(configuration_name, EXTENSION_SETTINGS); if (ExtensionHelper::CanAutoloadExtension(extension_name)) { ExtensionHelper::AutoLoadExtension(context, extension_name); - return; + return extension_name; } } #endif @@ -649,9 +650,7 @@ CatalogException Catalog::CreateMissingEntryException(CatalogEntryRetriever &ret const reference_set_t &schemas) { auto &context = retriever.GetContext(); auto entries = SimilarEntriesInSchemas(context, lookup_info, schemas); - - auto &config = DBConfig::GetConfig(context); - auto max_schema_count = config.GetSetting(context); + auto max_schema_count = DBConfig::GetSetting(context); reference_set_t unseen_schemas; auto &db_manager = DatabaseManager::Get(context); @@ -1140,6 +1139,16 @@ vector> Catalog::GetAllSchemas(ClientContext &cont return result; } +vector> Catalog::GetAllEntries(ClientContext &context, CatalogType catalog_type) { + vector> result; + auto schemas = GetAllSchemas(context); + for (const auto &schema_ref : schemas) { + auto &schema = schema_ref.get(); + schema.Scan(context, catalog_type, [&](CatalogEntry &entry) { result.push_back(entry); }); + } + return result; +} + void Catalog::Alter(CatalogTransaction transaction, AlterInfo &info) { if (transaction.HasContext()) { CatalogEntryRetriever retriever(transaction.GetContext()); diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp index b9345b529..a1f23aa7f 100644 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp @@ -855,12 +855,12 @@ unique_ptr DuckTableEntry::RenameField(ClientContext &context, Ren if (!ColumnExists(info.column_path[0])) { throw CatalogException("Cannot rename field from column \"%s\" - it does not exist", info.column_path[0]); } + // follow the path auto &col = GetColumn(info.column_path[0]); auto res = RenameFieldFromStruct(col.Type(), info.column_path, info.new_name, 1); if (res.error.HasError()) { res.error.Throw(); - return nullptr; } // construct the struct remapping expression @@ -871,7 +871,6 @@ unique_ptr DuckTableEntry::RenameField(ClientContext &context, Ren children.push_back(make_uniq(Value())); auto function = make_uniq("remap_struct", std::move(children)); - ChangeColumnTypeInfo change_column_type(info.GetAlterEntryData(), info.column_path[0], std::move(res.new_type), std::move(function)); return ChangeColumnType(context, change_column_type); diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp index 0261d5b2f..22a173fd8 100644 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp @@ -247,21 +247,21 @@ void LogicalUpdate::BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, } } if (found_column_count > 0 && found_column_count != bound_columns.size()) { - // columns in this CHECK constraint were referenced, but not all were part of the UPDATE + // columns that were required are not all part of the UPDATE // add them to the scan and update set - for (auto &check_column_id : bound_columns) { - if (found_columns.find(check_column_id) != found_columns.end()) { + for (auto &physical_id : bound_columns) { + if (found_columns.find(physical_id) != found_columns.end()) { // column is already projected continue; } // column is not projected yet: project it by adding the clause "i=i" to the set of updated columns - auto &column = table.GetColumns().GetColumn(check_column_id); + auto &column = table.GetColumns().GetColumn(physical_id); update.expressions.push_back(make_uniq( column.Type(), ColumnBinding(proj.table_index, proj.expressions.size()))); proj.expressions.push_back(make_uniq( column.Type(), ColumnBinding(get.table_index, get.GetColumnIds().size()))); - get.AddColumnId(check_column_id.index); - update.columns.push_back(check_column_id); + get.AddColumnId(column.Logical().index); + update.columns.push_back(physical_id); } } } diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp index 3da768021..054eaaf0f 100644 --- a/src/duckdb/src/common/adbc/adbc.cpp +++ b/src/duckdb/src/common/adbc/adbc.cpp @@ -8,12 +8,7 @@ #include "duckdb/common/arrow/arrow_wrapper.hpp" #include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" -#include "duckdb/main/capi/capi_internal.hpp" - -#ifndef DUCKDB_AMALGAMATION #include "duckdb/main/connection.hpp" -#endif - #include "duckdb/common/adbc/options.h" #include "duckdb/common/adbc/single_batch_array_stream.hpp" #include "duckdb/function/table/arrow.hpp" @@ -21,6 +16,8 @@ #include #include +#include "duckdb/main/prepared_statement_data.hpp" + // We must leak the symbols of the init function AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *error) { if (!driver) { @@ -62,7 +59,6 @@ enum class IngestionMode { CREATE = 0, APPEND = 1 }; struct DuckDBAdbcStatementWrapper { duckdb_connection connection; - duckdb_arrow result; duckdb_prepared_statement statement; char *ingestion_table_name; char *db_schema; @@ -72,6 +68,10 @@ struct DuckDBAdbcStatementWrapper { uint64_t plan_length; }; +struct DuckDBAdbcStreamWrapper { + duckdb_result result; +}; + static AdbcStatusCode QueryInternal(struct AdbcConnection *connection, struct ArrowArrayStream *out, const char *query, struct AdbcError *error) { AdbcStatement statement; @@ -533,8 +533,31 @@ static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) if (!stream || !stream->private_data || !out) { return DuckDBError; } - return duckdb_query_arrow_schema(static_cast(stream->private_data), - reinterpret_cast(&out)); + auto result_wrapper = static_cast(stream->private_data); + auto count = duckdb_column_count(&result_wrapper->result); + std::vector types(count); + + std::vector owned_names(count); + duckdb::vector names(count); + for (idx_t i = 0; i < count; i++) { + types[i] = duckdb_column_logical_type(&result_wrapper->result, i); + auto column_name = duckdb_column_name(&result_wrapper->result, i); + owned_names.emplace_back(column_name); + names[i] = owned_names.back().c_str(); + } + + auto arrow_options = duckdb_result_get_arrow_options(&result_wrapper->result); + + auto res = duckdb_to_arrow_schema(arrow_options, &types[0], names.data(), count, out); + duckdb_destroy_arrow_options(&arrow_options); + for (auto &type : types) { + duckdb_destroy_logical_type(&type); + } + if (res) { + duckdb_destroy_error_data(&res); + return DuckDBError; + } + return DuckDBSuccess; } static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { @@ -542,28 +565,39 @@ static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { return DuckDBError; } out->release = nullptr; + auto result_wrapper = static_cast(stream->private_data); + auto duckdb_chunk = duckdb_fetch_chunk(result_wrapper->result); + if (!duckdb_chunk) { + return DuckDBSuccess; + } + auto arrow_options = duckdb_result_get_arrow_options(&result_wrapper->result); + + auto conversion_success = duckdb_data_chunk_to_arrow(arrow_options, duckdb_chunk, out); + duckdb_destroy_arrow_options(&arrow_options); + duckdb_destroy_data_chunk(&duckdb_chunk); - return duckdb_query_arrow_array(static_cast(stream->private_data), - reinterpret_cast(&out)); + if (conversion_success) { + duckdb_destroy_error_data(&conversion_success); + return DuckDBError; + } + return DuckDBSuccess; } void release(struct ArrowArrayStream *stream) { if (!stream || !stream->release) { return; } - if (stream->private_data) { - duckdb_destroy_arrow(reinterpret_cast(&stream->private_data)); - stream->private_data = nullptr; + auto result_wrapper = reinterpret_cast(stream->private_data); + if (result_wrapper) { + duckdb_destroy_result(&result_wrapper->result); } + free(stream->private_data); + stream->private_data = nullptr; stream->release = nullptr; } const char *get_last_error(struct ArrowArrayStream *stream) { - if (!stream) { - return nullptr; - } return nullptr; - // return duckdb_query_arrow_error(stream); } // this is an evil hack, normally we would need a stream factory here, but its probably much easier if the adbc clients @@ -605,44 +639,65 @@ AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, cons return ADBC_STATUS_INVALID_ARGUMENT; } - auto cconn = reinterpret_cast(connection); + duckdb::ArrowSchemaWrapper arrow_schema_wrapper; + ConvertedSchemaWrapper out_types; - auto arrow_scan = - cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER(reinterpret_cast(input)), - duckdb::Value::POINTER(reinterpret_cast(stream_produce)), - duckdb::Value::POINTER(reinterpret_cast(stream_schema))}); - try { - switch (ingestion_mode) { - case IngestionMode::CREATE: - if (schema) { - arrow_scan->Create(schema, table_name, temporary); - } else { - arrow_scan->Create(table_name, temporary); - } - break; - case IngestionMode::APPEND: { - arrow_scan->CreateView("temp_adbc_view", true, true); - std::string query = "insert into "; - if (schema) { - query += duckdb::KeywordHelper::WriteOptionallyQuoted(schema) + "."; - } - query += duckdb::KeywordHelper::WriteOptionallyQuoted(table_name); - query += " select * from temp_adbc_view"; - auto result = cconn->Query(query); - break; + input->get_schema(input, &arrow_schema_wrapper.arrow_schema); + auto res = duckdb_schema_from_arrow(connection, &arrow_schema_wrapper.arrow_schema, out_types.GetPtr()); + if (res) { + SetError(error, duckdb_error_data_message(res)); + duckdb_destroy_error_data(&res); + return ADBC_STATUS_INTERNAL; + } + + auto &d_converted_schema = *reinterpret_cast(out_types.Get()); + auto types = d_converted_schema.GetTypes(); + auto names = d_converted_schema.GetNames(); + + if (ingestion_mode == IngestionMode::CREATE) { + // We must construct the create table SQL query + std::ostringstream create_table; + create_table << "CREATE TABLE "; + if (schema) { + create_table << schema << "."; } + create_table << table_name << " ("; + for (idx_t i = 0; i < types.size(); i++) { + create_table << names[i] << " "; + create_table << types[i].ToString(); + if (i + 1 < types.size()) { + create_table << ", "; + } } - // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid - // double-releasing it - input->release = nullptr; - } catch (std::exception &ex) { - if (error) { - duckdb::ErrorData parsed_error(ex); - error->message = strdup(parsed_error.RawMessage().c_str()); + create_table << ");"; + duckdb_result result; + if (duckdb_query(connection, create_table.str().c_str(), &result) == DuckDBError) { + SetError(error, duckdb_result_error(&result)); + duckdb_destroy_result(&result); + return ADBC_STATUS_INTERNAL; } + duckdb_destroy_result(&result); + } + AppenderWrapper appender(connection, schema, table_name); + if (!appender.Valid()) { return ADBC_STATUS_INTERNAL; - } catch (...) { - return ADBC_STATUS_INTERNAL; + } + duckdb::ArrowArrayWrapper arrow_array_wrapper; + + input->get_next(input, &arrow_array_wrapper.arrow_array); + while (arrow_array_wrapper.arrow_array.release) { + DataChunkWrapper out_chunk; + auto res = duckdb_data_chunk_from_arrow(connection, &arrow_array_wrapper.arrow_array, out_types.Get(), + &out_chunk.chunk); + if (res) { + SetError(error, duckdb_error_data_message(res)); + duckdb_destroy_error_data(&res); + } + if (duckdb_append_data_chunk(appender.Get(), out_chunk.chunk) != DuckDBSuccess) { + return ADBC_STATUS_INTERNAL; + } + arrow_array_wrapper = duckdb::ArrowArrayWrapper(); + input->get_next(input, &arrow_array_wrapper.arrow_array); } return ADBC_STATUS_OK; } @@ -675,7 +730,6 @@ AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatem statement_wrapper->connection = conn_wrapper->connection; statement_wrapper->statement = nullptr; - statement_wrapper->result = nullptr; statement_wrapper->ingestion_stream.release = nullptr; statement_wrapper->ingestion_table_name = nullptr; statement_wrapper->db_schema = nullptr; @@ -694,10 +748,6 @@ AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcErro duckdb_destroy_prepare(&wrapper->statement); wrapper->statement = nullptr; } - if (wrapper->result) { - duckdb_destroy_arrow(&wrapper->result); - wrapper->result = nullptr; - } if (wrapper->ingestion_stream.release) { wrapper->ingestion_stream.release(&wrapper->ingestion_stream); wrapper->ingestion_stream.release = nullptr; @@ -732,35 +782,44 @@ AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, stru auto wrapper = static_cast(statement->private_data); // TODO: we might want to cache this, but then we need to return a deep copy anyways.., so I'm not sure if that // would be worth the extra management - auto res = duckdb_prepared_arrow_schema(wrapper->statement, reinterpret_cast(&schema)); - if (res != DuckDBSuccess) { + + auto prepared_wrapper = reinterpret_cast(wrapper->statement); + if (!prepared_wrapper || !prepared_wrapper->statement || !prepared_wrapper->statement->data) { + SetError(error, "Invalid prepared statement wrapper"); return ADBC_STATUS_INVALID_ARGUMENT; } - return ADBC_STATUS_OK; -} + auto count = prepared_wrapper->statement->data->properties.parameter_count; + if (count == 0) { + count = 1; + } + std::vector types(count); + std::vector owned_names(count); + duckdb::vector names(count); -AdbcStatusCode GetPreparedParameters(duckdb_connection connection, duckdb::unique_ptr &result, - ArrowArrayStream *input, AdbcError *error) { - - auto cconn = reinterpret_cast(connection); - - try { - auto arrow_scan = - cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER(reinterpret_cast(input)), - duckdb::Value::POINTER(reinterpret_cast(stream_produce)), - duckdb::Value::POINTER(reinterpret_cast(stream_schema))}); - result = arrow_scan->Execute(); - // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid - // double-releasing it - input->release = nullptr; - } catch (std::exception &ex) { - if (error) { - ::duckdb::ErrorData parsed_error(ex); - error->message = strdup(parsed_error.RawMessage().c_str()); - } - return ADBC_STATUS_INTERNAL; - } catch (...) { - return ADBC_STATUS_INTERNAL; + for (idx_t i = 0; i < count; i++) { + // FIXME: we don't support named parameters yet, but when we do, this needs to be updated + // Every prepared parameter type is UNKNOWN, which we need to map to NULL according to the spec of + // 'AdbcStatementGetParameterSchema' + types[i] = duckdb_create_logical_type(DUCKDB_TYPE_SQLNULL); + auto column_name = std::to_string(i); + owned_names.emplace_back(column_name); + names[i] = owned_names.back().c_str(); + } + + duckdb_arrow_options arrow_options; + duckdb_connection_get_arrow_options(wrapper->connection, &arrow_options); + + auto res = duckdb_to_arrow_schema(arrow_options, &types[0], names.data(), count, schema); + + for (auto &type : types) { + duckdb_destroy_logical_type(&type); + } + duckdb_destroy_arrow_options(&arrow_options); + + if (res) { + SetError(error, duckdb_error_data_message(res)); + duckdb_destroy_error_data(&res); + return ADBC_STATUS_INVALID_ARGUMENT; } return ADBC_STATUS_OK; } @@ -772,7 +831,6 @@ static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *s // Take the input stream from the statement auto stream = statement->ingestion_stream; - statement->ingestion_stream.release = nullptr; // Ingest into a table from the bound stream return Ingest(statement->connection, statement->ingestion_table_name, statement->db_schema, &stream, error, @@ -802,34 +860,61 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr if (has_stream && to_table) { return IngestToTableFromBoundStream(wrapper, error); } + auto stream_wrapper = static_cast(malloc(sizeof(DuckDBAdbcStreamWrapper))); if (has_stream) { // A stream was bound to the statement, use that to bind parameters - duckdb::unique_ptr result; ArrowArrayStream stream = wrapper->ingestion_stream; - wrapper->ingestion_stream.release = nullptr; - auto adbc_res = GetPreparedParameters(wrapper->connection, result, &stream, error); - if (adbc_res != ADBC_STATUS_OK) { - return adbc_res; - } - if (!result) { - return ADBC_STATUS_INVALID_ARGUMENT; + ConvertedSchemaWrapper out_types; + duckdb::ArrowSchemaWrapper arrow_schema_wrapper; + stream.get_schema(&stream, &arrow_schema_wrapper.arrow_schema); + try { + auto res = + duckdb_schema_from_arrow(wrapper->connection, &arrow_schema_wrapper.arrow_schema, out_types.GetPtr()); + if (res) { + SetError(error, duckdb_error_data_message(res)); + duckdb_destroy_error_data(&res); + } + } catch (...) { + free(stream_wrapper); + return ADBC_STATUS_INTERNAL; } - duckdb::unique_ptr chunk; auto prepared_statement_params = reinterpret_cast(wrapper->statement)->statement->named_param_map.size(); - while ((chunk = result->Fetch()) != nullptr) { + duckdb::ArrowArrayWrapper arrow_array_wrapper; + + stream.get_next(&stream, &arrow_array_wrapper.arrow_array); + + while (arrow_array_wrapper.arrow_array.release) { + // This is a valid arrow array, let's make it into a data chunk + DataChunkWrapper out_chunk; + auto res_conv = duckdb_data_chunk_from_arrow(wrapper->connection, &arrow_array_wrapper.arrow_array, + out_types.Get(), &out_chunk.chunk); + if (res_conv) { + SetError(error, duckdb_error_data_message(res_conv)); + duckdb_destroy_error_data(&res_conv); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (!out_chunk.chunk) { + SetError(error, "Please provide a non-empty chunk to be bound"); + free(stream_wrapper); + return ADBC_STATUS_INVALID_ARGUMENT; + } + auto chunk = reinterpret_cast(out_chunk.chunk); if (chunk->size() == 0) { SetError(error, "Please provide a non-empty chunk to be bound"); + free(stream_wrapper); return ADBC_STATUS_INVALID_ARGUMENT; } if (chunk->size() != 1) { // TODO: add support for binding multiple rows SetError(error, "Binding multiple rows at once is not supported yet"); + free(stream_wrapper); return ADBC_STATUS_NOT_IMPLEMENTED; } if (chunk->ColumnCount() > prepared_statement_params) { SetError(error, "Input data has more column than prepared statement has parameters"); + free(stream_wrapper); return ADBC_STATUS_INVALID_ARGUMENT; } duckdb_clear_bindings(wrapper->statement); @@ -839,34 +924,35 @@ AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct Arr auto res = duckdb_bind_value(wrapper->statement, 1 + col_idx, duck_val); if (res != DuckDBSuccess) { SetError(error, duckdb_prepare_error(wrapper->statement)); + free(stream_wrapper); return ADBC_STATUS_INVALID_ARGUMENT; } } - - auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); + auto res = duckdb_execute_prepared(wrapper->statement, &stream_wrapper->result); if (res != DuckDBSuccess) { - SetError(error, duckdb_query_arrow_error(wrapper->result)); + SetError(error, duckdb_result_error(&stream_wrapper->result)); + free(stream_wrapper); return ADBC_STATUS_INVALID_ARGUMENT; } + // Recreate wrappers for next iteration + arrow_array_wrapper = duckdb::ArrowArrayWrapper(); + stream.get_next(&stream, &arrow_array_wrapper.arrow_array); } } else { - auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); + auto res = duckdb_execute_prepared(wrapper->statement, &stream_wrapper->result); if (res != DuckDBSuccess) { - SetError(error, duckdb_query_arrow_error(wrapper->result)); + SetError(error, duckdb_result_error(&stream_wrapper->result)); return ADBC_STATUS_INVALID_ARGUMENT; } } if (out) { - out->private_data = wrapper->result; + // We pass ownership of the statement private data to our stream + out->private_data = stream_wrapper; out->get_schema = get_schema; out->get_next = get_next; out->release = release; out->get_last_error = get_last_error; - - // because we handed out the stream pointer its no longer our responsibility to destroy it in - // AdbcStatementRelease, this is now done in release() - wrapper->result = nullptr; } return ADBC_STATUS_OK; @@ -1309,19 +1395,34 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth ), constraints AS ( SELECT - table_catalog, - table_schema, + database_name AS table_catalog, + schema_name AS table_schema, table_name, - LIST( - { - constraint_name: constraint_name, - constraint_type: constraint_type, - constraint_column_names: []::VARCHAR[], - constraint_column_usage: []::STRUCT(fk_catalog VARCHAR, fk_db_schema VARCHAR, fk_table VARCHAR, fk_column_name VARCHAR)[], - } - ) table_constraints - FROM information_schema.table_constraints - GROUP BY table_catalog, table_schema, table_name + LIST({ + constraint_name: constraint_name, + constraint_type: constraint_type, + constraint_column_names: constraint_column_names, + constraint_column_usage: list_transform( + referenced_column_names, + lambda name: { + fk_catalog: database_name, + fk_db_schema: schema_name, + fk_table: referenced_table, + fk_column_name: name, + } + ) + }) table_constraints + FROM duckdb_constraints() + WHERE + constraint_type NOT IN ('NOT NULL') AND + list_has_any( + constraint_column_names, + list_filter( + constraint_column_names, + lambda name: name LIKE '%s' + ) + ) + GROUP BY database_name, schema_name, table_name ), tables AS ( SELECT @@ -1365,8 +1466,8 @@ AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth WHERE catalog_name LIKE '%s' GROUP BY catalog_name )", - column_name_filter, table_name_filter, table_type_condition, - db_schema_filter, catalog_filter); + column_name_filter, column_name_filter, table_name_filter, + table_type_condition, db_schema_filter, catalog_filter); break; default: SetError(error, "Invalid value of Depth"); diff --git a/src/duckdb/src/common/adbc/driver_manager.cpp b/src/duckdb/src/common/adbc/driver_manager.cpp index 9ac932380..45fb8c24d 100644 --- a/src/duckdb/src/common/adbc/driver_manager.cpp +++ b/src/duckdb/src/common/adbc/driver_manager.cpp @@ -51,9 +51,9 @@ void GetWinError(std::string *buffer) { DWORD rc = GetLastError(); LPVOID message; - FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr); + FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr); (*buffer) += '('; (*buffer) += std::to_string(rc); diff --git a/src/duckdb/src/common/arrow/appender/append_data.cpp b/src/duckdb/src/common/arrow/appender/append_data.cpp new file mode 100644 index 000000000..06ccbc1ad --- /dev/null +++ b/src/duckdb/src/common/arrow/appender/append_data.cpp @@ -0,0 +1,29 @@ +#include "duckdb/common/arrow/appender/append_data.hpp" + +namespace duckdb { + +void ArrowAppendData::AppendValidity(UnifiedVectorFormat &format, idx_t from, idx_t to) { + // resize the buffer, filling the validity buffer with all valid values + idx_t size = to - from; + ResizeValidity(GetValidityBuffer(), row_count + size); + if (format.validity.AllValid()) { + // if all values are valid we don't need to do anything else + return; + } + + // otherwise we iterate through the validity mask + auto validity_data = (uint8_t *)GetValidityBuffer().data(); + uint8_t current_bit; + idx_t current_byte; + GetBitPosition(row_count, current_byte, current_bit); + for (idx_t i = from; i < to; i++) { + auto source_idx = format.sel->get_index(i); + // append the validity mask + if (!format.validity.RowIsValid(source_idx)) { + SetNull(validity_data, current_byte, current_bit); + } + NextBit(current_byte, current_bit); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/bool_data.cpp b/src/duckdb/src/common/arrow/appender/bool_data.cpp index 78befb603..798b71e44 100644 --- a/src/duckdb/src/common/arrow/appender/bool_data.cpp +++ b/src/duckdb/src/common/arrow/appender/bool_data.cpp @@ -6,7 +6,6 @@ namespace duckdb { void ArrowBoolData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { auto byte_count = (capacity + 7) / 8; result.GetMainBuffer().reserve(byte_count); - (void)AppendValidity; // silence a compiler warning about unused static function } void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { @@ -16,24 +15,24 @@ void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t fr auto &main_buffer = append_data.GetMainBuffer(); auto &validity_buffer = append_data.GetValidityBuffer(); // we initialize both the validity and the bit set to 1's - ResizeValidity(validity_buffer, append_data.row_count + size); - ResizeValidity(main_buffer, append_data.row_count + size); + ArrowAppendData::ResizeValidity(validity_buffer, append_data.row_count + size); + ArrowAppendData::ResizeValidity(main_buffer, append_data.row_count + size); auto data = UnifiedVectorFormat::GetData(format); auto result_data = main_buffer.GetData(); auto validity_data = validity_buffer.GetData(); uint8_t current_bit; idx_t current_byte; - GetBitPosition(append_data.row_count, current_byte, current_bit); + ArrowAppendData::GetBitPosition(append_data.row_count, current_byte, current_bit); for (idx_t i = from; i < to; i++) { auto source_idx = format.sel->get_index(i); // append the validity mask if (!format.validity.RowIsValid(source_idx)) { - SetNull(append_data, validity_data, current_byte, current_bit); + append_data.SetNull(validity_data, current_byte, current_bit); } else if (!data[source_idx]) { - UnsetBit(result_data, current_byte, current_bit); + ArrowAppendData::UnsetBit(result_data, current_byte, current_bit); } - NextBit(current_byte, current_bit); + ArrowAppendData::NextBit(current_byte, current_bit); } append_data.row_count += size; } diff --git a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp index 172144fd3..a8cbc16d9 100644 --- a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp +++ b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp @@ -18,7 +18,7 @@ void ArrowFixedSizeListData::Append(ArrowAppendData &append_data, Vector &input, UnifiedVectorFormat format; input.ToUnifiedFormat(input_size, format); idx_t size = to - from; - AppendValidity(append_data, format, from, to); + append_data.AppendValidity(format, from, to); input.Flatten(input_size); auto array_size = ArrayType::GetSize(input.GetType()); auto &child_vector = ArrayVector::GetEntry(input); diff --git a/src/duckdb/src/common/arrow/appender/struct_data.cpp b/src/duckdb/src/common/arrow/appender/struct_data.cpp index b2afa62d1..28cee72a9 100644 --- a/src/duckdb/src/common/arrow/appender/struct_data.cpp +++ b/src/duckdb/src/common/arrow/appender/struct_data.cpp @@ -18,7 +18,7 @@ void ArrowStructData::Append(ArrowAppendData &append_data, Vector &input, idx_t UnifiedVectorFormat format; input.ToUnifiedFormat(input_size, format); idx_t size = to - from; - AppendValidity(append_data, format, from, to); + append_data.AppendValidity(format, from, to); // append the children of the struct auto &children = StructVector::GetEntries(input); for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { diff --git a/src/duckdb/src/common/arrow/appender/union_data.cpp b/src/duckdb/src/common/arrow/appender/union_data.cpp index 1e9f4f432..4ca4ebf67 100644 --- a/src/duckdb/src/common/arrow/appender/union_data.cpp +++ b/src/duckdb/src/common/arrow/appender/union_data.cpp @@ -14,7 +14,6 @@ void ArrowUnionData::Initialize(ArrowAppendData &result, const LogicalType &type auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); result.child_data.push_back(std::move(child_buffer)); } - (void)AppendValidity; // silence a compiler warning about unused static functiondep } void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { diff --git a/src/duckdb/src/common/arrow/arrow_appender.cpp b/src/duckdb/src/common/arrow/arrow_appender.cpp index 344464e8e..8b455b585 100644 --- a/src/duckdb/src/common/arrow/arrow_appender.cpp +++ b/src/duckdb/src/common/arrow/arrow_appender.cpp @@ -215,21 +215,21 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic case LogicalTypeId::DECIMAL: switch (type.InternalType()) { case PhysicalType::INT16: - if (append_data.options.arrow_output_version > 14) { + if (append_data.options.arrow_output_version > ArrowFormatVersion::V1_4) { InitializeAppenderForType>(append_data); } else { InitializeAppenderForType>(append_data); } break; case PhysicalType::INT32: - if (append_data.options.arrow_output_version > 14) { + if (append_data.options.arrow_output_version > ArrowFormatVersion::V1_4) { InitializeAppenderForType>(append_data); } else { InitializeAppenderForType>(append_data); } break; case PhysicalType::INT64: - if (append_data.options.arrow_output_version > 14) { + if (append_data.options.arrow_output_version > ArrowFormatVersion::V1_4) { InitializeAppenderForType>(append_data); } else { InitializeAppenderForType>(append_data); @@ -245,9 +245,9 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic case LogicalTypeId::VARCHAR: case LogicalTypeId::BLOB: case LogicalTypeId::BIT: - case LogicalTypeId::VARINT: + case LogicalTypeId::BIGNUM: if ((append_data.options.produce_arrow_string_view || type.id() != LogicalTypeId::VARCHAR) && - append_data.options.arrow_output_version >= 14) { + append_data.options.arrow_output_version >= ArrowFormatVersion::V1_4) { InitializeAppenderForType(append_data); } else { if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { @@ -285,7 +285,8 @@ static void InitializeFunctionPointers(ArrowAppendData &append_data, const Logic InitializeAppenderForType(append_data); break; case LogicalTypeId::LIST: { - if (append_data.options.arrow_use_list_view && append_data.options.arrow_output_version >= 14) { + if (append_data.options.arrow_use_list_view && + append_data.options.arrow_output_version >= ArrowFormatVersion::V1_4) { if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { InitializeAppenderForType>(append_data); } else { diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp index 8cded261e..d5acf3698 100644 --- a/src/duckdb/src/common/arrow/arrow_converter.cpp +++ b/src/duckdb/src/common/arrow/arrow_converter.cpp @@ -39,6 +39,8 @@ static void ReleaseDuckDBArrowSchema(ArrowSchema *schema) { } schema->release = nullptr; auto holder = static_cast(schema->private_data); + schema->private_data = nullptr; + delete holder; } @@ -173,7 +175,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co if (options.arrow_lossless_conversion) { SetArrowExtension(root_holder, child, type, context); } else { - if (options.produce_arrow_string_view && options.arrow_output_version >= 14) { + if (options.produce_arrow_string_view && options.arrow_output_version >= ArrowFormatVersion::V1_4) { // List views are only introduced in arrow format v1.4 child.format = "vu"; } else { @@ -187,7 +189,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co break; } case LogicalTypeId::VARCHAR: - if (options.produce_arrow_string_view && options.arrow_output_version >= 14) { + if (options.produce_arrow_string_view && options.arrow_output_version >= ArrowFormatVersion::V1_4) { // List views are only introduced in arrow format v1.4 child.format = "vu"; } else { @@ -235,7 +237,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co break; case LogicalTypeId::DECIMAL: { uint8_t width, scale, bit_width; - if (options.arrow_output_version <= 14) { + if (options.arrow_output_version <= ArrowFormatVersion::V1_4) { // Before version 1.4 all decimals were int128 bit_width = 128; } else { @@ -266,7 +268,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co break; } case LogicalTypeId::BLOB: - if (options.arrow_output_version >= 14) { + if (options.arrow_output_version >= ArrowFormatVersion::V1_4) { // Views are only introduced in arrow format v1.4 child.format = "vz"; } else if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { @@ -279,7 +281,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co if (options.arrow_lossless_conversion) { SetArrowExtension(root_holder, child, type, context); } else { - if (options.arrow_output_version >= 14) { + if (options.arrow_output_version >= ArrowFormatVersion::V1_4) { // Views are only introduced in arrow format v1.4 child.format = "vz"; } else if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { @@ -292,7 +294,7 @@ void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, co break; } case LogicalTypeId::LIST: { - if (options.arrow_use_list_view && options.arrow_output_version >= 14) { + if (options.arrow_use_list_view && options.arrow_output_version >= ArrowFormatVersion::V1_4) { // List views are only introduced in arrow format v1.4 if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { child.format = "+vL"; diff --git a/src/duckdb/src/common/arrow/arrow_type_extension.cpp b/src/duckdb/src/common/arrow/arrow_type_extension.cpp index 9e50e7c25..93979cd36 100644 --- a/src/duckdb/src/common/arrow/arrow_type_extension.cpp +++ b/src/duckdb/src/common/arrow/arrow_type_extension.cpp @@ -316,16 +316,16 @@ struct ArrowBit { } }; -struct ArrowVarint { +struct ArrowBignum { static unique_ptr GetType(const ArrowSchema &schema, const ArrowSchemaMetadata &schema_metadata) { const auto format = string(schema.format); if (format == "z") { - return make_uniq(LogicalType::VARINT, make_uniq(ArrowVariableSizeType::NORMAL)); + return make_uniq(LogicalType::BIGNUM, make_uniq(ArrowVariableSizeType::NORMAL)); } else if (format == "Z") { - return make_uniq(LogicalType::VARINT, + return make_uniq(LogicalType::BIGNUM, make_uniq(ArrowVariableSizeType::SUPER_SIZE)); } - throw InvalidInputException("Arrow extension type \"%s\" not supported for Varint", format.c_str()); + throw InvalidInputException("Arrow extension type \"%s\" not supported for Bignum", format.c_str()); } static void PopulateSchema(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &schema, const LogicalType &type, @@ -387,7 +387,7 @@ void ArrowTypeExtensionSet::Initialize(const DBConfig &config) { config.RegisterArrowExtension({"DuckDB", "bit", &ArrowBit::PopulateSchema, &ArrowBit::GetType, make_shared_ptr(LogicalType::BIT), nullptr, nullptr}); - config.RegisterArrowExtension({"DuckDB", "varint", &ArrowVarint::PopulateSchema, &ArrowVarint::GetType, - make_shared_ptr(LogicalType::VARINT), nullptr, nullptr}); + config.RegisterArrowExtension({"DuckDB", "bignum", &ArrowBignum::PopulateSchema, &ArrowBignum::GetType, + make_shared_ptr(LogicalType::BIGNUM), nullptr, nullptr}); } } // namespace duckdb diff --git a/src/duckdb/src/common/bignum.cpp b/src/duckdb/src/common/bignum.cpp new file mode 100644 index 000000000..2bf228586 --- /dev/null +++ b/src/duckdb/src/common/bignum.cpp @@ -0,0 +1,362 @@ +#include "duckdb/common/bignum.hpp" +#include "duckdb/common/types/bignum.hpp" +#include + +namespace duckdb { +void PrintBits(const char value) { + for (int i = 7; i >= 0; --i) { + std::cout << ((value >> i) & 1); + } +} + +void bignum_t::Print() const { + auto ptr = data.GetData(); + auto length = data.GetSize(); + for (idx_t i = 0; i < length; ++i) { + PrintBits(ptr[i]); + std::cout << " "; + } + std::cout << '\n'; +} + +void BignumIntermediate::Print() const { + for (idx_t i = 0; i < size; ++i) { + PrintBits(static_cast(data[i])); + std::cout << " "; + } + std::cout << '\n'; +} + +BignumIntermediate::BignumIntermediate(const bignum_t &value) { + is_negative = (value.data.GetData()[0] & 0x80) == 0; + data = reinterpret_cast(value.data.GetDataWriteable() + Bignum::BIGNUM_HEADER_SIZE); + size = static_cast(value.data.GetSize()) - Bignum::BIGNUM_HEADER_SIZE; +} + +BignumIntermediate::BignumIntermediate(uint8_t *value, idx_t ptr_size) { + is_negative = (value[0] & 0x80) == 0; + data = value + Bignum::BIGNUM_HEADER_SIZE; + size = static_cast(ptr_size) - Bignum::BIGNUM_HEADER_SIZE; +} + +uint8_t BignumIntermediate::GetAbsoluteByte(int64_t index) const { + if (index < 0) { + // byte-extension + return 0; + } + return is_negative ? static_cast(~data[index]) : static_cast(data[index]); +} + +AbsoluteNumberComparison BignumIntermediate::IsAbsoluteBigger(const BignumIntermediate &rhs) const { + idx_t actual_start_pos = GetStartDataPos(); + idx_t actual_size = size - actual_start_pos; + + idx_t rhs_actual_start_pos = rhs.GetStartDataPos(); + idx_t rhs_actual_size = rhs.size - rhs_actual_start_pos; + + // we have opposing signs, gotta do a bunch of checks to figure out who is the biggest + // check sizes + if (actual_size > rhs_actual_size) { + return GREATER; + } + if (actual_size < rhs_actual_size) { + return SMALLER; + } else { + // they have the same size then + idx_t target_idx = actual_start_pos; + idx_t source_idx = rhs_actual_start_pos; + while (target_idx < size) { + auto data_byte = GetAbsoluteByte(static_cast(target_idx)); + auto rhs_byte = rhs.GetAbsoluteByte(static_cast(source_idx)); + if (data_byte > rhs_byte) { + return GREATER; + } else if (data_byte < rhs_byte) { + return SMALLER; + } + target_idx++; + source_idx++; + } + } + // If we got here, the values are equal. + return EQUAL; +} + +bool BignumIntermediate::IsMSBSet() const { + if (is_negative) { + return (data[0] & 0x80) == 0; + } + return (data[0] & 0x80) != 0; +} +void BignumIntermediate::Initialize(ArenaAllocator &allocator) { + is_negative = false; + size = 1; + data = allocator.Allocate(size); + // initialize the data + data[0] = 0; +} + +uint32_t BignumIntermediate::GetStartDataPos(data_ptr_t data, idx_t size, bool is_negative) { + uint8_t non_initialized = is_negative ? 0xFF : 0x00; + uint32_t actual_start = 0; + for (idx_t i = 0; i < size; ++i) { + if (data[i] == non_initialized) { + actual_start++; + } else { + break; + } + } + return actual_start; +} + +uint32_t BignumIntermediate::GetStartDataPos() const { + return GetStartDataPos(data, size, is_negative); +} + +void BignumIntermediate::Reallocate(ArenaAllocator &allocator, idx_t min_size) { + if (min_size < size) { + return; + } + uint32_t new_size = size; + while (new_size <= min_size) { + new_size *= 2; + } + auto new_data = allocator.Allocate(new_size); + // Then we initialize to 0's until we have valid data again + memset(new_data, is_negative ? 0xFF : 0x00, new_size - size); + // Copy the old data to the new data + memcpy(new_data + new_size - size, data, size); + // Set size and pointer + data = new_data; + size = new_size; +} + +idx_t BignumIntermediate::Trim(data_ptr_t data, uint32_t &size, bool is_negative) { + auto actual_start = GetStartDataPos(data, size, is_negative); + if (actual_start == 0) { + return 0; + } + // This bad-boy is wearing shoe lifts, time to prune it. + D_ASSERT(actual_start <= size); + size -= actual_start; + if (size == 0) { + // Always keep at least one byte + actual_start = 0; + size++; + } + memmove(data, data + actual_start, size); + return actual_start; +} + +void BignumIntermediate::Trim() { + Trim(data, size, is_negative); +} + +bool BignumIntermediate::OverOrUnderflow(data_ptr_t data, idx_t size, bool is_negative) { + if (size <= Bignum::MAX_DATA_SIZE) { + return false; + } + // variable that stores a fully unset byte can safely be ignored + uint8_t byte_to_compare = is_negative ? 0xFF : 0x00; + // we will basically check if any byte has any set bit up to Bignum::MAX_DATA_SIZE, if so, that's an under/overflow + idx_t data_pos = 0; + for (idx_t i = size; i > Bignum::MAX_DATA_SIZE; i--) { + if (data[data_pos++] != byte_to_compare) { + return true; + } + } + return false; +} + +bool BignumIntermediate::OverOrUnderflow() const { + return OverOrUnderflow(data, size, is_negative); +} + +bignum_t BignumIntermediate::ToBignum(ArenaAllocator &allocator) { + // This must be trimmed before transforming + Trim(); + bignum_t result; + uint32_t bignum_size = Bignum::BIGNUM_HEADER_SIZE + size; + auto ptr = reinterpret_cast(allocator.Allocate(bignum_size)); + // Set Header + Bignum::SetHeader(ptr, size, is_negative); + // Copy data + memcpy(ptr + Bignum::BIGNUM_HEADER_SIZE, data, size); + result.data = string_t(ptr, bignum_size); + return result; +} + +void BignumAddition(data_ptr_t result, int64_t result_end, bool is_target_absolute_bigger, + const BignumIntermediate &lhs, const BignumIntermediate &rhs) { + bool is_result_negative = is_target_absolute_bigger ? lhs.is_negative : rhs.is_negative; + + int64_t i_target = lhs.size - 1; // last byte index in target + int64_t i_source = rhs.size - 1; // last byte index in source + int64_t i_result = result_end - 1; // last byte index in result + + // Carry for addition + uint16_t carry = 0; + uint16_t borrow = 0; + // Add bytes from right to left + while (i_result >= 0) { + // If the numbers are negative, we bit flip them + uint8_t target_byte = lhs.GetAbsoluteByte(i_target); + uint8_t source_byte = rhs.GetAbsoluteByte(i_source); + // Add bytes and carry + uint16_t sum; + if (lhs.is_negative == rhs.is_negative) { + sum = static_cast(target_byte) + static_cast(source_byte) + carry; + carry = (sum >> 8) & 0xFF; + } else { + if (is_target_absolute_bigger) { + sum = static_cast(target_byte) - static_cast(source_byte) - borrow; + borrow = sum > static_cast(target_byte) ? 1 : 0; + } else { + sum = static_cast(source_byte) - static_cast(target_byte) - borrow; + borrow = sum > static_cast(source_byte) ? 1 : 0; + } + } + uint8_t result_byte = static_cast(sum & 0xFF); + // If the result is not positive, we must flip the bits again + result[i_result] = is_result_negative ? ~result_byte : result_byte; + i_target--; + i_source--; + i_result--; + } + + if (is_result_negative != lhs.is_negative) { + // If we are flipping the sign we must be sure that we are flipping all extra bits from our target + for (int64_t i = 0; i < result_end - rhs.size; ++i) { + result[i] = is_result_negative ? 0xFF : 0x00; + } + } +} + +string_t BignumIntermediate::Negate(Vector &result_vector) const { + + auto target = StringVector::EmptyString(result_vector, size + Bignum::BIGNUM_HEADER_SIZE); + auto ptr = target.GetDataWriteable(); + + if (!is_negative && size == 1 && data[0] == 0x00) { + // If we have a zero, we just do a copy + Bignum::SetHeader(ptr, size, is_negative); + for (idx_t i = 0; i < size; ++i) { + ptr[i + Bignum::BIGNUM_HEADER_SIZE] = static_cast(data[i]); + } + } else { + // Otherwise, we set the header with a flip on the signal + Bignum::SetHeader(ptr, size, !is_negative); + for (idx_t i = 0; i < size; ++i) { + // And flip all the data bits + ptr[i + Bignum::BIGNUM_HEADER_SIZE] = static_cast(~data[i]); + } + } + + return target; +} + +void BignumIntermediate::NegateInPlace() { + if (!is_negative && size == 1 && data[0] == 0x00) { + // this is a zero, there is no negation + return; + } + is_negative = !is_negative; + for (size_t i = 0; i < size; i++) { + data[i] = ~data[i]; // flip each byte of the pointer + } +} + +string ProduceOverUnderFlowError(bool is_result_negative, idx_t actual_start, idx_t data_size) { + // We must throw an error, usually we should print the numbers, but I have a feeling that it won't be possible + // here. + std::ostringstream error; + if (is_result_negative) { + error << "Underflow "; + } else { + error << "Overflow "; + } + error << "in Bignum Operation. A Bignum can hold max " << Bignum::MAX_DATA_SIZE + << " data bytes. Current bignum has " << data_size - actual_start << " bytes."; + return error.str(); +} + +string_t BignumIntermediate::Add(Vector &result_vector, const BignumIntermediate &lhs, const BignumIntermediate &rhs) { + const bool same_sign = lhs.is_negative == rhs.is_negative; + const uint32_t actual_size = lhs.size - lhs.GetStartDataPos(); + const uint32_t actual_rhs_size = rhs.size - rhs.GetStartDataPos(); + uint32_t result_size = actual_size; + if (actual_size < actual_rhs_size || (same_sign && (lhs.IsMSBSet() || (rhs.IsMSBSet() && lhs.size == rhs.size)))) { + result_size = actual_size < actual_rhs_size ? actual_rhs_size + 1 : actual_size + 1; + } + bool is_target_absolute_bigger = true; + if (result_size == 0) { + result_size++; + } + result_size += Bignum::BIGNUM_HEADER_SIZE; + if (lhs.is_negative != rhs.is_negative) { + auto is_absolute_bigger = lhs.IsAbsoluteBigger(rhs); + if (is_absolute_bigger == EQUAL) { + // We set this value to 0 + auto target = StringVector::EmptyString(result_vector, result_size); + auto target_data = target.GetDataWriteable(); + Bignum::SetHeader(target_data, 1, false); + target_data[Bignum::BIGNUM_HEADER_SIZE] = 0; + return target; + + } else if (is_absolute_bigger == SMALLER) { + is_target_absolute_bigger = false; + } + } + + auto target = StringVector::EmptyString(result_vector, result_size); + auto result_size_data = result_size - Bignum::BIGNUM_HEADER_SIZE; + + auto target_data = target.GetDataWriteable(); + BignumAddition(reinterpret_cast(target_data + Bignum::BIGNUM_HEADER_SIZE), result_size_data, + is_target_absolute_bigger, lhs, rhs); + bool is_result_negative = is_target_absolute_bigger ? lhs.is_negative : rhs.is_negative; + if (OverOrUnderflow(reinterpret_cast(target_data + Bignum::BIGNUM_HEADER_SIZE), result_size_data, + is_result_negative)) { + auto actual_start = GetStartDataPos(reinterpret_cast(target_data + Bignum::BIGNUM_HEADER_SIZE), + result_size_data, is_result_negative); + throw OutOfRangeException(ProduceOverUnderFlowError(is_result_negative, actual_start, result_size_data)); + } + Trim(reinterpret_cast(target_data + Bignum::BIGNUM_HEADER_SIZE), result_size_data, is_result_negative); + Bignum::SetHeader(target_data, result_size_data, is_result_negative); + target.SetSizeAndFinalize(result_size_data + Bignum::BIGNUM_HEADER_SIZE); + return target; +} +void BignumIntermediate::AddInPlace(ArenaAllocator &allocator, const BignumIntermediate &rhs) { + const bool same_sign = is_negative == rhs.is_negative; + idx_t actual_size = size - GetStartDataPos(); + idx_t actual_rhs_size = rhs.size - rhs.GetStartDataPos(); + if (actual_size < actual_rhs_size || (same_sign && (IsMSBSet() || (rhs.IsMSBSet() && size == rhs.size)))) { + // We must reallocate + idx_t min_size = actual_size < actual_rhs_size ? actual_rhs_size + 1 : size + 1; + Reallocate(allocator, min_size); + } + bool is_target_absolute_bigger = true; + if (rhs.is_negative != is_negative) { + auto is_absolute_bigger = IsAbsoluteBigger(rhs); + if (is_absolute_bigger == EQUAL) { + // We set this value to 0 + *this = BignumIntermediate(); + Initialize(allocator); + return; + } else if (is_absolute_bigger == SMALLER) { + is_target_absolute_bigger = false; + } + } + + bool is_result_negative = is_target_absolute_bigger ? is_negative : rhs.is_negative; + BignumAddition(data, size, is_target_absolute_bigger, *this, rhs); + if (is_result_negative != is_negative) { + is_negative = is_result_negative; + } + if (OverOrUnderflow()) { + // We must throw an error, usually we should print the numbers, but I have a feeling that it won't be possible + // here. + throw OutOfRangeException(ProduceOverUnderFlowError(is_result_negative, GetStartDataPos(), size)); + } +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp index 5257e9902..27c5e7472 100644 --- a/src/duckdb/src/common/enum_util.cpp +++ b/src/duckdb/src/common/enum_util.cpp @@ -15,8 +15,10 @@ #include "duckdb/common/box_renderer.hpp" #include "duckdb/common/enums/access_mode.hpp" #include "duckdb/common/enums/aggregate_handling.hpp" +#include "duckdb/common/enums/arrow_format_version.hpp" #include "duckdb/common/enums/catalog_lookup_behavior.hpp" #include "duckdb/common/enums/catalog_type.hpp" +#include "duckdb/common/enums/checkpoint_abort.hpp" #include "duckdb/common/enums/compression_type.hpp" #include "duckdb/common/enums/copy_overwrite_mode.hpp" #include "duckdb/common/enums/cte_materialize.hpp" @@ -44,6 +46,7 @@ #include "duckdb/common/enums/optimizer_type.hpp" #include "duckdb/common/enums/order_preservation_type.hpp" #include "duckdb/common/enums/order_type.hpp" +#include "duckdb/common/enums/ordinality_request_type.hpp" #include "duckdb/common/enums/output_type.hpp" #include "duckdb/common/enums/pending_execution_result.hpp" #include "duckdb/common/enums/physical_operator_type.hpp" @@ -112,11 +115,10 @@ #include "duckdb/function/table/arrow/enum/arrow_type_info_type.hpp" #include "duckdb/function/table/arrow/enum/arrow_variable_size_type.hpp" #include "duckdb/function/table_function.hpp" +#include "duckdb/function/window/window_merge_sort_tree.hpp" #include "duckdb/logging/logging.hpp" #include "duckdb/main/appender.hpp" #include "duckdb/main/capi/capi_internal.hpp" -#include "duckdb/main/client_properties.hpp" -#include "duckdb/main/config.hpp" #include "duckdb/main/error_manager.hpp" #include "duckdb/main/extension.hpp" #include "duckdb/main/extension_helper.hpp" @@ -124,7 +126,7 @@ #include "duckdb/main/query_profiler.hpp" #include "duckdb/main/query_result.hpp" #include "duckdb/main/secret/secret.hpp" -#include "duckdb/main/settings.hpp" +#include "duckdb/main/setting_info.hpp" #include "duckdb/parallel/interrupt.hpp" #include "duckdb/parallel/meta_pipeline.hpp" #include "duckdb/parallel/task.hpp" @@ -161,6 +163,7 @@ #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/storage/table/chunk_info.hpp" #include "duckdb/storage/table/column_segment.hpp" +#include "duckdb/storage/table/table_index_list.hpp" #include "duckdb/storage/temporary_file_manager.hpp" #include "duckdb/verification/statement_verifier.hpp" @@ -189,19 +192,20 @@ const StringUtil::EnumStringLiteral *GetARTHandlingResultValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(ARTHandlingResult::CONTINUE), "CONTINUE" }, { static_cast(ARTHandlingResult::SKIP), "SKIP" }, - { static_cast(ARTHandlingResult::YIELD), "YIELD" } + { static_cast(ARTHandlingResult::YIELD), "YIELD" }, + { static_cast(ARTHandlingResult::NONE), "NONE" } }; return values; } template<> const char* EnumUtil::ToChars(ARTHandlingResult value) { - return StringUtil::EnumToString(GetARTHandlingResultValues(), 3, "ARTHandlingResult", static_cast(value)); + return StringUtil::EnumToString(GetARTHandlingResultValues(), 4, "ARTHandlingResult", static_cast(value)); } template<> ARTHandlingResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetARTHandlingResultValues(), 3, "ARTHandlingResult", value)); + return static_cast(StringUtil::StringToEnum(GetARTHandlingResultValues(), 4, "ARTHandlingResult", value)); } const StringUtil::EnumStringLiteral *GetARTScanHandlingValues() { @@ -481,6 +485,25 @@ AppenderType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetAppenderTypeValues(), 2, "AppenderType", value)); } +const StringUtil::EnumStringLiteral *GetArrowArrayPhysicalTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(ArrowArrayPhysicalType::DICTIONARY_ENCODED), "DICTIONARY_ENCODED" }, + { static_cast(ArrowArrayPhysicalType::RUN_END_ENCODED), "RUN_END_ENCODED" }, + { static_cast(ArrowArrayPhysicalType::DEFAULT), "DEFAULT" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(ArrowArrayPhysicalType value) { + return StringUtil::EnumToString(GetArrowArrayPhysicalTypeValues(), 3, "ArrowArrayPhysicalType", static_cast(value)); +} + +template<> +ArrowArrayPhysicalType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetArrowArrayPhysicalTypeValues(), 3, "ArrowArrayPhysicalType", value)); +} + const StringUtil::EnumStringLiteral *GetArrowDateTimeTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(ArrowDateTimeType::MILLISECONDS), "MILLISECONDS" }, @@ -504,6 +527,28 @@ ArrowDateTimeType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetArrowDateTimeTypeValues(), 7, "ArrowDateTimeType", value)); } +const StringUtil::EnumStringLiteral *GetArrowFormatVersionValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(ArrowFormatVersion::V1_0), "1.0" }, + { static_cast(ArrowFormatVersion::V1_1), "1.1" }, + { static_cast(ArrowFormatVersion::V1_2), "1.2" }, + { static_cast(ArrowFormatVersion::V1_3), "1.3" }, + { static_cast(ArrowFormatVersion::V1_4), "1.4" }, + { static_cast(ArrowFormatVersion::V1_5), "1.5" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(ArrowFormatVersion value) { + return StringUtil::EnumToString(GetArrowFormatVersionValues(), 6, "ArrowFormatVersion", static_cast(value)); +} + +template<> +ArrowFormatVersion EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetArrowFormatVersionValues(), 6, "ArrowFormatVersion", value)); +} + const StringUtil::EnumStringLiteral *GetArrowOffsetSizeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(ArrowOffsetSize::REGULAR), "REGULAR" }, @@ -1726,19 +1771,20 @@ const StringUtil::EnumStringLiteral *GetExtraTypeInfoTypeValues() { { static_cast(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO), "AGGREGATE_STATE_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::ARRAY_TYPE_INFO), "ARRAY_TYPE_INFO" }, { static_cast(ExtraTypeInfoType::ANY_TYPE_INFO), "ANY_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), "INTEGER_LITERAL_TYPE_INFO" } + { static_cast(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), "INTEGER_LITERAL_TYPE_INFO" }, + { static_cast(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), "TEMPLATE_TYPE_INFO" } }; return values; } template<> const char* EnumUtil::ToChars(ExtraTypeInfoType value) { - return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 12, "ExtraTypeInfoType", static_cast(value)); + return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", static_cast(value)); } template<> ExtraTypeInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 12, "ExtraTypeInfoType", value)); + return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 13, "ExtraTypeInfoType", value)); } const StringUtil::EnumStringLiteral *GetFileBufferTypeValues() { @@ -2106,6 +2152,25 @@ IndexAppendMode EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetIndexAppendModeValues(), 3, "IndexAppendMode", value)); } +const StringUtil::EnumStringLiteral *GetIndexBindStateValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(IndexBindState::UNBOUND), "UNBOUND" }, + { static_cast(IndexBindState::BINDING), "BINDING" }, + { static_cast(IndexBindState::BOUND), "BOUND" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(IndexBindState value) { + return StringUtil::EnumToString(GetIndexBindStateValues(), 3, "IndexBindState", static_cast(value)); +} + +template<> +IndexBindState EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetIndexBindStateValues(), 3, "IndexBindState", value)); +} + const StringUtil::EnumStringLiteral *GetIndexConstraintTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(IndexConstraintType::NONE), "NONE" }, @@ -2456,6 +2521,7 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { { static_cast(LogicalTypeId::UNKNOWN), "UNKNOWN" }, { static_cast(LogicalTypeId::ANY), "ANY" }, { static_cast(LogicalTypeId::USER), "USER" }, + { static_cast(LogicalTypeId::TEMPLATE), "TEMPLATE" }, { static_cast(LogicalTypeId::BOOLEAN), "BOOLEAN" }, { static_cast(LogicalTypeId::TINYINT), "TINYINT" }, { static_cast(LogicalTypeId::SMALLINT), "SMALLINT" }, @@ -2484,7 +2550,7 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { { static_cast(LogicalTypeId::BIT), "BIT" }, { static_cast(LogicalTypeId::STRING_LITERAL), "STRING_LITERAL" }, { static_cast(LogicalTypeId::INTEGER_LITERAL), "INTEGER_LITERAL" }, - { static_cast(LogicalTypeId::VARINT), "VARINT" }, + { static_cast(LogicalTypeId::BIGNUM), "BIGNUM" }, { static_cast(LogicalTypeId::UHUGEINT), "UHUGEINT" }, { static_cast(LogicalTypeId::HUGEINT), "HUGEINT" }, { static_cast(LogicalTypeId::POINTER), "POINTER" }, @@ -2505,12 +2571,12 @@ const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { template<> const char* EnumUtil::ToChars(LogicalTypeId value) { - return StringUtil::EnumToString(GetLogicalTypeIdValues(), 48, "LogicalTypeId", static_cast(value)); + return StringUtil::EnumToString(GetLogicalTypeIdValues(), 49, "LogicalTypeId", static_cast(value)); } template<> LogicalTypeId EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 48, "LogicalTypeId", value)); + return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 49, "LogicalTypeId", value)); } const StringUtil::EnumStringLiteral *GetLookupResultTypeValues() { @@ -2977,10 +3043,10 @@ const StringUtil::EnumStringLiteral *GetOrderByNullTypeValues() { { static_cast(OrderByNullType::INVALID), "INVALID" }, { static_cast(OrderByNullType::ORDER_DEFAULT), "ORDER_DEFAULT" }, { static_cast(OrderByNullType::ORDER_DEFAULT), "DEFAULT" }, - { static_cast(OrderByNullType::NULLS_FIRST), "NULLS_FIRST" }, { static_cast(OrderByNullType::NULLS_FIRST), "NULLS FIRST" }, - { static_cast(OrderByNullType::NULLS_LAST), "NULLS_LAST" }, - { static_cast(OrderByNullType::NULLS_LAST), "NULLS LAST" } + { static_cast(OrderByNullType::NULLS_FIRST), "NULLS_FIRST" }, + { static_cast(OrderByNullType::NULLS_LAST), "NULLS LAST" }, + { static_cast(OrderByNullType::NULLS_LAST), "NULLS_LAST" } }; return values; } @@ -3037,6 +3103,24 @@ OrderType EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetOrderTypeValues(), 7, "OrderType", value)); } +const StringUtil::EnumStringLiteral *GetOrdinalityTypeValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(OrdinalityType::WITHOUT_ORDINALITY), "WITHOUT_ORDINALITY" }, + { static_cast(OrdinalityType::WITH_ORDINALITY), "WITH_ORDINALITY" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(OrdinalityType value) { + return StringUtil::EnumToString(GetOrdinalityTypeValues(), 2, "OrdinalityType", static_cast(value)); +} + +template<> +OrdinalityType EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetOrdinalityTypeValues(), 2, "OrdinalityType", value)); +} + const StringUtil::EnumStringLiteral *GetOutputStreamValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(OutputStream::STREAM_STDOUT), "STREAM_STDOUT" }, @@ -3839,19 +3923,20 @@ SettingScope EnumUtil::FromString(const char *value) { const StringUtil::EnumStringLiteral *GetShowTypeValues() { static constexpr StringUtil::EnumStringLiteral values[] { { static_cast(ShowType::SUMMARY), "SUMMARY" }, - { static_cast(ShowType::DESCRIBE), "DESCRIBE" } + { static_cast(ShowType::DESCRIBE), "DESCRIBE" }, + { static_cast(ShowType::SHOW_FROM), "SHOW_FROM" } }; return values; } template<> const char* EnumUtil::ToChars(ShowType value) { - return StringUtil::EnumToString(GetShowTypeValues(), 2, "ShowType", static_cast(value)); + return StringUtil::EnumToString(GetShowTypeValues(), 3, "ShowType", static_cast(value)); } template<> ShowType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetShowTypeValues(), 2, "ShowType", value)); + return static_cast(StringUtil::StringToEnum(GetShowTypeValues(), 3, "ShowType", value)); } const StringUtil::EnumStringLiteral *GetSimplifiedTokenTypeValues() { @@ -4675,6 +4760,7 @@ const StringUtil::EnumStringLiteral *GetVerificationTypeValues() { { static_cast(VerificationType::NO_OPERATOR_CACHING), "NO_OPERATOR_CACHING" }, { static_cast(VerificationType::PREPARED), "PREPARED" }, { static_cast(VerificationType::EXTERNAL), "EXTERNAL" }, + { static_cast(VerificationType::EXPLAIN), "EXPLAIN" }, { static_cast(VerificationType::FETCH_ROW_AS_SCAN), "FETCH_ROW_AS_SCAN" }, { static_cast(VerificationType::INVALID), "INVALID" } }; @@ -4683,12 +4769,12 @@ const StringUtil::EnumStringLiteral *GetVerificationTypeValues() { template<> const char* EnumUtil::ToChars(VerificationType value) { - return StringUtil::EnumToString(GetVerificationTypeValues(), 10, "VerificationType", static_cast(value)); + return StringUtil::EnumToString(GetVerificationTypeValues(), 11, "VerificationType", static_cast(value)); } template<> VerificationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVerificationTypeValues(), 10, "VerificationType", value)); + return static_cast(StringUtil::StringToEnum(GetVerificationTypeValues(), 11, "VerificationType", value)); } const StringUtil::EnumStringLiteral *GetVerifyExistenceTypeValues() { @@ -4820,5 +4906,26 @@ WindowExcludeMode EnumUtil::FromString(const char *value) { return static_cast(StringUtil::StringToEnum(GetWindowExcludeModeValues(), 4, "WindowExcludeMode", value)); } +const StringUtil::EnumStringLiteral *GetWindowMergeSortStageValues() { + static constexpr StringUtil::EnumStringLiteral values[] { + { static_cast(WindowMergeSortStage::INIT), "INIT" }, + { static_cast(WindowMergeSortStage::COMBINE), "COMBINE" }, + { static_cast(WindowMergeSortStage::FINALIZE), "FINALIZE" }, + { static_cast(WindowMergeSortStage::SORTED), "SORTED" }, + { static_cast(WindowMergeSortStage::FINISHED), "FINISHED" } + }; + return values; +} + +template<> +const char* EnumUtil::ToChars(WindowMergeSortStage value) { + return StringUtil::EnumToString(GetWindowMergeSortStageValues(), 5, "WindowMergeSortStage", static_cast(value)); +} + +template<> +WindowMergeSortStage EnumUtil::FromString(const char *value) { + return static_cast(StringUtil::StringToEnum(GetWindowMergeSortStageValues(), 5, "WindowMergeSortStage", value)); +} + } diff --git a/src/duckdb/src/common/error_data.cpp b/src/duckdb/src/common/error_data.cpp index de6685ce9..2ddf94af6 100644 --- a/src/duckdb/src/common/error_data.cpp +++ b/src/duckdb/src/common/error_data.cpp @@ -91,6 +91,17 @@ const ExceptionType &ErrorData::Type() const { return this->type; } +void ErrorData::Merge(const ErrorData &other) { + if (!other.HasError()) { + return; + } + if (!HasError()) { + *this = other; + return; + } + final_message += "\n\n" + other.Message(); +} + bool ErrorData::operator==(const ErrorData &other) const { if (initialized != other.initialized) { return false; diff --git a/src/duckdb/src/common/exception_format_value.cpp b/src/duckdb/src/common/exception_format_value.cpp index ddef4e10c..a77ab7f38 100644 --- a/src/duckdb/src/common/exception_format_value.cpp +++ b/src/duckdb/src/common/exception_format_value.cpp @@ -15,6 +15,9 @@ ExceptionFormatValue::ExceptionFormatValue(double dbl_val) ExceptionFormatValue::ExceptionFormatValue(int64_t int_val) : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER), int_val(int_val) { } +ExceptionFormatValue::ExceptionFormatValue(idx_t uint_val) + : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER), int_val(Hugeint::Convert(uint_val)) { +} ExceptionFormatValue::ExceptionFormatValue(hugeint_t huge_val) : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(Hugeint::ToString(huge_val)) { } @@ -68,6 +71,10 @@ ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value) { return ExceptionFormatValue(string(value)); } template <> +ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value) { + return ExceptionFormatValue(value); +} +template <> ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value) { return ExceptionFormatValue(value); } diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp index b9986f258..1d3160814 100644 --- a/src/duckdb/src/common/extra_type_info.cpp +++ b/src/duckdb/src/common/extra_type_info.cpp @@ -488,4 +488,23 @@ shared_ptr IntegerLiteralTypeInfo::Copy() const { return make_shared_ptr(*this); } +//===--------------------------------------------------------------------===// +// Template Type Info +//===--------------------------------------------------------------------===// +TemplateTypeInfo::TemplateTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::TEMPLATE_TYPE_INFO) { +} + +TemplateTypeInfo::TemplateTypeInfo(string name_p) + : ExtraTypeInfo(ExtraTypeInfoType::TEMPLATE_TYPE_INFO), name(std::move(name_p)) { +} + +bool TemplateTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { + auto &other = other_p->Cast(); + return name == other.name; +} + +shared_ptr TemplateTypeInfo::Copy() const { + return make_shared_ptr(*this); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/file_buffer.cpp b/src/duckdb/src/common/file_buffer.cpp index cd2f627b6..8e108ddc1 100644 --- a/src/duckdb/src/common/file_buffer.cpp +++ b/src/duckdb/src/common/file_buffer.cpp @@ -103,9 +103,9 @@ void FileBuffer::Resize(BlockManager &block_manager) { ResizeInternal(block_manager.GetBlockSize(), block_manager.GetBlockHeaderSize()); } -void FileBuffer::Read(FileHandle &handle, uint64_t location) { +void FileBuffer::Read(QueryContext context, FileHandle &handle, uint64_t location) { D_ASSERT(type != FileBufferType::TINY_BUFFER); - handle.Read(internal_buffer, internal_size, location); + handle.Read(context, internal_buffer, internal_size, location); } void FileBuffer::Write(QueryContext context, FileHandle &handle, const uint64_t location) { diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp index f661b2de5..a24d92206 100644 --- a/src/duckdb/src/common/file_system.cpp +++ b/src/duckdb/src/common/file_system.cpp @@ -14,6 +14,7 @@ #include "duckdb/main/extension_helper.hpp" #include "duckdb/common/windows_util.hpp" #include "duckdb/common/operator/multiply.hpp" +#include "duckdb/logging/log_manager.hpp" #include #include @@ -703,6 +704,11 @@ void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes), location); } +void FileHandle::Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { + // FIXME: Add profiling. + file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes), location); +} + void FileHandle::Write(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location) { // FIXME: Add profiling. file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes), location); diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp index f1672aa8d..1f246d241 100644 --- a/src/duckdb/src/common/local_file_system.cpp +++ b/src/duckdb/src/common/local_file_system.cpp @@ -10,6 +10,7 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/main/database.hpp" #include "duckdb/logging/file_system_logger.hpp" +#include "duckdb/logging/log_manager.hpp" #include #include @@ -589,7 +590,7 @@ timestamp_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { throw IOException("Failed to get last modified time for file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); } - return Timestamp::FromTimeT(s.st_mtime); + return Timestamp::FromEpochSeconds(s.st_mtime); } FileType LocalFileSystem::GetFileType(FileHandle &handle) { @@ -794,12 +795,18 @@ std::string LocalFileSystem::GetLastErrorAsString() { if (errorMessageID == 0) return std::string(); // No error message has been recorded - LPSTR messageBuffer = nullptr; - idx_t size = - FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, errorMessageID, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); + LPWSTR messageBuffer = nullptr; + idx_t size = FormatMessageW( + FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, + errorMessageID, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&messageBuffer, 0, NULL); - std::string message(messageBuffer, size); + if (size == 0) { + return std::string(); + } + + // Convert wide string to UTF-8 + std::wstring wideMessage(messageBuffer, size); + std::string message = WindowsUtil::UnicodeToUTF8(wideMessage.c_str()); // Free the buffer. LocalFree(messageBuffer); diff --git a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp index 6f9b22695..c4537738d 100644 --- a/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp +++ b/src/duckdb/src/common/multi_file/multi_file_column_mapper.cpp @@ -1091,11 +1091,12 @@ unique_ptr MultiFileColumnMapper::CreateFilters(map(global_type, local_id); + expr = make_uniq(global_type, local_id); } } return result; diff --git a/src/duckdb/src/common/sort/partition_state.cpp b/src/duckdb/src/common/sort/partition_state.cpp index 51726df9c..2a0a65895 100644 --- a/src/duckdb/src/common/sort/partition_state.cpp +++ b/src/duckdb/src/common/sort/partition_state.cpp @@ -89,7 +89,7 @@ PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - external = ClientConfig::GetConfig(context).GetSetting(context); + external = ClientConfig::GetConfig(context).force_external; const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); while (max_bits < 10 && (thread_pages >> max_bits) > 1) { diff --git a/src/duckdb/src/common/sorting/hashed_sort.cpp b/src/duckdb/src/common/sorting/hashed_sort.cpp new file mode 100644 index 000000000..613d4675b --- /dev/null +++ b/src/duckdb/src/common/sorting/hashed_sort.cpp @@ -0,0 +1,481 @@ +#include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/parallel/thread_context.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// HashedSortGroup +//===--------------------------------------------------------------------===// +HashedSortGroup::HashedSortGroup(ClientContext &context, const Orders &orders, const Types &input_types, + idx_t group_idx) + : group_idx(group_idx), tasks_completed(0) { + vector projection_map; + sort = make_uniq(context, orders, input_types, projection_map); + sort_global = sort->GetGlobalSinkState(context); +} + +//===--------------------------------------------------------------------===// +// HashedSortGlobalSinkState +//===--------------------------------------------------------------------===// +void HashedSortGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, + const vector> &partition_bys, + const Orders &order_bys, + const vector> &partition_stats) { + + // we sort by both 1) partition by expression list and 2) order by expressions + const auto partition_cols = partition_bys.size(); + for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { + auto &pexpr = partition_bys[prt_idx]; + + if (partition_stats.empty() || !partition_stats[prt_idx]) { + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), nullptr); + } else { + orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), + partition_stats[prt_idx]->ToUnique()); + } + partitions.emplace_back(orders.back().Copy()); + } + + for (const auto &order : order_bys) { + orders.emplace_back(order.Copy()); + } +} + +HashedSortGlobalSinkState::HashedSortGlobalSinkState(ClientContext &context, + const vector> &partition_bys, + const vector &order_bys, + const Types &input_types, + const vector> &partition_stats, + idx_t estimated_cardinality) + : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), + fixed_bits(0), payload_types(input_types), max_bits(1), count(0) { + + GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); + + // The payload prefix is the same as the input schema + for (column_t i = 0; i < payload_types.size(); ++i) { + scan_ids.emplace_back(i); + } + + // We have to compute ordering expressions ourselves and materialise them. + // To do this, we scan the orders and add generate extra payload columns that we can reference. + for (auto &order : orders) { + auto &expr = *order.expression; + if (expr.GetExpressionClass() == ExpressionClass::BOUND_REF) { + auto &ref = expr.Cast(); + sort_ids.emplace_back(ref.index); + continue; + } + + // Real expression - replace with a ref and save the expression + auto saved = std::move(order.expression); + const auto type = saved->return_type; + const auto idx = payload_types.size(); + order.expression = make_uniq(type, idx); + sort_ids.emplace_back(idx); + payload_types.emplace_back(type); + sort_exprs.emplace_back(std::move(saved)); + } + + const auto memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); + const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); + while (max_bits < 8 && (thread_pages >> max_bits) > 1) { + ++max_bits; + } + + grouping_types_ptr = make_shared_ptr(); + if (!orders.empty()) { + if (partitions.empty()) { + // Sort early into a dedicated hash group if we only sort. + grouping_types_ptr->Initialize(payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); + auto new_group = make_uniq(context, orders, payload_types, idx_t(0)); + hash_groups.emplace_back(std::move(new_group)); + } else { + auto types = payload_types; + types.push_back(LogicalType::HASH); + grouping_types_ptr->Initialize(types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); + Rehash(estimated_cardinality); + } + } +} + +unique_ptr HashedSortGlobalSinkState::CreatePartition(idx_t new_bits) const { + const auto hash_col_idx = payload_types.size(); + return make_uniq(buffer_manager, grouping_types_ptr, new_bits, hash_col_idx); +} + +void HashedSortGlobalSinkState::Rehash(idx_t cardinality) { + // Have we started to combine? Then just live with it. + if (fixed_bits) { + return; + } + // Is the average partition size too large? + const idx_t partition_size = DEFAULT_ROW_GROUP_SIZE; + const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; + auto new_bits = bits ? bits : 4; + while (new_bits < max_bits && (cardinality / RadixPartitioning::NumberOfPartitions(new_bits)) > partition_size) { + ++new_bits; + } + + // Repartition the grouping data + if (new_bits != bits) { + grouping_data = CreatePartition(new_bits); + } +} + +void HashedSortGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { + // We are done if the local_partition is right sized. + const auto new_bits = grouping_data->GetRadixBits(); + if (local_partition->GetRadixBits() == new_bits) { + return; + } + + // If the local partition is now too small, flush it and reallocate + auto new_partition = CreatePartition(new_bits); + local_partition->FlushAppendState(*local_append); + local_partition->Repartition(context, *new_partition); + + local_partition = std::move(new_partition); + local_append = make_uniq(); + local_partition->InitializeAppendState(*local_append); +} + +void HashedSortGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, + GroupingAppend &partition_append) { + // Make sure grouping_data doesn't change under us. + lock_guard guard(lock); + + if (!local_partition) { + local_partition = CreatePartition(grouping_data->GetRadixBits()); + partition_append = make_uniq(); + local_partition->InitializeAppendState(*partition_append); + return; + } + + // Grow the groups if they are too big + Rehash(count); + + // Sync local partition to have the same bit count + SyncLocalPartition(local_partition, partition_append); +} + +void HashedSortGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, + GroupingAppend &local_append) { + if (!local_partition) { + return; + } + local_partition->FlushAppendState(*local_append); + + // Make sure grouping_data doesn't change under us. + // Combine has an internal mutex, so this is single-threaded anyway. + lock_guard guard(lock); + SyncLocalPartition(local_partition, local_append); + fixed_bits = true; + + // We now know the number of hash_groups (some may be empty) + auto &groups = local_partition->GetPartitions(); + if (hash_groups.empty()) { + hash_groups.resize(groups.size()); + } + + // Create missing HashedSortGroups inside the mutex + for (idx_t group_idx = 0; group_idx < groups.size(); ++group_idx) { + auto &hash_group = hash_groups[group_idx]; + if (hash_group) { + continue; + } + + auto &group_data = groups[group_idx]; + if (group_data->Count()) { + hash_group = make_uniq(context, orders, payload_types, group_idx); + } + } +} + +void HashedSortGlobalSinkState::Finalize(ClientContext &context, InterruptState &interrupt_state) { + // OVER() + if (unsorted) { + return; + } + + // OVER(...) + D_ASSERT(!hash_groups.empty()); + for (auto &hash_group : hash_groups) { + if (!hash_group) { + continue; + } + OperatorSinkFinalizeInput finalize {*hash_group->sort_global, interrupt_state}; + hash_group->sort->Finalize(context, finalize); + } +} + +bool HashedSortGlobalSinkState::HasMergeTasks() const { + return (!hash_groups.empty()); +} + +//===--------------------------------------------------------------------===// +// HashedSortLocalSinkState +//===--------------------------------------------------------------------===// +HashedSortLocalSinkState::HashedSortLocalSinkState(ExecutionContext &context, HashedSortGlobalSinkState &gstate) + : gstate(gstate), allocator(Allocator::Get(context.client)), hash_exec(context.client), sort_exec(context.client) { + + vector group_types; + for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { + auto &pexpr = *gstate.partitions[prt_idx].expression.get(); + group_types.push_back(pexpr.return_type); + hash_exec.AddExpression(pexpr); + } + sort_col_count = gstate.orders.size() + group_types.size(); + + vector sort_types; + for (const auto &expr : gstate.sort_exprs) { + sort_types.emplace_back(expr->return_type); + sort_exec.AddExpression(*expr); + } + sort_chunk.Initialize(context.client, sort_types); + + if (sort_col_count) { + auto payload_types = gstate.payload_types; + if (!group_types.empty()) { + // OVER(PARTITION BY...) + group_chunk.Initialize(allocator, group_types); + payload_types.emplace_back(LogicalType::HASH); + } else { + // OVER(ORDER BY...) + for (idx_t ord_idx = 0; ord_idx < gstate.orders.size(); ord_idx++) { + auto &pexpr = *gstate.orders[ord_idx].expression.get(); + group_types.push_back(pexpr.return_type); + hash_exec.AddExpression(pexpr); + } + group_chunk.Initialize(allocator, group_types); + + // Single partition + auto &sort = *gstate.hash_groups[0]->sort; + sort_local = sort.GetLocalSinkState(context); + } + // OVER(...) + payload_chunk.Initialize(allocator, payload_types); + } +} + +void HashedSortLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { + const auto count = input_chunk.size(); + D_ASSERT(group_chunk.ColumnCount() > 0); + + // OVER(PARTITION BY...) (hash grouping) + group_chunk.Reset(); + hash_exec.Execute(input_chunk, group_chunk); + VectorOperations::Hash(group_chunk.data[0], hash_vector, count); + for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { + VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); + } +} + +void HashedSortLocalSinkState::Sink(ExecutionContext &context, DataChunk &input_chunk) { + gstate.count += input_chunk.size(); + + // Window::Sink: + // PartitionedTupleData::Append + // Sort::Sink + // ColumnDataCollection::Append + + // OVER() + if (sort_col_count == 0) { + if (!unsorted) { + unsorted = make_uniq(context.client, gstate.payload_types); + unsorted->InitializeAppend(unsorted_append); + } + unsorted->Append(unsorted_append, input_chunk); + return; + } + + // Payload prefix is the input data + payload_chunk.Reset(); + for (column_t i = 0; i < input_chunk.ColumnCount(); ++i) { + payload_chunk.data[i].Reference(input_chunk.data[i]); + } + + // Compute any sort columns that are not references and append them to the end of the payload + if (!gstate.sort_exprs.empty()) { + sort_chunk.Reset(); + sort_exec.Execute(input_chunk, sort_chunk); + for (column_t i = 0; i < sort_chunk.ColumnCount(); ++i) { + payload_chunk.data[input_chunk.ColumnCount() + i].Reference(sort_chunk.data[i]); + } + } + payload_chunk.SetCardinality(input_chunk); + + // OVER(ORDER BY...) + if (sort_local) { + auto &hash_group = *gstate.hash_groups[0]; + OperatorSinkInput input {*hash_group.sort_global, *sort_local, interrupt}; + hash_group.sort->Sink(context, payload_chunk, input); + return; + } + + // OVER(PARTITION BY...) + auto &hash_vector = payload_chunk.data.back(); + Hash(input_chunk, hash_vector); + for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { + payload_chunk.data[col_idx].Reference(input_chunk.data[col_idx]); + } + + gstate.UpdateLocalPartition(local_grouping, grouping_append); + local_grouping->Append(*grouping_append, payload_chunk); +} + +void HashedSortLocalSinkState::Combine(ExecutionContext &context) { + // Window::Combine: + // Sort::Sink then Sort::Combine (per hash partition) + // Sort::Combine + // ColumnDataCollection::Combine + + // OVER() + if (sort_col_count == 0) { + // Only one partition again, so need a global lock. + lock_guard glock(gstate.lock); + if (gstate.unsorted) { + if (unsorted) { + gstate.unsorted->Combine(*unsorted); + unsorted.reset(); + } + } else { + gstate.unsorted = std::move(unsorted); + } + return; + } + + // OVER(ORDER BY...) + if (sort_local) { + auto &hash_group = *gstate.hash_groups[0]; + auto &sort = *hash_group.sort; + OperatorSinkCombineInput input {*hash_group.sort_global, *sort_local, interrupt}; + sort.Combine(context, input); + sort_local.reset(); + return; + } + + // OVER(PARTITION BY...) + if (!local_grouping) { + return; + } + + // Flush our data and lock the bit count + gstate.CombineLocalPartition(local_grouping, grouping_append); + + // Don't scan the hash column + vector column_ids; + for (column_t i = 0; i < gstate.payload_types.size(); ++i) { + column_ids.emplace_back(i); + } + + // Loop over the partitions and add them to each hash group's global sort state + TupleDataScanState scan_state; + DataChunk chunk; + auto &partitions = local_grouping->GetPartitions(); + for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { + auto &partition = *partitions[hash_bin]; + if (!partition.Count()) { + continue; + } + + partition.InitializeScan(scan_state, column_ids, TupleDataPinProperties::DESTROY_AFTER_DONE); + if (chunk.data.empty()) { + partition.InitializeScanChunk(scan_state, chunk); + } + + auto &hash_group = *gstate.hash_groups[hash_bin]; + auto &sort = *hash_group.sort; + sort_local = sort.GetLocalSinkState(context); + OperatorSinkInput sink {*hash_group.sort_global, *sort_local, interrupt}; + while (partition.Scan(scan_state, chunk)) { + sort.Sink(context, chunk, sink); + } + + OperatorSinkCombineInput combine {*hash_group.sort_global, *sort_local, interrupt}; + sort.Combine(context, combine); + } +} + +//===--------------------------------------------------------------------===// +// HashedSortMaterializeTask +//===--------------------------------------------------------------------===// +class HashedSortMaterializeTask : public ExecutorTask { +public: + HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, const PhysicalOperator &op, + HashedSortGroup &hash_group, idx_t tasks_scheduled, + optional_ptr callback); + + TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; + + string TaskType() const override { + return "HashedSortMaterializeTask"; + } + +private: + Pipeline &pipeline; + HashedSortGroup &hash_group; + const idx_t tasks_scheduled; + optional_ptr callback; +}; + +HashedSortMaterializeTask::HashedSortMaterializeTask(Pipeline &pipeline, shared_ptr event, + const PhysicalOperator &op, HashedSortGroup &hash_group, + idx_t tasks_scheduled, optional_ptr callback) + : ExecutorTask(pipeline.GetClientContext(), std::move(event), op), pipeline(pipeline), hash_group(hash_group), + tasks_scheduled(tasks_scheduled), callback(callback) { +} + +TaskExecutionResult HashedSortMaterializeTask::ExecuteTask(TaskExecutionMode mode) { + ExecutionContext execution(pipeline.GetClientContext(), *thread_context, &pipeline); + auto &sort = *hash_group.sort; + auto &sort_global = *hash_group.sort_source; + auto sort_local = sort.GetLocalSourceState(execution, sort_global); + InterruptState interrupt((weak_ptr(shared_from_this()))); + OperatorSourceInput input {sort_global, *sort_local, interrupt}; + sort.MaterializeColumnData(execution, input); + if (++hash_group.tasks_completed == tasks_scheduled && callback) { + hash_group.sorted = sort.GetColumnData(input); + callback->OnSortedGroup(hash_group); + } + + event->FinishTask(); + return TaskExecutionResult::TASK_FINISHED; +} + +//===--------------------------------------------------------------------===// +// HashedSortMaterializeEvent +//===--------------------------------------------------------------------===// +HashedSortMaterializeEvent::HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, + const PhysicalOperator &op, HashedSortCallback *callback) + : BasePipelineEvent(pipeline), gstate(gstate), op(op), callback(callback) { +} + +void HashedSortMaterializeEvent::Schedule() { + auto &client = pipeline->GetClientContext(); + + // Schedule as many tasks per hash group as the sort will allow + auto &ts = TaskScheduler::GetScheduler(client); + const auto num_threads = NumericCast(ts.NumberOfThreads()); + + vector> merge_tasks; + for (auto &hash_group : gstate.hash_groups) { + if (!hash_group) { + continue; + } + auto &sort = *hash_group->sort; + auto &global_sink = *hash_group->sort_global; + hash_group->sort_source = sort.GetGlobalSourceState(client, global_sink); + const auto tasks_scheduled = MinValue(num_threads, hash_group->sort_source->MaxThreads()); + for (idx_t t = 0; t < tasks_scheduled; ++t) { + merge_tasks.emplace_back(make_uniq(*pipeline, shared_from_this(), op, + *hash_group, tasks_scheduled, callback)); + } + } + + SetTasks(std::move(merge_tasks)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sort.cpp b/src/duckdb/src/common/sorting/sort.cpp index 4d3ccd6a2..2159878ff 100644 --- a/src/duckdb/src/common/sorting/sort.cpp +++ b/src/duckdb/src/common/sorting/sort.cpp @@ -4,6 +4,7 @@ #include "duckdb/common/sorting/sort_key.hpp" #include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/sorting/sorted_run_merger.hpp" +#include "duckdb/common/types/batched_data_collection.hpp" #include "duckdb/common/types/row/tuple_data_collection.hpp" #include "duckdb/function/create_sort_key.hpp" #include "duckdb/function/function_binder.hpp" @@ -383,6 +384,9 @@ class SortGlobalSourceState : public GlobalSourceState { //! Sorted run merger and associated global state SortedRunMerger merger; unique_ptr merger_global_state; + + //! Materialized column data (optional) + unique_ptr column_data; }; class SortLocalSourceState : public LocalSourceState { @@ -432,4 +436,78 @@ OperatorPartitionData Sort::GetPartitionData(ExecutionContext &context, DataChun partition_info); } +//===--------------------------------------------------------------------===// +// Non-Standard Interface +//===--------------------------------------------------------------------===// +SourceResultType Sort::MaterializeColumnData(ExecutionContext &context, OperatorSourceInput &input) const { + auto &gstate = input.global_state.Cast(); + + // Derive output types + vector types; + types.resize(output_projection_columns.size()); + for (auto &opc : output_projection_columns) { + const auto &type = opc.is_payload ? payload_layout->GetTypes()[opc.layout_col_idx] + : StructType::GetChildType(decode_sort_key->return_type, opc.layout_col_idx); + types[opc.output_col_idx] = type; + } + + // Initialize scan chunk + DataChunk chunk; + chunk.Initialize(context.client, types); + + // Initialize local output collection + auto local_column_data = make_uniq(context.client, types, true); + + while (true) { + // Check for interrupts since this could be a long-running task + if (context.client.interrupted.load(std::memory_order_relaxed)) { + throw InterruptException(); + } + // Scan a chunk + chunk.Reset(); + GetData(context, chunk, input); + if (chunk.size() == 0) { + break; + } + // Append to the output collection + const auto batch_index = + GetPartitionData(context, chunk, input.global_state, input.local_state, OperatorPartitionInfo()) + .batch_index; + local_column_data->Append(chunk, batch_index); + } + + // Merge into global output collection + auto guard = gstate.Lock(); + if (!gstate.column_data) { + gstate.column_data = std::move(local_column_data); + } else { + gstate.column_data->Merge(*local_column_data); + } + + // Return type indicates whether materialization is done + const auto progress_data = GetProgress(context.client, input.global_state); + return progress_data.done == progress_data.total ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; +} + +unique_ptr Sort::GetColumnData(OperatorSourceInput &input) const { + auto &gstate = input.global_state.Cast(); + auto guard = gstate.Lock(); + return gstate.column_data->FetchCollection(); +} + +SourceResultType Sort::MaterializeSortedRun(ExecutionContext &context, OperatorSourceInput &input) const { + auto &gstate = input.global_state.Cast(); + if (gstate.merger.total_count == 0) { + return SourceResultType::FINISHED; + } + auto &lstate = input.local_state.Cast(); + OperatorSourceInput merger_input {*gstate.merger_global_state, *lstate.merger_local_state, input.interrupt_state}; + return gstate.merger.MaterializeMerge(context, merger_input); +} + +unique_ptr Sort::GetSortedRun(GlobalSourceState &global_state) { + auto &gstate = global_state.Cast(); + return gstate.merger.GetMaterialized(gstate); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/sorting/sorted_run.cpp b/src/duckdb/src/common/sorting/sorted_run.cpp index 4ad91ce5d..0554990fa 100644 --- a/src/duckdb/src/common/sorting/sorted_run.cpp +++ b/src/duckdb/src/common/sorting/sorted_run.cpp @@ -14,7 +14,7 @@ SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_l : context(context_p), key_data(make_uniq(BufferManager::GetBufferManager(context), std::move(key_layout))), payload_data( - payload_layout->ColumnCount() != 0 + payload_layout && payload_layout->ColumnCount() != 0 ? make_uniq(BufferManager::GetBufferManager(context), std::move(payload_layout)) : nullptr), is_index_sort(is_index_sort_p), finalized(false) { @@ -24,6 +24,15 @@ SortedRun::SortedRun(ClientContext &context_p, shared_ptr key_l } } +unique_ptr SortedRun::CreateRunForMaterialization() const { + auto res = make_uniq(context, key_data->GetLayoutPtr(), + payload_data ? payload_data->GetLayoutPtr() : nullptr, is_index_sort); + res->key_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; + res->payload_append_state.pin_state.properties = TupleDataPinProperties::UNPIN_AFTER_DONE; + res->finalized = true; + return res; +} + SortedRun::~SortedRun() { } diff --git a/src/duckdb/src/common/sorting/sorted_run_merger.cpp b/src/duckdb/src/common/sorting/sorted_run_merger.cpp index b99889937..eb879edc5 100644 --- a/src/duckdb/src/common/sorting/sorted_run_merger.cpp +++ b/src/duckdb/src/common/sorting/sorted_run_merger.cpp @@ -100,7 +100,7 @@ class SortedRunMergerLocalState : public LocalSourceState { //! Whether this thread has finished the work it has been assigned bool TaskFinished() const; //! Do the work this thread has been assigned - void ExecuteTask(SortedRunMergerGlobalState &gstate, DataChunk &chunk); + void ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk); private: //! Computes upper partition boundaries using K-way Merge Path @@ -127,6 +127,11 @@ class SortedRunMergerLocalState : public LocalSourceState { template void TemplatedScanPartition(SortedRunMergerGlobalState &gstate, DataChunk &chunk); + //! Materialize the merge + void MaterializePartition(SortedRunMergerGlobalState &gstate); + template + unique_ptr TemplatedMaterializePartition(SortedRunMergerGlobalState &gstate); + public: //! Types for templating const BlockIteratorStateType iterator_state_type; @@ -275,6 +280,9 @@ class SortedRunMergerGlobalState : public GlobalSourceState { mutex destroy_lock; idx_t destroy_partition_idx; + + mutex materialized_partition_lock; + vector> materialized_partitions; }; //===--------------------------------------------------------------------===// @@ -320,7 +328,7 @@ bool SortedRunMergerLocalState::TaskFinished() const { } } -void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, DataChunk &chunk) { +void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, optional_ptr chunk) { D_ASSERT(task != SortedRunMergerTask::FINISHED); switch (task) { case SortedRunMergerTask::COMPUTE_BOUNDARIES: @@ -336,8 +344,12 @@ void SortedRunMergerLocalState::ExecuteTask(SortedRunMergerGlobalState &gstate, task = SortedRunMergerTask::SCAN_PARTITION; break; case SortedRunMergerTask::SCAN_PARTITION: - ScanPartition(gstate, chunk); - if (chunk.size() == 0) { + if (chunk) { + ScanPartition(gstate, *chunk); + } else { + MaterializePartition(gstate); + } + if (!chunk || chunk->size() == 0) { gstate.DestroyScannedData(); gstate.partitions[partition_idx.GetIndex()]->scanned = true; gstate.total_scanned += merged_partition_count; @@ -763,6 +775,104 @@ void SortedRunMergerLocalState::TemplatedScanPartition(SortedRunMergerGlobalStat chunk.SetCardinality(count); } +void SortedRunMergerLocalState::MaterializePartition(SortedRunMergerGlobalState &gstate) { + unique_ptr sorted_run; + switch (sort_key_type) { + case SortKeyType::NO_PAYLOAD_FIXED_8: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::NO_PAYLOAD_FIXED_16: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::NO_PAYLOAD_FIXED_24: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::NO_PAYLOAD_FIXED_32: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::NO_PAYLOAD_VARIABLE_32: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::PAYLOAD_FIXED_16: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::PAYLOAD_FIXED_24: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::PAYLOAD_FIXED_32: + sorted_run = TemplatedMaterializePartition(gstate); + break; + case SortKeyType::PAYLOAD_VARIABLE_32: + sorted_run = TemplatedMaterializePartition(gstate); + break; + default: + throw NotImplementedException("SortedRunMergerLocalState::MaterializePartition for %s", + EnumUtil::ToString(sort_key_type)); + } + + // Add to global state + lock_guard guard(gstate.materialized_partition_lock); + gstate.materialized_partitions.resize(partition_idx.GetIndex()); + gstate.materialized_partitions[partition_idx.GetIndex()] = std::move(sorted_run); +} + +template +unique_ptr SortedRunMergerLocalState::TemplatedMaterializePartition(SortedRunMergerGlobalState &gstate) { + using SORT_KEY = SortKey; + const auto merged_partition_keys = reinterpret_cast(merged_partition.get()) + merged_partition_index; + + TupleDataChunkState key_data_input; + const auto key_locations = FlatVector::GetData(key_data_input.row_locations); + const auto key_heap_locations = FlatVector::GetData(key_data_input.heap_locations); + const auto key_heap_sizes = FlatVector::GetData(key_data_input.heap_sizes); + + TupleDataChunkState payload_data_input; + const auto payload_locations = FlatVector::GetData(payload_data_input.row_locations); + + auto sorted_run = gstate.merger.sorted_runs[0]->CreateRunForMaterialization(); + + while (merged_partition_index < merged_partition_count) { + const auto count = MinValue(merged_partition_count - merged_partition_index, STANDARD_VECTOR_SIZE); + for (idx_t i = 0; i < count + count; i++) { + auto &key = merged_partition_keys[merged_partition_index + i]; + key_locations[i] = data_ptr_cast(&key); + if (!SORT_KEY::CONSTANT_SIZE) { + key_heap_locations[i] = key.GetData(); + key_heap_sizes[i] = key.GetHeapSize(); + } + if (SORT_KEY::HAS_PAYLOAD) { + payload_locations[i] = key.GetPayload(); + } + } + + sorted_run->key_append_state.chunk_state.heap_sizes.Reference(key_data_input.heap_sizes); + sorted_run->key_data->Build(sorted_run->key_append_state.pin_state, sorted_run->key_append_state.chunk_state, 0, + count); + sorted_run->key_data->CopyRows(sorted_run->key_append_state.chunk_state, key_data_input, + *FlatVector::IncrementalSelectionVector(), count); + + if (SORT_KEY::HAS_PAYLOAD) { + if (!sorted_run->payload_data->GetLayout().AllConstant()) { + sorted_run->payload_data->FindHeapPointers(payload_data_input, count); + } + sorted_run->payload_append_state.chunk_state.heap_sizes.Reference(key_data_input.heap_sizes); + sorted_run->payload_data->Build(sorted_run->payload_append_state.pin_state, + sorted_run->payload_append_state.chunk_state, 0, count); + sorted_run->payload_data->CopyRows(sorted_run->payload_append_state.chunk_state, payload_data_input, + *FlatVector::IncrementalSelectionVector(), count); + } + + merged_partition_index += count; + } + + sorted_run->key_data->FinalizePinState(sorted_run->key_append_state.pin_state); + if (sorted_run->payload_data) { + sorted_run->payload_data->FinalizePinState(sorted_run->payload_append_state.pin_state); + } + + return sorted_run; +} + //===--------------------------------------------------------------------===// // Sorted Run Merger //===--------------------------------------------------------------------===// @@ -791,7 +901,7 @@ SourceResultType SortedRunMerger::GetData(ExecutionContext &, DataChunk &chunk, while (chunk.size() == 0) { if (!lstate.TaskFinished() || gstate.AssignTask(lstate)) { - lstate.ExecuteTask(gstate, chunk); + lstate.ExecuteTask(gstate, &chunk); } else { break; } @@ -816,4 +926,44 @@ ProgressData SortedRunMerger::GetProgress(ClientContext &, GlobalSourceState &gs return res; } +//===--------------------------------------------------------------------===// +// Non-Standard Interface +//===--------------------------------------------------------------------===// +SourceResultType SortedRunMerger::MaterializeMerge(ExecutionContext &, OperatorSourceInput &input) const { + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + + while (true) { + if (!lstate.TaskFinished() || gstate.AssignTask(lstate)) { + lstate.ExecuteTask(gstate, nullptr); + } else { + break; + } + } + + if (gstate.total_scanned == total_count) { + // This signals that the data has been fully materialized + return SourceResultType::FINISHED; + } + // This signals that no more tasks are left, but that the data has not yet been fully materialized + return SourceResultType::HAVE_MORE_OUTPUT; +} + +unique_ptr SortedRunMerger::GetMaterialized(GlobalSourceState &global_state) { + auto &gstate = global_state.Cast(); + if (gstate.materialized_partitions.empty()) { + D_ASSERT(total_count == 0); + return nullptr; + } + auto &target = *gstate.materialized_partitions[0]; + for (idx_t i = 1; i < gstate.materialized_partitions.size(); i++) { + auto &source = *gstate.materialized_partitions[i]; + target.key_data->Combine(*source.key_data); + if (target.payload_data) { + target.payload_data->Combine(*source.payload_data); + } + } + return std::move(gstate.materialized_partitions[0]); +} + } // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp index 6eae451ad..251736dd4 100644 --- a/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp +++ b/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp @@ -160,6 +160,41 @@ string AdjustTextForRendering(string source, idx_t max_render_width) { return string(half_spaces + extra_left_space, ' ') + source + string(half_spaces, ' '); } +string TextTreeRenderer::FormatNumber(const string &input) { + if (config.decimal_separator == '\0' && config.thousand_separator == '\0') { + // no thousand separator + return input; + } + // first check how many digits there are (preceding any decimal point) + idx_t character_count = 0; + for (auto c : input) { + if (!StringUtil::CharacterIsDigit(c)) { + break; + } + character_count++; + } + // find the position of the first thousand separator + idx_t separator_position = character_count % 3 == 0 ? 3 : character_count % 3; + // now add the thousand separators + string result; + for (idx_t c = 0; c < character_count; c++) { + if (c == separator_position && config.thousand_separator != '\0') { + result += config.thousand_separator; + separator_position += 3; + } + result += input[c]; + } + // add any remaining characters + for (idx_t c = character_count; c < input.size(); c++) { + if (input[c] == '.' && config.decimal_separator != '\0') { + result += config.decimal_separator; + } else { + result += input[c]; + } + } + return result; +} + void TextTreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_t y) { // we first need to figure out how high our boxes are going to be vector> extra_info; @@ -246,7 +281,7 @@ void TextTreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_ if (render_y + 1 == extra_height && render_text.empty()) { auto entry = node->extra_text.find(RenderTreeNode::CARDINALITY); if (entry != node->extra_text.end()) { - render_text = entry->second + " Rows"; + render_text = FormatNumber(entry->second) + " row" + (entry->second == "1" ? "" : "s"); } } if (render_y == extra_height && render_text.empty()) { @@ -257,14 +292,16 @@ void TextTreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_ // we only render estimated cardinality if there is no real cardinality auto entry = node->extra_text.find(RenderTreeNode::ESTIMATED_CARDINALITY); if (entry != node->extra_text.end()) { - render_text = "~" + entry->second + " Rows"; + render_text = + "~" + FormatNumber(entry->second) + " row" + (entry->second == "1" ? "" : "s"); } } if (node->extra_text.find(RenderTreeNode::CARDINALITY) == node->extra_text.end()) { // we only render estimated cardinality if there is no real cardinality auto entry = node->extra_text.find(RenderTreeNode::ESTIMATED_CARDINALITY); if (entry != node->extra_text.end()) { - render_text = "~" + entry->second + " Rows"; + render_text = + "~" + FormatNumber(entry->second) + " row" + (entry->second == "1" ? "" : "s"); } } } diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp index 25d75e377..08e033c99 100644 --- a/src/duckdb/src/common/types.cpp +++ b/src/duckdb/src/common/types.cpp @@ -25,6 +25,7 @@ #include "duckdb/main/database_manager.hpp" #include "duckdb/parser/keyword_helper.hpp" #include "duckdb/parser/parser.hpp" +#include "duckdb/main/settings.hpp" #include @@ -110,11 +111,11 @@ PhysicalType LogicalType::GetInternalType() { width, DecimalType::MaxWidth()); } } + case LogicalTypeId::BIGNUM: case LogicalTypeId::VARCHAR: case LogicalTypeId::CHAR: case LogicalTypeId::BLOB: case LogicalTypeId::BIT: - case LogicalTypeId::VARINT: return PhysicalType::VARCHAR; case LogicalTypeId::INTERVAL: return PhysicalType::INTERVAL; @@ -151,6 +152,7 @@ PhysicalType LogicalType::GetInternalType() { case LogicalTypeId::UNKNOWN: case LogicalTypeId::STRING_LITERAL: case LogicalTypeId::INTEGER_LITERAL: + case LogicalTypeId::TEMPLATE: return PhysicalType::INVALID; case LogicalTypeId::USER: return PhysicalType::UNKNOWN; @@ -203,7 +205,7 @@ constexpr const LogicalTypeId LogicalType::VARCHAR; constexpr const LogicalTypeId LogicalType::BLOB; constexpr const LogicalTypeId LogicalType::BIT; -constexpr const LogicalTypeId LogicalType::VARINT; +constexpr const LogicalTypeId LogicalType::BIGNUM; constexpr const LogicalTypeId LogicalType::INTERVAL; constexpr const LogicalTypeId LogicalType::ROW_TYPE; @@ -241,7 +243,7 @@ const vector LogicalType::AllTypes() { LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::DOUBLE, LogicalType::FLOAT, LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::BIT, - LogicalType::VARINT, LogicalType::INTERVAL, LogicalType::HUGEINT, LogicalTypeId::DECIMAL, + LogicalType::BIGNUM, LogicalType::INTERVAL, LogicalType::HUGEINT, LogicalTypeId::DECIMAL, LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, LogicalType::UHUGEINT, LogicalType::TIME, LogicalTypeId::LIST, LogicalTypeId::STRUCT, LogicalType::TIME_TZ, LogicalType::TIMESTAMP_TZ, LogicalTypeId::MAP, LogicalTypeId::UNION, @@ -514,6 +516,12 @@ string LogicalType::ToString() const { case LogicalTypeId::SQLNULL: { return "\"NULL\""; } + case LogicalTypeId::TEMPLATE: { + if (!type_info_) { + return "T"; + } + return TemplateType::GetName(*this); + } default: return EnumUtil::ToString(id_); } @@ -735,6 +743,7 @@ bool LogicalType::IsComplete() const { case LogicalTypeId::INVALID: case LogicalTypeId::UNKNOWN: case LogicalTypeId::ANY: + case LogicalTypeId::TEMPLATE: return true; // These are incomplete by default case LogicalTypeId::LIST: case LogicalTypeId::MAP: @@ -781,6 +790,10 @@ bool LogicalType::IsComplete() const { }); } +bool LogicalType::IsTemplated() const { + return TypeVisitor::Contains(*this, LogicalTypeId::TEMPLATE); +} + bool LogicalType::SupportsRegularUpdate() const { switch (id()) { case LogicalTypeId::LIST: @@ -1224,7 +1237,7 @@ struct ForceGetTypeOperation { bool LogicalType::TryGetMaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right, LogicalType &result) { - if (DBConfig::GetConfig(context).options.old_implicit_casting) { + if (DBConfig::GetSetting(context)) { result = LogicalType::ForceMaxLogicalType(left, right); return true; } @@ -1237,6 +1250,7 @@ static idx_t GetLogicalTypeScore(const LogicalType &type) { case LogicalTypeId::SQLNULL: case LogicalTypeId::UNKNOWN: case LogicalTypeId::ANY: + case LogicalTypeId::TEMPLATE: case LogicalTypeId::STRING_LITERAL: case LogicalTypeId::INTEGER_LITERAL: return 0; @@ -1302,7 +1316,7 @@ static idx_t GetLogicalTypeScore(const LogicalType &type) { return 101; case LogicalTypeId::UUID: return 102; - case LogicalTypeId::VARINT: + case LogicalTypeId::BIGNUM: return 103; // nested types case LogicalTypeId::STRUCT: @@ -1934,6 +1948,22 @@ LogicalType LogicalType::INTEGER_LITERAL(const Value &constant) { // NOLINT return LogicalType(LogicalTypeId::INTEGER_LITERAL, std::move(type_info)); } +//===--------------------------------------------------------------------===// +// Template Type +//===--------------------------------------------------------------------===// +LogicalType LogicalType::TEMPLATE(const string &name) { + D_ASSERT(!name.empty()); + auto type_info = make_shared_ptr(name); + return LogicalType(LogicalTypeId::TEMPLATE, std::move(type_info)); +} + +const string &TemplateType::GetName(const LogicalType &type) { + D_ASSERT(type.id() == LogicalTypeId::TEMPLATE); + auto info = type.AuxInfo(); + D_ASSERT(info->type == ExtraTypeInfoType::TEMPLATE_TYPE_INFO); + return info->Cast().name; +} + //===--------------------------------------------------------------------===// // Logical Type //===--------------------------------------------------------------------===// diff --git a/src/duckdb/src/common/types/varint.cpp b/src/duckdb/src/common/types/bignum.cpp similarity index 74% rename from src/duckdb/src/common/types/varint.cpp rename to src/duckdb/src/common/types/bignum.cpp index 6c0c1019f..41bdfc538 100644 --- a/src/duckdb/src/common/types/varint.cpp +++ b/src/duckdb/src/common/types/bignum.cpp @@ -1,4 +1,5 @@ -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/bignum.hpp" +#include "duckdb/common/types/bignum.hpp" #include "duckdb/common/exception/conversion_exception.hpp" #include "duckdb/common/numeric_utils.hpp" #include "duckdb/common/typedefs.hpp" @@ -6,48 +7,55 @@ namespace duckdb { -void Varint::Verify(const string_t &input) { +void Bignum::Verify(const bignum_t &input) { #ifdef DEBUG // Size must be >= 4 - idx_t varint_bytes = input.GetSize(); - if (varint_bytes < 4) { - throw InternalException("Varint number of bytes is invalid, current number of bytes is %d", varint_bytes); + idx_t bignum_bytes = input.data.GetSize(); + if (bignum_bytes < 4) { + throw InternalException("Bignum number of bytes is invalid, current number of bytes is %d", bignum_bytes); } // Bytes in header must quantify the number of data bytes - auto varint_ptr = input.GetData(); - bool is_negative = (varint_ptr[0] & 0x80) == 0; + auto bignum_ptr = input.data.GetData(); + bool is_negative = (bignum_ptr[0] & 0x80) == 0; uint32_t number_of_bytes = 0; + if (bignum_bytes == 4 && is_negative) { + // There is only one invalid value, which is -0 + if (bignum_ptr[3] == static_cast(0xFF)) { + throw InternalException("Bignum value -0 is not allowed in the Bignum specification."); + } + } + char mask = 0x7F; if (is_negative) { - number_of_bytes |= static_cast(~varint_ptr[0] & mask) << 16 & 0xFF0000; - number_of_bytes |= static_cast(~varint_ptr[1]) << 8 & 0xFF00; + number_of_bytes |= static_cast(~bignum_ptr[0] & mask) << 16 & 0xFF0000; + number_of_bytes |= static_cast(~bignum_ptr[1]) << 8 & 0xFF00; ; - number_of_bytes |= static_cast(~varint_ptr[2]) & 0xFF; + number_of_bytes |= static_cast(~bignum_ptr[2]) & 0xFF; } else { - number_of_bytes |= static_cast(varint_ptr[0] & mask) << 16 & 0xFF0000; - number_of_bytes |= static_cast(varint_ptr[1]) << 8 & 0xFF00; - number_of_bytes |= static_cast(varint_ptr[2]) & 0xFF; + number_of_bytes |= static_cast(bignum_ptr[0] & mask) << 16 & 0xFF0000; + number_of_bytes |= static_cast(bignum_ptr[1]) << 8 & 0xFF00; + number_of_bytes |= static_cast(bignum_ptr[2]) & 0xFF; } - if (number_of_bytes != varint_bytes - 3) { - throw InternalException("The number of bytes set in the Varint header: %d bytes. Does not " - "match the number of bytes encountered as the varint data: %d bytes.", - number_of_bytes, varint_bytes - 3); + if (number_of_bytes != bignum_bytes - 3) { + throw InternalException("The number of bytes set in the Bignum header: %d bytes. Does not " + "match the number of bytes encountered as the bignum data: %d bytes.", + number_of_bytes, bignum_bytes - 3); } // No bytes between 4 and end can be 0, unless total size == 4 - if (varint_bytes > 4) { + if (bignum_bytes > 4) { if (is_negative) { - if (static_cast(~varint_ptr[3]) == 0) { - throw InternalException("Invalid top data bytes set to 0 for VARINT values"); + if (static_cast(~bignum_ptr[3]) == 0) { + throw InternalException("Invalid top data bytes set to 0 for BIGNUM values"); } } else { - if (varint_ptr[3] == 0) { - throw InternalException("Invalid top data bytes set to 0 for VARINT values"); + if (bignum_ptr[3] == 0) { + throw InternalException("Invalid top data bytes set to 0 for BIGNUM values"); } } } #endif } -void Varint::SetHeader(char *blob, uint64_t number_of_bytes, bool is_negative) { +void Bignum::SetHeader(char *blob, uint64_t number_of_bytes, bool is_negative) { uint32_t header = static_cast(number_of_bytes); // Set MSBit of 3rd byte header |= 0x00800000; @@ -62,35 +70,36 @@ void Varint::SetHeader(char *blob, uint64_t number_of_bytes, bool is_negative) { } // Creates a blob representing the value 0 -string_t Varint::InitializeVarintZero(Vector &result) { - uint32_t blob_size = 1 + VARINT_HEADER_SIZE; +bignum_t Bignum::InitializeBignumZero(Vector &result) { + uint32_t blob_size = 1 + BIGNUM_HEADER_SIZE; auto blob = StringVector::EmptyString(result, blob_size); auto writable_blob = blob.GetDataWriteable(); SetHeader(writable_blob, 1, false); writable_blob[3] = 0; blob.Finalize(); - return blob; + const bignum_t result_bignum(blob); + return result_bignum; } -string Varint::InitializeVarintZero() { - uint32_t blob_size = 1 + VARINT_HEADER_SIZE; +string Bignum::InitializeBignumZero() { + uint32_t blob_size = 1 + BIGNUM_HEADER_SIZE; string result(blob_size, '0'); SetHeader(&result[0], 1, false); result[3] = 0; return result; } -int Varint::CharToDigit(char c) { +int Bignum::CharToDigit(char c) { return c - '0'; } -char Varint::DigitToChar(int digit) { +char Bignum::DigitToChar(int digit) { // FIXME: this would be the proper solution: // return UnsafeNumericCast(digit + '0'); return static_cast(digit + '0'); } -bool Varint::VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &end_pos, bool &is_negative, +bool Bignum::VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &end_pos, bool &is_negative, bool &is_zero) { // If it's empty we error if (value.Empty()) { @@ -153,7 +162,7 @@ bool Varint::VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &e return true; } -void Varint::GetByteArray(vector &byte_array, bool &is_negative, const string_t &blob) { +void Bignum::GetByteArray(vector &byte_array, bool &is_negative, const string_t &blob) { if (blob.GetSize() < 4) { throw InvalidInputException("Invalid blob size."); } @@ -173,10 +182,10 @@ void Varint::GetByteArray(vector &byte_array, bool &is_negative, const } } -string Varint::FromByteArray(uint8_t *data, idx_t size, bool is_negative) { - string result(VARINT_HEADER_SIZE + size, '0'); +string Bignum::FromByteArray(uint8_t *data, idx_t size, bool is_negative) { + string result(BIGNUM_HEADER_SIZE + size, '0'); SetHeader(&result[0], size, is_negative); - uint8_t *result_data = reinterpret_cast(&result[VARINT_HEADER_SIZE]); + uint8_t *result_data = reinterpret_cast(&result[BIGNUM_HEADER_SIZE]); if (is_negative) { for (idx_t i = 0; i < size; i++) { result_data[i] = ~data[i]; @@ -190,11 +199,11 @@ string Varint::FromByteArray(uint8_t *data, idx_t size, bool is_negative) { } // Following CPython and Knuth (TAOCP, Volume 2 (3rd edn), section 4.4, Method 1b). -string Varint::VarIntToVarchar(const string_t &blob) { +string Bignum::BignumToVarchar(const bignum_t &blob) { string decimal_string; vector byte_array; bool is_negative; - GetByteArray(byte_array, is_negative, blob); + GetByteArray(byte_array, is_negative, blob.data); vector digits; // Rounding byte_array to digit_bytes multiple size, so that we can process every digit_bytes bytes // at a time without if check in the for loop @@ -244,21 +253,21 @@ string Varint::VarIntToVarchar(const string_t &blob) { return decimal_string; } -string Varint::VarcharToVarInt(const string_t &value) { +string Bignum::VarcharToBignum(const string_t &value) { idx_t start_pos, end_pos; bool is_negative, is_zero; if (!VarcharFormatting(value, start_pos, end_pos, is_negative, is_zero)) { - throw ConversionException("Could not convert string \'%s\' to Varint", value.GetString()); + throw ConversionException("Could not convert string \'%s\' to Bignum", value.GetString()); } if (is_zero) { // Return Value 0 - return InitializeVarintZero(); + return InitializeBignumZero(); } auto int_value_char = value.GetData(); idx_t actual_size = end_pos - start_pos; // we initalize result with space for our header - string result(VARINT_HEADER_SIZE, '0'); + string result(BIGNUM_HEADER_SIZE, '0'); unsafe_vector digits; // The max number a uint64_t can represent is 18.446.744.073.709.551.615 @@ -302,24 +311,24 @@ string Varint::VarcharToVarInt(const string_t &value) { result.push_back(static_cast(remainder)); } } - std::reverse(result.begin() + VARINT_HEADER_SIZE, result.end()); - // Set header after we know the size of the varint - SetHeader(&result[0], result.size() - VARINT_HEADER_SIZE, is_negative); + std::reverse(result.begin() + BIGNUM_HEADER_SIZE, result.end()); + // Set header after we know the size of the bignum + SetHeader(&result[0], result.size() - BIGNUM_HEADER_SIZE, is_negative); return result; } -bool Varint::VarintToDouble(const string_t &blob, double &result, bool &strict) { +bool Bignum::BignumToDouble(const bignum_t &blob, double &result, bool &strict) { result = 0; - if (blob.GetSize() < 4) { + if (blob.data.GetSize() < 4) { throw InvalidInputException("Invalid blob size."); } - auto blob_ptr = blob.GetData(); + auto blob_ptr = blob.data.GetData(); // Determine if the number is negative bool is_negative = (blob_ptr[0] & 0x80) == 0; idx_t byte_pos = 0; - for (idx_t i = blob.GetSize() - 1; i > 2; i--) { + for (idx_t i = blob.data.GetSize() - 1; i > 2; i--) { if (is_negative) { result += static_cast(~blob_ptr[i]) * pow(256, static_cast(byte_pos)); } else { @@ -333,7 +342,7 @@ bool Varint::VarintToDouble(const string_t &blob, double &result, bool &strict) } if (!std::isfinite(result)) { // We throw an error - throw ConversionException("Could not convert varint '%s' to Double", VarIntToVarchar(blob)); + throw ConversionException("Could not convert bignum '%s' to Double", BignumToVarchar(blob)); } return true; } diff --git a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp index 4a46c6734..eff1186b0 100644 --- a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp +++ b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp @@ -3,6 +3,7 @@ #include "duckdb/common/radix_partitioning.hpp" #include "duckdb/common/types/row/tuple_data_iterator.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/common/printer.hpp" namespace duckdb { diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp index ca6061ee7..ffd4a2b4c 100644 --- a/src/duckdb/src/common/types/row/tuple_data_collection.cpp +++ b/src/duckdb/src/common/types/row/tuple_data_collection.cpp @@ -54,6 +54,10 @@ void TupleDataCollection::GetAllColumnIDs(vector &column_ids) { GetAllColumnIDsInternal(column_ids, layout.ColumnCount()); } +shared_ptr TupleDataCollection::GetLayoutPtr() const { + return layout_ptr; +} + const TupleDataLayout &TupleDataCollection::GetLayout() const { return layout; } @@ -118,11 +122,17 @@ void TupleDataCollection::DestroyChunks(const idx_t chunk_idx_begin, const idx_t } if (!layout.AllConstant()) { + if (chunk_begin.heap_block_ids.Empty()) { + return; + } const auto heap_block_begin = chunk_begin.heap_block_ids.Start(); if (chunk_idx_end == ChunkCount()) { segment.allocator->DestroyHeapBlocks(heap_block_begin, segment.allocator->HeapBlockCount()); } else { auto &chunk_end = segment.chunks[chunk_idx_end]; + if (chunk_end.heap_block_ids.Empty()) { + return; + } const auto heap_block_end = chunk_end.heap_block_ids.Start(); segment.allocator->DestroyHeapBlocks(heap_block_begin, heap_block_end); } diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp index e379f97d4..54e703707 100644 --- a/src/duckdb/src/common/types/value.cpp +++ b/src/duckdb/src/common/types/value.cpp @@ -23,7 +23,7 @@ #include "duckdb/common/types/cast_helpers.hpp" #include "duckdb/function/cast/cast_function_set.hpp" #include "duckdb/main/error_manager.hpp" -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/types/bignum.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" @@ -276,8 +276,8 @@ Value Value::MinimumValue(const LogicalType &type) { } case LogicalTypeId::ENUM: return Value::ENUM(0, type); - case LogicalTypeId::VARINT: - return Value::VARINT(Varint::VarcharToVarInt( + case LogicalTypeId::BIGNUM: + return Value::BIGNUM(Bignum::VarcharToBignum( "-179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540" "4589535143824642343213268894641827684675467035375169860499105765512820762454900903893289440758685084551339" "42304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368")); @@ -368,8 +368,8 @@ Value Value::MaximumValue(const LogicalType &type) { auto enum_size = EnumType::GetSize(type); return Value::ENUM(enum_size - (enum_size ? 1 : 0), type); } - case LogicalTypeId::VARINT: - return Value::VARINT(Varint::VarcharToVarInt( + case LogicalTypeId::BIGNUM: + return Value::BIGNUM(Bignum::VarcharToBignum( "1797693134862315708145274237317043567980705675258449965989174768031572607800285387605895586327668781715404" "5895351438246423432132688946418276846754670353751698604991057655128207624549009038932894407586850845513394" "2304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368")); @@ -891,12 +891,12 @@ Value Value::BLOB(const_data_ptr_t data, idx_t len) { return result; } -Value Value::VARINT(const_data_ptr_t data, idx_t len) { - return VARINT(string(const_char_ptr_cast(data), len)); +Value Value::BIGNUM(const_data_ptr_t data, idx_t len) { + return BIGNUM(string(const_char_ptr_cast(data), len)); } -Value Value::VARINT(const string &data) { - Value result(LogicalType::VARINT); +Value Value::BIGNUM(const string &data) { + Value result(LogicalType::BIGNUM); result.is_null = false; result.value_info_ = make_shared_ptr(data); return result; @@ -1487,6 +1487,11 @@ DUCKDB_API string_t Value::GetValueUnsafe() const { return string_t(StringValue::Get(*this)); } +template <> +DUCKDB_API bignum_t Value::GetValueUnsafe() const { + return bignum_t(StringValue::Get(*this)); +} + template <> float Value::GetValueUnsafe() const { D_ASSERT(type_.InternalType() == PhysicalType::FLOAT); diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp index 4edfab4d9..3beae89eb 100644 --- a/src/duckdb/src/common/types/vector.cpp +++ b/src/duckdb/src/common/types/vector.cpp @@ -15,7 +15,7 @@ #include "duckdb/common/types/sel_cache.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/common/types/value_map.hpp" -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/types/bignum.hpp" #include "duckdb/common/types/vector_cache.hpp" #include "duckdb/common/uhugeint.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" @@ -147,10 +147,9 @@ void Vector::ReferenceAndSetType(const Vector &other) { void Vector::Reinterpret(const Vector &other) { vector_type = other.vector_type; -#ifdef DEBUG auto &this_type = GetType(); auto &other_type = other.GetType(); - +#ifdef DEBUG auto type_is_same = other_type == this_type; bool this_is_nested = this_type.IsNested(); bool other_is_nested = other_type.IsNested(); @@ -163,7 +162,7 @@ void Vector::Reinterpret(const Vector &other) { D_ASSERT((not_nested && type_size_equal) || type_is_same); #endif AssignSharedPointer(buffer, other.buffer); - if (vector_type == VectorType::DICTIONARY_VECTOR) { + if (vector_type == VectorType::DICTIONARY_VECTOR && other_type != this_type) { Vector new_vector(GetType(), nullptr); new_vector.Reinterpret(DictionaryVector::Child(other)); auxiliary = make_shared_ptr(std::move(new_vector)); @@ -590,7 +589,7 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { int64_t start, increment; SequenceVector::GetSequence(*vector, start, increment); return Value::Numeric(vector->GetType(), - static_cast(start + static_cast(increment) * index)); + start + static_cast(static_cast(increment) * index)); } default: throw InternalException("Unimplemented vector type for Vector::GetValue"); @@ -711,9 +710,9 @@ Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { auto str = reinterpret_cast(data)[index]; return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); } - case LogicalTypeId::VARINT: { - auto str = reinterpret_cast(data)[index]; - return Value::VARINT(const_data_ptr_cast(str.GetData()), str.GetSize()); + case LogicalTypeId::BIGNUM: { + auto str = reinterpret_cast(data)[index]; + return Value::BIGNUM(const_data_ptr_cast(str.data.GetData()), str.data.GetSize()); } case LogicalTypeId::AGGREGATE_STATE: { auto str = reinterpret_cast(data)[index]; @@ -840,7 +839,7 @@ string Vector::ToString(idx_t count) const { int64_t start, increment; SequenceVector::GetSequence(*this, start, increment); for (idx_t i = 0; i < count; i++) { - retval += to_string(static_cast(start + static_cast(increment) * i)) + + retval += to_string(start + static_cast(static_cast(increment) * i)) + (i == count - 1 ? "" : ", "); } break; @@ -1607,7 +1606,7 @@ void Vector::Verify(Vector &vector_p, const SelectionVector &sel_p, idx_t count) } } - if (type.id() == LogicalTypeId::VARINT) { + if (type.id() == LogicalTypeId::BIGNUM) { switch (vtype) { case VectorType::FLAT_VECTOR: { auto &validity = FlatVector::Validity(*vector); @@ -1615,7 +1614,7 @@ void Vector::Verify(Vector &vector_p, const SelectionVector &sel_p, idx_t count) for (idx_t i = 0; i < count; i++) { auto oidx = sel->get_index(i); if (validity.RowIsValid(oidx)) { - Varint::Verify(strings[oidx]); + Bignum::Verify(static_cast(strings[oidx])); } } } break; diff --git a/src/duckdb/src/common/value_operations/comparison_operations.cpp b/src/duckdb/src/common/value_operations/comparison_operations.cpp index 78dd5244e..455b6a72a 100644 --- a/src/duckdb/src/common/value_operations/comparison_operations.cpp +++ b/src/duckdb/src/common/value_operations/comparison_operations.cpp @@ -65,38 +65,11 @@ inline bool ValuePositionComparator::Final(const Value &lhs, return ValueOperations::NotDistinctFrom(lhs, rhs); } -// Non-strict inequalities must use strict comparisons for Definite -template <> -bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Definite(lhs, rhs); -} - template <> bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { return ValueOperations::DistinctGreaterThan(lhs, rhs); } -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Final(lhs, rhs); -} - -template <> -bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Definite(rhs, lhs); -} - -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Final(rhs, lhs); -} - -// Strict inequalities just use strict for both Definite and Final -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValuePositionComparator::Final(rhs, lhs); -} - template static bool TemplatedBooleanOperation(const Value &left, const Value &right) { const auto &left_type = left.type(); diff --git a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp index 054ef2173..e57f9738d 100644 --- a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp +++ b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp @@ -446,61 +446,6 @@ idx_t PositionComparator::Final(Vector &left, Vector &righ return VectorOperations::NestedNotEquals(left, right, &sel, count, true_sel, false_sel, null_mask); } -// Non-strict inequalities must use strict comparisons for Definite -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - SelectionVector &false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, &false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThanEquals(right, left, &sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - SelectionVector &false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, &false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel, null_mask); -} - -// Strict inequalities just use strict for both Definite and Final -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - // DistinctGreaterThan has NULLs last - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, false_sel, null_mask); -} - template <> idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, optional_ptr true_sel, diff --git a/src/duckdb/src/execution/column_binding_resolver.cpp b/src/duckdb/src/execution/column_binding_resolver.cpp index 8c031d4fe..d9730eb85 100644 --- a/src/duckdb/src/execution/column_binding_resolver.cpp +++ b/src/duckdb/src/execution/column_binding_resolver.cpp @@ -112,21 +112,22 @@ void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { //! We want to execute the normal path, but also add a dummy 'excluded' binding if there is a // ON CONFLICT DO UPDATE clause auto &insert_op = op.Cast(); - if (insert_op.action_type != OnConflictAction::THROW) { + if (insert_op.on_conflict_info.action_type != OnConflictAction::THROW) { // Get the bindings from the children VisitOperatorChildren(op); auto column_count = insert_op.table.GetColumns().PhysicalColumnCount(); - auto dummy_bindings = LogicalOperator::GenerateColumnBindings(insert_op.excluded_table_index, column_count); + auto dummy_bindings = + LogicalOperator::GenerateColumnBindings(insert_op.on_conflict_info.excluded_table_index, column_count); // Now insert our dummy bindings at the start of the bindings, // so the first 'column_count' indices of the chunk are reserved for our 'excluded' columns bindings.insert(bindings.begin(), dummy_bindings.begin(), dummy_bindings.end()); // TODO: fill types in too (clearing skips type checks) types.clear(); - if (insert_op.on_conflict_condition) { - VisitExpression(&insert_op.on_conflict_condition); + if (insert_op.on_conflict_info.on_conflict_condition) { + VisitExpression(&insert_op.on_conflict_info.on_conflict_condition); } - if (insert_op.do_update_condition) { - VisitExpression(&insert_op.do_update_condition); + if (insert_op.on_conflict_info.do_update_condition) { + VisitExpression(&insert_op.on_conflict_info.do_update_condition); } VisitOperatorExpressions(op); bindings = op.GetColumnBindings(); diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp index 9b46463bc..98daf3017 100644 --- a/src/duckdb/src/execution/expression_executor.cpp +++ b/src/duckdb/src/execution/expression_executor.cpp @@ -4,12 +4,12 @@ #include "duckdb/execution/execution_context.hpp" #include "duckdb/storage/statistics/base_statistics.hpp" #include "duckdb/planner/expression/list.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { ExpressionExecutor::ExpressionExecutor(ClientContext &context) : context(&context) { - auto &config = DBConfig::GetConfig(context); - debug_vector_verification = config.options.debug_verify_vector; + debug_vector_verification = DBConfig::GetSetting(context); } ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression *expression) diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp index 6d1136b22..38c7de719 100644 --- a/src/duckdb/src/execution/index/art/art.cpp +++ b/src/duckdb/src/execution/index/art/art.cpp @@ -35,7 +35,7 @@ struct ARTIndexScanState : public IndexScanState { ExpressionType expressions[2]; bool checked = false; //! All scanned row IDs. - unsafe_vector row_ids; + set row_ids; }; //===--------------------------------------------------------------------===// @@ -207,6 +207,8 @@ unique_ptr ART::TryInitializeScan(const Expression &expr, const high_comparison_type = between.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO : ExpressionType::COMPARE_LESSTHAN; } + // FIXME: add another if...else... to match rewritten BETWEEN, + // i.e., WHERE i BETWEEN 50 AND 1502 is rewritten to CONJUNCTION_AND. // We cannot use an index scan. if (equal_value.IsNull() && low_value.IsNull() && high_value.IsNull()) { @@ -495,7 +497,7 @@ bool ART::Construct(unsafe_vector &keys, unsafe_vector &row_ids, } #ifdef DEBUG - unsafe_vector row_ids_debug; + set row_ids_debug; Iterator it(*this); it.FindMinimum(tree); ARTKey empty_key = ARTKey(); @@ -682,7 +684,7 @@ void ART::Erase(Node &node, reference key, idx_t depth, reference< // This is the root node, which can be a leaf with possible prefix nodes. if (next.get().GetType() == NType::LEAF_INLINED) { if (next.get().GetRowId() == row_id.get().GetRowId()) { - Node::Free(*this, node); + Node::FreeTree(*this, node); } return; } @@ -761,7 +763,7 @@ void ART::Erase(Node &node, reference key, idx_t depth, reference< // Point and range lookups //===--------------------------------------------------------------------===// -bool ART::SearchEqual(ARTKey &key, idx_t max_count, unsafe_vector &row_ids) { +bool ART::SearchEqual(ARTKey &key, idx_t max_count, set &row_ids) { auto leaf = ARTOperator::Lookup(*this, tree, key, 0); if (!leaf) { return true; @@ -773,7 +775,7 @@ bool ART::SearchEqual(ARTKey &key, idx_t max_count, unsafe_vector &row_id return it.Scan(empty_key, max_count, row_ids, false); } -bool ART::SearchGreater(ARTKey &key, bool equal, idx_t max_count, unsafe_vector &row_ids) { +bool ART::SearchGreater(ARTKey &key, bool equal, idx_t max_count, set &row_ids) { if (!tree.HasMetadata()) { return true; } @@ -791,7 +793,7 @@ bool ART::SearchGreater(ARTKey &key, bool equal, idx_t max_count, unsafe_vector< return it.Scan(ARTKey(), max_count, row_ids, false); } -bool ART::SearchLess(ARTKey &upper_bound, bool equal, idx_t max_count, unsafe_vector &row_ids) { +bool ART::SearchLess(ARTKey &upper_bound, bool equal, idx_t max_count, set &row_ids) { if (!tree.HasMetadata()) { return true; } @@ -810,7 +812,7 @@ bool ART::SearchLess(ARTKey &upper_bound, bool equal, idx_t max_count, unsafe_ve } bool ART::SearchCloseRange(ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, bool right_equal, idx_t max_count, - unsafe_vector &row_ids) { + set &row_ids) { // Find the first node that satisfies the left predicate. Iterator it(*this); @@ -823,7 +825,7 @@ bool ART::SearchCloseRange(ARTKey &lower_bound, ARTKey &upper_bound, bool left_e return it.Scan(upper_bound, max_count, row_ids, right_equal); } -bool ART::Scan(IndexScanState &state, const idx_t max_count, unsafe_vector &row_ids) { +bool ART::Scan(IndexScanState &state, const idx_t max_count, set &row_ids) { auto &scan_state = state.Cast(); D_ASSERT(scan_state.values[0].type().InternalType() == types[0]); ArenaAllocator arena_allocator(Allocator::Get(db)); @@ -962,27 +964,27 @@ void ART::VerifyLeaf(const Node &leaf, const ARTKey &key, optional_ptr dele Iterator it(*this); it.FindMinimum(leaf); ARTKey empty_key = ARTKey(); - unsafe_vector row_ids; + set row_ids; auto success = it.Scan(empty_key, 2, row_ids, false); if (!success || row_ids.size() != 2) { throw InternalException("VerifyLeaf expects exactly two row IDs to be scanned"); } - if (!deleted_leaf) { - if (manager.AddHit(i, row_ids[0]) || manager.AddSecondHit(i, row_ids[1])) { - conflict_idx = i; + if (deleted_leaf) { + auto deleted_row_id = deleted_leaf->GetRowId(); + for (const auto row_id : row_ids) { + if (deleted_row_id == row_id) { + return; + } } - return; - } - - auto deleted_row_id = deleted_leaf->GetRowId(); - if (deleted_row_id == row_ids[0] || deleted_row_id == row_ids[1]) { - return; } - if (manager.AddHit(i, row_ids[0]) || manager.AddSecondHit(i, row_ids[1])) { + auto row_id_it = row_ids.begin(); + if (manager.AddHit(i, *row_id_it)) { conflict_idx = i; } + row_id_it++; + manager.AddSecondHit(i, *row_id_it); } void ART::VerifyConstraint(DataChunk &chunk, IndexAppendInfo &info, ConflictManager &manager) { @@ -1262,7 +1264,7 @@ void ART::Vacuum(IndexLock &state) { break; } default: - throw InternalException("invalid node type for Vacuum: %s", EnumUtil::ToString(type)); + throw InternalException("invalid node type for Vacuum: %d", type); } const auto idx = Node::GetAllocatorIdx(type); @@ -1301,14 +1303,14 @@ void ART::InitializeMerge(Node &node, unsafe_vector &upper_bounds) { auto handler = [&upper_bounds](Node &node) { const auto type = node.GetType(); if (node.GetType() == NType::LEAF_INLINED) { - return ARTHandlingResult::CONTINUE; + return ARTHandlingResult::NONE; } if (type == NType::LEAF) { throw InternalException("deprecated ART storage in InitializeMerge"); } const auto idx = Node::GetAllocatorIdx(type); node.IncreaseBufferId(upper_bounds[idx]); - return ARTHandlingResult::CONTINUE; + return ARTHandlingResult::NONE; }; ARTScanner scanner(*this, handler, node); diff --git a/src/duckdb/src/execution/index/art/art_merger.cpp b/src/duckdb/src/execution/index/art/art_merger.cpp index 472bbfcb8..70781cbfb 100644 --- a/src/duckdb/src/execution/index/art/art_merger.cpp +++ b/src/duckdb/src/execution/index/art/art_merger.cpp @@ -144,7 +144,7 @@ void ARTMerger::MergeLeaves(NodeEntry &entry) { for (idx_t i = 0; i < bytes.size(); i++) { Node::InsertChild(art, entry.left, bytes[i]); } - Node::Free(art, entry.right); + Node::FreeNode(art, entry.right); } NodeChildren ARTMerger::ExtractChildren(Node &node) { @@ -177,7 +177,7 @@ void ARTMerger::MergeNodes(NodeEntry &entry) { auto children = ExtractChildren(entry.right); // As long as the arena is valid, // the copied-out nodes (and their references) are valid. - Node::Free(art, entry.right); + Node::FreeNode(art, entry.right); // First, we iterate and insert children. // This might grow the node, so we need to do it prior to Emplace. @@ -300,8 +300,7 @@ void ARTMerger::MergePrefixes(NodeEntry &entry) { // Free the right prefix, but keep the reference to its child alive. // Then, iterate on the left and right (reduced) child. auto r_child = *r_prefix.ptr; - r_prefix.ptr->Clear(); - Node::Free(art, entry.right); + Node::FreeNode(art, entry.right); entry.right = r_child; auto depth = entry.depth + l_prefix.data[count]; diff --git a/src/duckdb/src/execution/index/art/base_leaf.cpp b/src/duckdb/src/execution/index/art/base_leaf.cpp index a0d88e88d..a694ca3b5 100644 --- a/src/duckdb/src/execution/index/art/base_leaf.cpp +++ b/src/duckdb/src/execution/index/art/base_leaf.cpp @@ -76,15 +76,16 @@ void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byt auto remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; remainder |= UnsafeNumericCast(n7.key[0]); - n7.count--; - Node::Free(art, node); - + // Free the prefix (nodes) and inline the remainder. if (prefix.GetType() == NType::PREFIX) { - Node::Free(art, prefix); + Node::FreeTree(art, prefix); Leaf::New(prefix, UnsafeNumericCast(remainder)); - } else { - Leaf::New(node, UnsafeNumericCast(remainder)); + return; } + + // Free the Node7Leaf and inline the remainder. + Node::FreeNode(art, node); + Leaf::New(node, UnsafeNumericCast(remainder)); } } @@ -98,8 +99,7 @@ void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) n7.key[i] = n15.key[i]; } - n15.count = 0; - Node::Free(art, node15_leaf); + Node::FreeNode(art, node15_leaf); } //===--------------------------------------------------------------------===// @@ -139,8 +139,7 @@ void Node15Leaf::GrowNode7Leaf(ART &art, Node &node15_leaf, Node &node7_leaf) { n15.key[i] = n7.key[i]; } - n7.count = 0; - Node::Free(art, node7_leaf); + Node::FreeNode(art, node7_leaf); } void Node15Leaf::ShrinkNode256Leaf(ART &art, Node &node15_leaf, Node &node256_leaf) { @@ -156,7 +155,7 @@ void Node15Leaf::ShrinkNode256Leaf(ART &art, Node &node15_leaf, Node &node256_le } } - Node::Free(art, node256_leaf); + Node::FreeNode(art, node256_leaf); } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/base_node.cpp b/src/duckdb/src/execution/index/art/base_node.cpp index 3e6713d2b..95e69716f 100644 --- a/src/duckdb/src/execution/index/art/base_node.cpp +++ b/src/duckdb/src/execution/index/art/base_node.cpp @@ -43,7 +43,7 @@ NodeHandle> BaseNode::DeleteChildIntern } // Free the child and decrease the count. - Node::Free(art, n.children[child_pos]); + Node::FreeTree(art, n.children[child_pos]); n.count--; // Possibly move children backwards. @@ -89,13 +89,12 @@ void Node4::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte, } // Compress one-way nodes. - n.count--; child = n.children[0]; remainder = n.key[0]; } auto old_status = node.GetGateStatus(); - Node::Free(art, node); + Node::FreeNode(art, node); Prefix::Concat(art, prefix, remainder, old_status, child, status); } @@ -113,9 +112,8 @@ void Node4::ShrinkNode16(ART &art, Node &node4, Node &node16) { n4.key[i] = n16.key[i]; n4.children[i] = n16.children[i]; } - n16.count = 0; } - Node::Free(art, node16); + Node::FreeNode(art, node16); } //===--------------------------------------------------------------------===// @@ -165,9 +163,8 @@ void Node16::GrowNode4(ART &art, Node &node16, Node &node4) { n16.key[i] = n4.key[i]; n16.children[i] = n4.children[i]; } - n4.count = 0; } - Node::Free(art, node4); + Node::FreeNode(art, node4); } void Node16::ShrinkNode48(ART &art, Node &node16, Node &node48) { @@ -187,9 +184,8 @@ void Node16::ShrinkNode48(ART &art, Node &node16, Node &node48) { n16.count++; } } - n48.count = 0; } - Node::Free(art, node48); + Node::FreeNode(art, node48); } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/iterator.cpp b/src/duckdb/src/execution/index/art/iterator.cpp index 1c138e1d3..1a88b7262 100644 --- a/src/duckdb/src/execution/index/art/iterator.cpp +++ b/src/duckdb/src/execution/index/art/iterator.cpp @@ -42,7 +42,7 @@ bool IteratorKey::GreaterThan(const ARTKey &key, const bool equal, const uint8_t // Iterator //===--------------------------------------------------------------------===// -bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, unsafe_vector &row_ids, const bool equal) { +bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, set &row_ids, const bool equal) { bool has_next; do { // An empty upper bound indicates that no upper bound exists. @@ -59,7 +59,7 @@ bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, unsafe_vec if (row_ids.size() + 1 > max_count) { return false; } - row_ids.push_back(last_leaf.GetRowId()); + row_ids.insert(last_leaf.GetRowId()); break; case NType::LEAF: if (!Leaf::DeprecatedGetRowIds(art, last_leaf, row_ids, max_count)) { @@ -76,7 +76,7 @@ bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, unsafe_vec } row_id[ROW_ID_SIZE - 1] = byte; ARTKey key(&row_id[0], ROW_ID_SIZE); - row_ids.push_back(key.GetRowId()); + row_ids.insert(key.GetRowId()); if (byte == NumericLimits::Maximum()) { break; } diff --git a/src/duckdb/src/execution/index/art/leaf.cpp b/src/duckdb/src/execution/index/art/leaf.cpp index 949318ce5..f6c7751d6 100644 --- a/src/duckdb/src/execution/index/art/leaf.cpp +++ b/src/duckdb/src/execution/index/art/leaf.cpp @@ -98,7 +98,7 @@ void Leaf::TransformToNested(ART &art, Node &node) { } root.SetGateStatus(GateStatus::GATE_SET); - Node::Free(art, node); + DeprecatedFree(art, node); node = root; } @@ -111,17 +111,17 @@ void Leaf::TransformToDeprecated(ART &art, Node &node) { } // Collect all row IDs and free the nested leaf. - unsafe_vector row_ids; + set row_ids; Iterator it(art); it.FindMinimum(node); ARTKey empty_key = ARTKey(); it.Scan(empty_key, NumericLimits().Maximum(), row_ids, false); - Node::Free(art, node); + Node::FreeTree(art, node); D_ASSERT(row_ids.size() > 1); // Create the deprecated leaves. idx_t remaining = row_ids.size(); - idx_t copy_count = 0; + auto row_ids_it = row_ids.begin(); reference ref(node); while (remaining) { ref.get() = Node::GetAllocator(art, LEAF).New(); @@ -132,10 +132,9 @@ void Leaf::TransformToDeprecated(ART &art, Node &node) { leaf.count = UnsafeNumericCast(min); for (uint8_t i = 0; i < leaf.count; i++) { - leaf.row_ids[i] = row_ids[copy_count + i]; + leaf.row_ids[i] = *row_ids_it; + row_ids_it++; } - - copy_count += leaf.count; remaining -= leaf.count; ref = leaf.ptr; @@ -149,17 +148,16 @@ void Leaf::TransformToDeprecated(ART &art, Node &node) { void Leaf::DeprecatedFree(ART &art, Node &node) { D_ASSERT(node.GetType() == LEAF); - Node next; while (node.HasMetadata()) { next = Node::Ref(art, node, LEAF).ptr; - Node::GetAllocator(art, LEAF).Free(node); + Node::FreeNode(art, node); node = next; } node.Clear(); } -bool Leaf::DeprecatedGetRowIds(ART &art, const Node &node, unsafe_vector &row_ids, const idx_t max_count) { +bool Leaf::DeprecatedGetRowIds(ART &art, const Node &node, set &row_ids, const idx_t max_count) { D_ASSERT(node.GetType() == LEAF); reference ref(node); @@ -170,7 +168,7 @@ bool Leaf::DeprecatedGetRowIds(ART &art, const Node &node, unsafe_vector return false; } for (uint8_t i = 0; i < leaf.count; i++) { - row_ids.push_back(leaf.row_ids[i]); + row_ids.insert(leaf.row_ids[i]); } ref = leaf.ptr; } diff --git a/src/duckdb/src/execution/index/art/node.cpp b/src/duckdb/src/execution/index/art/node.cpp index 950204629..29d5c8cf9 100644 --- a/src/duckdb/src/execution/index/art/node.cpp +++ b/src/duckdb/src/execution/index/art/node.cpp @@ -44,44 +44,45 @@ void Node::New(ART &art, Node &node, NType type) { Node256::New(art, node); break; default: - throw InternalException("Invalid node type for New: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for New: %d.", type); } } -void Node::Free(ART &art, Node &node) { - if (!node.HasMetadata()) { - return node.Clear(); - } +void Node::FreeNode(ART &art, Node &node) { + D_ASSERT(node.HasMetadata()); + GetAllocator(art, node.GetType()).Free(node); + node.Clear(); +} - // Free the children. - auto type = node.GetType(); - switch (type) { - case NType::PREFIX: - return Prefix::Free(art, node); - case NType::LEAF: - return Leaf::DeprecatedFree(art, node); - case NType::NODE_4: - Node4::Free(art, node); - break; - case NType::NODE_16: - Node16::Free(art, node); - break; - case NType::NODE_48: - Node48::Free(art, node); - break; - case NType::NODE_256: - Node256::Free(art, node); - break; - case NType::LEAF_INLINED: - return node.Clear(); - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: - break; - } +void Node::FreeTree(ART &art, Node &node) { + auto handler = [&art](Node &node) { + const auto type = node.GetType(); + switch (type) { + case NType::LEAF_INLINED: + node.Clear(); + return ARTHandlingResult::NONE; + case NType::LEAF: + Leaf::DeprecatedFree(art, node); + return ARTHandlingResult::NONE; + case NType::NODE_7_LEAF: + case NType::NODE_15_LEAF: + case NType::NODE_256_LEAF: + case NType::PREFIX: + case NType::NODE_4: + case NType::NODE_16: + case NType::NODE_48: + case NType::NODE_256: + break; + default: + throw InternalException("invalid node type for Free: %d", type); + } - GetAllocator(art, type).Free(node); - node.Clear(); + FreeNode(art, node); + return ARTHandlingResult::NONE; + }; + + ARTScanner scanner(art, handler, node); + scanner.Scan(handler); } //===--------------------------------------------------------------------===// @@ -113,7 +114,7 @@ uint8_t Node::GetAllocatorIdx(const NType type) { case NType::NODE_256_LEAF: return 8; default: - throw InternalException("Invalid node type for GetAllocatorIdx: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for GetAllocatorIdx: %d.", type); } } @@ -135,7 +136,7 @@ void Node::ReplaceChild(const ART &art, const uint8_t byte, const Node child) co case NType::NODE_256: return Ref(art, *this, type).ReplaceChild(byte, child); default: - throw InternalException("Invalid node type for ReplaceChild: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for ReplaceChild: %d.", type); } } @@ -159,7 +160,7 @@ void Node::InsertChild(ART &art, Node &node, const uint8_t byte, const Node chil case NType::NODE_256_LEAF: return Node256Leaf::InsertByte(art, node, byte); default: - throw InternalException("Invalid node type for InsertChild: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for InsertChild: %d.", type); } } @@ -188,7 +189,7 @@ void Node::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte, c case NType::NODE_256_LEAF: return Node256Leaf::DeleteByte(art, node, byte); default: - throw InternalException("Invalid node type for DeleteChild: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for DeleteChild: %d.", type); } } @@ -212,7 +213,7 @@ unsafe_optional_ptr GetChildInternal(ART &art, NODE &node, const uint8_t b return Node256::GetChild(Node::Ref(art, node, type), byte); } default: - throw InternalException("Invalid node type for GetChildInternal: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for GetChildInternal: %d.", type); } } @@ -239,7 +240,7 @@ unsafe_optional_ptr GetNextChildInternal(ART &art, NODE &node, uint8_t &by case NType::NODE_256: return Node256::GetNextChild(Node::Ref(art, node, type), byte); default: - throw InternalException("Invalid node type for GetNextChildInternal: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for GetNextChildInternal: %d.", type); } } @@ -259,7 +260,7 @@ bool Node::HasByte(ART &art, uint8_t &byte) const { case NType::NODE_256_LEAF: return Ref(art, *this, NType::NODE_256_LEAF).HasByte(byte); default: - throw InternalException("Invalid node type for GetNextByte: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for GetNextByte: %d.", type); } } @@ -275,7 +276,7 @@ bool Node::GetNextByte(ART &art, uint8_t &byte) const { case NType::NODE_256_LEAF: return Ref(art, *this, NType::NODE_256_LEAF).GetNextByte(byte); default: - throw InternalException("Invalid node type for GetNextByte: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for GetNextByte: %d.", type); } } @@ -300,7 +301,7 @@ idx_t GetCapacity(NType type) { case NType::NODE_256: return Node256::CAPACITY; default: - throw InternalException("Invalid node type for GetCapacity: %s.", EnumUtil::ToString(type)); + throw InternalException("Invalid node type for GetCapacity: %d.", type); } } @@ -382,7 +383,7 @@ void Node::TransformToDeprecated(ART &art, Node &node, case NType::NODE_256: return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), deprecated_prefix_allocator); default: - throw InternalException("invalid node type for TransformToDeprecated: %s", EnumUtil::ToString(type)); + throw InternalException("invalid node type for TransformToDeprecated: %d", type); } } @@ -471,7 +472,7 @@ void Node::VerifyAllocations(ART &art, unordered_map &node_count break; } default: - throw InternalException("invalid node type for VerifyAllocations: %s", EnumUtil::ToString(type)); + throw InternalException("invalid node type for VerifyAllocations: %d", type); } node_counts[GetAllocatorIdx(type)]++; return result; diff --git a/src/duckdb/src/execution/index/art/node256.cpp b/src/duckdb/src/execution/index/art/node256.cpp index f08717e13..f5ff96643 100644 --- a/src/duckdb/src/execution/index/art/node256.cpp +++ b/src/duckdb/src/execution/index/art/node256.cpp @@ -4,75 +4,49 @@ namespace duckdb { -Node256 &Node256::New(ART &art, Node &node) { - node = Node::GetAllocator(art, NODE_256).New(); - node.SetMetadata(static_cast(NODE_256)); - auto &n256 = Node::Ref(art, node, NODE_256); - - n256.count = 0; - for (uint16_t i = 0; i < CAPACITY; i++) { - n256.children[i].Clear(); - } - - return n256; -} - -void Node256::Free(ART &art, Node &node) { - auto &n256 = Node::Ref(art, node, NODE_256); - if (!n256.count) { - return; - } - - Iterator(n256, [&](Node &child) { Node::Free(art, child); }); -} - void Node256::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - auto &n256 = Node::Ref(art, node, NODE_256); - n256.count++; - n256.children[byte] = child; + NodeHandle handle(art, node); + auto &n = handle.Get(); + n.count++; + n.children[byte] = child; } void Node256::DeleteChild(ART &art, Node &node, const uint8_t byte) { - auto &n256 = Node::Ref(art, node, NODE_256); + { + NodeHandle handle(art, node); + auto &n = handle.Get(); - // Free the child and decrease the count. - Node::Free(art, n256.children[byte]); - n256.count--; + // Free the child and decrease the count. + Node::FreeTree(art, n.children[byte]); + n.count--; - // Shrink to Node48. - if (n256.count <= SHRINK_THRESHOLD) { - auto node256 = node; - Node48::ShrinkNode256(art, node, node256); + if (n.count >= SHRINK_THRESHOLD) { + return; + } } -} - -void Node256::ReplaceChild(const uint8_t byte, const Node child) { - D_ASSERT(count > SHRINK_THRESHOLD); - auto status = children[byte].GetGateStatus(); - children[byte] = child; - if (status == GateStatus::GATE_SET && child.HasMetadata()) { - children[byte].SetGateStatus(status); - } + // Shrink to Node48. + auto node256 = node; + Node48::ShrinkNode256(art, node, node256); } -Node256 &Node256::GrowNode48(ART &art, Node &node256, Node &node48) { - auto &n48 = Node::Ref(art, node48, NType::NODE_48); - auto &n256 = New(art, node256); - node256.SetGateStatus(node48.GetGateStatus()); +void Node256::GrowNode48(ART &art, Node &node256, Node &node48) { + { + NodeHandle n48_handle(art, node48); + auto &n48 = n48_handle.Get(); - n256.count = n48.count; - for (uint16_t i = 0; i < CAPACITY; i++) { - if (n48.child_index[i] != Node48::EMPTY_MARKER) { - n256.children[i] = n48.children[n48.child_index[i]]; - } else { - n256.children[i].Clear(); + auto n256_handle = New(art, node256); + auto &n256 = n256_handle.Get(); + node256.SetGateStatus(node48.GetGateStatus()); + + n256.count = n48.count; + for (uint16_t i = 0; i < CAPACITY; i++) { + if (n48.child_index[i] != Node48::EMPTY_MARKER) { + n256.children[i] = n48.children[n48.child_index[i]]; + } } } - - n48.count = 0; - Node::Free(art, node48); - return n256; + Node::FreeNode(art, node48); } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node256_leaf.cpp b/src/duckdb/src/execution/index/art/node256_leaf.cpp index d735fcc77..14810b5fe 100644 --- a/src/duckdb/src/execution/index/art/node256_leaf.cpp +++ b/src/duckdb/src/execution/index/art/node256_leaf.cpp @@ -92,9 +92,8 @@ void Node256Leaf::GrowNode15Leaf(ART &art, Node &node256_leaf, Node &node15_leaf for (uint8_t i = 0; i < n15.count; i++) { mask.SetValid(n15.key[i]); } - n15.count = 0; } - Node::Free(art, node15_leaf); + Node::FreeNode(art, node15_leaf); } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node48.cpp b/src/duckdb/src/execution/index/art/node48.cpp index f9ad0460c..0e8d640d4 100644 --- a/src/duckdb/src/execution/index/art/node48.cpp +++ b/src/duckdb/src/execution/index/art/node48.cpp @@ -5,126 +5,94 @@ namespace duckdb { -Node48 &Node48::New(ART &art, Node &node) { - node = Node::GetAllocator(art, NODE_48).New(); - node.SetMetadata(static_cast(NODE_48)); - auto &n48 = Node::Ref(art, node, NODE_48); - - n48.count = 0; - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - n48.child_index[i] = EMPTY_MARKER; - } - for (uint8_t i = 0; i < CAPACITY; i++) { - n48.children[i].Clear(); - } - - return n48; -} - -void Node48::Free(ART &art, Node &node) { - auto &n48 = Node::Ref(art, node, NODE_48); - if (!n48.count) { - return; - } - - Iterator(n48, [&](Node &child) { Node::Free(art, child); }); -} - void Node48::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - auto &n48 = Node::Ref(art, node, NODE_48); - - // The node is full. Grow to Node256. - if (n48.count == CAPACITY) { - auto node48 = node; - Node256::GrowNode48(art, node, node48); - Node256::InsertChild(art, node, byte, child); - return; - } - - // Still space. Insert the child. - uint8_t child_pos = n48.count; - if (n48.children[child_pos].HasMetadata()) { - // Find an empty position in the node list. - child_pos = 0; - while (n48.children[child_pos].HasMetadata()) { - child_pos++; + { + NodeHandle handle(art, node); + auto &n = handle.Get(); + + if (n.count != CAPACITY) { + // Still space. Insert the child. + // Find an empty position in the node list. + auto child_pos = n.count; + if (n.children[child_pos].HasMetadata()) { + child_pos = 0; + while (n.children[child_pos].HasMetadata()) { + child_pos++; + } + } + + n.children[child_pos] = child; + n.child_index[byte] = child_pos; + n.count++; + return; } } - n48.children[child_pos] = child; - n48.child_index[byte] = child_pos; - n48.count++; + // The node is full. + // Grow to Node256. + auto node48 = node; + Node256::GrowNode48(art, node, node48); + Node256::InsertChild(art, node, byte, child); } void Node48::DeleteChild(ART &art, Node &node, const uint8_t byte) { - auto &n48 = Node::Ref(art, node, NODE_48); + { + NodeHandle handle(art, node); + auto &n = handle.Get(); - // Free the child and decrease the count. - Node::Free(art, n48.children[n48.child_index[byte]]); - n48.child_index[byte] = EMPTY_MARKER; - n48.count--; + // Free the child and decrease the count. + Node::FreeTree(art, n.children[n.child_index[byte]]); + n.child_index[byte] = EMPTY_MARKER; + n.count--; - // Shrink to Node16. - if (n48.count < SHRINK_THRESHOLD) { - auto node48 = node; - Node16::ShrinkNode48(art, node, node48); + if (n.count >= SHRINK_THRESHOLD) { + return; + } } -} - -void Node48::ReplaceChild(const uint8_t byte, const Node child) { - D_ASSERT(count >= SHRINK_THRESHOLD); - auto status = children[child_index[byte]].GetGateStatus(); - children[child_index[byte]] = child; - if (status == GateStatus::GATE_SET && child.HasMetadata()) { - children[child_index[byte]].SetGateStatus(status); - } + // Shrink to Node16. + auto node48 = node; + Node16::ShrinkNode48(art, node, node48); } -Node48 &Node48::GrowNode16(ART &art, Node &node48, Node &node16) { - auto &n16 = Node::Ref(art, node16, NType::NODE_16); - auto &n48 = New(art, node48); - node48.SetGateStatus(node16.GetGateStatus()); +void Node48::GrowNode16(ART &art, Node &node48, Node &node16) { + { + NodeHandle n16_handle(art, node16); + auto &n16 = n16_handle.Get(); - n48.count = n16.count; - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - n48.child_index[i] = EMPTY_MARKER; - } - for (uint8_t i = 0; i < n16.count; i++) { - n48.child_index[n16.key[i]] = i; - n48.children[i] = n16.children[i]; - } - for (uint8_t i = n16.count; i < CAPACITY; i++) { - n48.children[i].Clear(); - } + auto n48_handle = New(art, node48); + auto &n48 = n48_handle.Get(); + node48.SetGateStatus(node16.GetGateStatus()); - n16.count = 0; - Node::Free(art, node16); - return n48; + n48.count = n16.count; + for (uint8_t i = 0; i < n16.count; i++) { + n48.child_index[n16.key[i]] = i; + n48.children[i] = n16.children[i]; + } + } + Node::FreeNode(art, node16); } -Node48 &Node48::ShrinkNode256(ART &art, Node &node48, Node &node256) { - auto &n48 = New(art, node48); - auto &n256 = Node::Ref(art, node256, NType::NODE_256); - node48.SetGateStatus(node256.GetGateStatus()); - - n48.count = 0; - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - if (!n256.children[i].HasMetadata()) { - n48.child_index[i] = EMPTY_MARKER; - continue; +void Node48::ShrinkNode256(ART &art, Node &node48, Node &node256) { + { + auto n48_handle = New(art, node48); + auto &n48 = n48_handle.Get(); + node48.SetGateStatus(node256.GetGateStatus()); + + NodeHandle n256_handle(art, node256); + auto &n256 = n256_handle.Get(); + + n48.count = 0; + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + if (!n256.children[i].HasMetadata()) { + continue; + } + n48.child_index[i] = n48.count; + n48.children[n48.count] = n256.children[i]; + n48.count++; } - n48.child_index[i] = n48.count; - n48.children[n48.count] = n256.children[i]; - n48.count++; - } - for (uint8_t i = n48.count; i < CAPACITY; i++) { - n48.children[i].Clear(); } - - n256.count = 0; - Node::Free(art, node256); - return n48; + Node::FreeNode(art, node256); } } // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/prefix.cpp b/src/duckdb/src/execution/index/art/prefix.cpp index e39556609..316e56b2d 100644 --- a/src/duckdb/src/execution/index/art/prefix.cpp +++ b/src/duckdb/src/execution/index/art/prefix.cpp @@ -48,10 +48,9 @@ uint8_t Prefix::GetByte(const ART &art, const Node &node, const uint8_t pos) { return prefix.data[pos]; } -Prefix Prefix::NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset, - const NType type) { - node = Node::GetAllocator(art, type).New(); - node.SetMetadata(static_cast(type)); +Prefix Prefix::NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset) { + node = Node::GetAllocator(art, PREFIX).New(); + node.SetMetadata(static_cast(PREFIX)); Prefix prefix(art, node, true); prefix.data[Count(art)] = count; @@ -59,6 +58,7 @@ Prefix Prefix::NewInternal(ART &art, Node &node, const data_ptr_t data, const ui D_ASSERT(count); memcpy(prefix.data, data + offset, count); } + prefix.ptr->Clear(); return prefix; } @@ -68,7 +68,7 @@ void Prefix::New(ART &art, reference &ref, const ARTKey &key, const idx_t while (count) { auto min = MinValue(UnsafeNumericCast(Count(art)), count); auto this_count = UnsafeNumericCast(min); - auto prefix = NewInternal(art, ref, key.data, this_count, offset + depth, PREFIX); + auto prefix = NewInternal(art, ref, key.data, this_count, offset + depth); ref = *prefix.ptr; offset += this_count; @@ -76,20 +76,6 @@ void Prefix::New(ART &art, reference &ref, const ARTKey &key, const idx_t } } -void Prefix::Free(ART &art, Node &node) { - Node next; - - while (node.HasMetadata() && node.GetType() == PREFIX) { - Prefix prefix(art, node, true); - next = *prefix.ptr; - Node::GetAllocator(art, PREFIX).Free(node); - node = next; - } - - Node::Free(art, node); - node.Clear(); -} - void Prefix::Concat(ART &art, Node &parent, uint8_t byte, const GateStatus old_status, const Node &child, const GateStatus status) { D_ASSERT(!parent.IsAnyLeaf()); @@ -107,14 +93,24 @@ void Prefix::Concat(ART &art, Node &parent, uint8_t byte, const GateStatus old_s } if (status == GateStatus::GATE_SET && child.GetType() == NType::LEAF_INLINED) { + // Inside gates, inlined leaves are not prefixed. auto row_id = child.GetRowId(); - Free(art, parent); + // We free the prefix (chain) until we reach the deleted Node4. + // Then, we move the row ID up. + auto current = parent; + while (current.HasMetadata()) { + D_ASSERT(current.GetType() == NType::PREFIX); + Prefix prefix(art, current, true); + auto next = *prefix.ptr; + Node::FreeNode(art, current); + current = next; + } Leaf::New(parent, row_id); return; } if (parent.GetType() != PREFIX) { - auto prefix = NewInternal(art, parent, &byte, 1, 0, PREFIX); + auto prefix = NewInternal(art, parent, &byte, 1, 0); if (child.GetType() == PREFIX) { prefix.Append(art, child); } else { @@ -173,8 +169,7 @@ void Prefix::Reduce(ART &art, Node &node, const idx_t pos) { Prefix prefix(art, node); if (pos == idx_t(prefix.data[Count(art)] - 1)) { auto next = *prefix.ptr; - prefix.ptr->Clear(); - Node::Free(art, node); + Node::FreeNode(art, node); node = next; return; } @@ -216,7 +211,7 @@ GateStatus Prefix::Split(ART &art, reference &node, Node &child, const uin // Create a new prefix and // 1. copy the remaining bytes of this prefix. // 2. append remaining prefix nodes. - auto new_prefix = NewInternal(art, child, nullptr, 0, 0, PREFIX); + auto new_prefix = NewInternal(art, child, nullptr, 0, 0); new_prefix.data[Count(art)] = prefix.data[Count(art)] - pos - 1; memcpy(new_prefix.data, prefix.data + pos + 1, new_prefix.data[Count(art)]); @@ -243,8 +238,7 @@ GateStatus Prefix::Split(ART &art, reference &node, Node &child, const uin // No bytes left before the split, free this node. if (pos == 0) { auto old_status = node.get().GetGateStatus(); - prefix.ptr->Clear(); - Node::Free(art, node); + Node::FreeNode(art, node); return old_status; } @@ -305,8 +299,7 @@ void Prefix::TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptrClear(); - Node::Free(art, current_node); + Node::FreeNode(art, current_node); current_node = *new_prefix.ptr; } @@ -321,7 +314,7 @@ Prefix Prefix::Append(ART &art, const uint8_t byte) { return *this; } - auto prefix = NewInternal(art, *ptr, nullptr, 0, 0, PREFIX); + auto prefix = NewInternal(art, *ptr, nullptr, 0, 0); return prefix.Append(art, byte); } @@ -341,7 +334,7 @@ void Prefix::Append(ART &art, Node other) { } *prefix.ptr = *other_prefix.ptr; - Node::GetAllocator(art, PREFIX).Free(other); + Node::FreeNode(art, other); other = *prefix.ptr; } } @@ -364,14 +357,14 @@ void Prefix::ConcatGate(ART &art, Node &parent, uint8_t byte, const Node &child) } else if (child.GetType() == PREFIX) { // At least one more row ID in this gate. - auto prefix = NewInternal(art, new_prefix, &byte, 1, 0, PREFIX); + auto prefix = NewInternal(art, new_prefix, &byte, 1, 0); prefix.ptr->Clear(); prefix.Append(art, child); new_prefix.SetGateStatus(GateStatus::GATE_SET); } else { // At least one more row ID in this gate. - auto prefix = NewInternal(art, new_prefix, &byte, 1, 0, PREFIX); + auto prefix = NewInternal(art, new_prefix, &byte, 1, 0); *prefix.ptr = child; new_prefix.SetGateStatus(GateStatus::GATE_SET); } @@ -386,7 +379,7 @@ void Prefix::ConcatGate(ART &art, Node &parent, uint8_t byte, const Node &child) void Prefix::ConcatChildIsGate(ART &art, Node &parent, uint8_t byte, const Node &child) { // Create a new prefix and point it to the gate. if (parent.GetType() != PREFIX) { - auto prefix = NewInternal(art, parent, &byte, 1, 0, PREFIX); + auto prefix = NewInternal(art, parent, &byte, 1, 0); *prefix.ptr = child; return; } diff --git a/src/duckdb/src/execution/index/bound_index.cpp b/src/duckdb/src/execution/index/bound_index.cpp index 2c1de0efc..f3a4bd0e3 100644 --- a/src/duckdb/src/execution/index/bound_index.cpp +++ b/src/duckdb/src/execution/index/bound_index.cpp @@ -154,4 +154,31 @@ string BoundIndex::AppendRowError(DataChunk &input, idx_t index) { return error; } +void BoundIndex::ApplyBufferedAppends(ColumnDataCollection &buffered_appends) { + IndexAppendInfo index_append_info(IndexAppendMode::INSERT_DUPLICATES, nullptr); + + ColumnDataScanState state; + buffered_appends.InitializeScan(state); + + DataChunk scan_chunk; + buffered_appends.InitializeScanChunk(scan_chunk); + + auto append_types = scan_chunk.GetTypes(); + append_types.pop_back(); + DataChunk append_chunk; + append_chunk.InitializeEmpty(append_types); + + while (buffered_appends.Scan(state, scan_chunk)) { + for (idx_t i = 0; i < append_chunk.ColumnCount(); i++) { + append_chunk.data[i].Reference(scan_chunk.data[i]); + } + append_chunk.SetCardinality(scan_chunk.size()); + + auto error = Append(append_chunk, scan_chunk.data.back(), index_append_info); + if (error.HasError()) { + throw InternalException("error while applying buffered appends: " + error.Message()); + } + } +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/index/unbound_index.cpp b/src/duckdb/src/execution/index/unbound_index.cpp index 86f224fee..63d54a15c 100644 --- a/src/duckdb/src/execution/index/unbound_index.cpp +++ b/src/duckdb/src/execution/index/unbound_index.cpp @@ -1,15 +1,13 @@ #include "duckdb/execution/index/unbound_index.hpp" + +#include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/parser/parsed_data/create_index_info.hpp" -#include "duckdb/storage/table_io_manager.hpp" #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/index_storage_info.hpp" +#include "duckdb/storage/table_io_manager.hpp" namespace duckdb { -//------------------------------------------------------------------------------- -// Unbound index -//------------------------------------------------------------------------------- - UnboundIndex::UnboundIndex(unique_ptr create_info, IndexStorageInfo storage_info_p, TableIOManager &table_io_manager, AttachedDatabase &db) : Index(create_info->Cast().column_ids, table_io_manager, db), create_info(std::move(create_info)), @@ -20,7 +18,7 @@ UnboundIndex::UnboundIndex(unique_ptr create_info, IndexStorageInfo auto &info = storage_info.allocator_infos[info_idx]; for (idx_t buffer_idx = 0; buffer_idx < info.buffer_ids.size(); buffer_idx++) { if (info.buffer_ids[buffer_idx] > idx_t(MAX_ROW_ID)) { - throw InternalException("Found invalid buffer ID in UnboundIndex constructor"); + throw InternalException("found invalid buffer ID in UnboundIndex constructor"); } } } @@ -37,4 +35,23 @@ void UnboundIndex::CommitDrop() { } } +void UnboundIndex::BufferChunk(DataChunk &chunk, Vector &row_ids) { + auto types = chunk.GetTypes(); + types.push_back(LogicalType::ROW_TYPE); + + if (!buffered_appends) { + auto &allocator = Allocator::Get(db); + buffered_appends = make_uniq(allocator, types); + } + + DataChunk combined_chunk; + combined_chunk.InitializeEmpty(types); + for (idx_t i = 0; i < chunk.ColumnCount(); i++) { + combined_chunk.data[i].Reference(chunk.data[i]); + } + combined_chunk.data.back().Reference(row_ids); + combined_chunk.SetCardinality(chunk.size()); + buffered_appends->Append(combined_chunk); +} + } // namespace duckdb diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp index 589ad0e61..cfa845a88 100644 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ b/src/duckdb/src/execution/join_hashtable.cpp @@ -6,6 +6,8 @@ #include "duckdb/execution/ht_entry.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/storage/buffer_manager.hpp" +#include "duckdb/main/settings.hpp" +#include "duckdb/logging/log_manager.hpp" namespace duckdb { @@ -19,7 +21,7 @@ JoinHashTable::SharedState::SharedState() } JoinHashTable::ProbeState::ProbeState() - : SharedState(), ht_offsets_v(LogicalType::UBIGINT), hashes_dense_v(LogicalType::HASH), + : SharedState(), ht_offsets_and_salts_v(LogicalType::UBIGINT), hashes_dense_v(LogicalType::HASH), non_empty_sel(STANDARD_VECTOR_SIZE) { } @@ -109,8 +111,7 @@ JoinHashTable::JoinHashTable(ClientContext &context_p, const PhysicalOperator &o memset(dead_end.get(), 0, layout_ptr->GetRowWidth()); if (join_type == JoinType::SINGLE) { - auto &config = ClientConfig::GetConfig(context); - single_join_error_on_multiple_rows = config.scalar_subquery_error_on_multiple_rows; + single_join_error_on_multiple_rows = DBConfig::GetSetting(context); } InitializePartitionMasks(); @@ -168,18 +169,21 @@ static void AddPointerToCompare(JoinHashTable::ProbeState &state, const ht_entry idx_t row_ht_offset, idx_t &keys_to_compare_count, const idx_t &row_index) { const auto row_ptr_insert_to = FlatVector::GetData(pointers_result_v); - const auto ht_offsets = FlatVector::GetData(state.ht_offsets_v); + const auto ht_offsets_and_salts = FlatVector::GetData(state.ht_offsets_and_salts_v); state.keys_to_compare_sel.set_index(keys_to_compare_count, row_index); row_ptr_insert_to[row_index] = entry.GetPointer(); - ht_offsets[row_index] = row_ht_offset; + + // If the key does not match, we have to continue linear probing, we need to store the ht_offset and the salt + // for this element based on the row_index. We can't get the offset from the hash as we already might have + // some linear probing steps when arriving here. + ht_offsets_and_salts[row_index] = row_ht_offset | entry.GetSaltWithNulls(); keys_to_compare_count += 1; } template static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHashTable &ht, ht_entry_t *entries, - Vector &hashes_v, Vector &pointers_result_v, const SelectionVector *row_sel, - idx_t &count) { + Vector &pointers_result_v, const SelectionVector *row_sel, idx_t &count) { auto hashes_dense = FlatVector::GetData(state.hashes_dense_v); @@ -187,11 +191,11 @@ static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHash for (idx_t i = 0; i < count; i++) { - auto row_hash = hashes_dense[i]; // hashes has been flattened before -> always access dense + auto row_hash = hashes_dense[i]; // hashes have been flattened before -> always access dense auto row_ht_offset = row_hash & ht.bitmask; if (USE_SALTS) { - // increment the ht_offset of the entry as long as next entry is occupied and salt does not match + // increment the ht_offset of the entry as long as the next entry is occupied and salt does not match while (true) { const ht_entry_t entry = entries[row_ht_offset]; const bool occupied = entry.IsOccupied(); @@ -211,7 +215,7 @@ static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHash break; } - // full and salt does not match -> continue probing + // full and salt do not match -> continue probing IncrementAndWrap(row_ht_offset, ht.bitmask); } } else { @@ -235,14 +239,12 @@ static idx_t ProbeForPointersInternal(JoinHashTable::ProbeState &state, JoinHash /// -> match, add to compare sel and increase found count template static idx_t ProbeForPointers(JoinHashTable::ProbeState &state, JoinHashTable &ht, ht_entry_t *entries, - Vector &hashes_v, Vector &pointers_result_v, const SelectionVector *row_sel, idx_t count, + Vector &pointers_result_v, const SelectionVector *row_sel, idx_t count, const bool has_row_sel) { if (has_row_sel) { - return ProbeForPointersInternal(state, ht, entries, hashes_v, pointers_result_v, row_sel, - count); + return ProbeForPointersInternal(state, ht, entries, pointers_result_v, row_sel, count); } else { - return ProbeForPointersInternal(state, ht, entries, hashes_v, pointers_result_v, row_sel, - count); + return ProbeForPointersInternal(state, ht, entries, pointers_result_v, row_sel, count); } } @@ -254,14 +256,10 @@ static void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_sta ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel, bool has_row_sel) { - // in case of a hash collision, we need this information to correctly retrieve the salt of this hash - bool uses_unified = false; - UnifiedVectorFormat hashes_unified_v; - // densify hashes: If there is no sel, flatten the hashes, else densify via UnifiedVectorFormat if (has_row_sel) { + UnifiedVectorFormat hashes_unified_v; hashes_v.ToUnifiedFormat(count, hashes_unified_v); - uses_unified = true; auto hashes_unified = UnifiedVectorFormat::GetData(hashes_unified_v); auto hashes_dense = FlatVector::GetData(state.hashes_dense_v); @@ -282,8 +280,8 @@ static void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_sta idx_t elements_to_probe_count = count; do { - const idx_t keys_to_compare_count = ProbeForPointers(state, ht, entries, hashes_v, pointers_result_v, - row_sel, elements_to_probe_count, has_row_sel); + const idx_t keys_to_compare_count = ProbeForPointers(state, ht, entries, pointers_result_v, row_sel, + elements_to_probe_count, has_row_sel); // if there are no keys to compare, we are done if (keys_to_compare_count == 0) { @@ -305,37 +303,21 @@ static void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_sta match_count++; } - // Linear probing for collisions: Move to the next entry in the HT - auto hashes_unified = uses_unified ? UnifiedVectorFormat::GetData(hashes_unified_v) : nullptr; - auto hashes_dense = FlatVector::GetData(state.hashes_dense_v); - auto ht_offsets = FlatVector::GetData(state.ht_offsets_v); + const auto ht_offsets_and_salts = FlatVector::GetData(state.ht_offsets_and_salts_v); + const auto hashes_dense = FlatVector::GetData(state.hashes_dense_v); + // For all the non-matches, increment the offset to continue probing but keep the salt intact for (idx_t i = 0; i < keys_no_match_count; i++) { const auto row_index = state.keys_no_match_sel.get_index(i); - // The ProbeForPointers function calculates the ht_offset from the hash; therefore, we have to write the - // new offset into the hashes_v; otherwise the next iteration will start at the old position. This might - // seem as an overhead but assures that the first call of ProbeForPointers is optimized as conceding - // calls are unlikely (Max 1-(65535/65536)^VectorSize = 3.1%) - auto ht_offset = ht_offsets[row_index]; - IncrementAndWrap(ht_offset, ht.bitmask); - - // Get original hash from unified vector format to extract the salt if hashes_dense was populated that way - hash_t hash; - if (hashes_unified) { - const auto uvf_index = hashes_unified_v.sel->get_index(row_index); - hash = hashes_unified[uvf_index]; - } else { - hash = hashes_dense[row_index]; - } - - const auto offset_and_salt = ht_offset | (hash & ht_entry_t::SALT_MASK); - - hashes_dense[i] = offset_and_salt; // populate dense again + auto ht_offset_and_salt = ht_offsets_and_salts[row_index]; + IncrementAndWrap(ht_offset_and_salt, ht.bitmask | ht_entry_t::SALT_MASK); + hashes_dense[i] = ht_offset_and_salt; // populate dense again } // in the next interation, we have a selection vector with the keys that do not match row_sel = &state.keys_no_match_sel; has_row_sel = true; + elements_to_probe_count = keys_no_match_count; } while (DUCKDB_UNLIKELY(keys_no_match_count > 0)); @@ -746,6 +728,11 @@ void JoinHashTable::AllocatePointerTable() { capacity = PointerTableCapacity(Count()); D_ASSERT(IsPowerOfTwo(capacity)); + constexpr uint64_t MAX_HASHTABLE_CAPACITY = (1ULL << 48) - 1; + if (capacity >= MAX_HASHTABLE_CAPACITY) { + throw InternalException("Hashtable capacity exceeds 48-bit limit (2^48 - 1)"); + } + if (hash_map.get()) { // There is already a hash map auto current_capacity = hash_map.GetSize() / sizeof(ht_entry_t); diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp index 09a585b4a..a86ce6ae4 100644 --- a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp +++ b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp @@ -1,6 +1,8 @@ #include "duckdb/execution/operator/aggregate/physical_window.hpp" -#include "duckdb/common/sort/partition_state.hpp" +#include "duckdb/common/sorting/hashed_sort.hpp" +#include "duckdb/common/types/row/tuple_data_collection.hpp" +#include "duckdb/common/types/row/tuple_data_iterator.hpp" #include "duckdb/function/window/window_aggregate_function.hpp" #include "duckdb/function/window/window_executor.hpp" #include "duckdb/function/window/window_rank_function.hpp" @@ -8,8 +10,7 @@ #include "duckdb/function/window/window_shared_expressions.hpp" #include "duckdb/function/window/window_value_function.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" -// -#include +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -37,8 +38,8 @@ struct WindowSourceTask { class WindowHashGroup { public: - using HashGroupPtr = unique_ptr; - using OrderMasks = PartitionGlobalHashGroup::OrderMasks; + using HashGroupPtr = unique_ptr; + using OrderMasks = HashedSortGroup::OrderMasks; using ExecutorGlobalStatePtr = unique_ptr; using ExecutorGlobalStates = vector; using ExecutorLocalStatePtr = unique_ptr; @@ -46,10 +47,13 @@ class WindowHashGroup { using ThreadLocalStates = vector; using Task = WindowSourceTask; using TaskPtr = optional_ptr; + using ScannerPtr = unique_ptr; - WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash_bin_p); + WindowHashGroup(WindowGlobalSinkState &gsink, const idx_t hash_bin_p); - ExecutorGlobalStates &Initialize(WindowGlobalSinkState &gstate); + void ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks); + + ExecutorGlobalStates &Initialize(ClientContext &client); // The total number of tasks we will execute (SINK, FINALIZE, GETDATA per thread) inline idx_t GetTaskCount() const { @@ -62,20 +66,9 @@ class WindowHashGroup { // Set up the task parameters idx_t InitTasks(idx_t per_thread); - // Scan all of the blocks during the build phase - unique_ptr GetBuildScanner(idx_t block_idx) const { - if (!rows) { - return nullptr; - } - return make_uniq(*rows, *heap, layout, external, block_idx, false); - } - - // Scan a single block during the evaluate phase - unique_ptr GetEvaluateScanner(idx_t block_idx) const { - // Second pass can flush - D_ASSERT(rows); - return make_uniq(*rows, *heap, layout, external, block_idx, true); - } + // Scan all of the chunks, starting at a given point + ScannerPtr GetScanner(const idx_t begin_idx) const; + void UpdateScanner(ScannerPtr &scanner, idx_t begin_idx) const; // The processing stage for this group WindowGroupStage GetStage() const { @@ -114,7 +107,7 @@ class WindowHashGroup { task.thread_idx = next_task % group_threads; task.group_idx = hash_bin; task.begin_idx = task.thread_idx * per_thread; - task.max_idx = rows->blocks.size(); + task.max_idx = rows->ChunkCount(); task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); ++next_task; return true; @@ -123,23 +116,22 @@ class WindowHashGroup { return false; } + //! The shared global state from sinking + WindowGlobalSinkState &gsink; //! The hash partition data HashGroupPtr hash_group; //! The size of the group idx_t count = 0; //! The number of blocks in the group idx_t blocks = 0; - unique_ptr rows; - unique_ptr heap; - RowLayout layout; + unique_ptr rows; + TupleDataLayout layout; //! The partition boundary mask ValidityMask partition_mask; //! The order boundary mask OrderMasks order_masks; //! The fully materialised data collection unique_ptr collection; - //! External paging - bool external; // The processing stage for this group atomic stage; //! The function global states for this hash group @@ -165,75 +157,69 @@ class WindowHashGroup { std::atomic completed; //! The output ordering batch index this hash group starts at idx_t batch_base; - -private: - void MaterializeSortedData(); }; -class WindowPartitionGlobalSinkState; - class WindowGlobalSinkState : public GlobalSinkState { public: + using WindowHashGroupPtr = unique_ptr; using ExecutorPtr = unique_ptr; using Executors = vector; - WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context); + class Callback : public HashedSortCallback { + public: + explicit Callback(GlobalSinkState &gsink) : gsink(gsink) { + } - //! Parent operator - const PhysicalWindow &op; - //! Execution context - ClientContext &context; - //! The partitioned sunk data - unique_ptr global_partition; - //! The execution functions - Executors executors; - //! The shared expressions library - WindowSharedExpressions shared; -}; + void OnSortedGroup(HashedSortGroup &hash_group) override { + gsink.Cast().OnSortedGroup(hash_group); + } -class WindowPartitionGlobalSinkState : public PartitionGlobalSinkState { -public: - using WindowHashGroupPtr = unique_ptr; + GlobalSinkState &gsink; + }; - WindowPartitionGlobalSinkState(WindowGlobalSinkState &gsink, const BoundWindowExpression &wexpr) - : PartitionGlobalSinkState(gsink.context, wexpr.partitions, wexpr.orders, gsink.op.children[0].get().GetTypes(), - wexpr.partitions_stats, gsink.op.estimated_cardinality), - gsink(gsink) { - } - ~WindowPartitionGlobalSinkState() override = default; + WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context); - void OnBeginMerge() override { - PartitionGlobalSinkState::OnBeginMerge(); - window_hash_groups.resize(hash_groups.size()); + void Finalize(ClientContext &context, InterruptState &interrupt_state) { + global_partition->Finalize(context, interrupt_state); + window_hash_groups.resize(global_partition->hash_groups.size()); } - void OnSortedPartition(const idx_t group_idx) override { - PartitionGlobalSinkState::OnSortedPartition(group_idx); - window_hash_groups[group_idx] = make_uniq(gsink, group_idx); + void OnSortedGroup(HashedSortGroup &hash_group) { + window_hash_groups[hash_group.group_idx] = make_uniq(*this, hash_group.group_idx); } - //! Operator global sink state - WindowGlobalSinkState &gsink; + //! Parent operator + const PhysicalWindow &op; + //! Client context + ClientContext &client; + //! The partitioned sunk data + unique_ptr global_partition; + //! The callback for completed hash groups + Callback callback; //! The sorted hash groups vector window_hash_groups; + //! The execution functions + Executors executors; + //! The shared expressions library + WindowSharedExpressions shared; }; // Per-thread sink state class WindowLocalSinkState : public LocalSinkState { public: - WindowLocalSinkState(ClientContext &context, const WindowGlobalSinkState &gstate) - : local_partition(context, *gstate.global_partition) { + WindowLocalSinkState(ExecutionContext &context, const WindowGlobalSinkState &gstate) + : local_group(context, *gstate.global_partition) { } - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); + void Sink(ExecutionContext &context, DataChunk &input_chunk) { + local_group.Sink(context, input_chunk); } - void Combine() { - local_partition.Combine(); + void Combine(ExecutionContext &context) { + local_group.Combine(context); } - PartitionLocalSinkState local_partition; + HashedSortLocalSinkState local_group; }; // this implements a sorted window functions variant @@ -259,55 +245,57 @@ PhysicalWindow::PhysicalWindow(PhysicalPlan &physical_plan, vector } } -static unique_ptr WindowExecutorFactory(BoundWindowExpression &wexpr, ClientContext &context, +static unique_ptr WindowExecutorFactory(BoundWindowExpression &wexpr, ClientContext &client, WindowSharedExpressions &shared, WindowAggregationMode mode) { switch (wexpr.GetExpressionType()) { case ExpressionType::WINDOW_AGGREGATE: - return make_uniq(wexpr, context, shared, mode); + return make_uniq(wexpr, client, shared, mode); case ExpressionType::WINDOW_ROW_NUMBER: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_RANK_DENSE: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_RANK: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_PERCENT_RANK: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_CUME_DIST: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_NTILE: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_LEAD: case ExpressionType::WINDOW_LAG: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_FILL: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_FIRST_VALUE: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_LAST_VALUE: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); case ExpressionType::WINDOW_NTH_VALUE: - return make_uniq(wexpr, context, shared); + return make_uniq(wexpr, shared); break; default: throw InternalException("Window aggregate type %s", ExpressionTypeToString(wexpr.GetExpressionType())); } } -WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context) - : op(op), context(context) { +WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &client) + : op(op), client(client), callback(*this) { D_ASSERT(op.select_list[op.order_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); auto &wexpr = op.select_list[op.order_idx]->Cast(); - const auto mode = DBConfig::GetConfig(context).options.window_mode; + const auto mode = DBConfig::GetSetting(client); for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); auto &wexpr = op.select_list[expr_idx]->Cast(); - auto wexec = WindowExecutorFactory(wexpr, context, shared, mode); + auto wexec = WindowExecutorFactory(wexpr, client, shared, mode); executors.emplace_back(std::move(wexec)); } - global_partition = make_uniq(*this, wexpr); + global_partition = + make_uniq(client, wexpr.partitions, wexpr.orders, op.children[0].get().GetTypes(), + wexpr.partitions_stats, op.estimated_cardinality); } //===--------------------------------------------------------------------===// @@ -316,51 +304,60 @@ WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientCon SinkResultType PhysicalWindow::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { auto &lstate = input.local_state.Cast(); - lstate.Sink(chunk); + lstate.Sink(context, chunk); return SinkResultType::NEED_MORE_INPUT; } SinkCombineResultType PhysicalWindow::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { auto &lstate = input.local_state.Cast(); - lstate.Combine(); + lstate.Combine(context); return SinkCombineResultType::FINISHED; } unique_ptr PhysicalWindow::GetLocalSinkState(ExecutionContext &context) const { auto &gstate = sink_state->Cast(); - return make_uniq(context.client, gstate); + return make_uniq(context, gstate); } -unique_ptr PhysicalWindow::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); +unique_ptr PhysicalWindow::GetGlobalSinkState(ClientContext &client) const { + return make_uniq(*this, client); } -SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, +SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, OperatorSinkFinalizeInput &input) const { - auto &state = input.global_state.Cast(); + auto &gsink = input.global_state.Cast(); + auto &gpart = *gsink.global_partition; // Did we get any data? - if (!state.global_partition->count) { + if (!gpart.count) { return SinkFinalizeType::NO_OUTPUT_POSSIBLE; } - // Do we have any sorting to schedule? - if (state.global_partition->rows) { - D_ASSERT(!state.global_partition->grouping_data); - return state.global_partition->rows->count ? SinkFinalizeType::READY : SinkFinalizeType::NO_OUTPUT_POSSIBLE; + // OVER() + if (gpart.unsorted) { + // We need to construct the single WindowHashGroup here because the sort tasks will not be run. + D_ASSERT(!gpart.grouping_data); + if (!gpart.unsorted->Count()) { + return SinkFinalizeType::NO_OUTPUT_POSSIBLE; + } + + gsink.window_hash_groups.emplace_back(make_uniq(gsink, idx_t(0))); + return SinkFinalizeType::READY; } + gsink.Finalize(client, input.interrupt_state); + // Find the first group to sort - if (!state.global_partition->HasMergeTasks()) { + if (!gsink.global_partition->HasMergeTasks()) { // Empty input! return SinkFinalizeType::NO_OUTPUT_POSSIBLE; } // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(*state.global_partition, pipeline, *this); - event.InsertEvent(std::move(new_event)); + auto sort_event = make_shared_ptr(gpart, pipeline, *this, &gsink.callback); + event.InsertEvent(std::move(sort_event)); return SinkFinalizeType::READY; } @@ -370,16 +367,13 @@ SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, Clie //===--------------------------------------------------------------------===// class WindowGlobalSourceState : public GlobalSourceState { public: - using ScannerPtr = unique_ptr; + using ScannerPtr = unique_ptr; using Task = WindowSourceTask; using TaskPtr = optional_ptr; using PartitionBlock = std::pair; WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p); - //! Build task list - void CreateTaskList(); - //! Are there any more tasks? bool HasMoreTasks() const { return !stopped && started < total_tasks; @@ -391,7 +385,7 @@ class WindowGlobalSourceState : public GlobalSourceState { bool TryNextTask(TaskPtr &task, Task &task_local); //! Context for executing computations - ClientContext &context; + ClientContext &client; //! All the sunk data WindowGlobalSinkState &gsink; //! The total number of blocks to process; @@ -421,68 +415,53 @@ class WindowGlobalSourceState : public GlobalSourceState { } protected: + //! Build task list + void CreateTaskList(); //! Finish a task void FinishTask(TaskPtr task); }; -WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p) - : context(context_p), gsink(gsink_p), next_group(0), locals(0), started(0), finished(0), stopped(false), - returned(0) { - auto &gpart = gsink.global_partition; - auto &window_hash_groups = gsink.global_partition->window_hash_groups; +WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &client, WindowGlobalSinkState &gsink_p) + : client(client), gsink(gsink_p), next_group(0), locals(0), started(0), finished(0), stopped(false), returned(0) { + auto &window_hash_groups = gsink.window_hash_groups; - if (window_hash_groups.empty()) { - // OVER() - if (gpart->rows && !gpart->rows->blocks.empty()) { - // We need to construct the single WindowHashGroup here because the sort tasks will not be run. - window_hash_groups.emplace_back(make_uniq(gsink, idx_t(0))); - total_blocks = gpart->rows->blocks.size(); + for (auto &window_hash_group : window_hash_groups) { + if (!window_hash_group) { + continue; } - } else { - idx_t batch_base = 0; - for (auto &window_hash_group : window_hash_groups) { - if (!window_hash_group) { - continue; - } - auto &rows = window_hash_group->rows; - if (!rows) { - continue; - } - - const auto block_count = window_hash_group->rows->blocks.size(); - window_hash_group->batch_base = batch_base; - batch_base += block_count; + auto &rows = window_hash_group->rows; + if (!rows) { + continue; } - total_blocks = batch_base; - } -} -void WindowGlobalSourceState::CreateTaskList() { - // Check whether we have a task list outside the mutex. - if (started.load()) { - return; + const auto block_count = window_hash_group->rows->ChunkCount(); + window_hash_group->batch_base = total_blocks; + total_blocks += block_count; } - auto guard = Lock(); - - auto &window_hash_groups = gsink.global_partition->window_hash_groups; - if (!partition_blocks.empty()) { - return; - } + CreateTaskList(); +} +void WindowGlobalSourceState::CreateTaskList() { // Sort the groups from largest to smallest + auto &window_hash_groups = gsink.window_hash_groups; if (window_hash_groups.empty()) { return; } for (idx_t group_idx = 0; group_idx < window_hash_groups.size(); ++group_idx) { auto &window_hash_group = window_hash_groups[group_idx]; - partition_blocks.emplace_back(window_hash_group->rows->blocks.size(), group_idx); + if (!window_hash_group) { + continue; + } + partition_blocks.emplace_back(window_hash_group->rows->ChunkCount(), group_idx); } std::sort(partition_blocks.begin(), partition_blocks.end(), std::greater()); // Schedule the largest group on as many threads as possible - const auto threads = locals.load(); + auto &ts = TaskScheduler::GetScheduler(client); + const auto threads = NumericCast(ts.NumberOfThreads()); + const auto &max_block = partition_blocks.front(); const auto per_thread = (max_block.first + threads - 1) / threads; if (!per_thread) { @@ -495,59 +474,20 @@ void WindowGlobalSourceState::CreateTaskList() { } } -void WindowHashGroup::MaterializeSortedData() { - auto &global_sort_state = *hash_group->global_sort; - if (global_sort_state.sorted_blocks.empty()) { - return; - } - - // scan the sorted row data - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - auto &sb = *global_sort_state.sorted_blocks[0]; - - // Free up some memory before allocating more - sb.radix_sorting_data.clear(); - sb.blob_sorting_data = nullptr; - - // Move the sorting row blocks into our RDCs - auto &buffer_manager = global_sort_state.buffer_manager; - auto &sd = *sb.payload_data; - - // Data blocks are required - D_ASSERT(!sd.data_blocks.empty()); - auto &block = sd.data_blocks[0]; - rows = make_uniq(buffer_manager, block->capacity, block->entry_size); - rows->blocks = std::move(sd.data_blocks); - rows->count = std::accumulate(rows->blocks.begin(), rows->blocks.end(), idx_t(0), - [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - - // Heap blocks are optional, but we want both for iteration. - if (!sd.heap_blocks.empty()) { - auto &block = sd.heap_blocks[0]; - heap = make_uniq(buffer_manager, block->capacity, block->entry_size); - heap->blocks = std::move(sd.heap_blocks); - hash_group.reset(); - } else { - heap = make_uniq(buffer_manager, buffer_manager.GetBlockSize(), 1U, true); - } - heap->count = std::accumulate(heap->blocks.begin(), heap->blocks.end(), idx_t(0), - [&](idx_t c, const unique_ptr &b) { return c + b->count; }); -} - -WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash_bin_p) - : count(0), blocks(0), stage(WindowGroupStage::SINK), hash_bin(hash_bin_p), sunk(0), finalized(0), completed(0), - batch_base(0) { +WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gsink, const idx_t hash_bin_p) + : gsink(gsink), count(0), blocks(0), stage(WindowGroupStage::SINK), hash_bin(hash_bin_p), sunk(0), finalized(0), + completed(0), batch_base(0) { // There are three types of partitions: // 1. No partition (no sorting) // 2. One partition (sorting, but no hashing) // 3. Multiple partitions (sorting and hashing) // How big is the partition? - auto &gpart = *gstate.global_partition; - layout.Initialize(gpart.payload_types); + auto &gpart = *gsink.global_partition; + layout.Initialize(gpart.payload_types, TupleDataValidityType::CAN_HAVE_NULL_VALUES); if (hash_bin < gpart.hash_groups.size() && gpart.hash_groups[hash_bin]) { - count = gpart.hash_groups[hash_bin]->count; - } else if (gpart.rows && !hash_bin) { + count = gpart.hash_groups[hash_bin]->sorted->Count(); + } else if (gpart.unsorted && !hash_bin) { count = gpart.count; } else { return; @@ -557,7 +497,7 @@ WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash partition_mask.Initialize(count); partition_mask.SetAllInvalid(count); - const auto &executors = gstate.executors; + const auto &executors = gsink.executors; for (auto &wexec : executors) { auto &wexpr = wexec->wexpr; auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; @@ -569,41 +509,126 @@ WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash } // Scan the sorted data into new Collections - external = gpart.external; - if (gpart.rows && !hash_bin) { + if (gpart.unsorted && !hash_bin) { // Simple mask partition_mask.SetValidUnsafe(0); for (auto &order_mask : order_masks) { order_mask.second.SetValidUnsafe(0); } - // No partition - align the heap blocks with the row blocks - rows = gpart.rows->CloneEmpty(gpart.rows->keep_pinned); - heap = gpart.strings->CloneEmpty(gpart.strings->keep_pinned); - RowDataCollectionScanner::AlignHeapBlocks(*rows, *heap, *gpart.rows, *gpart.strings, layout); - external = true; + // No partition - take ownership of the accumulated data + rows = std::move(gpart.unsorted); } else if (hash_bin < gpart.hash_groups.size()) { // Overwrite the collections with the sorted data D_ASSERT(gpart.hash_groups[hash_bin].get()); hash_group = std::move(gpart.hash_groups[hash_bin]); - hash_group->ComputeMasks(partition_mask, order_masks); - external = hash_group->global_sort->external; - MaterializeSortedData(); + rows = std::move(hash_group->sorted); + ComputeMasks(partition_mask, order_masks); } if (rows) { - blocks = rows->blocks.size(); + blocks = rows->ChunkCount(); } // Set up the collection for any fully materialised data - const auto &shared = WindowSharedExpressions::GetSortedExpressions(gstate.shared.coll_shared); + const auto &shared = WindowSharedExpressions::GetSortedExpressions(gsink.shared.coll_shared); vector types; for (auto &expr : shared) { types.emplace_back(expr->return_type); } - auto &buffer_manager = BufferManager::GetBufferManager(gstate.context); + auto &buffer_manager = BufferManager::GetBufferManager(gsink.client); collection = make_uniq(buffer_manager, count, types); } +unique_ptr WindowHashGroup::GetScanner(const idx_t begin_idx) const { + if (!rows) { + return nullptr; + } + + auto &scan_ids = gsink.global_partition->scan_ids; + return make_uniq(*rows, scan_ids, begin_idx); +} + +void WindowHashGroup::UpdateScanner(ScannerPtr &scanner, idx_t begin_idx) const { + if (!scanner || &scanner->collection != rows.get()) { + scanner.reset(); + scanner = GetScanner(begin_idx); + } else { + scanner->Seek(begin_idx); + } +} + +void WindowHashGroup::ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks) { + D_ASSERT(count > 0); + + // Set up the partition compare structs + auto &partitions = gsink.global_partition->partitions; + partition_mask.SetValidUnsafe(0); + const auto key_count = partitions.size(); + + // Set up the order data structures + auto &collection = *rows; + auto &scan_cols = gsink.global_partition->sort_ids; + WindowCollectionChunkScanner scanner(collection, scan_cols, 0); + unordered_map prefixes; + for (auto &order_mask : order_masks) { + order_mask.second.SetValidUnsafe(0); + D_ASSERT(order_mask.first >= partitions.size()); + auto order_type = scanner.PrefixStructType(order_mask.first, partitions.size()); + vector types(2, order_type); + auto &keys = prefixes[order_mask.first]; + // We can't use InitializeEmpty here because it doesn't set up all of the STRUCT internals... + keys.Initialize(collection.GetAllocator(), types); + } + + // TODO: Parallelise on mask entry boundaries + const idx_t block_begin = 0; + const auto block_end = collection.ChunkCount(); + WindowDeltaScanner(collection, block_begin, block_end, scan_cols, key_count, + [&](const idx_t row_idx, DataChunk &prev, DataChunk &curr, const idx_t ndistinct, + SelectionVector &distinct, const SelectionVector &matching) { + // Process the partition boundaries + for (idx_t i = 0; i < ndistinct; ++i) { + const idx_t curr_index = row_idx + distinct.get_index(i); + partition_mask.SetValidUnsafe(curr_index); + for (auto &order_mask : order_masks) { + order_mask.second.SetValidUnsafe(curr_index); + } + } + + // Process the peers with each partition + const auto count = MinValue(prev.size(), curr.size()); + const auto nmatch = count - ndistinct; + if (!nmatch) { + return; + } + + for (auto &order_mask : order_masks) { + // If there are no order columns, then all the partition elements are peers and we are + // done + if (partitions.size() == order_mask.first) { + continue; + } + auto &prefix = prefixes[order_mask.first]; + prefix.Reset(); + auto &order_prev = prefix.data[0]; + auto &order_curr = prefix.data[1]; + scanner.ReferenceStructColumns(prev, order_prev, order_mask.first, partitions.size()); + scanner.ReferenceStructColumns(curr, order_curr, order_mask.first, partitions.size()); + if (ndistinct) { + prefix.Slice(matching, nmatch); + } else { + prefix.SetCardinality(nmatch); + } + const auto m = VectorOperations::DistinctFrom(order_curr, order_prev, nullptr, nmatch, + &distinct, nullptr); + for (idx_t i = 0; i < m; ++i) { + const idx_t curr_index = row_idx + matching.get_index(distinct.get_index(i)); + order_mask.second.SetValidUnsafe(curr_index); + } + } + }); +} + // Per-thread scan state class WindowLocalSourceState : public LocalSourceState { public: @@ -624,7 +649,7 @@ class WindowLocalSourceState : public LocalSourceState { //! Assign the next task bool TryAssignTask(); //! Execute a step in the current task - void ExecuteTask(DataChunk &chunk); + void ExecuteTask(ExecutionContext &context, DataChunk &chunk); //! The shared source state WindowGlobalSourceState &gsource; @@ -637,16 +662,14 @@ class WindowLocalSourceState : public LocalSourceState { //! The current source being processed optional_ptr window_hash_group; //! The scan cursor - unique_ptr scanner; - //! Buffer for the inputs - DataChunk input_chunk; + unique_ptr scanner; //! Buffer for window results DataChunk output_chunk; protected: - void Sink(); - void Finalize(); - void GetData(DataChunk &chunk); + void Sink(ExecutionContext &context); + void Finalize(ExecutionContext &context); + void GetData(ExecutionContext &context, DataChunk &chunk); //! Storage and evaluation for the fully materialised data unique_ptr builder; @@ -664,13 +687,13 @@ class WindowLocalSourceState : public LocalSourceState { idx_t WindowHashGroup::InitTasks(idx_t per_thread_p) { per_thread = per_thread_p; - group_threads = (rows->blocks.size() + per_thread - 1) / per_thread; + group_threads = (rows->ChunkCount() + per_thread - 1) / per_thread; thread_states.resize(GetThreadCount()); return GetTaskCount(); } -WindowHashGroup::ExecutorGlobalStates &WindowHashGroup::Initialize(WindowGlobalSinkState &gsink) { +WindowHashGroup::ExecutorGlobalStates &WindowHashGroup::Initialize(ClientContext &client) { // Single-threaded building as this is mostly memory allocation lock_guard gestate_guard(lock); const auto &executors = gsink.executors; @@ -682,13 +705,13 @@ WindowHashGroup::ExecutorGlobalStates &WindowHashGroup::Initialize(WindowGlobalS for (auto &wexec : executors) { auto &wexpr = wexec->wexpr; auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; - gestates.emplace_back(wexec->GetGlobalState(count, partition_mask, order_mask)); + gestates.emplace_back(wexec->GetGlobalState(client, count, partition_mask, order_mask)); } return gestates; } -void WindowLocalSourceState::Sink() { +void WindowLocalSourceState::Sink(ExecutionContext &context) { D_ASSERT(task); D_ASSERT(task->stage == WindowGroupStage::SINK); @@ -697,67 +720,60 @@ void WindowLocalSourceState::Sink() { // Create the global state for each function // These can be large so we defer building them until we are ready. - auto &gestates = window_hash_group->Initialize(gsink); + auto &gestates = window_hash_group->Initialize(context.client); // Set up the local states auto &local_states = window_hash_group->thread_states.at(task->thread_idx); if (local_states.empty()) { for (idx_t w = 0; w < executors.size(); ++w) { - local_states.emplace_back(executors[w]->GetLocalState(*gestates[w])); + local_states.emplace_back(executors[w]->GetLocalState(context, *gestates[w])); } } // First pass over the input without flushing + scanner = window_hash_group->GetScanner(task->begin_idx); + if (!scanner) { + return; + } for (; task->begin_idx < task->end_idx; ++task->begin_idx) { - scanner = window_hash_group->GetBuildScanner(task->begin_idx); - if (!scanner) { + const idx_t input_idx = scanner->Scanned(); + if (!scanner->Scan()) { break; } - while (true) { - // TODO: Try to align on validity mask boundaries by starting ragged? - idx_t input_idx = scanner->Scanned(); - input_chunk.Reset(); - scanner->Scan(input_chunk); - if (input_chunk.size() == 0) { - break; - } + auto &input_chunk = scanner->chunk; - // Compute fully materialised expressions - if (coll_chunk.data.empty()) { - coll_chunk.SetCardinality(input_chunk); - } else { - coll_chunk.Reset(); - coll_exec.Execute(input_chunk, coll_chunk); - auto collection = window_hash_group->collection.get(); - if (!builder || &builder->collection != collection) { - builder = make_uniq(*collection); - } - - builder->Sink(coll_chunk, input_idx); + // Compute fully materialised expressions + if (coll_chunk.data.empty()) { + coll_chunk.SetCardinality(input_chunk); + } else { + coll_chunk.Reset(); + coll_exec.Execute(input_chunk, coll_chunk); + auto collection = window_hash_group->collection.get(); + if (!builder || &builder->collection != collection) { + builder = make_uniq(*collection); } - // Compute sink expressions - if (sink_chunk.data.empty()) { - sink_chunk.SetCardinality(input_chunk); - } else { - sink_chunk.Reset(); - sink_exec.Execute(input_chunk, sink_chunk); - } + builder->Sink(coll_chunk, input_idx); + } - for (idx_t w = 0; w < executors.size(); ++w) { - executors[w]->Sink(sink_chunk, coll_chunk, input_idx, *gestates[w], *local_states[w]); - } + // Compute sink expressions + if (sink_chunk.data.empty()) { + sink_chunk.SetCardinality(input_chunk); + } else { + sink_chunk.Reset(); + sink_exec.Execute(input_chunk, sink_chunk); + } - window_hash_group->sunk += input_chunk.size(); + for (idx_t w = 0; w < executors.size(); ++w) { + executors[w]->Sink(context, sink_chunk, coll_chunk, input_idx, *gestates[w], *local_states[w]); } - // External scanning assumes all blocks are swizzled. - scanner->SwizzleBlock(task->begin_idx); - scanner.reset(); + window_hash_group->sunk += input_chunk.size(); } + scanner.reset(); } -void WindowLocalSourceState::Finalize() { +void WindowLocalSourceState::Finalize(ExecutionContext &context) { D_ASSERT(task); D_ASSERT(task->stage == WindowGroupStage::FINALIZE); @@ -774,7 +790,7 @@ void WindowLocalSourceState::Finalize() { auto &gestates = window_hash_group->gestates; auto &local_states = window_hash_group->thread_states.at(task->thread_idx); for (idx_t w = 0; w < executors.size(); ++w) { - executors[w]->Finalize(*gestates[w], *local_states[w], window_hash_group->collection); + executors[w]->Finalize(context, *gestates[w], *local_states[w], window_hash_group->collection); } // Mark this range as done @@ -783,13 +799,11 @@ void WindowLocalSourceState::Finalize() { } WindowLocalSourceState::WindowLocalSourceState(WindowGlobalSourceState &gsource) - : gsource(gsource), batch_index(0), coll_exec(gsource.context), sink_exec(gsource.context), - eval_exec(gsource.context) { + : gsource(gsource), batch_index(0), coll_exec(gsource.client), sink_exec(gsource.client), + eval_exec(gsource.client) { auto &gsink = gsource.gsink; auto &global_partition = *gsink.global_partition; - input_chunk.Initialize(global_partition.allocator, global_partition.payload_types); - vector output_types; for (auto &wexec : gsink.executors) { auto &wexpr = wexec->wexpr; @@ -815,9 +829,8 @@ bool WindowGlobalSourceState::TryNextTask(TaskPtr &task, Task &task_local) { } // Run through the active groups looking for one that can assign a task - auto &gpart = *gsink.global_partition; for (const auto &group_idx : active_groups) { - auto &window_hash_group = gpart.window_hash_groups[group_idx]; + auto &window_hash_group = gsink.window_hash_groups[group_idx]; if (window_hash_group->TryPrepareNextStage()) { UnblockTasks(guard); } @@ -833,7 +846,7 @@ bool WindowGlobalSourceState::TryNextTask(TaskPtr &task, Task &task_local) { const auto group_idx = partition_blocks[next_group++].second; active_groups.emplace_back(group_idx); - auto &window_hash_group = gpart.window_hash_groups[group_idx]; + auto &window_hash_group = gsink.window_hash_groups[group_idx]; if (window_hash_group->TryPrepareNextStage()) { UnblockTasks(guard); } @@ -857,9 +870,8 @@ void WindowGlobalSourceState::FinishTask(TaskPtr task) { return; } - auto &gpart = *gsink.global_partition; const auto group_idx = task->group_idx; - auto &finished_hash_group = gpart.window_hash_groups[group_idx]; + auto &finished_hash_group = gsink.window_hash_groups[group_idx]; D_ASSERT(finished_hash_group); if (++finished_hash_group->completed >= finished_hash_group->GetTaskCount()) { @@ -886,25 +898,25 @@ bool WindowLocalSourceState::TryAssignTask() { return gsource.TryNextTask(task, task_local); } -void WindowLocalSourceState::ExecuteTask(DataChunk &result) { +void WindowLocalSourceState::ExecuteTask(ExecutionContext &context, DataChunk &result) { auto &gsink = gsource.gsink; // Update the hash group - window_hash_group = gsink.global_partition->window_hash_groups[task->group_idx].get(); + window_hash_group = gsink.window_hash_groups[task->group_idx].get(); // Process the new state switch (task->stage) { case WindowGroupStage::SINK: - Sink(); + Sink(context); D_ASSERT(TaskFinished()); break; case WindowGroupStage::FINALIZE: - Finalize(); + Finalize(context); D_ASSERT(TaskFinished()); break; case WindowGroupStage::GETDATA: D_ASSERT(!TaskFinished()); - GetData(result); + GetData(context, result); break; default: throw InternalException("Invalid window source state."); @@ -916,17 +928,15 @@ void WindowLocalSourceState::ExecuteTask(DataChunk &result) { } } -void WindowLocalSourceState::GetData(DataChunk &result) { +void WindowLocalSourceState::GetData(ExecutionContext &context, DataChunk &result) { D_ASSERT(window_hash_group->GetStage() == WindowGroupStage::GETDATA); - if (!scanner || !scanner->Remaining()) { - scanner = window_hash_group->GetEvaluateScanner(task->begin_idx); - batch_index = window_hash_group->batch_base + task->begin_idx; - } + window_hash_group->UpdateScanner(scanner, task->begin_idx); + batch_index = window_hash_group->batch_base + task->begin_idx; const auto position = scanner->Scanned(); - input_chunk.Reset(); - scanner->Scan(input_chunk); + auto &input_chunk = scanner->chunk; + scanner->Scan(); const auto &executors = gsource.gsink.executors; auto &gestates = window_hash_group->gestates; @@ -943,7 +953,7 @@ void WindowLocalSourceState::GetData(DataChunk &result) { eval_chunk.Reset(); eval_exec.Execute(input_chunk, eval_chunk); } - executor.Evaluate(position, eval_chunk, result, lstate, gstate); + executor.Evaluate(context, position, eval_chunk, result, lstate, gstate); } output_chunk.SetCardinality(input_chunk); output_chunk.Verify(); @@ -957,10 +967,8 @@ void WindowLocalSourceState::GetData(DataChunk &result) { result.data[out_idx++].Reference(output_chunk.data[col_idx]); } - // If we done with this block, move to the next one - if (!scanner->Remaining()) { - ++task->begin_idx; - } + // Move to the next chunk + ++task->begin_idx; result.Verify(); } @@ -971,9 +979,9 @@ unique_ptr PhysicalWindow::GetLocalSourceState(ExecutionContex return make_uniq(gsource); } -unique_ptr PhysicalWindow::GetGlobalSourceState(ClientContext &context) const { +unique_ptr PhysicalWindow::GetGlobalSourceState(ClientContext &client) const { auto &gsink = sink_state->Cast(); - return make_uniq(context, gsink); + return make_uniq(client, gsink); } bool PhysicalWindow::SupportsPartitioning(const OperatorPartitionInfo &partition_info) const { @@ -1001,7 +1009,7 @@ OrderPreservationType PhysicalWindow::SourceOrder() const { return OrderPreservationType::FIXED_ORDER; } -ProgressData PhysicalWindow::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { +ProgressData PhysicalWindow::GetProgress(ClientContext &client, GlobalSourceState &gsource_p) const { auto &gsource = gsource_p.Cast(); const auto returned = gsource.returned.load(); @@ -1032,12 +1040,10 @@ SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &c auto &gsource = input.global_state.Cast(); auto &lsource = input.local_state.Cast(); - gsource.CreateTaskList(); - while (gsource.HasUnfinishedTasks() && chunk.size() == 0) { if (!lsource.TaskFinished() || lsource.TryAssignTask()) { try { - lsource.ExecuteTask(chunk); + lsource.ExecuteTask(context, chunk); } catch (...) { gsource.stopped = true; throw; diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp index 575d69d7d..bcb8306e6 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp @@ -31,12 +31,12 @@ CSVBuffer::CSVBuffer(CSVFileHandle &file_handle, ClientContext &context, idx_t b last_buffer = file_handle.FinishedReading(); } -shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_size, bool &has_seaked) const { - if (has_seaked) { +shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_size, bool &has_seeked) const { + if (has_seeked) { // This means that at some point a reload was done, and we are currently on the incorrect position in our file // handle file_handle.Seek(global_csv_start + actual_buffer_size); - has_seaked = false; + has_seeked = false; } auto next_csv_buffer = make_shared_ptr(file_handle, context, buffer_size, global_csv_start + actual_buffer_size, buffer_idx + 1); @@ -68,7 +68,7 @@ void CSVBuffer::Reload(CSVFileHandle &file_handle) { shared_ptr CSVBuffer::Pin(CSVFileHandle &file_handle, bool &has_seeked) { auto &buffer_manager = BufferManager::GetBufferManager(context); - if (!is_pipe && block->IsUnloaded()) { + if (!block || (!is_pipe && block->IsUnloaded())) { // We have to reload it from disk block = nullptr; Reload(file_handle); diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp index 461c4150c..41df74ff0 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp @@ -504,8 +504,9 @@ void StringValueResult::Reset() { } // We keep a reference to the buffer from our current iteration if it already exists shared_ptr cur_buffer; - if (buffer_handles.find(iterator.GetBufferIdx()) != buffer_handles.end()) { - cur_buffer = buffer_handles[iterator.GetBufferIdx()]; + auto handle_iter = buffer_handles.find(iterator.GetBufferIdx()); + if (handle_iter != buffer_handles.end()) { + cur_buffer = handle_iter->second; } buffer_handles.clear(); idx_t actual_size = 0; diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp index 629a7fb2f..fc8dc9385 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp @@ -394,6 +394,18 @@ void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, if (stats.best_consistent_rows == consistent_rows && num_cols >= max_columns_found) { // If both have not been escaped, this might get solved later on. sniffing_state_machine.dialect_options.num_cols = num_cols; + if (options.dialect_options.skip_rows.IsSetByUser()) { + // If skip rows are set by the user, and we found dirty notes, we only accept it if either + // null_padding or ignore_errors is set + if (dirty_notes != 0 && !options.null_padding && !options.ignore_errors.GetValue()) { + return; + } + sniffing_state_machine.dialect_options.skip_rows = options.dialect_options.skip_rows.GetValue(); + } else if (!options.null_padding) { + sniffing_state_machine.dialect_options.skip_rows = dirty_notes; + } + sniffing_state_machine.dialect_options.num_cols = num_cols; + lines_sniffed = sniffed_column_counts.result_position; successful_candidates.emplace_back(std::move(scanner)); max_columns_found = num_cols; return; @@ -495,7 +507,7 @@ void CSVSniffer::RefineCandidates() { return; } if (candidates.size() == 1 || candidates[0]->FinishedFile()) { - // Only one candidate nothing to refine or all candidates already checked + // Only one candidate nothing to refine, or all candidates already checked return; } @@ -612,7 +624,8 @@ void CSVSniffer::DetectDialect() { if (all_fail_max_line_size) { error = line_error; } else { - error = CSVError::SniffingError(options, dialect_candidates.Print(), max_columns_found_error, set_columns); + error = CSVError::SniffingError(options, dialect_candidates.Print(), max_columns_found_error, set_columns, + false); } error_handler->Error(error, true); } diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp index f3ee76182..390c5d095 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp @@ -425,7 +425,7 @@ void CSVSniffer::DetectTypes() { idx_t min_varchar_cols = max_columns_found + 1; idx_t min_errors = NumericLimits::Maximum(); vector return_types; - // check which info candidate leads to minimum amount of non-varchar columns... + // check which info candidate leads to the minimum number of non-varchar columns... for (auto &candidate_cc : candidates) { auto &sniffing_state_machine = candidate_cc->GetStateMachine(); unordered_map> info_sql_types_candidates; @@ -441,7 +441,7 @@ void CSVSniffer::DetectTypes() { // Reset candidate for parsing auto candidate = candidate_cc->UpgradeToStringValueScanner(); SetUserDefinedDateTimeFormat(*candidate->state_machine); - // Parse chunk and read csv with info candidate + // Parse chunk and read csv with info-candidate auto &data_chunk = candidate->ParseChunk().ToChunk(); if (candidate->error_handler->AnyErrors() && !candidate->error_handler->HasError(MAXIMUM_LINE_SIZE) && !candidate->state_machine->options.ignore_errors.GetValue()) { @@ -502,7 +502,7 @@ void CSVSniffer::DetectTypes() { } if (!best_candidate) { DialectCandidates dialect_candidates(options.dialect_options.state_machine_options); - auto error = CSVError::SniffingError(options, dialect_candidates.Print(), max_columns_found, set_columns); + auto error = CSVError::SniffingError(options, dialect_candidates.Print(), max_columns_found, set_columns, true); error_handler->Error(error, true); } // Assert that it's all good at this point. diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp index d574acc49..7fd64d889 100644 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp +++ b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp @@ -438,12 +438,17 @@ CSVError CSVError::HeaderSniffingError(const CSVReaderOptions &options, const ve } CSVError CSVError::SniffingError(const CSVReaderOptions &options, const string &search_space, idx_t max_columns_found, - SetColumns &set_columns) { + SetColumns &set_columns, bool type_detection) { std::ostringstream error; // 1. Which file error << "Error when sniffing file \"" << options.file_path << "\"." << '\n'; // 2. What's the error - error << "It was not possible to automatically detect the CSV Parsing dialect/types" << '\n'; + error << "It was not possible to automatically detect the CSV parsing "; + if (type_detection) { + error << "types" << '\n'; + } else { + error << "dialect" << '\n'; + } // 2. What was the search space? error << "The search space used was:" << '\n'; diff --git a/src/duckdb/src/execution/operator/helper/physical_reset.cpp b/src/duckdb/src/execution/operator/helper/physical_reset.cpp index b6a219a48..b1751ab5e 100644 --- a/src/duckdb/src/execution/operator/helper/physical_reset.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_reset.cpp @@ -32,9 +32,11 @@ SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &ch // check if this is an extra extension variable auto entry = config.extension_parameters.find(name); if (entry == config.extension_parameters.end()) { - Catalog::AutoloadExtensionByConfigName(context.client, name); + auto extension_name = Catalog::AutoloadExtensionByConfigName(context.client, name); entry = config.extension_parameters.find(name); - D_ASSERT(entry != config.extension_parameters.end()); + if (entry == config.extension_parameters.end()) { + throw InvalidInputException("Extension parameter %s was not found after autoloading", name); + } } ResetExtensionVariable(context, config, entry->second); return SourceResultType::FINISHED; @@ -45,12 +47,28 @@ SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &ch if (variable_scope == SetScope::AUTOMATIC) { if (option->set_local) { variable_scope = SetScope::SESSION; - } else { - D_ASSERT(option->set_global); + } else if (option->set_global) { variable_scope = SetScope::GLOBAL; + } else { + variable_scope = option->default_scope; } } + if (option->default_value) { + if (option->set_callback) { + SettingCallbackInfo info(context.client, variable_scope); + auto parameter_type = DBConfig::ParseLogicalType(option->parameter_type); + Value reset_val = Value(option->default_value).CastAs(context.client, parameter_type); + option->set_callback(info, reset_val); + } + if (variable_scope == SetScope::SESSION) { + auto &client_config = ClientConfig::GetConfig(context.client); + client_config.set_variables.erase(name); + } else { + config.ResetGenericOption(name); + } + return SourceResultType::FINISHED; + } switch (variable_scope) { case SetScope::GLOBAL: { if (!option->set_global) { diff --git a/src/duckdb/src/execution/operator/helper/physical_set.cpp b/src/duckdb/src/execution/operator/helper/physical_set.cpp index 4a321a52e..e8362ad9c 100644 --- a/src/duckdb/src/execution/operator/helper/physical_set.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_set.cpp @@ -6,20 +6,27 @@ namespace duckdb { +void PhysicalSet::SetGenericVariable(ClientContext &context, const string &name, SetScope scope, Value target_value) { + if (scope == SetScope::GLOBAL) { + auto &config = DBConfig::GetConfig(context); + config.SetOption(name, std::move(target_value)); + } else { + auto &client_config = ClientConfig::GetConfig(context); + client_config.set_variables[name] = std::move(target_value); + } +} + void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, SetScope scope, const Value &value) { - auto &config = DBConfig::GetConfig(context); auto &target_type = extension_option.type; Value target_value = value.CastAs(context, target_type); if (extension_option.set_function) { extension_option.set_function(context, scope, target_value); } - if (scope == SetScope::GLOBAL) { - config.SetOption(name, std::move(target_value)); - } else { - auto &client_config = ClientConfig::GetConfig(context); - client_config.set_variables[name] = std::move(target_value); + if (scope == SetScope::AUTOMATIC) { + scope = extension_option.default_scope; } + SetGenericVariable(context, name, scope, std::move(target_value)); } SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { @@ -31,9 +38,11 @@ SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chun // check if this is an extra extension variable auto entry = config.extension_parameters.find(name); if (entry == config.extension_parameters.end()) { - Catalog::AutoloadExtensionByConfigName(context.client, name); + auto extension_name = Catalog::AutoloadExtensionByConfigName(context.client, name); entry = config.extension_parameters.find(name); - D_ASSERT(entry != config.extension_parameters.end()); + if (entry == config.extension_parameters.end()) { + throw InvalidInputException("Extension parameter %s was not found after autoloading", name); + } } SetExtensionVariable(context.client, entry->second, name, scope, value); return SourceResultType::FINISHED; @@ -42,20 +51,28 @@ SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chun if (variable_scope == SetScope::AUTOMATIC) { if (option->set_local) { variable_scope = SetScope::SESSION; - } else { - D_ASSERT(option->set_global); + } else if (option->set_global) { variable_scope = SetScope::GLOBAL; + } else { + variable_scope = option->default_scope; } } Value input_val = value.CastAs(context.client, DBConfig::ParseLogicalType(option->parameter_type)); + if (option->default_value) { + if (option->set_callback) { + SettingCallbackInfo info(context.client, variable_scope); + option->set_callback(info, input_val); + } + SetGenericVariable(context.client, option->name, variable_scope, std::move(input_val)); + return SourceResultType::FINISHED; + } switch (variable_scope) { case SetScope::GLOBAL: { if (!option->set_global) { throw CatalogException("option \"%s\" cannot be set globally", name); } auto &db = DatabaseInstance::GetDatabase(context.client); - auto &config = DBConfig::GetConfig(context.client); config.SetOption(&db, *option, input_val); break; } diff --git a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp index a8ad410db..385964952 100644 --- a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp +++ b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp @@ -7,6 +7,7 @@ #include "duckdb/main/valid_checker.hpp" #include "duckdb/transaction/meta_transaction.hpp" #include "duckdb/transaction/transaction_manager.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -28,11 +29,10 @@ SourceResultType PhysicalTransaction::GetData(ExecutionContext &context, DataChu // prevent it from being closed after this query, hence // preserving the transaction context for the next query client.transaction.SetAutoCommit(false); - auto &config = DBConfig::GetConfig(context.client); if (info->modifier == TransactionModifierType::TRANSACTION_READ_ONLY) { client.transaction.SetReadOnly(); } - if (config.options.immediate_transaction_mode) { + if (DBConfig::GetSetting(context.client)) { // if immediate transaction mode is enabled then start all transactions immediately auto databases = DatabaseManager::Get(client).GetDatabases(client); for (auto db : databases) { diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp index 9aa97a42b..9513bded8 100644 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp @@ -23,6 +23,8 @@ #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/temporary_memory_manager.hpp" +#include "duckdb/main/settings.hpp" +#include "duckdb/logging/log_manager.hpp" namespace duckdb { @@ -148,7 +150,7 @@ class HashJoinGlobalSinkState : public GlobalSinkState { NumericStats::Max(*op.join_stats[1])); } // For external hash join - external = ClientConfig::GetConfig(context).GetSetting(context); + external = ClientConfig::GetConfig(context).force_external; // Set probe types probe_types = op.children[0].get().GetTypes(); probe_types.emplace_back(LogicalType::HASH); @@ -761,7 +763,7 @@ unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, o return final_min_max; // There are not table souces in which we can push down filters } - auto dynamic_or_filter_threshold = ClientConfig::GetSetting(context); + auto dynamic_or_filter_threshold = DBConfig::GetSetting(context); // create a filter for each of the aggregates for (idx_t filter_idx = 0; filter_idx < join_condition.size(); filter_idx++) { const auto cmp = op.conditions[join_condition[filter_idx]].comparison; diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp index ad9896f66..90aba4722 100644 --- a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp +++ b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp @@ -23,9 +23,7 @@ PhysicalIEJoin::PhysicalIEJoin(PhysicalPlan &physical_plan, LogicalComparisonJoi PhysicalOperator &right, vector cond, JoinType join_type, idx_t estimated_cardinality, unique_ptr pushdown_info) : PhysicalRangeJoin(physical_plan, op, PhysicalOperatorType::IE_JOIN, left, right, std::move(cond), join_type, - estimated_cardinality) { - - filter_pushdown = std::move(pushdown_info); + estimated_cardinality, std::move(pushdown_info)) { // 1. let L1 (resp. L2) be the array of column X (resp. Y) D_ASSERT(conditions.size() >= 2); diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp index 437fa1175..1bd48ab62 100644 --- a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp @@ -20,9 +20,7 @@ PhysicalPiecewiseMergeJoin::PhysicalPiecewiseMergeJoin(PhysicalPlan &physical_pl idx_t estimated_cardinality, unique_ptr pushdown_info_p) : PhysicalRangeJoin(physical_plan, op, PhysicalOperatorType::PIECEWISE_MERGE_JOIN, left, right, std::move(cond), - join_type, estimated_cardinality) { - - filter_pushdown = std::move(pushdown_info_p); + join_type, estimated_cardinality, std::move(pushdown_info_p)) { for (auto &join_cond : conditions) { D_ASSERT(join_cond.left->return_type == join_cond.right->return_type); @@ -289,8 +287,8 @@ class PiecewiseMergeJoinState : public CachingOperatorState { }; unique_ptr PhysicalPiecewiseMergeJoin::GetOperatorState(ExecutionContext &context) const { - auto &config = ClientConfig::GetConfig(context.client); - return make_uniq(context.client, *this, config.force_external); + bool force_external = ClientConfig::GetConfig(context.client).force_external; + return make_uniq(context.client, *this, force_external); } static inline idx_t SortedBlockNotNull(const idx_t base, const idx_t count, const idx_t not_null) { diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp index a11a402ee..4fefafbd4 100644 --- a/src/duckdb/src/execution/operator/join/physical_range_join.cpp +++ b/src/duckdb/src/execution/operator/join/physical_range_join.cpp @@ -7,6 +7,7 @@ #include "duckdb/common/sort/sort.hpp" #include "duckdb/common/types/validity_mask.hpp" #include "duckdb/common/types/vector.hpp" +#include "duckdb/common/unordered_map.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/main/client_context.hpp" @@ -63,8 +64,7 @@ PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, : op(op_p), global_sort_state(context, orders, payload_layout), has_null(0), count(0), memory_per_thread(0) { // Set external (can be forced with the PRAGMA) - auto &config = ClientConfig::GetConfig(context); - global_sort_state.external = config.force_external; + global_sort_state.external = ClientConfig::GetConfig(context).force_external; memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); } @@ -167,12 +167,15 @@ void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &e PhysicalRangeJoin::PhysicalRangeJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperatorType type, PhysicalOperator &left, PhysicalOperator &right, vector cond, - JoinType join_type, idx_t estimated_cardinality) + JoinType join_type, idx_t estimated_cardinality, + unique_ptr pushdown_info) : PhysicalComparisonJoin(physical_plan, op, type, std::move(cond), join_type, estimated_cardinality) { + filter_pushdown = std::move(pushdown_info); // Reorder the conditions so that ranges are at the front. // TODO: use stats to improve the choice? // TODO: Prefer fixed length types? if (conditions.size() > 1) { + unordered_map cond_idx; vector conditions_p(conditions.size()); std::swap(conditions_p, conditions); idx_t range_position = 0; @@ -184,12 +187,21 @@ PhysicalRangeJoin::PhysicalRangeJoin(PhysicalPlan &physical_plan, LogicalCompari case ExpressionType::COMPARE_GREATERTHAN: case ExpressionType::COMPARE_GREATERTHANOREQUALTO: conditions[range_position++] = std::move(conditions_p[i]); + cond_idx[i] = range_position - 1; break; default: conditions[--other_position] = std::move(conditions_p[i]); + cond_idx[i] = other_position; break; } } + if (filter_pushdown) { + for (auto &idx : filter_pushdown->join_condition) { + if (cond_idx.find(idx) != cond_idx.end()) { + idx = cond_idx[idx]; + } + } + } } children.push_back(left); diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp index 88d1cb32d..d3fa408bb 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp @@ -150,6 +150,7 @@ class BatchInsertGlobalState : public GlobalSinkState { : memory_manager(context, minimum_memory_per_thread), table(table), insert_count(0), optimistically_written(false), minimum_memory_per_thread(minimum_memory_per_thread) { row_group_size = table.GetStorage().GetRowGroupSize(); + table.GetStorage().BindIndexes(context); } BatchMemoryManager memory_manager; diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp index 7116f965a..9b16d476c 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp @@ -81,10 +81,12 @@ SourceResultType PhysicalCopyDatabase::GetData(ExecutionContext &context, DataCh storage_info.options.emplace("v1_0_0_storage", false); auto unbound_index = make_uniq(create_index_info.Copy(), storage_info, data_table.GetTableIOManager(), catalog.GetAttached()); - data_table.AddIndex(std::move(unbound_index)); + + // We add unbound indexes, so we immediately bind them. + // Otherwise, WAL serialization fails due to unbound indexes. auto &data_table_info = *data_table.GetDataTableInfo(); - data_table_info.GetIndexes().InitializeIndexes(context.client, data_table_info); + data_table_info.BindIndexes(context.client); } return SourceResultType::FINISHED; diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp index 88848b3f9..5958a9885 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp @@ -8,6 +8,7 @@ #include "duckdb/common/value_operations/value_operations.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/planner/operator/logical_copy_to_file.hpp" +#include "duckdb/main/settings.hpp" #include @@ -50,7 +51,7 @@ class CopyToFunctionGlobalState : public GlobalSinkState { explicit CopyToFunctionGlobalState(ClientContext &context) : initialized(false), rows_copied(0), last_file_offset(0), file_write_lock_if_rotating(make_uniq()) { - max_open_files = ClientConfig::GetConfig(context).partitioned_write_max_open_files; + max_open_files = DBConfig::GetSetting(context); } StorageLock lock; @@ -252,11 +253,14 @@ string PhysicalCopyToFile::GetTrimmedPath(ClientContext &context) const { class CopyToFunctionLocalState : public LocalSinkState { public: - explicit CopyToFunctionLocalState(unique_ptr local_state) : local_state(std::move(local_state)) { + explicit CopyToFunctionLocalState(ClientContext &context, unique_ptr local_state) + : local_state(std::move(local_state)) { + partitioned_write_flush_threshold = DBConfig::GetSetting(context); } unique_ptr global_state; unique_ptr local_state; idx_t total_rows_copied = 0; + idx_t partitioned_write_flush_threshold; //! Buffers the tuples in partitions before writing unique_ptr part_buffer; @@ -281,7 +285,7 @@ class CopyToFunctionLocalState : public LocalSinkState { } part_buffer->Append(*part_buffer_append_state, chunk); append_count += chunk.size(); - if (append_count >= ClientConfig::GetConfig(context.client).partitioned_write_flush_threshold) { + if (append_count >= partitioned_write_flush_threshold) { // flush all cached partitions FlushPartitions(context, op, g); } @@ -370,11 +374,12 @@ unique_ptr PhysicalCopyToFile::GetLocalSinkState(ExecutionContex if (partition_output) { auto &g = sink_state->Cast(); - auto state = make_uniq(nullptr); + auto state = make_uniq(context.client, nullptr); state->InitializeAppendState(context.client, *this, g); return std::move(state); } - auto res = make_uniq(function.copy_to_initialize_local(context, *bind_data)); + auto res = + make_uniq(context.client, function.copy_to_initialize_local(context, *bind_data)); return std::move(res); } diff --git a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp index 2e01d66af..13458c923 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp @@ -39,6 +39,7 @@ class DeleteGlobalState : public GlobalSinkState { mutex delete_lock; idx_t deleted_count; ColumnDataCollection return_collection; + unordered_set deleted_row_ids; LocalAppendState delete_index_append_state; bool has_unique_indexes; }; @@ -147,12 +148,40 @@ SinkResultType PhysicalDelete::Sink(ExecutionContext &context, DataChunk &chunk, }); } + auto deleted_count = table.Delete(*l_state.delete_state, context.client, row_ids, chunk.size()); + g_state.deleted_count += deleted_count; + // Append the return_chunk to the return collection. if (return_chunk) { + // Rows can be duplicated, so we get the chunk indexes for new row id values. + map new_row_ids_deleted; + auto flat_ids = FlatVector::GetData(row_ids); + for (idx_t i = 0; i < chunk.size(); i++) { + // If the row has not been deleted previously + // and is not a duplicate within the current chunk, + // then we add it to new_row_ids_deleted. + auto row_id = flat_ids[i]; + auto already_deleted = g_state.deleted_row_ids.find(row_id) != g_state.deleted_row_ids.end(); + auto newly_deleted = new_row_ids_deleted.find(row_id) != new_row_ids_deleted.end(); + if (!already_deleted && !newly_deleted) { + new_row_ids_deleted[row_id] = i; + g_state.deleted_row_ids.insert(row_id); + } + } + + D_ASSERT(new_row_ids_deleted.size() == deleted_count); + if (deleted_count < l_state.delete_chunk.size()) { + SelectionVector delete_sel(0, deleted_count); + idx_t chunk_index = 0; + for (auto &row_id_to_chunk_index : new_row_ids_deleted) { + delete_sel.set_index(chunk_index, row_id_to_chunk_index.second); + chunk_index++; + } + l_state.delete_chunk.Slice(delete_sel, deleted_count); + } g_state.return_collection.Append(l_state.delete_chunk); } - g_state.deleted_count += table.Delete(*l_state.delete_state, context.client, row_ids, chunk.size()); return SinkResultType::NEED_MORE_INPUT; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp index 9caea7aa7..1875d86c1 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp @@ -76,6 +76,7 @@ void PhysicalInsert::GetInsertInfo(const BoundCreateTableInfo &info, vector &return_types, DuckTableEntry &table) : table(table), insert_count(0), return_collection(context, return_types) { + table.GetStorage().BindIndexes(context); } InsertLocalState::InsertLocalState(ClientContext &context, const vector &types, @@ -358,7 +359,7 @@ static void PrepareSortKeys(DataChunk &input, unordered_map> CheckDistinctness(DataChunk &input, ConflictInfo &info, - unordered_set &matched_indexes) { + reference_set_t &matched_indexes) { map> conflicts; unordered_map> sort_keys; //! Register which rows have already caused a conflict @@ -367,7 +368,7 @@ static map> CheckDistinctness(DataChunk &input, ConflictInf auto &column_ids = info.column_ids; if (column_ids.empty()) { for (auto index : matched_indexes) { - auto &index_column_ids = index->GetColumnIdSet(); + auto &index_column_ids = index.get().GetColumnIdSet(); PrepareSortKeys(input, sort_keys, index_column_ids); vector> columns; for (auto &idx : index_column_ids) { @@ -423,7 +424,7 @@ static void VerifyOnConflictCondition(ExecutionContext &context, DataChunk &comb auto &indexes = local_storage.GetIndexes(context.client, data_table); auto storage = local_storage.GetStorage(data_table); - DataTable::VerifyUniqueIndexes(indexes, storage, tuples, nullptr); + data_table.VerifyUniqueIndexes(indexes, storage, tuples, nullptr); throw InternalException("VerifyUniqueIndexes was expected to throw but didn't"); } @@ -445,7 +446,7 @@ static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &c data_table.VerifyAppendConstraints(constraint_state, context.client, tuples, storage, &conflict_manager); } else { auto &indexes = local_storage.GetIndexes(context.client, data_table); - DataTable::VerifyUniqueIndexes(indexes, storage, tuples, &conflict_manager); + data_table.VerifyUniqueIndexes(indexes, storage, tuples, &conflict_manager); } if (!conflict_manager.HasConflicts()) { @@ -529,32 +530,30 @@ idx_t PhysicalInsert::OnConflictHandling(TableCatalogEntry &table, ExecutionCont } ConflictInfo conflict_info(conflict_target); + reference_set_t matching_indexes; - auto &global_indexes = data_table.GetDataTableInfo()->GetIndexes(); - auto &local_indexes = local_storage.GetIndexes(context.client, data_table); - - unordered_set matching_indexes; if (conflict_info.column_ids.empty()) { + auto &global_indexes = data_table.GetDataTableInfo()->GetIndexes(); // We care about every index that applies to the table if no ON CONFLICT (...) target is given global_indexes.Scan([&](Index &index) { if (!index.IsUnique()) { return false; } + D_ASSERT(index.IsBound()); if (conflict_info.ConflictTargetMatches(index)) { - D_ASSERT(index.IsBound()); - auto &bound_index = index.Cast(); - matching_indexes.insert(&bound_index); + matching_indexes.insert(index); } return false; }); + auto &local_indexes = local_storage.GetIndexes(context.client, data_table); local_indexes.Scan([&](Index &index) { if (!index.IsUnique()) { return false; } + D_ASSERT(index.IsBound()); if (conflict_info.ConflictTargetMatches(index)) { - D_ASSERT(index.IsBound()); auto &bound_index = index.Cast(); - matching_indexes.insert(&bound_index); + matching_indexes.insert(bound_index); } return false; }); diff --git a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp index 4d568ff51..04a5f3dca 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_merge_into.cpp @@ -6,9 +6,10 @@ namespace duckdb { PhysicalMergeInto::PhysicalMergeInto(PhysicalPlan &physical_plan, vector types, map>> actions_p, - idx_t row_id_index, optional_idx source_marker, bool parallel_p) + idx_t row_id_index, optional_idx source_marker, bool parallel_p, + bool return_chunk_p) : PhysicalOperator(physical_plan, PhysicalOperatorType::MERGE_INTO, std::move(types), 1), - row_id_index(row_id_index), source_marker(source_marker), parallel(parallel_p) { + row_id_index(row_id_index), source_marker(source_marker), parallel(parallel_p), return_chunk(return_chunk_p) { map ranges; for (auto &entry : actions_p) { @@ -396,15 +397,119 @@ SinkFinalizeType PhysicalMergeInto::Finalize(Pipeline &pipeline, Event &event, C //===--------------------------------------------------------------------===// // Source //===--------------------------------------------------------------------===// +class MergeGlobalSourceState : public GlobalSourceState { +public: + explicit MergeGlobalSourceState(ClientContext &context, const PhysicalMergeInto &op) { + if (!op.return_chunk) { + return; + } + auto &g = op.sink_state->Cast(); + for (idx_t i = 0; i < op.actions.size(); i++) { + auto &action = *op.actions[i]; + unique_ptr global_state; + if (action.op) { + // assign the global sink state + action.op->sink_state = std::move(g.sink_states[i]); + // initialize the global source state + global_state = action.op->GetGlobalSourceState(context); + } + global_states.push_back(std::move(global_state)); + } + } + + vector> global_states; +}; + +class MergeLocalSourceState : public LocalSourceState { +public: + explicit MergeLocalSourceState(ExecutionContext &context, const PhysicalMergeInto &op, + MergeGlobalSourceState &gstate) { + if (!op.return_chunk) { + return; + } + for (idx_t i = 0; i < op.actions.size(); i++) { + auto &action = *op.actions[i]; + unique_ptr local_state; + if (action.op) { + local_state = action.op->GetLocalSourceState(context, *gstate.global_states[i]); + } + local_states.push_back(std::move(local_state)); + } + vector scan_types; + for (idx_t c = 0; c < op.types.size() - 1; c++) { + scan_types.emplace_back(op.types[c]); + } + scan_chunk.Initialize(context.client, scan_types); + } + + DataChunk scan_chunk; + vector> local_states; + idx_t index = 0; +}; + unique_ptr PhysicalMergeInto::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); + return make_uniq(context, *this); +} + +unique_ptr PhysicalMergeInto::GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const { + return make_uniq(context, *this, gstate.Cast()); } SourceResultType PhysicalMergeInto::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { auto &g = sink_state->Cast(); - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.merged_count.load()))); + if (!return_chunk) { + chunk.SetCardinality(1); + chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.merged_count.load()))); + return SourceResultType::FINISHED; + } + auto &gstate = input.global_state.Cast(); + auto &lstate = input.local_state.Cast(); + chunk.Reset(); + for (; lstate.index < actions.size(); lstate.index++) { + auto &action = *actions[lstate.index]; + if (!action.op) { + // no action to scan from + continue; + } + auto &child_gstate = *gstate.global_states[lstate.index]; + auto &child_lstate = *lstate.local_states[lstate.index]; + OperatorSourceInput source_input {child_gstate, child_lstate, input.interrupt_state}; + + auto result = action.op->GetData(context, lstate.scan_chunk, source_input); + if (lstate.scan_chunk.size() > 0) { + // construct the result chunk + for (idx_t c = 0; c < lstate.scan_chunk.ColumnCount(); c++) { + chunk.data[c].Reference(lstate.scan_chunk.data[c]); + } + // set the merge action + string merge_action_name; + switch (action.action_type) { + case MergeActionType::MERGE_UPDATE: + merge_action_name = "UPDATE"; + break; + case MergeActionType::MERGE_INSERT: + merge_action_name = "INSERT"; + break; + case MergeActionType::MERGE_DELETE: + merge_action_name = "DELETE"; + break; + default: + throw InternalException("Unsupported merge action for RETURNING"); + } + Value merge_action(merge_action_name); + chunk.data.back().Reference(merge_action); + chunk.SetCardinality(lstate.scan_chunk.size()); + } + + if (result != SourceResultType::FINISHED) { + return result; + } + if (chunk.size() != 0) { + return SourceResultType::HAVE_MORE_OUTPUT; + } + } return SourceResultType::FINISHED; } diff --git a/src/duckdb/src/execution/operator/persistent/physical_update.cpp b/src/duckdb/src/execution/operator/persistent/physical_update.cpp index 3fecd2947..f96dba699 100644 --- a/src/duckdb/src/execution/operator/persistent/physical_update.cpp +++ b/src/duckdb/src/execution/operator/persistent/physical_update.cpp @@ -162,8 +162,8 @@ SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, lock_guard glock(g_state.lock); for (idx_t i = 0; i < update_chunk.size(); i++) { auto row_id = row_id_data[i]; - if (g_state.updated_rows.find(row_id) == g_state.updated_rows.end()) { - g_state.updated_rows.insert(row_id); + const auto is_new = g_state.updated_rows.insert(row_id).second; + if (is_new) { sel.set_index(update_count++, i); } } diff --git a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp index 79cdaa0e5..f464bfb18 100644 --- a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp +++ b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp @@ -11,6 +11,7 @@ class TableInOutLocalState : public OperatorState { idx_t row_index; bool new_row; DataChunk input_chunk; + idx_t current_ordinality_idx = 1; }; class TableInOutGlobalState : public GlobalOperatorState { @@ -61,6 +62,15 @@ unique_ptr PhysicalTableInOutFunction::GetGlobalOperatorSta return std::move(result); } +void PhysicalTableInOutFunction::SetOrdinality(DataChunk &chunk, const optional_idx &ordinality_column_idx, + const idx_t &ordinality_idx, const idx_t &ordinality) { + D_ASSERT(ordinality_column_idx.IsValid()); + if (ordinality > 0) { + constexpr idx_t step = 1; + chunk.data[ordinality_column_idx.GetIndex()].Sequence(static_cast(ordinality_idx), step, ordinality); + } +} + OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, GlobalOperatorState &gstate_p, OperatorState &state_p) const { auto &gstate = gstate_p.Cast(); @@ -68,7 +78,13 @@ OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); if (projected_input.empty()) { // straightforward case - no need to project input - return function.in_out_function(context, data, input, chunk); + auto result = function.in_out_function(context, data, input, chunk); + if (this->ordinality_idx.IsValid()) { + const idx_t ordinality = chunk.size(); + SetOrdinality(chunk, this->ordinality_idx, state.current_ordinality_idx, ordinality); + state.current_ordinality_idx += ordinality; + } + return result; } // when project_input is set we execute the input function row-by-row if (state.new_row) { @@ -87,6 +103,7 @@ OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context state.input_chunk.SetCardinality(1); state.row_index++; state.new_row = false; + state.current_ordinality_idx = 1; } // set up the output data in "chunk" D_ASSERT(chunk.ColumnCount() > projected_input.size()); @@ -98,6 +115,11 @@ OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context ConstantVector::Reference(chunk.data[target_idx], input.data[source_idx], state.row_index - 1, 1); } auto result = function.in_out_function(context, data, state.input_chunk, chunk); + if (this->ordinality_idx.IsValid()) { + const idx_t ordinality = chunk.size(); + SetOrdinality(chunk, this->ordinality_idx, state.current_ordinality_idx, ordinality); + state.current_ordinality_idx += ordinality; + } if (result == OperatorResultType::FINISHED) { return result; } diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp index c631dcc09..e9f66bea4 100644 --- a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp +++ b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp @@ -110,6 +110,7 @@ SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk function.in_out_function_final(context, data, chunk); } switch (function.in_out_function(context, data, g_state.input_chunk, chunk)) { + case OperatorResultType::BLOCKED: { auto guard = g_state.Lock(); return g_state.BlockSource(guard, input.interrupt_state); diff --git a/src/duckdb/src/execution/operator/schema/physical_attach.cpp b/src/duckdb/src/execution/operator/schema/physical_attach.cpp index d871f9ef5..a097712b8 100644 --- a/src/duckdb/src/execution/operator/schema/physical_attach.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_attach.cpp @@ -17,7 +17,8 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c OperatorSourceInput &input) const { // parse the options auto &config = DBConfig::GetConfig(context.client); - AttachOptions options(info, config.options.access_mode); + // construct the options + AttachOptions options(info->options, config.options.access_mode); // get the name and path of the database auto &name = info->name; @@ -67,8 +68,7 @@ SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &c auto attached_db = db_manager.AttachDatabase(context.client, *info, options); //! Initialize the database. - const auto storage_options = info->GetStorageOptions(); - attached_db->Initialize(context.client, storage_options); + attached_db->Initialize(context.client); if (!options.default_table.name.empty()) { attached_db->GetCatalog().SetDefaultTable(options.default_table.schema, options.default_table.name); } diff --git a/src/duckdb/src/execution/operator/schema/physical_drop.cpp b/src/duckdb/src/execution/operator/schema/physical_drop.cpp index cfb6841a8..7c9cbf933 100644 --- a/src/duckdb/src/execution/operator/schema/physical_drop.cpp +++ b/src/duckdb/src/execution/operator/schema/physical_drop.cpp @@ -17,8 +17,9 @@ SourceResultType PhysicalDrop::GetData(ExecutionContext &context, DataChunk &chu case CatalogType::PREPARED_STATEMENT: { // DEALLOCATE silently ignores errors auto &statements = ClientData::Get(context.client).prepared_statements; - if (statements.find(info->name) != statements.end()) { - statements.erase(info->name); + auto stmt_iter = statements.find(info->name); + if (stmt_iter != statements.end()) { + statements.erase(stmt_iter); } break; } diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp index c8c4077a6..5759583c5 100644 --- a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp @@ -13,6 +13,7 @@ #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -274,9 +275,10 @@ PhysicalOperator &PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) } D_ASSERT(asof_idx < op.conditions.size()); - auto &config = ClientConfig::GetConfig(context); - if (!config.force_asof_iejoin) { - if (op.children[0]->has_estimated_cardinality && lhs_cardinality < config.asof_loop_join_threshold) { + bool force_asof_join = DBConfig::GetSetting(context); + if (!force_asof_join) { + idx_t asof_join_threshold = DBConfig::GetSetting(context); + if (op.children[0]->has_estimated_cardinality && lhs_cardinality < asof_join_threshold) { auto result = PlanAsOfLoopJoin(op, left, right); if (result) { return *result; diff --git a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp index 6d6ba2a74..fb499d2d3 100644 --- a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp @@ -14,6 +14,7 @@ #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/operator/logical_comparison_join.hpp" #include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -53,10 +54,9 @@ PhysicalOperator &PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoi default: break; } - auto &client_config = ClientConfig::GetConfig(context); - // TODO: Extend PWMJ to handle all comparisons and projection maps - const auto prefer_range_joins = client_config.prefer_range_joins && can_iejoin; + bool prefer_range_joins = DBConfig::GetSetting(context); + prefer_range_joins = prefer_range_joins && can_iejoin; if (has_equality && !prefer_range_joins) { // Equality join with small number of keys : possible perfect join optimization auto &join = Make(op, left, right, std::move(op.conditions), op.join_type, @@ -67,15 +67,16 @@ PhysicalOperator &PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoi } D_ASSERT(op.left_projection_map.empty()); - if (left.estimated_cardinality <= client_config.nested_loop_join_threshold || - right.estimated_cardinality <= client_config.nested_loop_join_threshold) { + idx_t nested_loop_join_threshold = DBConfig::GetSetting(context); + if (left.estimated_cardinality <= nested_loop_join_threshold || + right.estimated_cardinality <= nested_loop_join_threshold) { can_iejoin = false; can_merge = false; } if (can_merge && can_iejoin) { - if (left.estimated_cardinality <= client_config.merge_join_threshold || - right.estimated_cardinality <= client_config.merge_join_threshold) { + idx_t merge_join_threshold = DBConfig::GetSetting(context); + if (left.estimated_cardinality <= merge_join_threshold || right.estimated_cardinality <= merge_join_threshold) { can_iejoin = false; } } diff --git a/src/duckdb/src/execution/physical_plan/plan_get.cpp b/src/duckdb/src/execution/physical_plan/plan_get.cpp index a319ea996..8a446e678 100644 --- a/src/duckdb/src/execution/physical_plan/plan_get.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_get.cpp @@ -77,6 +77,8 @@ PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { Make(op.types, op.function, std::move(op.bind_data), column_ids, op.estimated_cardinality, std::move(op.projected_input)); table_in_out.children.push_back(child); + auto &cast_table_in_out = table_in_out.Cast(); + cast_table_in_out.ordinality_idx = op.ordinality_idx; return table_in_out; } diff --git a/src/duckdb/src/execution/physical_plan/plan_insert.cpp b/src/duckdb/src/execution/physical_plan/plan_insert.cpp index 66822ac04..ab8df5393 100644 --- a/src/duckdb/src/execution/physical_plan/plan_insert.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_insert.cpp @@ -8,6 +8,7 @@ #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/catalog/duck_catalog.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -33,8 +34,6 @@ OrderPreservationType PhysicalPlanGenerator::OrderPreservationRecursive(Physical } bool PhysicalPlanGenerator::PreserveInsertionOrder(ClientContext &context, PhysicalOperator &plan) { - auto &config = DBConfig::GetConfig(context); - auto preservation_type = OrderPreservationRecursive(plan); if (preservation_type == OrderPreservationType::FIXED_ORDER) { // always need to maintain preservation order @@ -45,7 +44,7 @@ bool PhysicalPlanGenerator::PreserveInsertionOrder(ClientContext &context, Physi return false; } // preserve insertion order - check flags - if (!config.options.preserve_insertion_order) { + if (!DBConfig::GetSetting(context)) { // preserving insertion order is disabled by config return false; } @@ -108,11 +107,11 @@ PhysicalOperator &DuckCatalog::PlanInsert(ClientContext &context, PhysicalPlanGe parallel_streaming_insert = false; use_batch_index = false; } - if (op.action_type != OnConflictAction::THROW) { + if (op.on_conflict_info.action_type != OnConflictAction::THROW) { // We don't support ON CONFLICT clause in batch insertion operation currently use_batch_index = false; } - if (op.action_type == OnConflictAction::UPDATE) { + if (op.on_conflict_info.action_type == OnConflictAction::UPDATE) { // When we potentially need to perform updates, we have to check that row is not updated twice // that currently needs to be done for every chunk, which would add a huge bottleneck to parallelized insertion parallel_streaming_insert = false; @@ -128,11 +127,12 @@ PhysicalOperator &DuckCatalog::PlanInsert(ClientContext &context, PhysicalPlanGe } auto &insert = planner.Make( - op.types, op.table, std::move(op.bound_constraints), std::move(op.expressions), std::move(op.set_columns), - std::move(op.set_types), op.estimated_cardinality, op.return_chunk, - parallel_streaming_insert && num_threads > 1, op.action_type, std::move(op.on_conflict_condition), - std::move(op.do_update_condition), std::move(op.on_conflict_filter), std::move(op.columns_to_fetch), - op.update_is_del_and_insert); + op.types, op.table, std::move(op.bound_constraints), std::move(op.expressions), + std::move(op.on_conflict_info.set_columns), std::move(op.on_conflict_info.set_types), op.estimated_cardinality, + op.return_chunk, parallel_streaming_insert && num_threads > 1, op.on_conflict_info.action_type, + std::move(op.on_conflict_info.on_conflict_condition), std::move(op.on_conflict_info.do_update_condition), + std::move(op.on_conflict_info.on_conflict_filter), std::move(op.on_conflict_info.columns_to_fetch), + op.on_conflict_info.update_is_del_and_insert); insert.children.push_back(*plan); return insert; } diff --git a/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp b/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp index 80edab918..c4d06b0a6 100644 --- a/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp +++ b/src/duckdb/src/execution/physical_plan/plan_merge_into.cpp @@ -19,23 +19,31 @@ unique_ptr PlanMergeIntoAction(ClientContext &context, Logica for (auto &constraint : op.bound_constraints) { bound_constraints.push_back(constraint->Copy()); } + auto return_types = op.types; + if (op.return_chunk) { + // for RETURNING, the last column is the merge_action - this is added in the merge itself + return_types.pop_back(); + } + auto cardinality = op.EstimateCardinality(context); switch (action.action_type) { case MergeActionType::MERGE_UPDATE: { vector> defaults; for (auto &def : op.bound_defaults) { defaults.push_back(def->Copy()); } - result->op = planner.Make(op.types, op.table, op.table.GetStorage(), std::move(action.columns), - std::move(action.expressions), std::move(defaults), - std::move(bound_constraints), 1ULL, false); + result->op = + planner.Make(std::move(return_types), op.table, op.table.GetStorage(), + std::move(action.columns), std::move(action.expressions), std::move(defaults), + std::move(bound_constraints), cardinality, op.return_chunk); auto &cast_update = result->op->Cast(); cast_update.update_is_del_and_insert = action.update_is_del_and_insert; break; } case MergeActionType::MERGE_DELETE: { - result->op = planner.Make(op.types, op.table, op.table.GetStorage(), - std::move(bound_constraints), op.row_id_start, 1ULL, false); + result->op = + planner.Make(std::move(return_types), op.table, op.table.GetStorage(), + std::move(bound_constraints), op.row_id_start, cardinality, op.return_chunk); break; } case MergeActionType::MERGE_INSERT: { @@ -45,10 +53,11 @@ unique_ptr PlanMergeIntoAction(ClientContext &context, Logica unordered_set on_conflict_filter; vector columns_to_fetch; - result->op = planner.Make( - op.types, op.table, std::move(bound_constraints), std::move(set_expressions), std::move(set_columns), - std::move(set_types), 1ULL, false, true, OnConflictAction::THROW, nullptr, nullptr, - std::move(on_conflict_filter), std::move(columns_to_fetch), false); + result->op = planner.Make(std::move(return_types), op.table, std::move(bound_constraints), + std::move(set_expressions), std::move(set_columns), + std::move(set_types), cardinality, op.return_chunk, !op.return_chunk, + OnConflictAction::THROW, nullptr, nullptr, + std::move(on_conflict_filter), std::move(columns_to_fetch), false); // transform expressions if required if (!action.column_index_map.empty()) { vector> new_expressions; @@ -100,17 +109,17 @@ PhysicalOperator &DuckCatalog::PlanMergeInto(ClientContext &context, PhysicalPla actions.emplace(entry.first, std::move(planned_actions)); } - bool parallel = append_count <= 1; + bool parallel = append_count <= 1 && !op.return_chunk; - auto &result = - planner.Make(op.types, std::move(actions), op.row_id_start, op.source_marker, parallel); + auto &result = planner.Make(op.types, std::move(actions), op.row_id_start, op.source_marker, + parallel, op.return_chunk); result.children.push_back(plan); return result; } PhysicalOperator &Catalog::PlanMergeInto(ClientContext &context, PhysicalPlanGenerator &planner, LogicalMergeInto &op, PhysicalOperator &plan) { - throw NotImplementedException("Database does not support merge into"); + throw NotImplementedException("Database type \"%s\" does not support MERGE INTO or ON CONFLICT", GetName()); } PhysicalOperator &PhysicalPlanGenerator::CreatePlan(LogicalMergeInto &op) { diff --git a/src/duckdb/src/execution/physical_plan_generator.cpp b/src/duckdb/src/execution/physical_plan_generator.cpp index 7595963e0..de7d1a093 100644 --- a/src/duckdb/src/execution/physical_plan_generator.cpp +++ b/src/duckdb/src/execution/physical_plan_generator.cpp @@ -10,6 +10,7 @@ #include "duckdb/planner/operator/logical_extension_operator.hpp" #include "duckdb/planner/operator/list.hpp" #include "duckdb/execution/operator/helper/physical_verify_vector.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -56,11 +57,10 @@ unique_ptr PhysicalPlanGenerator::PlanInternal(LogicalOperator &op physical_plan->SetRoot(CreatePlan(op)); physical_plan->Root().estimated_cardinality = op.estimated_cardinality; - auto &config = DBConfig::GetConfig(context); - if (config.options.debug_verify_vector != DebugVectorVerification::NONE) { - if (config.options.debug_verify_vector != DebugVectorVerification::DICTIONARY_EXPRESSION) { - physical_plan->SetRoot( - Make(physical_plan->Root(), config.options.debug_verify_vector)); + auto debug_verify_vector = DBConfig::GetSetting(context); + if (debug_verify_vector != DebugVectorVerification::NONE) { + if (debug_verify_vector != DebugVectorVerification::DICTIONARY_EXPRESSION) { + physical_plan->SetRoot(Make(physical_plan->Root(), debug_verify_vector)); } } return std::move(physical_plan); diff --git a/src/duckdb/src/function/aggregate/distributive/minmax.cpp b/src/duckdb/src/function/aggregate/distributive/minmax.cpp index 72fecbce7..ce5ef12af 100644 --- a/src/duckdb/src/function/aggregate/distributive/minmax.cpp +++ b/src/duckdb/src/function/aggregate/distributive/minmax.cpp @@ -12,6 +12,7 @@ #include "duckdb/planner/expression.hpp" #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/expression_binder.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -334,7 +335,7 @@ unique_ptr BindMinMax(ClientContext &context, AggregateFunction &f vector> &arguments) { if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) { auto str_collation = StringType::GetCollation(arguments[0]->return_type); - if (!str_collation.empty() || !DBConfig::GetConfig(context).options.collation.empty()) { + if (!str_collation.empty() || !DBConfig::GetSetting(context).empty()) { // If aggr function is min/max and uses collations, replace bound_function with arg_min/arg_max // to make sure the result's correctness. string function_name = function.name == "min" ? "arg_min" : "arg_max"; diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp index cd9f17cf3..8c697c5cf 100644 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp @@ -1,5 +1,5 @@ #include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/common/types/column/column_data_collection.hpp" #include "duckdb/common/types/list_segment.hpp" #include "duckdb/function/aggregate_function.hpp" @@ -7,8 +7,9 @@ #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/parser/expression_map.hpp" +#include "duckdb/parallel/thread_context.hpp" namespace duckdb { @@ -18,32 +19,62 @@ struct SortedAggregateBindData : public FunctionData { using BindInfoPtr = unique_ptr; using OrderBys = vector; - SortedAggregateBindData(ClientContext &context_p, Expressions &children, AggregateFunction &aggregate, + SortedAggregateBindData(ClientContext &context, Expressions &children, AggregateFunction &aggregate, BindInfoPtr &bind_info, OrderBys &order_bys) - : context(context_p), function(aggregate), bind_info(std::move(bind_info)), - threshold(ClientConfig::GetConfig(context).ordered_aggregate_threshold), - external(ClientConfig::GetConfig(context).force_external) { - arg_types.reserve(children.size()); - arg_funcs.reserve(children.size()); + : context(context), function(aggregate), bind_info(std::move(bind_info)), + threshold(ClientConfig::GetConfig(context).ordered_aggregate_threshold) { + + // Describe the arguments. for (const auto &child : children) { - arg_types.emplace_back(child->return_type); - ListSegmentFunctions funcs; - GetSegmentDataFunctions(funcs, arg_types.back()); - arg_funcs.emplace_back(std::move(funcs)); + buffered_cols.emplace_back(buffered_cols.size()); + buffered_types.emplace_back(child->return_type); + + // Column 0 in the sort data is the group number + scan_cols.emplace_back(buffered_cols.size()); } - sort_types.reserve(order_bys.size()); - sort_funcs.reserve(order_bys.size()); - for (auto &order : order_bys) { - orders.emplace_back(order.Copy()); - sort_types.emplace_back(order.expression->return_type); - ListSegmentFunctions funcs; - GetSegmentDataFunctions(funcs, sort_types.back()); - sort_funcs.emplace_back(std::move(funcs)); + scan_types = buffered_types; + + // The first sort column is the group number. It is prefixed onto the buffered data + sort_types.emplace_back(LogicalType::USMALLINT); + orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, + make_uniq(sort_types.back(), 0U))); + + // Determine whether we are sorted on all the arguments. + // Even if we are not, we want to share inputs for sorting. + for (idx_t ord_idx = 0; ord_idx < order_bys.size(); ++ord_idx) { + auto order = order_bys[ord_idx].Copy(); + bool matched = false; + const auto &type = order.expression->return_type; + + for (idx_t arg_idx = 0; arg_idx < children.size(); ++arg_idx) { + auto &child = children[arg_idx]; + if (child->Equals(*order.expression)) { + order.expression = make_uniq(type, arg_idx + 1); + matched = true; + break; + } + } + + if (!matched) { + sorted_on_args = false; + buffered_cols.emplace_back(children.size() + ord_idx); + buffered_types.emplace_back(type); + order.expression = make_uniq(type, buffered_cols.size()); + } + + orders.emplace_back(std::move(order)); } - sorted_on_args = (children.size() == order_bys.size()); - for (size_t i = 0; sorted_on_args && i < children.size(); ++i) { - sorted_on_args = children[i]->Equals(*order_bys[i].expression); + + // Look up all the linked list functions we need + for (auto &type : buffered_types) { + ListSegmentFunctions funcs; + GetSegmentDataFunctions(funcs, type); + buffered_funcs.emplace_back(std::move(funcs)); + sort_types.emplace_back(type); } + + // Only scan the argument columns after sorting + sort = make_uniq(context, orders, sort_types, scan_cols); } SortedAggregateBindData(ClientContext &context, BoundAggregateExpression &expr) @@ -55,15 +86,17 @@ struct SortedAggregateBindData : public FunctionData { } SortedAggregateBindData(const SortedAggregateBindData &other) - : context(other.context), function(other.function), arg_types(other.arg_types), arg_funcs(other.arg_funcs), - sort_types(other.sort_types), sort_funcs(other.sort_funcs), sorted_on_args(other.sorted_on_args), - threshold(other.threshold), external(other.external) { + : context(other.context), function(other.function), sort_types(other.sort_types), scan_cols(other.scan_cols), + scan_types(other.scan_types), buffered_cols(other.buffered_cols), buffered_types(other.buffered_types), + buffered_funcs(other.buffered_funcs), sorted_on_args(other.sorted_on_args), threshold(other.threshold) { if (other.bind_info) { bind_info = other.bind_info->Copy(); } for (auto &order : other.orders) { orders.emplace_back(order.Copy()); } + + sort = make_uniq(context, orders, sort_types, scan_cols); } unique_ptr Copy() const override { @@ -95,18 +128,30 @@ struct SortedAggregateBindData : public FunctionData { ClientContext &context; AggregateFunction function; - vector arg_types; unique_ptr bind_info; - vector arg_funcs; + //! The sort expressions (all references as the expressions have been computed) vector orders; + //! The types of the sunk columns vector sort_types; - vector sort_funcs; - bool sorted_on_args; + //! The sorted columns that have the arguments + vector scan_cols; + //! The types of the sunk columns + vector scan_types; + //! The shared sort specification + unique_ptr sort; + + //! The mapping from inputs to buffered columns + vector buffered_cols; + //! The schema of the buffered data + vector buffered_types; + //! The linked list functions for the buffered data + vector buffered_funcs; + //! Can we just use the inputs for sorting? + bool sorted_on_args = true; //! The sort flush threshold const idx_t threshold; - const bool external; }; struct SortedAggregateState { @@ -128,10 +173,7 @@ struct SortedAggregateState { } inline void InitializeLinkedLists(const SortedAggregateBindData &order_bind) { - InitializeLinkedList(sort_linked, order_bind.sort_types); - if (!order_bind.sorted_on_args) { - InitializeLinkedList(arg_linked, order_bind.arg_types); - } + InitializeLinkedList(input_linked, order_bind.buffered_types); } static inline void InitializeChunk(Allocator &allocator, unique_ptr &chunk, @@ -145,10 +187,7 @@ struct SortedAggregateState { void InitializeChunks(const SortedAggregateBindData &order_bind) { // Lazy instantiation of the buffer chunks auto &allocator = BufferManager::GetBufferManager(order_bind.context).GetBufferAllocator(); - InitializeChunk(allocator, sort_chunk, order_bind.sort_types); - if (!order_bind.sorted_on_args) { - InitializeChunk(allocator, arg_chunk, order_bind.arg_types); - } + InitializeChunk(allocator, input_chunk, order_bind.buffered_types); } static inline void FlushLinkedList(const LinkedChunkFunctions &funcs, LinkedLists &linked, DataChunk &chunk) { @@ -161,34 +200,19 @@ struct SortedAggregateState { void FlushLinkedLists(const SortedAggregateBindData &order_bind) { InitializeChunks(order_bind); - FlushLinkedList(order_bind.sort_funcs, sort_linked, *sort_chunk); - if (arg_chunk) { - FlushLinkedList(order_bind.arg_funcs, arg_linked, *arg_chunk); - } + FlushLinkedList(order_bind.buffered_funcs, input_linked, *input_chunk); } void InitializeCollections(const SortedAggregateBindData &order_bind) { - ordering = make_uniq(order_bind.context, order_bind.sort_types); - ordering_append = make_uniq(); - ordering->InitializeAppend(*ordering_append); - - if (!order_bind.sorted_on_args) { - arguments = make_uniq(order_bind.context, order_bind.arg_types); - arguments_append = make_uniq(); - arguments->InitializeAppend(*arguments_append); - } + input_collection = make_uniq(order_bind.context, order_bind.buffered_types); + input_append = make_uniq(); + input_collection->InitializeAppend(*input_append); } void FlushChunks(const SortedAggregateBindData &order_bind) { - D_ASSERT(sort_chunk); - ordering->Append(*ordering_append, *sort_chunk); - sort_chunk->Reset(); - - if (arguments) { - D_ASSERT(arg_chunk); - arguments->Append(*arguments_append, *arg_chunk); - arg_chunk->Reset(); - } + D_ASSERT(input_chunk); + input_collection->Append(*input_append, *input_chunk); + input_chunk->Reset(); } void Resize(const SortedAggregateBindData &order_bind, idx_t n) { @@ -199,11 +223,11 @@ struct SortedAggregateState { InitializeLinkedLists(order_bind); } - if (count > LIST_CAPACITY && !sort_chunk && !ordering) { + if (count > LIST_CAPACITY && !input_chunk && !input_collection) { FlushLinkedLists(order_bind); } - if (count > CHUNK_CAPACITY && !ordering) { + if (count > CHUNK_CAPACITY && !input_collection) { InitializeCollections(order_bind); FlushChunks(order_bind); } @@ -244,61 +268,43 @@ struct SortedAggregateState { } } - void Update(const AggregateInputData &aggr_input_data, DataChunk &sort_input, DataChunk &arg_input) { + void Update(const AggregateInputData &aggr_input_data, DataChunk &input) { const auto &order_bind = aggr_input_data.bind_data->Cast(); - Resize(order_bind, count + sort_input.size()); + Resize(order_bind, count + input.size()); sel.Initialize(nullptr); - nsel = sort_input.size(); + nsel = input.size(); - if (ordering) { + if (input_collection) { // Using collections - ordering->Append(*ordering_append, sort_input); - if (arguments) { - arguments->Append(*arguments_append, arg_input); - } - } else if (sort_chunk) { + input_collection->Append(*input_append, input); + } else if (input_chunk) { // Still using data chunks - sort_chunk->Append(sort_input); - if (arg_chunk) { - arg_chunk->Append(arg_input); - } + input_chunk->Append(input); } else { // Still using linked lists - LinkedAppend(order_bind.sort_funcs, aggr_input_data.allocator, sort_input, sort_linked, sel, nsel); - if (!arg_linked.empty()) { - LinkedAppend(order_bind.arg_funcs, aggr_input_data.allocator, arg_input, arg_linked, sel, nsel); - } + LinkedAppend(order_bind.buffered_funcs, aggr_input_data.allocator, input, input_linked, sel, nsel); } nsel = 0; offset = 0; } - void UpdateSlice(const AggregateInputData &aggr_input_data, DataChunk &sort_input, DataChunk &arg_input) { + void UpdateSlice(const AggregateInputData &aggr_input_data, DataChunk &input) { const auto &order_bind = aggr_input_data.bind_data->Cast(); Resize(order_bind, count + nsel); - if (ordering) { + if (input_collection) { // Using collections - D_ASSERT(sort_chunk); - sort_chunk->Slice(sort_input, sel, nsel); - if (arg_chunk) { - arg_chunk->Slice(arg_input, sel, nsel); - } + D_ASSERT(input_chunk); + input_chunk->Slice(input, sel, nsel); FlushChunks(order_bind); - } else if (sort_chunk) { + } else if (input_chunk) { // Still using data chunks - sort_chunk->Append(sort_input, true, &sel, nsel); - if (arg_chunk) { - arg_chunk->Append(arg_input, true, &sel, nsel); - } + input_chunk->Append(input, true, &sel, nsel); } else { // Still using linked lists - LinkedAppend(order_bind.sort_funcs, aggr_input_data.allocator, sort_input, sort_linked, sel, nsel); - if (!arg_linked.empty()) { - LinkedAppend(order_bind.arg_funcs, aggr_input_data.allocator, arg_input, arg_linked, sel, nsel); - } + LinkedAppend(order_bind.buffered_funcs, aggr_input_data.allocator, input, input_linked, sel, nsel); } nsel = 0; @@ -308,16 +314,12 @@ struct SortedAggregateState { void Swap(SortedAggregateState &other) { std::swap(count, other.count); - std::swap(arguments, other.arguments); - std::swap(arguments_append, other.arguments_append); - std::swap(ordering, other.ordering); - std::swap(ordering_append, other.ordering_append); + std::swap(input_collection, other.input_collection); + std::swap(input_append, other.input_append); - std::swap(sort_chunk, other.sort_chunk); - std::swap(arg_chunk, other.arg_chunk); + std::swap(input_chunk, other.input_chunk); - std::swap(sort_linked, other.sort_linked); - std::swap(arg_linked, other.arg_linked); + std::swap(input_linked, other.input_linked); } void Absorb(const SortedAggregateBindData &order_bind, SortedAggregateState &other) { @@ -333,46 +335,31 @@ struct SortedAggregateState { // 3x3 matrix. // We can simplify the logic a bit because the target is already set for the final capacity - if (!sort_chunk) { + if (!input_chunk) { // If the combined count is still linked lists, // then just move the pointers. // Note that this assumes ArenaAllocator is shared and the memory will not vanish under us. - LinkedAbsorb(other.sort_linked, sort_linked); - if (!arg_linked.empty()) { - LinkedAbsorb(other.arg_linked, arg_linked); - } + LinkedAbsorb(other.input_linked, input_linked); other.Reset(); return; } - if (!other.sort_chunk) { + if (!other.input_chunk) { other.FlushLinkedLists(order_bind); } - if (!ordering) { + if (!input_collection) { // Still using chunks, which means the source is using chunks or lists - D_ASSERT(sort_chunk); - D_ASSERT(other.sort_chunk); - sort_chunk->Append(*other.sort_chunk); - if (arg_chunk) { - D_ASSERT(other.arg_chunk); - arg_chunk->Append(*other.arg_chunk); - } + D_ASSERT(input_chunk); + D_ASSERT(other.input_chunk); + input_chunk->Append(*other.input_chunk); } else { // Using collections, so source could be using anything. - if (other.ordering) { - ordering->Combine(*other.ordering); - if (arguments) { - D_ASSERT(other.arguments); - arguments->Combine(*other.arguments); - } + if (other.input_collection) { + input_collection->Combine(*other.input_collection); } else { - ordering->Append(*other.sort_chunk); - if (arguments) { - D_ASSERT(other.arg_chunk); - arguments->Append(*other.arg_chunk); - } + input_collection->Append(*other.input_chunk); } } @@ -381,43 +368,30 @@ struct SortedAggregateState { } void PrefixSortBuffer(DataChunk &prefixed) { - for (column_t col_idx = 0; col_idx < sort_chunk->ColumnCount(); ++col_idx) { - prefixed.data[col_idx + 1].Reference(sort_chunk->data[col_idx]); + for (column_t col_idx = 0; col_idx < input_chunk->ColumnCount(); ++col_idx) { + prefixed.data[col_idx + 1].Reference(input_chunk->data[col_idx]); } - prefixed.SetCardinality(*sort_chunk); + prefixed.SetCardinality(*input_chunk); } - void Finalize(const SortedAggregateBindData &order_bind, DataChunk &prefixed, LocalSortState &local_sort) { - if (arguments) { + void Finalize(const SortedAggregateBindData &order_bind, DataChunk &prefixed, ExecutionContext &context, + OperatorSinkInput &sink) { + auto &sort = *order_bind.sort; + if (input_collection) { ColumnDataScanState sort_state; - ordering->InitializeScan(sort_state); - ColumnDataScanState arg_state; - arguments->InitializeScan(arg_state); - for (sort_chunk->Reset(); ordering->Scan(sort_state, *sort_chunk); sort_chunk->Reset()) { + input_collection->InitializeScan(sort_state); + for (input_chunk->Reset(); input_collection->Scan(sort_state, *input_chunk); input_chunk->Reset()) { PrefixSortBuffer(prefixed); - arg_chunk->Reset(); - arguments->Scan(arg_state, *arg_chunk); - local_sort.SinkChunk(prefixed, *arg_chunk); - } - } else if (ordering) { - ColumnDataScanState sort_state; - ordering->InitializeScan(sort_state); - for (sort_chunk->Reset(); ordering->Scan(sort_state, *sort_chunk); sort_chunk->Reset()) { - PrefixSortBuffer(prefixed); - local_sort.SinkChunk(prefixed, *sort_chunk); + sort.Sink(context, prefixed, sink); } } else { // Force chunks so we can sort - if (!sort_chunk) { + if (!input_chunk) { FlushLinkedLists(order_bind); } PrefixSortBuffer(prefixed); - if (arg_chunk) { - local_sort.SinkChunk(prefixed, *arg_chunk); - } else { - local_sort.SinkChunk(prefixed, *sort_chunk); - } + sort.Sink(context, prefixed, sink); } Reset(); @@ -425,30 +399,19 @@ struct SortedAggregateState { void Reset() { // Release all memory - ordering.reset(); - arguments.reset(); - - sort_chunk.reset(); - arg_chunk.reset(); - - sort_linked.clear(); - arg_linked.clear(); + input_collection.reset(); + input_chunk.reset(); + input_linked.clear(); count = 0; } idx_t count; - unique_ptr arguments; - unique_ptr arguments_append; - unique_ptr ordering; - unique_ptr ordering_append; - - unique_ptr sort_chunk; - unique_ptr arg_chunk; - - LinkedLists sort_linked; - LinkedLists arg_linked; + unique_ptr input_collection; + unique_ptr input_append; + unique_ptr input_chunk; + LinkedLists input_linked; // Selection for scattering SelectionVector sel; @@ -468,33 +431,26 @@ struct SortedAggregateFunction { } static void ProjectInputs(Vector inputs[], const SortedAggregateBindData &order_bind, idx_t input_count, - idx_t count, DataChunk &arg_input, DataChunk &sort_input) { - idx_t col = 0; + idx_t count, DataChunk &buffered) { - if (!order_bind.sorted_on_args) { - arg_input.InitializeEmpty(order_bind.arg_types); - for (auto &dst : arg_input.data) { - dst.Reference(inputs[col++]); - } - arg_input.SetCardinality(count); + // Only reference the buffered columns + buffered.InitializeEmpty(order_bind.buffered_types); + const auto &buffered_cols = order_bind.buffered_cols; + for (idx_t b = 0; b < buffered_cols.size(); ++b) { + D_ASSERT(buffered_cols[b] < input_count); + buffered.data[b].Reference(inputs[buffered_cols[b]]); } - - sort_input.InitializeEmpty(order_bind.sort_types); - for (auto &dst : sort_input.data) { - dst.Reference(inputs[col++]); - } - sort_input.SetCardinality(count); + buffered.SetCardinality(count); } static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, idx_t count) { const auto order_bind = aggr_input_data.bind_data->Cast(); DataChunk arg_input; - DataChunk sort_input; - ProjectInputs(inputs, order_bind, input_count, count, arg_input, sort_input); + ProjectInputs(inputs, order_bind, input_count, count, arg_input); const auto order_state = reinterpret_cast(state); - order_state->Update(aggr_input_data, sort_input, arg_input); + order_state->Update(aggr_input_data, arg_input); } static void ScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, @@ -506,8 +462,7 @@ struct SortedAggregateFunction { // Append the arguments to the two sub-collections const auto &order_bind = aggr_input_data.bind_data->Cast(); DataChunk arg_inputs; - DataChunk sort_inputs; - ProjectInputs(inputs, order_bind, input_count, count, arg_inputs, sort_inputs); + ProjectInputs(inputs, order_bind, input_count, count, arg_inputs); // We have to scatter the chunks one at a time // so build a selection vector for each one. @@ -545,7 +500,7 @@ struct SortedAggregateFunction { continue; } - order_state->UpdateSlice(aggr_input_data, sort_inputs, arg_inputs); + order_state->UpdateSlice(aggr_input_data, arg_inputs); } } @@ -565,15 +520,13 @@ struct SortedAggregateFunction { static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, const idx_t offset) { auto &order_bind = aggr_input_data.bind_data->Cast(); - auto &context = order_bind.context; - RowLayout payload_layout; - payload_layout.Initialize(order_bind.arg_types); + auto &client = order_bind.context; - auto &buffer_allocator = BufferManager::GetBufferManager(order_bind.context).GetBufferAllocator(); - DataChunk chunk; - chunk.Initialize(buffer_allocator, order_bind.arg_types); + auto &buffer_allocator = BufferManager::GetBufferManager(client).GetBufferAllocator(); + DataChunk scanned; + scanned.Initialize(buffer_allocator, order_bind.scan_types); DataChunk sliced; - sliced.Initialize(buffer_allocator, order_bind.arg_types); + sliced.Initialize(buffer_allocator, order_bind.scan_types); // Reusable inner state auto &aggr = order_bind.function; @@ -598,21 +551,15 @@ struct SortedAggregateFunction { state_unprocessed[i] = sdata[i]->count; } - // Sort the input payloads on (state_idx ASC, orders) - vector orders; - orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, - make_uniq(Value::USMALLINT(0)))); - for (const auto &order : order_bind.orders) { - orders.emplace_back(order.Copy()); - } - - auto global_sort = make_uniq(context, orders, payload_layout); - global_sort->external = order_bind.external; - auto local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); + ThreadContext thread(client); + ExecutionContext context(client, thread, nullptr); + InterruptState interrupt; + auto &sort = order_bind.sort; + auto global_sink = sort->GetGlobalSinkState(client); + auto local_sink = sort->GetLocalSinkState(context); DataChunk prefixed; - prefixed.Initialize(buffer_allocator, global_sort->sort_layout.logical_types); + prefixed.Initialize(buffer_allocator, order_bind.sort_types); // Go through the states accumulating values to sort until we hit the sort threshold idx_t unsorted_count = 0; @@ -622,7 +569,8 @@ struct SortedAggregateFunction { auto state = sdata[finalized]; prefixed.Reset(); prefixed.data[0].Reference(Value::USMALLINT(UnsafeNumericCast(finalized))); - state->Finalize(order_bind, prefixed, *local_sort); + OperatorSinkInput sink {*global_sink, *local_sink, interrupt}; + state->Finalize(order_bind, prefixed, context, sink); unsorted_count += state_unprocessed[finalized]; // Go to the next aggregate unless this is the last one @@ -638,24 +586,26 @@ struct SortedAggregateFunction { } // Sort all the data - global_sort->AddLocalState(*local_sort); - global_sort->PrepareMergePhase(); - while (global_sort->sorted_blocks.size() > 1) { - global_sort->InitializeMergeRound(); - MergeSorter merge_sorter(*global_sort, global_sort->buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort->CompleteMergeRound(false); - } + OperatorSinkCombineInput combine {*global_sink, *local_sink, interrupt}; + order_bind.sort->Combine(context, combine); + + OperatorSinkFinalizeInput finalize_input {*global_sink, interrupt}; + order_bind.sort->Finalize(client, finalize_input); + + auto global_source = sort->GetGlobalSourceState(client, *global_sink); + auto local_source = sort->GetLocalSourceState(context, *global_source); - auto scanner = make_uniq(*global_sort); initialize(aggr, agg_state.data()); - while (scanner->Remaining()) { - chunk.Reset(); - scanner->Scan(chunk); + for (;;) { + OperatorSourceInput source {*global_source, *local_source, interrupt}; + scanned.Reset(); + if (sort->GetData(context, scanned, source) == SourceResultType::FINISHED) { + break; + } idx_t consumed = 0; // Distribute the scanned chunk to the aggregates - while (consumed < chunk.size()) { + while (consumed < scanned.size()) { // Find the next aggregate that needs data for (; !state_unprocessed[sorted]; ++sorted) { // Finalize a single value at the next offset @@ -667,9 +617,9 @@ struct SortedAggregateFunction { initialize(aggr, agg_state.data()); } - const auto input_count = MinValue(state_unprocessed[sorted], chunk.size() - consumed); - for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { - sliced.data[col_idx].Slice(chunk.data[col_idx], consumed, consumed + input_count); + const auto input_count = MinValue(state_unprocessed[sorted], scanned.size() - consumed); + for (column_t col_idx = 0; col_idx < scanned.ColumnCount(); ++col_idx) { + sliced.data[col_idx].Slice(scanned.data[col_idx], consumed, consumed + input_count); } sliced.SetCardinality(input_count); @@ -702,11 +652,8 @@ struct SortedAggregateFunction { } // Create a new sort - scanner.reset(); - global_sort = make_uniq(context, orders, payload_layout); - global_sort->external = order_bind.external; - local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); + global_sink = sort->GetGlobalSinkState(client); + local_sink = sort->GetLocalSinkState(context); unsorted_count = 0; } diff --git a/src/duckdb/src/function/cast/varint_casts.cpp b/src/duckdb/src/function/cast/bignum_casts.cpp similarity index 65% rename from src/duckdb/src/function/cast/varint_casts.cpp rename to src/duckdb/src/function/cast/bignum_casts.cpp index 1705bc3bf..428fa4d49 100644 --- a/src/duckdb/src/function/cast/varint_casts.cpp +++ b/src/duckdb/src/function/cast/bignum_casts.cpp @@ -1,13 +1,14 @@ #include "duckdb/function/cast/default_casts.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/function/cast/vector_cast_helpers.hpp" -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/types/bignum.hpp" #include +#include "duckdb/common/bignum.hpp" namespace duckdb { template -static string_t IntToVarInt(Vector &result, T int_value) { +static bignum_t IntToBignum(Vector &result, T int_value) { // Determine if the number is negative bool is_negative = int_value < 0; // Determine the number of data bytes @@ -28,13 +29,13 @@ static string_t IntToVarInt(Vector &result, T int_value) { data_byte_size = static_cast(std::ceil(std::log2(abs_value) / 8.0)); } - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; + uint32_t blob_size = data_byte_size + Bignum::BIGNUM_HEADER_SIZE; auto blob = StringVector::EmptyString(result, blob_size); auto writable_blob = blob.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, is_negative); + Bignum::SetHeader(writable_blob, data_byte_size, is_negative); // Add data bytes to the blob, starting off after header bytes - idx_t wb_idx = Varint::VARINT_HEADER_SIZE; + idx_t wb_idx = Bignum::BIGNUM_HEADER_SIZE; for (int i = static_cast(data_byte_size) - 1; i >= 0; --i) { if (is_negative) { writable_blob[wb_idx++] = static_cast(~(abs_value >> i * 8 & 0xFF)); @@ -43,11 +44,12 @@ static string_t IntToVarInt(Vector &result, T int_value) { } } blob.Finalize(); - return blob; + bignum_t result_bignum(blob); + return result_bignum; } template <> -string_t HugeintCastToVarInt::Operation(uhugeint_t int_value, Vector &result) { +bignum_t HugeintCastToBignum::Operation(uhugeint_t int_value, Vector &result) { uint32_t data_byte_size; if (int_value.upper != NumericLimits::Maximum()) { data_byte_size = @@ -70,13 +72,13 @@ string_t HugeintCastToVarInt::Operation(uhugeint_t int_value, Vector &result) { if (data_byte_size == 0) { data_byte_size++; } - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; + uint32_t blob_size = data_byte_size + Bignum::BIGNUM_HEADER_SIZE; auto blob = StringVector::EmptyString(result, blob_size); auto writable_blob = blob.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, false); + Bignum::SetHeader(writable_blob, data_byte_size, false); // Add data bytes to the blob, starting off after header bytes - idx_t wb_idx = Varint::VARINT_HEADER_SIZE; + idx_t wb_idx = Bignum::BIGNUM_HEADER_SIZE; for (int i = static_cast(upper_byte_size) - 1; i >= 0; --i) { writable_blob[wb_idx++] = static_cast(int_value.upper >> i * 8 & 0xFF); } @@ -84,11 +86,12 @@ string_t HugeintCastToVarInt::Operation(uhugeint_t int_value, Vector &result) { writable_blob[wb_idx++] = static_cast(int_value.lower >> i * 8 & 0xFF); } blob.Finalize(); - return blob; + bignum_t result_bignum(blob); + return result_bignum; } template <> -string_t HugeintCastToVarInt::Operation(hugeint_t int_value, Vector &result) { +bignum_t HugeintCastToBignum::Operation(hugeint_t int_value, Vector &result) { // Determine if the number is negative bool is_negative = int_value.upper >> 63 & 1; if (is_negative) { @@ -98,12 +101,12 @@ string_t HugeintCastToVarInt::Operation(hugeint_t int_value, Vector &result) { uhugeint_t u_int_value {0x8000000000000000, 0}; auto cast_value = Operation(u_int_value, result); // We have to do all the bit flipping. - auto writable_value_ptr = cast_value.GetDataWriteable(); - Varint::SetHeader(writable_value_ptr, cast_value.GetSize() - Varint::VARINT_HEADER_SIZE, is_negative); - for (idx_t i = Varint::VARINT_HEADER_SIZE; i < cast_value.GetSize(); i++) { + auto writable_value_ptr = cast_value.data.GetDataWriteable(); + Bignum::SetHeader(writable_value_ptr, cast_value.data.GetSize() - Bignum::BIGNUM_HEADER_SIZE, is_negative); + for (idx_t i = Bignum::BIGNUM_HEADER_SIZE; i < cast_value.data.GetSize(); i++) { writable_value_ptr[i] = static_cast(~writable_value_ptr[i]); } - cast_value.Finalize(); + cast_value.data.Finalize(); return cast_value; } int_value = -int_value; @@ -134,13 +137,13 @@ string_t HugeintCastToVarInt::Operation(hugeint_t int_value, Vector &result) { if (data_byte_size == 0) { data_byte_size++; } - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; + uint32_t blob_size = data_byte_size + Bignum::BIGNUM_HEADER_SIZE; auto blob = StringVector::EmptyString(result, blob_size); auto writable_blob = blob.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, is_negative); + Bignum::SetHeader(writable_blob, data_byte_size, is_negative); // Add data bytes to the blob, starting off after header bytes - idx_t wb_idx = Varint::VARINT_HEADER_SIZE; + idx_t wb_idx = Bignum::BIGNUM_HEADER_SIZE; for (int i = static_cast(upper_byte_size) - 1; i >= 0; --i) { if (is_negative) { writable_blob[wb_idx++] = static_cast(~(abs_value_upper >> i * 8 & 0xFF)); @@ -156,30 +159,30 @@ string_t HugeintCastToVarInt::Operation(hugeint_t int_value, Vector &result) { } } blob.Finalize(); - return blob; + bignum_t result_bignum(blob); + return result_bignum; } -// Varchar to Varint -// TODO: This is a slow quadratic algorithm, we can still optimize it further. +// Varchar to Bignum template <> -bool TryCastToVarInt::Operation(string_t input_value, string_t &result_value, Vector &result, +bool TryCastToBignum::Operation(string_t input_value, bignum_t &result_value, Vector &result, CastParameters ¶meters) { - auto blob_string = Varint::VarcharToVarInt(input_value); + auto blob_string = Bignum::VarcharToBignum(input_value); uint32_t blob_size = static_cast(blob_string.size()); - result_value = StringVector::EmptyString(result, blob_size); - auto writable_blob = result_value.GetDataWriteable(); + result_value = bignum_t(StringVector::EmptyString(result, blob_size)); + auto writable_blob = result_value.data.GetDataWriteable(); // Write string_blob into blob for (idx_t i = 0; i < blob_string.size(); i++) { writable_blob[i] = blob_string[i]; } - result_value.Finalize(); + result_value.data.Finalize(); return true; } template -static bool DoubleToVarInt(T double_value, string_t &result_value, Vector &result) { +static bool DoubleToBignum(T double_value, bignum_t &result_value, Vector &result) { // Check if we can cast it if (!std::isfinite(double_value)) { // We can't cast inf -inf nan @@ -192,7 +195,7 @@ static bool DoubleToVarInt(T double_value, string_t &result_value, Vector &resul if (abs_value == 0) { // Return Value 0 - result_value = Varint::InitializeVarintZero(result); + result_value = Bignum::InitializeBignumZero(result); return true; } vector value; @@ -208,73 +211,102 @@ static bool DoubleToVarInt(T double_value, string_t &result_value, Vector &resul } } uint32_t data_byte_size = static_cast(value.size()); - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; - result_value = StringVector::EmptyString(result, blob_size); - auto writable_blob = result_value.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, is_negative); + uint32_t blob_size = data_byte_size + Bignum::BIGNUM_HEADER_SIZE; + result_value.data = StringVector::EmptyString(result, blob_size); + auto writable_blob = result_value.data.GetDataWriteable(); + Bignum::SetHeader(writable_blob, data_byte_size, is_negative); // Add data bytes to the blob, starting off after header bytes idx_t blob_string_idx = value.size() - 1; - for (idx_t i = Varint::VARINT_HEADER_SIZE; i < blob_size; i++) { + for (idx_t i = Bignum::BIGNUM_HEADER_SIZE; i < blob_size; i++) { writable_blob[i] = value[blob_string_idx--]; } - result_value.Finalize(); + result_value.data.Finalize(); return true; } template <> -bool TryCastToVarInt::Operation(double double_value, string_t &result_value, Vector &result, +bool TryCastToBignum::Operation(double double_value, bignum_t &result_value, Vector &result, CastParameters ¶meters) { - return DoubleToVarInt(double_value, result_value, result); + return DoubleToBignum(double_value, result_value, result); } template <> -bool TryCastToVarInt::Operation(float double_value, string_t &result_value, Vector &result, +bool TryCastToBignum::Operation(float double_value, bignum_t &result_value, Vector &result, CastParameters ¶meters) { - return DoubleToVarInt(double_value, result_value, result); + return DoubleToBignum(double_value, result_value, result); } -BoundCastInfo Varint::NumericToVarintCastSwitch(const LogicalType &source) { +BoundCastInfo Bignum::NumericToBignumCastSwitch(const LogicalType &source) { // now switch on the result type switch (source.id()) { case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::UHUGEINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::DECIMAL: default: return DefaultCasts::TryVectorNullCast; } } -BoundCastInfo DefaultCasts::VarintCastSwitch(BindCastInput &input, const LogicalType &source, +BoundCastInfo DefaultCasts::BignumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(source.id() == LogicalTypeId::VARINT); + D_ASSERT(source.id() == LogicalTypeId::BIGNUM); // now switch on the result type switch (target.id()) { + case LogicalTypeId::TINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::UTINYINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::SMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::USMALLINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::INTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::UINTEGER: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::BIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::UBIGINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::HUGEINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + + case LogicalTypeId::UHUGEINT: + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + case LogicalTypeId::VARCHAR: - return BoundCastInfo(&VectorCastHelpers::StringCast); + return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); + return BoundCastInfo(&VectorCastHelpers::TryCastLoop); default: return TryVectorNullCast; } diff --git a/src/duckdb/src/function/cast/cast_function_set.cpp b/src/duckdb/src/function/cast/cast_function_set.cpp index 7735e5cd3..3c05591ed 100644 --- a/src/duckdb/src/function/cast/cast_function_set.cpp +++ b/src/duckdb/src/function/cast/cast_function_set.cpp @@ -1,5 +1,7 @@ #include "duckdb/function/cast/cast_function_set.hpp" +#include "duckdb/main/settings.hpp" + #include "duckdb/common/pair.hpp" #include "duckdb/common/types/type_map.hpp" #include "duckdb/function/cast_rules.hpp" @@ -161,7 +163,8 @@ struct MapCastInfo : public BindCastInfo { type_id_map_t>>> casts; }; -int64_t CastFunctionSet::ImplicitCastCost(const LogicalType &source, const LogicalType &target) { +int64_t CastFunctionSet::ImplicitCastCost(optional_ptr context, const LogicalType &source, + const LogicalType &target) { // check if a cast has been registered if (map_info) { auto entry = map_info->GetEntry(source, target); @@ -171,14 +174,29 @@ int64_t CastFunctionSet::ImplicitCastCost(const LogicalType &source, const Logic } // if not, fallback to the default implicit cast rules auto score = CastRules::ImplicitCast(source, target); - if (score < 0 && config && config->options.old_implicit_casting) { - if (source.id() != LogicalTypeId::BLOB && target.id() == LogicalTypeId::VARCHAR) { + if (score < 0 && source.id() != LogicalTypeId::BLOB && target.id() == LogicalTypeId::VARCHAR) { + bool old_implicit_casting = false; + if (context) { + old_implicit_casting = DBConfig::GetSetting(*context); + } else if (config) { + old_implicit_casting = DBConfig::GetSetting(*config); + } + if (old_implicit_casting) { score = 149; } } return score; } +int64_t CastFunctionSet::ImplicitCastCost(ClientContext &context, const LogicalType &source, + const LogicalType &target) { + return CastFunctionSet::Get(context).ImplicitCastCost(&context, source, target); +} + +int64_t CastFunctionSet::ImplicitCastCost(DatabaseInstance &db, const LogicalType &source, const LogicalType &target) { + return CastFunctionSet::Get(db).ImplicitCastCost(nullptr, source, target); +} + static BoundCastInfo MapCastFunction(BindCastInput &input, const LogicalType &source, const LogicalType &target) { D_ASSERT(input.info); auto &map_info = input.info->Cast(); diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp index 14a66a949..819dd3523 100644 --- a/src/duckdb/src/function/cast/default_casts.cpp +++ b/src/duckdb/src/function/cast/default_casts.cpp @@ -156,8 +156,8 @@ BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const L return EnumCastSwitch(input, source, target); case LogicalTypeId::ARRAY: return ArrayCastSwitch(input, source, target); - case LogicalTypeId::VARINT: - return VarintCastSwitch(input, source, target); + case LogicalTypeId::BIGNUM: + return BignumCastSwitch(input, source, target); case LogicalTypeId::AGGREGATE_STATE: return AggregateStateToBlobCast; default: diff --git a/src/duckdb/src/function/cast/numeric_casts.cpp b/src/duckdb/src/function/cast/numeric_casts.cpp index bdb999ffc..eb13fc220 100644 --- a/src/duckdb/src/function/cast/numeric_casts.cpp +++ b/src/duckdb/src/function/cast/numeric_casts.cpp @@ -2,7 +2,7 @@ #include "duckdb/function/cast/vector_cast_helpers.hpp" #include "duckdb/common/operator/string_cast.hpp" #include "duckdb/common/operator/numeric_cast.hpp" -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/types/bignum.hpp" namespace duckdb { @@ -42,8 +42,8 @@ static BoundCastInfo InternalNumericCastSwitch(const LogicalType &source, const return BoundCastInfo(&VectorCastHelpers::StringCast); case LogicalTypeId::BIT: return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::VARINT: - return Varint::NumericToVarintCastSwitch(source); + case LogicalTypeId::BIGNUM: + return Bignum::NumericToBignumCastSwitch(source); default: return DefaultCasts::TryVectorNullCast; } diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp index 02ae3830b..6ebfbf857 100644 --- a/src/duckdb/src/function/cast/string_cast.cpp +++ b/src/duckdb/src/function/cast/string_cast.cpp @@ -5,7 +5,7 @@ #include "duckdb/common/pair.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/function/cast/bound_cast_data.hpp" -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/types/bignum.hpp" namespace duckdb { @@ -513,8 +513,8 @@ BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const Logical MapBoundCastData::BindMapToMapCast( input, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR), target), InitMapCastLocalState); - case LogicalTypeId::VARINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); + case LogicalTypeId::BIGNUM: + return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); default: return VectorStringCastNumericSwitch(input, source, target); } diff --git a/src/duckdb/src/function/cast/union_casts.cpp b/src/duckdb/src/function/cast/union_casts.cpp index 7f1ee738b..65f018d2b 100644 --- a/src/duckdb/src/function/cast/union_casts.cpp +++ b/src/duckdb/src/function/cast/union_casts.cpp @@ -22,7 +22,7 @@ static unique_ptr BindToUnionCast(BindCastInput &input, const Log for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(target); member_idx++) { auto member_type = UnionType::GetMemberType(target, member_idx); auto member_name = UnionType::GetMemberName(target, member_idx); - auto member_cast_cost = input.function_set.ImplicitCastCost(source, member_type); + auto member_cast_cost = input.function_set.ImplicitCastCost(nullptr, source, member_type); if (member_cast_cost != -1) { auto member_cast_info = input.GetCastFunction(source, member_type); candidates.emplace_back(member_idx, member_name, member_type, member_cast_cost, diff --git a/src/duckdb/src/function/cast_rules.cpp b/src/duckdb/src/function/cast_rules.cpp index 951ecc935..5ab775eea 100644 --- a/src/duckdb/src/function/cast_rules.cpp +++ b/src/duckdb/src/function/cast_rules.cpp @@ -18,6 +18,8 @@ static int64_t TargetTypeCost(const LogicalType &type) { return 104; case LogicalTypeId::DECIMAL: return 105; + case LogicalTypeId::BIGNUM: + return 106; case LogicalTypeId::TIMESTAMP_NS: return 119; case LogicalTypeId::TIMESTAMP: @@ -38,6 +40,9 @@ static int64_t TargetTypeCost(const LogicalType &type) { return 160; case LogicalTypeId::ANY: return int64_t(AnyType::GetCastScore(type)); + case LogicalTypeId::TEMPLATE: + // we can cast anything to a template type, but prefer to cast to anything else! + return 1000000; default: return 110; } @@ -52,6 +57,7 @@ static int64_t ImplicitCastTinyint(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -66,6 +72,7 @@ static int64_t ImplicitCastSmallint(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -79,6 +86,7 @@ static int64_t ImplicitCastInteger(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -91,6 +99,7 @@ static int64_t ImplicitCastBigint(const LogicalType &to) { case LogicalTypeId::DOUBLE: case LogicalTypeId::HUGEINT: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -110,6 +119,7 @@ static int64_t ImplicitCastUTinyint(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -127,6 +137,7 @@ static int64_t ImplicitCastUSmallint(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -143,6 +154,7 @@ static int64_t ImplicitCastUInteger(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -156,6 +168,7 @@ static int64_t ImplicitCastUBigint(const LogicalType &to) { case LogicalTypeId::UHUGEINT: case LogicalTypeId::HUGEINT: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -164,6 +177,7 @@ static int64_t ImplicitCastUBigint(const LogicalType &to) { static int64_t ImplicitCastFloat(const LogicalType &to) { switch (to.id()) { + case LogicalTypeId::BIGNUM: case LogicalTypeId::DOUBLE: return TargetTypeCost(to); default: @@ -173,6 +187,9 @@ static int64_t ImplicitCastFloat(const LogicalType &to) { static int64_t ImplicitCastDouble(const LogicalType &to) { switch (to.id()) { + + case LogicalTypeId::BIGNUM: + return TargetTypeCost(to); default: return -1; } @@ -193,6 +210,7 @@ static int64_t ImplicitCastHugeint(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -204,6 +222,7 @@ static int64_t ImplicitCastUhugeint(const LogicalType &to) { case LogicalTypeId::FLOAT: case LogicalTypeId::DOUBLE: case LogicalTypeId::DECIMAL: + case LogicalTypeId::BIGNUM: return TargetTypeCost(to); default: return -1; @@ -274,7 +293,7 @@ static int64_t ImplicitCastTimestamp(const LogicalType &to) { } } -static int64_t ImplicitCastVarint(const LogicalType &to) { +static int64_t ImplicitCastBignum(const LogicalType &to) { switch (to.id()) { case LogicalTypeId::DOUBLE: return TargetTypeCost(to); @@ -333,7 +352,11 @@ bool LogicalTypeIsValid(const LogicalType &type) { } int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) { - if (from.id() == LogicalTypeId::SQLNULL || to.id() == LogicalTypeId::ANY) { + if (from.id() == LogicalTypeId::SQLNULL && to.id() == LogicalTypeId::TEMPLATE) { + // Prefer the TEMPLATE type for NULL casts, as it is the most generic + return 5; + } + if (from.id() == LogicalTypeId::SQLNULL || to.id() == LogicalTypeId::ANY || to.id() == LogicalTypeId::TEMPLATE) { // NULL expression can be cast to anything return TargetTypeCost(to); } @@ -575,8 +598,8 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) return ImplicitCastTimestampNS(to); case LogicalTypeId::TIMESTAMP: return ImplicitCastTimestamp(to); - case LogicalTypeId::VARINT: - return ImplicitCastVarint(to); + case LogicalTypeId::BIGNUM: + return ImplicitCastBignum(to); default: return -1; } diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp index 96f52c659..c9ad55d92 100644 --- a/src/duckdb/src/function/function_binder.cpp +++ b/src/duckdb/src/function/function_binder.cpp @@ -3,6 +3,7 @@ #include "duckdb/catalog/catalog.hpp" #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/common/limits.hpp" +#include "duckdb/common/type_visitor.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/function/aggregate_function.hpp" #include "duckdb/function/cast_rules.hpp" @@ -33,7 +34,7 @@ optional_idx FunctionBinder::BindVarArgsFunctionCost(const SimpleFunction &func, // arguments match: do nothing continue; } - int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], arg_type); + int64_t cast_cost = CastFunctionSet::ImplicitCastCost(context, arguments[i], arg_type); if (cast_cost >= 0) { // we can implicitly cast, add the cost to the total cost cost += idx_t(cast_cost); @@ -61,7 +62,7 @@ optional_idx FunctionBinder::BindFunctionCost(const SimpleFunction &func, const has_parameter = true; continue; } - int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], func.arguments[i]); + int64_t cast_cost = CastFunctionSet::ImplicitCastCost(context, arguments[i], func.arguments[i]); if (cast_cost >= 0) { // we can implicitly cast, add the cost to the total cost cost += idx_t(cast_cost); @@ -234,7 +235,7 @@ optional_idx FunctionBinder::BindFunction(const string &name, TableFunctionSet & enum class LogicalTypeComparisonResult : uint8_t { IDENTICAL_TYPE, TARGET_IS_ANY, DIFFERENT_TYPES }; -LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const LogicalType &target_type) { +static LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const LogicalType &target_type) { if (target_type.id() == LogicalTypeId::ANY) { return LogicalTypeComparisonResult::TARGET_IS_ANY; } @@ -250,7 +251,7 @@ LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const L return LogicalTypeComparisonResult::DIFFERENT_TYPES; } -bool TypeRequiresPrepare(const LogicalType &type) { +static bool TypeRequiresPrepare(const LogicalType &type) { if (type.id() == LogicalTypeId::ANY) { return true; } @@ -260,7 +261,7 @@ bool TypeRequiresPrepare(const LogicalType &type) { return false; } -LogicalType PrepareTypeForCastRecursive(const LogicalType &type) { +static LogicalType PrepareTypeForCastRecursive(const LogicalType &type) { if (type.id() == LogicalTypeId::ANY) { return AnyType::GetTargetType(type); } @@ -270,7 +271,7 @@ LogicalType PrepareTypeForCastRecursive(const LogicalType &type) { return type; } -void PrepareTypeForCast(LogicalType &type) { +static void PrepareTypeForCast(LogicalType &type) { if (!TypeRequiresPrepare(type)) { return; } @@ -354,11 +355,11 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogE return BindScalarFunction(bound_function, std::move(children), is_operator, binder); } -bool RequiresCollationPropagation(const LogicalType &type) { +static bool RequiresCollationPropagation(const LogicalType &type) { return type.id() == LogicalTypeId::VARCHAR && !type.HasAlias(); } -string ExtractCollation(const vector> &children) { +static string ExtractCollation(const vector> &children) { string collation; for (auto &arg : children) { if (!RequiresCollationPropagation(arg->return_type)) { @@ -375,7 +376,8 @@ string ExtractCollation(const vector> &children) { return collation; } -void PropagateCollations(ClientContext &, ScalarFunction &bound_function, vector> &children) { +static void PropagateCollations(ClientContext &, ScalarFunction &bound_function, + vector> &children) { if (!RequiresCollationPropagation(bound_function.return_type)) { // we only need to propagate if the function returns a varchar return; @@ -390,8 +392,8 @@ void PropagateCollations(ClientContext &, ScalarFunction &bound_function, vector bound_function.return_type = std::move(collation_type); } -void PushCollations(ClientContext &context, ScalarFunction &bound_function, vector> &children, - CollationType type) { +static void PushCollations(ClientContext &context, ScalarFunction &bound_function, + vector> &children, CollationType type) { auto collation = ExtractCollation(children); if (collation.empty()) { // no collation to push @@ -413,8 +415,8 @@ void PushCollations(ClientContext &context, ScalarFunction &bound_function, vect } } -void HandleCollations(ClientContext &context, ScalarFunction &bound_function, - vector> &children) { +static void HandleCollations(ClientContext &context, ScalarFunction &bound_function, + vector> &children) { switch (bound_function.collation_handling) { case FunctionCollationHandling::IGNORE_COLLATIONS: // explicitly ignoring collation handling @@ -431,9 +433,224 @@ void HandleCollations(ClientContext &context, ScalarFunction &bound_function, } } +static void InferTemplateType(ClientContext &context, const LogicalType &source, const LogicalType &target, + case_insensitive_map_t> &bindings, const Expression ¤t_expr, + const BaseScalarFunction &function) { + + if (target.id() == LogicalTypeId::UNKNOWN || target.id() == LogicalTypeId::SQLNULL) { + // If the actual type is unknown, we cannot infer anything more. + // Therefore, we map all remaining templates in the source to UNKNOWN or SQLNULL, if not already inferred to + // something else + + // This might seem a bit strange, why not just not set the binding and error out later when we try to substitute + // all templates? Well, this is how bindings for most nested functions already work, they simply propagate the + // UNKNOWN/SQLNULL. The binder will later check for UNKNOWN/SQLNULL in return types and if it finds one, insert + // a dummy cast to INT32 so that the function can be executed without errors (and just return NULLs). + + TypeVisitor::Contains(source, [&](const LogicalType &child) { + if (child.id() == LogicalTypeId::TEMPLATE) { + const auto index = TemplateType::GetName(child); + if (bindings.find(index) == bindings.end()) { + // not found, add the binding + bindings[index] = {target.id()}; + } + } + return false; // continue visiting + }); + return; + } + + // If the source is a template type, we bind it, or try to unify its existing binding with the target type. + if (source.id() == LogicalTypeId::TEMPLATE) { + const auto &index = TemplateType::GetName(source); + auto it = bindings.find(index); + if (it == bindings.end()) { + // not found, add the binding + bindings[index] = {target}; + return; + } + if (it->second.back() == target) { + // already bound to the same type + return; + } + + // Try to unify (promote) the type candidates + LogicalType result; + if (LogicalType::TryGetMaxLogicalType(context, it->second.back(), target, result)) { + // Type unification was successful + if (it->second.back() != result) { + // update the binding + it->second.push_back(target); + it->second.push_back(std::move(result)); // Push the new promoted type + } + return; + } + + // If we reach here, it means the types are incompatible + string msg = + StringUtil::Format("Cannot deduce template type '%s' in function: '%s'\nType '%s' was inferred to be:\n", + TemplateType::GetName(source), function.ToString(), TemplateType::GetName(source)); + const auto &steps = it->second; + + for (idx_t i = 0; i < steps.size(); i += 2) { + if (i == 0) { + // Normalize the first step to ensure it is a valid type + msg += StringUtil::Format(" - '%s', from first occurrence\n", steps[i].ToString()); + } else { + msg += StringUtil::Format(" - '%s', by promoting '%s' + '%s'\n", steps[i].ToString(), + steps[i - 2].ToString(), steps[i - 1]); + } + } + msg += StringUtil::Format(" - '%s', which is incompatible with previously inferred type!", target.ToString()); + throw BinderException(current_expr.GetQueryLocation(), msg); + } + + // Otherwise, recurse downwards into nested types, and try to infer nested type members + // This only works if the source and target types are completely defined (excluding templates), + // i.e. they have aux info. + if (!(source.IsNested() && target.IsNested() && source.AuxInfo() && target.AuxInfo())) { + return; + } + + switch (source.id()) { + case LogicalTypeId::LIST: + case LogicalTypeId::ARRAY: { + if ((source.id() == LogicalTypeId::ARRAY || source.id() == LogicalTypeId::LIST) && + (target.id() == LogicalTypeId::LIST || target.id() == LogicalTypeId::ARRAY)) { + + const auto &source_child = + source.id() == LogicalTypeId::LIST ? ListType::GetChildType(source) : ArrayType::GetChildType(source); + const auto &target_child = + target.id() == LogicalTypeId::LIST ? ListType::GetChildType(target) : ArrayType::GetChildType(target); + InferTemplateType(context, source_child, target_child, bindings, current_expr, function); + } + } break; + case LogicalTypeId::MAP: { + // Map is only implicitly castable to map, so we only need to handle this case here/ + if (target.id() == LogicalTypeId::MAP) { + const auto &source_key = MapType::KeyType(source); + const auto &source_val = MapType::ValueType(source); + const auto &target_key = MapType::KeyType(target); + const auto &target_val = MapType::ValueType(target); + + InferTemplateType(context, source_key, target_key, bindings, current_expr, function); + InferTemplateType(context, source_val, target_val, bindings, current_expr, function); + } + } break; + case LogicalTypeId::UNION: { + // TODO: Support union types with template member types. + throw NotImplementedException("Union types cannot infer templated member types yet!"); + } break; + case LogicalTypeId::STRUCT: { + // Structs are only implicitly castable to structs, so we only need to handle this case here. + if (target.id() == LogicalTypeId::STRUCT && StructType::IsUnnamed(source)) { + const auto &source_children = StructType::GetChildTypes(source); + const auto &target_children = StructType::GetChildTypes(target); + + const auto common_children = MinValue(source_children.size(), target_children.size()); + for (idx_t i = 0; i < common_children; i++) { + const auto &source_child_type = source_children[i].second; + const auto &target_child_type = target_children[i].second; + InferTemplateType(context, source_child_type, target_child_type, bindings, current_expr, function); + } + } else { + // TODO: Support named structs with template child types. + throw NotImplementedException("Named structs cannot infer templated child types yet!"); + } + } break; + default: + break; // no template type to infer + } +} + +static void SubstituteTemplateType(LogicalType &type, case_insensitive_map_t> &bindings, + const string &function_name) { + + // Replace all template types in with their bound concrete types. + type = TypeVisitor::VisitReplace(type, [&](const LogicalType &t) -> LogicalType { + if (t.id() == LogicalTypeId::TEMPLATE) { + const auto index = TemplateType::GetName(t); + auto it = bindings.find(index); + if (it != bindings.end()) { + // found a binding, return the concrete type + return LogicalType::NormalizeType(it->second.back()); + } + + // If we reach here, the template type was not bound to any concrete type. + // We dont throw an error here, but give users a chance to handle unresolved template type later in the + // "bind_scalar_function_t" callback. We then throw an error if the template type is still not bound + // in the "CheckTemplateTypesResolved" method afterwards. + } + return t; + }); +} + +void FunctionBinder::ResolveTemplateTypes(BaseScalarFunction &bound_function, + const vector> &children) { + case_insensitive_map_t> bindings; + vector> to_substitute; + + // First, we need to infer the template types from the children. + for (idx_t i = 0; i < bound_function.arguments.size(); i++) { + auto ¶m = bound_function.arguments[i]; + + // If the parameter is not templated, we can skip it. + if (param.IsTemplated()) { + auto actual = ExpressionBinder::GetExpressionReturnType(*children[i]); + InferTemplateType(context, param, actual, bindings, *children[i], bound_function); + + to_substitute.emplace_back(param); + } + } + + // If the function has a templated varargs, we need to infer its type too + if (bound_function.varargs.IsTemplated()) { + // All remaining children are considered varargs. + for (idx_t i = bound_function.arguments.size(); i < children.size(); i++) { + auto actual = ExpressionBinder::GetExpressionReturnType(*children[i]); + InferTemplateType(context, bound_function.varargs, actual, bindings, *children[i], bound_function); + } + to_substitute.emplace_back(bound_function.varargs); + } + + // If the return type is templated, we need to subsitute it as well + if (bound_function.return_type.IsTemplated()) { + to_substitute.emplace_back(bound_function.return_type); + } + + // Finally, substitute all template types in the bound function with their concrete types. + for (auto &templated_type : to_substitute) { + SubstituteTemplateType(templated_type, bindings, bound_function.name); + } +} + +static void VerifyTemplateType(const LogicalType &type, const string &function_name) { + TypeVisitor::Contains(type, [&](const LogicalType &type) { + if (type.id() == LogicalTypeId::TEMPLATE) { + const auto msg = + "Function '%s' has a template parameter type '%s' that could not be resolved to a concrete type"; + throw BinderException(msg, function_name, TemplateType::GetName(type)); + } + return false; // continue visiting + }); +} + +// Verify that all template types are bound to concrete types. +void FunctionBinder::CheckTemplateTypesResolved(const BaseScalarFunction &bound_function) { + for (const auto &arg : bound_function.arguments) { + VerifyTemplateType(arg, bound_function.name); + } + VerifyTemplateType(bound_function.varargs, bound_function.name); + VerifyTemplateType(bound_function.return_type, bound_function.name); +} + unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, vector> children, bool is_operator, optional_ptr binder) { + + // Attempt to resolve template types, before we call the "Bind" callback. + ResolveTemplateTypes(bound_function, children); + unique_ptr bind_info; if (bound_function.bind) { @@ -448,11 +665,15 @@ unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_f bind_info = bound_function.bind_extended(bind_input, bound_function, children); } + // After the "bind" callback, we verify that all template types are bound to concrete types. + CheckTemplateTypesResolved(bound_function); + if (bound_function.get_modified_databases && binder) { auto &properties = binder->GetStatementProperties(); FunctionModifiedDatabasesInput input(bind_info, properties); bound_function.get_modified_databases(context, input); } + HandleCollations(context, bound_function, children); // check if we need to add casts to the children @@ -477,6 +698,9 @@ unique_ptr FunctionBinder::BindAggregateFunction(Aggre vector> children, unique_ptr filter, AggregateType aggr_type) { + + ResolveTemplateTypes(bound_function, children); + unique_ptr bind_info; if (bound_function.bind) { bind_info = bound_function.bind(context, bound_function, children); @@ -484,6 +708,8 @@ unique_ptr FunctionBinder::BindAggregateFunction(Aggre children.resize(MinValue(bound_function.arguments.size(), children.size())); } + CheckTemplateTypesResolved(bound_function); + // check if we need to add casts to the children CastToFunctionArguments(bound_function, children); diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp index 08ec07a24..ad1a3185a 100644 --- a/src/duckdb/src/function/function_list.cpp +++ b/src/duckdb/src/function/function_list.cpp @@ -53,6 +53,7 @@ static const StaticFunctionDefinition function[] = { DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUintegerFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUsmallintFun), DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUtinyintFun), + DUCKDB_SCALAR_FUNCTION(InternalCompressStringHugeintFun), DUCKDB_SCALAR_FUNCTION(InternalCompressStringUbigintFun), DUCKDB_SCALAR_FUNCTION(InternalCompressStringUhugeintFun), DUCKDB_SCALAR_FUNCTION(InternalCompressStringUintegerFun), diff --git a/src/duckdb/src/function/pragma/pragma_functions.cpp b/src/duckdb/src/function/pragma/pragma_functions.cpp index f03da603f..ad1a3ec3a 100644 --- a/src/duckdb/src/function/pragma/pragma_functions.cpp +++ b/src/duckdb/src/function/pragma/pragma_functions.cpp @@ -12,6 +12,7 @@ #include "duckdb/storage/buffer_manager.hpp" #include "duckdb/storage/storage_manager.hpp" #include "duckdb/common/encryption_functions.hpp" +#include "duckdb/logging/log_manager.hpp" #include diff --git a/src/duckdb/src/function/pragma/pragma_queries.cpp b/src/duckdb/src/function/pragma/pragma_queries.cpp index 921395a03..9107b8c01 100644 --- a/src/duckdb/src/function/pragma/pragma_queries.cpp +++ b/src/duckdb/src/function/pragma/pragma_queries.cpp @@ -19,19 +19,32 @@ static string PragmaTableInfo(ClientContext &context, const FunctionParameters & KeywordHelper::WriteQuoted(parameters.values[0].ToString(), '\'')); } -string PragmaShowTables() { +string PragmaShowTables(const string &database, const string &schema) { + string where_clause = ""; + vector where_conditions; + if (!database.empty()) { + where_conditions.push_back(StringUtil::Format("lower(database_name) = lower(%s)", SQLString(database))); + } + if (!schema.empty()) { + where_conditions.push_back(StringUtil::Format("lower(schema_name) = lower(%s)", SQLString(schema))); + } + if (where_conditions.empty()) { + where_conditions.push_back("in_search_path(database_name, schema_name)"); + } + where_clause = "WHERE " + StringUtil::Join(where_conditions, " AND "); + // clang-format off - return R"EOF( + string query = R"EOF( with "tables" as ( SELECT table_name as "name" FROM duckdb_tables - where in_search_path(database_name, schema_name) + )EOF" + where_clause + R"EOF( ), "views" as ( SELECT view_name as "name" FROM duckdb_views - where in_search_path(database_name, schema_name) + )EOF" + where_clause + R"EOF( ), db_objects as ( SELECT "name" FROM "tables" @@ -41,6 +54,8 @@ string PragmaShowTables() { SELECT "name" FROM db_objects ORDER BY "name";)EOF"; + + return query; // clang-format on } diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp index 21c71a5e3..740924397 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp @@ -228,6 +228,11 @@ ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, GetIntegralCompressFunctionInputSwitch(input_type, result_type), CMUtils::Bind); result.serialize = CMIntegralSerialize; result.deserialize = CMIntegralDeserialize; +#if defined(D_ASSERT_IS_ENABLED) + result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; // Can only throw runtime error when assertions are enabled +#else + result.errors = FunctionErrors::CANNOT_ERROR; +#endif return result; } diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp index 36175bedb..204a36c66 100644 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp +++ b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp @@ -38,8 +38,9 @@ inline RESULT_TYPE StringCompressInternal(const string_t &input) { TemplatedReverseMemCpy(result_ptr + REMAINDER, const_data_ptr_cast(input.GetPrefix())); memset(result_ptr, '\0', REMAINDER); } else { - const auto remainder = sizeof(RESULT_TYPE) - input.GetSize(); - ReverseMemCpy(result_ptr + remainder, data_ptr_cast(input.GetPointer()), input.GetSize()); + const auto size = MinValue(sizeof(RESULT_TYPE), input.GetSize()); + const auto remainder = sizeof(RESULT_TYPE) - size; + ReverseMemCpy(result_ptr + remainder, data_ptr_cast(input.GetPointer()), size); memset(result_ptr, '\0', remainder); } result_ptr[0] = UnsafeNumericCast(input.GetSize()); @@ -97,6 +98,9 @@ scalar_function_t GetStringCompressFunctionSwitch(const LogicalType &result_type return GetStringCompressFunction(result_type); case LogicalTypeId::UHUGEINT: return GetStringCompressFunction(result_type); + case LogicalTypeId::HUGEINT: + // Never generated, only for backwards compatibility + return GetStringCompressFunction(result_type); default: throw InternalException("Unexpected type in GetStringCompressFunctionSwitch"); } @@ -238,6 +242,11 @@ ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) GetStringCompressFunctionSwitch(result_type), CMUtils::Bind); result.serialize = CMStringCompressSerialize; result.deserialize = CMStringCompressDeserialize; +#if defined(D_ASSERT_IS_ENABLED) + result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; // Can only throw runtime error when assertions are enabled +#else + result.errors = FunctionErrors::CANNOT_ERROR; +#endif return result; } @@ -266,6 +275,11 @@ ScalarFunction InternalCompressStringUbigintFun::GetFunction() { return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::UBIGINT)); } +ScalarFunction InternalCompressStringHugeintFun::GetFunction() { + // We never generate this, but it's needed for backwards compatibility + return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::HUGEINT)); +} + ScalarFunction InternalCompressStringUhugeintFun::GetFunction() { return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::UHUGEINT)); } diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp index 7bdad3db1..064bd4b00 100644 --- a/src/duckdb/src/function/scalar/list/contains_or_position.cpp +++ b/src/duckdb/src/function/scalar/list/contains_or_position.cpp @@ -26,62 +26,14 @@ static void ListSearchFunction(DataChunk &input, ExpressionState &state, Vector } } -static unique_ptr ListSearchBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - // If the first argument is an array, cast it to a list - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - const auto &list = arguments[0]->return_type; - const auto &value = arguments[1]->return_type; - - if (list.id() == LogicalTypeId::SQLNULL) { - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.arguments[1] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); - } - - if (list.IsUnknown() && value.IsUnknown()) { - bound_function.arguments[0] = list; - bound_function.arguments[1] = value; - return nullptr; - } - - if (list.IsUnknown()) { - // Only the list type is unknown. - // We can infer its type from the type of the value. - bound_function.arguments[0] = LogicalType::LIST(value); - bound_function.arguments[1] = value; - } else if (value.IsUnknown()) { - // Only the value type is unknown. - // We can infer its type from the child type of the list. - bound_function.arguments[0] = list; - bound_function.arguments[1] = ListType::GetChildType(list); - } else { - LogicalType max_child_type; - if (!LogicalType::TryGetMaxLogicalType(context, ListType::GetChildType(list), value, max_child_type)) { - throw BinderException( - "%s: Cannot match element of type '%s' in a list of type '%s' - an explicit cast is required", - bound_function.name, value.ToString(), list.ToString()); - } - - bound_function.arguments[0] = LogicalType::LIST(max_child_type); - bound_function.arguments[1] = max_child_type; - } - - return make_uniq(bound_function.return_type); -} - ScalarFunction ListContainsFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, LogicalType::BOOLEAN, - ListSearchFunction, ListSearchBind); + return ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::TEMPLATE("T")}, + LogicalType::BOOLEAN, ListSearchFunction); } ScalarFunction ListPositionFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, LogicalType::INTEGER, - ListSearchFunction, ListSearchBind); + auto fun = ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::TEMPLATE("T")}, + LogicalType::INTEGER, ListSearchFunction); fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; return fun; } diff --git a/src/duckdb/src/function/scalar/list/list_extract.cpp b/src/duckdb/src/function/scalar/list/list_extract.cpp index 058ce5f42..fd79249d9 100644 --- a/src/duckdb/src/function/scalar/list/list_extract.cpp +++ b/src/duckdb/src/function/scalar/list/list_extract.cpp @@ -137,14 +137,7 @@ static unique_ptr ListExtractBind(ClientContext &context, ScalarFu vector> &arguments) { D_ASSERT(bound_function.arguments.size() == 2); arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id()); - // list extract returns the child type of the list as return type - auto child_type = ListType::GetChildType(arguments[0]->return_type); - - bound_function.return_type = child_type; - bound_function.arguments[0] = LogicalType::LIST(child_type); - return make_uniq(bound_function.return_type); + return nullptr; } static unique_ptr ListExtractStats(ClientContext &context, FunctionStatisticsInput &input) { @@ -160,8 +153,8 @@ ScalarFunctionSet ListExtractFun::GetFunctions() { ScalarFunctionSet list_extract_set("list_extract"); // the arguments and return types are actually set in the binder function - ScalarFunction lfun({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, - ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); + ScalarFunction lfun({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::BIGINT}, + LogicalType::TEMPLATE("T"), ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); BaseScalarFunction::SetReturnsError(lfun); @@ -175,8 +168,8 @@ ScalarFunctionSet ArrayExtractFun::GetFunctions() { ScalarFunctionSet array_extract_set("array_extract"); // the arguments and return types are actually set in the binder function - ScalarFunction lfun({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, - ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); + ScalarFunction lfun({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::BIGINT}, + LogicalType::TEMPLATE("T"), ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); diff --git a/src/duckdb/src/function/scalar/list/list_select.cpp b/src/duckdb/src/function/scalar/list/list_select.cpp index a6f2bfb96..ebaef993c 100644 --- a/src/duckdb/src/function/scalar/list/list_select.cpp +++ b/src/duckdb/src/function/scalar/list/list_select.cpp @@ -148,39 +148,19 @@ void ListSelectFunction(DataChunk &args, ExpressionState &state, Vector &result) result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); } -unique_ptr ListSelectBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - // If the first argument is an array, cast it to a list - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - LogicalType child_type; - if (arguments[0]->return_type == LogicalTypeId::UNKNOWN || arguments[1]->return_type == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); - } - - D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id() || - LogicalTypeId::SQLNULL == arguments[0]->return_type.id()); - - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); -} - } // namespace + ScalarFunction ListWhereFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::LIST(LogicalTypeId::ANY), LogicalType::LIST(LogicalType::BOOLEAN)}, - LogicalType::LIST(LogicalTypeId::ANY), ListSelectFunction, - ListSelectBind); + auto fun = + ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::LIST(LogicalType::BOOLEAN)}, + LogicalType::LIST(LogicalType::TEMPLATE("T")), ListSelectFunction); return fun; } ScalarFunction ListSelectFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::LIST(LogicalTypeId::ANY), LogicalType::LIST(LogicalType::BIGINT)}, - LogicalType::LIST(LogicalTypeId::ANY), ListSelectFunction, - ListSelectBind); + auto fun = + ScalarFunction({LogicalType::LIST(LogicalType::TEMPLATE("T")), LogicalType::LIST(LogicalType::BIGINT)}, + LogicalType::LIST(LogicalType::TEMPLATE("T")), ListSelectFunction); return fun; } diff --git a/src/duckdb/src/function/scalar/map/map_contains.cpp b/src/duckdb/src/function/scalar/map/map_contains.cpp index ebc6edb5f..823fb9c1b 100644 --- a/src/duckdb/src/function/scalar/map/map_contains.cpp +++ b/src/duckdb/src/function/scalar/map/map_contains.cpp @@ -18,38 +18,12 @@ static void MapContainsFunction(DataChunk &input, ExpressionState &state, Vector } } -static unique_ptr MapContainsBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - const auto &map = arguments[0]->return_type; - const auto &key = arguments[1]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - - if (key.id() == LogicalTypeId::UNKNOWN) { - // Infer the argument type from the map type - bound_function.arguments[0] = map; - bound_function.arguments[1] = MapType::KeyType(map); - } else { - LogicalType max_child_type; - if (!LogicalType::TryGetMaxLogicalType(context, MapType::KeyType(map), key, max_child_type)) { - throw BinderException( - "%s: Cannot match element of type '%s' in a map of type '%s' - an explicit cast is required", - bound_function.name, key.ToString(), map.ToString()); - } - - bound_function.arguments[0] = LogicalType::MAP(max_child_type, MapType::ValueType(map)); - bound_function.arguments[1] = max_child_type; - } - return nullptr; -} - ScalarFunction MapContainsFun::GetFunction() { - ScalarFunction fun("map_contains", {LogicalType::MAP(LogicalType::ANY, LogicalType::ANY), LogicalType::ANY}, - LogicalType::BOOLEAN, MapContainsFunction, MapContainsBind); + auto key_type = LogicalType::TEMPLATE("K"); + auto val_type = LogicalType::TEMPLATE("V"); + + ScalarFunction fun("map_contains", {LogicalType::MAP(key_type, val_type), key_type}, LogicalType::BOOLEAN, + MapContainsFunction); return fun; } diff --git a/src/duckdb/src/function/scalar/operator/arithmetic.cpp b/src/duckdb/src/function/scalar/operator/arithmetic.cpp index 563432713..1ed6095ed 100644 --- a/src/duckdb/src/function/scalar/operator/arithmetic.cpp +++ b/src/duckdb/src/function/scalar/operator/arithmetic.cpp @@ -1,4 +1,5 @@ #include "duckdb/common/enum_util.hpp" +#include "duckdb/common/bignum.hpp" #include "duckdb/common/operator/add.hpp" #include "duckdb/common/operator/interpolate.hpp" #include "duckdb/common/operator/multiply.hpp" @@ -311,6 +312,40 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &type) { } } +void BignumAdd(DataChunk &args, ExpressionState &state, Vector &result) { + auto &allocator = state.GetAllocator(); + ArenaAllocator arena(allocator); + BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), + [&](bignum_t a, bignum_t b) { + const BignumIntermediate lhs(a); + const BignumIntermediate rhs(b); + return BignumIntermediate::Add(result, lhs, rhs); + }); +} + +void BignumSubtract(DataChunk &args, ExpressionState &state, Vector &result) { + auto &allocator = state.GetAllocator(); + ArenaAllocator arena(allocator); + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), [&](bignum_t a, bignum_t b) { + const BignumIntermediate lhs(a); + BignumIntermediate rhs(b); + rhs.NegateInPlace(); + auto result_value = BignumIntermediate::Add(result, lhs, rhs); + rhs.NegateInPlace(); + return result_value; + }); +} + +void BignumNegate(DataChunk &args, ExpressionState &state, Vector &result) { + auto &allocator = state.GetAllocator(); + ArenaAllocator arena(allocator); + UnaryExecutor::Execute(args.data[0], result, args.size(), [&](bignum_t a) { + const BignumIntermediate lhs(a); + return lhs.Negate(result); + }); +} + ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { if (left_type.IsNumeric() && left_type.id() == right_type.id()) { if (left_type.id() == LogicalTypeId::DECIMAL) { @@ -336,6 +371,14 @@ ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const Logi } switch (left_type.id()) { + case LogicalTypeId::BIGNUM: + if (right_type.id() == LogicalTypeId::BIGNUM) { + ScalarFunction function("+", {left_type, right_type}, LogicalType::BIGNUM, BignumAdd); + BaseScalarFunction::SetReturnsError(function); + return function; + } + break; + case LogicalTypeId::DATE: if (right_type.id() == LogicalTypeId::INTEGER) { ScalarFunction function("+", {left_type, right_type}, LogicalType::DATE, @@ -477,6 +520,9 @@ ScalarFunctionSet OperatorAddFun::GetFunctions() { // we can add lists together add.AddFunction(ListConcatFun::GetFunction()); + // we can add bignums together + add.AddFunction(AddFunction::GetFunction(LogicalType::BIGNUM, LogicalType::BIGNUM)); + return add; } @@ -629,6 +675,9 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &type) { } else if (type.id() == LogicalTypeId::DECIMAL) { ScalarFunction func("-", {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); return func; + } else if (type.id() == LogicalTypeId::BIGNUM) { + ScalarFunction func("+", {type}, LogicalType::BIGNUM, BignumNegate); + return func; } else { D_ASSERT(type.IsNumeric()); ScalarFunction func("-", {type}, type, ScalarFunction::GetScalarUnaryFunction(type), nullptr, @@ -664,6 +713,10 @@ ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const } switch (left_type.id()) { + case LogicalTypeId::BIGNUM: { + ScalarFunction function("-", {left_type, right_type}, left_type, BignumSubtract); + return function; + } case LogicalTypeId::DATE: if (right_type.id() == LogicalTypeId::DATE) { ScalarFunction function("-", {left_type, right_type}, LogicalType::BIGINT, @@ -741,6 +794,8 @@ ScalarFunctionSet OperatorSubtractFun::GetFunctions() { // binary subtract function "a - b", subtracts b from a subtract.AddFunction(SubtractFunction::GetFunction(type, type)); } + subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::BIGNUM)); + subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::BIGNUM, LogicalType::BIGNUM)); // we can subtract dates from each other subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::DATE, LogicalType::DATE)); // we can subtract integers from dates diff --git a/src/duckdb/src/function/scalar/system/write_log.cpp b/src/duckdb/src/function/scalar/system/write_log.cpp index bcd55bdc5..b7649f8f5 100644 --- a/src/duckdb/src/function/scalar/system/write_log.cpp +++ b/src/duckdb/src/function/scalar/system/write_log.cpp @@ -2,7 +2,7 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/main/client_data.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" - +#include "duckdb/logging/log_manager.hpp" #include "utf8proc.hpp" namespace duckdb { diff --git a/src/duckdb/src/function/table/arrow.cpp b/src/duckdb/src/function/table/arrow.cpp index 676787ae0..d41f63b2c 100644 --- a/src/duckdb/src/function/table/arrow.cpp +++ b/src/duckdb/src/function/table/arrow.cpp @@ -18,23 +18,32 @@ namespace duckdb { -void ArrowTableFunction::PopulateArrowTableType(DBConfig &config, ArrowTableType &arrow_table, - const ArrowSchemaWrapper &schema_p, vector &names, - vector &return_types) { - for (idx_t col_idx = 0; col_idx < static_cast(schema_p.arrow_schema.n_children); col_idx++) { - auto &schema = *schema_p.arrow_schema.children[col_idx]; +void ArrowTableFunction::PopulateArrowTableSchema(DBConfig &config, ArrowTableSchema &arrow_table, + const ArrowSchema &arrow_schema) { + vector names; + // We first gather the column names and deduplicate them + for (idx_t col_idx = 0; col_idx < static_cast(arrow_schema.n_children); col_idx++) { + const auto &schema = *arrow_schema.children[col_idx]; if (!schema.release) { throw InvalidInputException("arrow_scan: released schema passed"); } - auto arrow_type = ArrowType::GetArrowLogicalType(config, schema); - return_types.emplace_back(arrow_type->GetDuckType(true)); - arrow_table.AddColumn(col_idx, std::move(arrow_type)); auto name = string(schema.name); if (name.empty()) { name = string("v") + to_string(col_idx); } names.push_back(name); } + QueryResult::DeduplicateColumns(names); + + // We do a second iteration to figure out the arrow types and already set their deduplicated names + for (idx_t col_idx = 0; col_idx < static_cast(arrow_schema.n_children); col_idx++) { + auto &schema = *arrow_schema.children[col_idx]; + if (!schema.release) { + throw InvalidInputException("arrow_scan: released schema passed"); + } + auto arrow_type = ArrowType::GetArrowLogicalType(config, schema); + arrow_table.AddColumn(col_idx, std::move(arrow_type), names[col_idx]); + } } unique_ptr ArrowTableFunction::ArrowScanBindDumb(ClientContext &context, TableFunctionBindInput &input, @@ -69,8 +78,9 @@ unique_ptr ArrowTableFunction::ArrowScanBind(ClientContext &contex auto &data = *res; stream_factory_get_schema(reinterpret_cast(stream_factory_ptr), data.schema_root.arrow_schema); - PopulateArrowTableType(DBConfig::GetConfig(context), res->arrow_table, data.schema_root, names, return_types); - QueryResult::DeduplicateColumns(names); + PopulateArrowTableSchema(DBConfig::GetConfig(context), res->arrow_table, data.schema_root.arrow_schema); + names = res->arrow_table.GetNames(); + return_types = res->arrow_table.GetTypes(); res->all_types = return_types; if (return_types.empty()) { throw InvalidInputException("Provided table/dataframe must have at least one column"); diff --git a/src/duckdb/src/function/table/arrow/arrow_array_scan_state.cpp b/src/duckdb/src/function/table/arrow/arrow_array_scan_state.cpp index bbb1bf3e2..fd9d14dc7 100644 --- a/src/duckdb/src/function/table/arrow/arrow_array_scan_state.cpp +++ b/src/duckdb/src/function/table/arrow/arrow_array_scan_state.cpp @@ -4,15 +4,14 @@ namespace duckdb { -ArrowArrayScanState::ArrowArrayScanState(ArrowScanLocalState &state, ClientContext &context) - : state(state), context(context) { +ArrowArrayScanState::ArrowArrayScanState(ClientContext &context) : context(context) { arrow_dictionary = nullptr; } ArrowArrayScanState &ArrowArrayScanState::GetChild(idx_t child_idx) { auto it = children.find(child_idx); if (it == children.end()) { - auto child_p = make_uniq(state, context); + auto child_p = make_uniq(context); auto &child = *child_p; child.owned_data = owned_data; children.emplace(child_idx, std::move(child_p)); diff --git a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp index 86e63ff36..ab5611c13 100644 --- a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp +++ b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp @@ -8,15 +8,27 @@ namespace duckdb { -void ArrowTableType::AddColumn(idx_t index, shared_ptr type) { +void ArrowTableSchema::AddColumn(idx_t index, shared_ptr type, const string &name) { D_ASSERT(arrow_convert_data.find(index) == arrow_convert_data.end()); + if (index >= types.size()) { + types.resize(index + 1); + column_names.resize(index + 1); + } + types[index] = type->GetDuckType(true); + column_names[index] = name; arrow_convert_data.emplace(std::make_pair(index, std::move(type))); } -const arrow_column_map_t &ArrowTableType::GetColumns() const { +const arrow_column_map_t &ArrowTableSchema::GetColumns() const { return arrow_convert_data; } +vector &ArrowTableSchema::GetTypes() { + return types; +} +vector &ArrowTableSchema::GetNames() { + return column_names; +} void ArrowType::SetDictionary(unique_ptr dictionary) { D_ASSERT(!this->dictionary_type); dictionary_type = std::move(dictionary); @@ -380,6 +392,16 @@ bool ArrowType::HasExtension() const { return extension_data.get() != nullptr; } +ArrowArrayPhysicalType ArrowType::GetPhysicalType() const { + if (HasDictionary()) { + return ArrowArrayPhysicalType::DICTIONARY_ENCODED; + } + if (RunEndEncoded()) { + return ArrowArrayPhysicalType::RUN_END_ENCODED; + } + return ArrowArrayPhysicalType::DEFAULT; +} + unique_ptr ArrowType::GetTypeFromSchema(DBConfig &config, ArrowSchema &schema) { auto format = string(schema.format); // Let's first figure out if this type is an extension type diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp index 047034033..e194852f0 100644 --- a/src/duckdb/src/function/table/arrow_conversion.cpp +++ b/src/duckdb/src/function/table/arrow_conversion.cpp @@ -13,22 +13,6 @@ namespace duckdb { -namespace { - -enum class ArrowArrayPhysicalType : uint8_t { DICTIONARY_ENCODED, RUN_END_ENCODED, DEFAULT }; - -ArrowArrayPhysicalType GetArrowArrayPhysicalType(const ArrowType &type) { - if (type.HasDictionary()) { - return ArrowArrayPhysicalType::DICTIONARY_ENCODED; - } - if (type.RunEndEncoded()) { - return ArrowArrayPhysicalType::RUN_END_ENCODED; - } - return ArrowArrayPhysicalType::DEFAULT; -} - -} // namespace - #if STANDARD_VECTOR_SIZE > 64 static void ShiftRight(unsigned char *ar, int size, int shift) { int carry = 0; @@ -42,7 +26,7 @@ static void ShiftRight(unsigned char *ar, int size, int shift) { } #endif -idx_t GetEffectiveOffset(const ArrowArray &array, int64_t parent_offset, const ArrowScanLocalState &state, +idx_t GetEffectiveOffset(const ArrowArray &array, int64_t parent_offset, idx_t chunk_offset, int64_t nested_offset = -1) { if (nested_offset != -1) { // The parent of this array is a list @@ -52,7 +36,7 @@ idx_t GetEffectiveOffset(const ArrowArray &array, int64_t parent_offset, const A // Parent offset is set in the case of a struct, it applies to all child arrays // 'chunk_offset' is how much of the chunk we've already scanned, in case the chunk size exceeds // STANDARD_VECTOR_SIZE - return UnsafeNumericCast(array.offset + parent_offset) + state.chunk_offset; + return UnsafeNumericCast(array.offset + parent_offset) + chunk_offset; } template @@ -60,7 +44,7 @@ T *ArrowBufferData(ArrowArray &array, idx_t buffer_idx) { return (T *)array.buffers[buffer_idx]; // NOLINT } -static void GetValidityMask(ValidityMask &mask, ArrowArray &array, const ArrowScanLocalState &scan_state, idx_t size, +static void GetValidityMask(ValidityMask &mask, ArrowArray &array, idx_t chunk_offset, idx_t size, int64_t parent_offset, int64_t nested_offset = -1, bool add_null = false) { // In certains we don't need to or cannot copy arrow's validity mask to duckdb. // @@ -69,7 +53,7 @@ static void GetValidityMask(ValidityMask &mask, ArrowArray &array, const ArrowSc // 2. n_buffers > 0, meaning the array's arrow type is not `null` // 3. the validity buffer (the first buffer) is not a nullptr if (array.null_count != 0 && array.n_buffers > 0 && array.buffers[0]) { - auto bit_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto bit_offset = GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); mask.EnsureWritable(); #if STANDARD_VECTOR_SIZE > 64 auto n_bitmask_bytes = (size + 8 - 1) / 8; @@ -107,26 +91,13 @@ static void GetValidityMask(ValidityMask &mask, ArrowArray &array, const ArrowSc } } -static void SetValidityMask(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, idx_t size, - int64_t parent_offset, int64_t nested_offset, bool add_null = false) { +void ArrowToDuckDBConversion::SetValidityMask(Vector &vector, ArrowArray &array, idx_t chunk_offset, idx_t size, + int64_t parent_offset, int64_t nested_offset, bool add_null) { D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); auto &mask = FlatVector::Validity(vector); - GetValidityMask(mask, array, scan_state, size, parent_offset, nested_offset, add_null); + GetValidityMask(mask, array, chunk_offset, size, parent_offset, nested_offset, add_null); } -static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset = -1, - ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); - -static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset = -1, - ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0, - bool ignore_extensions = false); - -static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset = -1, - const ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); - namespace { struct ArrowListOffsetData { @@ -223,15 +194,13 @@ static ArrowListOffsetData ConvertArrowListOffsets(Vector &vector, ArrowArray &a } } -static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, - int64_t parent_offset) { - auto &scan_state = array_state.state; - +static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, idx_t chunk_offset, ArrowArrayScanState &array_state, + idx_t size, const ArrowType &arrow_type, int64_t nested_offset, + const ValidityMask *parent_mask, int64_t parent_offset) { auto &list_info = arrow_type.GetTypeInfo(); - SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); + ArrowToDuckDBConversion::SetValidityMask(vector, array, chunk_offset, size, parent_offset, nested_offset); - auto effective_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto effective_offset = GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); auto list_data = ConvertArrowListOffsets(vector, array, size, arrow_type, effective_offset); auto &start_offset = list_data.start_offset; auto &list_size = list_data.list_size; @@ -239,8 +208,8 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanS ListVector::Reserve(vector, list_size); ListVector::SetListSize(vector, list_size); auto &child_vector = ListVector::GetEntry(vector); - SetValidityMask(child_vector, *array.children[0], scan_state, list_size, array.offset, - NumericCast(start_offset)); + ArrowToDuckDBConversion::SetValidityMask(child_vector, *array.children[0], chunk_offset, list_size, array.offset, + NumericCast(start_offset)); auto &list_mask = FlatVector::Validity(vector); if (parent_mask) { //! Since this List is owned by a struct we must guarantee their validity map matches on Null @@ -258,45 +227,47 @@ static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanS if (list_size == 0 && start_offset == 0) { D_ASSERT(!child_array.dictionary); - ColumnArrowToDuckDB(child_vector, child_array, child_state, list_size, child_type, -1); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(child_vector, child_array, chunk_offset, child_state, list_size, + child_type, -1); return; } - auto array_physical_type = GetArrowArrayPhysicalType(child_type); + auto array_physical_type = child_type.GetPhysicalType(); switch (array_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: // TODO: add support for offsets - ColumnArrowToDuckDBDictionary(child_vector, child_array, child_state, list_size, child_type, - NumericCast(start_offset)); + ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(child_vector, child_array, chunk_offset, child_state, + list_size, child_type, + NumericCast(start_offset)); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(child_vector, child_array, child_state, list_size, child_type, - NumericCast(start_offset)); + ArrowToDuckDBConversion::ColumnArrowToDuckDBRunEndEncoded(child_vector, child_array, chunk_offset, child_state, + list_size, child_type, + NumericCast(start_offset)); break; case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(child_vector, child_array, child_state, list_size, child_type, - NumericCast(start_offset)); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(child_vector, child_array, chunk_offset, child_state, list_size, + child_type, NumericCast(start_offset)); break; default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); } } -static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, - int64_t parent_offset) { +static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, idx_t chunk_offset, ArrowArrayScanState &array_state, + idx_t size, const ArrowType &arrow_type, int64_t nested_offset, + const ValidityMask *parent_mask, int64_t parent_offset) { auto &array_info = arrow_type.GetTypeInfo(); - auto &scan_state = array_state.state; auto array_size = array_info.FixedSize(); auto child_count = array_size * size; - auto child_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset) * array_size; + auto child_offset = GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset) * array_size; - SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); + ArrowToDuckDBConversion::SetValidityMask(vector, array, chunk_offset, size, parent_offset, nested_offset); auto &child_vector = ArrayVector::GetEntry(vector); - SetValidityMask(child_vector, *array.children[0], scan_state, child_count, array.offset, - NumericCast(child_offset)); + ArrowToDuckDBConversion::SetValidityMask(child_vector, *array.children[0], chunk_offset, child_count, array.offset, + NumericCast(child_offset)); auto &array_mask = FlatVector::Validity(vector); if (parent_mask) { @@ -327,14 +298,16 @@ static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScan auto &child_type = array_info.GetChild(); if (child_count == 0 && child_offset == 0) { D_ASSERT(!child_array.dictionary); - ColumnArrowToDuckDB(child_vector, child_array, child_state, child_count, child_type, -1); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(child_vector, child_array, chunk_offset, child_state, child_count, + child_type, -1); } else { if (child_array.dictionary) { - ColumnArrowToDuckDBDictionary(child_vector, child_array, child_state, child_count, child_type, - NumericCast(child_offset)); + ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(child_vector, child_array, chunk_offset, child_state, + child_count, child_type, + NumericCast(child_offset)); } else { - ColumnArrowToDuckDB(child_vector, child_array, child_state, child_count, child_type, - NumericCast(child_offset)); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(child_vector, child_array, chunk_offset, child_state, + child_count, child_type, NumericCast(child_offset)); } } } @@ -401,22 +374,22 @@ static void SetVectorStringView(Vector &vector, idx_t size, ArrowArray &array, i } } -static void DirectConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, uint64_t parent_offset) { +static void DirectConversion(Vector &vector, ArrowArray &array, idx_t chunk_offset, int64_t nested_offset, + uint64_t parent_offset) { auto internal_type = GetTypeIdSize(vector.GetType().InternalType()); auto data_ptr = ArrowBufferData(array, 1) + - internal_type * GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + internal_type * GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); FlatVector::SetData(vector, data_ptr); } template -static void TimeConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { +static void TimeConversion(Vector &vector, ArrowArray &array, idx_t chunk_offset, int64_t nested_offset, + int64_t parent_offset, idx_t size, int64_t conversion) { auto tgt_ptr = FlatVector::GetData(vector); auto &validity_mask = FlatVector::Validity(vector); - auto src_ptr = - static_cast(array.buffers[1]) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto src_ptr = static_cast(array.buffers[1]) + + GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { if (!validity_mask.RowIsValid(row)) { continue; @@ -427,12 +400,12 @@ static void TimeConversion(Vector &vector, ArrowArray &array, const ArrowScanLoc } } -static void UUIDConversion(Vector &vector, const ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size) { +static void UUIDConversion(Vector &vector, const ArrowArray &array, idx_t chunk_offset, int64_t nested_offset, + int64_t parent_offset, idx_t size) { auto tgt_ptr = FlatVector::GetData(vector); auto &validity_mask = FlatVector::Validity(vector); auto src_ptr = static_cast(array.buffers[1]) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { if (!validity_mask.RowIsValid(row)) { continue; @@ -444,12 +417,12 @@ static void UUIDConversion(Vector &vector, const ArrowArray &array, const ArrowS } } -static void TimestampTZConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { +static void TimestampTZConversion(Vector &vector, ArrowArray &array, idx_t chunk_offset, int64_t nested_offset, + int64_t parent_offset, idx_t size, int64_t conversion) { auto tgt_ptr = FlatVector::GetData(vector); auto &validity_mask = FlatVector::Validity(vector); auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { if (!validity_mask.RowIsValid(row)) { continue; @@ -460,11 +433,11 @@ static void TimestampTZConversion(Vector &vector, ArrowArray &array, const Arrow } } -static void IntervalConversionUs(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { +static void IntervalConversionUs(Vector &vector, ArrowArray &array, idx_t chunk_offset, int64_t nested_offset, + int64_t parent_offset, idx_t size, int64_t conversion) { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].days = 0; tgt_ptr[row].months = 0; @@ -474,11 +447,11 @@ static void IntervalConversionUs(Vector &vector, ArrowArray &array, const ArrowS } } -static void IntervalConversionMonths(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size) { +static void IntervalConversionMonths(Vector &vector, ArrowArray &array, idx_t chunk_offset, int64_t nested_offset, + int64_t parent_offset, idx_t size) { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].days = 0; tgt_ptr[row].micros = 0; @@ -486,11 +459,11 @@ static void IntervalConversionMonths(Vector &vector, ArrowArray &array, const Ar } } -static void IntervalConversionMonthDayNanos(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, +static void IntervalConversionMonthDayNanos(Vector &vector, ArrowArray &array, idx_t chunk_offset, int64_t nested_offset, int64_t parent_offset, idx_t size) { auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); + auto src_ptr = ArrowBufferData(array, 1) + + GetEffectiveOffset(array, parent_offset, chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].days = src_ptr[row].days; tgt_ptr[row].micros = src_ptr[row].nanoseconds / Interval::NANOS_PER_MICRO; @@ -664,9 +637,11 @@ static void FlattenRunEndsSwitch(Vector &result, ArrowRunEndEncodingState &run_e } } -static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset, - ValidityMask *parent_mask, uint64_t parent_offset) { +void ArrowToDuckDBConversion::ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, + idx_t chunk_offset, ArrowArrayScanState &array_state, + idx_t size, const ArrowType &arrow_type, + int64_t nested_offset, ValidityMask *parent_mask, + uint64_t parent_offset) { // Scan the 'run_ends' array D_ASSERT(array.n_children == 2); auto &run_ends_array = *array.children[0]; @@ -677,7 +652,6 @@ static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &a auto &values_type = struct_info.GetChild(1); D_ASSERT(vector.GetType() == values_type.GetDuckType()); - auto &scan_state = array_state.state; if (vector.GetBuffer()) { vector.GetBuffer()->SetAuxiliaryData(make_uniq(array_state.owned_data)); } @@ -692,14 +666,16 @@ static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &a run_end_encoding.run_ends = make_uniq(run_ends_type.GetDuckType(), compressed_size); run_end_encoding.values = make_uniq(values_type.GetDuckType(), compressed_size); - ColumnArrowToDuckDB(*run_end_encoding.run_ends, run_ends_array, array_state, compressed_size, run_ends_type); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(*run_end_encoding.run_ends, run_ends_array, chunk_offset, + array_state, compressed_size, run_ends_type); auto &values = *run_end_encoding.values; - SetValidityMask(values, values_array, scan_state, compressed_size, NumericCast(parent_offset), - nested_offset); - ColumnArrowToDuckDB(values, values_array, array_state, compressed_size, values_type); + ArrowToDuckDBConversion::SetValidityMask(values, values_array, chunk_offset, compressed_size, + NumericCast(parent_offset), nested_offset); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(values, values_array, chunk_offset, array_state, compressed_size, + values_type); } - idx_t scan_offset = GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + idx_t scan_offset = GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); auto physical_type = run_ends_type.GetDuckType().InternalType(); switch (physical_type) { case PhysicalType::INT16: @@ -717,7 +693,7 @@ static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &a } template void ConvertDecimal(SRC src_ptr, Vector &vector, ArrowArray &array, idx_t size, int64_t nested_offset, - uint64_t parent_offset, ArrowScanLocalState &scan_state, ValidityMask &val_mask, + uint64_t parent_offset, idx_t chunk_offset, ValidityMask &val_mask, DecimalBitWidth arrow_bit_width) { switch (vector.GetType().InternalType()) { @@ -737,7 +713,7 @@ void ConvertDecimal(SRC src_ptr, Vector &vector, ArrowArray &array, idx_t size, FlatVector::SetData(vector, ArrowBufferData(array, 1) + GetTypeIdSize(vector.GetType().InternalType()) * GetEffectiveOffset(array, NumericCast(parent_offset), - scan_state, nested_offset)); + chunk_offset, nested_offset)); } else { auto tgt_ptr = FlatVector::GetData(vector); for (idx_t row = 0; row < size; row++) { @@ -755,7 +731,7 @@ void ConvertDecimal(SRC src_ptr, Vector &vector, ArrowArray &array, idx_t size, FlatVector::SetData(vector, ArrowBufferData(array, 1) + GetTypeIdSize(vector.GetType().InternalType()) * GetEffectiveOffset(array, NumericCast(parent_offset), - scan_state, nested_offset)); + chunk_offset, nested_offset)); } else { auto tgt_ptr = FlatVector::GetData(vector); for (idx_t row = 0; row < size; row++) { @@ -773,7 +749,7 @@ void ConvertDecimal(SRC src_ptr, Vector &vector, ArrowArray &array, idx_t size, FlatVector::SetData(vector, ArrowBufferData(array, 1) + GetTypeIdSize(vector.GetType().InternalType()) * GetEffectiveOffset(array, NumericCast(parent_offset), - scan_state, nested_offset)); + chunk_offset, nested_offset)); } else { auto tgt_ptr = FlatVector::GetData(vector); for (idx_t row = 0; row < size; row++) { @@ -793,17 +769,18 @@ void ConvertDecimal(SRC src_ptr, Vector &vector, ArrowArray &array, idx_t size, } } -static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask, - uint64_t parent_offset, bool ignore_extensions) { - auto &scan_state = array_state.state; +void ArrowToDuckDBConversion::ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, idx_t chunk_offset, + ArrowArrayScanState &array_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset, + ValidityMask *parent_mask, uint64_t parent_offset, + bool ignore_extensions) { D_ASSERT(!array.dictionary); if (!ignore_extensions && arrow_type.HasExtension()) { if (arrow_type.extension_data->arrow_to_duckdb) { // Convert the storage and then call the cast function Vector input_data(arrow_type.extension_data->GetInternalType()); - ColumnArrowToDuckDB(input_data, array, array_state, size, arrow_type, nested_offset, parent_mask, - parent_offset, /*ignore_extensions*/ true); + ColumnArrowToDuckDB(input_data, array, chunk_offset, array_state, size, arrow_type, nested_offset, + parent_mask, parent_offset, /*ignore_extensions*/ true); arrow_type.extension_data->arrow_to_duckdb(array_state.context, input_data, vector, size); return; } @@ -820,7 +797,7 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca //! Arrow bit-packs boolean values //! Lets first figure out where we are in the source array auto effective_offset = - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); auto src_ptr = ArrowBufferData(array, 1) + effective_offset / 8; auto tgt_ptr = (uint8_t *)FlatVector::GetData(vector); int src_pos = 0; @@ -856,15 +833,15 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca case LogicalTypeId::TIMESTAMP_MS: case LogicalTypeId::TIMESTAMP_NS: case LogicalTypeId::TIME_TZ: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); + DirectConversion(vector, array, chunk_offset, nested_offset, parent_offset); break; } case LogicalTypeId::UUID: - UUIDConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size); + UUIDConversion(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size); break; case LogicalTypeId::BLOB: case LogicalTypeId::BIT: - case LogicalTypeId::VARINT: + case LogicalTypeId::BIGNUM: case LogicalTypeId::VARCHAR: { auto &string_info = arrow_type.GetTypeInfo(); auto size_type = string_info.GetSizeType(); @@ -872,29 +849,29 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca case ArrowVariableSizeType::SUPER_SIZE: { auto cdata = ArrowBufferData(array, 2); auto offsets = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); SetVectorString(vector, size, cdata, offsets); break; } case ArrowVariableSizeType::NORMAL: { auto cdata = ArrowBufferData(array, 2); auto offsets = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); SetVectorString(vector, size, cdata, offsets); break; } case ArrowVariableSizeType::VIEW: { SetVectorStringView( vector, size, array, - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset)); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset)); break; } case ArrowVariableSizeType::FIXED_SIZE: { - SetValidityMask(vector, array, scan_state, size, NumericCast(parent_offset), nested_offset); + SetValidityMask(vector, array, chunk_offset, size, NumericCast(parent_offset), nested_offset); auto fixed_size = string_info.FixedSize(); // Have to check validity mask before setting this up - idx_t offset = - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset) * fixed_size; + idx_t offset = GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset) * + fixed_size; auto cdata = ArrowBufferData(array, 1); auto blob_len = fixed_size; auto result = FlatVector::GetData(vector); @@ -916,13 +893,13 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::DAYS: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); + DirectConversion(vector, array, chunk_offset, nested_offset, parent_offset); break; } case ArrowDateTimeType::MILLISECONDS: { //! convert date from nanoseconds to days auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); auto tgt_ptr = FlatVector::GetData(vector); for (idx_t row = 0; row < size; row++) { tgt_ptr[row] = date_t(UnsafeNumericCast(static_cast(src_ptr[row]) / @@ -940,24 +917,24 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000000); + TimeConversion(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), + size, 1000000); break; } case ArrowDateTimeType::MILLISECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000); + TimeConversion(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), + size, 1000); break; } case ArrowDateTimeType::MICROSECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1); + TimeConversion(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), + size, 1); break; } case ArrowDateTimeType::NANOSECONDS: { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].micros = src_ptr[row] / 1000; } @@ -973,23 +950,23 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + TimestampTZConversion(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size, 1000000); break; } case ArrowDateTimeType::MILLISECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + TimestampTZConversion(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size, 1000); break; } case ArrowDateTimeType::MICROSECONDS: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); + DirectConversion(vector, array, chunk_offset, nested_offset, parent_offset); break; } case ArrowDateTimeType::NANOSECONDS: { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].value = src_ptr[row] / 1000; } @@ -1005,25 +982,25 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto precision = datetime_info.GetDateTimeType(); switch (precision) { case ArrowDateTimeType::SECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + IntervalConversionUs(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size, 1000000); break; } case ArrowDateTimeType::DAYS: case ArrowDateTimeType::MILLISECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + IntervalConversionUs(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size, 1000); break; } case ArrowDateTimeType::MICROSECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, + IntervalConversionUs(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size, 1); break; } case ArrowDateTimeType::NANOSECONDS: { auto tgt_ptr = FlatVector::GetData(vector); auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); for (idx_t row = 0; row < size; row++) { tgt_ptr[row].micros = src_ptr[row] / 1000; tgt_ptr[row].days = 0; @@ -1032,12 +1009,12 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca break; } case ArrowDateTimeType::MONTHS: { - IntervalConversionMonths(vector, array, scan_state, nested_offset, NumericCast(parent_offset), + IntervalConversionMonths(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size); break; } case ArrowDateTimeType::MONTH_DAY_NANO: { - IntervalConversionMonthDayNanos(vector, array, scan_state, nested_offset, + IntervalConversionMonthDayNanos(vector, array, chunk_offset, nested_offset, NumericCast(parent_offset), size); break; } @@ -1054,22 +1031,25 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca switch (bit_width) { case DecimalBitWidth::DECIMAL_32: { auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - ConvertDecimal(src_ptr, vector, array, size, nested_offset, parent_offset, scan_state, val_mask, bit_width); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); + ConvertDecimal(src_ptr, vector, array, size, nested_offset, parent_offset, chunk_offset, val_mask, + bit_width); break; } case DecimalBitWidth::DECIMAL_64: { auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - ConvertDecimal(src_ptr, vector, array, size, nested_offset, parent_offset, scan_state, val_mask, bit_width); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); + ConvertDecimal(src_ptr, vector, array, size, nested_offset, parent_offset, chunk_offset, val_mask, + bit_width); break; } case DecimalBitWidth::DECIMAL_128: { auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - ConvertDecimal(src_ptr, vector, array, size, nested_offset, parent_offset, scan_state, val_mask, bit_width); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); + ConvertDecimal(src_ptr, vector, array, size, nested_offset, parent_offset, chunk_offset, val_mask, + bit_width); break; } default: @@ -1078,17 +1058,17 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca break; } case LogicalTypeId::LIST: { - ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, + ArrowToDuckDBList(vector, array, chunk_offset, array_state, size, arrow_type, nested_offset, parent_mask, NumericCast(parent_offset)); break; } case LogicalTypeId::ARRAY: { - ArrowToDuckDBArray(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, + ArrowToDuckDBArray(vector, array, chunk_offset, array_state, size, arrow_type, nested_offset, parent_mask, NumericCast(parent_offset)); break; } case LogicalTypeId::MAP: { - ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, + ArrowToDuckDBList(vector, array, chunk_offset, array_state, size, arrow_type, nested_offset, parent_mask, NumericCast(parent_offset)); ArrowToDuckDBMapVerify(vector, size); break; @@ -1104,7 +1084,8 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto &child_type = struct_info.GetChild(child_idx); auto &child_state = array_state.GetChild(child_idx); - SetValidityMask(child_entry, child_array, scan_state, size, array.offset, nested_offset); + ArrowToDuckDBConversion::SetValidityMask(child_entry, child_array, chunk_offset, size, array.offset, + nested_offset); if (!struct_validity_mask.AllValid()) { auto &child_validity_mark = FlatVector::Validity(child_entry); for (idx_t i = 0; i < size; i++) { @@ -1114,19 +1095,21 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca } } - auto array_physical_type = GetArrowArrayPhysicalType(child_type); + auto array_physical_type = child_type.GetPhysicalType(); switch (array_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, NumericCast(array.offset)); + ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary( + child_entry, child_array, chunk_offset, child_state, size, child_type, nested_offset, + &struct_validity_mask, NumericCast(array.offset)); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, NumericCast(array.offset)); + ColumnArrowToDuckDBRunEndEncoded(child_entry, child_array, chunk_offset, child_state, size, child_type, + nested_offset, &struct_validity_mask, + NumericCast(array.offset)); break; case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, NumericCast(array.offset), false); + ColumnArrowToDuckDB(child_entry, child_array, chunk_offset, child_state, size, child_type, + nested_offset, &struct_validity_mask, NumericCast(array.offset), false); break; default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); @@ -1148,19 +1131,22 @@ static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArraySca auto &child_state = array_state.GetChild(child_idx); auto &child_type = union_info.GetChild(child_idx); - SetValidityMask(child, child_array, scan_state, size, NumericCast(parent_offset), nested_offset); - auto array_physical_type = GetArrowArrayPhysicalType(child_type); + ArrowToDuckDBConversion::SetValidityMask(child, child_array, chunk_offset, size, + NumericCast(parent_offset), nested_offset); + auto array_physical_type = child_type.GetPhysicalType(); switch (array_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(child, child_array, child_state, size, child_type); + ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(child, child_array, chunk_offset, child_state, + size, child_type); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(child, child_array, child_state, size, child_type); + ArrowToDuckDBConversion::ColumnArrowToDuckDBRunEndEncoded(child, child_array, chunk_offset, child_state, + size, child_type); break; case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(child, child_array, child_state, size, child_type, nested_offset, &validity_mask, - false); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(child, child_array, chunk_offset, child_state, size, + child_type, nested_offset, &validity_mask, false); break; default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); @@ -1321,33 +1307,34 @@ static bool CanContainNull(const ArrowArray &array, const ValidityMask *parent_m return !parent_mask->AllValid(); } -static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset, - const ValidityMask *parent_mask, uint64_t parent_offset) { +void ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, idx_t chunk_offset, + ArrowArrayScanState &array_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset, + const ValidityMask *parent_mask, uint64_t parent_offset) { if (vector.GetBuffer()) { vector.GetBuffer()->SetAuxiliaryData(make_uniq(array_state.owned_data)); } D_ASSERT(arrow_type.HasDictionary()); - auto &scan_state = array_state.state; const bool has_nulls = CanContainNull(array, parent_mask); if (array_state.CacheOutdated(array.dictionary)) { //! We need to set the dictionary data for this column auto base_vector = make_uniq(vector.GetType(), NumericCast(array.dictionary->length)); - SetValidityMask(*base_vector, *array.dictionary, scan_state, NumericCast(array.dictionary->length), 0, 0, - has_nulls); + ArrowToDuckDBConversion::SetValidityMask(*base_vector, *array.dictionary, chunk_offset, + NumericCast(array.dictionary->length), 0, 0, has_nulls); auto &dictionary_type = arrow_type.GetDictionary(); - auto arrow_physical_type = GetArrowArrayPhysicalType(dictionary_type); + auto arrow_physical_type = dictionary_type.GetPhysicalType(); + ; switch (arrow_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(*base_vector, *array.dictionary, array_state, + ColumnArrowToDuckDBDictionary(*base_vector, *array.dictionary, chunk_offset, array_state, NumericCast(array.dictionary->length), dictionary_type); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(*base_vector, *array.dictionary, array_state, + ColumnArrowToDuckDBRunEndEncoded(*base_vector, *array.dictionary, chunk_offset, array_state, NumericCast(array.dictionary->length), dictionary_type); break; case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(*base_vector, *array.dictionary, array_state, + ColumnArrowToDuckDB(*base_vector, *array.dictionary, chunk_offset, array_state, NumericCast(array.dictionary->length), dictionary_type); break; default: @@ -1359,12 +1346,12 @@ static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, Arr //! Get Pointer to Indices of Dictionary auto indices = ArrowBufferData(array, 1) + GetTypeIdSize(offset_type.InternalType()) * - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); + GetEffectiveOffset(array, NumericCast(parent_offset), chunk_offset, nested_offset); SelectionVector sel; if (has_nulls) { ValidityMask indices_validity; - GetValidityMask(indices_validity, array, scan_state, size, NumericCast(parent_offset)); + GetValidityMask(indices_validity, array, chunk_offset, size, NumericCast(parent_offset)); if (parent_mask && !parent_mask->AllValid()) { auto &struct_validity_mask = *parent_mask; for (idx_t i = 0; i < size; i++) { @@ -1427,19 +1414,22 @@ void ArrowTableFunction::ArrowToDuckDB(ArrowScanLocalState &scan_state, const ar if (!array_state.owned_data) { array_state.owned_data = scan_state.chunk; } - - auto array_physical_type = GetArrowArrayPhysicalType(arrow_type); + auto array_physical_type = arrow_type.GetPhysicalType(); switch (array_physical_type) { case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(output.data[idx], array, array_state, output.size(), arrow_type); + ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(output.data[idx], array, scan_state.chunk_offset, + array_state, output.size(), arrow_type); break; case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(output.data[idx], array, array_state, output.size(), arrow_type); + ArrowToDuckDBConversion::ColumnArrowToDuckDBRunEndEncoded(output.data[idx], array, scan_state.chunk_offset, + array_state, output.size(), arrow_type); break; case ArrowArrayPhysicalType::DEFAULT: - SetValidityMask(output.data[idx], array, scan_state, output.size(), parent_array.offset, -1); - ColumnArrowToDuckDB(output.data[idx], array, array_state, output.size(), arrow_type); + ArrowToDuckDBConversion::SetValidityMask(output.data[idx], array, scan_state.chunk_offset, output.size(), + parent_array.offset, -1); + ArrowToDuckDBConversion::ColumnArrowToDuckDB(output.data[idx], array, scan_state.chunk_offset, array_state, + output.size(), arrow_type); break; default: throw NotImplementedException("ArrowArrayPhysicalType not recognized"); diff --git a/src/duckdb/src/function/table/read_file.cpp b/src/duckdb/src/function/table/read_file.cpp index 7ea4b9700..6be5618e5 100644 --- a/src/duckdb/src/function/table/read_file.cpp +++ b/src/duckdb/src/function/table/read_file.cpp @@ -138,7 +138,7 @@ static void ReadFileExecute(ClientContext &context, TableFunctionInput &input, D if (FileSystem::IsRemoteFile(file.path)) { flags |= FileFlags::FILE_FLAGS_DIRECT_IO; } - file_handle = fs.OpenFile(file, flags); + file_handle = fs.OpenFile(QueryContext(context), file, flags); } for (idx_t col_idx = 0; col_idx < state.column_ids.size(); col_idx++) { diff --git a/src/duckdb/src/function/table/system/duckdb_approx_database_count.cpp b/src/duckdb/src/function/table/system/duckdb_approx_database_count.cpp new file mode 100644 index 000000000..228f38eb2 --- /dev/null +++ b/src/duckdb/src/function/table/system/duckdb_approx_database_count.cpp @@ -0,0 +1,43 @@ +#include "duckdb/function/table/system_functions.hpp" +#include "duckdb/main/database_manager.hpp" + +namespace duckdb { + +struct DuckDBApproxDatabaseCountData : public GlobalTableFunctionState { + DuckDBApproxDatabaseCountData() : count(0), finished(false) { + } + idx_t count; + bool finished; +}; + +static unique_ptr DuckDBApproxDatabaseCountBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, + vector &names) { + names.emplace_back("approx_count"); + return_types.emplace_back(LogicalType::UBIGINT); + return nullptr; +} + +unique_ptr DuckDBApproxDatabaseCountInit(ClientContext &context, + TableFunctionInitInput &input) { + auto result = make_uniq(); + result->count = DatabaseManager::Get(context).ApproxDatabaseCount(); + return std::move(result); +} + +void DuckDBApproxDatabaseCountFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = data_p.global_state->Cast(); + if (data.finished) { + return; + } + output.SetValue(0, 0, Value::UBIGINT(data.count)); + output.SetCardinality(1); + data.finished = true; +} + +void DuckDBApproxDatabaseCountFun::RegisterFunction(BuiltinFunctions &set) { + set.AddFunction(TableFunction("duckdb_approx_database_count", {}, DuckDBApproxDatabaseCountFunction, + DuckDBApproxDatabaseCountBind, DuckDBApproxDatabaseCountInit)); +} + +} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_settings.cpp b/src/duckdb/src/function/table/system/duckdb_settings.cpp index b7c3a56b8..3ec34d908 100644 --- a/src/duckdb/src/function/table/system/duckdb_settings.cpp +++ b/src/duckdb/src/function/table/system/duckdb_settings.cpp @@ -7,10 +7,11 @@ namespace duckdb { struct DuckDBSettingValue { string name; - string value; + Value value; string description; string input_type; string scope; + vector aliases; }; struct DuckDBSettingsData : public GlobalTableFunctionState { @@ -38,12 +39,21 @@ static unique_ptr DuckDBSettingsBind(ClientContext &context, Table names.emplace_back("scope"); return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("aliases"); + return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); + return nullptr; } unique_ptr DuckDBSettingsInit(ClientContext &context, TableFunctionInitInput &input) { auto result = make_uniq(); + unordered_map> aliases; + for (idx_t i = 0; i < DBConfig::GetAliasCount(); i++) { + auto alias = DBConfig::GetAliasByIndex(i); + aliases[alias->option_index].emplace_back(alias->alias); + } + auto &config = DBConfig::GetConfig(context); auto options_count = DBConfig::GetOptionCount(); for (idx_t i = 0; i < options_count; i++) { @@ -52,25 +62,36 @@ unique_ptr DuckDBSettingsInit(ClientContext &context, DuckDBSettingValue value; auto scope = option->set_global ? SettingScope::GLOBAL : SettingScope::LOCAL; value.name = option->name; - value.value = option->get_setting(context).ToString(); + if (option->get_setting) { + value.value = option->get_setting(context); + } else { + auto lookup_result = context.TryGetCurrentSetting(value.name, value.value); + if (lookup_result) { + scope = lookup_result.GetScope(); + } else { + value.value = option->default_value; + } + } value.description = option->description; value.input_type = option->parameter_type; value.scope = EnumUtil::ToString(scope); + auto entry = aliases.find(i); + if (entry != aliases.end()) { + value.aliases = std::move(entry->second); + } result->settings.push_back(std::move(value)); } for (auto &ext_param : config.extension_parameters) { Value setting_val; - string setting_str_val; auto scope = SettingScope::GLOBAL; auto lookup_result = context.TryGetCurrentSetting(ext_param.first, setting_val); if (lookup_result) { - setting_str_val = setting_val.ToString(); scope = lookup_result.GetScope(); } DuckDBSettingValue value; value.name = ext_param.first; - value.value = std::move(setting_str_val); + value.value = std::move(setting_val); value.description = ext_param.second.description; value.input_type = ext_param.second.type.ToString(); value.scope = EnumUtil::ToString(scope); @@ -96,13 +117,15 @@ void DuckDBSettingsFunction(ClientContext &context, TableFunctionInput &data_p, // name, LogicalType::VARCHAR output.SetValue(0, count, Value(entry.name)); // value, LogicalType::VARCHAR - output.SetValue(1, count, Value(entry.value)); + output.SetValue(1, count, entry.value.CastAs(context, LogicalType::VARCHAR)); // description, LogicalType::VARCHAR output.SetValue(2, count, Value(entry.description)); // input_type, LogicalType::VARCHAR output.SetValue(3, count, Value(entry.input_type)); // scope, LogicalType::VARCHAR output.SetValue(4, count, Value(entry.scope)); + // aliases, LogicalType::VARCHAR[] + output.SetValue(5, count, Value::LIST(LogicalType::VARCHAR, std::move(entry.aliases))); count++; } output.SetCardinality(count); diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp index 0fbcd884b..0523524ea 100644 --- a/src/duckdb/src/function/table/system/test_all_types.cpp +++ b/src/duckdb/src/function/table/system/test_all_types.cpp @@ -8,6 +8,8 @@ #include #include +#include "duckdb/common/types/bignum.hpp" + namespace duckdb { struct TestAllTypesData : public GlobalTableFunctionState { @@ -18,7 +20,7 @@ struct TestAllTypesData : public GlobalTableFunctionState { idx_t offset; }; -vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { +vector TestAllTypesFun::GetTestTypes(bool use_large_enum, bool use_large_bignum) { vector result; // scalar types/numerics result.emplace_back(LogicalType::BOOLEAN, "bool"); @@ -32,7 +34,24 @@ vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { result.emplace_back(LogicalType::USMALLINT, "usmallint"); result.emplace_back(LogicalType::UINTEGER, "uint"); result.emplace_back(LogicalType::UBIGINT, "ubigint"); - result.emplace_back(LogicalType::VARINT, "varint"); + if (use_large_bignum) { + string data; + idx_t total_data_size = Bignum::BIGNUM_HEADER_SIZE + Bignum::MAX_DATA_SIZE; + data.resize(total_data_size); + // Let's set our header + Bignum::SetHeader(&data[0], Bignum::MAX_DATA_SIZE, false); + // Set all our other bits + memset(&data[Bignum::BIGNUM_HEADER_SIZE], 0xFF, Bignum::MAX_DATA_SIZE); + auto max = Value::BIGNUM(data); + // Let's set our header + Bignum::SetHeader(&data[0], Bignum::MAX_DATA_SIZE, true); + // Set all our other bits + memset(&data[Bignum::BIGNUM_HEADER_SIZE], 0x00, Bignum::MAX_DATA_SIZE); + auto min = Value::BIGNUM(data); + result.emplace_back(LogicalType::BIGNUM, "bignum", min, max); + } else { + result.emplace_back(LogicalType::BIGNUM, "bignum"); + } result.emplace_back(LogicalType::DATE, "date"); result.emplace_back(LogicalType::TIME, "time"); result.emplace_back(LogicalType::TIMESTAMP, "timestamp"); @@ -298,11 +317,16 @@ static unique_ptr TestAllTypesBind(ClientContext &context, TableFu vector &return_types, vector &names) { auto result = make_uniq(); bool use_large_enum = false; + bool use_large_bignum = false; auto entry = input.named_parameters.find("use_large_enum"); if (entry != input.named_parameters.end()) { use_large_enum = BooleanValue::Get(entry->second); } - result->test_types = TestAllTypesFun::GetTestTypes(use_large_enum); + entry = input.named_parameters.find("use_large_bignum"); + if (entry != input.named_parameters.end()) { + use_large_bignum = BooleanValue::Get(entry->second); + } + result->test_types = TestAllTypesFun::GetTestTypes(use_large_enum, use_large_bignum); for (auto &test_type : result->test_types) { return_types.push_back(test_type.type); names.push_back(test_type.name); @@ -346,6 +370,7 @@ void TestAllTypesFunction(ClientContext &context, TableFunctionInput &data_p, Da void TestAllTypesFun::RegisterFunction(BuiltinFunctions &set) { TableFunction test_all_types("test_all_types", {}, TestAllTypesFunction, TestAllTypesBind, TestAllTypesInit); test_all_types.named_parameters["use_large_enum"] = LogicalType::BOOLEAN; + test_all_types.named_parameters["use_large_bignum"] = LogicalType::BOOLEAN; set.AddFunction(test_all_types); } diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp index 6aa864ac5..9f1561b22 100644 --- a/src/duckdb/src/function/table/system_functions.cpp +++ b/src/duckdb/src/function/table/system_functions.cpp @@ -18,6 +18,7 @@ void BuiltinFunctions::RegisterSQLiteFunctions() { PragmaDatabaseSize::RegisterFunction(*this); PragmaUserAgent::RegisterFunction(*this); + DuckDBApproxDatabaseCountFun::RegisterFunction(*this); DuckDBColumnsFun::RegisterFunction(*this); DuckDBConstraintsFun::RegisterFunction(*this); DuckDBDatabasesFun::RegisterFunction(*this); diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp index 08f8ae4c1..5fea970c0 100644 --- a/src/duckdb/src/function/table/table_scan.cpp +++ b/src/duckdb/src/function/table/table_scan.cpp @@ -25,7 +25,7 @@ #include "duckdb/planner/expression/bound_comparison_expression.hpp" #include "duckdb/planner/filter/conjunction_filter.hpp" #include "duckdb/common/types/value_map.hpp" - +#include "duckdb/main/settings.hpp" #include namespace duckdb { @@ -109,14 +109,19 @@ class TableScanGlobalState : public GlobalTableFunctionState { class DuckIndexScanState : public TableScanGlobalState { public: DuckIndexScanState(ClientContext &context, const FunctionData *bind_data_p) - : TableScanGlobalState(context, bind_data_p), next_batch_index(0), finished(false) { + : TableScanGlobalState(context, bind_data_p), next_batch_index(0), arena(Allocator::Get(context)), + row_ids(nullptr), row_id_count(0), finished(false) { } //! The batch index of the next Sink. //! Also determines the offset of the next chunk. I.e., offset = next_batch_index * STANDARD_VECTOR_SIZE. atomic next_batch_index; - //! The total scanned row IDs. - unsafe_vector row_ids; + //! The arena allocator containing the memory of the row IDs. + ArenaAllocator arena; + //! A pointer to the row IDs. + row_t *row_ids; + //! The number of scanned row IDs. + idx_t row_id_count; //! The column IDs of the to-be-scanned columns. vector column_ids; //! True, if no more row IDs must be scanned. @@ -154,7 +159,6 @@ class DuckIndexScanState : public TableScanGlobalState { auto &storage = duck_table.GetStorage(); auto &l_state = data_p.local_state->Cast(); - auto row_id_count = row_ids.size(); idx_t scan_count = 0; idx_t offset = 0; @@ -173,7 +177,7 @@ class DuckIndexScanState : public TableScanGlobalState { } if (scan_count != 0) { - auto row_id_data = (data_ptr_t)&row_ids[0 + offset]; // NOLINT - this is not pretty + auto row_id_data = reinterpret_cast(row_ids + offset); Vector local_vector(LogicalType::ROW_TYPE, row_id_data); if (CanRemoveFilterColumns()) { @@ -198,13 +202,11 @@ class DuckIndexScanState : public TableScanGlobalState { } double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p) const override { - auto total_rows = row_ids.size(); - if (total_rows == 0) { + if (row_id_count == 0) { return 100; } - auto scanned_rows = next_batch_index * STANDARD_VECTOR_SIZE; - auto percentage = 100 * (static_cast(scanned_rows) / static_cast(total_rows)); + auto percentage = 100 * (static_cast(scanned_rows) / static_cast(row_id_count)); return percentage > 100 ? 100 : percentage; } @@ -339,26 +341,20 @@ unique_ptr DuckTableScanInitGlobal(ClientContext &cont } unique_ptr DuckIndexScanInitGlobal(ClientContext &context, TableFunctionInitInput &input, - const TableScanBindData &bind_data, - unsafe_vector &row_ids) { + const TableScanBindData &bind_data, set &row_ids) { auto g_state = make_uniq(context, input.bind_data.get()); + g_state->finished = row_ids.empty() ? true : false; + if (!row_ids.empty()) { - // Duplicate-eliminate row IDs. - unordered_set row_id_set; - auto it = row_ids.begin(); - while (it != row_ids.end()) { - if (row_id_set.find(*it) == row_id_set.end()) { - row_id_set.insert(*it++); - continue; - } - // Found a duplicate. - it = row_ids.erase(it); - } + auto row_id_ptr = g_state->arena.AllocateAligned(row_ids.size() * sizeof(row_t)); + g_state->row_ids = reinterpret_cast(row_id_ptr); + g_state->row_id_count = row_ids.size(); - std::sort(row_ids.begin(), row_ids.end()); - g_state->row_ids = std::move(row_ids); + idx_t row_id_count = 0; + for (const auto row_id : row_ids) { + g_state->row_ids[row_id_count++] = row_id; + } } - g_state->finished = g_state->row_ids.empty() ? true : false; auto &duck_table = bind_data.table.Cast(); if (input.CanRemoveFilterColumns()) { @@ -482,7 +478,7 @@ vector> ExtractFilterExpressions(const ColumnDefinition & } bool TryScanIndex(ART &art, const ColumnList &column_list, TableFunctionInitInput &input, TableFilterSet &filter_set, - idx_t max_count, unsafe_vector &row_ids) { + idx_t max_count, set &row_ids) { // FIXME: No support for index scans on compound ARTs. // See note above on multi-filter support. if (art.unbound_expressions.size() > 1) { @@ -569,9 +565,8 @@ unique_ptr TableScanInitGlobal(ClientContext &context, return DuckTableScanInitGlobal(context, input, storage, bind_data); } - auto &db_config = DBConfig::GetConfig(context); - auto scan_percentage = db_config.GetSetting(context); - auto scan_max_count = db_config.GetSetting(context); + auto scan_percentage = DBConfig::GetSetting(context); + auto scan_max_count = DBConfig::GetSetting(context); auto total_rows = storage.GetTotalRows(); auto total_rows_from_percentage = LossyNumericCast(double(total_rows) * scan_percentage); @@ -579,9 +574,15 @@ unique_ptr TableScanInitGlobal(ClientContext &context, auto &column_list = duck_table.GetColumns(); bool index_scan = false; - unsafe_vector row_ids; + set row_ids; - info->GetIndexes().BindAndScan(context, *info, [&](ART &art) { + info->BindIndexes(context, ART::TYPE_NAME); + info->GetIndexes().Scan([&](Index &index) { + if (index.GetIndexType() != ART::TYPE_NAME) { + return false; + } + D_ASSERT(index.IsBound()); + auto &art = index.Cast(); index_scan = TryScanIndex(art, column_list, input, filter_set, max_count, row_ids); return index_scan; }); diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp index 82ca7e67b..976b568b1 100644 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ b/src/duckdb/src/function/table/version/pragma_version.cpp @@ -1,5 +1,5 @@ #ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "0-dev1655" +#define DUCKDB_PATCH_VERSION "0-dev2825" #endif #ifndef DUCKDB_MINOR_VERSION #define DUCKDB_MINOR_VERSION 4 @@ -8,10 +8,10 @@ #define DUCKDB_MAJOR_VERSION 1 #endif #ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.4.0-dev1655" +#define DUCKDB_VERSION "v1.4.0-dev2825" #endif #ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "59382ca326" +#define DUCKDB_SOURCE_ID "3483d12aab" #endif #include "duckdb/function/table/system_functions.hpp" #include "duckdb/main/database.hpp" diff --git a/src/duckdb/src/function/window/window_aggregate_function.cpp b/src/duckdb/src/function/window/window_aggregate_function.cpp index 07fa8f3ec..20a512d9d 100644 --- a/src/duckdb/src/function/window/window_aggregate_function.cpp +++ b/src/duckdb/src/function/window/window_aggregate_function.cpp @@ -16,8 +16,9 @@ namespace duckdb { //===--------------------------------------------------------------------===// class WindowAggregateExecutorGlobalState : public WindowExecutorGlobalState { public: - WindowAggregateExecutorGlobalState(const WindowAggregateExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask); + WindowAggregateExecutorGlobalState(ClientContext &client, const WindowAggregateExecutor &executor, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask); // aggregate global state unique_ptr gsink; @@ -48,19 +49,19 @@ static BoundWindowExpression &SimplifyWindowedAggregate(BoundWindowExpression &w return wexpr; } -WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &context, +WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &client, WindowSharedExpressions &shared, WindowAggregationMode mode) - : WindowExecutor(SimplifyWindowedAggregate(wexpr, context), context, shared), mode(mode) { + : WindowExecutor(SimplifyWindowedAggregate(wexpr, client), shared), mode(mode) { // Force naive for SEPARATE mode or for (currently!) unsupported functionality - if (!ClientConfig::GetConfig(context).enable_optimizer || mode == WindowAggregationMode::SEPARATE) { + if (!ClientConfig::GetConfig(client).enable_optimizer || mode == WindowAggregationMode::SEPARATE) { aggregator = make_uniq(*this, shared); } else if (WindowDistinctAggregator::CanAggregate(wexpr)) { // build a merge sort tree // see https://dl.acm.org/doi/pdf/10.1145/3514221.3526184 - aggregator = make_uniq(wexpr, shared, context); + aggregator = make_uniq(wexpr, shared, client); } else if (WindowConstantAggregator::CanAggregate(wexpr)) { - aggregator = make_uniq(wexpr, shared, context); + aggregator = make_uniq(wexpr, shared, client); } else if (WindowCustomAggregator::CanAggregate(wexpr, mode)) { aggregator = make_uniq(wexpr, shared); } else if (WindowSegmentTree::CanAggregate(wexpr)) { @@ -80,25 +81,28 @@ WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, C } } -WindowAggregateExecutorGlobalState::WindowAggregateExecutorGlobalState(const WindowAggregateExecutor &executor, +WindowAggregateExecutorGlobalState::WindowAggregateExecutorGlobalState(ClientContext &client, + const WindowAggregateExecutor &executor, const idx_t group_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, group_count, partition_mask, order_mask), + : WindowExecutorGlobalState(client, executor, group_count, partition_mask, order_mask), filter_ref(executor.filter_ref.get()) { - gsink = executor.aggregator->GetGlobalState(executor.context, group_count, partition_mask); + gsink = executor.aggregator->GetGlobalState(client, group_count, partition_mask); } -unique_ptr WindowAggregateExecutor::GetGlobalState(const idx_t payload_count, +unique_ptr WindowAggregateExecutor::GetGlobalState(ClientContext &client, + const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); + return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -class WindowAggregateExecutorLocalState : public WindowExecutorBoundsState { +class WindowAggregateExecutorLocalState : public WindowExecutorBoundsLocalState { public: - WindowAggregateExecutorLocalState(const WindowExecutorGlobalState &gstate, const WindowAggregator &aggregator) - : WindowExecutorBoundsState(gstate), filter_executor(gstate.executor.context) { + WindowAggregateExecutorLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate, + const WindowAggregator &aggregator) + : WindowExecutorBoundsLocalState(context, gstate), filter_executor(gstate.client) { auto &gastate = gstate.Cast(); aggregator_state = aggregator.GetLocalState(*gastate.gsink); @@ -121,12 +125,13 @@ class WindowAggregateExecutorLocalState : public WindowExecutorBoundsState { }; unique_ptr -WindowAggregateExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate, *aggregator); +WindowAggregateExecutor::GetLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) const { + return make_uniq(context, gstate, *aggregator); } -void WindowAggregateExecutor::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, - WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { +void WindowAggregateExecutor::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + const idx_t input_idx, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate) const { auto &gastate = gstate.Cast(); auto &lastate = lstate.Cast(); auto &filter_sel = lastate.filter_sel; @@ -142,9 +147,9 @@ void WindowAggregateExecutor::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, D_ASSERT(aggregator); auto &gestate = *gastate.gsink; auto &lestate = *lastate.aggregator_state; - aggregator->Sink(gestate, lestate, sink_chunk, coll_chunk, input_idx, filtering, filtered); + aggregator->Sink(context, gestate, lestate, sink_chunk, coll_chunk, input_idx, filtering, filtered); - WindowExecutor::Sink(sink_chunk, coll_chunk, input_idx, gstate, lstate); + WindowExecutor::Sink(context, sink_chunk, coll_chunk, input_idx, gstate, lstate); } static void ApplyWindowStats(const WindowBoundary &boundary, FrameDelta &delta, BaseStatistics *base, bool is_start) { @@ -210,9 +215,9 @@ static void ApplyWindowStats(const WindowBoundary &boundary, FrameDelta &delta, } } -void WindowAggregateExecutor::Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection) const { - WindowExecutor::Finalize(gstate, lstate, collection); +void WindowAggregateExecutor::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, CollectionPtr collection) const { + WindowExecutor::Finalize(context, gstate, lstate, collection); auto &gastate = gstate.Cast(); auto &gsink = gastate.gsink; @@ -234,12 +239,12 @@ void WindowAggregateExecutor::Finalize(WindowExecutorGlobalState &gstate, Window ApplyWindowStats(wexpr.end, stats[1], base, false); auto &lastate = lstate.Cast(); - aggregator->Finalize(*gsink, *lastate.aggregator_state, collection, stats); + aggregator->Finalize(context, *gsink, *lastate.aggregator_state, collection, stats); } -void WindowAggregateExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowAggregateExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &gastate = gstate.Cast(); auto &lastate = lstate.Cast(); auto &gsink = gastate.gsink; @@ -247,7 +252,7 @@ void WindowAggregateExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate auto &agg_state = *lastate.aggregator_state; - aggregator->Evaluate(*gsink, agg_state, lastate.bounds, result, count, row_idx); + aggregator->Evaluate(context, *gsink, agg_state, lastate.bounds, result, count, row_idx); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_aggregator.cpp b/src/duckdb/src/function/window/window_aggregator.cpp index 197b89e38..10a3d53a8 100644 --- a/src/duckdb/src/function/window/window_aggregator.cpp +++ b/src/duckdb/src/function/window/window_aggregator.cpp @@ -36,16 +36,16 @@ unique_ptr WindowAggregator::GetGlobalState(ClientContext return make_uniq(context, *this, group_count); } -void WindowAggregatorLocalState::Sink(WindowAggregatorGlobalState &gastate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx) { +void WindowAggregatorLocalState::Sink(ExecutionContext &context, WindowAggregatorGlobalState &gastate, + DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx) { } -void WindowAggregator::Sink(WindowAggregatorState &gstate, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered) { +void WindowAggregator::Sink(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered) { auto &gastate = gstate.Cast(); auto &lastate = lstate.Cast(); - lastate.Sink(gastate, sink_chunk, coll_chunk, input_idx); + lastate.Sink(context, gastate, sink_chunk, coll_chunk, input_idx); if (filter_sel) { auto &filter_mask = gastate.filter_mask; for (idx_t f = 0; f < filtered; ++f) { @@ -71,18 +71,19 @@ void WindowAggregatorLocalState::InitSubFrames(SubFrames &frames, const WindowEx frames.resize(nframes, {0, 0}); } -void WindowAggregatorLocalState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { +void WindowAggregatorLocalState::Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, + CollectionPtr collection) { // Prepare to scan if (!cursor) { cursor = make_uniq(*collection, gastate.aggregator.child_idx); } } -void WindowAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) { +void WindowAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + CollectionPtr collection, const FrameStats &stats) { auto &gasink = gstate.Cast(); auto &lastate = lstate.Cast(); - lastate.Finalize(gasink, collection); + lastate.Finalize(context, gasink, collection); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_collection.cpp b/src/duckdb/src/function/window/window_collection.cpp index 0dee0cc84..f24de2237 100644 --- a/src/duckdb/src/function/window/window_collection.cpp +++ b/src/duckdb/src/function/window/window_collection.cpp @@ -143,4 +143,34 @@ WindowCursor::WindowCursor(const WindowCollection &paged, column_t col_idx) : WindowCursor(paged, vector(1, col_idx)) { } +LogicalType WindowCollectionChunkScanner::PrefixStructType(column_t end, column_t begin) { + child_list_t partition_children; + for (auto c = begin; c < end; ++c) { + auto name = std::to_string(c); + auto type = chunk.data[c].GetType(); + std::pair child {name, type}; + partition_children.emplace_back(child); + } + // For single children, don;t build a struct - compare will be slow + if (partition_children.size() == 1) { + return partition_children[0].second; + } + return LogicalType::STRUCT(partition_children); +} + +void WindowCollectionChunkScanner::ReferenceStructColumns(DataChunk &chunk, Vector &vec, column_t end, column_t begin) { + // Check for single column + const auto width = end - begin; + if (width == 1) { + vec.Reference(chunk.data[begin]); + return; + } + + auto &entries = StructVector::GetEntries(vec); + D_ASSERT(width == entries.size()); + for (column_t i = 0; i < entries.size(); ++i) { + entries[i]->Reference(chunk.data[begin + i]); + } +} + } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_constant_aggregator.cpp b/src/duckdb/src/function/window/window_constant_aggregator.cpp index 7ae1784b1..ca51b5be2 100644 --- a/src/duckdb/src/function/window/window_constant_aggregator.cpp +++ b/src/duckdb/src/function/window/window_constant_aggregator.cpp @@ -81,8 +81,8 @@ class WindowConstantAggregatorLocalState : public WindowAggregatorLocalState { ~WindowConstantAggregatorLocalState() override { } - void Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered); + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered); void Combine(WindowConstantAggregatorGlobalState &gstate); public: @@ -206,16 +206,16 @@ unique_ptr WindowConstantAggregator::GetGlobalState(Clien return make_uniq(context, *this, group_count, partition_mask); } -void WindowConstantAggregator::Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered) { +void WindowConstantAggregator::Sink(ExecutionContext &context, WindowAggregatorState &gsink, + WindowAggregatorState &lstate, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, optional_ptr filter_sel, idx_t filtered) { auto &lastate = lstate.Cast(); - lastate.Sink(sink_chunk, coll_chunk, input_idx, filter_sel, filtered); + lastate.Sink(context, sink_chunk, coll_chunk, input_idx, filter_sel, filtered); } -void WindowConstantAggregatorLocalState::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t row, - optional_ptr filter_sel, idx_t filtered) { +void WindowConstantAggregatorLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t row, optional_ptr filter_sel, idx_t filtered) { auto &partition_offsets = gstate.partition_offsets; const auto &aggr = gstate.aggr; const auto chunk_begin = row; @@ -298,8 +298,9 @@ void WindowConstantAggregatorLocalState::Sink(DataChunk &sink_chunk, DataChunk & } } -void WindowConstantAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats) { +void WindowConstantAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gstate, + WindowAggregatorState &lstate, CollectionPtr collection, + const FrameStats &stats) { auto &gastate = gstate.Cast(); auto &lastate = lstate.Cast(); @@ -317,8 +318,9 @@ unique_ptr WindowConstantAggregator::GetLocalState(const return make_uniq(gstate.Cast()); } -void WindowConstantAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { +void WindowConstantAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, + WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, + idx_t count, idx_t row_idx) const { auto &gasink = gsink.Cast(); const auto &partition_offsets = gasink.partition_offsets; const auto &results = *gasink.results; diff --git a/src/duckdb/src/function/window/window_custom_aggregator.cpp b/src/duckdb/src/function/window/window_custom_aggregator.cpp index 8416e3031..70f288969 100644 --- a/src/duckdb/src/function/window/window_custom_aggregator.cpp +++ b/src/duckdb/src/function/window/window_custom_aggregator.cpp @@ -1,5 +1,6 @@ #include "duckdb/function/window/window_custom_aggregator.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/common/enums/window_aggregation_mode.hpp" namespace duckdb { @@ -30,10 +31,10 @@ WindowCustomAggregator::WindowCustomAggregator(const BoundWindowExpression &wexp WindowCustomAggregator::~WindowCustomAggregator() { } -class WindowCustomAggregatorState : public WindowAggregatorLocalState { +class WindowCustomAggregatorLocalState : public WindowAggregatorLocalState { public: - WindowCustomAggregatorState(const AggregateObject &aggr, const WindowExcludeMode exclude_mode); - ~WindowCustomAggregatorState() override; + WindowCustomAggregatorLocalState(const AggregateObject &aggr, const WindowExcludeMode exclude_mode); + ~WindowCustomAggregatorLocalState() override; public: //! The aggregate function @@ -51,8 +52,7 @@ class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState { explicit WindowCustomAggregatorGlobalState(ClientContext &context, const WindowCustomAggregator &aggregator, idx_t group_count) : WindowAggregatorGlobalState(context, aggregator, group_count), context(context) { - - gcstate = make_uniq(aggr, aggregator.exclude_mode); + gcstate = make_uniq(aggr, aggregator.exclude_mode); } //! Buffer manager for paging custom accelerator data @@ -60,13 +60,13 @@ class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState { //! Traditional packed filter mask for API ValidityMask filter_packed; //! Data pointer that contains a single local state, used for global custom window execution state - unique_ptr gcstate; + unique_ptr gcstate; //! Partition description for custom window APIs unique_ptr partition_input; }; -WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject &aggr, - const WindowExcludeMode exclude_mode) +WindowCustomAggregatorLocalState::WindowCustomAggregatorLocalState(const AggregateObject &aggr, + const WindowExcludeMode exclude_mode) : aggr(aggr), state(aggr.function.state_size(aggr.function)), statef(Value::POINTER(CastPointerToValue(state.data()))), frames(3, {0, 0}) { // if we have a frame-by-frame method, share the single state @@ -75,7 +75,7 @@ WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject & InitSubFrames(frames, exclude_mode); } -WindowCustomAggregatorState::~WindowCustomAggregatorState() { +WindowCustomAggregatorLocalState::~WindowCustomAggregatorLocalState() { if (aggr.function.destructor) { AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); aggr.function.destructor(statef, aggr_input_data, 1); @@ -87,8 +87,9 @@ unique_ptr WindowCustomAggregator::GetGlobalState(ClientC return make_uniq(context, *this, group_count); } -void WindowCustomAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats) { +void WindowCustomAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gstate, + WindowAggregatorState &lstate, CollectionPtr collection, + const FrameStats &stats) { // Single threaded Finalize for now auto &gcsink = gstate.Cast(); lock_guard gestate_guard(gcsink.lock); @@ -96,7 +97,7 @@ void WindowCustomAggregator::Finalize(WindowAggregatorState &gstate, WindowAggre return; } - WindowAggregator::Finalize(gstate, lstate, collection, stats); + WindowAggregator::Finalize(context, gstate, lstate, collection, stats); auto inputs = collection->inputs.get(); const auto count = collection->size(); @@ -109,7 +110,7 @@ void WindowCustomAggregator::Finalize(WindowAggregatorState &gstate, WindowAggre filter_mask.Pack(filter_packed, filter_mask.Capacity()); gcsink.partition_input = - make_uniq(gcsink.context, inputs, count, child_idx, all_valids, filter_packed, stats); + make_uniq(context, inputs, count, child_idx, all_valids, filter_packed, stats); if (aggr.function.window_init) { auto &gcstate = *gcsink.gcstate; @@ -122,12 +123,13 @@ void WindowCustomAggregator::Finalize(WindowAggregatorState &gstate, WindowAggre } unique_ptr WindowCustomAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(aggr, exclude_mode); + return make_uniq(aggr, exclude_mode); } -void WindowCustomAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { - auto &lcstate = lstate.Cast(); +void WindowCustomAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, + WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, + idx_t count, idx_t row_idx) const { + auto &lcstate = lstate.Cast(); auto &frames = lcstate.frames; const_data_ptr_t gstate_p = nullptr; auto &gcsink = gsink.Cast(); diff --git a/src/duckdb/src/function/window/window_distinct_aggregator.cpp b/src/duckdb/src/function/window/window_distinct_aggregator.cpp index 0f98160d4..a868335e7 100644 --- a/src/duckdb/src/function/window/window_distinct_aggregator.cpp +++ b/src/duckdb/src/function/window/window_distinct_aggregator.cpp @@ -1,11 +1,10 @@ #include "duckdb/function/window/window_distinct_aggregator.hpp" -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/execution/merge_sort_tree.hpp" #include "duckdb/function/window/window_aggregate_states.hpp" #include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include @@ -13,6 +12,8 @@ namespace duckdb { +enum class WindowDistinctSortStage : uint8_t { INIT, COMBINE, FINALIZE, SORTED, FINISHED }; + //===--------------------------------------------------------------------===// // WindowDistinctAggregator //===--------------------------------------------------------------------===// @@ -39,7 +40,7 @@ class WindowDistinctSortTree : public MergeSortTree { using ZippedTuple = std::tuple; using ZippedElements = vector; - explicit WindowDistinctSortTree(WindowDistinctAggregatorGlobalState &gdastate, idx_t count) : gdastate(gdastate) { + WindowDistinctSortTree(WindowDistinctAggregatorGlobalState &gdastate, idx_t count) : gdastate(gdastate) { // Set up for parallel build build_level = 0; build_complete = 0; @@ -59,31 +60,31 @@ class WindowDistinctSortTree : public MergeSortTree { class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState { public: - using GlobalSortStatePtr = unique_ptr; - using LocalSortStatePtr = unique_ptr; using ZippedTuple = WindowDistinctSortTree::ZippedTuple; using ZippedElements = WindowDistinctSortTree::ZippedElements; WindowDistinctAggregatorGlobalState(ClientContext &context, const WindowDistinctAggregator &aggregator, idx_t group_count); - //! Compute the block starts - void MeasurePayloadBlocks(); //! Create a new local sort - optional_ptr InitializeLocalSort() const; + optional_ptr InitializeLocalSort(ExecutionContext &context) const; - //! Patch up the previous index block boundaries - void PatchPrevIdcs(); - bool TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate); + ArenaAllocator &CreateTreeAllocator() const { + lock_guard tree_lock(lock); + tree_allocators.emplace_back(make_uniq(Allocator::DefaultAllocator())); + return *tree_allocators.back(); + } - // Single threaded sorting for now - ClientContext &context; - idx_t memory_per_thread; + bool TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate); + //! The tree allocators. + //! We need to hold onto them for the tree lifetime, + //! not the lifetime of the local state that constructed part of the tree + mutable vector> tree_allocators; //! Finalize guard mutable mutex lock; //! Finalize stage - atomic stage; + atomic stage; //! Tasks launched idx_t total_tasks = 0; //! Tasks launched @@ -91,20 +92,18 @@ class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState { //! Tasks landed mutable atomic tasks_completed; - //! The sorted payload data types (partition index) - vector payload_types; //! The aggregate arguments + partition index vector sort_types; //! Sorting operations - GlobalSortStatePtr global_sort; - //! Local sort set - mutable vector local_sorts; - //! The block starts (the scanner doesn't know this) plus the total count - vector block_starts; - - //! The block boundary seconds - mutable ZippedElements seconds; + vector sort_cols; + unique_ptr sort; + unique_ptr global_sink; + //! Local sort sets + mutable vector> local_sinks; + //! The resulting sorted data + unique_ptr sorted; + //! The MST with the distinct back pointers mutable MergeSortTree zipped_tree; //! The merge sort tree for the aggregate. @@ -116,35 +115,29 @@ class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState { vector levels_flat_start; }; -WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(ClientContext &context, +WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(ClientContext &client, const WindowDistinctAggregator &aggregator, idx_t group_count) - : WindowAggregatorGlobalState(context, aggregator, group_count), context(aggregator.context), - stage(PartitionSortStage::INIT), tasks_assigned(0), tasks_completed(0), merge_sort_tree(*this, group_count), - levels_flat_native(aggr) { - payload_types.emplace_back(LogicalType::UBIGINT); + : WindowAggregatorGlobalState(client, aggregator, group_count), stage(WindowDistinctSortStage::INIT), + tasks_assigned(0), tasks_completed(0), merge_sort_tree(*this, group_count), levels_flat_native(aggr) { // 1: functionComputePrevIdcs(𝑖𝑛) // 2: sorted ← [] // We sort the aggregate arguments and use the partition index as a tie-breaker. // TODO: Use a hash table? sort_types = aggregator.arg_types; - for (const auto &type : payload_types) { - sort_types.emplace_back(type); - } + sort_types.emplace_back(LogicalType::UBIGINT); + // All expressions will be precomputed for sharing, so we jsut need to reference the arguments vector orders; for (const auto &type : sort_types) { - auto expr = make_uniq(Value(type)); + auto expr = make_uniq(type, orders.size()); orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr))); + sort_cols.emplace_back(sort_cols.size()); } - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - - global_sort = make_uniq(context, orders, payload_layout); - - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); + sort = make_uniq(client, orders, sort_types, sort_cols); + global_sink = sort->GetGlobalSinkState(client); // 6: prevIdcs ← [] // 7: prevIdcs[0] ← “-” @@ -176,14 +169,13 @@ WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(ClientC } } -optional_ptr WindowDistinctAggregatorGlobalState::InitializeLocalSort() const { +optional_ptr WindowDistinctAggregatorGlobalState::InitializeLocalSort(ExecutionContext &context) const { lock_guard local_sort_guard(lock); - auto local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); + auto local_sink = sort->GetLocalSinkState(context); ++tasks_assigned; - local_sorts.emplace_back(std::move(local_sort)); + local_sinks.emplace_back(std::move(local_sink)); - return local_sorts.back().get(); + return local_sinks.back().get(); } class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { @@ -194,18 +186,20 @@ class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { statef.Destroy(); } - void Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered); - void Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered); + void Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; void Sorted(); - void ExecuteTask(); - void Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx); + void ExecuteTask(ExecutionContext &context, WindowDistinctAggregatorGlobalState &gdstate); + void Evaluate(ExecutionContext &context, const WindowDistinctAggregatorGlobalState &gdstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx); + //! The thread-local allocator for building the tree + ArenaAllocator &tree_allocator; //! Thread-local sorting data - optional_ptr local_sort; + optional_ptr local_sink; //! Finalize stage - PartitionSortStage stage = PartitionSortStage::INIT; + WindowDistinctSortStage stage = WindowDistinctSortStage::INIT; //! Finalize scan block index idx_t block_idx; //! Thread-local tree aggregation @@ -220,9 +214,9 @@ class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { void FlushStates(); //! The aggregator we are working with - const WindowDistinctAggregatorGlobalState &gastate; + const WindowDistinctAggregatorGlobalState &gdstate; + //! The sort input chunk DataChunk sort_chunk; - DataChunk payload_chunk; //! Reused result state container for the window functions WindowAggregateStates statef; //! A vector of pointers to "state", used for buffering intermediate aggregates @@ -236,16 +230,15 @@ class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { }; WindowDistinctAggregatorLocalState::WindowDistinctAggregatorLocalState( - const WindowDistinctAggregatorGlobalState &gastate) - : update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), target_v(LogicalType::POINTER), gastate(gastate), - statef(gastate.aggr), statep(LogicalType::POINTER), statel(LogicalType::POINTER), flush_count(0) { - InitSubFrames(frames, gastate.aggregator.exclude_mode); - payload_chunk.Initialize(Allocator::DefaultAllocator(), gastate.payload_types); + const WindowDistinctAggregatorGlobalState &gdstate) + : tree_allocator(gdstate.CreateTreeAllocator()), update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), + target_v(LogicalType::POINTER), gdstate(gdstate), statef(gdstate.aggr), statep(LogicalType::POINTER), + statel(LogicalType::POINTER), flush_count(0) { + InitSubFrames(frames, gdstate.aggregator.exclude_mode); - sort_chunk.Initialize(Allocator::DefaultAllocator(), gastate.sort_types); - sort_chunk.data.back().Reference(payload_chunk.data[0]); + sort_chunk.Initialize(Allocator::DefaultAllocator(), gdstate.sort_types); - gastate.locals++; + gdstate.locals++; } unique_ptr WindowDistinctAggregator::GetGlobalState(ClientContext &context, idx_t group_count, @@ -253,190 +246,159 @@ unique_ptr WindowDistinctAggregator::GetGlobalState(Clien return make_uniq(context, *this, group_count); } -void WindowDistinctAggregator::Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered) { - WindowAggregator::Sink(gsink, lstate, sink_chunk, coll_chunk, input_idx, filter_sel, filtered); +void WindowDistinctAggregator::Sink(ExecutionContext &context, WindowAggregatorState &gsink, + WindowAggregatorState &lstate, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, optional_ptr filter_sel, idx_t filtered) { + WindowAggregator::Sink(context, gsink, lstate, sink_chunk, coll_chunk, input_idx, filter_sel, filtered); auto &ldstate = lstate.Cast(); - ldstate.Sink(sink_chunk, coll_chunk, input_idx, filter_sel, filtered); + ldstate.Sink(context, sink_chunk, coll_chunk, input_idx, filter_sel, filtered); } -void WindowDistinctAggregatorLocalState::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, - optional_ptr filter_sel, idx_t filtered) { +void WindowDistinctAggregatorLocalState::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + idx_t input_idx, optional_ptr filter_sel, + idx_t filtered) { // 3: for i ← 0 to in.size do // 4: sorted[i] ← (in[i], i) const auto count = sink_chunk.size(); - payload_chunk.Reset(); - auto &sorted_vec = payload_chunk.data[0]; + sort_chunk.Reset(); + auto &sorted_vec = sort_chunk.data.back(); auto sorted = FlatVector::GetData(sorted_vec); std::iota(sorted, sorted + count, input_idx); // Our arguments are being fully materialised, // but we also need them as sort keys. - auto &child_idx = gastate.aggregator.child_idx; + auto &child_idx = gdstate.aggregator.child_idx; for (column_t c = 0; c < child_idx.size(); ++c) { sort_chunk.data[c].Reference(coll_chunk.data[child_idx[c]]); } - sort_chunk.data.back().Reference(sorted_vec); sort_chunk.SetCardinality(sink_chunk); - payload_chunk.SetCardinality(sort_chunk); // Apply FILTER clause, if any if (filter_sel) { sort_chunk.Slice(*filter_sel, filtered); - payload_chunk.Slice(*filter_sel, filtered); } - if (!local_sort) { - local_sort = gastate.InitializeLocalSort(); + if (!local_sink) { + local_sink = gdstate.InitializeLocalSort(context); } - local_sort->SinkChunk(sort_chunk, payload_chunk); - - if (local_sort->SizeInBytes() > gastate.memory_per_thread) { - local_sort->Sort(*gastate.global_sort, true); - } + InterruptState interrupt_state; + OperatorSinkInput sink {*gdstate.global_sink, *local_sink, interrupt_state}; + gdstate.sort->Sink(context, sort_chunk, sink); } -void WindowDistinctAggregatorLocalState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { - WindowAggregatorLocalState::Finalize(gastate, collection); +void WindowDistinctAggregatorLocalState::Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, + CollectionPtr collection) { + WindowAggregatorLocalState::Finalize(context, gastate, collection); //! Input data chunk, used for leaf segment aggregation leaves.Initialize(Allocator::DefaultAllocator(), cursor->chunk.GetTypes()); sel.Initialize(); } -void WindowDistinctAggregatorLocalState::ExecuteTask() { - auto &global_sort = *gastate.global_sort; +void WindowDistinctAggregatorLocalState::ExecuteTask(ExecutionContext &context, + WindowDistinctAggregatorGlobalState &gdstate) { + PostIncrement> on_done(gdstate.tasks_completed); + switch (stage) { - case PartitionSortStage::SCAN: - global_sort.AddLocalState(*gastate.local_sorts[block_idx]); + case WindowDistinctSortStage::COMBINE: { + auto &local_sink = *gdstate.local_sinks[block_idx]; + InterruptState interrupt_state; + OperatorSinkCombineInput combine {*gdstate.global_sink, local_sink, interrupt_state}; + gdstate.sort->Combine(context, combine); break; - case PartitionSortStage::MERGE: { - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); + } + case WindowDistinctSortStage::FINALIZE: { + // 5: Sort sorted lexicographically increasing + auto &sort = *gdstate.sort; + InterruptState interrupt; + OperatorSinkFinalizeInput finalize {*gdstate.global_sink, interrupt}; + sort.Finalize(context.client, finalize); + auto sort_global = sort.GetGlobalSourceState(context.client, *gdstate.global_sink); + auto sort_local = sort.GetLocalSourceState(context, *sort_global); + OperatorSourceInput source {*sort_global, *sort_local, interrupt}; + sort.MaterializeColumnData(context, source); + gdstate.sorted = sort.GetColumnData(source); break; } - case PartitionSortStage::SORTED: + case WindowDistinctSortStage::SORTED: Sorted(); break; default: break; } - - ++gastate.tasks_completed; -} - -void WindowDistinctAggregatorGlobalState::MeasurePayloadBlocks() { - const auto &blocks = global_sort->sorted_blocks[0]->payload_data->data_blocks; - idx_t count = 0; - for (const auto &block : blocks) { - block_starts.emplace_back(count); - count += block->count; - } - block_starts.emplace_back(count); } bool WindowDistinctAggregatorGlobalState::TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate) { lock_guard stage_guard(lock); switch (stage.load()) { - case PartitionSortStage::INIT: - // 5: Sort sorted lexicographically increasing - total_tasks = local_sorts.size(); + case WindowDistinctSortStage::INIT: + total_tasks = local_sinks.size(); tasks_assigned = 0; tasks_completed = 0; - lstate.stage = stage = PartitionSortStage::SCAN; + lstate.stage = stage = WindowDistinctSortStage::COMBINE; lstate.block_idx = tasks_assigned++; return true; - case PartitionSortStage::SCAN: - // Process all the local sorts + case WindowDistinctSortStage::COMBINE: if (tasks_assigned < total_tasks) { - lstate.stage = PartitionSortStage::SCAN; + lstate.stage = WindowDistinctSortStage::COMBINE; lstate.block_idx = tasks_assigned++; return true; } else if (tasks_completed < tasks_assigned) { return false; } - global_sort->PrepareMergePhase(); - if (!(global_sort->sorted_blocks.size() / 2)) { - if (global_sort->sorted_blocks.empty()) { - lstate.stage = stage = PartitionSortStage::FINISHED; - return true; - } - MeasurePayloadBlocks(); - seconds.resize(block_starts.size() - 1); - total_tasks = seconds.size(); - tasks_completed = 0; - tasks_assigned = 0; - lstate.stage = stage = PartitionSortStage::SORTED; - lstate.block_idx = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.stage = stage = PartitionSortStage::MERGE; - total_tasks = locals; - tasks_assigned = 1; + // All combines are done, so move on to materialising the sorted data (1 task) + total_tasks = 1; tasks_completed = 0; + tasks_assigned = 0; + lstate.stage = stage = WindowDistinctSortStage::FINALIZE; + lstate.block_idx = tasks_assigned++; return true; - case PartitionSortStage::MERGE: - if (tasks_assigned < total_tasks) { - lstate.stage = PartitionSortStage::MERGE; - ++tasks_assigned; - return true; - } else if (tasks_completed < tasks_assigned) { + case WindowDistinctSortStage::FINALIZE: + if (tasks_completed < tasks_assigned) { + // Wait for the single task to finish return false; } - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - MeasurePayloadBlocks(); - seconds.resize(block_starts.size() - 1); - total_tasks = seconds.size(); - tasks_completed = 0; - tasks_assigned = 0; - lstate.stage = stage = PartitionSortStage::SORTED; - lstate.block_idx = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.stage = PartitionSortStage::MERGE; - total_tasks = locals; - tasks_assigned = 1; + // Move on to building the tree in parallel + total_tasks = local_sinks.size(); tasks_completed = 0; + tasks_assigned = 0; + lstate.stage = stage = WindowDistinctSortStage::SORTED; + lstate.block_idx = tasks_assigned++; return true; - case PartitionSortStage::SORTED: + case WindowDistinctSortStage::SORTED: if (tasks_assigned < total_tasks) { - lstate.stage = PartitionSortStage::SORTED; + lstate.stage = WindowDistinctSortStage::SORTED; lstate.block_idx = tasks_assigned++; return true; } else if (tasks_completed < tasks_assigned) { - lstate.stage = PartitionSortStage::FINISHED; + lstate.stage = WindowDistinctSortStage::FINISHED; // Sleep while other tasks finish return false; } - // Last task patches the boundaries - PatchPrevIdcs(); break; default: break; } - lstate.stage = stage = PartitionSortStage::FINISHED; + lstate.stage = stage = WindowDistinctSortStage::FINISHED; return true; } -void WindowDistinctAggregator::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats) { +void WindowDistinctAggregator::Finalize(ExecutionContext &context, WindowAggregatorState &gsink, + WindowAggregatorState &lstate, CollectionPtr collection, + const FrameStats &stats) { auto &gdsink = gsink.Cast(); auto &ldstate = lstate.Cast(); - ldstate.Finalize(gdsink, collection); + ldstate.Finalize(context, gdsink, collection); // Sort, merge and build the tree in parallel - while (gdsink.stage.load() != PartitionSortStage::FINISHED) { + while (gdsink.stage.load() != WindowDistinctSortStage::FINISHED) { if (gdsink.TryPrepareNextStage(ldstate)) { - ldstate.ExecuteTask(); + ldstate.ExecuteTask(context, gdsink); } else { std::this_thread::yield(); } @@ -452,89 +414,71 @@ void WindowDistinctAggregator::Finalize(WindowAggregatorState &gsink, WindowAggr void WindowDistinctAggregatorLocalState::Sorted() { using ZippedTuple = WindowDistinctAggregatorGlobalState::ZippedTuple; - auto &global_sort = gastate.global_sort; - auto &prev_idcs = gastate.zipped_tree.LowestLevel(); - auto &aggregator = gastate.aggregator; - auto &scan_chunk = payload_chunk; + auto &collection = *gdstate.sorted; + auto &prev_idcs = gdstate.zipped_tree.LowestLevel(); + auto &aggregator = gdstate.aggregator; - auto scanner = make_uniq(*global_sort, block_idx); - const auto in_size = gastate.block_starts.at(block_idx + 1); - scanner->Scan(scan_chunk); - idx_t scan_idx = 0; + // Find our chunk range + const auto block_begin = (block_idx * collection.ChunkCount()) / gdstate.total_tasks; + const auto block_end = ((block_idx + 1) * collection.ChunkCount()) / gdstate.total_tasks; - auto *input_idx = FlatVector::GetData(scan_chunk.data[0]); - idx_t i = 0; + const auto &scan_cols = gdstate.sort_cols; + const auto key_count = aggregator.arg_types.size(); - SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - auto prefix_layout = global_sort->sort_layout.GetPrefixComparisonLayout(aggregator.arg_types.size()); + // Setting up the first row is a bit tricky - we have to scan the first value ourselves + WindowCollectionChunkScanner scanner(collection, scan_cols, block_begin ? block_begin - 1 : 0); + auto &scanned = scanner.chunk; + if (!scanner.Scan()) { + return; + } - const auto block_begin = gastate.block_starts.at(block_idx); + idx_t prev_i = 0; if (!block_begin) { // First block, so set up initial sentinel - i = input_idx[scan_idx++]; - prev_idcs[i] = ZippedTuple(0, i); - std::get<0>(gastate.seconds[block_idx]) = i; + auto input_idx = FlatVector::GetData(scanned.data.back()); + prev_i = input_idx[0]; + prev_idcs[prev_i] = ZippedTuple(0, prev_i); } else { // Move to the to end of the previous block // so we can record the comparison result for the first row - curr.SetIndex(block_begin - 1); - prev.SetIndex(block_begin - 1); - scan_idx = 0; - std::get<0>(gastate.seconds[block_idx]) = input_idx[scan_idx]; + auto input_idx = FlatVector::GetData(scanned.data.back()); + auto scan_idx = scanned.size() - 1; + prev_i = input_idx[scan_idx]; } // 8: for i ← 1 to in.size do - for (++curr; curr.GetIndex() < in_size; ++curr, ++prev) { - // Scan second one chunk at a time - // Note the scan is one behind the iterators - if (scan_idx >= scan_chunk.size()) { - scan_chunk.Reset(); - scanner->Scan(scan_chunk); - scan_idx = 0; - input_idx = FlatVector::GetData(scan_chunk.data[0]); - } - auto second = i; - i = input_idx[scan_idx++]; + WindowDeltaScanner(collection, block_begin, block_end, scan_cols, key_count, + [&](const idx_t row_idx, DataChunk &prev, DataChunk &curr, const idx_t ndistinct, + SelectionVector &distinct, const SelectionVector &matching) { + const auto count = MinValue(prev.size(), curr.size()); + + // The input index has probably been sliced. + UnifiedVectorFormat input_format; + curr.data.back().ToUnifiedFormat(count, input_format); + auto input_idx = UnifiedVectorFormat::GetData(input_format); + + const auto nmatch = count - ndistinct; + // 9: if sorted[i].first == sorted[i-1].first then + // 10: prevIdcs[i] ← sorted[i-1].second + for (idx_t j = 0; j < nmatch; ++j) { + auto scan_idx = matching.get_index(j); + auto i = input_idx[input_format.sel->get_index(scan_idx)]; + auto second = scan_idx ? input_idx[input_format.sel->get_index(scan_idx - 1)] : prev_i; + prev_idcs[i] = ZippedTuple(second + 1, i); + } + // 11: else + // 12: prevIdcs[i] ← “-” + for (idx_t j = 0; j < ndistinct; ++j) { + auto scan_idx = distinct.get_index(j); + auto i = input_idx[input_format.sel->get_index(scan_idx)]; + prev_idcs[i] = ZippedTuple(0, i); + } + + // Remember the last input_idx of this chunk. + prev_i = input_idx[input_format.sel->get_index(count - 1)]; + }); - int lt = 0; - if (prefix_layout.all_constant) { - lt = FastMemcmp(prev.entry_ptr, curr.entry_ptr, prefix_layout.comparison_size); - } else { - lt = Comparators::CompareTuple(prev.scan, curr.scan, prev.entry_ptr, curr.entry_ptr, prefix_layout, - prev.external); - } - - // 9: if sorted[i].first == sorted[i-1].first then - // 10: prevIdcs[i] ← sorted[i-1].second - // 11: else - // 12: prevIdcs[i] ← “-” - if (!lt) { - prev_idcs[i] = ZippedTuple(second + 1, i); - } else { - prev_idcs[i] = ZippedTuple(0, i); - } - } - - // Save the last value of i for patching up the block boundaries - std::get<1>(gastate.seconds[block_idx]) = i; -} - -void WindowDistinctAggregatorGlobalState::PatchPrevIdcs() { // 13: return prevIdcs - - // Patch up the indices at block boundaries - // (We don't need to patch block 0.) - auto &prev_idcs = zipped_tree.LowestLevel(); - for (idx_t block_idx = 1; block_idx < seconds.size(); ++block_idx) { - // We only need to patch if the first index in the block - // was a back link to the previous block (10:) - auto i = std::get<0>(seconds.at(block_idx)); - if (std::get<0>(prev_idcs[i])) { - auto second = std::get<1>(seconds.at(block_idx - 1)); - prev_idcs[i] = ZippedTuple(second + 1, i); - } - } } bool WindowDistinctSortTree::TryNextRun(idx_t &level_idx, idx_t &run_idx) { @@ -593,7 +537,6 @@ void WindowDistinctSortTree::Build(WindowDistinctAggregatorLocalState &ldastate) void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDistinctAggregatorLocalState &ldastate) { auto &aggr = gdastate.aggr; - auto &allocator = gdastate.allocator; auto &inputs = ldastate.cursor->chunk; auto &levels_flat_native = gdastate.levels_flat_native; @@ -601,7 +544,7 @@ void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDisti auto &leaves = ldastate.leaves; auto &sel = ldastate.sel; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), ldastate.tree_allocator); //! The states to update auto &update_v = ldastate.update_v; @@ -697,15 +640,16 @@ void WindowDistinctAggregatorLocalState::FlushStates() { return; } - const auto &aggr = gastate.aggr; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); + const auto &aggr = gdstate.aggr; + AggregateInputData aggr_input_data(aggr.GetFunctionData(), tree_allocator); statel.Verify(flush_count); aggr.function.combine(statel, statep, aggr_input_data, flush_count); flush_count = 0; } -void WindowDistinctAggregatorLocalState::Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, +void WindowDistinctAggregatorLocalState::Evaluate(ExecutionContext &context, + const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) { auto ldata = FlatVector::GetData(statel); auto pdata = FlatVector::GetData(statep); @@ -750,15 +694,17 @@ void WindowDistinctAggregatorLocalState::Evaluate(const WindowDistinctAggregator } unique_ptr WindowDistinctAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(gstate.Cast()); + auto &gdstate = gstate.Cast(); + return make_uniq(gdstate); } -void WindowDistinctAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { +void WindowDistinctAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, + WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, + idx_t count, idx_t row_idx) const { const auto &gdstate = gsink.Cast(); auto &ldstate = lstate.Cast(); - ldstate.Evaluate(gdstate, bounds, result, count, row_idx); + ldstate.Evaluate(context, gdstate, bounds, result, count, row_idx); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_executor.cpp b/src/duckdb/src/function/window/window_executor.cpp index 7329dca2d..90d1b6569 100644 --- a/src/duckdb/src/function/window/window_executor.cpp +++ b/src/duckdb/src/function/window/window_executor.cpp @@ -7,17 +7,18 @@ namespace duckdb { //===--------------------------------------------------------------------===// -// WindowExecutorBoundsState +// WindowExecutorBoundsLocalState //===--------------------------------------------------------------------===// -WindowExecutorBoundsState::WindowExecutorBoundsState(const WindowExecutorGlobalState &gstate) - : WindowExecutorLocalState(gstate), partition_mask(gstate.partition_mask), order_mask(gstate.order_mask), +WindowExecutorBoundsLocalState::WindowExecutorBoundsLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) + : WindowExecutorLocalState(context, gstate), partition_mask(gstate.partition_mask), order_mask(gstate.order_mask), state(gstate.executor.wexpr, gstate.payload_count) { vector bounds_types(8, LogicalType(LogicalTypeId::UBIGINT)); - bounds.Initialize(Allocator::Get(gstate.executor.context), bounds_types); + bounds.Initialize(Allocator::Get(gstate.client), bounds_types); } -void WindowExecutorBoundsState::UpdateBounds(WindowExecutorGlobalState &gstate, idx_t row_idx, DataChunk &eval_chunk, - optional_ptr range) { +void WindowExecutorBoundsLocalState::UpdateBounds(WindowExecutorGlobalState &gstate, idx_t row_idx, + DataChunk &eval_chunk, optional_ptr range) { // Evaluate the row-level arguments WindowInputExpression boundary_start(eval_chunk, gstate.executor.boundary_start_idx); WindowInputExpression boundary_end(eval_chunk, gstate.executor.boundary_end_idx); @@ -29,8 +30,8 @@ void WindowExecutorBoundsState::UpdateBounds(WindowExecutorGlobalState &gstate, //===--------------------------------------------------------------------===// // WindowExecutor //===--------------------------------------------------------------------===// -WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared) - : wexpr(wexpr), context(context), +WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : wexpr(wexpr), range_expr((WindowBoundariesState::HasPrecedingRange(wexpr) || WindowBoundariesState::HasFollowingRange(wexpr)) ? wexpr.orders[0].expression.get() : nullptr) { @@ -46,57 +47,62 @@ bool WindowExecutor::IgnoreNulls() const { return wexpr.ignore_nulls; } -void WindowExecutor::Evaluate(idx_t row_idx, DataChunk &eval_chunk, Vector &result, WindowExecutorLocalState &lstate, - WindowExecutorGlobalState &gstate) const { - auto &lbstate = lstate.Cast(); +void WindowExecutor::Evaluate(ExecutionContext &context, idx_t row_idx, DataChunk &eval_chunk, Vector &result, + WindowExecutorLocalState &lstate, WindowExecutorGlobalState &gstate) const { + auto &lbstate = lstate.Cast(); lbstate.UpdateBounds(gstate, row_idx, eval_chunk, lstate.range_cursor); const auto count = eval_chunk.size(); - EvaluateInternal(gstate, lstate, eval_chunk, result, count, row_idx); + EvaluateInternal(context, gstate, lstate, eval_chunk, result, count, row_idx); result.Verify(count); } -WindowExecutorGlobalState::WindowExecutorGlobalState(const WindowExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : executor(executor), payload_count(payload_count), partition_mask(partition_mask), order_mask(order_mask) { +WindowExecutorGlobalState::WindowExecutorGlobalState(ClientContext &client, const WindowExecutor &executor, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : client(client), executor(executor), payload_count(payload_count), partition_mask(partition_mask), + order_mask(order_mask) { for (const auto &child : executor.wexpr.children) { arg_types.emplace_back(child->return_type); } } -WindowExecutorLocalState::WindowExecutorLocalState(const WindowExecutorGlobalState &gstate) { +WindowExecutorLocalState::WindowExecutorLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) { } -void WindowExecutorLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { +void WindowExecutorLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) { } -void WindowExecutorLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { +void WindowExecutorLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + CollectionPtr collection) { const auto range_idx = gstate.executor.range_idx; if (range_idx != DConstants::INVALID_INDEX) { range_cursor = make_uniq(*collection, range_idx); } } -unique_ptr WindowExecutor::GetGlobalState(const idx_t payload_count, +unique_ptr WindowExecutor::GetGlobalState(ClientContext &client, const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); + return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr WindowExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate); +unique_ptr WindowExecutor::GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const { + return make_uniq(context, gstate); } -void WindowExecutor::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, - WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { - lstate.Sink(gstate, sink_chunk, coll_chunk, input_idx); +void WindowExecutor::Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, + const idx_t input_idx, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate) const { + lstate.Sink(context, gstate, sink_chunk, coll_chunk, input_idx); } -void WindowExecutor::Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection) const { - lstate.Finalize(gstate, collection); +void WindowExecutor::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, CollectionPtr collection) const { + lstate.Finalize(context, gstate, collection); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_index_tree.cpp b/src/duckdb/src/function/window/window_index_tree.cpp index 5aed5b092..78c70c60b 100644 --- a/src/duckdb/src/function/window/window_index_tree.cpp +++ b/src/duckdb/src/function/window/window_index_tree.cpp @@ -14,37 +14,46 @@ WindowIndexTree::WindowIndexTree(ClientContext &context, const BoundOrderModifie : WindowIndexTree(context, order_bys.orders, sort_idx, count) { } -unique_ptr WindowIndexTree::GetLocalState() { - return make_uniq(*this); +unique_ptr WindowIndexTree::GetLocalState(ExecutionContext &context) { + return make_uniq(context, *this); } -WindowIndexTreeLocalState::WindowIndexTreeLocalState(WindowIndexTree &index_tree) - : WindowMergeSortTreeLocalState(index_tree), index_tree(index_tree) { +WindowIndexTreeLocalState::WindowIndexTreeLocalState(ExecutionContext &context, WindowIndexTree &index_tree) + : WindowMergeSortTreeLocalState(context, index_tree), index_tree(index_tree) { } void WindowIndexTreeLocalState::BuildLeaves() { - auto &global_sort = *index_tree.global_sort; - if (global_sort.sorted_blocks.empty()) { + auto &collection = *window_tree.sorted; + if (!collection.Count()) { return; } - PayloadScanner scanner(global_sort, build_task); - idx_t row_idx = index_tree.block_starts[build_task]; - for (;;) { - payload_chunk.Reset(); - scanner.Scan(payload_chunk); + // Find our chunk range + const auto block_begin = (build_task * collection.ChunkCount()) / window_tree.total_tasks; + const auto block_end = ((build_task + 1) * collection.ChunkCount()) / window_tree.total_tasks; + + // Scan the index column (the last one) + vector index_ids(1, window_tree.scan_cols.size() - 1); + WindowCollectionChunkScanner scanner(collection, index_ids, block_begin); + auto &payload_chunk = scanner.chunk; + + idx_t row_idx = scanner.Scanned(); + for (auto block_curr = block_begin; block_curr < block_end; ++block_curr) { + if (!scanner.Scan()) { + break; + } const auto count = payload_chunk.size(); if (count == 0) { break; } auto &indices = payload_chunk.data[0]; - if (index_tree.mst32) { - auto &sorted = index_tree.mst32->LowestLevel(); - auto data = FlatVector::GetDataUnsafe(indices); + if (window_tree.mst32) { + auto &sorted = window_tree.mst32->LowestLevel(); + auto data = FlatVector::GetData(indices); std::copy(data, data + count, sorted.data() + row_idx); } else { - auto &sorted = index_tree.mst64->LowestLevel(); - auto data = FlatVector::GetDataUnsafe(indices); + auto &sorted = window_tree.mst64->LowestLevel(); + auto data = FlatVector::GetData(indices); std::copy(data, data + count, sorted.data() + row_idx); } row_idx += count; diff --git a/src/duckdb/src/function/window/window_merge_sort_tree.cpp b/src/duckdb/src/function/window/window_merge_sort_tree.cpp index 812c945fc..9ffd269bb 100644 --- a/src/duckdb/src/function/window/window_merge_sort_tree.cpp +++ b/src/duckdb/src/function/window/window_merge_sort_tree.cpp @@ -1,17 +1,16 @@ #include "duckdb/function/window/window_merge_sort_tree.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include #include namespace duckdb { -WindowMergeSortTree::WindowMergeSortTree(ClientContext &context, const vector &orders, - const vector &sort_idx, const idx_t count, bool unique) - : context(context), memory_per_thread(PhysicalOperator::GetMaxThreadMemory(context)), sort_idx(sort_idx), - build_stage(PartitionSortStage::INIT), tasks_completed(0) { +WindowMergeSortTree::WindowMergeSortTree(ClientContext &client, const vector &orders_p, + const vector &order_idx, const idx_t count, bool unique) + : order_idx(order_idx), build_stage(WindowMergeSortStage::INIT), tasks_completed(0) { // Sort the unfiltered indices by the orders - const auto force_external = ClientConfig::GetConfig(context).force_external; + const auto force_external = ClientConfig::GetConfig(client).force_external; LogicalType index_type; if (count < std::numeric_limits::max() && !force_external) { index_type = LogicalType::INTEGER; @@ -21,105 +20,107 @@ WindowMergeSortTree::WindowMergeSortTree(ClientContext &context, const vector(); } - vector payload_types; - payload_types.emplace_back(index_type); + vector orders; + for (const auto &order_p : orders_p) { + auto order = order_p.Copy(); + const auto &type = order.expression->return_type; + scan_types.emplace_back(type); + order.expression = make_uniq(type, orders.size()); + orders.emplace_back(std::move(order)); + scan_cols.emplace_back(scan_cols.size()); + key_cols.emplace_back(key_cols.size()); + } - RowLayout payload_layout; - payload_layout.Initialize(payload_types); + // Also track the index type + scan_types.emplace_back(index_type); + scan_cols.emplace_back(scan_cols.size()); + // If the caller wants disambiguation, also sort by the index column if (unique) { - vector unique_orders; - for (const auto &order : orders) { - unique_orders.emplace_back(order.Copy()); - } - auto unique_expr = make_uniq(Value(index_type)); + auto unique_expr = make_uniq(scan_types.back(), orders.size()); const auto order_type = OrderType::ASCENDING; const auto order_by_type = OrderByNullType::NULLS_LAST; - unique_orders.emplace_back(BoundOrderByNode(order_type, order_by_type, std::move(unique_expr))); - global_sort = make_uniq(context, unique_orders, payload_layout); - } else { - global_sort = make_uniq(context, orders, payload_layout); + orders.emplace_back(BoundOrderByNode(order_type, order_by_type, std::move(unique_expr))); + key_cols.emplace_back(key_cols.size()); } - global_sort->external = force_external; + + sort = make_uniq(client, orders, scan_types, scan_cols); + + global_sink = sort->GetGlobalSinkState(client); } -optional_ptr WindowMergeSortTree::AddLocalSort() { +optional_ptr WindowMergeSortTree::InitializeLocalSort(ExecutionContext &context) const { lock_guard local_sort_guard(lock); - auto local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); - local_sorts.emplace_back(std::move(local_sort)); + auto local_sink = sort->GetLocalSinkState(context); + local_sinks.emplace_back(std::move(local_sink)); - return local_sorts.back().get(); + return local_sinks.back().get(); } -WindowMergeSortTreeLocalState::WindowMergeSortTreeLocalState(WindowMergeSortTree &window_tree) +WindowMergeSortTreeLocalState::WindowMergeSortTreeLocalState(ExecutionContext &context, + WindowMergeSortTree &window_tree) : window_tree(window_tree) { - sort_chunk.Initialize(window_tree.context, window_tree.global_sort->sort_layout.logical_types); - payload_chunk.Initialize(window_tree.context, window_tree.global_sort->payload_layout.GetTypes()); - local_sort = window_tree.AddLocalSort(); + sort_chunk.Initialize(context.client, window_tree.scan_types); + local_sink = window_tree.InitializeLocalSort(context); } -void WindowMergeSortTreeLocalState::SinkChunk(DataChunk &chunk, const idx_t row_idx, - optional_ptr filter_sel, idx_t filtered) { +void WindowMergeSortTreeLocalState::Sink(ExecutionContext &context, DataChunk &chunk, const idx_t row_idx, + optional_ptr filter_sel, idx_t filtered) { // Sequence the payload column - auto &indices = payload_chunk.data[0]; - payload_chunk.SetCardinality(chunk); - indices.Sequence(int64_t(row_idx), 1, payload_chunk.size()); - - // Reference the sort columns - auto &sort_idx = window_tree.sort_idx; - for (column_t c = 0; c < sort_idx.size(); ++c) { - sort_chunk.data[c].Reference(chunk.data[sort_idx[c]]); - } - // Add the row numbers if we are uniquifying - if (sort_idx.size() < sort_chunk.ColumnCount()) { - sort_chunk.data[sort_idx.size()].Reference(indices); + sort_chunk.Reset(); + auto &indices = sort_chunk.data.back(); + indices.Sequence(int64_t(row_idx), 1, chunk.size()); + + // Reference the ORDER BY columns + auto &order_idx = window_tree.order_idx; + for (column_t c = 0; c < order_idx.size(); ++c) { + sort_chunk.data[c].Reference(chunk.data[order_idx[c]]); } sort_chunk.SetCardinality(chunk); // Apply FILTER clause, if any if (filter_sel) { sort_chunk.Slice(*filter_sel, filtered); - payload_chunk.Slice(*filter_sel, filtered); } - local_sort->SinkChunk(sort_chunk, payload_chunk); - - // Flush if we have too much data - if (local_sort->SizeInBytes() > window_tree.memory_per_thread) { - local_sort->Sort(*window_tree.global_sort, true); - } + InterruptState interrupt; + OperatorSinkInput sink {*window_tree.global_sink, *local_sink, interrupt}; + window_tree.sort->Sink(context, sort_chunk, sink); } -void WindowMergeSortTreeLocalState::ExecuteSortTask() { +void WindowMergeSortTreeLocalState::ExecuteSortTask(ExecutionContext &context) { + PostIncrement> on_completed(window_tree.tasks_completed); + switch (build_stage) { - case PartitionSortStage::SCAN: - window_tree.global_sort->AddLocalState(*window_tree.local_sorts[build_task]); + case WindowMergeSortStage::COMBINE: { + auto &local_sink = *window_tree.local_sinks[build_task]; + InterruptState interrupt_state; + OperatorSinkCombineInput combine {*window_tree.global_sink, local_sink, interrupt_state}; + window_tree.sort->Combine(context, combine); break; - case PartitionSortStage::MERGE: { - auto &global_sort = *window_tree.global_sort; - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); + } + case WindowMergeSortStage::FINALIZE: { + auto &sort = *window_tree.sort; + InterruptState interrupt; + OperatorSinkFinalizeInput finalize {*window_tree.global_sink, interrupt}; + sort.Finalize(context.client, finalize); + auto sort_global = sort.GetGlobalSourceState(context.client, *window_tree.global_sink); + auto sort_local = sort.GetLocalSourceState(context, *sort_global); + OperatorSourceInput source {*sort_global, *sort_local, interrupt}; + sort.MaterializeColumnData(context, source); + window_tree.sorted = sort.GetColumnData(source); break; } - case PartitionSortStage::SORTED: + case WindowMergeSortStage::SORTED: BuildLeaves(); break; default: break; } - - ++window_tree.tasks_completed; } idx_t WindowMergeSortTree::MeasurePayloadBlocks() { - const auto &blocks = global_sort->sorted_blocks[0]->payload_data->data_blocks; - idx_t count = 0; - for (const auto &block : blocks) { - block_starts.emplace_back(count); - count += block->count; - } - block_starts.emplace_back(count); + const auto count = sorted->Count(); // Allocate the leaves. if (mst32) { @@ -133,130 +134,78 @@ idx_t WindowMergeSortTree::MeasurePayloadBlocks() { return count; } -void WindowMergeSortTreeLocalState::BuildLeaves() { - auto &global_sort = *window_tree.global_sort; - if (global_sort.sorted_blocks.empty()) { - return; - } - - PayloadScanner scanner(global_sort, build_task); - idx_t row_idx = window_tree.block_starts[build_task]; - for (;;) { - payload_chunk.Reset(); - scanner.Scan(payload_chunk); - const auto count = payload_chunk.size(); - if (count == 0) { - break; - } - auto &indices = payload_chunk.data[0]; - if (window_tree.mst32) { - auto &sorted = window_tree.mst32->LowestLevel(); - auto data = FlatVector::GetData(indices); - std::copy(data, data + count, sorted.data() + row_idx); - } else { - auto &sorted = window_tree.mst64->LowestLevel(); - auto data = FlatVector::GetData(indices); - std::copy(data, data + count, sorted.data() + row_idx); - } - row_idx += count; - } -} - -void WindowMergeSortTree::CleanupSort() { - global_sort.reset(); - local_sorts.clear(); +void WindowMergeSortTree::Finished() { + global_sink.reset(); + local_sinks.clear(); + sorted.reset(); } bool WindowMergeSortTree::TryPrepareSortStage(WindowMergeSortTreeLocalState &lstate) { lock_guard stage_guard(lock); switch (build_stage.load()) { - case PartitionSortStage::INIT: - total_tasks = local_sorts.size(); + case WindowMergeSortStage::INIT: + total_tasks = local_sinks.size(); tasks_assigned = 0; tasks_completed = 0; - lstate.build_stage = build_stage = PartitionSortStage::SCAN; + lstate.build_stage = build_stage = WindowMergeSortStage::COMBINE; lstate.build_task = tasks_assigned++; return true; - case PartitionSortStage::SCAN: + case WindowMergeSortStage::COMBINE: // Process all the local sorts if (tasks_assigned < total_tasks) { - lstate.build_stage = PartitionSortStage::SCAN; + lstate.build_stage = WindowMergeSortStage::COMBINE; lstate.build_task = tasks_assigned++; return true; } else if (tasks_completed < tasks_assigned) { return false; } - global_sort->PrepareMergePhase(); - if (!(global_sort->sorted_blocks.size() / 2)) { - if (global_sort->sorted_blocks.empty()) { - lstate.build_stage = build_stage = PartitionSortStage::FINISHED; - return true; - } - MeasurePayloadBlocks(); - total_tasks = block_starts.size() - 1; - tasks_completed = 0; - tasks_assigned = 0; - lstate.build_stage = build_stage = PartitionSortStage::SORTED; - lstate.build_task = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.build_stage = build_stage = PartitionSortStage::MERGE; - total_tasks = local_sorts.size(); - tasks_assigned = 1; + // All combines are done, so move on to materialising the sorted data (1 task) + total_tasks = 1; tasks_completed = 0; + tasks_assigned = 0; + lstate.build_stage = build_stage = WindowMergeSortStage::FINALIZE; + lstate.build_task = tasks_assigned++; return true; - case PartitionSortStage::MERGE: - if (tasks_assigned < total_tasks) { - lstate.build_stage = PartitionSortStage::MERGE; - ++tasks_assigned; - return true; - } else if (tasks_completed < tasks_assigned) { + case WindowMergeSortStage::FINALIZE: + if (tasks_completed < tasks_assigned) { + // Wait for the single task to finish return false; } - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - MeasurePayloadBlocks(); - total_tasks = block_starts.size() - 1; - tasks_completed = 0; - tasks_assigned = 0; - lstate.build_stage = build_stage = PartitionSortStage::SORTED; - lstate.build_task = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.build_stage = PartitionSortStage::MERGE; - total_tasks = local_sorts.size(); - tasks_assigned = 1; + // Move on to building the tree in parallel + MeasurePayloadBlocks(); + total_tasks = local_sinks.size(); tasks_completed = 0; + tasks_assigned = 0; + lstate.build_stage = build_stage = WindowMergeSortStage::SORTED; + lstate.build_task = tasks_assigned++; return true; - case PartitionSortStage::SORTED: + case WindowMergeSortStage::SORTED: if (tasks_assigned < total_tasks) { - lstate.build_stage = PartitionSortStage::SORTED; + lstate.build_stage = WindowMergeSortStage::SORTED; lstate.build_task = tasks_assigned++; return true; } else if (tasks_completed < tasks_assigned) { - lstate.build_stage = PartitionSortStage::FINISHED; + lstate.build_stage = WindowMergeSortStage::FINISHED; // Sleep while other tasks finish return false; } - CleanupSort(); + Finished(); break; - default: + case WindowMergeSortStage::FINISHED: break; } - lstate.build_stage = build_stage = PartitionSortStage::FINISHED; + lstate.build_stage = build_stage = WindowMergeSortStage::FINISHED; return true; } -void WindowMergeSortTreeLocalState::Sort() { +void WindowMergeSortTreeLocalState::Finalize(ExecutionContext &context) { // Sort, merge and build the tree in parallel - while (window_tree.build_stage.load() != PartitionSortStage::FINISHED) { + while (window_tree.build_stage.load() != WindowMergeSortStage::FINISHED) { if (window_tree.TryPrepareSortStage(*this)) { - ExecuteSortTask(); + ExecuteSortTask(context); } else { std::this_thread::yield(); } diff --git a/src/duckdb/src/function/window/window_naive_aggregator.cpp b/src/duckdb/src/function/window/window_naive_aggregator.cpp index 5d7c74971..35ba22cbd 100644 --- a/src/duckdb/src/function/window/window_naive_aggregator.cpp +++ b/src/duckdb/src/function/window/window_naive_aggregator.cpp @@ -1,7 +1,8 @@ #include "duckdb/function/window/window_naive_aggregator.hpp" -#include "duckdb/common/sort/sort.hpp" +#include "duckdb/common/sorting/sort.hpp" #include "duckdb/function/window/window_collection.hpp" #include "duckdb/function/window/window_shared_expressions.hpp" +#include "duckdb/planner/expression/bound_reference_expression.hpp" #include "duckdb/planner/expression/bound_window_expression.hpp" #include "duckdb/function/window/window_aggregate_function.hpp" @@ -21,38 +22,38 @@ WindowNaiveAggregator::WindowNaiveAggregator(const WindowAggregateExecutor &exec WindowNaiveAggregator::~WindowNaiveAggregator() { } -class WindowNaiveState : public WindowAggregatorLocalState { +class WindowNaiveLocalState : public WindowAggregatorLocalState { public: struct HashRow { - explicit HashRow(WindowNaiveState &state) : state(state) { + explicit HashRow(WindowNaiveLocalState &state) : state(state) { } inline size_t operator()(const idx_t &i) const { return state.Hash(i); } - WindowNaiveState &state; + WindowNaiveLocalState &state; }; struct EqualRow { - explicit EqualRow(WindowNaiveState &state) : state(state) { + explicit EqualRow(WindowNaiveLocalState &state) : state(state) { } inline bool operator()(const idx_t &lhs, const idx_t &rhs) const { return state.KeyEqual(lhs, rhs); } - WindowNaiveState &state; + WindowNaiveLocalState &state; }; using RowSet = std::unordered_set; - explicit WindowNaiveState(const WindowNaiveAggregator &gsink); + explicit WindowNaiveLocalState(const WindowNaiveAggregator &gsink); - void Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; + void Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; - void Evaluate(const WindowAggregatorGlobalState &gsink, const DataChunk &bounds, Vector &result, idx_t count, - idx_t row_idx); + void Evaluate(ExecutionContext &context, const WindowAggregatorGlobalState &gsink, const DataChunk &bounds, + Vector &result, idx_t count, idx_t row_idx); protected: //! Flush the accumulated intermediate states into the result states @@ -85,19 +86,19 @@ class WindowNaiveState : public WindowAggregatorLocalState { unique_ptr comparer; //! The state used for scanning ORDER BY values from the collection + unique_ptr sort; + //! The order by collection unique_ptr arg_orderer; //! Reusable sort key chunk - DataChunk orderby_sort; + DataChunk orderby_sink; //! Reusable sort payload chunk - DataChunk orderby_payload; + DataChunk orderby_scan; //! Reusable sort key slicer SelectionVector orderby_sel; - //! Reusable payload layout. - RowLayout payload_layout; }; -WindowNaiveState::WindowNaiveState(const WindowNaiveAggregator &aggregator_p) - : aggregator(aggregator_p), state(aggregator.state_size * STANDARD_VECTOR_SIZE), statef(LogicalType::POINTER), +WindowNaiveLocalState::WindowNaiveLocalState(const WindowNaiveAggregator &aggregator) + : aggregator(aggregator), state(aggregator.state_size * STANDARD_VECTOR_SIZE), statef(LogicalType::POINTER), statep((LogicalType::POINTER)), flush_count(0), hashes(LogicalType::HASH) { InitSubFrames(frames, aggregator.exclude_mode); @@ -113,17 +114,11 @@ WindowNaiveState::WindowNaiveState(const WindowNaiveAggregator &aggregator_p) fdata[i] = state_ptr; state_ptr += aggregator.state_size; } - - // Initialise any ORDER BY data - if (!aggregator.arg_order_idx.empty() && !arg_orderer) { - orderby_payload.Initialize(Allocator::DefaultAllocator(), {LogicalType::UBIGINT}); - payload_layout.Initialize(orderby_payload.GetTypes()); - orderby_sel.Initialize(); - } } -void WindowNaiveState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { - WindowAggregatorLocalState::Finalize(gastate, collection); +void WindowNaiveLocalState::Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, + CollectionPtr collection) { + WindowAggregatorLocalState::Finalize(context, gastate, collection); // Set up the comparison scanner just in case if (!comparer) { @@ -133,17 +128,35 @@ void WindowNaiveState::Finalize(WindowAggregatorGlobalState &gastate, Collection // Set up the argument ORDER BY scanner if needed if (!aggregator.arg_order_idx.empty() && !arg_orderer) { arg_orderer = make_uniq(*collection, aggregator.arg_order_idx); - orderby_sort.Initialize(BufferAllocator::Get(gastate.context), arg_orderer->chunk.GetTypes()); + auto input_types = arg_orderer->chunk.GetTypes(); + input_types.emplace_back(LogicalType::UBIGINT); + orderby_sink.Initialize(BufferAllocator::Get(gastate.client), input_types); + + // The sort expressions have already been computed, so we just need to reference them + vector orders; + for (const auto &order_by : aggregator.wexpr.arg_orders) { + auto order = order_by.Copy(); + const auto &type = order.expression->return_type; + order.expression = make_uniq(type, orders.size()); + orders.emplace_back(std::move(order)); + } + + // We only want the row numbers + vector projection_map(1, input_types.size() - 1); + orderby_scan.Initialize(BufferAllocator::Get(gastate.client), {input_types.back()}); + sort = make_uniq(context.client, orders, input_types, projection_map); + + orderby_sel.Initialize(); } // Initialise the chunks const auto types = cursor->chunk.GetTypes(); if (leaves.ColumnCount() == 0 && !types.empty()) { - leaves.Initialize(BufferAllocator::Get(gastate.context), types); + leaves.Initialize(BufferAllocator::Get(context.client), types); } } -void WindowNaiveState::FlushStates(const WindowAggregatorGlobalState &gsink) { +void WindowNaiveLocalState::FlushStates(const WindowAggregatorGlobalState &gsink) { if (!flush_count) { return; } @@ -158,7 +171,7 @@ void WindowNaiveState::FlushStates(const WindowAggregatorGlobalState &gsink) { flush_count = 0; } -size_t WindowNaiveState::Hash(idx_t rid) { +size_t WindowNaiveLocalState::Hash(idx_t rid) { D_ASSERT(cursor->RowIsVisible(rid)); auto s = cursor->RowOffset(rid); auto &scanned = cursor->chunk; @@ -169,7 +182,7 @@ size_t WindowNaiveState::Hash(idx_t rid) { return *FlatVector::GetData(hashes); } -bool WindowNaiveState::KeyEqual(const idx_t &lidx, const idx_t &ridx) { +bool WindowNaiveLocalState::KeyEqual(const idx_t &lidx, const idx_t &ridx) { // One of the indices will be scanned, so make it the left one auto lhs = lidx; auto rhs = ridx; @@ -206,8 +219,8 @@ bool WindowNaiveState::KeyEqual(const idx_t &lidx, const idx_t &ridx) { return true; } -void WindowNaiveState::Evaluate(const WindowAggregatorGlobalState &gsink, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx) { +void WindowNaiveLocalState::Evaluate(ExecutionContext &context, const WindowAggregatorGlobalState &gsink, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) { const auto &aggr = gsink.aggr; auto &filter_mask = gsink.filter_mask; const auto types = cursor->chunk.GetTypes(); @@ -228,14 +241,13 @@ void WindowNaiveState::Evaluate(const WindowAggregatorGlobalState &gsink, const // Sort the input rows by the argument if (arg_orderer) { - auto &context = aggregator.executor.context; - auto &orders = aggregator.wexpr.arg_orders; - GlobalSortState global_sort(context, orders, payload_layout); - LocalSortState local_sort; - local_sort.Initialize(global_sort, global_sort.buffer_manager); + auto global_sink = sort->GetGlobalSinkState(context.client); + auto local_sink = sort->GetLocalSinkState(context); + InterruptState interrupt; + OperatorSinkInput sink {*global_sink, *local_sink, interrupt}; idx_t orderby_count = 0; - auto orderby_row = FlatVector::GetData(orderby_payload.data[0]); + auto orderby_row = FlatVector::GetData(orderby_sink.data.back()); for (const auto &frame : frames) { for (auto f = frame.start; f < frame.end; ++f) { // FILTER before the ORDER BY @@ -245,43 +257,45 @@ void WindowNaiveState::Evaluate(const WindowAggregatorGlobalState &gsink, const if (!arg_orderer->RowIsVisible(f) || orderby_count >= STANDARD_VECTOR_SIZE) { if (orderby_count) { - orderby_sort.Reference(arg_orderer->chunk); - orderby_sort.Slice(orderby_sel, orderby_count); - orderby_payload.SetCardinality(orderby_count); - local_sort.SinkChunk(orderby_sort, orderby_payload); + for (column_t c = 0; c < arg_orderer->chunk.ColumnCount(); ++c) { + orderby_sink.data[c].Reference(arg_orderer->chunk.data[c]); + } + orderby_sink.Slice(orderby_sel, orderby_count); + sort->Sink(context, orderby_sink, sink); + orderby_sink.Reset(); } orderby_count = 0; arg_orderer->Seek(f); + // Fill in the row numbers + for (idx_t i = 0; i < arg_orderer->chunk.size(); ++i) { + orderby_row[i] = arg_orderer->state.current_row_index + i; + } } - orderby_row[orderby_count] = f; orderby_sel.set_index(orderby_count++, arg_orderer->RowOffset(f)); } } if (orderby_count) { - orderby_sort.Reference(arg_orderer->chunk); - orderby_sort.Slice(orderby_sel, orderby_count); - orderby_payload.SetCardinality(orderby_count); - local_sort.SinkChunk(orderby_sort, orderby_payload); + for (column_t c = 0; c < arg_orderer->chunk.ColumnCount(); ++c) { + orderby_sink.data[c].Reference(arg_orderer->chunk.data[c]); + } + orderby_sink.Slice(orderby_sel, orderby_count); + sort->Sink(context, orderby_sink, sink); + orderby_sink.Reset(); } - global_sort.AddLocalState(local_sort); - if (global_sort.sorted_blocks.empty()) { - return; - } - global_sort.PrepareMergePhase(); - while (global_sort.sorted_blocks.size() > 1) { - global_sort.InitializeMergeRound(); - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort.CompleteMergeRound(false); - } + OperatorSinkCombineInput combine {*global_sink, *local_sink, interrupt}; + sort->Combine(context, combine); + + OperatorSinkFinalizeInput finalize {*global_sink, interrupt}; + sort->Finalize(context.client, finalize); - PayloadScanner scanner(global_sort); - while (scanner.Remaining()) { - orderby_payload.Reset(); - scanner.Scan(orderby_payload); - orderby_row = FlatVector::GetData(orderby_payload.data[0]); - for (idx_t i = 0; i < orderby_payload.size(); ++i) { + auto global_source = sort->GetGlobalSourceState(context.client, *global_sink); + auto local_source = sort->GetLocalSourceState(context, *global_source); + OperatorSourceInput source {*global_source, *local_source, interrupt}; + orderby_scan.Reset(); + for (; SourceResultType::FINISHED != sort->GetData(context, orderby_scan, source); orderby_scan.Reset()) { + orderby_row = FlatVector::GetData(orderby_scan.data[0]); + for (idx_t i = 0; i < orderby_scan.size(); ++i) { const auto f = orderby_row[i]; // Seek to the current position if (!cursor->RowIsVisible(f)) { @@ -347,14 +361,15 @@ void WindowNaiveState::Evaluate(const WindowAggregatorGlobalState &gsink, const } unique_ptr WindowNaiveAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(*this); + return make_uniq(*this); } -void WindowNaiveAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { +void WindowNaiveAggregator::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, + WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, + idx_t count, idx_t row_idx) const { const auto &gnstate = gsink.Cast(); - auto &lnstate = lstate.Cast(); - lnstate.Evaluate(gnstate, bounds, result, count, row_idx); + auto &lnstate = lstate.Cast(); + lnstate.Evaluate(context, gnstate, bounds, result, count, row_idx); } } // namespace duckdb diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp index 5b269a91b..dbd2f98a4 100644 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ b/src/duckdb/src/function/window/window_rank_function.cpp @@ -10,9 +10,9 @@ namespace duckdb { //===--------------------------------------------------------------------===// class WindowPeerGlobalState : public WindowExecutorGlobalState { public: - WindowPeerGlobalState(const WindowPeerExecutor &executor, const idx_t payload_count, + WindowPeerGlobalState(ClientContext &client, const WindowPeerExecutor &executor, const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask) { + : WindowExecutorGlobalState(client, executor, payload_count, partition_mask, order_mask) { if (!executor.arg_order_idx.empty()) { use_framing = true; @@ -20,10 +20,9 @@ class WindowPeerGlobalState : public WindowExecutorGlobalState { // (and the optimizer is enabled), then we can just use the partition ordering. auto &wexpr = executor.wexpr; auto &arg_orders = executor.wexpr.arg_orders; - const auto optimize = ClientConfig::GetConfig(executor.context).enable_optimizer; + const auto optimize = ClientConfig::GetConfig(client).enable_optimizer; if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { - token_tree = - make_uniq(executor.context, arg_orders, executor.arg_order_idx, payload_count); + token_tree = make_uniq(client, arg_orders, executor.arg_order_idx, payload_count); } } } @@ -39,20 +38,20 @@ class WindowPeerGlobalState : public WindowExecutorGlobalState { // WindowPeerLocalState //===--------------------------------------------------------------------===// // Base class for non-aggregate functions that use peer boundaries -class WindowPeerLocalState : public WindowExecutorBoundsState { +class WindowPeerLocalState : public WindowExecutorBoundsLocalState { public: - explicit WindowPeerLocalState(const WindowPeerGlobalState &gpstate) - : WindowExecutorBoundsState(gpstate), gpstate(gpstate) { + WindowPeerLocalState(ExecutionContext &context, const WindowPeerGlobalState &gpstate) + : WindowExecutorBoundsLocalState(context, gpstate), gpstate(gpstate) { if (gpstate.token_tree) { - local_tree = gpstate.token_tree->GetLocalState(); + local_tree = gpstate.token_tree->GetLocalState(context); } } //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; + void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) override; //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; + void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection) override; void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); @@ -66,22 +65,23 @@ class WindowPeerLocalState : public WindowExecutorBoundsState { unique_ptr local_tree; }; -void WindowPeerLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); +void WindowPeerLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) { + WindowExecutorBoundsLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0); + local_tokens.Sink(context, sink_chunk, input_idx, nullptr, 0); } } -void WindowPeerLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowExecutorBoundsState::Finalize(gstate, collection); +void WindowPeerLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + CollectionPtr collection) { + WindowExecutorBoundsLocalState::Finalize(context, gstate, collection); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.Sort(); + local_tokens.Finalize(context); local_tokens.window_tree.Build(); } } @@ -102,35 +102,36 @@ void WindowPeerLocalState::NextRank(idx_t partition_begin, idx_t peer_begin, idx //===--------------------------------------------------------------------===// // WindowPeerExecutor //===--------------------------------------------------------------------===// -WindowPeerExecutor::WindowPeerExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowExecutor(wexpr, context, shared) { +WindowPeerExecutor::WindowPeerExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowExecutor(wexpr, shared) { for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } } -unique_ptr WindowPeerExecutor::GetGlobalState(const idx_t payload_count, +unique_ptr WindowPeerExecutor::GetGlobalState(ClientContext &client, + const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); + return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr WindowPeerExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate.Cast()); +unique_ptr WindowPeerExecutor::GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const { + return make_uniq(context, gstate.Cast()); } //===--------------------------------------------------------------------===// // WindowRankExecutor //===--------------------------------------------------------------------===// -WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { +WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowPeerExecutor(wexpr, shared) { } -void WindowRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { +void WindowRankExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &gpeer = gstate.Cast(); auto &lpeer = lstate.Cast(); auto rdata = FlatVector::GetData(result); @@ -168,14 +169,13 @@ void WindowRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, Win //===--------------------------------------------------------------------===// // WindowDenseRankExecutor //===--------------------------------------------------------------------===// -WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { +WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowPeerExecutor(wexpr, shared) { } -void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowDenseRankExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &lpeer = lstate.Cast(); auto &order_mask = gstate.order_mask; @@ -231,9 +231,8 @@ void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate //===--------------------------------------------------------------------===// // WindowPercentRankExecutor //===--------------------------------------------------------------------===// -WindowPercentRankExecutor::WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { +WindowPercentRankExecutor::WindowPercentRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowPeerExecutor(wexpr, shared) { } static inline double PercentRank(const idx_t begin, const idx_t end, const uint64_t rank) { @@ -241,9 +240,9 @@ static inline double PercentRank(const idx_t begin, const idx_t end, const uint6 return denom > 0 ? ((double)rank - 1) / denom : 0; } -void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowPercentRankExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, + Vector &result, idx_t count, idx_t row_idx) const { auto &gpeer = gstate.Cast(); auto &lpeer = lstate.Cast(); auto rdata = FlatVector::GetData(result); @@ -284,9 +283,8 @@ void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gsta //===--------------------------------------------------------------------===// // WindowCumeDistExecutor //===--------------------------------------------------------------------===// -WindowCumeDistExecutor::WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { +WindowCumeDistExecutor::WindowCumeDistExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowPeerExecutor(wexpr, shared) { } static inline double CumeDist(const idx_t begin, const idx_t end, const idx_t peer_end) { @@ -295,8 +293,9 @@ static inline double CumeDist(const idx_t begin, const idx_t end, const idx_t pe return denom > 0 ? (num / denom) : 0; } -void WindowCumeDistExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { +void WindowCumeDistExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &gpeer = gstate.Cast(); auto &lpeer = lstate.Cast(); auto rdata = FlatVector::GetData(result); diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp index b96e818d4..de52242b5 100644 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ b/src/duckdb/src/function/window/window_rownumber_function.cpp @@ -10,9 +10,10 @@ namespace duckdb { //===--------------------------------------------------------------------===// class WindowRowNumberGlobalState : public WindowExecutorGlobalState { public: - WindowRowNumberGlobalState(const WindowRowNumberExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask), + WindowRowNumberGlobalState(ClientContext &client, const WindowRowNumberExecutor &executor, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowExecutorGlobalState(client, executor, payload_count, partition_mask, order_mask), ntile_idx(executor.ntile_idx) { if (!executor.arg_order_idx.empty()) { use_framing = true; @@ -21,12 +22,12 @@ class WindowRowNumberGlobalState : public WindowExecutorGlobalState { // then we can just use the partition ordering. auto &wexpr = executor.wexpr; auto &arg_orders = executor.wexpr.arg_orders; - const auto optimize = ClientConfig::GetConfig(executor.context).enable_optimizer; + const auto optimize = ClientConfig::GetConfig(client).enable_optimizer; if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their // position in the input data, such that two elements never compare as equal." - token_tree = make_uniq(executor.context, executor.wexpr.arg_orders, - executor.arg_order_idx, payload_count, true); + token_tree = make_uniq(client, executor.wexpr.arg_orders, executor.arg_order_idx, + payload_count, true); } } } @@ -44,20 +45,20 @@ class WindowRowNumberGlobalState : public WindowExecutorGlobalState { //===--------------------------------------------------------------------===// // WindowRowNumberLocalState //===--------------------------------------------------------------------===// -class WindowRowNumberLocalState : public WindowExecutorBoundsState { +class WindowRowNumberLocalState : public WindowExecutorBoundsLocalState { public: - explicit WindowRowNumberLocalState(const WindowRowNumberGlobalState &grstate) - : WindowExecutorBoundsState(grstate), grstate(grstate) { + explicit WindowRowNumberLocalState(ExecutionContext &context, const WindowRowNumberGlobalState &grstate) + : WindowExecutorBoundsLocalState(context, grstate), grstate(grstate) { if (grstate.token_tree) { - local_tree = grstate.token_tree->GetLocalState(); + local_tree = grstate.token_tree->GetLocalState(context); } } //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; + void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) override; //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; + void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection) override; //! The corresponding global peer state const WindowRowNumberGlobalState &grstate; @@ -65,22 +66,23 @@ class WindowRowNumberLocalState : public WindowExecutorBoundsState { unique_ptr local_tree; }; -void WindowRowNumberLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); +void WindowRowNumberLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, + DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx) { + WindowExecutorBoundsLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0); + local_tokens.Sink(context, sink_chunk, input_idx, nullptr, 0); } } -void WindowRowNumberLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowExecutorBoundsState::Finalize(gstate, collection); +void WindowRowNumberLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + CollectionPtr collection) { + WindowExecutorBoundsLocalState::Finalize(context, gstate, collection); if (local_tree) { auto &local_tokens = local_tree->Cast(); - local_tokens.Sort(); + local_tokens.Finalize(context); local_tokens.window_tree.Build(); } } @@ -88,29 +90,29 @@ void WindowRowNumberLocalState::Finalize(WindowExecutorGlobalState &gstate, Coll //===--------------------------------------------------------------------===// // WindowRowNumberExecutor //===--------------------------------------------------------------------===// -WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowExecutor(wexpr, context, shared) { +WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowExecutor(wexpr, shared) { for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); } } -unique_ptr WindowRowNumberExecutor::GetGlobalState(const idx_t payload_count, +unique_ptr WindowRowNumberExecutor::GetGlobalState(ClientContext &client, + const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); + return make_uniq(client, *this, payload_count, partition_mask, order_mask); } unique_ptr -WindowRowNumberExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate.Cast()); +WindowRowNumberExecutor::GetLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) const { + return make_uniq(context, gstate.Cast()); } -void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowRowNumberExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &grstate = gstate.Cast(); auto &lrstate = lstate.Cast(); auto rdata = FlatVector::GetData(result); @@ -140,16 +142,16 @@ void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate //===--------------------------------------------------------------------===// // WindowNtileExecutor //===--------------------------------------------------------------------===// -WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowRowNumberExecutor(wexpr, context, shared) { +WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowRowNumberExecutor(wexpr, shared) { // NTILE has one argument ntile_idx = shared.RegisterEvaluate(wexpr.children[0]); } -void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { +void WindowNtileExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &grstate = gstate.Cast(); auto &lrstate = lstate.Cast(); auto partition_begin = FlatVector::GetData(lrstate.bounds.data[PARTITION_BEGIN]); diff --git a/src/duckdb/src/function/window/window_segment_tree.cpp b/src/duckdb/src/function/window/window_segment_tree.cpp index 9168a4f22..31d4f8597 100644 --- a/src/duckdb/src/function/window/window_segment_tree.cpp +++ b/src/duckdb/src/function/window/window_segment_tree.cpp @@ -131,23 +131,23 @@ class WindowSegmentTreePart { vector right_stack; }; -class WindowSegmentTreeState : public WindowAggregatorLocalState { +class WindowSegmentTreeLocalState : public WindowAggregatorLocalState { public: - WindowSegmentTreeState() { + WindowSegmentTreeLocalState() { } - void Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; - void Evaluate(const WindowSegmentTreeGlobalState &gsink, const DataChunk &bounds, Vector &result, idx_t count, - idx_t row_idx); + void Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; + void Evaluate(ExecutionContext &context, const WindowSegmentTreeGlobalState &gsink, const DataChunk &bounds, + Vector &result, idx_t count, idx_t row_idx); //! The left (default) segment tree part unique_ptr part; //! The right segment tree part (for EXCLUDE) unique_ptr right_part; }; -void WindowSegmentTree::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) { - WindowAggregator::Finalize(gsink, lstate, collection, stats); +void WindowSegmentTree::Finalize(ExecutionContext &context, WindowAggregatorState &gsink, WindowAggregatorState &lstate, + CollectionPtr collection, const FrameStats &stats) { + WindowAggregator::Finalize(context, gsink, lstate, collection, stats); auto &gasink = gsink.Cast(); ++gasink.finalized; @@ -188,7 +188,7 @@ unique_ptr WindowSegmentTree::GetGlobalState(ClientContex } unique_ptr WindowSegmentTree::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(); + return make_uniq(); } void WindowSegmentTreePart::FlushStates(bool combining) { @@ -335,8 +335,9 @@ WindowSegmentTreeGlobalState::WindowSegmentTreeGlobalState(ClientContext &contex } } -void WindowSegmentTreeState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { - WindowAggregatorLocalState::Finalize(gastate, collection); +void WindowSegmentTreeLocalState::Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, + CollectionPtr collection) { + WindowAggregatorLocalState::Finalize(context, gastate, collection); // Single part for constructing the tree auto &gstate = gastate.Cast(); @@ -390,15 +391,16 @@ void WindowSegmentTreeState::Finalize(WindowAggregatorGlobalState &gastate, Coll } } -void WindowSegmentTree::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { +void WindowSegmentTree::Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, + WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, idx_t count, + idx_t row_idx) const { const auto >state = gsink.Cast(); - auto <state = lstate.Cast(); - ltstate.Evaluate(gtstate, bounds, result, count, row_idx); + auto <state = lstate.Cast(); + ltstate.Evaluate(context, gtstate, bounds, result, count, row_idx); } -void WindowSegmentTreeState::Evaluate(const WindowSegmentTreeGlobalState >state, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) { +void WindowSegmentTreeLocalState::Evaluate(ExecutionContext &context, const WindowSegmentTreeGlobalState >state, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) { auto window_begin = FlatVector::GetData(bounds.data[FRAME_BEGIN]); auto window_end = FlatVector::GetData(bounds.data[FRAME_END]); auto peer_begin = FlatVector::GetData(bounds.data[PEER_BEGIN]); diff --git a/src/duckdb/src/function/window/window_token_tree.cpp b/src/duckdb/src/function/window/window_token_tree.cpp index c0a290aba..bcad348cf 100644 --- a/src/duckdb/src/function/window/window_token_tree.cpp +++ b/src/duckdb/src/function/window/window_token_tree.cpp @@ -4,8 +4,8 @@ namespace duckdb { class WindowTokenTreeLocalState : public WindowMergeSortTreeLocalState { public: - explicit WindowTokenTreeLocalState(WindowTokenTree &token_tree) - : WindowMergeSortTreeLocalState(token_tree), token_tree(token_tree) { + WindowTokenTreeLocalState(ExecutionContext &context, WindowTokenTree &token_tree) + : WindowMergeSortTreeLocalState(context, token_tree), token_tree(token_tree) { } //! Process sorted leaf data void BuildLeaves() override; @@ -14,40 +14,36 @@ class WindowTokenTreeLocalState : public WindowMergeSortTreeLocalState { }; void WindowTokenTreeLocalState::BuildLeaves() { - auto &global_sort = *token_tree.global_sort; - if (global_sort.sorted_blocks.empty()) { - return; - } - - // Scan the sort keys and note deltas - SBIterator curr(global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator prev(global_sort, ExpressionType::COMPARE_LESSTHAN); - const auto &sort_layout = global_sort.sort_layout; + // Find our chunk range + auto &collection = *token_tree.sorted; + const auto block_begin = (build_task * collection.ChunkCount()) / token_tree.total_tasks; + const auto block_end = ((build_task + 1) * collection.ChunkCount()) / token_tree.total_tasks; - const auto block_begin = token_tree.block_starts.at(build_task); - const auto block_end = token_tree.block_starts.at(build_task + 1); auto &deltas = token_tree.deltas; if (!block_begin) { // First block, so set up initial delta deltas[0] = 0; - } else { - // Move to the to end of the previous block - // so we can record the comparison result for the first row - curr.SetIndex(block_begin - 1); - prev.SetIndex(block_begin - 1); } - for (++curr; curr.GetIndex() < block_end; ++curr, ++prev) { - int lt = 0; - if (sort_layout.all_constant) { - lt = FastMemcmp(prev.entry_ptr, curr.entry_ptr, sort_layout.comparison_size); - } else { - lt = Comparators::CompareTuple(prev.scan, curr.scan, prev.entry_ptr, curr.entry_ptr, sort_layout, - prev.external); - } - - deltas[curr.GetIndex()] = (lt != 0); - } + const auto &scan_cols = token_tree.key_cols; + const auto key_count = scan_cols.size(); + WindowDeltaScanner(collection, block_begin, block_end, scan_cols, key_count, + [&](const idx_t row_idx, DataChunk &prev, DataChunk &curr, const idx_t ndistinct, + SelectionVector &distinct, const SelectionVector &matching) { + // Same as previous - token delta is 0 + const auto count = MinValue(prev.size(), curr.size()); + const auto nmatch = count - ndistinct; + for (idx_t j = 0; j < nmatch; ++j) { + auto scan_idx = matching.get_index(j); + deltas[scan_idx + row_idx] = 0; + } + + // Different value - token delta is 1 + for (idx_t j = 0; j < ndistinct; ++j) { + auto scan_idx = distinct.get_index(j); + deltas[scan_idx + row_idx] = 1; + } + }); } idx_t WindowTokenTree::MeasurePayloadBlocks() { @@ -60,18 +56,22 @@ idx_t WindowTokenTree::MeasurePayloadBlocks() { template static void BuildTokens(WindowTokenTree &token_tree, vector &tokens) { - PayloadScanner scanner(*token_tree.global_sort); - DataChunk payload_chunk; - payload_chunk.Initialize(token_tree.context, token_tree.global_sort->payload_layout.GetTypes()); + auto &collection = *token_tree.sorted; + if (!collection.Count()) { + return; + } + // Scan the index column + vector scan_ids(1, token_tree.scan_cols.size() - 1); + WindowCollectionChunkScanner scanner(collection, scan_ids, 0); + auto &payload_chunk = scanner.chunk; + const T *row_idx = nullptr; idx_t i = 0; T token = 0; for (auto &d : token_tree.deltas) { if (i >= payload_chunk.size()) { - payload_chunk.Reset(); - scanner.Scan(payload_chunk); - if (!payload_chunk.size()) { + if (!scanner.Scan()) { break; } row_idx = FlatVector::GetDataUnsafe(payload_chunk.data[0]); @@ -83,23 +83,27 @@ static void BuildTokens(WindowTokenTree &token_tree, vector &tokens) { } } -unique_ptr WindowTokenTree::GetLocalState() { - return make_uniq(*this); +unique_ptr WindowTokenTree::GetLocalState(ExecutionContext &context) { + return make_uniq(context, *this); } -void WindowTokenTree::CleanupSort() { +void WindowTokenTree::Finished() { // Convert the deltas to tokens if (mst64) { BuildTokens(*this, mst64->LowestLevel()); } else { BuildTokens(*this, mst32->LowestLevel()); } - + /* + for (const auto &d : deltas) { + Printer::Print(StringUtil::Format("%lld", d)); + } + */ // Deallocate memory vector empty; deltas.swap(empty); - WindowMergeSortTree::CleanupSort(); + WindowMergeSortTree::Finished(); } template diff --git a/src/duckdb/src/function/window/window_value_function.cpp b/src/duckdb/src/function/window/window_value_function.cpp index 0aafbbc39..9e7f0f514 100644 --- a/src/duckdb/src/function/window/window_value_function.cpp +++ b/src/duckdb/src/function/window/window_value_function.cpp @@ -19,14 +19,14 @@ namespace duckdb { class WindowValueGlobalState : public WindowExecutorGlobalState { public: using WindowCollectionPtr = unique_ptr; - WindowValueGlobalState(const WindowValueExecutor &executor, const idx_t payload_count, + WindowValueGlobalState(ClientContext &client, const WindowValueExecutor &executor, const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask), ignore_nulls(&all_valid), - child_idx(executor.child_idx) { + : WindowExecutorGlobalState(client, executor, payload_count, partition_mask, order_mask), + ignore_nulls(&all_valid), child_idx(executor.child_idx) { if (!executor.arg_order_idx.empty()) { - value_tree = make_uniq(executor.context, executor.wexpr.arg_orders, executor.arg_order_idx, - payload_count); + value_tree = + make_uniq(client, executor.wexpr.arg_orders, executor.arg_order_idx, payload_count); } } @@ -54,14 +54,14 @@ class WindowValueGlobalState : public WindowExecutorGlobalState { //===--------------------------------------------------------------------===// //! A class representing the state of the first_value, last_value and nth_value functions -class WindowValueLocalState : public WindowExecutorBoundsState { +class WindowValueLocalState : public WindowExecutorBoundsLocalState { public: - explicit WindowValueLocalState(const WindowValueGlobalState &gvstate) - : WindowExecutorBoundsState(gvstate), gvstate(gvstate) { + WindowValueLocalState(ExecutionContext &context, const WindowValueGlobalState &gvstate) + : WindowExecutorBoundsLocalState(context, gvstate), gvstate(gvstate) { WindowAggregatorLocalState::InitSubFrames(frames, gvstate.executor.wexpr.exclude_clause); if (gvstate.value_tree) { - local_value = gvstate.value_tree->GetLocalState(); + local_value = gvstate.value_tree->GetLocalState(context); if (gvstate.executor.IgnoreNulls()) { sort_nulls.Initialize(); } @@ -69,10 +69,10 @@ class WindowValueLocalState : public WindowExecutorBoundsState { } //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; + void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) override; //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; + void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection) override; //! The corresponding global value state const WindowValueGlobalState &gvstate; @@ -87,9 +87,9 @@ class WindowValueLocalState : public WindowExecutorBoundsState { unique_ptr cursor; }; -void WindowValueLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); +void WindowValueLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) { + WindowExecutorBoundsLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx); if (local_value) { idx_t filtered = 0; @@ -114,16 +114,17 @@ void WindowValueLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &s } auto &value_state = local_value->Cast(); - value_state.SinkChunk(sink_chunk, input_idx, filter_sel, filtered); + value_state.Sink(context, sink_chunk, input_idx, filter_sel, filtered); } } -void WindowValueLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowExecutorBoundsState::Finalize(gstate, collection); +void WindowValueLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + CollectionPtr collection) { + WindowExecutorBoundsLocalState::Finalize(context, gstate, collection); if (local_value) { auto &value_state = local_value->Cast(); - value_state.Sort(); + value_state.Finalize(context); value_state.index_tree.Build(); } @@ -136,9 +137,8 @@ void WindowValueLocalState::Finalize(WindowExecutorGlobalState &gstate, Collecti //===--------------------------------------------------------------------===// // WindowValueExecutor //===--------------------------------------------------------------------===// -WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowExecutor(wexpr, context, shared) { +WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowExecutor(wexpr, shared) { for (const auto &order : wexpr.arg_orders) { arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); @@ -157,23 +157,25 @@ WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientCon default_idx = shared.RegisterEvaluate(wexpr.default_expr); } -unique_ptr WindowValueExecutor::GetGlobalState(const idx_t payload_count, +unique_ptr WindowValueExecutor::GetGlobalState(ClientContext &client, + const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); + return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -void WindowValueExecutor::Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection) const { +void WindowValueExecutor::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, CollectionPtr collection) const { auto &gvstate = gstate.Cast(); gvstate.Finalize(collection); - WindowExecutor::Finalize(gstate, lstate, collection); + WindowExecutor::Finalize(context, gstate, lstate, collection); } -unique_ptr WindowValueExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowValueExecutor::GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const { const auto &gvstate = gstate.Cast(); - return make_uniq(gvstate); + return make_uniq(context, gvstate); } //===--------------------------------------------------------------------===// @@ -196,9 +198,10 @@ unique_ptr WindowValueExecutor::GetLocalState(const Wi class WindowLeadLagGlobalState : public WindowValueGlobalState { public: - explicit WindowLeadLagGlobalState(const WindowValueExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowValueGlobalState(executor, payload_count, partition_mask, order_mask) { + explicit WindowLeadLagGlobalState(ClientContext &client, const WindowValueExecutor &executor, + const idx_t payload_count, const ValidityMask &partition_mask, + const ValidityMask &order_mask) + : WindowValueGlobalState(client, executor, payload_count, partition_mask, order_mask) { if (value_tree) { use_framing = true; @@ -207,14 +210,13 @@ class WindowLeadLagGlobalState : public WindowValueGlobalState { // then we can just use the partition ordering. auto &wexpr = executor.wexpr; auto &arg_orders = executor.wexpr.arg_orders; - const auto optimize = ClientConfig::GetConfig(executor.context).enable_optimizer; + const auto optimize = ClientConfig::GetConfig(client).enable_optimizer; if (!optimize || BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) != arg_orders.size()) { // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their // position in the input data, such that two elements never compare as equal." // Note: If the user specifies an partial secondary sort, the disambiguation will use the // partition's row numbers, not the secondary sort's row numbers. - row_tree = make_uniq(executor.context, arg_orders, executor.arg_order_idx, - payload_count, true); + row_tree = make_uniq(client, arg_orders, executor.arg_order_idx, payload_count, true); } else { // The value_tree is cheap to construct, so we just get rid of it if we now discover we don't need it. value_tree.reset(); @@ -234,41 +236,43 @@ class WindowLeadLagGlobalState : public WindowValueGlobalState { //===--------------------------------------------------------------------===// class WindowLeadLagLocalState : public WindowValueLocalState { public: - explicit WindowLeadLagLocalState(const WindowLeadLagGlobalState &gstate) : WindowValueLocalState(gstate) { + explicit WindowLeadLagLocalState(ExecutionContext &context, const WindowLeadLagGlobalState &gstate) + : WindowValueLocalState(context, gstate) { if (gstate.row_tree) { - local_row = gstate.row_tree->GetLocalState(); + local_row = gstate.row_tree->GetLocalState(context); } } //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; + void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) override; //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; + void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection) override; //! The optional sorting state for the secondary sort row mapping unique_ptr local_row; }; -void WindowLeadLagLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowValueLocalState::Sink(gstate, sink_chunk, coll_chunk, input_idx); +void WindowLeadLagLocalState::Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx) { + WindowValueLocalState::Sink(context, gstate, sink_chunk, coll_chunk, input_idx); if (local_row) { idx_t filtered = 0; optional_ptr filter_sel; auto &row_state = local_row->Cast(); - row_state.SinkChunk(sink_chunk, input_idx, filter_sel, filtered); + row_state.Sink(context, sink_chunk, input_idx, filter_sel, filtered); } } -void WindowLeadLagLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowValueLocalState::Finalize(gstate, collection); +void WindowLeadLagLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + CollectionPtr collection) { + WindowValueLocalState::Finalize(context, gstate, collection); if (local_row) { auto &row_state = local_row->Cast(); - row_state.Sort(); + row_state.Finalize(context); row_state.window_tree.Build(); } } @@ -276,25 +280,26 @@ void WindowLeadLagLocalState::Finalize(WindowExecutorGlobalState &gstate, Collec //===--------------------------------------------------------------------===// // WindowLeadLagExecutor //===--------------------------------------------------------------------===// -WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { +WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowValueExecutor(wexpr, shared) { } -unique_ptr WindowLeadLagExecutor::GetGlobalState(const idx_t payload_count, +unique_ptr WindowLeadLagExecutor::GetGlobalState(ClientContext &client, + const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); + return make_uniq(client, *this, payload_count, partition_mask, order_mask); } unique_ptr -WindowLeadLagExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { +WindowLeadLagExecutor::GetLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate) const { const auto &glstate = gstate.Cast(); - return make_uniq(glstate); + return make_uniq(context, glstate); } -void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { +void WindowLeadLagExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &glstate = gstate.Cast(); auto &llstate = lstate.Cast(); auto &cursor = *llstate.cursor; @@ -445,14 +450,13 @@ void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, } } -WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { +WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowValueExecutor(wexpr, shared) { } -void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowFirstValueExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &gvstate = gstate.Cast(); auto &lvstate = lstate.Cast(); auto &cursor = *lvstate.cursor; @@ -496,14 +500,13 @@ void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstat }); } -WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { +WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowValueExecutor(wexpr, shared) { } -void WindowLastValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { +void WindowLastValueExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &gvstate = gstate.Cast(); auto &lvstate = lstate.Cast(); auto &cursor = *lvstate.cursor; @@ -557,13 +560,13 @@ void WindowLastValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate }); } -WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { +WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowValueExecutor(wexpr, shared) { } -void WindowNthValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { +void WindowNthValueExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, + idx_t count, idx_t row_idx) const { auto &gvstate = gstate.Cast(); auto &lvstate = lstate.Cast(); auto &cursor = *lvstate.cursor; @@ -845,9 +848,8 @@ static fill_value_t GetFillValueFunction(const LogicalType &type) { } } -WindowFillExecutor::WindowFillExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { +WindowFillExecutor::WindowFillExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared) + : WindowValueExecutor(wexpr, shared) { // We need the sort values for interpolation, so either use the range or the secondary ordering expression if (arg_order_idx.empty()) { @@ -877,9 +879,10 @@ static void WindowFillCopy(WindowCursor &cursor, Vector &result, idx_t count, id class WindowFillGlobalState : public WindowLeadLagGlobalState { public: - explicit WindowFillGlobalState(const WindowFillExecutor &executor, const idx_t payload_count, + explicit WindowFillGlobalState(ClientContext &client, const WindowFillExecutor &executor, const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowLeadLagGlobalState(executor, payload_count, partition_mask, order_mask), order_idx(executor.order_idx) { + : WindowLeadLagGlobalState(client, executor, payload_count, partition_mask, order_mask), + order_idx(executor.order_idx) { } //! Collection index of the secondary sort values @@ -888,18 +891,20 @@ class WindowFillGlobalState : public WindowLeadLagGlobalState { class WindowFillLocalState : public WindowLeadLagLocalState { public: - explicit WindowFillLocalState(const WindowLeadLagGlobalState &gvstate) : WindowLeadLagLocalState(gvstate) { + WindowFillLocalState(ExecutionContext &context, const WindowLeadLagGlobalState &gvstate) + : WindowLeadLagLocalState(context, gvstate) { } //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; + void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection) override; //! Cursor for the secondary sort values unique_ptr order_cursor; }; -void WindowFillLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowLeadLagLocalState::Finalize(gstate, collection); +void WindowFillLocalState::Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + CollectionPtr collection) { + WindowLeadLagLocalState::Finalize(context, gstate, collection); // Prepare to scan auto &gfstate = gvstate.Cast(); @@ -908,19 +913,22 @@ void WindowFillLocalState::Finalize(WindowExecutorGlobalState &gstate, Collectio } } -unique_ptr WindowFillExecutor::GetGlobalState(const idx_t payload_count, +unique_ptr WindowFillExecutor::GetGlobalState(ClientContext &client, + const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); + return make_uniq(client, *this, payload_count, partition_mask, order_mask); } -unique_ptr WindowFillExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { +unique_ptr WindowFillExecutor::GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const { const auto &gfstate = gstate.Cast(); - return make_uniq(gfstate); + return make_uniq(context, gfstate); } -void WindowFillExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &, Vector &result, idx_t count, idx_t row_idx) const { +void WindowFillExecutor::EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &, Vector &result, idx_t count, + idx_t row_idx) const { auto &lfstate = lstate.Cast(); auto &cursor = *lfstate.cursor; diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h index 2737575c5..14ce20012 100644 --- a/src/duckdb/src/include/duckdb.h +++ b/src/duckdb/src/include/duckdb.h @@ -132,14 +132,16 @@ typedef enum DUCKDB_TYPE { DUCKDB_TYPE_TIMESTAMP_TZ = 31, // enum type, only useful as logical type DUCKDB_TYPE_ANY = 34, - // duckdb_varint - DUCKDB_TYPE_VARINT = 35, + // duckdb_bignum + DUCKDB_TYPE_BIGNUM = 35, // enum type, only useful as logical type DUCKDB_TYPE_SQLNULL = 36, // enum type, only useful as logical type DUCKDB_TYPE_STRING_LITERAL = 37, // enum type, only useful as logical type DUCKDB_TYPE_INTEGER_LITERAL = 38, + // duckdb_time_ns (nanoseconds) + DUCKDB_TYPE_TIME_NS = 39, } duckdb_type; //! An enum over the returned state of different functions. @@ -290,6 +292,11 @@ typedef struct { int32_t micros; } duckdb_time_struct; +//! TIME_NS is stored as nanoseconds since 00:00:00. +typedef struct { + int64_t nanos; +} duckdb_time_ns; + //! TIME_TZ is stored as 40 bits for the int64_t microseconds, and 24 bits for the int32_t offset. //! Use the `duckdb_from_time_tz` function to extract individual information. typedef struct { @@ -448,14 +455,14 @@ typedef struct { idx_t size; } duckdb_bit; -//! VARINTs are composed of a byte pointer, a size, and an `is_negative` bool. +//! BIGNUMs are composed of a byte pointer, a size, and an `is_negative` bool. //! The absolute value of the number is stored in `data` in little endian format. //! You must free `data` with `duckdb_free`. typedef struct { uint8_t *data; idx_t size; bool is_negative; -} duckdb_varint; +} duckdb_bignum; //! A query result consists of a pointer to its internal data. //! Must be freed with 'duckdb_destroy_result'. @@ -564,6 +571,12 @@ typedef struct _duckdb_error_data { void *internal_ptr; } * duckdb_error_data; +//! Holds a bound expression. +//! Must be destroyed with `duckdb_destroy_expression`. +typedef struct _duckdb_expression { + void *internal_ptr; +} * duckdb_expression; + //===--------------------------------------------------------------------===// // C API extension information //===--------------------------------------------------------------------===// @@ -702,6 +715,14 @@ typedef void (*duckdb_replacement_callback_t)(duckdb_replacement_scan_info info, // Arrow-related types //===--------------------------------------------------------------------===// +//! Forward declare Arrow structs +//! It is important to notice that these structs are not defined by DuckDB but are actually Arrow external objects. +//! They're defined by the C Data Interface Arrow spec: https://arrow.apache.org/docs/format/CDataInterface.html + +struct ArrowArray; + +struct ArrowSchema; + //! Holds an arrow query result. Must be destroyed with `duckdb_destroy_arrow`. typedef struct _duckdb_arrow { void *internal_ptr; @@ -717,11 +738,25 @@ typedef struct _duckdb_arrow_schema { void *internal_ptr; } * duckdb_arrow_schema; -//! Holds an arrow array. Remember to release the respective ArrowArray object. +//! Holds an arrow converted schema (i.e., duckdb::ArrowTableSchema). +//! In practice, this object holds the information necessary to do proper conversion between Arrow Types and DuckDB +//! Types. Check duckdb/function/table/arrow/arrow_duck_schema.hpp for more details! Must be destroyed with +//! `duckdb_destroy_arrow_converted_schema` +typedef struct _duckdb_arrow_converted_schema { + void *internal_ptr; +} * duckdb_arrow_converted_schema; + +//! Holds an arrow array. Remember to release the respective ArrowSchema object. typedef struct _duckdb_arrow_array { void *internal_ptr; } * duckdb_arrow_array; +//! The arrow options used when transforming the DuckDB schema and datachunks into Arrow schema and arrays. +//! Used in `duckdb_to_arrow_schema` and `duckdb_data_chunk_to_arrow` +typedef struct _duckdb_arrow_options { + void *internal_ptr; +} * duckdb_arrow_options; + //===--------------------------------------------------------------------===// // DuckDB extension access //===--------------------------------------------------------------------===// @@ -855,6 +890,14 @@ Retrieves the client context of the connection. DUCKDB_C_API void duckdb_connection_get_client_context(duckdb_connection connection, duckdb_client_context *out_context); +/*! +Retrieves the arrow options of the connection. + +* @param connection The connection. +*/ +DUCKDB_C_API void duckdb_connection_get_arrow_options(duckdb_connection connection, + duckdb_arrow_options *out_arrow_options); + /*! Returns the connection id of the client context. @@ -870,6 +913,13 @@ Destroys the client context and deallocates its memory. */ DUCKDB_C_API void duckdb_destroy_client_context(duckdb_client_context *context); +/*! +Destroys the arrow options and deallocates its memory. + +* @param arrow_options The arrow options to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_arrow_options(duckdb_arrow_options *arrow_options); + /*! Returns the version of the linked DuckDB, with a version postfix for dev versions @@ -1070,6 +1120,15 @@ Returns `NULL` if the column is out of range. */ DUCKDB_C_API duckdb_logical_type duckdb_column_logical_type(duckdb_result *result, idx_t col); +/*! +Returns the arrow options associated with the given result. These options are definitions of how the arrow arrays/schema +should be produced. +* @param result The result object to fetch arrow options from. +* @return The arrow options associated with the given result. This must be destroyed with +`duckdb_destroy_arrow_options`. +*/ +DUCKDB_C_API duckdb_arrow_options duckdb_result_get_arrow_options(duckdb_result *result); + /*! Returns the number of columns present in a the result object. @@ -2205,12 +2264,12 @@ Creates a value from a uhugeint DUCKDB_C_API duckdb_value duckdb_create_uhugeint(duckdb_uhugeint input); /*! -Creates a VARINT value from a duckdb_varint +Creates a BIGNUM value from a duckdb_bignum -* @param input The duckdb_varint value +* @param input The duckdb_bignum value * @return The value. This must be destroyed with `duckdb_destroy_value`. */ -DUCKDB_C_API duckdb_value duckdb_create_varint(duckdb_varint input); +DUCKDB_C_API duckdb_value duckdb_create_bignum(duckdb_bignum input); /*! Creates a DECIMAL value from a duckdb_decimal @@ -2252,6 +2311,14 @@ Creates a value from a time */ DUCKDB_C_API duckdb_value duckdb_create_time(duckdb_time input); +/*! +Creates a value from a time_ns + +* @param input The time value +* @return The value. This must be destroyed with `duckdb_destroy_value`. +*/ +DUCKDB_C_API duckdb_value duckdb_create_time_ns(duckdb_time_ns input); + /*! Creates a value from a time_tz. Not to be confused with `duckdb_create_time_tz`, which creates a duckdb_time_tz_t. @@ -2423,13 +2490,13 @@ Returns the uhugeint value of the given value. DUCKDB_C_API duckdb_uhugeint duckdb_get_uhugeint(duckdb_value val); /*! -Returns the duckdb_varint value of the given value. +Returns the duckdb_bignum value of the given value. The `data` field must be destroyed with `duckdb_free`. -* @param val A duckdb_value containing a VARINT -* @return A duckdb_varint. The `data` field must be destroyed with `duckdb_free`. +* @param val A duckdb_value containing a BIGNUM +* @return A duckdb_bignum. The `data` field must be destroyed with `duckdb_free`. */ -DUCKDB_C_API duckdb_varint duckdb_get_varint(duckdb_value val); +DUCKDB_C_API duckdb_bignum duckdb_get_bignum(duckdb_value val); /*! Returns the duckdb_decimal value of the given value. @@ -2471,6 +2538,14 @@ Returns the time value of the given value. */ DUCKDB_C_API duckdb_time duckdb_get_time(duckdb_value val); +/*! +Returns the time_ns value of the given value. + +* @param val A duckdb_value containing a time_ns +* @return A duckdb_time_ns, or MinValue if the value cannot be converted +*/ +DUCKDB_C_API duckdb_time_ns duckdb_get_time_ns(duckdb_value val); + /*! Returns the time_tz value of the given value. @@ -3525,6 +3600,23 @@ If the set is incomplete or a function with this name already exists DuckDBError */ DUCKDB_C_API duckdb_state duckdb_register_scalar_function_set(duckdb_connection con, duckdb_scalar_function_set set); +/*! +Returns the number of input arguments of the scalar function. + +* @param info The bind info. +* @return The number of input arguments. +*/ +DUCKDB_C_API idx_t duckdb_scalar_function_bind_get_argument_count(duckdb_bind_info info); + +/*! +Returns the input argument at index of the scalar function. + +* @param info The bind info. +* @param index The argument index. +* @return The input argument at index. Must be destroyed with `duckdb_destroy_expression`. +*/ +DUCKDB_C_API duckdb_expression duckdb_scalar_function_bind_get_argument(duckdb_bind_info info, idx_t index); + //===--------------------------------------------------------------------===// // Selection Vector Interface //===--------------------------------------------------------------------===// @@ -3840,6 +3932,14 @@ Retrieves the extra info of the function as set in `duckdb_table_function_set_ex */ DUCKDB_C_API void *duckdb_bind_get_extra_info(duckdb_bind_info info); +/*! +Retrieves the client context of the bind info of a table function. + +* @param info The bind info object of the table function. +* @param out_context The client context of the bind info. Must be destroyed with `duckdb_destroy_client_context`. +*/ +DUCKDB_C_API void duckdb_table_function_get_client_context(duckdb_bind_info info, duckdb_client_context *out_context); + /*! Adds a result column to the output of the table function. @@ -4485,6 +4585,66 @@ DUCKDB_C_API char *duckdb_table_description_get_column_name(duckdb_table_descrip // Arrow Interface //===--------------------------------------------------------------------===// +/*! +Transforms a DuckDB Schema into an Arrow Schema + +* @param arrow_options The Arrow settings used to produce arrow. +* @param types The DuckDB logical types for each column in the schema. +* @param names The names for each column in the schema. +* @param column_count The number of columns that exist in the schema. +* @param out_schema The resulting arrow schema. Must be destroyed with `out_schema->release(out_schema)`. +* @return The error data. Must be destroyed with `duckdb_destroy_error_data`. +*/ +DUCKDB_C_API duckdb_error_data duckdb_to_arrow_schema(duckdb_arrow_options arrow_options, duckdb_logical_type *types, + const char **names, idx_t column_count, + struct ArrowSchema *out_schema); + +/*! +Transforms a DuckDB data chunk into an Arrow array. + +* @param arrow_options The Arrow settings used to produce arrow. +* @param chunk The DuckDB data chunk to convert. +* @param out_arrow_array The output Arrow structure that will hold the converted data. Must be released with +`out_arrow_array->release(out_arrow_array)` +* @return The error data. Must be destroyed with `duckdb_destroy_error_data`. +*/ +DUCKDB_C_API duckdb_error_data duckdb_data_chunk_to_arrow(duckdb_arrow_options arrow_options, duckdb_data_chunk chunk, + struct ArrowArray *out_arrow_array); + +/*! +Transforms an Arrow Schema into a DuckDB Schema. + +* @param connection The connection to get the transformation settings from. +* @param schema The input Arrow schema. Must be released with `schema->release(schema)`. +* @param out_types The Arrow converted schema with extra information about the arrow types. Must be destroyed with +`duckdb_destroy_arrow_converted_schema`. +* @return The error data. Must be destroyed with `duckdb_destroy_error_data`. +*/ +DUCKDB_C_API duckdb_error_data duckdb_schema_from_arrow(duckdb_connection connection, struct ArrowSchema *schema, + duckdb_arrow_converted_schema *out_types); + +/*! +Transforms an Arrow array into a DuckDB data chunk. The data chunk will retain ownership of the underlying Arrow data. + +* @param connection The connection to get the transformation settings from. +* @param arrow_array The input Arrow array. Data ownership is passed on to DuckDB's DataChunk, the underlying object +does not need to be released and won't have ownership of the data. +* @param converted_schema The Arrow converted schema with extra information about the arrow types. +* @param out_chunk The resulting DuckDB data chunk. Must be destroyed by duckdb_destroy_data_chunk. +* @return The error data. Must be destroyed with `duckdb_destroy_error_data`. +*/ +DUCKDB_C_API duckdb_error_data duckdb_data_chunk_from_arrow(duckdb_connection connection, + struct ArrowArray *arrow_array, + duckdb_arrow_converted_schema converted_schema, + duckdb_data_chunk *out_chunk); + +/*! +Destroys the arrow converted schema and de-allocates all memory allocated for that arrow converted schema. + +* @param arrow_converted_schema The arrow converted schema to destroy. +*/ +DUCKDB_C_API void duckdb_destroy_arrow_converted_schema(duckdb_arrow_converted_schema *arrow_converted_schema); + #ifndef DUCKDB_API_NO_DEPRECATED /*! **DEPRECATION NOTICE**: This method is scheduled for removal in a future release. @@ -4889,6 +5049,44 @@ Destroys the cast function object. */ DUCKDB_C_API void duckdb_destroy_cast_function(duckdb_cast_function *cast_function); +//===--------------------------------------------------------------------===// +// Expression Interface +//===--------------------------------------------------------------------===// + +/*! +Destroys the expression and de-allocates its memory. + +* @param expr A pointer to the expression. +*/ +DUCKDB_C_API void duckdb_destroy_expression(duckdb_expression *expr); + +/*! +Returns the return type of an expression. + +* @param expr The expression. +* @return The return type. Must be destroyed with `duckdb_destroy_logical_type`. +*/ +DUCKDB_C_API duckdb_logical_type duckdb_expression_return_type(duckdb_expression expr); + +/*! +Returns whether the expression is foldable into a value or not. + +* @param expr The expression. +* @return True, if the expression is foldable, else false. +*/ +DUCKDB_C_API bool duckdb_expression_is_foldable(duckdb_expression expr); + +/*! +Folds an expression creating a folded value. + +* @param context The client context. +* @param expr The expression. Must be foldable. +* @param out_value The folded value, if folding was successful. Must be destroyed with `duckdb_destroy_value`. +* @return The error data. Must be destroyed with `duckdb_destroy_error_data`. +*/ +DUCKDB_C_API duckdb_error_data duckdb_expression_fold(duckdb_client_context context, duckdb_expression expr, + duckdb_value *out_value); + #endif #ifdef __cplusplus diff --git a/src/duckdb/src/include/duckdb/catalog/catalog.hpp b/src/duckdb/src/include/duckdb/catalog/catalog.hpp index c36d683e0..fcf76bbbc 100644 --- a/src/duckdb/src/include/duckdb/catalog/catalog.hpp +++ b/src/duckdb/src/include/duckdb/catalog/catalog.hpp @@ -44,6 +44,7 @@ struct MetadataBlockInfo; class AttachedDatabase; class ClientContext; +class QueryContext; class Transaction; class AggregateFunctionCatalogEntry; @@ -368,12 +369,14 @@ class Catalog { const string &catalog_name); DUCKDB_API static vector> GetAllSchemas(ClientContext &context); + static vector> GetAllEntries(ClientContext &context, CatalogType catalog_type); + virtual void Verify(); static CatalogException UnrecognizedConfigurationError(ClientContext &context, const string &name); //! Autoload the extension required for `configuration_name` or throw a CatalogException - static void AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name); + static string AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name); //! Autoload the extension required for `function_name` or throw a CatalogException static bool AutoLoadExtensionByCatalogEntry(DatabaseInstance &db, CatalogType type, const string &entry_name); DUCKDB_API static bool TryAutoLoad(ClientContext &context, const string &extension_name) noexcept; diff --git a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp index ca580544e..c52f2c4e6 100644 --- a/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp +++ b/src/duckdb/src/include/duckdb/catalog/default/builtin_types/types.hpp @@ -83,7 +83,7 @@ static constexpr const builtin_type_array BUILTIN_TYPES{{ {"union", LogicalTypeId::UNION}, {"bit", LogicalTypeId::BIT}, {"bitstring", LogicalTypeId::BIT}, - {"varint", LogicalTypeId::VARINT}, + {"bignum", LogicalTypeId::BIGNUM}, {"boolean", LogicalTypeId::BOOLEAN}, {"bool", LogicalTypeId::BOOLEAN}, {"logical", LogicalTypeId::BOOLEAN}, diff --git a/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp index 365de229a..391744f48 100644 --- a/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp +++ b/src/duckdb/src/include/duckdb/common/adbc/adbc.hpp @@ -10,10 +10,78 @@ #include "duckdb/common/adbc/adbc.h" +#include "duckdb/main/capi/capi_internal.hpp" + #include namespace duckdb_adbc { +class AppenderWrapper { +public: + AppenderWrapper(duckdb_connection conn, const char *schema, const char *table) : appender(nullptr) { + if (duckdb_appender_create(conn, schema, table, &appender) != DuckDBSuccess) { + appender = nullptr; + } + } + ~AppenderWrapper() { + if (appender) { + duckdb_appender_destroy(&appender); + } + } + + duckdb_appender Get() const { + return appender; + } + bool Valid() const { + return appender != nullptr; + } + +private: + duckdb_appender appender; +}; + +class DataChunkWrapper { +public: + DataChunkWrapper() : chunk(nullptr) { + } + + ~DataChunkWrapper() { + if (chunk) { + duckdb_destroy_data_chunk(&chunk); + } + } + + explicit operator duckdb_data_chunk() const { + return chunk; + } + + duckdb_data_chunk chunk; +}; + +class ConvertedSchemaWrapper { +public: + ConvertedSchemaWrapper() : schema(nullptr) { + } + ~ConvertedSchemaWrapper() { + if (schema) { + duckdb_destroy_arrow_converted_schema(&schema); + } + } + duckdb_arrow_converted_schema *GetPtr() { + return &schema; + } + + explicit operator duckdb_arrow_converted_schema() const { + return schema; + } + duckdb_arrow_converted_schema Get() const { + return schema; + } + +private: + duckdb_arrow_converted_schema schema; +}; + AdbcStatusCode DatabaseNew(struct AdbcDatabase *database, struct AdbcError *error); AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, @@ -90,4 +158,3 @@ void InitializeADBCError(AdbcError *error); //! This method should only be called when the string is guaranteed to not be NULL void SetError(struct AdbcError *error, const std::string &message); -// void SetError(struct AdbcError *error, const char *message); diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp index c136100ae..4a1e594d0 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/append_data.hpp @@ -42,6 +42,7 @@ struct ArrowAppendData { arrow_buffers.resize(3); } +public: //! Getters for the Buffers ArrowBuffer &GetValidityBuffer() { return arrow_buffers[0]; @@ -63,6 +64,36 @@ struct ArrowAppendData { return arrow_buffers[3]; } +public: + static void GetBitPosition(idx_t row_idx, idx_t ¤t_byte, uint8_t ¤t_bit) { + current_byte = row_idx / 8; + current_bit = row_idx % 8; + } + + static void UnsetBit(uint8_t *data, idx_t current_byte, uint8_t current_bit) { + data[current_byte] &= ~((uint64_t)1 << current_bit); + } + + static void NextBit(idx_t ¤t_byte, uint8_t ¤t_bit) { + current_bit++; + if (current_bit == 8) { + current_byte++; + current_bit = 0; + } + } + + static void ResizeValidity(ArrowBuffer &buffer, idx_t row_count) { + auto byte_count = (row_count + 7) / 8; + buffer.resize(byte_count, 0xFF); + } + + void SetNull(uint8_t *validity_data, idx_t current_byte, uint8_t current_bit) { + UnsetBit(validity_data, current_byte, current_bit); + null_count++; + } + void AppendValidity(UnifiedVectorFormat &format, idx_t from, idx_t to); + +public: idx_t row_count = 0; idx_t null_count = 0; @@ -93,59 +124,4 @@ struct ArrowAppendData { vector arrow_buffers; }; -//===--------------------------------------------------------------------===// -// Append Helper Functions -//===--------------------------------------------------------------------===// - -static void GetBitPosition(idx_t row_idx, idx_t ¤t_byte, uint8_t ¤t_bit) { - current_byte = row_idx / 8; - current_bit = row_idx % 8; -} - -static void UnsetBit(uint8_t *data, idx_t current_byte, uint8_t current_bit) { - data[current_byte] &= ~((uint64_t)1 << current_bit); -} - -static void NextBit(idx_t ¤t_byte, uint8_t ¤t_bit) { - current_bit++; - if (current_bit == 8) { - current_byte++; - current_bit = 0; - } -} - -static void ResizeValidity(ArrowBuffer &buffer, idx_t row_count) { - auto byte_count = (row_count + 7) / 8; - buffer.resize(byte_count, 0xFF); -} - -static void SetNull(ArrowAppendData &append_data, uint8_t *validity_data, idx_t current_byte, uint8_t current_bit) { - UnsetBit(validity_data, current_byte, current_bit); - append_data.null_count++; -} - -static void AppendValidity(ArrowAppendData &append_data, UnifiedVectorFormat &format, idx_t from, idx_t to) { - // resize the buffer, filling the validity buffer with all valid values - idx_t size = to - from; - ResizeValidity(append_data.GetValidityBuffer(), append_data.row_count + size); - if (format.validity.AllValid()) { - // if all values are valid we don't need to do anything else - return; - } - - // otherwise we iterate through the validity mask - auto validity_data = (uint8_t *)append_data.GetValidityBuffer().data(); - uint8_t current_bit; - idx_t current_byte; - GetBitPosition(append_data.row_count, current_byte, current_bit); - for (idx_t i = from; i < to; i++) { - auto source_idx = format.sel->get_index(i); - // append the validity mask - if (!format.validity.RowIsValid(source_idx)) { - SetNull(append_data, validity_data, current_byte, current_bit); - } - NextBit(current_byte, current_bit); - } -} - } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/enum_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/enum_data.hpp index a5d066abb..20cb1a1b0 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/enum_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/enum_data.hpp @@ -35,7 +35,7 @@ struct ArrowEnumData : public ArrowScalarBaseData { auto &main_buffer = append_data.GetMainBuffer(); auto &aux_buffer = append_data.GetAuxBuffer(); // resize the validity mask and set up the validity buffer for iteration - ResizeValidity(append_data.GetValidityBuffer(), append_data.row_count + size); + ArrowAppendData::ResizeValidity(append_data.GetValidityBuffer(), append_data.row_count + size); // resize the offset buffer - the offset buffer holds the offsets into the child array main_buffer.resize(main_buffer.size() + sizeof(int32_t) * (size + 1)); diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp index 21274da48..627e5fbbc 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list_data.hpp @@ -28,7 +28,7 @@ struct ArrowListData { input.ToUnifiedFormat(input_size, format); idx_t size = to - from; vector child_indices; - AppendValidity(append_data, format, from, to); + append_data.AppendValidity(format, from, to); AppendOffsets(append_data, format, from, to, child_indices); // append the child vector of the list diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp index f46b316dd..f326b7648 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/list_view_data.hpp @@ -30,7 +30,7 @@ struct ArrowListViewData { input.ToUnifiedFormat(input_size, format); idx_t size = to - from; vector child_indices; - AppendValidity(append_data, format, from, to); + append_data.AppendValidity(format, from, to); AppendListMetadata(append_data, format, from, to, child_indices); // append the child vector of the list diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp index 7f82f401f..b73b0016b 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/map_data.hpp @@ -39,7 +39,7 @@ struct ArrowMapData { UnifiedVectorFormat format; input.ToUnifiedFormat(input_size, format); idx_t size = to - from; - AppendValidity(append_data, format, from, to); + append_data.AppendValidity(format, from, to); vector child_indices; ArrowListData::AppendOffsets(append_data, format, from, to, child_indices); diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp index a6729b460..e28c002ee 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/scalar_data.hpp @@ -96,7 +96,7 @@ struct ArrowScalarBaseData { input.ToUnifiedFormat(input_size, format); // append the validity mask - AppendValidity(append_data, format, from, to); + append_data.AppendValidity(format, from, to); // append the main data auto &main_buffer = append_data.GetMainBuffer(); diff --git a/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp b/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp index 45e9c14a9..87e45f5dc 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/appender/varchar_data.hpp @@ -59,7 +59,7 @@ struct ArrowVarcharData { auto &aux_buffer = append_data.GetAuxBuffer(); // resize the validity mask and set up the validity buffer for iteration - ResizeValidity(validity_buffer, append_data.row_count + size); + ArrowAppendData::ResizeValidity(validity_buffer, append_data.row_count + size); auto validity_data = (uint8_t *)validity_buffer.data(); // resize the offset buffer - the offset buffer holds the offsets into the child array @@ -80,8 +80,8 @@ struct ArrowVarcharData { if (!format.validity.RowIsValid(source_idx)) { uint8_t current_bit; idx_t current_byte; - GetBitPosition(append_data.row_count + i - from, current_byte, current_bit); - SetNull(append_data, validity_data, current_byte, current_bit); + ArrowAppendData::GetBitPosition(append_data.row_count + i - from, current_byte, current_bit); + append_data.SetNull(validity_data, current_byte, current_bit); offset_data[offset_idx] = last_offset; continue; } @@ -141,7 +141,7 @@ struct ArrowVarcharToStringViewData { auto &validity_buffer = append_data.GetValidityBuffer(); auto &aux_buffer = append_data.GetAuxBuffer(); // resize the validity mask and set up the validity buffer for iteration - ResizeValidity(validity_buffer, append_data.row_count + size); + ArrowAppendData::ResizeValidity(validity_buffer, append_data.row_count + size); auto validity_data = (uint8_t *)validity_buffer.data(); main_buffer.resize(main_buffer.size() + sizeof(arrow_string_view_t) * (size)); @@ -155,8 +155,8 @@ struct ArrowVarcharToStringViewData { // Null value uint8_t current_bit; idx_t current_byte; - GetBitPosition(result_idx, current_byte, current_bit); - SetNull(append_data, validity_data, current_byte, current_bit); + ArrowAppendData::GetBitPosition(result_idx, current_byte, current_bit); + append_data.SetNull(validity_data, current_byte, current_bit); // We have to set these bytes to 0, for some reason arrow_data[result_idx] = arrow_string_view_t(0, ""); continue; diff --git a/src/duckdb/src/include/duckdb/common/arrow/arrow_wrapper.hpp b/src/duckdb/src/include/duckdb/common/arrow/arrow_wrapper.hpp index 0639da539..9d0fd49f5 100644 --- a/src/duckdb/src/include/duckdb/common/arrow/arrow_wrapper.hpp +++ b/src/duckdb/src/include/duckdb/common/arrow/arrow_wrapper.hpp @@ -33,6 +33,16 @@ class ArrowArrayWrapper { ArrowArrayWrapper(ArrowArrayWrapper &&other) noexcept : arrow_array(other.arrow_array) { other.arrow_array.release = nullptr; } + ArrowArrayWrapper &operator=(ArrowArrayWrapper &&other) noexcept { + if (this != &other) { + if (arrow_array.release) { + arrow_array.release(&arrow_array); + } + arrow_array = other.arrow_array; + other.arrow_array.release = nullptr; + } + return *this; + } ~ArrowArrayWrapper(); }; diff --git a/src/duckdb/src/include/duckdb/common/bignum.hpp b/src/duckdb/src/include/duckdb/common/bignum.hpp new file mode 100644 index 000000000..cefddf23e --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/bignum.hpp @@ -0,0 +1,85 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/bignum.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/winapi.hpp" +#include "duckdb/common/string.hpp" +#include +#include "duckdb/common/types/string_type.hpp" +#include "duckdb/storage/arena_allocator.hpp" + +namespace duckdb { + +struct bignum_t { + string_t data; + + bignum_t() : data() { + } + + explicit bignum_t(const string_t &data) : data(data) { + } + + bignum_t(const bignum_t &rhs) = default; + bignum_t(bignum_t &&other) = default; + bignum_t &operator=(const bignum_t &rhs) = default; + bignum_t &operator=(bignum_t &&rhs) = default; + + void Print() const; +}; + +enum AbsoluteNumberComparison : uint8_t { + // If number is equal + EQUAL = 0, + // If compared number is greater + GREATER = 1, + // If compared number is smaller + SMALLER = 2, +}; + +struct BignumIntermediate { + BignumIntermediate() : is_negative(false), size(0), data(nullptr) {}; + explicit BignumIntermediate(const bignum_t &value); + BignumIntermediate(uint8_t *value, idx_t size); + void Print() const; + //! Information on the header + bool is_negative; + uint32_t size; + //! The actual data + data_ptr_t data; + //! If the absolute number is bigger than the absolute rhs + //! 1 = true, 0 = equal, -1 = false + AbsoluteNumberComparison IsAbsoluteBigger(const BignumIntermediate &rhs) const; + //! Get the absolute value of a byte + uint8_t GetAbsoluteByte(int64_t index) const; + //! If the most significant bit of the first byte is set. + bool IsMSBSet() const; + //! Initializes our bignum to 0 and 1 byte + void Initialize(ArenaAllocator &allocator); + //! If necessary, we reallocate our intermediate to the next power of 2. + void Reallocate(ArenaAllocator &allocator, idx_t min_size); + static uint32_t GetStartDataPos(data_ptr_t data, idx_t size, bool is_negative); + uint32_t GetStartDataPos() const; + //! In case we have unnecessary extra 0's or 1's in our bignum we trim them + static idx_t Trim(data_ptr_t data, uint32_t &size, bool is_negative); + void Trim(); + //! Add a BignumIntermediate to another BignumIntermediate, equivalent of a += + void AddInPlace(ArenaAllocator &allocator, const BignumIntermediate &rhs); + //! Adds two BignumIntermediates and returns a string_t result, equivalent of a + + static string_t Add(Vector &result, const BignumIntermediate &lhs, const BignumIntermediate &rhs); + //! Negates a value, e.g., -x + string_t Negate(Vector &result_vector) const; + void NegateInPlace(); + //! Exports to a bignum, either arena allocated + bignum_t ToBignum(ArenaAllocator &allocator); + //! Check if an over/underflow has occurred + static bool OverOrUnderflow(data_ptr_t data, idx_t size, bool is_negative); + bool OverOrUnderflow() const; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/encryption_functions.hpp b/src/duckdb/src/include/duckdb/common/encryption_functions.hpp index a52d633e9..8106eee7f 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_functions.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_functions.hpp @@ -4,10 +4,7 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/common/encryption_state.hpp" #include "duckdb/common/encryption_key_manager.hpp" - -#ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/object_cache.hpp" -#endif namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp b/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp index 24905fe29..55c3aed75 100644 --- a/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp +++ b/src/duckdb/src/include/duckdb/common/encryption_key_manager.hpp @@ -12,10 +12,7 @@ #include "duckdb/common/types/value.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/types.hpp" - -#ifndef DUCKDB_AMALGAMATION #include "duckdb/storage/object_cache.hpp" -#endif namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/common/enum_util.hpp b/src/duckdb/src/include/duckdb/common/enum_util.hpp index 98eb64bbc..173329238 100644 --- a/src/duckdb/src/include/duckdb/common/enum_util.hpp +++ b/src/duckdb/src/include/duckdb/common/enum_util.hpp @@ -64,8 +64,12 @@ enum class AlterViewType : uint8_t; enum class AppenderType : uint8_t; +enum class ArrowArrayPhysicalType : uint8_t; + enum class ArrowDateTimeType : uint8_t; +enum class ArrowFormatVersion : uint8_t; + enum class ArrowOffsetSize : uint8_t; enum class ArrowTypeInfoType : uint8_t; @@ -202,6 +206,8 @@ enum class HTTPStatusCode : uint16_t; enum class IndexAppendMode : uint8_t; +enum class IndexBindState : uint8_t; + enum class IndexConstraintType : uint8_t; enum class InsertColumnOrder : uint8_t; @@ -276,6 +282,8 @@ enum class OrderPreservationType : uint8_t; enum class OrderType : uint8_t; +enum class OrdinalityType : uint8_t; + enum class OutputStream : uint8_t; enum class ParseInfoType : uint8_t; @@ -430,6 +438,8 @@ enum class WindowBoundary : uint8_t; enum class WindowExcludeMode : uint8_t; +enum class WindowMergeSortStage : uint8_t; + template<> const char* EnumUtil::ToChars(ARTConflictType value); @@ -479,9 +489,15 @@ const char* EnumUtil::ToChars(AlterViewType value); template<> const char* EnumUtil::ToChars(AppenderType value); +template<> +const char* EnumUtil::ToChars(ArrowArrayPhysicalType value); + template<> const char* EnumUtil::ToChars(ArrowDateTimeType value); +template<> +const char* EnumUtil::ToChars(ArrowFormatVersion value); + template<> const char* EnumUtil::ToChars(ArrowOffsetSize value); @@ -686,6 +702,9 @@ const char* EnumUtil::ToChars(HTTPStatusCode value); template<> const char* EnumUtil::ToChars(IndexAppendMode value); +template<> +const char* EnumUtil::ToChars(IndexBindState value); + template<> const char* EnumUtil::ToChars(IndexConstraintType value); @@ -797,6 +816,9 @@ const char* EnumUtil::ToChars(OrderPreservationType value template<> const char* EnumUtil::ToChars(OrderType value); +template<> +const char* EnumUtil::ToChars(OrdinalityType value); + template<> const char* EnumUtil::ToChars(OutputStream value); @@ -1028,6 +1050,9 @@ const char* EnumUtil::ToChars(WindowBoundary value); template<> const char* EnumUtil::ToChars(WindowExcludeMode value); +template<> +const char* EnumUtil::ToChars(WindowMergeSortStage value); + template<> ARTConflictType EnumUtil::FromString(const char *value); @@ -1077,9 +1102,15 @@ AlterViewType EnumUtil::FromString(const char *value); template<> AppenderType EnumUtil::FromString(const char *value); +template<> +ArrowArrayPhysicalType EnumUtil::FromString(const char *value); + template<> ArrowDateTimeType EnumUtil::FromString(const char *value); +template<> +ArrowFormatVersion EnumUtil::FromString(const char *value); + template<> ArrowOffsetSize EnumUtil::FromString(const char *value); @@ -1284,6 +1315,9 @@ HTTPStatusCode EnumUtil::FromString(const char *value); template<> IndexAppendMode EnumUtil::FromString(const char *value); +template<> +IndexBindState EnumUtil::FromString(const char *value); + template<> IndexConstraintType EnumUtil::FromString(const char *value); @@ -1395,6 +1429,9 @@ OrderPreservationType EnumUtil::FromString(const char *va template<> OrderType EnumUtil::FromString(const char *value); +template<> +OrdinalityType EnumUtil::FromString(const char *value); + template<> OutputStream EnumUtil::FromString(const char *value); @@ -1626,5 +1663,8 @@ WindowBoundary EnumUtil::FromString(const char *value); template<> WindowExcludeMode EnumUtil::FromString(const char *value); +template<> +WindowMergeSortStage EnumUtil::FromString(const char *value); + } diff --git a/src/duckdb/src/include/duckdb/common/enums/arrow_format_version.hpp b/src/duckdb/src/include/duckdb/common/enums/arrow_format_version.hpp new file mode 100644 index 000000000..58deb43b6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/arrow_format_version.hpp @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/arrow_format_version.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/constants.hpp" + +namespace duckdb { + +enum class ArrowOffsetSize : uint8_t { REGULAR, LARGE }; + +enum class ArrowFormatVersion : uint8_t { + //! Base Version + V1_0 = 10, + //! Added 256-bit Decimal type. + V1_1 = 11, + //! Added MonthDayNano interval type. + V1_2 = 12, + //! Added Run-End Encoded Layout. + V1_3 = 13, + //! Added Variable-size Binary View Layout and the associated BinaryView and Utf8View types. + //! Added ListView Layout and the associated ListView and LargeListView types. Added Variadic buffers. + V1_4 = 14, + //! Expanded Decimal type bit widths to allow 32-bit and 64-bit types. + V1_5 = 15 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/checkpoint_abort.hpp b/src/duckdb/src/include/duckdb/common/enums/checkpoint_abort.hpp new file mode 100644 index 000000000..321fc25b6 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/checkpoint_abort.hpp @@ -0,0 +1,22 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/checkpoint_abort.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" + +namespace duckdb { + +enum class CheckpointAbort : uint8_t { + NO_ABORT = 0, + DEBUG_ABORT_BEFORE_TRUNCATE = 1, + DEBUG_ABORT_BEFORE_HEADER = 2, + DEBUG_ABORT_AFTER_FREE_LIST_WRITE = 3 +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/checkpoint_type.hpp b/src/duckdb/src/include/duckdb/common/enums/checkpoint_type.hpp index 9ad3b4529..7610b1abe 100644 --- a/src/duckdb/src/include/duckdb/common/enums/checkpoint_type.hpp +++ b/src/duckdb/src/include/duckdb/common/enums/checkpoint_type.hpp @@ -32,7 +32,9 @@ enum class CheckpointType { FULL_CHECKPOINT, //! Concurrent checkpoints write committed data to disk but do less clean-up //! They can be run even when active transactions need to read old data - CONCURRENT_CHECKPOINT + CONCURRENT_CHECKPOINT, + //! Only run vacuum - this can be triggered for in-memory tables + VACUUM_ONLY }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/enums/ordinality_request_type.hpp b/src/duckdb/src/include/duckdb/common/enums/ordinality_request_type.hpp new file mode 100644 index 000000000..c2234dea8 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/enums/ordinality_request_type.hpp @@ -0,0 +1,16 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/enums/ordinality_request_type.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once +#include "duckdb/common/typedefs.hpp" + +namespace duckdb { + +enum class OrdinalityType : uint8_t { WITHOUT_ORDINALITY = 0, WITH_ORDINALITY = 1 }; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/error_data.hpp b/src/duckdb/src/include/duckdb/common/error_data.hpp index 7f7c65f41..388c12d5f 100644 --- a/src/duckdb/src/include/duckdb/common/error_data.hpp +++ b/src/duckdb/src/include/duckdb/common/error_data.hpp @@ -22,7 +22,7 @@ class ErrorData { //! From std::exception DUCKDB_API ErrorData(const std::exception &ex); // NOLINT: allow implicit construction from exception //! From a raw string and exception type - DUCKDB_API explicit ErrorData(ExceptionType type, const string &raw_message); + DUCKDB_API ErrorData(ExceptionType type, const string &raw_message); //! From a raw string DUCKDB_API explicit ErrorData(const string &raw_message); @@ -38,6 +38,7 @@ class ErrorData { DUCKDB_API const string &RawMessage() const { return raw_message; } + DUCKDB_API void Merge(const ErrorData &other); DUCKDB_API bool operator==(const ErrorData &other) const; //! Returns true, if this error data contains an exception, else false. diff --git a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp index 34be78c09..3f5db0a43 100644 --- a/src/duckdb/src/include/duckdb/common/exception_format_value.hpp +++ b/src/duckdb/src/include/duckdb/common/exception_format_value.hpp @@ -49,6 +49,7 @@ enum class ExceptionFormatValueType : uint8_t { struct ExceptionFormatValue { DUCKDB_API ExceptionFormatValue(double dbl_val); // NOLINT DUCKDB_API ExceptionFormatValue(int64_t int_val); // NOLINT + DUCKDB_API ExceptionFormatValue(idx_t uint_val); // NOLINT DUCKDB_API ExceptionFormatValue(string str_val); // NOLINT DUCKDB_API ExceptionFormatValue(hugeint_t hg_val); // NOLINT DUCKDB_API ExceptionFormatValue(uhugeint_t uhg_val); // NOLINT @@ -56,7 +57,7 @@ struct ExceptionFormatValue { ExceptionFormatValueType type; double dbl_val = 0; - int64_t int_val = 0; + hugeint_t int_val = 0; string str_val; public: @@ -86,6 +87,8 @@ DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const ch template <> DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value); template <> +DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(idx_t value); +template <> DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value); template <> DUCKDB_API ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value); diff --git a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp index d97491326..d5e35ee96 100644 --- a/src/duckdb/src/include/duckdb/common/extra_type_info.hpp +++ b/src/duckdb/src/include/duckdb/common/extra_type_info.hpp @@ -27,7 +27,8 @@ enum class ExtraTypeInfoType : uint8_t { AGGREGATE_STATE_TYPE_INFO = 8, ARRAY_TYPE_INFO = 9, ANY_TYPE_INFO = 10, - INTEGER_LITERAL_TYPE_INFO = 11 + INTEGER_LITERAL_TYPE_INFO = 11, + TEMPLATE_TYPE_INFO = 12 }; struct ExtraTypeInfo { @@ -259,4 +260,22 @@ struct IntegerLiteralTypeInfo : public ExtraTypeInfo { IntegerLiteralTypeInfo(); }; +struct TemplateTypeInfo : public ExtraTypeInfo { + + explicit TemplateTypeInfo(string name_p); + + // The name of the template, e.g. `T`, or `KEY_TYPE`. Used to distinguish between different template types within + // the same function. The binder tries to resolve all templates with the same name to the same concrete type. + string name; + +public: + void Serialize(Serializer &serializer) const override; + static shared_ptr Deserialize(Deserializer &source); + shared_ptr Copy() const override; + +protected: + bool EqualsInternal(ExtraTypeInfo *other_p) const override; + TemplateTypeInfo(); +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/file_buffer.hpp b/src/duckdb/src/include/duckdb/common/file_buffer.hpp index 0080da022..f330854c2 100644 --- a/src/duckdb/src/include/duckdb/common/file_buffer.hpp +++ b/src/duckdb/src/include/duckdb/common/file_buffer.hpp @@ -45,7 +45,7 @@ class FileBuffer { public: //! Read into the FileBuffer from the location. - void Read(FileHandle &handle, uint64_t location); + void Read(QueryContext context, FileHandle &handle, uint64_t location); //! Write the FileBuffer to the location. void Write(QueryContext context, FileHandle &handle, const uint64_t location); diff --git a/src/duckdb/src/include/duckdb/common/file_opener.hpp b/src/duckdb/src/include/duckdb/common/file_opener.hpp index 349dfbd92..546fa038d 100644 --- a/src/duckdb/src/include/duckdb/common/file_opener.hpp +++ b/src/duckdb/src/include/duckdb/common/file_opener.hpp @@ -10,13 +10,14 @@ #include "duckdb/common/string.hpp" #include "duckdb/common/winapi.hpp" -#include "duckdb/main/settings.hpp" +#include "duckdb/main/setting_info.hpp" namespace duckdb { struct CatalogTransaction; class SecretManager; class ClientContext; +class HTTPUtil; class Value; class Logger; diff --git a/src/duckdb/src/include/duckdb/common/file_system.hpp b/src/duckdb/src/include/duckdb/common/file_system.hpp index 62059ca94..97e7691d2 100644 --- a/src/duckdb/src/include/duckdb/common/file_system.hpp +++ b/src/duckdb/src/include/duckdb/common/file_system.hpp @@ -68,6 +68,7 @@ struct FileHandle { // Read at [nr_bytes] bytes into [buffer]. // File offset will not be changed. DUCKDB_API void Read(void *buffer, idx_t nr_bytes, idx_t location); + DUCKDB_API void Read(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location); DUCKDB_API void Write(QueryContext context, void *buffer, idx_t nr_bytes, idx_t location); DUCKDB_API void Seek(idx_t location); DUCKDB_API void Reset(); diff --git a/src/duckdb/src/include/duckdb/common/helper.hpp b/src/duckdb/src/include/duckdb/common/helper.hpp index 747ebbf22..d5fb4b465 100644 --- a/src/duckdb/src/include/duckdb/common/helper.hpp +++ b/src/duckdb/src/include/duckdb/common/helper.hpp @@ -58,7 +58,7 @@ struct TemplatedUniqueIf }; template -inline +inline typename TemplatedUniqueIf::templated_unique_single_t make_uniq(ARGS&&... args) // NOLINT: mimic std style { @@ -66,7 +66,7 @@ make_uniq(ARGS&&... args) // NOLINT: mimic std style } template -inline +inline shared_ptr make_shared_ptr(ARGS&&... args) // NOLINT: mimic std style { @@ -74,7 +74,7 @@ make_shared_ptr(ARGS&&... args) // NOLINT: mimic std style } template -inline +inline typename TemplatedUniqueIf::templated_unique_single_t make_unsafe_uniq(ARGS&&... args) // NOLINT: mimic std style { @@ -274,4 +274,17 @@ void DynamicCastCheck(const SRC *source) { #endif } +//! Used to increment counters that need to be exception-proof +template +class PostIncrement { +public: + explicit PostIncrement(T &t) : t(t) { + } + ~PostIncrement() { + ++t; + } +private: + T &t; +}; + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp index 172601c71..534b979f8 100644 --- a/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp +++ b/src/duckdb/src/include/duckdb/common/multi_file/multi_file_function.hpp @@ -181,11 +181,11 @@ class MultiFileFunction : public TableFunction { std::move(file_options), std::move(options), std::move(interface)); } - static unique_ptr MultiFileBindCopy(ClientContext &context, CopyInfo &info, + static unique_ptr MultiFileBindCopy(ClientContext &context, CopyFromFunctionBindInput &input, vector &expected_names, vector &expected_types) { auto multi_file_reader = MultiFileReader::CreateDefault("COPY"); - vector paths = {info.file_path}; + vector paths = {input.info.file_path}; auto file_list = multi_file_reader->CreateFileList(context, paths); auto interface = OP::InitializeInterface(context, *multi_file_reader, *file_list); @@ -193,7 +193,7 @@ class MultiFileFunction : public TableFunction { auto options = interface->InitializeOptions(context, nullptr); MultiFileOptions file_options; - for (auto &option : info.options) { + for (auto &option : input.info.options) { auto loption = StringUtil::Lower(option.first); if (interface->ParseCopyOption(context, loption, option.second, *options, expected_names, expected_types)) { continue; diff --git a/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp new file mode 100644 index 000000000..fb30ea687 --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/sorting/hashed_sort.hpp @@ -0,0 +1,162 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/sorting/hashed_sort.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/radix_partitioning.hpp" +#include "duckdb/common/sorting/sort.hpp" +#include "duckdb/parallel/base_pipeline_event.hpp" + +namespace duckdb { + +// Formerly PartitionGlobalHashGroup +class HashedSortGroup { +public: + using Orders = vector; + using Types = vector; + using OrderMasks = unordered_map; + + HashedSortGroup(ClientContext &context, const Orders &orders, const Types &input_types, idx_t group_idx); + + const idx_t group_idx; + + // Sink + unique_ptr sort; + unique_ptr sort_global; + + // Source + atomic tasks_completed; + unique_ptr sort_source; + unique_ptr sorted; +}; + +// Formerly PartitionGlobalSinkState +class HashedSortGlobalSinkState { +public: + using HashGroupPtr = unique_ptr; + using Orders = vector; + using Types = vector; + + using GroupingPartition = unique_ptr; + using GroupingAppend = unique_ptr; + + static void GenerateOrderings(Orders &partitions, Orders &orders, + const vector> &partition_bys, const Orders &order_bys, + const vector> &partitions_stats); + + HashedSortGlobalSinkState(ClientContext &context, const vector> &partition_bys, + const vector &order_bys, const Types &payload_types, + const vector> &partitions_stats, idx_t estimated_cardinality); + + bool HasMergeTasks() const; + + // OVER(PARTITION BY...) (hash grouping) + unique_ptr CreatePartition(idx_t new_bits) const; + void UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &partition_append); + void CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); + void Finalize(ClientContext &context, InterruptState &interrupt_state); + + //! System and query state + ClientContext &context; + BufferManager &buffer_manager; + Allocator &allocator; + mutex lock; + + // OVER(PARTITION BY...) (hash grouping) + GroupingPartition grouping_data; + //! Payload plus hash column + shared_ptr grouping_types_ptr; + //! The number of radix bits if this partition is being synced with another + idx_t fixed_bits; + + // OVER(...) (sorting) + Orders partitions; + Orders orders; + Types payload_types; + vector hash_groups; + // Input columns in the sorted output + vector scan_ids; + // Key columns in the sorted output + vector sort_ids; + // Key columns that must be computed + vector> sort_exprs; + + // OVER() (no sorting) + unique_ptr unsorted; + + // Threading + idx_t max_bits; + atomic count; + +private: + void Rehash(idx_t cardinality); + void SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append); +}; + +// Formerly PartitionLocalSinkState +class HashedSortLocalSinkState { +public: + using LocalSortStatePtr = unique_ptr; + using GroupingPartition = unique_ptr; + using GroupingAppend = unique_ptr; + + HashedSortLocalSinkState(ExecutionContext &context, HashedSortGlobalSinkState &gstate); + + //! Global state + HashedSortGlobalSinkState &gstate; + Allocator &allocator; + + //! Shared expression evaluation + ExpressionExecutor hash_exec; + ExpressionExecutor sort_exec; + DataChunk group_chunk; + DataChunk sort_chunk; + DataChunk payload_chunk; + size_t sort_col_count; + + //! Compute the hash values + void Hash(DataChunk &input_chunk, Vector &hash_vector); + //! Sink an input chunk + void Sink(ExecutionContext &context, DataChunk &input_chunk); + //! Merge the state into the global state. + void Combine(ExecutionContext &context); + + // OVER(PARTITION BY...) (hash grouping) + GroupingPartition local_grouping; + GroupingAppend grouping_append; + + // OVER(ORDER BY...) (only sorting) + LocalSortStatePtr sort_local; + InterruptState interrupt; + + // OVER() (no sorting) + unique_ptr unsorted; + ColumnDataAppendState unsorted_append; +}; + +class HashedSortCallback { +public: + virtual ~HashedSortCallback() = default; + virtual void OnSortedGroup(HashedSortGroup &hash_group) = 0; +}; + +// Formerly PartitionMergeEvent +class HashedSortMaterializeEvent : public BasePipelineEvent { +public: + HashedSortMaterializeEvent(HashedSortGlobalSinkState &gstate, Pipeline &pipeline, const PhysicalOperator &op, + HashedSortCallback *callback); + + HashedSortGlobalSinkState &gstate; + const PhysicalOperator &op; + optional_ptr callback; + +public: + void Schedule() override; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp index edf45dfcc..597b8261b 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sort.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sort.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/sorting/sorted_run.hpp" #include "duckdb/common/types/row/tuple_data_layout.hpp" #include "duckdb/execution/physical_operator_states.hpp" #include "duckdb/common/sorting/sort_projection_column.hpp" @@ -69,6 +70,16 @@ class Sort { OperatorPartitionData GetPartitionData(ExecutionContext &context, DataChunk &chunk, GlobalSourceState &gstate, LocalSourceState &lstate, const OperatorPartitionInfo &partition_info) const; ProgressData GetProgress(ClientContext &context, GlobalSourceState &gstate) const; + +public: + //===--------------------------------------------------------------------===// + // Non-Standard Interface + //===--------------------------------------------------------------------===// + SourceResultType MaterializeColumnData(ExecutionContext &context, OperatorSourceInput &input) const; + unique_ptr GetColumnData(OperatorSourceInput &input) const; + + SourceResultType MaterializeSortedRun(ExecutionContext &context, OperatorSourceInput &input) const; + unique_ptr GetSortedRun(GlobalSourceState &global_state); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp index cabea9146..fe0d67e32 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run.hpp @@ -21,7 +21,7 @@ class SortedRun { public: SortedRun(ClientContext &context, shared_ptr key_layout, shared_ptr payload_layout, bool is_index_sort); - + unique_ptr CreateRunForMaterialization() const; ~SortedRun(); public: diff --git a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp index 2007c6c11..21a56df83 100644 --- a/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp +++ b/src/duckdb/src/include/duckdb/common/sorting/sorted_run_merger.hpp @@ -40,6 +40,13 @@ class SortedRunMerger { LocalSourceState &lstate, const OperatorPartitionInfo &partition_info) const; ProgressData GetProgress(ClientContext &context, GlobalSourceState &gstate) const; +public: + //===--------------------------------------------------------------------===// + // Non-Standard Interface + //===--------------------------------------------------------------------===// + SourceResultType MaterializeMerge(ExecutionContext &context, OperatorSourceInput &input) const; + unique_ptr GetMaterialized(GlobalSourceState &global_state); + public: const Expression &decode_sort_key; shared_ptr key_layout; diff --git a/src/duckdb/src/include/duckdb/common/tree_renderer/text_tree_renderer.hpp b/src/duckdb/src/include/duckdb/common/tree_renderer/text_tree_renderer.hpp index 1bf43191a..fe1b1e2b7 100644 --- a/src/duckdb/src/include/duckdb/common/tree_renderer/text_tree_renderer.hpp +++ b/src/duckdb/src/include/duckdb/common/tree_renderer/text_tree_renderer.hpp @@ -37,6 +37,10 @@ struct TextTreeRendererConfig { idx_t max_extra_lines = 30; bool detailed = false; + // Formatting options + char thousand_separator = ','; + char decimal_separator = '.'; + #ifndef DUCKDB_ASCII_TREE_RENDERER const char *LTCORNER = "\342\224\214"; // NOLINT "┌"; const char *RTCORNER = "\342\224\220"; // NOLINT "┐"; @@ -115,6 +119,7 @@ class TextTreeRenderer : public TreeRenderer { void SplitUpExtraInfo(const InsertionOrderPreservingMap &extra_info, vector &result, idx_t max_lines); void SplitStringBuffer(const string &source, vector &result); + string FormatNumber(const string &input); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/type_util.hpp b/src/duckdb/src/include/duckdb/common/type_util.hpp index ef490763b..40a3eb872 100644 --- a/src/duckdb/src/include/duckdb/common/type_util.hpp +++ b/src/duckdb/src/include/duckdb/common/type_util.hpp @@ -17,6 +17,7 @@ #include "duckdb/common/types/double_na_equal.hpp" namespace duckdb { +struct bignum_t; //! Returns the PhysicalType for the given type template @@ -67,7 +68,8 @@ PhysicalType GetTypeId() { return PhysicalType::FLOAT; } else if (std::is_same() || std::is_same()) { return PhysicalType::DOUBLE; - } else if (std::is_same() || std::is_same() || std::is_same()) { + } else if (std::is_same() || std::is_same() || std::is_same() || + std::is_same()) { return PhysicalType::VARCHAR; } else if (std::is_same()) { return PhysicalType::INTERVAL; diff --git a/src/duckdb/src/include/duckdb/common/type_visitor.hpp b/src/duckdb/src/include/duckdb/common/type_visitor.hpp index 9296a75db..d8c2b381a 100644 --- a/src/duckdb/src/include/duckdb/common/type_visitor.hpp +++ b/src/duckdb/src/include/duckdb/common/type_visitor.hpp @@ -26,6 +26,9 @@ template inline LogicalType TypeVisitor::VisitReplace(const LogicalType &type, F &&func) { switch (type.id()) { case LogicalTypeId::STRUCT: { + if (!type.AuxInfo()) { + return func(type); + } auto children = StructType::GetChildTypes(type); for (auto &child : children) { child.second = VisitReplace(child.second, func); @@ -33,6 +36,9 @@ inline LogicalType TypeVisitor::VisitReplace(const LogicalType &type, F &&func) return func(LogicalType::STRUCT(children)); } case LogicalTypeId::UNION: { + if (!type.AuxInfo()) { + return func(type); + } auto children = UnionType::CopyMemberTypes(type); for (auto &child : children) { child.second = VisitReplace(child.second, func); @@ -40,16 +46,25 @@ inline LogicalType TypeVisitor::VisitReplace(const LogicalType &type, F &&func) return func(LogicalType::UNION(children)); } case LogicalTypeId::LIST: { - auto child = ListType::GetChildType(type); + if (!type.AuxInfo()) { + return func(type); + } + const auto &child = ListType::GetChildType(type); return func(LogicalType::LIST(VisitReplace(child, func))); } case LogicalTypeId::ARRAY: { - auto child = ArrayType::GetChildType(type); + if (!type.AuxInfo()) { + return func(type); + } + const auto &child = ArrayType::GetChildType(type); return func(LogicalType::ARRAY(VisitReplace(child, func), ArrayType::GetSize(type))); } case LogicalTypeId::MAP: { - auto key = MapType::KeyType(type); - auto value = MapType::ValueType(type); + if (!type.AuxInfo()) { + return func(type); + } + const auto &key = MapType::KeyType(type); + const auto &value = MapType::ValueType(type); return func(LogicalType::MAP(VisitReplace(key, func), VisitReplace(value, func))); } default: @@ -64,6 +79,9 @@ inline bool TypeVisitor::Contains(const LogicalType &type, F &&predicate) { } switch (type.id()) { case LogicalTypeId::STRUCT: { + if (!type.AuxInfo()) { + return false; + } for (const auto &child : StructType::GetChildTypes(type)) { if (Contains(child.second, predicate)) { return true; @@ -72,17 +90,29 @@ inline bool TypeVisitor::Contains(const LogicalType &type, F &&predicate) { return false; } case LogicalTypeId::UNION: - for (const auto &child : UnionType::CopyMemberTypes(type)) { - if (Contains(child.second, predicate)) { + if (!type.AuxInfo()) { + return false; + } + for (idx_t i = 0; i < UnionType::GetMemberCount(type); i++) { + if (Contains(UnionType::GetMemberType(type, i), predicate)) { return true; } } return false; case LogicalTypeId::LIST: + if (!type.AuxInfo()) { + return false; + } return Contains(ListType::GetChildType(type), predicate); case LogicalTypeId::ARRAY: + if (!type.AuxInfo()) { + return false; + } return Contains(ArrayType::GetChildType(type), predicate); case LogicalTypeId::MAP: + if (!type.AuxInfo()) { + return false; + } return Contains(MapType::KeyType(type), predicate) || Contains(MapType::ValueType(type), predicate); default: return false; diff --git a/src/duckdb/src/include/duckdb/common/types.hpp b/src/duckdb/src/include/duckdb/common/types.hpp index ec47ac788..2efc5641b 100644 --- a/src/duckdb/src/include/duckdb/common/types.hpp +++ b/src/duckdb/src/include/duckdb/common/types.hpp @@ -186,6 +186,15 @@ enum class LogicalTypeId : uint8_t { UNKNOWN = 2, /* unknown type, used for parameter expressions */ ANY = 3, /* ANY type, used for functions that accept any type as parameter */ USER = 4, /* A User Defined Type (e.g., ENUMs before the binder) */ + + + // A "template" type functions as a "placeholder" type for function arguments and return types. + // Templates only exist during the binding phase, in the scope of a function, and are replaced with concrete types + // before execution. When defining a template, you provide a name to distinguish between different template types, + // specifying to the binder that they dont need to resolve to the same concrete type. Two templates with the same + // name are always resolved to the same concrete type. + TEMPLATE = 5, + BOOLEAN = 10, TINYINT = 11, SMALLINT = 12, @@ -214,7 +223,7 @@ enum class LogicalTypeId : uint8_t { BIT = 36, STRING_LITERAL = 37, /* string literals, used for constant strings - only exists while binding */ INTEGER_LITERAL = 38,/* integer literals, used for constant integers - only exists while binding */ - VARINT = 39, + BIGNUM = 39, UHUGEINT = 49, HUGEINT = 50, POINTER = 51, @@ -355,6 +364,7 @@ struct LogicalType { DUCKDB_API bool IsValid() const; DUCKDB_API bool IsComplete() const; + DUCKDB_API bool IsTemplated() const; //! True, if this type supports in-place updates. bool SupportsRegularUpdate() const; @@ -395,7 +405,7 @@ struct LogicalType { static constexpr const LogicalTypeId ANY = LogicalTypeId::ANY; static constexpr const LogicalTypeId BLOB = LogicalTypeId::BLOB; static constexpr const LogicalTypeId BIT = LogicalTypeId::BIT; - static constexpr const LogicalTypeId VARINT = LogicalTypeId::VARINT; + static constexpr const LogicalTypeId BIGNUM = LogicalTypeId::BIGNUM; static constexpr const LogicalTypeId INTERVAL = LogicalTypeId::INTERVAL; static constexpr const LogicalTypeId HUGEINT = LogicalTypeId::HUGEINT; @@ -421,6 +431,7 @@ struct LogicalType { DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT // ANY but with special rules (default is LogicalType::ANY, 5) DUCKDB_API static LogicalType ANY_PARAMS(LogicalType target, idx_t cast_score = 5); // NOLINT + DUCKDB_API static LogicalType TEMPLATE(const string &name); // NOLINT //! Integer literal of the specified value DUCKDB_API static LogicalType INTEGER_LITERAL(const Value &constant); // NOLINT // DEPRECATED - provided for backwards compatibility @@ -523,6 +534,11 @@ struct IntegerLiteral { DUCKDB_API static bool FitsInType(const LogicalType &type, const LogicalType &target); }; +struct TemplateType { + // Get the name of the template type + DUCKDB_API static const string &GetName(const LogicalType &type); +}; + // **DEPRECATED**: Use EnumUtil directly instead. DUCKDB_API string LogicalTypeIdToString(LogicalTypeId type); diff --git a/src/duckdb/src/include/duckdb/common/types/bignum.hpp b/src/duckdb/src/include/duckdb/common/types/bignum.hpp new file mode 100644 index 000000000..cdf1d6ccf --- /dev/null +++ b/src/duckdb/src/include/duckdb/common/types/bignum.hpp @@ -0,0 +1,159 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/common/types/bignum.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/limits.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/winapi.hpp" +#include "duckdb/function/cast/default_casts.hpp" +#include "duckdb/common/bignum.hpp" + +namespace duckdb { +using digit_t = uint32_t; +using twodigit_t = uint64_t; + +//! The Bignum class is a static class that holds helper functions for the Bignum type. +class Bignum { +public: + //! This is the maximum number of bytes a BIGNUM can have on it's data size + //! i.e., 2^(8*3-1) - 1. + DUCKDB_API static constexpr uint32_t MAX_DATA_SIZE = 8388607; + //! Header size of a Bignum is always 3 bytes. + DUCKDB_API static constexpr uint8_t BIGNUM_HEADER_SIZE = 3; + //! Max(e such that 10**e fits in a digit_t) + DUCKDB_API static constexpr uint8_t DECIMAL_SHIFT = 9; + //! 10 ** DECIMAL_SHIFT + DUCKDB_API static constexpr digit_t DECIMAL_BASE = 1000000000; + //! Bytes of a digit_t + DUCKDB_API static constexpr uint8_t DIGIT_BYTES = sizeof(digit_t); + //! Bits of a digit_t + DUCKDB_API static constexpr uint8_t DIGIT_BITS = DIGIT_BYTES * 8; + //! Verifies if a Bignum is valid. i.e., if it has 3 header bytes. The header correctly represents the number of + //! data bytes, and the data bytes has no leading zero bytes. + DUCKDB_API static void Verify(const bignum_t &input); + + //! Sets the header of a bignum (i.e., char* blob), depending on the number of bytes that bignum needs and if it's a + //! negative number + DUCKDB_API static void SetHeader(char *blob, uint64_t number_of_bytes, bool is_negative); + //! Initializes and returns a blob with value 0, allocated in Vector& result + DUCKDB_API static bignum_t InitializeBignumZero(Vector &result); + DUCKDB_API static string InitializeBignumZero(); + + //! Switch Case of To Bignum Convertion + DUCKDB_API static BoundCastInfo NumericToBignumCastSwitch(const LogicalType &source); + + //! ----------------------------------- Varchar Cast ----------------------------------- // + //! Function to prepare a varchar for conversion. We trim zero's, check for negative values, and what-not + //! Returns false if this is an invalid varchar + DUCKDB_API static bool VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &end_pos, bool &is_negative, + bool &is_zero); + + //! Converts a char to a Digit + DUCKDB_API static int CharToDigit(char c); + //! Converts a Digit to a char + DUCKDB_API static char DigitToChar(int digit); + //! Function to convert a string_t into a vector of bytes + DUCKDB_API static void GetByteArray(vector &byte_array, bool &is_negative, const string_t &blob); + //! Function to create a BIGNUM blob from a byte array containing the absolute value, plus an is_negative bool + DUCKDB_API static string FromByteArray(uint8_t *data, idx_t size, bool is_negative); + //! Function to convert BIGNUM blob to a VARCHAR + DUCKDB_API static string BignumToVarchar(const bignum_t &blob); + //! Function to convert Varchar to BIGNUM blob + DUCKDB_API static string VarcharToBignum(const string_t &value); + //! ----------------------------------- Double Cast ----------------------------------- // + DUCKDB_API static bool BignumToDouble(const bignum_t &blob, double &result, bool &strict); + template + static bool BignumToInt(const bignum_t &blob, T &result, bool &strict) { + auto data_byte_size = blob.data.GetSize() - BIGNUM_HEADER_SIZE; + auto data = blob.data.GetData(); + bool is_negative = (data[0] & 0x80) == 0; + + uhugeint_t abs_value = 0; + for (idx_t i = 0; i < data_byte_size; ++i) { + uint8_t byte = static_cast(data[Bignum::BIGNUM_HEADER_SIZE + i]); + if (is_negative) { + byte = ~byte; + } + abs_value = (abs_value << 8) | byte; + } + + if (is_negative) { + if (abs_value > static_cast(std::numeric_limits::max()) + 1) { + throw OutOfRangeException("Negative bignum too small for type"); + } + result = static_cast(-static_cast(abs_value)); + } else { + if (abs_value > static_cast(std::numeric_limits::max())) { + throw OutOfRangeException("Positive bignum too large for type"); + } + result = static_cast(abs_value); + } + return true; + } +}; + +//! ----------------------------------- (u)Integral Cast ----------------------------------- // +struct IntCastToBignum { + template + static inline bignum_t Operation(SRC input, Vector &result) { + return IntToBignum(result, input); + } +}; + +//! ----------------------------------- (u)HugeInt Cast ----------------------------------- // +struct HugeintCastToBignum { + template + static inline bignum_t Operation(SRC input, Vector &result) { + throw InternalException("Unsupported type for cast to BIGNUM"); + } +}; + +struct TryCastToBignum { + template + static inline bool Operation(SRC input, DST &result, Vector &result_vector, CastParameters ¶meters) { + throw InternalException("Unsupported type for try cast to BIGNUM"); + } +}; + +template <> +DUCKDB_API bool TryCastToBignum::Operation(double double_value, bignum_t &result_value, Vector &result, + CastParameters ¶meters); + +template <> +DUCKDB_API bool TryCastToBignum::Operation(float float_value, bignum_t &result_value, Vector &result, + CastParameters ¶meters); + +template <> +DUCKDB_API bool TryCastToBignum::Operation(string_t input_value, bignum_t &result_value, Vector &result, + CastParameters ¶meters); + +struct BignumCastToVarchar { + template + DUCKDB_API static inline string_t Operation(SRC input, Vector &result) { + return StringVector::AddStringOrBlob(result, Bignum::BignumToVarchar(input)); + } +}; + +struct BignumToDoubleCast { + template + DUCKDB_API static inline bool Operation(SRC input, DST &result, bool strict = false) { + return Bignum::BignumToDouble(input, result, strict); + } +}; + +struct BignumToIntCast { + template + DUCKDB_API static inline bool Operation(SRC input, DST &result, bool strict = false) { + return Bignum::BignumToInt(input, result, strict); + } +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp index 072d9282d..3720bf844 100644 --- a/src/duckdb/src/include/duckdb/common/types/hugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/types/hugeint.hpp @@ -66,7 +66,7 @@ class Hugeint { inline static hugeint_t Multiply(hugeint_t lhs, hugeint_t rhs) { hugeint_t result; if (!TryMultiply(lhs, rhs, result)) { - throw OutOfRangeException("Overflow in HUGEINT multiplication: %s + %s", lhs.ToString(), rhs.ToString()); + throw OutOfRangeException("Overflow in HUGEINT multiplication: %s * %s", lhs.ToString(), rhs.ToString()); } return result; } @@ -77,12 +77,12 @@ class Hugeint { inline static hugeint_t Divide(hugeint_t lhs, hugeint_t rhs) { // No division by zero if (rhs == 0) { - throw OutOfRangeException("Division of HUGEINT by zero!"); + throw OutOfRangeException("Division of HUGEINT by zero: %s / %s", lhs.ToString(), rhs.ToString()); } // division only has one reason to overflow: MINIMUM / -1 if (lhs == NumericLimits::Minimum() && rhs == -1) { - throw OutOfRangeException("Overflow in HUGEINT division: %s + %s", lhs.ToString(), rhs.ToString()); + throw OutOfRangeException("Overflow in HUGEINT division: %s / %s", lhs.ToString(), rhs.ToString()); } return Divide(lhs, rhs); } @@ -91,12 +91,12 @@ class Hugeint { inline static hugeint_t Modulo(hugeint_t lhs, hugeint_t rhs) { // No division by zero if (rhs == 0) { - throw OutOfRangeException("Modulo of HUGEINT by zero: %s + %s", lhs.ToString(), rhs.ToString()); + throw OutOfRangeException("Modulo of HUGEINT by zero: %s %% %s", lhs.ToString(), rhs.ToString()); } // division only has one reason to overflow: MINIMUM / -1 if (lhs == NumericLimits::Minimum() && rhs == -1) { - throw OutOfRangeException("Overflow in HUGEINT modulo: %s + %s", lhs.ToString(), rhs.ToString()); + throw OutOfRangeException("Overflow in HUGEINT modulo: %s %% %s", lhs.ToString(), rhs.ToString()); } return Modulo(lhs, rhs); } @@ -116,7 +116,7 @@ class Hugeint { template inline static hugeint_t Subtract(hugeint_t lhs, hugeint_t rhs) { if (!TrySubtractInPlace(lhs, rhs)) { - throw OutOfRangeException("Underflow in HUGEINT addition: %s - %s", lhs.ToString(), rhs.ToString()); + throw OutOfRangeException("Underflow in HUGEINT subtraction: %s - %s", lhs.ToString(), rhs.ToString()); } return lhs; } diff --git a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp index 57de39b57..d759341ee 100644 --- a/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp +++ b/src/duckdb/src/include/duckdb/common/types/row/tuple_data_collection.hpp @@ -57,6 +57,8 @@ class TupleDataCollection { unique_ptr CreateUnique() const; public: + //! Get the layout (shared pointer) of the rows + shared_ptr GetLayoutPtr() const; //! The layout of the stored rows const TupleDataLayout &GetLayout() const; //! How many tuples fit per block diff --git a/src/duckdb/src/include/duckdb/common/types/uhugeint.hpp b/src/duckdb/src/include/duckdb/common/types/uhugeint.hpp index 14cf6e7be..b1e6ac8eb 100644 --- a/src/duckdb/src/include/duckdb/common/types/uhugeint.hpp +++ b/src/duckdb/src/include/duckdb/common/types/uhugeint.hpp @@ -67,7 +67,7 @@ class Uhugeint { inline static uhugeint_t Multiply(uhugeint_t lhs, uhugeint_t rhs) { uhugeint_t result; if (!TryMultiply(lhs, rhs, result)) { - throw OutOfRangeException("Overflow in UHUGEINT multiplication!: %s + %s", lhs.ToString(), rhs.ToString()); + throw OutOfRangeException("Overflow in UHUGEINT multiplication: %s * %s", lhs.ToString(), rhs.ToString()); } return result; } @@ -78,7 +78,7 @@ class Uhugeint { inline static uhugeint_t Divide(uhugeint_t lhs, uhugeint_t rhs) { // division between two same-size unsigned integers can only go wrong with division by zero if (rhs == 0) { - throw OutOfRangeException("Division of UHUGEINT by zero!"); + throw OutOfRangeException("Division of UHUGEINT by zero: %s / %s", lhs.ToString(), rhs.ToString()); } return Divide(lhs, rhs); } @@ -86,7 +86,7 @@ class Uhugeint { template inline static uhugeint_t Modulo(uhugeint_t lhs, uhugeint_t rhs) { if (rhs == 0) { - throw OutOfRangeException("Modulo of UHUGEINT by zero!"); + throw OutOfRangeException("Modulo of UHUGEINT by zero: %s %% %s", lhs.ToString(), rhs.ToString()); } return Modulo(lhs, rhs); } @@ -106,7 +106,7 @@ class Uhugeint { template inline static uhugeint_t Subtract(uhugeint_t lhs, uhugeint_t rhs) { if (!TrySubtractInPlace(lhs, rhs)) { - throw OutOfRangeException("Underflow in HUGEINT addition: %s - %s", lhs.ToString(), rhs.ToString()); + throw OutOfRangeException("Underflow in HUGEINT subtraction: %s - %s", lhs.ToString(), rhs.ToString()); } return lhs; } diff --git a/src/duckdb/src/include/duckdb/common/types/value.hpp b/src/duckdb/src/include/duckdb/common/types/value.hpp index a91841e0b..5e4dafbe9 100644 --- a/src/duckdb/src/include/duckdb/common/types/value.hpp +++ b/src/duckdb/src/include/duckdb/common/types/value.hpp @@ -193,8 +193,8 @@ class Value { //! Creates a bitstring by casting a specified string to a bitstring DUCKDB_API static Value BIT(const_data_ptr_t data, idx_t len); DUCKDB_API static Value BIT(const string &data); - DUCKDB_API static Value VARINT(const_data_ptr_t data, idx_t len); - DUCKDB_API static Value VARINT(const string &data); + DUCKDB_API static Value BIGNUM(const_data_ptr_t data, idx_t len); + DUCKDB_API static Value BIGNUM(const string &data); //! Creates an aggregate state DUCKDB_API static Value AGGREGATE_STATE(const LogicalType &type, const_data_ptr_t data, idx_t len); // NOLINT diff --git a/src/duckdb/src/include/duckdb/common/types/varint.hpp b/src/duckdb/src/include/duckdb/common/types/varint.hpp deleted file mode 100644 index e0b99cbab..000000000 --- a/src/duckdb/src/include/duckdb/common/types/varint.hpp +++ /dev/null @@ -1,120 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/varint.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/winapi.hpp" -#include "duckdb/function/cast/default_casts.hpp" - -namespace duckdb { -using digit_t = uint32_t; -using twodigit_t = uint64_t; - -//! The Varint class is a static class that holds helper functions for the Varint type. -class Varint { -public: - //! Header size of a Varint is always 3 bytes. - DUCKDB_API static constexpr uint8_t VARINT_HEADER_SIZE = 3; - //! Max(e such that 10**e fits in a digit_t) - DUCKDB_API static constexpr uint8_t DECIMAL_SHIFT = 9; - //! 10 ** DECIMAL_SHIFT - DUCKDB_API static constexpr digit_t DECIMAL_BASE = 1000000000; - //! Bytes of a digit_t - DUCKDB_API static constexpr uint8_t DIGIT_BYTES = sizeof(digit_t); - //! Bits of a digit_t - DUCKDB_API static constexpr uint8_t DIGIT_BITS = DIGIT_BYTES * 8; - //! Verifies if a Varint is valid. i.e., if it has 3 header bytes. The header correctly represents the number of - //! data bytes, and the data bytes has no leading zero bytes. - DUCKDB_API static void Verify(const string_t &input); - - //! Sets the header of a varint (i.e., char* blob), depending on the number of bytes that varint needs and if it's a - //! negative number - DUCKDB_API static void SetHeader(char *blob, uint64_t number_of_bytes, bool is_negative); - //! Initializes and returns a blob with value 0, allocated in Vector& result - DUCKDB_API static string_t InitializeVarintZero(Vector &result); - DUCKDB_API static string InitializeVarintZero(); - - //! Switch Case of To Varint Convertion - DUCKDB_API static BoundCastInfo NumericToVarintCastSwitch(const LogicalType &source); - - //! ----------------------------------- Varchar Cast ----------------------------------- // - //! Function to prepare a varchar for conversion. We trim zero's, check for negative values, and what-not - //! Returns false if this is an invalid varchar - DUCKDB_API static bool VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &end_pos, bool &is_negative, - bool &is_zero); - - //! Converts a char to a Digit - DUCKDB_API static int CharToDigit(char c); - //! Converts a Digit to a char - DUCKDB_API static char DigitToChar(int digit); - //! Function to convert a string_t into a vector of bytes - DUCKDB_API static void GetByteArray(vector &byte_array, bool &is_negative, const string_t &blob); - //! Function to create a VARINT blob from a byte array containing the absolute value, plus an is_negative bool - DUCKDB_API static string FromByteArray(uint8_t *data, idx_t size, bool is_negative); - //! Function to convert VARINT blob to a VARCHAR - DUCKDB_API static string VarIntToVarchar(const string_t &blob); - //! Function to convert Varchar to VARINT blob - DUCKDB_API static string VarcharToVarInt(const string_t &value); - //! ----------------------------------- Double Cast ----------------------------------- // - DUCKDB_API static bool VarintToDouble(const string_t &blob, double &result, bool &strict); -}; - -//! ----------------------------------- (u)Integral Cast ----------------------------------- // -struct IntCastToVarInt { - template - static inline string_t Operation(SRC input, Vector &result) { - return IntToVarInt(result, input); - } -}; - -//! ----------------------------------- (u)HugeInt Cast ----------------------------------- // -struct HugeintCastToVarInt { - template - static inline string_t Operation(SRC input, Vector &result) { - throw InternalException("Unsupported type for cast to VARINT"); - } -}; - -struct TryCastToVarInt { - template - static inline bool Operation(SRC input, DST &result, Vector &result_vector, CastParameters ¶meters) { - throw InternalException("Unsupported type for try cast to VARINT"); - } -}; - -template <> -DUCKDB_API bool TryCastToVarInt::Operation(double double_value, string_t &result_value, Vector &result, - CastParameters ¶meters); - -template <> -DUCKDB_API bool TryCastToVarInt::Operation(float float_value, string_t &result_value, Vector &result, - CastParameters ¶meters); - -template <> -DUCKDB_API bool TryCastToVarInt::Operation(string_t input_value, string_t &result_value, Vector &result, - CastParameters ¶meters); - -struct VarIntCastToVarchar { - template - DUCKDB_API static inline string_t Operation(SRC input, Vector &result) { - return StringVector::AddStringOrBlob(result, Varint::VarIntToVarchar(input)); - } -}; - -struct VarintToDoubleCast { - template - DUCKDB_API static inline bool Operation(SRC input, DST &result, bool strict = false) { - return Varint::VarintToDouble(input, result, strict); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/ht_entry.hpp b/src/duckdb/src/include/duckdb/execution/ht_entry.hpp index e61460ae4..bcbe583c1 100644 --- a/src/duckdb/src/include/duckdb/execution/ht_entry.hpp +++ b/src/duckdb/src/include/duckdb/execution/ht_entry.hpp @@ -79,6 +79,10 @@ struct ht_entry_t { // NOLINT return ExtractSalt(value); } + inline hash_t GetSaltWithNulls() const { + return value & SALT_MASK; + } + inline void SetSalt(const hash_t &salt) { // Shouldn't be occupied when we set this D_ASSERT(!IsOccupied()); diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp index 60ad2ac71..34a097400 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art.hpp @@ -16,7 +16,7 @@ namespace duckdb { enum class VerifyExistenceType : uint8_t { APPEND = 0, APPEND_FK = 1, DELETE_FK = 2 }; enum class ARTConflictType : uint8_t { NO_CONFLICT = 0, CONSTRAINT = 1, TRANSACTION = 2 }; -enum class ARTHandlingResult : uint8_t { CONTINUE = 0, SKIP = 1, YIELD = 2 }; +enum class ARTHandlingResult : uint8_t { CONTINUE = 0, SKIP = 1, YIELD = 2, NONE = 3 }; class ConflictManager; class ARTKey; @@ -72,7 +72,7 @@ class ART : public BoundIndex { unique_ptr TryInitializeScan(const Expression &expr, const Expression &filter_expr); //! Perform a lookup on the ART, fetching up to max_count row IDs. //! If all row IDs were fetched, it return true, else false. - bool Scan(IndexScanState &state, idx_t max_count, unsafe_vector &row_ids); + bool Scan(IndexScanState &state, idx_t max_count, set &row_ids); //! Appends data to the locked index. ErrorData Append(IndexLock &l, DataChunk &chunk, Vector &row_ids) override; @@ -124,11 +124,11 @@ class ART : public BoundIndex { void VerifyBuffers(IndexLock &l) override; private: - bool SearchEqual(ARTKey &key, idx_t max_count, unsafe_vector &row_ids); - bool SearchGreater(ARTKey &key, bool equal, idx_t max_count, unsafe_vector &row_ids); - bool SearchLess(ARTKey &upper_bound, bool equal, idx_t max_count, unsafe_vector &row_ids); + bool SearchEqual(ARTKey &key, idx_t max_count, set &row_ids); + bool SearchGreater(ARTKey &key, bool equal, idx_t max_count, set &row_ids); + bool SearchLess(ARTKey &upper_bound, bool equal, idx_t max_count, set &row_ids); bool SearchCloseRange(ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, bool right_equal, idx_t max_count, - unsafe_vector &row_ids); + set &row_ids); string GenerateErrorKeyName(DataChunk &input, idx_t row); string GenerateConstraintErrorMessage(VerifyExistenceType verify_type, const string &key_name); diff --git a/src/duckdb/src/include/duckdb/execution/index/art/art_scanner.hpp b/src/duckdb/src/include/duckdb/execution/index/art/art_scanner.hpp index 26a9918ba..874c3125f 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/art_scanner.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/art_scanner.hpp @@ -71,7 +71,7 @@ class ARTScanner { break; } default: - throw InternalException("invalid node type for ART ARTScanner: %s", EnumUtil::ToString(type)); + throw InternalException("invalid node type for ART ARTScanner: %d", type); } } } @@ -80,9 +80,11 @@ class ARTScanner { template void Emplace(FUNC &&handler, NODE &node) { if (HANDLING == ARTScanHandling::EMPLACE) { - if (handler(node) == ARTHandlingResult::SKIP) { + auto result = handler(node); + if (result == ARTHandlingResult::SKIP) { return; } + D_ASSERT(result == ARTHandlingResult::CONTINUE); } s.emplace(node); } diff --git a/src/duckdb/src/include/duckdb/execution/index/art/base_node.hpp b/src/duckdb/src/include/duckdb/execution/index/art/base_node.hpp index 98096cafb..f06977e9a 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/base_node.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/base_node.hpp @@ -39,18 +39,15 @@ class BaseNode { NodeHandle handle(art, node); auto &n = handle.Get(); + // Reset the node (count). n.count = 0; - return handle; - } - - //! Free the children of the node. - static void Free(ART &art, Node &node) { - NodeHandle handle(art, node); - auto &n = handle.Get(); - - for (uint8_t i = 0; i < n.count; i++) { - Node::Free(art, n.children[i]); + // Zero-initialize the node. + for (uint8_t i = 0; i < CAPACITY; i++) { + n.key[i] = 0; + n.children[i].Clear(); } + + return handle; } //! Replace the child at byte. @@ -105,7 +102,6 @@ class BaseNode { children_ptr[i] = children[i]; } - count = 0; return NodeChildren(bytes, children_ptr); } diff --git a/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp index 977cc7791..c5907f820 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/iterator.hpp @@ -65,7 +65,7 @@ class Iterator { public: //! Scans the tree, starting at the current top node on the stack, and ending at upper_bound. //! If upper_bound is the empty ARTKey, than there is no upper bound. - bool Scan(const ARTKey &upper_bound, const idx_t max_count, unsafe_vector &row_ids, const bool equal); + bool Scan(const ARTKey &upper_bound, const idx_t max_count, set &row_ids, const bool equal); //! Finds the minimum (leaf) of the current subtree. void FindMinimum(const Node &node); //! Finds the lower bound of the ART and adds the nodes to the stack. Returns false, if the lower diff --git a/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp index 01196c68e..30efdba0a 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/leaf.hpp @@ -54,7 +54,7 @@ class Leaf { static void DeprecatedFree(ART &art, Node &node); //! Fills the row_ids vector with the row IDs of this linked list of leaves. //! Never pushes more than max_count row IDs. - static bool DeprecatedGetRowIds(ART &art, const Node &node, unsafe_vector &row_ids, const idx_t max_count); + static bool DeprecatedGetRowIds(ART &art, const Node &node, set &row_ids, const idx_t max_count); //! Vacuums the linked list of leaves. static void DeprecatedVacuum(ART &art, Node &node); //! Returns the string representation of the linked list of leaves, if only_verify is true. diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp index 851d4f898..764a3ce59 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/node.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/node.hpp @@ -52,8 +52,10 @@ class Node : public IndexPointer { public: //! Get a new pointer to a node and initialize it. static void New(ART &art, Node &node, const NType type); + //! Free the node. + static void FreeNode(ART &art, Node &node); //! Free the node and its children. - static void Free(ART &art, Node &node); + static void FreeTree(ART &art, Node &node); //! Get a reference to the allocator. static FixedSizeAllocator &GetAllocator(const ART &art, const NType type); diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node256.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node256.hpp index ec3b2827d..d3c2362b2 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/node256.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/node256.hpp @@ -33,21 +33,41 @@ class Node256 { Node children[CAPACITY]; public: - //! Get a new Node256 and initialize it. - static Node256 &New(ART &art, Node &node); - //! Free the node and its children. - static void Free(ART &art, Node &node); + //! Get a new Node256 handle and initialize the Node256. + static NodeHandle New(ART &art, Node &node) { + node = Node::GetAllocator(art, NODE_256).New(); + node.SetMetadata(static_cast(NODE_256)); + + NodeHandle handle(art, node); + auto &n = handle.Get(); + + // Reset the node (count and children). + n.count = 0; + for (uint16_t i = 0; i < CAPACITY; i++) { + n.children[i].Clear(); + } + + return handle; + } //! Insert a child at byte. static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); //! Delete the child at byte. static void DeleteChild(ART &art, Node &node, const uint8_t byte); //! Replace the child at byte. - void ReplaceChild(const uint8_t byte, const Node child); + void ReplaceChild(const uint8_t byte, const Node child) { + D_ASSERT(count > SHRINK_THRESHOLD); + auto status = children[byte].GetGateStatus(); + children[byte] = child; + if (status == GateStatus::GATE_SET && child.HasMetadata()) { + children[byte].SetGateStatus(status); + } + } public: template static void Iterator(NODE &n, F &&lambda) { + D_ASSERT(n.count); for (idx_t i = 0; i < CAPACITY; i++) { if (n.children[i].HasMetadata()) { lambda(n.children[i]); @@ -92,11 +112,10 @@ class Node256 { } } - count = 0; return NodeChildren(bytes, children_ptr); } private: - static Node256 &GrowNode48(ART &art, Node &node256, Node &node48); + static void GrowNode48(ART &art, Node &node256, Node &node48); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/node48.hpp b/src/duckdb/src/include/duckdb/execution/index/art/node48.hpp index 81142c938..6fee57ffd 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/node48.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/node48.hpp @@ -38,21 +38,45 @@ class Node48 { Node children[CAPACITY]; public: - //! Get a new Node48 and initialize it. - static Node48 &New(ART &art, Node &node); - //! Free the node and its children. - static void Free(ART &art, Node &node); + //! Get a new Node48 handle and initialize the Node48. + static NodeHandle New(ART &art, Node &node) { + node = Node::GetAllocator(art, NODE_48).New(); + node.SetMetadata(static_cast(NODE_48)); + + NodeHandle handle(art, node); + auto &n = handle.Get(); + + // Reset the node (count and child_index). + n.count = 0; + for (uint16_t i = 0; i < Node256::CAPACITY; i++) { + n.child_index[i] = EMPTY_MARKER; + } + // Zero-initialize the node. + for (uint8_t i = 0; i < CAPACITY; i++) { + n.children[i].Clear(); + } + + return handle; + } //! Insert a child at byte. static void InsertChild(ART &art, Node &node, const uint8_t byte, const Node child); //! Delete the child at byte. static void DeleteChild(ART &art, Node &node, const uint8_t byte); //! Replace the child at byte. - void ReplaceChild(const uint8_t byte, const Node child); + void ReplaceChild(const uint8_t byte, const Node child) { + D_ASSERT(count >= SHRINK_THRESHOLD); + auto status = children[child_index[byte]].GetGateStatus(); + children[child_index[byte]] = child; + if (status == GateStatus::GATE_SET && child.HasMetadata()) { + children[child_index[byte]].SetGateStatus(status); + } + } public: template static void Iterator(NODE &n, F &&lambda) { + D_ASSERT(n.count); for (idx_t i = 0; i < Node256::CAPACITY; i++) { if (n.child_index[i] != EMPTY_MARKER) { lambda(n.children[n.child_index[i]]); @@ -97,12 +121,11 @@ class Node48 { } } - count = 0; return NodeChildren(bytes, children_ptr); } private: - static Node48 &GrowNode16(ART &art, Node &node48, Node &node16); - static Node48 &ShrinkNode256(ART &art, Node &node48, Node &node256); + static void GrowNode16(ART &art, Node &node48, Node &node16); + static void ShrinkNode256(ART &art, Node &node48, Node &node256); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp index 644f789f6..530a888b6 100644 --- a/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/art/prefix.hpp @@ -47,12 +47,7 @@ class Prefix { //! Get a new list of prefix nodes. The node reference holds the child of the last prefix node. static void New(ART &art, reference &ref, const ARTKey &key, const idx_t depth, idx_t count); - //! Free the prefix and its child. - static void Free(ART &art, Node &node); - - //! Concatenates parent -> byte -> child. Special-handling, if - //! 1. the byte was in a gate node. - //! 2. the byte was in PREFIX_INLINED. + //! Concatenates parent -> byte -> child. static void Concat(ART &art, Node &parent, uint8_t byte, const GateStatus old_status, const Node &child, const GateStatus status); @@ -80,8 +75,7 @@ class Prefix { static void TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator); private: - static Prefix NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset, - const NType type); + static Prefix NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset); static Prefix GetTail(ART &art, const Node &node); diff --git a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp index 2f4254862..ddde9825e 100644 --- a/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/bound_index.hpp @@ -60,19 +60,18 @@ class BoundIndex : public Index { //! The index constraint type IndexConstraintType index_constraint_type; + vector> unbound_expressions; + public: bool IsBound() const override { return true; } - const string &GetIndexType() const override { return index_type; } - const string &GetIndexName() const override { return name; } - IndexConstraintType GetConstraintType() const override { return index_constraint_type; } @@ -156,7 +155,7 @@ class BoundIndex : public Index { virtual string GetConstraintViolationMessage(VerifyExistenceType verify_type, idx_t failed_index, DataChunk &input) = 0; - vector> unbound_expressions; + void ApplyBufferedAppends(ColumnDataCollection &buffered_appends); protected: //! Lock used for any changes to the index diff --git a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp index 1c7f859ea..3e56e864c 100644 --- a/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp +++ b/src/duckdb/src/include/duckdb/execution/index/unbound_index.hpp @@ -8,56 +8,61 @@ #pragma once -#include "duckdb/storage/index.hpp" #include "duckdb/parser/parsed_data/create_index_info.hpp" +#include "duckdb/storage/index.hpp" namespace duckdb { +class ColumnDataCollection; + class UnboundIndex final : public Index { private: - // The create info of the index + //! The CreateInfo of the index. unique_ptr create_info; - - // The serialized storage info of the index + //! The serialized storage information of the index. IndexStorageInfo storage_info; + //! Buffer for WAL replay appends. + unique_ptr buffered_appends; public: UnboundIndex(unique_ptr create_info, IndexStorageInfo storage_info, TableIOManager &table_io_manager, AttachedDatabase &db); +public: bool IsBound() const override { return false; } - const string &GetIndexType() const override { return GetCreateInfo().index_type; } - const string &GetIndexName() const override { return GetCreateInfo().index_name; } - IndexConstraintType GetConstraintType() const override { return GetCreateInfo().constraint_type; } - const CreateIndexInfo &GetCreateInfo() const { return create_info->Cast(); } - const IndexStorageInfo &GetStorageInfo() const { return storage_info; } - const vector> &GetParsedExpressions() const { return GetCreateInfo().parsed_expressions; } - const string &GetTableName() const { return GetCreateInfo().table; } void CommitDrop() override; + + void BufferChunk(DataChunk &chunk, Vector &row_ids); + bool HasBufferedAppends() const { + return buffered_appends != nullptr; + } + ColumnDataCollection &GetBufferedAppends() const { + return *buffered_appends; + } }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp index a4392e86b..0cd90d0b3 100644 --- a/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp +++ b/src/duckdb/src/include/duckdb/execution/join_hashtable.hpp @@ -149,7 +149,7 @@ class JoinHashTable { struct ProbeState : SharedState { ProbeState(); - Vector ht_offsets_v; + Vector ht_offsets_and_salts_v; Vector hashes_dense_v; SelectionVector non_empty_sel; }; diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp index 82bd81466..30ba0abc5 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_error.hpp @@ -69,7 +69,7 @@ class CSVError { const string ¤t_path); //! Produces an error message for a dialect sniffing error. static CSVError SniffingError(const CSVReaderOptions &options, const string &search_space, idx_t max_columns_found, - SetColumns &set_columns); + SetColumns &set_columns, bool type_detection); //! Produces an error message for a header sniffing error. static CSVError HeaderSniffingError(const CSVReaderOptions &options, const vector &best_header_row, idx_t column_count, const string &delimiter); diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state.hpp index 017ccfe39..bd278aa2e 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state.hpp @@ -12,7 +12,7 @@ namespace duckdb { -//! All States of CSV Parsing +//! All States of CSV parsing enum class CSVState : uint8_t { STANDARD = 0, //! Regular unquoted field state DELIMITER = 1, //! State after encountering a field separator (e.g., ;) - This is always the last delimiter byte diff --git a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp index 71eb1c11e..45aeaad9b 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/csv_scanner/csv_state_machine.hpp @@ -113,7 +113,7 @@ struct CSVStates { //! The STA indicates the current state of parsing based on both the current and preceding characters. //! This reveals whether we are dealing with a Field, a New Line, a Delimiter, and so forth. //! The STA's creation depends on the provided quote, character, and delimiter options for that state machine. -//! The motivation behind implementing an STA is to remove branching in regular CSV Parsing by predicting and detecting +//! The motivation behind implementing an STA is to remove branching in regular CSV parsing by predicting and detecting //! the states. Note: The State Machine is currently utilized solely in the CSV Sniffer. class CSVStateMachine { public: diff --git a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp index fbc97d4db..c7f2fb038 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/helper/physical_set.hpp @@ -40,6 +40,8 @@ class PhysicalSet : public PhysicalOperator { static void SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, SetScope scope, const Value &value); + static void SetGenericVariable(ClientContext &context, const string &name, SetScope scope, Value target_value); + public: const string name; const Value value; diff --git a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp index 0634641f1..4ee6ef557 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/join/physical_range_join.hpp @@ -94,7 +94,7 @@ class PhysicalRangeJoin : public PhysicalComparisonJoin { public: PhysicalRangeJoin(PhysicalPlan &physical_plan, LogicalComparisonJoin &op, PhysicalOperatorType type, PhysicalOperator &left, PhysicalOperator &right, vector cond, JoinType join_type, - idx_t estimated_cardinality); + idx_t estimated_cardinality, unique_ptr pushdown_info); // Projection mappings using ProjectionMapping = vector; diff --git a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp index f3798d877..9cbba99c1 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/persistent/physical_merge_into.hpp @@ -39,7 +39,7 @@ class PhysicalMergeInto : public PhysicalOperator { public: PhysicalMergeInto(PhysicalPlan &physical_plan, vector types, map>> actions, idx_t row_id_index, - optional_idx source_marker, bool parallel); + optional_idx source_marker, bool parallel, bool return_chunk); //! List of all actions vector> actions; @@ -50,10 +50,13 @@ class PhysicalMergeInto : public PhysicalOperator { idx_t row_id_index; optional_idx source_marker; bool parallel; + bool return_chunk; public: // Source interface unique_ptr GetGlobalSourceState(ClientContext &context) const override; + unique_ptr GetLocalSourceState(ExecutionContext &context, + GlobalSourceState &gstate) const override; SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override; bool IsSource() const override { diff --git a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp index 86b32f623..5a5ca7722 100644 --- a/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp +++ b/src/duckdb/src/include/duckdb/execution/operator/projection/physical_tableinout_function.hpp @@ -26,6 +26,8 @@ class PhysicalTableInOutFunction : public PhysicalOperator { public: unique_ptr GetOperatorState(ExecutionContext &context) const override; unique_ptr GetGlobalOperatorState(ClientContext &context) const override; + static void SetOrdinality(DataChunk &chunk, const optional_idx &ordinality_column_idx, const idx_t &ordinality_idx, + const idx_t &ordinality); OperatorResultType Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, GlobalOperatorState &gstate, OperatorState &state) const override; OperatorFinalizeResultType FinalExecute(ExecutionContext &context, DataChunk &chunk, GlobalOperatorState &gstate, @@ -41,6 +43,9 @@ class PhysicalTableInOutFunction : public PhysicalOperator { InsertionOrderPreservingMap ParamsToString() const override; + //! Information for WITH ORDINALITY + optional_idx ordinality_idx; + private: //! The table function TableFunction function; diff --git a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp index e63533d5c..799d6c3b1 100644 --- a/src/duckdb/src/include/duckdb/function/aggregate_function.hpp +++ b/src/duckdb/src/include/duckdb/function/aggregate_function.hpp @@ -35,13 +35,13 @@ using FrameStats = array; //! but the row count will still be valid class ColumnDataCollection; struct WindowPartitionInput { - WindowPartitionInput(ClientContext &context, const ColumnDataCollection *inputs, idx_t count, + WindowPartitionInput(ExecutionContext &context, const ColumnDataCollection *inputs, idx_t count, vector &column_ids, vector &all_valid, const ValidityMask &filter_mask, const FrameStats &stats) : context(context), inputs(inputs), count(count), column_ids(column_ids), all_valid(all_valid), filter_mask(filter_mask), stats(stats) { } - ClientContext &context; + ExecutionContext &context; const ColumnDataCollection *inputs; idx_t count; vector column_ids; diff --git a/src/duckdb/src/include/duckdb/function/cast/cast_function_set.hpp b/src/duckdb/src/include/duckdb/function/cast/cast_function_set.hpp index d019592ac..afdba020f 100644 --- a/src/duckdb/src/include/duckdb/function/cast/cast_function_set.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/cast_function_set.hpp @@ -52,7 +52,12 @@ class CastFunctionSet { GetCastFunctionInput &input); //! Returns the implicit cast cost of casting from source -> target //! -1 means an implicit cast is not possible - DUCKDB_API int64_t ImplicitCastCost(const LogicalType &source, const LogicalType &target); + DUCKDB_API int64_t ImplicitCastCost(optional_ptr context, const LogicalType &source, + const LogicalType &target); + DUCKDB_API static int64_t ImplicitCastCost(ClientContext &context, const LogicalType &source, + const LogicalType &target); + DUCKDB_API static int64_t ImplicitCastCost(DatabaseInstance &db, const LogicalType &source, + const LogicalType &target); //! Register a new cast function from source to target DUCKDB_API void RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function, int64_t implicit_cast_cost = -1); diff --git a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp index 21295b338..96ec9a4e9 100644 --- a/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/default_casts.hpp @@ -169,7 +169,7 @@ struct DefaultCasts { const LogicalType &target); static BoundCastInfo UnionCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); - static BoundCastInfo VarintCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); + static BoundCastInfo BignumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target); static BoundCastInfo ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target); }; diff --git a/src/duckdb/src/include/duckdb/function/cast/vector_cast_helpers.hpp b/src/duckdb/src/include/duckdb/function/cast/vector_cast_helpers.hpp index ad84e340f..1c50c3414 100644 --- a/src/duckdb/src/include/duckdb/function/cast/vector_cast_helpers.hpp +++ b/src/duckdb/src/include/duckdb/function/cast/vector_cast_helpers.hpp @@ -145,11 +145,10 @@ struct VectorCastHelpers { return TemplatedTryCastLoop>(source, result, count, parameters); } - template + template static bool StringCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { D_ASSERT(result.GetType().InternalType() == PhysicalType::VARCHAR); - UnaryExecutor::GenericExecute>(source, result, count, - (void *)&result); + UnaryExecutor::GenericExecute>(source, result, count, (void *)&result); return true; } diff --git a/src/duckdb/src/include/duckdb/function/copy_function.hpp b/src/duckdb/src/include/duckdb/function/copy_function.hpp index b3244cce0..a86625050 100644 --- a/src/duckdb/src/include/duckdb/function/copy_function.hpp +++ b/src/duckdb/src/include/duckdb/function/copy_function.hpp @@ -76,6 +76,14 @@ struct CopyFunctionBindInput { string file_extension; }; +struct CopyFromFunctionBindInput { + explicit CopyFromFunctionBindInput(const CopyInfo &info_p, TableFunction &tf_p) : info(info_p), tf(tf_p) { + } + + const CopyInfo &info; + TableFunction &tf; +}; + struct CopyToSelectInput { ClientContext &context; case_insensitive_map_t> &options; @@ -102,7 +110,7 @@ typedef void (*copy_to_serialize_t)(Serializer &serializer, const FunctionData & typedef unique_ptr (*copy_to_deserialize_t)(Deserializer &deserializer, CopyFunction &function); -typedef unique_ptr (*copy_from_bind_t)(ClientContext &context, CopyInfo &info, +typedef unique_ptr (*copy_from_bind_t)(ClientContext &context, CopyFromFunctionBindInput &info, vector &expected_names, vector &expected_types); typedef CopyFunctionExecutionMode (*copy_to_execution_mode_t)(bool preserve_insertion_order, bool supports_batch_index); diff --git a/src/duckdb/src/include/duckdb/function/function_binder.hpp b/src/duckdb/src/include/duckdb/function/function_binder.hpp index 021746aed..6eba740ab 100644 --- a/src/duckdb/src/include/duckdb/function/function_binder.hpp +++ b/src/duckdb/src/include/duckdb/function/function_binder.hpp @@ -76,6 +76,9 @@ class FunctionBinder { //! Cast a set of expressions to the arguments of this function void CastToFunctionArguments(SimpleFunction &function, vector> &children); + void ResolveTemplateTypes(BaseScalarFunction &bound_function, const vector> &children); + void CheckTemplateTypesResolved(const BaseScalarFunction &bound_function); + private: optional_idx BindVarArgsFunctionCost(const SimpleFunction &func, const vector &arguments); optional_idx BindFunctionCost(const SimpleFunction &func, const vector &arguments); diff --git a/src/duckdb/src/include/duckdb/function/function_serialization.hpp b/src/duckdb/src/include/duckdb/function/function_serialization.hpp index 203afb981..e0a2c58c9 100644 --- a/src/duckdb/src/include/duckdb/function/function_serialization.hpp +++ b/src/duckdb/src/include/duckdb/function/function_serialization.hpp @@ -143,6 +143,12 @@ class FunctionSerializer { bind_data = FunctionDeserialize(deserializer, function); deserializer.Unset(); } else { + + FunctionBinder binder(context); + + // Resolve templates + binder.ResolveTemplateTypes(function, children); + if (function.bind) { try { bind_data = function.bind(context, function, children); @@ -152,7 +158,10 @@ class FunctionSerializer { error.RawMessage()); } } - FunctionBinder binder(context); + + // Verify that all templates are bound to concrete types. + binder.CheckTemplateTypesResolved(function); + binder.CastToFunctionArguments(function, children); } diff --git a/src/duckdb/src/include/duckdb/function/pragma/pragma_functions.hpp b/src/duckdb/src/include/duckdb/function/pragma/pragma_functions.hpp index 9cef9078a..a1dd91f65 100644 --- a/src/duckdb/src/include/duckdb/function/pragma/pragma_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/pragma/pragma_functions.hpp @@ -21,7 +21,7 @@ struct PragmaFunctions { static void RegisterFunction(BuiltinFunctions &set); }; -string PragmaShowTables(); +string PragmaShowTables(const string &catalog = "", const string &schema = ""); string PragmaShowTablesExpanded(); string PragmaShowDatabases(); string PragmaShowVariables(); diff --git a/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp index ac042a906..f1289cc60 100644 --- a/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/scalar/compressed_materialization_functions.hpp @@ -105,6 +105,16 @@ struct InternalCompressStringUhugeintFun { static ScalarFunction GetFunction(); }; +struct InternalCompressStringHugeintFun { + static constexpr const char *Name = "__internal_compress_string_hugeint"; + static constexpr const char *Parameters = ""; + static constexpr const char *Description = ""; + static constexpr const char *Example = ""; + static constexpr const char *Categories = ""; + + static ScalarFunction GetFunction(); +}; + struct InternalDecompressIntegralSmallintFun { static constexpr const char *Name = "__internal_decompress_integral_smallint"; static constexpr const char *Parameters = ""; diff --git a/src/duckdb/src/include/duckdb/function/table/arrow.hpp b/src/duckdb/src/include/duckdb/function/table/arrow.hpp index f286732f2..a596c86e9 100644 --- a/src/duckdb/src/include/duckdb/function/table/arrow.hpp +++ b/src/duckdb/src/include/duckdb/function/table/arrow.hpp @@ -65,7 +65,7 @@ struct ArrowScanFunctionData : public TableFunctionData { //! The (optional) dependency of this function (used in Python for example) shared_ptr dependency; //! Arrow table data - ArrowTableType arrow_table; + ArrowTableSchema arrow_table; //! Whether projection pushdown is enabled on the scan bool projection_pushdown_enabled = true; }; @@ -89,10 +89,9 @@ struct ArrowRunEndEncodingState { struct ArrowScanLocalState; struct ArrowArrayScanState { public: - explicit ArrowArrayScanState(ArrowScanLocalState &state, ClientContext &context); + explicit ArrowArrayScanState(ClientContext &context); public: - ArrowScanLocalState &state; // Hold ownership over the Arrow Arrays owned by DuckDB to allow for zero-copy shared_ptr owned_data; unordered_map> children; @@ -153,7 +152,7 @@ struct ArrowScanLocalState : public LocalTableFunctionState { ArrowArrayScanState &GetState(idx_t child_idx) { auto it = array_states.find(child_idx); if (it == array_states.end()) { - auto child_p = make_uniq(*this, context); + auto child_p = make_uniq(context); auto &child = *child_p; array_states.emplace(child_idx, std::move(child_p)); return child; @@ -181,6 +180,26 @@ struct ArrowScanGlobalState : public GlobalTableFunctionState { } }; +struct ArrowToDuckDBConversion { + static void SetValidityMask(Vector &vector, ArrowArray &array, idx_t chunk_offset, idx_t size, + int64_t parent_offset, int64_t nested_offset, bool add_null = false); + + static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, idx_t chunk_offset, + ArrowArrayScanState &array_state, idx_t size, + const ArrowType &arrow_type, int64_t nested_offset = -1, + ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); + + static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, idx_t chunk_offset, + ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, + int64_t nested_offset = -1, ValidityMask *parent_mask = nullptr, + uint64_t parent_offset = 0, bool ignore_extensions = false); + + static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, idx_t chunk_offset, + ArrowArrayScanState &array_state, idx_t size, const ArrowType &arrow_type, + int64_t nested_offset = -1, const ValidityMask *parent_mask = nullptr, + uint64_t parent_offset = 0); +}; + struct ArrowTableFunction { public: static void RegisterFunction(BuiltinFunctions &set); @@ -214,9 +233,8 @@ struct ArrowTableFunction { //! Scan Function static void ArrowScanFunction(ClientContext &context, TableFunctionInput &data, DataChunk &output); - static void PopulateArrowTableType(DBConfig &config, ArrowTableType &arrow_table, - const ArrowSchemaWrapper &schema_p, vector &names, - vector &return_types); + static void PopulateArrowTableSchema(DBConfig &config, ArrowTableSchema &arrow_table, + const ArrowSchema &arrow_schema); protected: //! Defines Maximum Number of Threads @@ -230,9 +248,6 @@ struct ArrowTableFunction { //! -----Utility Functions:----- //! Gets Arrow Table's Cardinality static unique_ptr ArrowScanCardinality(ClientContext &context, const FunctionData *bind_data); - //! Gets the progress on the table scan, used for Progress Bars - static double ArrowProgress(ClientContext &context, const FunctionData *bind_data, - const GlobalTableFunctionState *global_state); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp index 7d92a3437..44472ea63 100644 --- a/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp +++ b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_duck_schema.hpp @@ -96,6 +96,8 @@ class ArrowType { bool HasExtension() const; + ArrowArrayPhysicalType GetPhysicalType() const; + //! The Arrow Type Extension data, if any shared_ptr extension_data; @@ -114,13 +116,17 @@ class ArrowType { using arrow_column_map_t = unordered_map>; -struct ArrowTableType { +struct ArrowTableSchema { public: - void AddColumn(idx_t index, shared_ptr type); + void AddColumn(idx_t index, shared_ptr type, const string &name); const arrow_column_map_t &GetColumns() const; + vector &GetTypes(); + vector &GetNames(); private: arrow_column_map_t arrow_convert_data; + vector types; + vector column_names; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/table/arrow/arrow_type_info.hpp b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_type_info.hpp index d2e419646..822f60343 100644 --- a/src/duckdb/src/include/duckdb/function/table/arrow/arrow_type_info.hpp +++ b/src/duckdb/src/include/duckdb/function/table/arrow/arrow_type_info.hpp @@ -21,6 +21,8 @@ namespace duckdb { class ArrowType; +enum class ArrowArrayPhysicalType : uint8_t { DICTIONARY_ENCODED, RUN_END_ENCODED, DEFAULT }; + struct ArrowTypeInfo { public: explicit ArrowTypeInfo() : type() { diff --git a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp index 64b8b5202..624fcabc2 100644 --- a/src/duckdb/src/include/duckdb/function/table/system_functions.hpp +++ b/src/duckdb/src/include/duckdb/function/table/system_functions.hpp @@ -47,6 +47,10 @@ struct DuckDBSchemasFun { static void RegisterFunction(BuiltinFunctions &set); }; +struct DuckDBApproxDatabaseCountFun { + static void RegisterFunction(BuiltinFunctions &set); +}; + struct DuckDBColumnsFun { static void RegisterFunction(BuiltinFunctions &set); }; @@ -164,7 +168,7 @@ struct TestType { struct TestAllTypesFun { static void RegisterFunction(BuiltinFunctions &set); - static vector GetTestTypes(bool large_enum = false); + static vector GetTestTypes(bool large_enum = false, bool large_bignum = false); }; struct TestVectorTypesFun { diff --git a/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp index efc782b8c..df6ec5745 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_aggregate_function.hpp @@ -16,17 +16,19 @@ namespace duckdb { class WindowAggregateExecutor : public WindowExecutor { public: - WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared, + WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &client, WindowSharedExpressions &shared, WindowAggregationMode mode); - void Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, WindowExecutorGlobalState &gstate, - WindowExecutorLocalState &lstate) const override; - void Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, + WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const override; + void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, CollectionPtr collection) const override; - unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const override; const WindowAggregationMode mode; @@ -37,8 +39,9 @@ class WindowAggregateExecutor : public WindowExecutor { unique_ptr filter_ref; protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp index 1caa6326d..82fd00bd9 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_aggregator.hpp @@ -110,19 +110,20 @@ class WindowAggregator { virtual ~WindowAggregator(); // Threading states - virtual unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, + virtual unique_ptr GetGlobalState(ClientContext &client, idx_t group_count, const ValidityMask &partition_mask) const; virtual unique_ptr GetLocalState(const WindowAggregatorState &gstate) const = 0; // Build - virtual void Sink(WindowAggregatorState &gstate, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, idx_t filtered); - virtual void Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats); + virtual void Sink(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, + optional_ptr filter_sel, idx_t filtered); + virtual void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + CollectionPtr collection, const FrameStats &stats); // Probe - virtual void Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) const = 0; + virtual void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const = 0; //! The window function const BoundWindowExpression &wexpr; @@ -142,8 +143,8 @@ class WindowAggregator { class WindowAggregatorGlobalState : public WindowAggregatorState { public: - WindowAggregatorGlobalState(ClientContext &context, const WindowAggregator &aggregator_p, idx_t group_count) - : context(context), aggregator(aggregator_p), aggr(aggregator.wexpr), locals(0), finalized(0) { + WindowAggregatorGlobalState(ClientContext &client, const WindowAggregator &aggregator_p, idx_t group_count) + : client(client), aggregator(aggregator_p), aggr(aggregator.wexpr), locals(0), finalized(0) { if (aggr.filter) { // Start with all invalid and set the ones that pass @@ -153,8 +154,8 @@ class WindowAggregatorGlobalState : public WindowAggregatorState { } } - //! The context we are in - ClientContext &context; + //! The client we are in + ClientContext &client; //! The aggregator data const WindowAggregator &aggregator; @@ -184,8 +185,9 @@ class WindowAggregatorLocalState : public WindowAggregatorState { WindowAggregatorLocalState() { } - void Sink(WindowAggregatorGlobalState &gastate, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t row_idx); - virtual void Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection); + void Sink(ExecutionContext &context, WindowAggregatorGlobalState &gastate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t row_idx); + virtual void Finalize(ExecutionContext &context, WindowAggregatorGlobalState &gastate, CollectionPtr collection); //! The state used for reading the collection unique_ptr cursor; diff --git a/src/duckdb/src/include/duckdb/function/window/window_collection.hpp b/src/duckdb/src/include/duckdb/function/window/window_collection.hpp index d946a1d79..95cf0534f 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_collection.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_collection.hpp @@ -143,4 +143,146 @@ class WindowCursor { DataChunk chunk; }; +class WindowCollectionChunkScanner { +public: + WindowCollectionChunkScanner(ColumnDataCollection &collection, const vector &scan_ids, + const idx_t begin_idx) + : collection(collection), curr_idx(0) { + collection.InitializeScan(state, scan_ids); + collection.InitializeScanChunk(state, chunk); + + Seek(begin_idx); + } + + void Seek(idx_t begin_idx) { + idx_t chunk_idx; + idx_t seg_idx; + idx_t row_idx; + for (; curr_idx > begin_idx; --curr_idx) { + collection.PrevScanIndex(state, chunk_idx, seg_idx, row_idx); + } + for (; curr_idx < begin_idx; ++curr_idx) { + collection.NextScanIndex(state, chunk_idx, seg_idx, row_idx); + } + } + + bool Scan() { + const auto result = collection.Scan(state, chunk); + ++curr_idx; + return result; + } + + idx_t Scanned() const { + return state.next_row_index; + } + + //! Return a struct type for comparing keys + LogicalType PrefixStructType(column_t end, column_t begin = 0); + //! Reference the chunk into a struct vector matching the keys + static void ReferenceStructColumns(DataChunk &chunk, Vector &vec, column_t end, column_t begin = 0); + + ColumnDataCollection &collection; + ColumnDataScanState state; + DataChunk chunk; + idx_t curr_idx; +}; + +template +static void WindowDeltaScanner(ColumnDataCollection &collection, idx_t block_begin, idx_t block_end, + const vector &scan_cols, const idx_t key_count, OP operation) { + + // Stop if there is no work to do + if (!collection.Count()) { + return; + } + + // Start back one to get the overlap + idx_t block_curr = block_begin ? block_begin - 1 : 0; + + // Scan the sort columns + WindowCollectionChunkScanner scanner(collection, scan_cols, block_curr); + auto &scanned = scanner.chunk; + + // Shifted buffer for the next values + DataChunk next; + collection.InitializeScanChunk(scanner.state, next); + + // Delay buffer for the previous row + DataChunk delayed; + collection.InitializeScanChunk(scanner.state, delayed); + + // Only compare the key arguments. + const auto key_type = scanner.PrefixStructType(key_count); + Vector compare_curr(key_type); + Vector compare_prev(key_type); + + bool boundary_compare = (block_begin > 0); + idx_t row_idx = 1; + if (!scanner.Scan()) { + return; + } + + // Process chunks offset by 1 + SelectionVector next_sel(1, STANDARD_VECTOR_SIZE); + SelectionVector distinct(STANDARD_VECTOR_SIZE); + SelectionVector matching(STANDARD_VECTOR_SIZE); + + // In order to reuse the verbose `distinct from` logic for both the main vector comparisons + // and single element boundary comparisons, we alternate between single element compares + // and count-1 compares. + while (block_curr < block_end) { + // Compare the current to the previous; + DataChunk *curr = nullptr; + DataChunk *prev = nullptr; + + idx_t count = 0; + if (boundary_compare) { + // Save the last row of the scanned chunk + count = 1; + sel_t last = UnsafeNumericCast(scanned.size() - 1); + SelectionVector sel(&last); + delayed.Reset(); + scanned.Copy(delayed, sel, count); + prev = &delayed; + + // Try to read the next chunk + ++block_curr; + row_idx = scanner.Scanned(); + if (block_curr >= block_end || !scanner.Scan()) { + break; + } + curr = &scanned; + } else { + // Compare the [1..size) values with the [0..size-1) values + count = scanned.size() - 1; + if (!count) { + // 1 row scanned, so just skip the rest of the loop. + boundary_compare = true; + continue; + } + prev = &scanned; + + // Slice the current back one into the previous + next.Slice(scanned, next_sel, count); + curr = &next; + } + + // Reference the comparison prefix as a struct to simplify the compares. + scanner.ReferenceStructColumns(*prev, compare_prev, key_count); + scanner.ReferenceStructColumns(*curr, compare_curr, key_count); + + const auto ndistinct = + VectorOperations::DistinctFrom(compare_curr, compare_prev, nullptr, count, &distinct, &matching); + + // If n is 0, neither SV has been filled in? + auto match_sel = ndistinct ? &matching : FlatVector::IncrementalSelectionVector(); + + operation(row_idx, *prev, *curr, ndistinct, distinct, *match_sel); + + // Transition between comparison ranges. + boundary_compare = !boundary_compare; + row_idx += count; + } +} + } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp index c3ca63f36..f15572d9c 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_constant_aggregator.hpp @@ -16,7 +16,7 @@ class WindowConstantAggregator : public WindowAggregator { public: static bool CanAggregate(const BoundWindowExpression &wexpr); - static BoundWindowExpression &RebindAggregate(ClientContext &context, BoundWindowExpression &wexpr); + static BoundWindowExpression &RebindAggregate(ClientContext &client, BoundWindowExpression &wexpr); WindowConstantAggregator(BoundWindowExpression &wexpr, WindowSharedExpressions &shared, ClientContext &context); ~WindowConstantAggregator() override { @@ -24,15 +24,15 @@ class WindowConstantAggregator : public WindowAggregator { unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, const ValidityMask &partition_mask) const override; - void Sink(WindowAggregatorState &gstate, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, + void Sink(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, idx_t filtered) override; - void Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) override; + void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + CollectionPtr collection, const FrameStats &stats) override; unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) const override; + void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp index ab664ca59..cbee084fd 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_custom_aggregator.hpp @@ -21,12 +21,12 @@ class WindowCustomAggregator : public WindowAggregator { unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, const ValidityMask &partition_mask) const override; - void Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) override; + void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + CollectionPtr collection, const FrameStats &stats) override; unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) const override; + void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp index 7fe9040c0..839150e52 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_distinct_aggregator.hpp @@ -17,20 +17,21 @@ class WindowDistinctAggregator : public WindowAggregator { static bool CanAggregate(const BoundWindowExpression &wexpr); WindowDistinctAggregator(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared, - ClientContext &context); + ClientContext &client); // Build - unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, + unique_ptr GetGlobalState(ClientContext &client, idx_t group_count, const ValidityMask &partition_mask) const override; - void Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx, optional_ptr filter_sel, idx_t filtered) override; - void Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) override; + void Sink(ExecutionContext &context, WindowAggregatorState &gsink, WindowAggregatorState &lstate, + DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, + idx_t filtered) override; + void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + CollectionPtr collection, const FrameStats &stats) override; // Evaluate unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) const override; + void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const override; //! Context for sorting ClientContext &context; diff --git a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp index af67bdc5f..1489b4c0a 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_executor.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_executor.hpp @@ -41,9 +41,10 @@ class WindowExecutorGlobalState : public WindowExecutorState { public: using CollectionPtr = optional_ptr; - WindowExecutorGlobalState(const WindowExecutor &executor, const idx_t payload_count, + WindowExecutorGlobalState(ClientContext &client, const WindowExecutor &executor, const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask); + ClientContext &client; const WindowExecutor &executor; const idx_t payload_count; @@ -56,19 +57,20 @@ class WindowExecutorLocalState : public WindowExecutorState { public: using CollectionPtr = optional_ptr; - explicit WindowExecutorLocalState(const WindowExecutorGlobalState &gstate); + WindowExecutorLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate); - virtual void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx); - virtual void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection); + virtual void Sink(ExecutionContext &context, WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, + DataChunk &coll_chunk, idx_t input_idx); + virtual void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, CollectionPtr collection); //! The state used for reading the range collection unique_ptr range_cursor; }; -class WindowExecutorBoundsState : public WindowExecutorLocalState { +class WindowExecutorBoundsLocalState : public WindowExecutorLocalState { public: - explicit WindowExecutorBoundsState(const WindowExecutorGlobalState &gstate); - ~WindowExecutorBoundsState() override { + WindowExecutorBoundsLocalState(ExecutionContext &context, const WindowExecutorGlobalState &gstate); + ~WindowExecutorBoundsLocalState() override { } virtual void UpdateBounds(WindowExecutorGlobalState &gstate, idx_t row_idx, DataChunk &eval_chunk, @@ -85,28 +87,29 @@ class WindowExecutor { public: using CollectionPtr = optional_ptr; - WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); virtual ~WindowExecutor() { } virtual bool IgnoreNulls() const; - virtual unique_ptr - GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, const ValidityMask &order_mask) const; - virtual unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const; + virtual unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, + const ValidityMask &order_mask) const; + virtual unique_ptr GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const; - virtual void Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, + virtual void Sink(ExecutionContext &context, DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const; - virtual void Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection) const; + virtual void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, CollectionPtr collection) const; - void Evaluate(idx_t row_idx, DataChunk &eval_chunk, Vector &result, WindowExecutorLocalState &lstate, - WindowExecutorGlobalState &gstate) const; + void Evaluate(ExecutionContext &context, idx_t row_idx, DataChunk &eval_chunk, Vector &result, + WindowExecutorLocalState &lstate, WindowExecutorGlobalState &gstate) const; // The function const BoundWindowExpression &wexpr; - ClientContext &context; // evaluate frame expressions, if needed column_t boundary_start_idx = DConstants::INVALID_INDEX; @@ -117,8 +120,9 @@ class WindowExecutor { column_t range_idx = DConstants::INVALID_INDEX; protected: - virtual void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const = 0; + virtual void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const = 0; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp index e95b52274..e131bba1b 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_index_tree.hpp @@ -16,7 +16,7 @@ class WindowIndexTree; class WindowIndexTreeLocalState : public WindowMergeSortTreeLocalState { public: - explicit WindowIndexTreeLocalState(WindowIndexTree &index_tree); + WindowIndexTreeLocalState(ExecutionContext &context, WindowIndexTree &index_tree); //! Process sorted leaf data void BuildLeaves() override; @@ -33,7 +33,7 @@ class WindowIndexTree : public WindowMergeSortTree { const idx_t count); ~WindowIndexTree() override = default; - unique_ptr GetLocalState() override; + unique_ptr GetLocalState(ExecutionContext &context) override; //! Find the Nth index in the set of subframes //! Returns {nth index, 0} or {nth offset, overflow} diff --git a/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp index 6faecbbf8..faa0f30b7 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_merge_sort_tree.hpp @@ -12,85 +12,88 @@ #include "duckdb/planner/bound_result_modifier.hpp" #include "duckdb/function/window/window_aggregator.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/partition_state.hpp" +#include "duckdb/common/sorting/sort.hpp" namespace duckdb { +enum class WindowMergeSortStage : uint8_t { INIT, COMBINE, FINALIZE, SORTED, FINISHED }; + class WindowMergeSortTree; class WindowMergeSortTreeLocalState : public WindowAggregatorState { public: - explicit WindowMergeSortTreeLocalState(WindowMergeSortTree &index_tree); + WindowMergeSortTreeLocalState(ExecutionContext &context, WindowMergeSortTree &index_tree); //! Add a chunk to the local sort - void SinkChunk(DataChunk &chunk, const idx_t row_idx, optional_ptr filter_sel, idx_t filtered); + void Sink(ExecutionContext &context, DataChunk &chunk, const idx_t row_idx, + optional_ptr filter_sel, idx_t filtered); //! Sort the data - void Sort(); + void Finalize(ExecutionContext &context); //! Process sorted leaf data virtual void BuildLeaves() = 0; //! The index tree we are building WindowMergeSortTree &window_tree; //! Thread-local sorting data - optional_ptr local_sort; - //! Buffer for the sort keys + optional_ptr local_sink; + //! Buffer for the sort data DataChunk sort_chunk; - //! Buffer for the payload data - DataChunk payload_chunk; //! Build stage - PartitionSortStage build_stage = PartitionSortStage::INIT; + WindowMergeSortStage build_stage = WindowMergeSortStage::INIT; //! Build task number idx_t build_task; private: - void ExecuteSortTask(); + void ExecuteSortTask(ExecutionContext &context); }; class WindowMergeSortTree { public: - using GlobalSortStatePtr = unique_ptr; - using LocalSortStatePtr = unique_ptr; + using GlobalSortStatePtr = unique_ptr; + using LocalSortStatePtr = unique_ptr; WindowMergeSortTree(ClientContext &context, const vector &orders, - const vector &sort_idx, const idx_t count, bool unique = false); + const vector &order_idx, const idx_t count, bool unique = false); virtual ~WindowMergeSortTree() = default; - virtual unique_ptr GetLocalState() = 0; + virtual unique_ptr GetLocalState(ExecutionContext &context) = 0; //! Make a local sort for a thread - optional_ptr AddLocalSort(); + optional_ptr InitializeLocalSort(ExecutionContext &context) const; //! Thread-safe post-sort cleanup - virtual void CleanupSort(); + virtual void Finished(); //! Sort state machine bool TryPrepareSortStage(WindowMergeSortTreeLocalState &lstate); //! Build the MST in parallel from the sorted data void Build(); - //! The query context - ClientContext &context; - //! Thread memory limit - const idx_t memory_per_thread; //! The column indices for sorting - const vector sort_idx; + const vector order_idx; + //! The sorted data schema + vector scan_types; + vector scan_cols; + //! The sort key columns + vector key_cols; + //! The sort specification + unique_ptr sort; //! The sorted data - GlobalSortStatePtr global_sort; + GlobalSortStatePtr global_sink; + //! The resulting sorted data + unique_ptr sorted; //! Finalize guard - mutex lock; + mutable mutex lock; //! Local sort set - vector local_sorts; + mutable vector local_sinks; //! Finalize stage - atomic build_stage; + atomic build_stage; //! Tasks launched idx_t total_tasks = 0; //! Tasks launched idx_t tasks_assigned = 0; //! Tasks landed atomic tasks_completed; - //! The block starts (the scanner doesn't know this) plus the total count - vector block_starts; // Merge sort trees for various sizes // Smaller is probably not worth the effort. diff --git a/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp b/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp index 1fdf497ee..0af8d691c 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_naive_aggregator.hpp @@ -21,8 +21,8 @@ class WindowNaiveAggregator : public WindowAggregator { ~WindowNaiveAggregator() override; unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) const override; + void Evaluate(ExecutionContext &context, const WindowAggregatorState &gsink, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const override; //! The parent executor const WindowAggregateExecutor &executor; diff --git a/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp index 9afb7e970..7127fd04d 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_rank_function.hpp @@ -14,11 +14,13 @@ namespace duckdb { class WindowPeerExecutor : public WindowExecutor { public: - WindowPeerExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowPeerExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, + unique_ptr GetGlobalState(ClientContext &context, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const override; //! The column indices of any ORDER BY argument expressions vector arg_order_idx; @@ -26,38 +28,42 @@ class WindowPeerExecutor : public WindowExecutor { class WindowRankExecutor : public WindowPeerExecutor { public: - WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; class WindowDenseRankExecutor : public WindowPeerExecutor { public: - WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowDenseRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; class WindowPercentRankExecutor : public WindowPeerExecutor { public: - WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowPercentRankExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; class WindowCumeDistExecutor : public WindowPeerExecutor { public: - WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowCumeDistExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp index be9bf4713..c387143ca 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_rownumber_function.hpp @@ -14,11 +14,13 @@ namespace duckdb { class WindowRowNumberExecutor : public WindowExecutor { public: - WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowRowNumberExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const override; //! The evaluation index of the NTILE column column_t ntile_idx = DConstants::INVALID_INDEX; @@ -26,18 +28,20 @@ class WindowRowNumberExecutor : public WindowExecutor { vector arg_order_idx; protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; // NTILE is just scaled ROW_NUMBER class WindowNtileExecutor : public WindowRowNumberExecutor { public: - WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowNtileExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp index 0b574b959..c571533b6 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_segment_tree.hpp @@ -21,11 +21,11 @@ class WindowSegmentTree : public WindowAggregator { unique_ptr GetGlobalState(ClientContext &context, idx_t group_count, const ValidityMask &partition_mask) const override; unique_ptr GetLocalState(const WindowAggregatorState &gstate) const override; - void Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) override; + void Finalize(ExecutionContext &context, WindowAggregatorState &gstate, WindowAggregatorState &lstate, + CollectionPtr collection, const FrameStats &stats) override; - void Evaluate(const WindowAggregatorState &gstate, WindowAggregatorState &lstate, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) const override; + void Evaluate(ExecutionContext &context, const WindowAggregatorState &gstate, WindowAggregatorState &lstate, + const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp b/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp index 4df4dab8f..e6cd41869 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_token_tree.hpp @@ -15,19 +15,19 @@ namespace duckdb { // Builds a merge sort tree that uses integer tokens for the comparison values instead of the sort keys. class WindowTokenTree : public WindowMergeSortTree { public: - WindowTokenTree(ClientContext &context, const vector &orders, const vector &sort_idx, + WindowTokenTree(ClientContext &context, const vector &orders, const vector &order_idx, const idx_t count, bool unique = false) - : WindowMergeSortTree(context, orders, sort_idx, count, unique) { + : WindowMergeSortTree(context, orders, order_idx, count, unique) { } - WindowTokenTree(ClientContext &context, const BoundOrderModifier &order_bys, const vector &sort_idx, + WindowTokenTree(ClientContext &context, const BoundOrderModifier &order_bys, const vector &order_idx, const idx_t count, bool unique = false) - : WindowTokenTree(context, order_bys.orders, sort_idx, count, unique) { + : WindowTokenTree(context, order_bys.orders, order_idx, count, unique) { } - unique_ptr GetLocalState() override; + unique_ptr GetLocalState(ExecutionContext &context) override; //! Thread-safe post-sort cleanup - void CleanupSort() override; + void Finished() override; //! Find the rank of the row within the range idx_t Rank(const idx_t lower, const idx_t upper, const idx_t row_idx) const; diff --git a/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp b/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp index 580ac226e..6076ebdd0 100644 --- a/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp +++ b/src/duckdb/src/include/duckdb/function/window/window_value_function.hpp @@ -15,14 +15,16 @@ namespace duckdb { // Base class for non-aggregate functions that have a payload class WindowValueExecutor : public WindowExecutor { public: - WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - void Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, + void Finalize(ExecutionContext &context, WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, CollectionPtr collection) const override; - unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const override; //! The column index of the value column column_t child_idx = DConstants::INVALID_INDEX; @@ -38,63 +40,72 @@ class WindowValueExecutor : public WindowExecutor { class WindowLeadLagExecutor : public WindowValueExecutor { public: - WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowLeadLagExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); - unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const override; protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; class WindowFirstValueExecutor : public WindowValueExecutor { public: - WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowFirstValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; class WindowLastValueExecutor : public WindowValueExecutor { public: - WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowLastValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; class WindowNthValueExecutor : public WindowValueExecutor { public: - WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowNthValueExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; class WindowFillExecutor : public WindowValueExecutor { public: - WindowFillExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared); + WindowFillExecutor(BoundWindowExpression &wexpr, WindowSharedExpressions &shared); //! Never ignore nulls (that's the point!) bool IgnoreNulls() const override { return false; } - unique_ptr GetGlobalState(const idx_t payload_count, const ValidityMask &partition_mask, + unique_ptr GetGlobalState(ClientContext &client, const idx_t payload_count, + const ValidityMask &partition_mask, const ValidityMask &order_mask) const override; - unique_ptr GetLocalState(const WindowExecutorGlobalState &gstate) const override; + unique_ptr GetLocalState(ExecutionContext &context, + const WindowExecutorGlobalState &gstate) const override; //! Secondary order collection index idx_t order_idx = DConstants::INVALID_INDEX; protected: - void EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, DataChunk &eval_chunk, - Vector &result, idx_t count, idx_t row_idx) const override; + void EvaluateInternal(ExecutionContext &context, WindowExecutorGlobalState &gstate, + WindowExecutorLocalState &lstate, DataChunk &eval_chunk, Vector &result, idx_t count, + idx_t row_idx) const override; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/logging/log_manager.hpp b/src/duckdb/src/include/duckdb/logging/log_manager.hpp index 0c6e15b75..252721d5f 100644 --- a/src/duckdb/src/include/duckdb/logging/log_manager.hpp +++ b/src/duckdb/src/include/duckdb/logging/log_manager.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/logging/log_storage.hpp +// duckdb/logging/log_manager.hpp // // //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/include/duckdb/main/attached_database.hpp b/src/duckdb/src/include/duckdb/main/attached_database.hpp index b92e0dbfd..d20726365 100644 --- a/src/duckdb/src/include/duckdb/main/attached_database.hpp +++ b/src/duckdb/src/include/duckdb/main/attached_database.hpp @@ -13,7 +13,6 @@ #include "duckdb/common/mutex.hpp" #include "duckdb/main/config.hpp" #include "duckdb/catalog/catalog_entry.hpp" -#include "duckdb/storage/storage_options.hpp" namespace duckdb { class Catalog; @@ -38,7 +37,7 @@ struct AttachOptions { //! Constructor for databases we attach outside of the ATTACH DATABASE statement. explicit AttachOptions(const DBConfigOptions &options); //! Constructor for databases we attach when using ATTACH DATABASE. - AttachOptions(const unique_ptr &info, const AccessMode default_access_mode); + AttachOptions(const unordered_map &options, const AccessMode default_access_mode); //! Defaults to the access mode configured in the DBConfig, unless specified otherwise. AccessMode access_mode; @@ -63,7 +62,7 @@ class AttachedDatabase : public CatalogEntry { ~AttachedDatabase() override; //! Initializes the catalog and storage of the attached database. - void Initialize(optional_ptr context = nullptr, StorageOptions options = StorageOptions()); + void Initialize(optional_ptr context = nullptr); void FinalizeLoad(optional_ptr context); void Close(); diff --git a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp index 941a63561..1784ce990 100644 --- a/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/capi_internal.hpp @@ -42,6 +42,11 @@ struct CClientContextWrapper { ClientContext &context; }; +struct CClientArrowOptionsWrapper { + explicit CClientArrowOptionsWrapper(ClientProperties &properties) : properties(properties) {}; + ClientProperties properties; +}; + struct PreparedStatementWrapper { //! Map of name -> values case_insensitive_map_t values; @@ -77,6 +82,10 @@ struct ErrorDataWrapper { ErrorData error_data; }; +struct ExpressionWrapper { + unique_ptr expr; +}; + enum class CAPIResultSetType : uint8_t { CAPI_RESULT_TYPE_NONE = 0, CAPI_RESULT_TYPE_MATERIALIZED, diff --git a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp index 4ba922f42..9b0183850 100644 --- a/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp +++ b/src/duckdb/src/include/duckdb/main/capi/extension_api.hpp @@ -136,7 +136,7 @@ typedef struct { duckdb_value (*duckdb_create_timestamp)(duckdb_timestamp input); duckdb_value (*duckdb_create_interval)(duckdb_interval input); duckdb_value (*duckdb_create_blob)(const uint8_t *data, idx_t length); - duckdb_value (*duckdb_create_varint)(duckdb_varint input); + duckdb_value (*duckdb_create_bignum)(duckdb_bignum input); duckdb_value (*duckdb_create_decimal)(duckdb_decimal input); duckdb_value (*duckdb_create_bit)(duckdb_bit input); duckdb_value (*duckdb_create_uuid)(duckdb_uhugeint input); @@ -160,7 +160,7 @@ typedef struct { duckdb_interval (*duckdb_get_interval)(duckdb_value val); duckdb_logical_type (*duckdb_get_value_type)(duckdb_value val); duckdb_blob (*duckdb_get_blob)(duckdb_value val); - duckdb_varint (*duckdb_get_varint)(duckdb_value val); + duckdb_bignum (*duckdb_get_bignum)(duckdb_value val); duckdb_decimal (*duckdb_get_decimal)(duckdb_value val); duckdb_bit (*duckdb_get_bit)(duckdb_value val); duckdb_uhugeint (*duckdb_get_uuid)(duckdb_value val); @@ -472,6 +472,18 @@ typedef struct { duckdb_state (*duckdb_append_default_to_chunk)(duckdb_appender appender, duckdb_data_chunk chunk, idx_t col, idx_t row); duckdb_error_data (*duckdb_appender_error_data)(duckdb_appender appender); + // New arrow interface functions + + duckdb_error_data (*duckdb_to_arrow_schema)(duckdb_arrow_options arrow_options, duckdb_logical_type *types, + const char **names, idx_t column_count, struct ArrowSchema *out_schema); + duckdb_error_data (*duckdb_data_chunk_to_arrow)(duckdb_arrow_options arrow_options, duckdb_data_chunk chunk, + struct ArrowArray *out_arrow_array); + duckdb_error_data (*duckdb_schema_from_arrow)(duckdb_connection connection, struct ArrowSchema *schema, + duckdb_arrow_converted_schema *out_types); + duckdb_error_data (*duckdb_data_chunk_from_arrow)(duckdb_connection connection, struct ArrowArray *arrow_array, + duckdb_arrow_converted_schema converted_schema, + duckdb_data_chunk *out_chunk); + void (*duckdb_destroy_arrow_converted_schema)(duckdb_arrow_converted_schema *arrow_converted_schema); // New functions for duckdb error data duckdb_error_data (*duckdb_create_error_data)(duckdb_error_type type, const char *message); @@ -479,12 +491,24 @@ typedef struct { duckdb_error_type (*duckdb_error_data_error_type)(duckdb_error_data error_data); const char *(*duckdb_error_data_message)(duckdb_error_data error_data); bool (*duckdb_error_data_has_error)(duckdb_error_data error_data); + // API to create and manipulate expressions + + void (*duckdb_destroy_expression)(duckdb_expression *expr); + duckdb_logical_type (*duckdb_expression_return_type)(duckdb_expression expr); + bool (*duckdb_expression_is_foldable)(duckdb_expression expr); + duckdb_error_data (*duckdb_expression_fold)(duckdb_client_context context, duckdb_expression expr, + duckdb_value *out_value); // New functions around the client context idx_t (*duckdb_client_context_get_connection_id)(duckdb_client_context context); void (*duckdb_destroy_client_context)(duckdb_client_context *context); void (*duckdb_connection_get_client_context)(duckdb_connection connection, duckdb_client_context *out_context); duckdb_value (*duckdb_get_table_names)(duckdb_connection connection, const char *query, bool qualified); + void (*duckdb_connection_get_arrow_options)(duckdb_connection connection, duckdb_arrow_options *out_arrow_options); + void (*duckdb_destroy_arrow_options)(duckdb_arrow_options *arrow_options); + // New query execution functions + + duckdb_arrow_options (*duckdb_result_get_arrow_options)(duckdb_result *result); // New functions around scalar function binding void (*duckdb_scalar_function_set_bind)(duckdb_scalar_function scalar_function, duckdb_scalar_function_bind_t bind); @@ -494,14 +518,21 @@ typedef struct { duckdb_delete_callback_t destroy); void *(*duckdb_scalar_function_get_bind_data)(duckdb_function_info info); void *(*duckdb_scalar_function_bind_get_extra_info)(duckdb_bind_info info); + idx_t (*duckdb_scalar_function_bind_get_argument_count)(duckdb_bind_info info); + duckdb_expression (*duckdb_scalar_function_bind_get_argument)(duckdb_bind_info info, idx_t index); // New string functions that are added char *(*duckdb_value_to_string)(duckdb_value value); + // New functions around table function binding + + void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); // New value functions that are added duckdb_value (*duckdb_create_map_value)(duckdb_logical_type map_type, duckdb_value *keys, duckdb_value *values, idx_t entry_count); duckdb_value (*duckdb_create_union_value)(duckdb_logical_type union_type, idx_t tag_index, duckdb_value value); + duckdb_value (*duckdb_create_time_ns)(duckdb_time_ns input); + duckdb_time_ns (*duckdb_get_time_ns)(duckdb_value val); // API to create and manipulate vector types duckdb_vector (*duckdb_create_vector)(duckdb_logical_type type, idx_t capacity); @@ -635,7 +666,7 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_create_timestamp = duckdb_create_timestamp; result.duckdb_create_interval = duckdb_create_interval; result.duckdb_create_blob = duckdb_create_blob; - result.duckdb_create_varint = duckdb_create_varint; + result.duckdb_create_bignum = duckdb_create_bignum; result.duckdb_create_decimal = duckdb_create_decimal; result.duckdb_create_bit = duckdb_create_bit; result.duckdb_create_uuid = duckdb_create_uuid; @@ -659,7 +690,7 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_get_interval = duckdb_get_interval; result.duckdb_get_value_type = duckdb_get_value_type; result.duckdb_get_blob = duckdb_get_blob; - result.duckdb_get_varint = duckdb_get_varint; + result.duckdb_get_bignum = duckdb_get_bignum; result.duckdb_get_decimal = duckdb_get_decimal; result.duckdb_get_bit = duckdb_get_bit; result.duckdb_get_uuid = duckdb_get_uuid; @@ -930,24 +961,41 @@ inline duckdb_ext_api_v1 CreateAPIv1() { result.duckdb_destroy_instance_cache = duckdb_destroy_instance_cache; result.duckdb_append_default_to_chunk = duckdb_append_default_to_chunk; result.duckdb_appender_error_data = duckdb_appender_error_data; + result.duckdb_to_arrow_schema = duckdb_to_arrow_schema; + result.duckdb_data_chunk_to_arrow = duckdb_data_chunk_to_arrow; + result.duckdb_schema_from_arrow = duckdb_schema_from_arrow; + result.duckdb_data_chunk_from_arrow = duckdb_data_chunk_from_arrow; + result.duckdb_destroy_arrow_converted_schema = duckdb_destroy_arrow_converted_schema; result.duckdb_create_error_data = duckdb_create_error_data; result.duckdb_destroy_error_data = duckdb_destroy_error_data; result.duckdb_error_data_error_type = duckdb_error_data_error_type; result.duckdb_error_data_message = duckdb_error_data_message; result.duckdb_error_data_has_error = duckdb_error_data_has_error; + result.duckdb_destroy_expression = duckdb_destroy_expression; + result.duckdb_expression_return_type = duckdb_expression_return_type; + result.duckdb_expression_is_foldable = duckdb_expression_is_foldable; + result.duckdb_expression_fold = duckdb_expression_fold; result.duckdb_client_context_get_connection_id = duckdb_client_context_get_connection_id; result.duckdb_destroy_client_context = duckdb_destroy_client_context; result.duckdb_connection_get_client_context = duckdb_connection_get_client_context; result.duckdb_get_table_names = duckdb_get_table_names; + result.duckdb_connection_get_arrow_options = duckdb_connection_get_arrow_options; + result.duckdb_destroy_arrow_options = duckdb_destroy_arrow_options; + result.duckdb_result_get_arrow_options = duckdb_result_get_arrow_options; result.duckdb_scalar_function_set_bind = duckdb_scalar_function_set_bind; result.duckdb_scalar_function_bind_set_error = duckdb_scalar_function_bind_set_error; result.duckdb_scalar_function_get_client_context = duckdb_scalar_function_get_client_context; result.duckdb_scalar_function_set_bind_data = duckdb_scalar_function_set_bind_data; result.duckdb_scalar_function_get_bind_data = duckdb_scalar_function_get_bind_data; result.duckdb_scalar_function_bind_get_extra_info = duckdb_scalar_function_bind_get_extra_info; + result.duckdb_scalar_function_bind_get_argument_count = duckdb_scalar_function_bind_get_argument_count; + result.duckdb_scalar_function_bind_get_argument = duckdb_scalar_function_bind_get_argument; result.duckdb_value_to_string = duckdb_value_to_string; + result.duckdb_table_function_get_client_context = duckdb_table_function_get_client_context; result.duckdb_create_map_value = duckdb_create_map_value; result.duckdb_create_union_value = duckdb_create_union_value; + result.duckdb_create_time_ns = duckdb_create_time_ns; + result.duckdb_get_time_ns = duckdb_get_time_ns; result.duckdb_create_vector = duckdb_create_vector; result.duckdb_destroy_vector = duckdb_destroy_vector; result.duckdb_slice_vector = duckdb_slice_vector; diff --git a/src/duckdb/src/include/duckdb/main/client_config.hpp b/src/duckdb/src/include/duckdb/main/client_config.hpp index bad9e04e0..04c622894 100644 --- a/src/duckdb/src/include/duckdb/main/client_config.hpp +++ b/src/duckdb/src/include/duckdb/main/client_config.hpp @@ -55,9 +55,6 @@ struct ClientConfig { //! The wait time before showing the progress bar int wait_time = 2000; - //! Preserve identifier case while parsing. - //! If false, all unquoted identifiers are lower-cased (e.g. "MyTable" -> "mytable"). - bool preserve_identifier_case = true; //! The maximum expression depth limit in the parser idx_t max_expression_depth = 1000; @@ -77,14 +74,8 @@ struct ClientConfig { bool verify_parallelism = false; //! Force out-of-core computation for operators that support it, used for testing bool force_external = false; - //! Force disable cross product generation when hyper graph isn't connected, used for testing - bool force_no_cross_product = false; - //! Force use of IEJoin to implement AsOfJoin, used for testing - bool force_asof_iejoin = false; //! Force use of fetch row instead of scan, used for testing bool force_fetch_row = false; - //! Use range joins for inequalities, even if there are equality predicates - bool prefer_range_joins = false; //! If this context should also try to use the available replacement scans //! True by default bool use_replacement_scans = true; @@ -93,16 +84,6 @@ struct ClientConfig { idx_t perfect_ht_threshold = 12; //! The maximum number of rows to accumulate before sorting ordered aggregates. idx_t ordered_aggregate_threshold = (idx_t(1) << 18); - //! The number of rows to accumulate before flushing during a partitioned write - idx_t partitioned_write_flush_threshold = idx_t(1) << idx_t(19); - //! The amount of rows we can keep open before we close and flush them during a partitioned write - idx_t partitioned_write_max_open_files = idx_t(100); - //! The maximum number of rows on either table to choose a nested loop join - idx_t nested_loop_join_threshold = 5; - //! The maximum number of rows on either table to choose a merge join over an IE join - idx_t merge_join_threshold = 1000; - //! The maximum number of rows to use the nested loop join implementation - idx_t asof_loop_join_threshold = 64; //! The maximum amount of memory to keep buffered in a streaming query result. Default: 1mb. idx_t streaming_buffer_size = 1000000; @@ -113,28 +94,8 @@ struct ClientConfig { //! The explain output type used when none is specified (default: PHYSICAL_ONLY) ExplainOutputType explain_output_type = ExplainOutputType::PHYSICAL_ONLY; - //! The maximum amount of pivot columns - idx_t pivot_limit = 100000; - - //! The threshold at which we switch from using filtered aggregates to LIST with a dedicated pivot operator - idx_t pivot_filter_threshold = 20; - - //! The maximum amount of OR filters we generate dynamically from a hash join - idx_t dynamic_or_filter_threshold = 50; - - //! The maximum amount of rows in the LIMIT/SAMPLE for which we trigger late materialization - idx_t late_materialization_max_rows = 50; - - //! Whether the "/" division operator defaults to integer division or floating point division - bool integer_division = false; - //! When a scalar subquery returns multiple rows - return a random row instead of returning an error - bool scalar_subquery_error_on_multiple_rows = true; //! Use IEE754-compliant floating point operations (returning NAN instead of errors/NULL) bool ieee_floating_point_ops = true; - //! Allow ordering by non-integer literals - ordering by such literals has no effect - bool order_by_non_integer_literal = false; - //! Disable casting from timestamp => timestamptz (naïve timestamps) - bool disable_timestamptz_casts = false; //! If DEFAULT or ENABLE_SINGLE_ARROW, it is possible to use the deprecated single arrow operator (->) for lambda //! functions. Otherwise, DISABLE_SINGLE_ARROW. LambdaSyntax lambda_syntax = LambdaSyntax::DEFAULT; @@ -165,26 +126,11 @@ struct ClientConfig { static ClientConfig &GetConfig(ClientContext &context); static const ClientConfig &GetConfig(const ClientContext &context); - bool AnyVerification() { - return query_verification_enabled || verify_external || verify_serializer || verify_fetch_row; - } + bool AnyVerification() const; - void SetUserVariable(const string &name, Value value) { - user_variables[name] = std::move(value); - } - - bool GetUserVariable(const string &name, Value &result) { - auto entry = user_variables.find(name); - if (entry == user_variables.end()) { - return false; - } - result = entry->second; - return true; - } - - void ResetUserVariable(const string &name) { - user_variables.erase(name); - } + void SetUserVariable(const string &name, Value value); + bool GetUserVariable(const string &name, Value &result); + void ResetUserVariable(const string &name); template static typename OP::RETURN_TYPE GetSetting(const ClientContext &context) { diff --git a/src/duckdb/src/include/duckdb/main/client_context.hpp b/src/duckdb/src/include/duckdb/main/client_context.hpp index 3e925248f..15beb4699 100644 --- a/src/duckdb/src/include/duckdb/main/client_context.hpp +++ b/src/duckdb/src/include/duckdb/main/client_context.hpp @@ -24,7 +24,6 @@ #include "duckdb/main/external_dependencies.hpp" #include "duckdb/main/pending_query_result.hpp" #include "duckdb/main/prepared_statement.hpp" -#include "duckdb/main/settings.hpp" #include "duckdb/main/stream_query_result.hpp" #include "duckdb/main/table_description.hpp" #include "duckdb/planner/expression/bound_parameter_data.hpp" @@ -191,7 +190,7 @@ class ClientContext : public enable_shared_from_this { bool requires_valid_transaction = true); //! Equivalent to CURRENT_SETTING(key) SQL function. - DUCKDB_API SettingLookupResult TryGetCurrentSetting(const std::string &key, Value &result) const; + DUCKDB_API SettingLookupResult TryGetCurrentSetting(const string &key, Value &result) const; //! Returns the parser options for this client context DUCKDB_API ParserOptions GetParserOptions() const; diff --git a/src/duckdb/src/include/duckdb/main/client_properties.hpp b/src/duckdb/src/include/duckdb/main/client_properties.hpp index 348e4c90b..f7a7f7a2c 100644 --- a/src/duckdb/src/include/duckdb/main/client_properties.hpp +++ b/src/duckdb/src/include/duckdb/main/client_properties.hpp @@ -10,27 +10,10 @@ #include "duckdb/common/string.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/common/enums/arrow_format_version.hpp" namespace duckdb { -enum class ArrowOffsetSize : uint8_t { REGULAR, LARGE }; - -enum ArrowFormatVersion : uint8_t { - //! Base Version - V1_0 = 10, - //! Added 256-bit Decimal type. - V1_1 = 11, - //! Added MonthDayNano interval type. - V1_2 = 12, - //! Added Run-End Encoded Layout. - V1_3 = 13, - //! Added Variable-size Binary View Layout and the associated BinaryView and Utf8View types. - //! Added ListView Layout and the associated ListView and LargeListView types. Added Variadic buffers. - V1_4 = 14, - //! Expanded Decimal type bit widths to allow 32-bit and 64-bit types. - V1_5 = 15 -}; - //! A set of properties from the client context that can be used to interpret the query result struct ClientProperties { ClientProperties(string time_zone_p, const ArrowOffsetSize arrow_offset_size_p, const bool arrow_use_list_view_p, @@ -47,7 +30,7 @@ struct ClientProperties { bool arrow_use_list_view = false; bool produce_arrow_string_view = false; bool arrow_lossless_conversion = false; - ArrowFormatVersion arrow_output_version = V1_0; + ArrowFormatVersion arrow_output_version = ArrowFormatVersion::V1_0; optional_ptr client_context; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/config.hpp b/src/duckdb/src/include/duckdb/main/config.hpp index 8d0044332..2a3219ec3 100644 --- a/src/duckdb/src/include/duckdb/main/config.hpp +++ b/src/duckdb/src/include/duckdb/main/config.hpp @@ -21,7 +21,6 @@ #include "duckdb/common/enums/optimizer_type.hpp" #include "duckdb/common/enums/order_type.hpp" #include "duckdb/common/enums/set_scope.hpp" -#include "duckdb/common/enums/window_aggregation_mode.hpp" #include "duckdb/common/file_system.hpp" #include "duckdb/common/set.hpp" #include "duckdb/common/types/value.hpp" @@ -30,16 +29,14 @@ #include "duckdb/execution/index/index_type_set.hpp" #include "duckdb/function/cast/default_casts.hpp" #include "duckdb/function/replacement_scan.hpp" -#include "duckdb/main/client_properties.hpp" #include "duckdb/optimizer/optimizer_extension.hpp" #include "duckdb/parser/parsed_data/create_info.hpp" #include "duckdb/parser/parser_extension.hpp" #include "duckdb/planner/operator_extension.hpp" #include "duckdb/storage/compression/bitpacking.hpp" #include "duckdb/function/encoding_function.hpp" +#include "duckdb/main/setting_info.hpp" #include "duckdb/logging/log_manager.hpp" -#include "duckdb/common/enums/debug_vector_verification.hpp" -#include "duckdb/logging/logging.hpp" namespace duckdb { @@ -62,46 +59,7 @@ class HTTPUtil; struct CompressionFunctionSet; struct DatabaseCacheEntry; struct DBConfig; - -enum class CheckpointAbort : uint8_t { - NO_ABORT = 0, - DEBUG_ABORT_BEFORE_TRUNCATE = 1, - DEBUG_ABORT_BEFORE_HEADER = 2, - DEBUG_ABORT_AFTER_FREE_LIST_WRITE = 3 -}; - -typedef void (*set_global_function_t)(DatabaseInstance *db, DBConfig &config, const Value ¶meter); -typedef void (*set_local_function_t)(ClientContext &context, const Value ¶meter); -typedef void (*reset_global_function_t)(DatabaseInstance *db, DBConfig &config); -typedef void (*reset_local_function_t)(ClientContext &context); -typedef Value (*get_setting_function_t)(const ClientContext &context); - -struct ConfigurationOption { - const char *name; - const char *description; - const char *parameter_type; - set_global_function_t set_global; - set_local_function_t set_local; - reset_global_function_t reset_global; - reset_local_function_t reset_local; - get_setting_function_t get_setting; -}; - -typedef void (*set_option_callback_t)(ClientContext &context, SetScope scope, Value ¶meter); - -struct ExtensionOption { - // NOLINTNEXTLINE: work around bug in clang-tidy - ExtensionOption(string description_p, LogicalType type_p, set_option_callback_t set_function_p, - Value default_value_p) - : description(std::move(description_p)), type(std::move(type_p)), set_function(set_function_p), - default_value(std::move(default_value_p)) { - } - - string description; - LogicalType type; - set_option_callback_t set_function; - Value default_value; -}; +struct SettingLookupResult; class SerializationCompatibility { public: @@ -175,14 +133,8 @@ struct DBConfigOptions { bool buffer_manager_track_eviction_timestamps = false; //! Whether or not to allow printing unredacted secrets bool allow_unredacted_secrets = false; - //! The collation type of the database - string collation = string(); - //! The order type used when none is specified (default: ASC) - OrderType default_order_type = OrderType::ASCENDING; //! Disables invalidating the database instance when encountering a fatal error. bool disable_database_invalidation = false; - //! NULL ordering used when none is specified (default: NULLS LAST) - DefaultOrderByNullType default_null_order = DefaultOrderByNullType::NULLS_LAST; //! enable COPY and related commands bool enable_external_access = true; //! Whether or not the global http metadata cache is used @@ -199,8 +151,6 @@ struct DBConfigOptions { bool checkpoint_on_shutdown = true; //! Serialize the metadata on checkpoint with compatibility for a given DuckDB version. SerializationCompatibility serialization_compatibility = SerializationCompatibility::Default(); - //! Debug flag that decides when a checkpoing should be aborted. Only used for testing purposes. - CheckpointAbort checkpoint_abort = CheckpointAbort::NO_ABORT; //! Initialize the database with the standard set of DuckDB functions //! You should probably not touch this unless you know what you are doing bool initialize_default_database = true; @@ -214,20 +164,6 @@ struct DBConfigOptions { set disabled_compression_methods; //! Force a specific bitpacking mode to be used when using the bitpacking compression method BitpackingMode force_bitpacking_mode = BitpackingMode::AUTO; - //! Debug setting for window aggregation mode: (window, combine, separate) - WindowAggregationMode window_mode = WindowAggregationMode::WINDOW; - //! Whether preserving insertion order should be preserved - bool preserve_insertion_order = true; - //! Whether Arrow Arrays use Large or Regular buffers - ArrowOffsetSize arrow_offset_size = ArrowOffsetSize::REGULAR; - //! Whether LISTs should produce Arrow ListViews - bool arrow_use_list_view = false; - //! For DuckDB types without an obvious corresponding Arrow type, export to an Arrow extension type instead of a - //! more portable but less efficient format. For example, UUIDs are exported to UTF-8 (string) when false, and - //! arrow.uuid type when true. - bool arrow_lossless_conversion = false; - //! Whether when producing arrow objects we produce string_views or regular strings - bool produce_arrow_string_views = false; //! Database configuration variables as controlled by SET case_insensitive_map_t set_variables; //! Database configuration variable default values; @@ -238,16 +174,6 @@ struct DBConfigOptions { bool allow_unsigned_extensions = false; //! Whether community extensions should be loaded bool allow_community_extensions = true; - //! Whether extensions with missing metadata should be loaded - bool allow_extensions_metadata_mismatch = false; - //! Enable emitting FSST Vectors - bool enable_fsst_vectors = false; - //! Enable VIEWs to create dependencies - bool enable_view_dependencies = false; - //! Enable macros to create dependencies - bool enable_macro_dependencies = false; - //! Start transactions immediately in all attached databases - instead of lazily when a database is referenced - bool immediate_transaction_mode = false; //! Debug setting - how to initialize blocks in the storage layer when allocating DebugInitialize debug_initialize = DebugInitialize::NO_INITIALIZE; //! The set of user-provided options @@ -268,34 +194,16 @@ struct DBConfigOptions { string duckdb_api; //! Metadata from DuckDB callers string custom_user_agent; - //! Use old implicit casting style (i.e. allow everything to be implicitly casted to VARCHAR) - bool old_implicit_casting = false; //! By default, WAL is encrypted for encrypted databases bool wal_encryption = true; //! Encrypt the temp files bool temp_file_encryption = false; //! The default block allocation size for new duckdb database files (new as-in, they do not yet exist). - idx_t default_block_alloc_size = DUCKDB_BLOCK_ALLOC_SIZE; + idx_t default_block_alloc_size = DEFAULT_BLOCK_ALLOC_SIZE; //! The default block header size for new duckdb database files. idx_t default_block_header_size = DUCKDB_BLOCK_HEADER_STORAGE_SIZE; //! Whether or not to abort if a serialization exception is thrown during WAL playback (when reading truncated WAL) bool abort_on_wal_failure = false; - //! The index_scan_percentage sets a threshold for index scans. - //! If fewer than MAX(index_scan_max_count, index_scan_percentage * total_row_count) - //! rows match, we perform an index scan instead of a table scan. - double index_scan_percentage = 0.001; - //! The index_scan_max_count sets a threshold for index scans. - //! If fewer than MAX(index_scan_max_count, index_scan_percentage * total_row_count) - //! rows match, we perform an index scan instead of a table scan. - idx_t index_scan_max_count = STANDARD_VECTOR_SIZE; - //! The maximum number of schemas we will look through for "did you mean..." style errors in the catalog - idx_t catalog_error_max_schemas = 100; - //! Whether or not to always write to the WAL file, even if this is not required - bool debug_skip_checkpoint_on_commit = false; - //! Vector verification to enable (debug setting only) - DebugVectorVerification debug_verify_vector = DebugVectorVerification::NONE; - //! The maximum amount of vacuum tasks to schedule during a checkpoint - idx_t max_vacuum_tasks = 100; //! Paths that are explicitly allowed, even if enable_external_access is false unordered_set allowed_paths; //! Directories that are explicitly allowed, even if enable_external_access is false @@ -304,8 +212,6 @@ struct DBConfigOptions { LogConfig log_config = LogConfig(); //! Whether to enable external file caching using CachingFileSystem bool enable_external_file_cache = true; - //! Output version of arrow depending on the format version - ArrowFormatVersion arrow_output_version = V1_0; //! Partially process tasks before rescheduling - allows for more scheduler fairness between separate queries #ifdef DUCKDB_ALTERNATIVE_VERIFY bool scheduler_process_partial = true; @@ -314,8 +220,6 @@ struct DBConfigOptions { #endif //! Whether to pin threads to cores (linux only, default AUTOMATIC: on when there are more than 64 cores) ThreadPinMode pin_threads = ThreadPinMode::AUTO; - //! Enable the Parquet reader to identify a Variant group structurally - bool variant_legacy_encoding = false; bool operator==(const DBConfigOptions &other) const; }; @@ -330,7 +234,7 @@ struct DBConfig { DUCKDB_API DBConfig(const case_insensitive_map_t &config_dict, bool read_only); DUCKDB_API ~DBConfig(); - mutex config_lock; + mutable mutex config_lock; //! Replacement table scans are automatically attempted when a table name cannot be found in the schema vector replacement_scans; @@ -378,22 +282,27 @@ struct DBConfig { DUCKDB_API static const DBConfig &GetConfig(const DatabaseInstance &db); DUCKDB_API static vector GetOptions(); DUCKDB_API static idx_t GetOptionCount(); + DUCKDB_API static idx_t GetAliasCount(); DUCKDB_API static vector GetOptionNames(); DUCKDB_API static bool IsInMemoryDatabase(const char *database_path); DUCKDB_API void AddExtensionOption(const string &name, string description, LogicalType parameter, - const Value &default_value = Value(), set_option_callback_t function = nullptr); + const Value &default_value = Value(), set_option_callback_t function = nullptr, + SetScope default_scope = SetScope::SESSION); //! Fetch an option by index. Returns a pointer to the option, or nullptr if out of range DUCKDB_API static optional_ptr GetOptionByIndex(idx_t index); + //! Fetcha n alias by index, or nullptr if out of range + DUCKDB_API static optional_ptr GetAliasByIndex(idx_t index); //! Fetch an option by name. Returns a pointer to the option, or nullptr if none exists. DUCKDB_API static optional_ptr GetOptionByName(const string &name); DUCKDB_API void SetOption(const ConfigurationOption &option, const Value &value); - DUCKDB_API void SetOption(DatabaseInstance *db, const ConfigurationOption &option, const Value &value); + DUCKDB_API void SetOption(optional_ptr db, const ConfigurationOption &option, const Value &value); DUCKDB_API void SetOptionByName(const string &name, const Value &value); DUCKDB_API void SetOptionsByName(const case_insensitive_map_t &values); - DUCKDB_API void ResetOption(DatabaseInstance *db, const ConfigurationOption &option); + DUCKDB_API void ResetOption(optional_ptr db, const ConfigurationOption &option); DUCKDB_API void SetOption(const string &name, Value value); DUCKDB_API void ResetOption(const string &name); + DUCKDB_API void ResetGenericOption(const string &name); static LogicalType ParseLogicalType(const string &type); DUCKDB_API void CheckLock(const string &name); @@ -430,19 +339,28 @@ struct DBConfig { void SetDefaultMaxMemory(); void SetDefaultTempDirectory(); - OrderType ResolveOrder(OrderType order_type) const; - OrderByNullType ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const; + OrderType ResolveOrder(ClientContext &context, OrderType order_type) const; + OrderByNullType ResolveNullOrder(ClientContext &context, OrderType order_type, OrderByNullType null_type) const; const string UserAgent() const; - template - typename OP::RETURN_TYPE GetSetting(const ClientContext &context) { - std::lock_guard lock(config_lock); - return OP::GetSetting(context).template GetValue(); + SettingLookupResult TryGetCurrentSetting(const string &key, Value &result) const; + + template + static typename std::enable_if::value, typename OP::RETURN_TYPE>::type + GetSetting(const SOURCE &source) { + return EnumUtil::FromString( + GetSettingInternal(source, OP::Name, OP::DefaultValue).ToString()); + } + + template + static typename std::enable_if::value, typename OP::RETURN_TYPE>::type + GetSetting(const SOURCE &source) { + return GetSettingInternal(source, OP::Name, OP::DefaultValue).template GetValue(); } template - Value GetSettingValue(const ClientContext &context) { - std::lock_guard lock(config_lock); + Value GetSettingValue(const ClientContext &context) const { + lock_guard lock(config_lock); return OP::GetSetting(context); } @@ -451,6 +369,11 @@ struct DBConfig { void AddAllowedPath(const string &path); string SanitizeAllowedPath(const string &path) const; +private: + static Value GetSettingInternal(const DatabaseInstance &db, const char *setting, const char *default_value); + static Value GetSettingInternal(const DBConfig &config, const char *setting, const char *default_value); + static Value GetSettingInternal(const ClientContext &context, const char *setting, const char *default_value); + private: unique_ptr compression_functions; unique_ptr encoding_functions; diff --git a/src/duckdb/src/include/duckdb/main/database.hpp b/src/duckdb/src/include/duckdb/main/database.hpp index c993ea767..7ab56b476 100644 --- a/src/duckdb/src/include/duckdb/main/database.hpp +++ b/src/duckdb/src/include/duckdb/main/database.hpp @@ -12,7 +12,6 @@ #include "duckdb/main/capi/extension_api.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/extension.hpp" -#include "duckdb/main/settings.hpp" #include "duckdb/main/valid_checker.hpp" #include "duckdb/main/extension/extension_loader.hpp" #include "duckdb/main/extension_manager.hpp" diff --git a/src/duckdb/src/include/duckdb/main/database_file_opener.hpp b/src/duckdb/src/include/duckdb/main/database_file_opener.hpp index 950f1f11e..7eacbbcfe 100644 --- a/src/duckdb/src/include/duckdb/main/database_file_opener.hpp +++ b/src/duckdb/src/include/duckdb/main/database_file_opener.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/opener_file_system.hpp" #include "duckdb/main/config.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/logging/log_manager.hpp" namespace duckdb { class DatabaseInstance; diff --git a/src/duckdb/src/include/duckdb/main/database_manager.hpp b/src/duckdb/src/include/duckdb/main/database_manager.hpp index fe221a344..0194f527a 100644 --- a/src/duckdb/src/include/duckdb/main/database_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/database_manager.hpp @@ -72,6 +72,11 @@ class DatabaseManager { const optional_idx max_db_count = optional_idx()); //! Scans the catalog set and returns each committed database entry vector> GetDatabases(); + //! Returns the approximate count of attached databases. + idx_t ApproxDatabaseCount() { + lock_guard path_lock(db_paths_lock); + return db_paths_to_name.size(); + } //! Removes all databases from the catalog set. This is necessary for the database instance's destructor, //! as the database manager has to be alive when destroying the catalog set objects. void ResetDatabases(unique_ptr &scheduler); @@ -98,8 +103,6 @@ class DatabaseManager { vector GetAttachedDatabasePaths(); private: - //! Returns a database with a specified path - optional_ptr GetDatabaseFromPath(ClientContext &context, const string &path); void CheckPathConflict(ClientContext &context, const string &path); private: @@ -121,7 +124,7 @@ class DatabaseManager { //! A set containing all attached database path //! This allows to attach many databases efficiently, and to avoid attaching the //! same file path twice - case_insensitive_set_t db_paths; + case_insensitive_map_t db_paths_to_name; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/extension_entries.hpp b/src/duckdb/src/include/duckdb/main/extension_entries.hpp index 7d8791a56..34588351a 100644 --- a/src/duckdb/src/include/duckdb/main/extension_entries.hpp +++ b/src/duckdb/src/include/duckdb/main/extension_entries.hpp @@ -150,10 +150,13 @@ static constexpr ExtensionFunctionEntry EXTENSION_FUNCTIONS[] = { {"delta_scan", "delta", CatalogType::TABLE_FUNCTION_ENTRY}, {"drop_fts_index", "fts", CatalogType::PRAGMA_FUNCTION_ENTRY}, {"dsdgen", "tpcds", CatalogType::TABLE_FUNCTION_ENTRY}, + {"ducklake_add_data_files", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, {"ducklake_cleanup_old_files", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, {"ducklake_expire_snapshots", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, + {"ducklake_flush_inlined_data", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, {"ducklake_list_files", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, {"ducklake_merge_adjacent_files", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, + {"ducklake_options", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, {"ducklake_set_option", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, {"ducklake_snapshots", "ducklake", CatalogType::TABLE_FUNCTION_ENTRY}, {"ducklake_table_changes", "ducklake", CatalogType::TABLE_MACRO_ENTRY}, @@ -1013,6 +1016,7 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"http_retry_backoff", "httpfs"}, {"http_retry_wait_ms", "httpfs"}, {"http_timeout", "httpfs"}, + {"httpfs_client_implementation", "httpfs"}, {"mysql_bit1_as_boolean", "mysql_scanner"}, {"mysql_debug_show_queries", "mysql_scanner"}, {"mysql_experimental_filter_pushdown", "mysql_scanner"}, @@ -1027,11 +1031,13 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"pg_pages_per_task", "postgres_scanner"}, {"pg_use_binary_copy", "postgres_scanner"}, {"pg_use_ctid_scan", "postgres_scanner"}, + {"pg_use_text_protocol", "postgres_scanner"}, {"prefetch_all_parquet_files", "parquet"}, {"s3_access_key_id", "httpfs"}, {"s3_endpoint", "httpfs"}, {"s3_kms_key_id", "httpfs"}, {"s3_region", "httpfs"}, + {"s3_requester_pays", "httpfs"}, {"s3_secret_access_key", "httpfs"}, {"s3_session_token", "httpfs"}, {"s3_uploader_max_filesize", "httpfs"}, @@ -1047,18 +1053,13 @@ static constexpr ExtensionEntry EXTENSION_SETTINGS[] = { {"ui_polling_interval", "ui"}, {"ui_remote_url", "ui"}, {"unsafe_enable_version_guessing", "iceberg"}, + {"variant_legacy_encoding", "parquet"}, }; // END_OF_EXTENSION_SETTINGS static constexpr ExtensionEntry EXTENSION_SECRET_TYPES[] = { - {"aws", "httpfs"}, - {"azure", "azure"}, - {"gcs", "httpfs"}, - {"huggingface", "httpfs"}, - {"iceberg", "iceberg"}, - {"mysql", "mysql_scanner"}, - {"postgres", "postgres_scanner"}, - {"r2", "httpfs"}, - {"s3", "httpfs"}, + {"aws", "httpfs"}, {"azure", "azure"}, {"ducklake", "ducklake"}, {"gcs", "httpfs"}, + {"huggingface", "httpfs"}, {"iceberg", "iceberg"}, {"mysql", "mysql_scanner"}, {"postgres", "postgres_scanner"}, + {"r2", "httpfs"}, {"s3", "httpfs"}, }; // END_OF_EXTENSION_SECRET_TYPES // Note: these are currently hardcoded in scripts/generate_extensions_function.py diff --git a/src/duckdb/src/include/duckdb/main/secret/secret.hpp b/src/duckdb/src/include/duckdb/main/secret/secret.hpp index 40a9dab64..ed8034413 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret.hpp @@ -12,6 +12,7 @@ #include "duckdb/common/named_parameter_map.hpp" #include "duckdb/common/serializer/deserializer.hpp" #include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/main/setting_info.hpp" namespace duckdb { class BaseSecret; diff --git a/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp b/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp index 4366d6cd9..80fb46044 100644 --- a/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp +++ b/src/duckdb/src/include/duckdb/main/secret/secret_manager.hpp @@ -20,6 +20,24 @@ class SecretManager; struct DBConfig; class SchemaCatalogEntry; +//! A Secret Entry in the secret manager +struct SecretEntry { +public: + explicit SecretEntry(unique_ptr secret) : secret(secret != nullptr ? secret->Clone() : nullptr) { + } + SecretEntry(const SecretEntry &other) + : persist_type(other.persist_type), storage_mode(other.storage_mode), + secret((other.secret != nullptr) ? other.secret->Clone() : nullptr) { + } + + //! Whether the secret is persistent + SecretPersistType persist_type; + //! The storage backend of the secret + string storage_mode; + //! The secret pointer + unique_ptr secret; +}; + //! Return value of a Secret Lookup struct SecretMatch { public: @@ -52,24 +70,6 @@ struct SecretMatch { int64_t score; }; -//! A Secret Entry in the secret manager -struct SecretEntry { -public: - explicit SecretEntry(unique_ptr secret) : secret(secret != nullptr ? secret->Clone() : nullptr) { - } - SecretEntry(const SecretEntry &other) - : persist_type(other.persist_type), storage_mode(other.storage_mode), - secret((other.secret != nullptr) ? other.secret->Clone() : nullptr) { - } - - //! Whether the secret is persistent - SecretPersistType persist_type; - //! The storage backend of the secret - string storage_mode; - //! The secret pointer - unique_ptr secret; -}; - struct SecretManagerConfig { static constexpr const bool DEFAULT_ALLOW_PERSISTENT_SECRETS = true; //! The default persistence type for secrets diff --git a/src/duckdb/src/include/duckdb/main/setting_info.hpp b/src/duckdb/src/include/duckdb/main/setting_info.hpp new file mode 100644 index 000000000..7968c4d8f --- /dev/null +++ b/src/duckdb/src/include/duckdb/main/setting_info.hpp @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/main/setting_info.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/common.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/enums/set_scope.hpp" + +namespace duckdb { +class ClientContext; +class DatabaseInstance; +struct DBConfig; + +const string GetDefaultUserAgent(); + +enum class SettingScope : uint8_t { + //! Setting is from the global Setting scope + GLOBAL, + //! Setting is from the local Setting scope + LOCAL, + //! Setting was not fetched from settings, but it was fetched from a secret instead + SECRET, + //! The setting was not found or invalid in some other way + INVALID +}; + +struct SettingLookupResult { +public: + SettingLookupResult() : scope(SettingScope::INVALID) { + } + explicit SettingLookupResult(SettingScope scope) : scope(scope) { + D_ASSERT(scope != SettingScope::INVALID); + } + +public: + operator bool() { // NOLINT: allow implicit conversion to bool + return scope != SettingScope::INVALID; + } + +public: + SettingScope GetScope() { + D_ASSERT(scope != SettingScope::INVALID); + return scope; + } + +private: + SettingScope scope = SettingScope::INVALID; +}; + +struct SettingCallbackInfo { + explicit SettingCallbackInfo(ClientContext &context, SetScope scope); + explicit SettingCallbackInfo(DBConfig &config, optional_ptr db); + + DBConfig &config; + optional_ptr db; + optional_ptr context; + SetScope scope; +}; + +typedef void (*set_callback_t)(SettingCallbackInfo &info, Value ¶meter); +typedef void (*set_global_function_t)(DatabaseInstance *db, DBConfig &config, const Value ¶meter); +typedef void (*set_local_function_t)(ClientContext &context, const Value ¶meter); +typedef void (*reset_global_function_t)(DatabaseInstance *db, DBConfig &config); +typedef void (*reset_local_function_t)(ClientContext &context); +typedef Value (*get_setting_function_t)(const ClientContext &context); + +struct ConfigurationOption { + const char *name; + const char *description; + const char *parameter_type; + set_global_function_t set_global; + set_local_function_t set_local; + reset_global_function_t reset_global; + reset_local_function_t reset_local; + get_setting_function_t get_setting; + SetScope default_scope; + const char *default_value; + set_callback_t set_callback; +}; + +struct ConfigurationAlias { + const char *alias; + idx_t option_index; +}; + +typedef void (*set_option_callback_t)(ClientContext &context, SetScope scope, Value ¶meter); + +struct ExtensionOption { + // NOLINTNEXTLINE: work around bug in clang-tidy + ExtensionOption(string description_p, LogicalType type_p, set_option_callback_t set_function_p, + Value default_value_p, SetScope default_scope_p) + : description(std::move(description_p)), type(std::move(type_p)), set_function(set_function_p), + default_value(std::move(default_value_p)), default_scope(default_scope_p) { + } + + string description; + LogicalType type; + set_option_callback_t set_function; + Value default_value; + SetScope default_scope; +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/main/settings.hpp b/src/duckdb/src/include/duckdb/main/settings.hpp index 4096f2319..e0a45c165 100644 --- a/src/duckdb/src/include/duckdb/main/settings.hpp +++ b/src/duckdb/src/include/duckdb/main/settings.hpp @@ -8,50 +8,18 @@ #pragma once -#include "duckdb/common/common.hpp" -#include "duckdb/common/types/value.hpp" #include "duckdb/main/config.hpp" +#include "duckdb/main/setting_info.hpp" +#include "duckdb/common/enums/access_mode.hpp" +#include "duckdb/common/enums/checkpoint_abort.hpp" +#include "duckdb/common/enums/debug_vector_verification.hpp" +#include "duckdb/common/enums/window_aggregation_mode.hpp" +#include "duckdb/common/enums/order_type.hpp" +#include "duckdb/common/enums/output_type.hpp" +#include "duckdb/common/enums/thread_pin_mode.hpp" +#include "duckdb/common/enums/arrow_format_version.hpp" namespace duckdb { -class ClientContext; -class DatabaseInstance; -struct DBConfig; - -const string GetDefaultUserAgent(); - -enum class SettingScope : uint8_t { - //! Setting is from the global Setting scope - GLOBAL, - //! Setting is from the local Setting scope - LOCAL, - //! Setting was not fetched from settings, but it was fetched from a secret instead - SECRET, - //! The setting was not found or invalid in some other way - INVALID -}; - -struct SettingLookupResult { -public: - SettingLookupResult() : scope(SettingScope::INVALID) { - } - explicit SettingLookupResult(SettingScope scope) : scope(scope) { - D_ASSERT(scope != SettingScope::INVALID); - } - -public: - operator bool() { // NOLINT: allow implicit conversion to bool - return scope != SettingScope::INVALID; - } - -public: - SettingScope GetScope() { - D_ASSERT(scope != SettingScope::INVALID); - return scope; - } - -private: - SettingScope scope = SettingScope::INVALID; -}; //===----------------------------------------------------------------------===// // This code is autogenerated from 'update_settings_header_file.py'. @@ -122,9 +90,8 @@ struct AllowExtensionsMetadataMismatchSetting { static constexpr const char *Name = "allow_extensions_metadata_mismatch"; static constexpr const char *Description = "Allow to load extensions with not compatible metadata"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct AllowPersistentSecretsSetting { @@ -190,9 +157,8 @@ struct ArrowLargeBufferSizeSetting { static constexpr const char *Description = "Whether Arrow buffers for strings, blobs, uuids and bits should be exported using large buffers"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct ArrowLosslessConversionSetting { @@ -202,9 +168,8 @@ struct ArrowLosslessConversionSetting { "Whenever a DuckDB type does not have a clear native or canonical extension match in Arrow, export the types " "with a duckdb.type_name extension name."; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct ArrowOutputListViewSetting { @@ -213,20 +178,19 @@ struct ArrowOutputListViewSetting { static constexpr const char *Description = "Whether export to Arrow format should use ListView as the physical layout for LIST columns"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct ArrowOutputVersionSetting { - using RETURN_TYPE = string; + using RETURN_TYPE = ArrowFormatVersion; static constexpr const char *Name = "arrow_output_version"; static constexpr const char *Description = "Whether strings should be produced by DuckDB in Utf8View format instead of Utf8"; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "1.0"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct AsofLoopJoinThresholdSetting { @@ -235,9 +199,8 @@ struct AsofLoopJoinThresholdSetting { static constexpr const char *Description = "The maximum number of rows we need on the left side of an ASOF join to use a nested loop join"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "64"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct AutoinstallExtensionRepositorySetting { @@ -279,9 +242,8 @@ struct CatalogErrorMaxSchemasSetting { static constexpr const char *Description = "The maximum number of schemas the system will scan for \"did you mean...\" style errors in the catalog"; static constexpr const char *InputType = "UBIGINT"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "100"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct CheckpointThresholdSetting { @@ -330,9 +292,8 @@ struct DebugAsofIejoinSetting { static constexpr const char *Name = "debug_asof_iejoin"; static constexpr const char *Description = "DEBUG SETTING: force use of IEJoin to implement AsOf joins"; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct DebugCheckpointAbortSetting { @@ -341,9 +302,9 @@ struct DebugCheckpointAbortSetting { static constexpr const char *Description = "DEBUG SETTING: trigger an abort while checkpointing for testing purposes"; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "NONE"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct DebugForceExternalSetting { @@ -363,9 +324,8 @@ struct DebugForceNoCrossProductSetting { static constexpr const char *Description = "DEBUG SETTING: Force disable cross product generation when hyper graph isn't connected, used for testing"; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct DebugSkipCheckpointOnCommitSetting { @@ -373,9 +333,8 @@ struct DebugSkipCheckpointOnCommitSetting { static constexpr const char *Name = "debug_skip_checkpoint_on_commit"; static constexpr const char *Description = "DEBUG SETTING: skip checkpointing on commit"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct DebugVerifyVectorSetting { @@ -383,9 +342,9 @@ struct DebugVerifyVectorSetting { static constexpr const char *Name = "debug_verify_vector"; static constexpr const char *Description = "DEBUG SETTING: enable vector verification"; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "NONE"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct DebugWindowModeSetting { @@ -393,9 +352,9 @@ struct DebugWindowModeSetting { static constexpr const char *Name = "debug_window_mode"; static constexpr const char *Description = "DEBUG SETTING: switch window mode to use"; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "WINDOW"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct DefaultBlockSizeSetting { @@ -414,11 +373,9 @@ struct DefaultCollationSetting { static constexpr const char *Name = "default_collation"; static constexpr const char *Description = "The collation setting used when none is specified"; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = ""; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct DefaultNullOrderSetting { @@ -426,19 +383,19 @@ struct DefaultNullOrderSetting { static constexpr const char *Name = "default_null_order"; static constexpr const char *Description = "NULL ordering used when none is specified (NULLS_FIRST or NULLS_LAST)"; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "NULLS_LAST"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct DefaultOrderSetting { - using RETURN_TYPE = string; + using RETURN_TYPE = OrderType; static constexpr const char *Name = "default_order"; static constexpr const char *Description = "The order type used when none is specified (ASC or DESC)"; static constexpr const char *InputType = "VARCHAR"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "ASCENDING"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct DefaultSecretStorageSetting { @@ -470,9 +427,8 @@ struct DisableTimestamptzCastsSetting { static constexpr const char *Name = "disable_timestamptz_casts"; static constexpr const char *Description = "Disable casting from timestamp to timestamptz "; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct DisabledCompressionMethodsSetting { @@ -531,9 +487,8 @@ struct DynamicOrFilterThresholdSetting { static constexpr const char *Description = "The maximum amount of OR filters we generate dynamically from a hash join"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "50"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct EnableExternalAccessSetting { @@ -566,9 +521,8 @@ struct EnableFSSTVectorsSetting { static constexpr const char *Description = "Allow scans on FSST compressed segments to emit compressed vectors to utilize late decompression"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct EnableHTTPLoggingSetting { @@ -607,9 +561,8 @@ struct EnableMacroDependenciesSetting { static constexpr const char *Description = "Enable created MACROs to create dependencies on the referenced objects (such as tables)"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct EnableObjectCacheSetting { @@ -617,9 +570,8 @@ struct EnableObjectCacheSetting { static constexpr const char *Name = "enable_object_cache"; static constexpr const char *Description = "[PLACEHOLDER] Legacy setting - does nothing"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct EnableProfilingSetting { @@ -663,9 +615,8 @@ struct EnableViewDependenciesSetting { static constexpr const char *Description = "Enable created VIEWs to create dependencies on the referenced objects (such as tables)"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct EnabledLogTypes { @@ -818,9 +769,8 @@ struct ImmediateTransactionModeSetting { static constexpr const char *Description = "Whether transactions should be started lazily when needed, or immediately when BEGIN TRANSACTION is called"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct IndexScanMaxCountSetting { @@ -830,9 +780,8 @@ struct IndexScanMaxCountSetting { "The maximum index scan count sets a threshold for index scans. If fewer than MAX(index_scan_max_count, " "index_scan_percentage * total_row_count) rows match, we perform an index scan instead of a table scan."; static constexpr const char *InputType = "UBIGINT"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "2048"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct IndexScanPercentageSetting { @@ -842,10 +791,9 @@ struct IndexScanPercentageSetting { "The index scan percentage sets a threshold for index scans. If fewer than MAX(index_scan_max_count, " "index_scan_percentage * total_row_count) rows match, we perform an index scan instead of a table scan."; static constexpr const char *InputType = "DOUBLE"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static bool OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "0.001"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; + static void OnSet(SettingCallbackInfo &info, Value &input); }; struct IntegerDivisionSetting { @@ -854,9 +802,8 @@ struct IntegerDivisionSetting { static constexpr const char *Description = "Whether or not the / operator defaults to integer division, or to floating point division"; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct LambdaSyntaxSetting { @@ -876,9 +823,8 @@ struct LateMaterializationMaxRowsSetting { static constexpr const char *Description = "The maximum amount of rows in the LIMIT/SAMPLE for which we trigger late materialization"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "50"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct LockConfigurationSetting { @@ -970,9 +916,8 @@ struct MaxVacuumTasksSetting { static constexpr const char *Name = "max_vacuum_tasks"; static constexpr const char *Description = "The maximum vacuum tasks to schedule during a checkpoint."; static constexpr const char *InputType = "UBIGINT"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "100"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct MergeJoinThresholdSetting { @@ -980,9 +925,8 @@ struct MergeJoinThresholdSetting { static constexpr const char *Name = "merge_join_threshold"; static constexpr const char *Description = "The maximum number of rows on either table to choose a merge join"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "1000"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct NestedLoopJoinThresholdSetting { @@ -991,9 +935,8 @@ struct NestedLoopJoinThresholdSetting { static constexpr const char *Description = "The maximum number of rows on either table to choose a nested loop join"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "5"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct OldImplicitCastingSetting { @@ -1001,9 +944,8 @@ struct OldImplicitCastingSetting { static constexpr const char *Name = "old_implicit_casting"; static constexpr const char *Description = "Allow implicit casting to/from VARCHAR"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct OrderByNonIntegerLiteralSetting { @@ -1012,9 +954,8 @@ struct OrderByNonIntegerLiteralSetting { static constexpr const char *Description = "Allow ordering by non-integer literals - ordering by such literals has no effect."; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct OrderedAggregateThresholdSetting { @@ -1034,9 +975,8 @@ struct PartitionedWriteFlushThresholdSetting { static constexpr const char *Description = "The threshold in number of rows after which we flush a thread state when writing using PARTITION_BY"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "524288"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct PartitionedWriteMaxOpenFilesSetting { @@ -1045,9 +985,8 @@ struct PartitionedWriteMaxOpenFilesSetting { static constexpr const char *Description = "The maximum amount of files the system can keep open before flushing to disk when writing using PARTITION_BY"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "100"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct PasswordSetting { @@ -1087,9 +1026,8 @@ struct PivotFilterThresholdSetting { static constexpr const char *Description = "The threshold to switch from using filtered aggregates to LIST with a dedicated pivot operator"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "20"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct PivotLimitSetting { @@ -1097,9 +1035,8 @@ struct PivotLimitSetting { static constexpr const char *Name = "pivot_limit"; static constexpr const char *Description = "The maximum number of pivot columns in a pivot statement"; static constexpr const char *InputType = "UBIGINT"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "100000"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct PreferRangeJoinsSetting { @@ -1107,9 +1044,8 @@ struct PreferRangeJoinsSetting { static constexpr const char *Name = "prefer_range_joins"; static constexpr const char *Description = "Force use of range joins with mixed predicates"; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct PreserveIdentifierCaseSetting { @@ -1118,9 +1054,8 @@ struct PreserveIdentifierCaseSetting { static constexpr const char *Description = "Whether or not to preserve the identifier case, instead of always lowercasing all non-quoted identifiers"; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "true"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct PreserveInsertionOrderSetting { @@ -1130,20 +1065,18 @@ struct PreserveInsertionOrderSetting { "Whether or not to preserve insertion order. If set to false the system is allowed to re-order any results " "that do not contain ORDER BY clauses."; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "true"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct ProduceArrowStringViewSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "produce_arrow_string_view"; static constexpr const char *Description = - "Whether strings should be produced by DuckDB in Utf8View format instead of Utf8"; + "Whether Arrow strings should be produced by DuckDB in Utf8View format instead of Utf8"; static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "false"; + static constexpr SetScope DefaultScope = SetScope::GLOBAL; }; struct ProfileOutputSetting { @@ -1194,9 +1127,8 @@ struct ScalarSubqueryErrorOnMultipleRowsSetting { static constexpr const char *Description = "When a scalar subquery returns multiple rows - return a random row instead of returning an error."; static constexpr const char *InputType = "BOOLEAN"; - static void SetLocal(ClientContext &context, const Value ¶meter); - static void ResetLocal(ClientContext &context); - static Value GetSetting(const ClientContext &context); + static constexpr const char *DefaultValue = "true"; + static constexpr SetScope DefaultScope = SetScope::SESSION; }; struct SchedulerProcessPartialSetting { @@ -1303,16 +1235,6 @@ struct UsernameSetting { static Value GetSetting(const ClientContext &context); }; -struct VariantLegacyEncodingSetting { - using RETURN_TYPE = bool; - static constexpr const char *Name = "variant_legacy_encoding"; - static constexpr const char *Description = "Enables the Parquet reader to identify a Variant structurally."; - static constexpr const char *InputType = "BOOLEAN"; - static void SetGlobal(DatabaseInstance *db, DBConfig &config, const Value ¶meter); - static void ResetGlobal(DatabaseInstance *db, DBConfig &config); - static Value GetSetting(const ClientContext &context); -}; - struct WalEncryptionSetting { using RETURN_TYPE = bool; static constexpr const char *Name = "wal_encryption"; diff --git a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp index 0a2588c6a..552753692 100644 --- a/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/common_table_expression_info.hpp @@ -8,7 +8,7 @@ #pragma once -#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/common/vector.hpp" #include "duckdb/common/enums/cte_materialize.hpp" namespace duckdb { @@ -24,6 +24,8 @@ struct CommonTableExpressionInfo { void Serialize(Serializer &serializer) const; static unique_ptr Deserialize(Deserializer &deserializer); unique_ptr Copy(); + + ~CommonTableExpressionInfo(); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp index 25864c3e8..b206e9143 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/alter_table_info.hpp @@ -134,14 +134,17 @@ struct RenameFieldInfo : public AlterTableInfo { RenameFieldInfo(AlterEntryData data, vector column_path, string new_name_p); ~RenameFieldInfo() override; - //! Path to source field + //! Path to source field. vector column_path; - //! Column new name + //! New name of the column (field). string new_name; public: unique_ptr Copy() const override; string ToString() const override; + string GetColumnName() const override { + return column_path[0]; + } void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); @@ -201,16 +204,19 @@ struct AddFieldInfo : public AlterTableInfo { AddFieldInfo(AlterEntryData data, vector column_path, ColumnDefinition new_field, bool if_field_not_exists); ~AddFieldInfo() override; - //! The path to the struct + //! Path to source field. vector column_path; - //! New field to add to the struct + //! New field to add. ColumnDefinition new_field; - //! Whether or not an error should be thrown if the field exist + //! Whether or not an error should be thrown if the field does not exist. bool if_field_not_exists; public: unique_ptr Copy() const override; string ToString() const override; + string GetColumnName() const override { + return column_path[0]; + } void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); @@ -253,16 +259,20 @@ struct RemoveFieldInfo : public AlterTableInfo { RemoveFieldInfo(AlterEntryData data, vector column_path, bool if_column_exists, bool cascade); ~RemoveFieldInfo() override; - //! The path to the field to remove + //! Path to source field. vector column_path; - //! Whether or not an error should be thrown if the column does not exist + //! Whether or not an error should be thrown if the column does not exist. bool if_column_exists; - //! Whether or not the column should be removed if a dependency conflict arises (used by GENERATED columns) + //! Whether or not the column should be removed if a dependency conflict arises (used by GENERATED columns). bool cascade; public: unique_ptr Copy() const override; string ToString() const override; + string GetColumnName() const override { + return column_path[0]; + } + void Serialize(Serializer &serializer) const override; static unique_ptr Deserialize(Deserializer &deserializer); diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp index 835dd9a4c..3f3504722 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/attach_info.hpp @@ -9,11 +9,10 @@ #pragma once #include "duckdb/parser/parsed_data/parse_info.hpp" -#include "duckdb/common/vector.hpp" #include "duckdb/common/unordered_map.hpp" #include "duckdb/common/types/value.hpp" #include "duckdb/common/enums/on_create_conflict.hpp" -#include "duckdb/storage/storage_options.hpp" +#include "duckdb/parser/parsed_expression.hpp" namespace duckdb { @@ -30,13 +29,13 @@ struct AttachInfo : public ParseInfo { //! The path to the attached database string path; //! Set of (key, value) options + case_insensitive_map_t> parsed_options; + //! Set of bound (key, value) options unordered_map options; //! What to do on create conflict OnCreateConflict on_conflict = OnCreateConflict::ERROR_ON_CONFLICT; public: - //! Returns the storage options - StorageOptions GetStorageOptions() const; //! Copies this AttachInfo and returns an unique pointer to the new AttachInfo. unique_ptr Copy() const; string ToString() const; diff --git a/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp index d3ed005db..b3e290754 100644 --- a/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp +++ b/src/duckdb/src/include/duckdb/parser/parsed_data/copy_info.hpp @@ -24,8 +24,7 @@ struct CopyInfo : public ParseInfo { static constexpr const ParseInfoType TYPE = ParseInfoType::COPY_INFO; public: - CopyInfo() : ParseInfo(TYPE), catalog(INVALID_CATALOG), schema(DEFAULT_SCHEMA), is_format_auto_detected(true) { - } + CopyInfo(); //! The catalog name to copy to/from string catalog; @@ -44,13 +43,14 @@ struct CopyInfo : public ParseInfo { //! The file path to copy to/from string file_path; //! Set of (key, value) options + case_insensitive_map_t> parsed_options; + //! Set of (key, value) options case_insensitive_map_t> options; //! The SQL statement used instead of a table when copying data out to a file unique_ptr select_statement; public: - static string CopyOptionsToString(const string &format, bool is_format_auto_detected, - const case_insensitive_map_t> &options); + string CopyOptionsToString() const; public: unique_ptr Copy() const; diff --git a/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp index 8f69686c1..554cdb126 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/merge_into_statement.hpp @@ -53,6 +53,9 @@ class MergeIntoStatement : public SQLStatement { map>> actions; + //! keep track of optional returningList if statement contains a RETURNING keyword + vector> returning_list; + //! CTEs CommonTableExpressionMap cte_map; diff --git a/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp b/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp index d0fd8b4d5..5d7295ca3 100644 --- a/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp +++ b/src/duckdb/src/include/duckdb/parser/statement/select_statement.hpp @@ -8,15 +8,12 @@ #pragma once -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/query_node.hpp" #include "duckdb/parser/sql_statement.hpp" #include "duckdb/parser/tableref.hpp" -#include "duckdb/parser/query_node.hpp" namespace duckdb { -class QueryNode; class Serializer; class Deserializer; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp index 1b9f8f988..6aa6e8015 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/joinref.hpp @@ -43,6 +43,8 @@ class JoinRef : public TableRef { vector> duplicate_eliminated_columns; //! If we have duplicate eliminated columns if the delim is flipped bool delim_flipped = false; + //! Whether or not this is an implicit cross join + bool is_implicit = false; public: string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp index fdd98e21b..77a37d444 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/showref.hpp @@ -16,7 +16,7 @@ namespace duckdb { -enum class ShowType : uint8_t { SUMMARY, DESCRIBE }; +enum class ShowType : uint8_t { SUMMARY, DESCRIBE, SHOW_FROM }; //! Represents a SHOW/DESCRIBE/SUMMARIZE statement class ShowRef : public TableRef { @@ -28,6 +28,10 @@ class ShowRef : public TableRef { //! The table name (if any) string table_name; + //! The catalog name (if any) + string catalog_name; + //! The schema name (if any) + string schema_name; //! The QueryNode of select query (if any) unique_ptr query; //! Whether or not we are requesting a summary or a describe diff --git a/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp b/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp index 51c4fd8e3..c76377991 100644 --- a/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp +++ b/src/duckdb/src/include/duckdb/parser/tableref/table_function_ref.hpp @@ -12,6 +12,7 @@ #include "duckdb/parser/tableref.hpp" #include "duckdb/common/vector.hpp" #include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/common/enums/ordinality_request_type.hpp" namespace duckdb { //! Represents a Table producing function @@ -27,6 +28,9 @@ class TableFunctionRef : public TableRef { // if the function takes a subquery as argument its in here unique_ptr subquery; + //! Whether or not WITH ORDINALITY has been invoked + OrdinalityType with_ordinality = OrdinalityType::WITHOUT_ORDINALITY; + public: string ToString() const override; diff --git a/src/duckdb/src/include/duckdb/parser/transformer.hpp b/src/duckdb/src/include/duckdb/parser/transformer.hpp index 37cb0f20f..cb9ad500a 100644 --- a/src/duckdb/src/include/duckdb/parser/transformer.hpp +++ b/src/duckdb/src/include/duckdb/parser/transformer.hpp @@ -376,8 +376,6 @@ class Transformer { unique_ptr TransformMacroFunction(duckdb_libpgquery::PGFunctionDefinition &function); - void ParseGenericOptionListEntry(case_insensitive_map_t> &result_options, string &name, - duckdb_libpgquery::PGNode *arg); vector TransformNameList(duckdb_libpgquery::PGList &list); public: diff --git a/src/duckdb/src/include/duckdb/planner/bind_context.hpp b/src/duckdb/src/include/duckdb/planner/bind_context.hpp index a9a905b3e..d9c20dd1d 100644 --- a/src/duckdb/src/include/duckdb/planner/bind_context.hpp +++ b/src/duckdb/src/include/duckdb/planner/bind_context.hpp @@ -97,6 +97,9 @@ class BindContext { vector &bound_column_ids, TableCatalogEntry &entry, bool add_row_id = true); void AddBaseTable(idx_t index, const string &alias, const vector &names, const vector &types, vector &bound_column_ids, const string &table_name); + void AddBaseTable(idx_t index, const string &alias, const vector &names, const vector &types, + vector &bound_column_ids, TableCatalogEntry &entry, + virtual_column_map_t virtual_columns); //! Adds a call to a table function with the given alias to the BindContext. void AddTableFunction(idx_t index, const string &alias, const vector &names, const vector &types, vector &bound_column_ids, diff --git a/src/duckdb/src/include/duckdb/planner/binder.hpp b/src/duckdb/src/include/duckdb/planner/binder.hpp index 19bd0a290..ab4bd78b8 100644 --- a/src/duckdb/src/include/duckdb/planner/binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/binder.hpp @@ -58,6 +58,7 @@ class BoundAtClause; struct CreateInfo; struct BoundCreateTableInfo; +struct BoundOnConflictInfo; struct CommonTableExpressionInfo; struct BoundParameterMap; struct BoundPragmaInfo; @@ -65,6 +66,7 @@ struct BoundLimitNode; struct EntryLookupInfo; struct PivotColumnEntry; struct UnpivotEntry; +struct CopyInfo; template class IndexVector; @@ -206,9 +208,6 @@ class Binder : public enable_shared_from_this { TableCatalogEntry &table, vector &columns, vector> &update_expressions, vector> &projection_expressions); - void BindDoUpdateSetExpressions(const string &table_alias, LogicalInsert &insert, UpdateSetInfo &set_info, - TableCatalogEntry &table, TableStorageInfo &storage_info); - void BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &table, InsertStatement &stmt); void BindVacuumTable(LogicalVacuum &vacuum, unique_ptr &root); @@ -278,12 +277,18 @@ class Binder : public enable_shared_from_this { idx_t unnamed_subquery_index = 1; //! Statement properties StatementProperties prop; + //! Root binder + Binder &root_binder; + //! Binder depth + idx_t depth; private: //! Get the root binder (binder with no parent) Binder &GetRootBinder(); //! Determine the depth of the binder idx_t GetBinderDepth() const; + //! Increase the depth of the binder + void IncreaseDepth(); //! Bind the expressions of generated columns to check for errors void BindGeneratedColumns(BoundCreateTableInfo &info); //! Bind the default values of the columns of a table @@ -333,7 +338,8 @@ class Binder : public enable_shared_from_this { void BindRowIdColumns(TableCatalogEntry &table, LogicalGet &get, vector> &expressions); BoundStatement BindReturning(vector> returning_list, TableCatalogEntry &table, const string &alias, idx_t update_table_index, - unique_ptr child_operator, BoundStatement result); + unique_ptr child_operator, + virtual_column_map_t virtual_columns = virtual_column_map_t()); unique_ptr BindTableMacro(FunctionExpression &function, TableMacroCatalogEntry ¯o_func, idx_t depth); @@ -403,6 +409,7 @@ class Binder : public enable_shared_from_this { BoundStatement BindCopyTo(CopyStatement &stmt, CopyToType copy_to_type); BoundStatement BindCopyFrom(CopyStatement &stmt); + void BindCopyOptions(CopyInfo &info); void PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, BoundQueryNode &result); void BindModifiers(BoundQueryNode &result, idx_t table_index, const vector &names, @@ -467,6 +474,9 @@ class Binder : public enable_shared_from_this { vector &named_column_map, vector &expected_types, IndexVector &column_index_map); void TryReplaceDefaultExpression(unique_ptr &expr, const ColumnDefinition &column); + void ExpandDefaultInValuesList(InsertStatement &stmt, TableCatalogEntry &table, + optional_ptr values_list, + const vector &named_column_map); unique_ptr BindMergeAction(LogicalMergeInto &merge_into, TableCatalogEntry &table, LogicalGet &get, idx_t proj_index, vector> &expressions, @@ -474,6 +484,8 @@ class Binder : public enable_shared_from_this { const vector &source_aliases, const vector &source_names); + unique_ptr GenerateMergeInto(InsertStatement &stmt, TableCatalogEntry &table); + private: Binder(ClientContext &context, shared_ptr parent, BinderType binder_type); }; diff --git a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp index b7c263a92..bd75aac19 100644 --- a/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp +++ b/src/duckdb/src/include/duckdb/planner/bound_tokens.hpp @@ -34,6 +34,7 @@ class BoundConjunctionExpression; class BoundConstantExpression; class BoundDefaultExpression; class BoundFunctionExpression; +class BoundLambdaRefExpression; class BoundOperatorExpression; class BoundParameterExpression; class BoundReferenceExpression; diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp index f180ed25d..601b9b652 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder.hpp @@ -147,13 +147,6 @@ class ExpressionBinder { //! Enables special-handling of lambda parameters during macro replacement by tracking them in the lambda_params //! vector. void ReplaceMacroParametersInLambda(FunctionExpression &function, vector> &lambda_params); - //! Recursively qualifies column references in ON CONFLICT DO UPDATE SET expressions. - void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, - vector> &lambda_params); - //! Enables special-handling of lambda parameters during ON CONFLICT TO UPDATE SET qualification by tracking them in - //! the lambda_params vector. - void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &table_name, - vector> &lambda_params); static LogicalType GetExpressionReturnType(const Expression &expr); diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp index 66451cc8b..8d2a9b3e6 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/projection_binder.hpp @@ -18,7 +18,7 @@ class ColumnAliasBinder; class ProjectionBinder : public ExpressionBinder { public: ProjectionBinder(Binder &binder, ClientContext &context, idx_t proj_index, - vector> &proj_expressions); + vector> &proj_expressions, string clause); protected: BindResult BindExpression(unique_ptr &expr_ptr, idx_t depth, @@ -31,6 +31,7 @@ class ProjectionBinder : public ExpressionBinder { private: idx_t proj_index; vector> &proj_expressions; + string clause; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/expression_binder/table_function_binder.hpp b/src/duckdb/src/include/duckdb/planner/expression_binder/table_function_binder.hpp index 9e14a852e..cc20dd5ed 100644 --- a/src/duckdb/src/include/duckdb/planner/expression_binder/table_function_binder.hpp +++ b/src/duckdb/src/include/duckdb/planner/expression_binder/table_function_binder.hpp @@ -15,7 +15,8 @@ namespace duckdb { //! The table function binder can bind standard table function parameters (i.e., non-table-in-out functions) class TableFunctionBinder : public ExpressionBinder { public: - TableFunctionBinder(Binder &binder, ClientContext &context, string table_function_name = string()); + TableFunctionBinder(Binder &binder, ClientContext &context, string table_function_name = string(), + string clause = "Table function"); protected: BindResult BindLambdaReference(LambdaRefExpression &expr, idx_t depth); @@ -26,6 +27,7 @@ class TableFunctionBinder : public ExpressionBinder { private: string table_function_name; + string clause; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp index 395224bc8..aebc52b96 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_get.hpp @@ -58,6 +58,8 @@ class LogicalGet : public LogicalOperator { ExtraOperatorInfo extra_info; //! Contains a reference to dynamically generated table filters (through e.g. a join up in the tree) shared_ptr dynamic_filters; + //! Information for WITH ORDINALITY + optional_idx ordinality_idx; string GetName() const override; InsertionOrderPreservingMap ParamsToString() const override; diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp index c20c33a54..e184b135c 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_insert.hpp @@ -17,28 +17,8 @@ class TableCatalogEntry; class Index; -//! LogicalInsert represents an insertion of data into a base table -class LogicalInsert : public LogicalOperator { -public: - static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INSERT; - -public: - LogicalInsert(TableCatalogEntry &table, idx_t table_index); - - vector>> insert_values; - //! The insertion map ([table_index -> index in result, or DConstants::INVALID_INDEX if not specified]) - physical_index_vector_t column_index_map; - //! The expected types for the INSERT statement (obtained from the column types) - vector expected_types; - //! The base table to insert into - TableCatalogEntry &table; - idx_t table_index; - //! if returning option is used, return actual chunk to projection - bool return_chunk; - //! The default statements used by the table - vector> bound_defaults; - //! The constraints used by the table - vector> bound_constraints; +struct BoundOnConflictInfo { + BoundOnConflictInfo(); //! Which action to take on conflict OnConflictAction action_type; @@ -62,6 +42,32 @@ class LogicalInsert : public LogicalOperator { vector source_columns; //! True, if the INSERT OR REPLACE requires delete + insert. bool update_is_del_and_insert; +}; + +//! LogicalInsert represents an insertion of data into a base table +class LogicalInsert : public LogicalOperator { +public: + static constexpr const LogicalOperatorType TYPE = LogicalOperatorType::LOGICAL_INSERT; + +public: + LogicalInsert(TableCatalogEntry &table, idx_t table_index); + + vector>> insert_values; + //! The insertion map ([table_index -> index in result, or DConstants::INVALID_INDEX if not specified]) + physical_index_vector_t column_index_map; + //! The expected types for the INSERT statement (obtained from the column types) + vector expected_types; + //! The base table to insert into + TableCatalogEntry &table; + idx_t table_index; + //! if returning option is used, return actual chunk to projection + bool return_chunk; + //! The default statements used by the table + vector> bound_defaults; + //! The constraints used by the table + vector> bound_constraints; + + BoundOnConflictInfo on_conflict_info; public: void Serialize(Serializer &serializer) const override; diff --git a/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp b/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp index 959b66ffe..657fec212 100644 --- a/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp +++ b/src/duckdb/src/include/duckdb/planner/operator/logical_merge_into.hpp @@ -52,6 +52,8 @@ class LogicalMergeInto : public LogicalOperator { optional_idx source_marker; //! Bound constraints vector> bound_constraints; + //! Whether or not to return the input data + bool return_chunk = false; map>> actions; @@ -60,6 +62,7 @@ class LogicalMergeInto : public LogicalOperator { static unique_ptr Deserialize(Deserializer &deserializer); idx_t EstimateCardinality(ClientContext &context) override; + vector GetTableIndex() const override; protected: vector GetColumnBindings() override; diff --git a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp index ecc9964ad..619e89a5a 100644 --- a/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/buffer_manager.hpp @@ -129,7 +129,7 @@ class BufferManager { //! Write a temporary file buffer. virtual void WriteTemporaryBuffer(MemoryTag tag, block_id_t block_id, FileBuffer &buffer); //! Read a temporary buffer. - virtual unique_ptr ReadTemporaryBuffer(MemoryTag tag, BlockHandle &block, + virtual unique_ptr ReadTemporaryBuffer(QueryContext context, MemoryTag tag, BlockHandle &block, unique_ptr buffer); //! Delete the temporary file containing the block. virtual void DeleteTemporaryFile(BlockHandle &block); diff --git a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp index adef8a924..99e949418 100644 --- a/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp +++ b/src/duckdb/src/include/duckdb/storage/caching_file_system.hpp @@ -12,12 +12,14 @@ #include "duckdb/common/file_open_flags.hpp" #include "duckdb/common/open_file_info.hpp" #include "duckdb/common/shared_ptr.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/storage/storage_lock.hpp" #include "duckdb/storage/external_file_cache.hpp" namespace duckdb { class ClientContext; +class QueryContext; class BufferHandle; class FileOpenFlags; class FileSystem; @@ -31,8 +33,8 @@ struct CachingFileHandle { using CachedFile = ExternalFileCache::CachedFile; public: - DUCKDB_API CachingFileHandle(CachingFileSystem &caching_file_system, const OpenFileInfo &path, FileOpenFlags flags, - CachedFile &cached_file); + DUCKDB_API CachingFileHandle(QueryContext context, CachingFileSystem &caching_file_system, const OpenFileInfo &path, + FileOpenFlags flags, CachedFile &cached_file); DUCKDB_API ~CachingFileHandle(); public: @@ -73,6 +75,8 @@ struct CachingFileHandle { idx_t location, bool actually_read); private: + QueryContext context; + //! The client caching file system that was used to create this CachingFileHandle CachingFileSystem &caching_file_system; //! The DB external file cache @@ -111,6 +115,8 @@ class CachingFileSystem { DUCKDB_API static CachingFileSystem Get(ClientContext &context); DUCKDB_API unique_ptr OpenFile(const OpenFileInfo &path, FileOpenFlags flags); + DUCKDB_API unique_ptr OpenFile(QueryContext context, const OpenFileInfo &path, + FileOpenFlags flags); private: //! The Client FileSystem (needs to be client-specific so we can do, e.g., HTTPFS profiling) diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp index d52313902..f86632c1b 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/row_group_writer.hpp @@ -9,6 +9,7 @@ #pragma once #include "duckdb/storage/checkpoint_manager.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" namespace duckdb { struct ColumnCheckpointState; @@ -22,16 +23,22 @@ class SegmentStatistics; // Writes data for an entire row group. class RowGroupWriter { public: - RowGroupWriter(TableCatalogEntry &table, PartialBlockManager &partial_block_manager) - : table(table), partial_block_manager(partial_block_manager) { - } + RowGroupWriter(TableCatalogEntry &table, PartialBlockManager &partial_block_manager); virtual ~RowGroupWriter() { } - CompressionType GetColumnCompressionType(idx_t i); + const vector &GetCompressionTypes() const { + return compression_types; + } virtual CheckpointType GetCheckpointType() const = 0; - virtual MetadataWriter &GetPayloadWriter() = 0; + virtual WriteStream &GetPayloadWriter() = 0; + virtual MetaBlockPointer GetMetaBlockPointer() = 0; + virtual optional_ptr GetMetadataManager() = 0; + virtual void StartWritingColumns(vector &column_metadata) { + } + virtual void FinishWritingColumns() { + } PartialBlockManager &GetPartialBlockManager() { return partial_block_manager; @@ -40,6 +47,7 @@ class RowGroupWriter { protected: TableCatalogEntry &table; PartialBlockManager &partial_block_manager; + vector compression_types; }; // Writes data for an entire row group. @@ -50,7 +58,11 @@ class SingleFileRowGroupWriter : public RowGroupWriter { public: CheckpointType GetCheckpointType() const override; - MetadataWriter &GetPayloadWriter() override; + WriteStream &GetPayloadWriter() override; + MetaBlockPointer GetMetaBlockPointer() override; + optional_ptr GetMetadataManager() override; + void StartWritingColumns(vector &column_metadata) override; + void FinishWritingColumns() override; private: //! Underlying writer object diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_reader.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_reader.hpp index 20ca6d4d3..ebf8672e5 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_reader.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_reader.hpp @@ -16,7 +16,7 @@ struct BoundCreateTableInfo; //! The table data reader is responsible for reading the data of a table from the block manager class TableDataReader { public: - TableDataReader(MetadataReader &reader, BoundCreateTableInfo &info); + TableDataReader(MetadataReader &reader, BoundCreateTableInfo &info, MetaBlockPointer table_pointer); void ReadTableData(); diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp index 4abc37c66..0860eb7c9 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint/table_data_writer.hpp @@ -29,11 +29,13 @@ class TableDataWriter { public: void WriteTableData(Serializer &metadata_serializer); + virtual void WriteUnchangedTable(MetaBlockPointer pointer, idx_t total_rows) = 0; virtual void FinalizeTable(const TableStatistics &global_stats, DataTableInfo *info, Serializer &serializer) = 0; virtual unique_ptr GetRowGroupWriter(RowGroup &row_group) = 0; virtual void AddRowGroup(RowGroupPointer &&row_group_pointer, unique_ptr writer); virtual CheckpointType GetCheckpointType() const = 0; + virtual MetadataManager &GetMetadataManager() = 0; DatabaseInstance &GetDatabase(); unique_ptr CreateTaskExecutor(); @@ -51,14 +53,20 @@ class SingleFileTableDataWriter : public TableDataWriter { MetadataWriter &table_data_writer); public: + void WriteUnchangedTable(MetaBlockPointer pointer, idx_t total_rows) override; void FinalizeTable(const TableStatistics &global_stats, DataTableInfo *info, Serializer &serializer) override; unique_ptr GetRowGroupWriter(RowGroup &row_group) override; CheckpointType GetCheckpointType() const override; + MetadataManager &GetMetadataManager() override; private: SingleFileCheckpointWriter &checkpoint_manager; //! Writes the actual table data MetadataWriter &table_data_writer; + //! The root pointer, if we are re-using metadata of the table + MetaBlockPointer existing_pointer; + optional_idx existing_rows; + vector existing_pointers; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp b/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp index 374547e4b..318d93abf 100644 --- a/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/checkpoint_manager.hpp @@ -33,6 +33,7 @@ class CheckpointWriter { //! The database AttachedDatabase &db; + virtual void CreateCheckpoint() = 0; virtual MetadataManager &GetMetadataManager() = 0; virtual MetadataWriter &GetMetadataWriter() = 0; virtual unique_ptr GetTableDataWriter(TableCatalogEntry &table) = 0; @@ -100,9 +101,7 @@ class SingleFileCheckpointWriter final : public CheckpointWriter { SingleFileCheckpointWriter(QueryContext context, AttachedDatabase &db, BlockManager &block_manager, CheckpointType checkpoint_type); - //! Checkpoint the current state of the WAL and flush it to the main storage. This should be called BEFORE any - //! connection is available because right now the checkpointing cannot be done online. (TODO) - void CreateCheckpoint(); + void CreateCheckpoint() override; MetadataWriter &GetMetadataWriter() override; MetadataManager &GetMetadataManager() override; diff --git a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/packed_data.hpp b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/packed_data.hpp index 932719ded..dad35baaa 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/packed_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/chimp/algorithm/packed_data.hpp @@ -1,14 +1,16 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/storage/compression/chimp/packed_data.hpp +// duckdb/storage/compression/chimp/algorithm/packed_data.hpp // // //===----------------------------------------------------------------------===// #pragma once +#include "duckdb/common/common.hpp" #include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" +#include "duckdb/common/assert.hpp" #include "duckdb.h" namespace duckdb { diff --git a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp index 4e5df07b9..1bc613a0a 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/dict_fsst/compression.hpp @@ -47,7 +47,7 @@ struct DictFSSTCompressionState : public CompressionState { DictionaryAppendState TryEncode(); bool CompressInternal(UnifiedVectorFormat &vector_format, const string_t &str, bool is_null, - EncodedInput &encoded_input, const idx_t i, idx_t count); + EncodedInput &encoded_input, const idx_t i, idx_t count, bool fail_on_no_space); void Compress(Vector &scan_vector, idx_t count); void FinalizeCompress(); void Flush(bool final); diff --git a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp index 05089982e..b523600e3 100644 --- a/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp +++ b/src/duckdb/src/include/duckdb/storage/compression/patas/patas_scan.hpp @@ -1,7 +1,7 @@ //===----------------------------------------------------------------------===// // DuckDB // -// duckdb/common/storage/compression/chimp/chimp_scan.hpp +// duckdb/common/storage/compression/patas/patas_scan.hpp // // //===----------------------------------------------------------------------===// @@ -10,6 +10,7 @@ #include "duckdb/storage/compression/chimp/chimp.hpp" #include "duckdb/storage/compression/chimp/algorithm/chimp_utils.hpp" +#include "duckdb/storage/compression/chimp/algorithm/packed_data.hpp" #include "duckdb/storage/compression/chimp/algorithm/byte_reader.hpp" #include "duckdb/storage/compression/patas/shared.hpp" #include "duckdb/storage/compression/patas/algorithm/patas.hpp" diff --git a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp index 43f1a1ee1..0ae44d7f3 100644 --- a/src/duckdb/src/include/duckdb/storage/data_pointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_pointer.hpp @@ -71,6 +71,11 @@ struct RowGroupPointer { vector data_pointers; //! Data pointers to the delete information of the row group (if any) vector deletes_pointers; + //! Whether or not we have all metadata blocks defined in the pointer + bool has_metadata_blocks = false; + //! Metadata blocks of the columns that are not mentioned in "data_pointers" + //! This is often empty - but can be set for wide columns with a lot of metadata + vector extra_metadata_blocks; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/data_table.hpp b/src/duckdb/src/include/duckdb/storage/data_table.hpp index 8149e0f61..f81d18e48 100644 --- a/src/duckdb/src/include/duckdb/storage/data_table.hpp +++ b/src/duckdb/src/include/duckdb/storage/data_table.hpp @@ -247,7 +247,7 @@ class DataTable : public enable_shared_from_this { shared_ptr &GetDataTableInfo(); - void InitializeIndexes(ClientContext &context); + void BindIndexes(ClientContext &context); bool HasIndexes() const; bool HasUniqueIndexes() const; bool HasForeignKeyIndex(const vector &keys, ForeignKeyType type); @@ -263,9 +263,9 @@ class DataTable : public enable_shared_from_this { idx_t GetRowGroupSize() const; - static void VerifyUniqueIndexes(TableIndexList &indexes, optional_ptr storage, DataChunk &chunk, - optional_ptr manager); - + //! Verify any unique indexes using optional delete indexes in the local storage. + void VerifyUniqueIndexes(TableIndexList &indexes, optional_ptr storage, DataChunk &chunk, + optional_ptr manager); //! AddIndex initializes an index and adds it to the table's index list. //! It is either empty, or initialized via its index storage information. void AddIndex(const ColumnList &columns, const vector &column_indexes, const IndexConstraintType type, diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp index 28110eceb..ee39fee68 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_manager.hpp @@ -11,6 +11,7 @@ #include "duckdb/common/common.hpp" #include "duckdb/storage/block.hpp" #include "duckdb/storage/block_manager.hpp" +#include "duckdb/common/atomic.hpp" #include "duckdb/common/set.hpp" #include "duckdb/storage/buffer/buffer_handle.hpp" @@ -19,9 +20,18 @@ class DatabaseInstance; struct MetadataBlockInfo; struct MetadataBlock { + MetadataBlock(); + // disable copy constructors + MetadataBlock(const MetadataBlock &other) = delete; + MetadataBlock &operator=(const MetadataBlock &) = delete; + //! enable move constructors + DUCKDB_API MetadataBlock(MetadataBlock &&other) noexcept; + DUCKDB_API MetadataBlock &operator=(MetadataBlock &&) noexcept; + shared_ptr block; block_id_t block_id; vector free_blocks; + atomic dirty; void Write(WriteStream &sink); static MetadataBlock Read(ReadStream &source); diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp index 1acb60fe2..6ccc2d645 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_reader.hpp @@ -32,6 +32,9 @@ class MetadataReader : public ReadStream { MetadataManager &GetMetadataManager() { return manager; } + //! Gets a list of all remaining blocks to be read by this metadata reader - consumes all blocks + //! If "last_block" is specified, we stop when reaching that block + vector GetRemainingBlocks(MetaBlockPointer last_block = MetaBlockPointer()); private: data_ptr_t BasePtr(); diff --git a/src/duckdb/src/include/duckdb/storage/metadata/metadata_writer.hpp b/src/duckdb/src/include/duckdb/storage/metadata/metadata_writer.hpp index 206451b77..4bb0f9804 100644 --- a/src/duckdb/src/include/duckdb/storage/metadata/metadata_writer.hpp +++ b/src/duckdb/src/include/duckdb/storage/metadata/metadata_writer.hpp @@ -30,6 +30,7 @@ class MetadataWriter : public WriteStream { MetadataManager &GetManager() { return manager; } + void SetWrittenPointers(optional_ptr> written_pointers); protected: virtual MetadataHandle NextHandle(); diff --git a/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp index 8bd67e239..4e901fb0f 100644 --- a/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/partial_block_manager.hpp @@ -60,6 +60,7 @@ struct PartialBlock { public: //! Add regions that need zero-initialization to avoid leaking memory void AddUninitializedRegion(const idx_t start, const idx_t end); + virtual void AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block); //! Flush the block to disk and zero-initialize any free space and uninitialized regions virtual void Flush(QueryContext context, const idx_t free_space_left) = 0; void FlushInternal(const idx_t free_space_left); @@ -85,7 +86,7 @@ struct PartialBlockAllocation { unique_ptr partial_block; }; -enum class PartialBlockType { FULL_CHECKPOINT, APPEND_TO_TABLE }; +enum class PartialBlockType { FULL_CHECKPOINT, APPEND_TO_TABLE, IN_MEMORY_CHECKPOINT }; //! Enables sharing blocks across some scope. Scope is whatever we want to share //! blocks across. It may be an entire checkpoint or just a single row group. @@ -122,6 +123,9 @@ class PartialBlockManager { //! Flush any remaining partial blocks to disk void FlushPartialBlocks(); + unique_ptr CreatePartialBlock(ColumnData &data, ColumnSegment &segment, PartialBlockState state, + BlockManager &block_manager); + unique_lock GetLock() { return unique_lock(partial_block_lock); } diff --git a/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp index e3ac95c0a..0765d041c 100644 --- a/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/single_file_block_manager.hpp @@ -65,7 +65,7 @@ class SingleFileBlockManager : public BlockManager { void CreateNewDatabase(QueryContext context); //! Loads an existing database. We pass the provided block allocation size as a parameter //! to detect inconsistencies with the file header. - void LoadExistingDatabase(); + void LoadExistingDatabase(QueryContext context); //! Creates a new Block using the specified block_id and returns a pointer unique_ptr ConvertBlock(block_id_t block_id, FileBuffer &source_buffer) override; @@ -124,7 +124,8 @@ class SingleFileBlockManager : public BlockManager { void CheckChecksum(FileBuffer &block, uint64_t location, uint64_t delta, bool skip_block_header = false) const; void CheckChecksum(data_ptr_t start_ptr, uint64_t delta, bool skip_block_header = false) const; - void ReadAndChecksum(FileBuffer &handle, uint64_t location, bool skip_block_header = false) const; + void ReadAndChecksum(QueryContext context, FileBuffer &handle, uint64_t location, + bool skip_block_header = false) const; void ChecksumAndWrite(QueryContext context, FileBuffer &handle, uint64_t location, bool skip_block_header = false) const; diff --git a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp index 4636a85dc..cadb346fb 100644 --- a/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/standard_buffer_manager.hpp @@ -135,7 +135,7 @@ class StandardBufferManager : public BufferManager { //! Write a temporary buffer to disk void WriteTemporaryBuffer(MemoryTag tag, block_id_t block_id, FileBuffer &buffer) final; //! Read a temporary buffer from disk - unique_ptr ReadTemporaryBuffer(MemoryTag tag, BlockHandle &block, + unique_ptr ReadTemporaryBuffer(QueryContext context, MemoryTag tag, BlockHandle &block, unique_ptr buffer = nullptr) final; //! Get the path of the temporary buffer string GetTemporaryPath(block_id_t id); diff --git a/src/duckdb/src/include/duckdb/storage/storage_extension.hpp b/src/duckdb/src/include/duckdb/storage/storage_extension.hpp index 62d5a8028..7770df5fc 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_extension.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_extension.hpp @@ -25,10 +25,10 @@ struct StorageExtensionInfo { } }; -typedef unique_ptr (*attach_function_t)(StorageExtensionInfo *storage_info, ClientContext &context, - AttachedDatabase &db, const string &name, AttachInfo &info, - AccessMode access_mode); -typedef unique_ptr (*create_transaction_manager_t)(StorageExtensionInfo *storage_info, +typedef unique_ptr (*attach_function_t)(optional_ptr storage_info, + ClientContext &context, AttachedDatabase &db, const string &name, + AttachInfo &info, AttachOptions &options); +typedef unique_ptr (*create_transaction_manager_t)(optional_ptr storage_info, AttachedDatabase &db, Catalog &catalog); class StorageExtension { diff --git a/src/duckdb/src/include/duckdb/storage/storage_info.hpp b/src/duckdb/src/include/duckdb/storage/storage_info.hpp index c08f97115..f7a078330 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_info.hpp @@ -16,6 +16,7 @@ namespace duckdb { struct FileHandle; +class QueryContext; //! The standard row group size #define DEFAULT_ROW_GROUP_SIZE 122880ULL @@ -26,10 +27,6 @@ struct FileHandle; //! The default block allocation size. #define DEFAULT_BLOCK_ALLOC_SIZE 262144ULL -//! The configurable block allocation size. -#ifndef DUCKDB_BLOCK_ALLOC_SIZE -#define DUCKDB_BLOCK_ALLOC_SIZE DEFAULT_BLOCK_ALLOC_SIZE -#endif //! The default block header size. #define DEFAULT_BLOCK_HEADER_STORAGE_SIZE 8ULL //! The default block header size. @@ -110,7 +107,7 @@ struct MainHeader { static constexpr uint64_t AES_IV_LEN = 16; static constexpr uint64_t AES_TAG_LEN = 16; - static void CheckMagicBytes(FileHandle &handle); + static void CheckMagicBytes(QueryContext context, FileHandle &handle); string LibraryGitDesc() { return string(char_ptr_cast(library_git_desc), 0, MAX_VERSION_SIZE); @@ -197,8 +194,5 @@ struct DatabaseHeader { #if (DEFAULT_BLOCK_ALLOC_SIZE & (DEFAULT_BLOCK_ALLOC_SIZE - 1) != 0) #error The default block allocation size must be a power of two #endif -#if (DUCKDB_BLOCK_ALLOC_SIZE & (DUCKDB_BLOCK_ALLOC_SIZE - 1) != 0) -#error The duckdb block allocation size must be a power of two -#endif } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp index fa9b3b40e..ea60405b5 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_manager.hpp @@ -61,7 +61,7 @@ struct CheckpointOptions { //! StorageManager is responsible for managing the physical storage of a persistent database. class StorageManager { public: - StorageManager(AttachedDatabase &db, string path, bool read_only); + StorageManager(AttachedDatabase &db, string path, const AttachOptions &options); virtual ~StorageManager(); public: @@ -71,10 +71,10 @@ class StorageManager { //! Initialize a database or load an existing database from the database file path. The block_alloc_size is //! either set, or invalid. If invalid, then DuckDB defaults to the default_block_alloc_size (DBConfig), //! or the file's block allocation size, if it is an existing database. - void Initialize(QueryContext context, StorageOptions &options); + void Initialize(QueryContext context); DatabaseInstance &GetDatabase(); - AttachedDatabase &GetAttached() { + AttachedDatabase &GetAttached() const { return db; } @@ -93,8 +93,8 @@ class StorageManager { return load_complete; } //! The path to the WAL, derived from the database file path - string GetWALPath(); - bool InMemory(); + string GetWALPath() const; + bool InMemory() const; virtual bool AutomaticCheckpoint(idx_t estimated_wal_bytes) = 0; virtual unique_ptr GenStorageCommitState(WriteAheadLog &wal) = 0; @@ -115,9 +115,18 @@ class StorageManager { D_ASSERT(HasStorageVersion()); return storage_version.GetIndex(); } + void AddInMemoryChange(idx_t size) { + in_memory_change_size += size; + } + void ResetInMemoryChange() { + in_memory_change_size = 0; + } + bool CompressionIsEnabled() const { + return storage_options.compress_in_memory == CompressInMemory::COMPRESS; + } protected: - virtual void LoadDatabase(QueryContext context, StorageOptions &options) = 0; + virtual void LoadDatabase(QueryContext context) = 0; protected: //! The attached database managed by this storage manager. @@ -133,6 +142,10 @@ class StorageManager { bool load_complete = false; //! The serialization compatibility version when reading and writing from this database optional_idx storage_version; + //! Estimated size of changes for determining automatic checkpointing on in-memory databases + atomic in_memory_change_size; + //! Storage options passed in through configuration + StorageOptions storage_options; public: template @@ -151,7 +164,7 @@ class StorageManager { class SingleFileStorageManager : public StorageManager { public: SingleFileStorageManager() = delete; - SingleFileStorageManager(AttachedDatabase &db, string path, bool read_only); + SingleFileStorageManager(AttachedDatabase &db, string path, const AttachOptions &options); //! The BlockManager to read from and write to blocks (meta data and data). unique_ptr block_manager; @@ -169,6 +182,8 @@ class SingleFileStorageManager : public StorageManager { BlockManager &GetBlockManager() override; protected: - void LoadDatabase(QueryContext context, StorageOptions &options) override; + void LoadDatabase(QueryContext context) override; + + unique_ptr CreateCheckpointWriter(QueryContext context, CheckpointOptions options); }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/storage_options.hpp b/src/duckdb/src/include/duckdb/storage/storage_options.hpp index 61c99151b..a3ff65acc 100644 --- a/src/duckdb/src/include/duckdb/storage/storage_options.hpp +++ b/src/duckdb/src/include/duckdb/storage/storage_options.hpp @@ -10,9 +10,12 @@ #include "duckdb/common/common.hpp" #include "duckdb/common/optional_idx.hpp" +#include "duckdb/common/types/value.hpp" namespace duckdb { +enum class CompressInMemory { AUTOMATIC, COMPRESS, DO_NOT_COMPRESS }; + struct StorageOptions { //! The allocation size of blocks for this attached database file (if any) optional_idx block_alloc_size; @@ -23,6 +26,8 @@ struct StorageOptions { //! Block header size (only used for encryption) optional_idx block_header_size; + CompressInMemory compress_in_memory = CompressInMemory::AUTOMATIC; + //! Whether the database is encrypted bool encryption = false; //! Encryption algorithm (default = GCM) @@ -30,6 +35,8 @@ struct StorageOptions { //! encryption key //! FIXME: change to a unique_ptr in the future shared_ptr user_key; + + void Initialize(const unordered_map &options); }; inline void ClearUserKey(shared_ptr const &encryption_key) { diff --git a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp index 05d48e9c0..abc9577a3 100644 --- a/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/array_column_data.hpp @@ -61,6 +61,7 @@ class ArrayColumnData : public ColumnData { unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; bool IsPersistent() override; + bool HasAnyChanges() const override; PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp b/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp index 25504a3e8..9e5f7f98b 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_checkpoint_state.hpp @@ -87,7 +87,7 @@ struct PartialBlockForCheckpoint : public PartialBlock { bool IsFlushed(); void Flush(QueryContext context, const idx_t free_space_left) override; void Merge(PartialBlock &other, idx_t offset, idx_t other_size) override; - void AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block); + void AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block) override; void Clear() override; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp index 49a95c7ed..60f99962b 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data.hpp @@ -102,6 +102,11 @@ class ColumnData { //! Whether or not the column has any updates bool HasUpdates() const; bool HasChanges(idx_t start_row, idx_t end_row) const; + //! Whether or not the column has changes at this level + bool HasChanges() const; + + //! Whether or not the column has ANY changes, including in child columns + virtual bool HasAnyChanges() const; //! Whether or not we can scan an entire vector virtual ScanVectorType GetVectorScanType(ColumnScanState &state, idx_t scan_count, Vector &result); @@ -214,6 +219,7 @@ class ColumnData { void FetchUpdateRow(TransactionData transaction, row_t row_id, Vector &result, idx_t result_idx); void UpdateInternal(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count, Vector &base_vector); + idx_t FetchUpdateData(row_t *row_ids, Vector &base_vector); idx_t GetVectorCount(idx_t vector_index) const; diff --git a/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp index aec2efdba..f31e77521 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_data_checkpointer.hpp @@ -22,9 +22,10 @@ struct ColumnDataCheckpointData { ColumnDataCheckpointData() { } ColumnDataCheckpointData(ColumnCheckpointState &checkpoint_state, ColumnData &col_data, DatabaseInstance &db, - RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info) + RowGroup &row_group, ColumnCheckpointInfo &checkpoint_info, + StorageManager &storage_manager) : checkpoint_state(checkpoint_state), col_data(col_data), db(db), row_group(row_group), - checkpoint_info(checkpoint_info) { + checkpoint_info(checkpoint_info), storage_manager(storage_manager) { } public: @@ -34,6 +35,7 @@ struct ColumnDataCheckpointData { RowGroup &GetRowGroup(); ColumnCheckpointState &GetCheckpointState(); DatabaseInstance &GetDatabase(); + StorageManager &GetStorageManager(); private: optional_ptr checkpoint_state; @@ -41,6 +43,7 @@ struct ColumnDataCheckpointData { optional_ptr db; optional_ptr row_group; optional_ptr checkpoint_info; + optional_ptr storage_manager; }; struct CheckpointAnalyzeResult { diff --git a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp index 0af6d6fd0..61b2c0d4f 100644 --- a/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/column_segment.hpp @@ -99,12 +99,13 @@ class ColumnSegment : public SegmentBase { //! Revert an append made to this segment void RevertAppend(idx_t start_row); - //! Convert a transient in-memory segment into a persistent segment blocked by an on-disk block. + //! Convert a transient in-memory segment to a persistent segment backed by an on-disk block. //! Only used during checkpointing. - void ConvertToPersistent(QueryContext context, optional_ptr block_manager, block_id_t block_id); + void ConvertToPersistent(QueryContext context, optional_ptr block_manager, const block_id_t block_id); //! Updates pointers to refer to the given block and offset. This is only used //! when sharing a block among segments. This is invoked only AFTER the block is written. void MarkAsPersistent(shared_ptr block, uint32_t offset_in_block); + void SetBlock(shared_ptr block, uint32_t offset); //! Gets a data pointer from a persistent column segment DataPointer GetDataPointer(); diff --git a/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp index fbc2e443b..9511cb18a 100644 --- a/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/data_table_info.hpp @@ -23,9 +23,9 @@ struct DataTableInfo { public: DataTableInfo(AttachedDatabase &db, shared_ptr table_io_manager_p, string schema, string table); - //! Initialize any unknown indexes whose types might now be present after an extension load, optionally throwing an - //! exception if an index can't be initialized - void InitializeIndexes(ClientContext &context, const char *index_type = nullptr); + //! Bind unknown indexes throwing an exception if binding fails. + //! Only binds the specified index type, or all, if nullptr. + void BindIndexes(ClientContext &context, const char *index_type = nullptr); //! Whether or not the table is temporary bool IsTemporary() const; diff --git a/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp b/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp new file mode 100644 index 000000000..68896f664 --- /dev/null +++ b/src/duckdb/src/include/duckdb/storage/table/in_memory_checkpoint.hpp @@ -0,0 +1,92 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/storage/table/in_memory_checkpoint.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/storage/checkpoint/row_group_writer.hpp" +#include "duckdb/storage/checkpoint/table_data_writer.hpp" +#include "duckdb/storage/checkpoint_manager.hpp" + +namespace duckdb { + +class InMemoryCheckpointer final : public CheckpointWriter { +public: + InMemoryCheckpointer(QueryContext context, AttachedDatabase &db, BlockManager &block_manager, + StorageManager &storage_manager, CheckpointType checkpoint_type); + + void CreateCheckpoint() override; + + MetadataWriter &GetMetadataWriter() override; + MetadataManager &GetMetadataManager() override; + unique_ptr GetTableDataWriter(TableCatalogEntry &table) override; + optional_ptr GetClientContext() const { + return context; + } + CheckpointType GetCheckpointType() const { + return checkpoint_type; + } + PartialBlockManager &GetPartialBlockManager() { + return partial_block_manager; + } + +public: + void WriteTable(TableCatalogEntry &table, Serializer &serializer) override; + +private: + optional_ptr context; + PartialBlockManager partial_block_manager; + StorageManager &storage_manager; + CheckpointType checkpoint_type; +}; + +class InMemoryTableDataWriter : public TableDataWriter { +public: + InMemoryTableDataWriter(InMemoryCheckpointer &checkpoint_manager, TableCatalogEntry &table); + +public: + void WriteUnchangedTable(MetaBlockPointer pointer, idx_t total_rows) override; + void FinalizeTable(const TableStatistics &global_stats, DataTableInfo *info, Serializer &serializer) override; + unique_ptr GetRowGroupWriter(RowGroup &row_group) override; + CheckpointType GetCheckpointType() const override; + MetadataManager &GetMetadataManager() override; + +private: + InMemoryCheckpointer &checkpoint_manager; +}; + +class InMemoryRowGroupWriter : public RowGroupWriter { +public: + InMemoryRowGroupWriter(TableCatalogEntry &table, PartialBlockManager &partial_block_manager, + InMemoryCheckpointer &checkpoint_manager); + +public: + CheckpointType GetCheckpointType() const override; + WriteStream &GetPayloadWriter() override; + MetaBlockPointer GetMetaBlockPointer() override; + optional_ptr GetMetadataManager() override; + +private: + //! Underlying writer object + InMemoryCheckpointer &checkpoint_manager; + // Nop metadata writer + MemoryStream metadata_writer; +}; + +struct InMemoryPartialBlock : public PartialBlock { +public: + InMemoryPartialBlock(ColumnData &data, ColumnSegment &segment, PartialBlockState state, + BlockManager &block_manager); + ~InMemoryPartialBlock() override; + +public: + void Flush(QueryContext context, const idx_t free_space_left) override; + void Merge(PartialBlock &other, idx_t offset, idx_t other_size) override; + void AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block) override; + void Clear() override; +}; +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp index f80371f4f..c8e75d136 100644 --- a/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/list_column_data.hpp @@ -59,6 +59,7 @@ class ListColumnData : public ColumnData { unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; bool IsPersistent() override; + bool HasAnyChanges() const override; PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp b/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp index f5eb6d9aa..fa806478c 100644 --- a/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/persistent_table_data.hpp @@ -22,6 +22,7 @@ class PersistentTableData { explicit PersistentTableData(idx_t column_count); ~PersistentTableData(); + MetaBlockPointer base_table_pointer; TableStatistics table_stats; idx_t total_rows; idx_t row_group_count; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp index 0a56f1290..129996d59 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group.hpp @@ -65,6 +65,7 @@ struct RowGroupWriteInfo { struct RowGroupWriteData { vector> states; vector statistics; + vector existing_pointers; }; class RowGroup : public SegmentBase { @@ -92,6 +93,12 @@ class RowGroup : public SegmentBase { RowGroupCollection &GetCollection() { return collection.get(); } + //! Returns the list of meta block pointers used by the columns + vector GetColumnPointers(); + //! Returns the list of meta block pointers used by the deletes + const vector &GetDeletesPointers() const { + return deletes_pointers; + } BlockManager &GetBlockManager(); DataTableInfo &GetTableInfo(); @@ -106,6 +113,7 @@ class RowGroup : public SegmentBase { void CommitDropColumn(const idx_t column_index); void InitializeEmpty(const vector &types); + bool HasChanges() const; //! Initialize a scan over this row_group bool InitializeScan(CollectionScanState &state); @@ -209,6 +217,8 @@ class RowGroup : public SegmentBase { vector column_pointers; unique_ptr[]> is_loaded; vector deletes_pointers; + bool has_metadata_blocks = false; + vector extra_metadata_blocks; atomic deletes_is_loaded; atomic allocation_size; }; diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp index e9bf0c949..769242606 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_collection.hpp @@ -167,6 +167,8 @@ class RowGroupCollection { TableStatistics stats; //! Allocation size, only tracked for appends atomic allocation_size; + //! Root metadata pointer, if the collection is loaded from disk + MetaBlockPointer metadata_pointer; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp index 948107531..d93610138 100644 --- a/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/row_group_segment_tree.hpp @@ -23,6 +23,10 @@ class RowGroupSegmentTree : public SegmentTree { void Initialize(PersistentTableData &data); + MetaBlockPointer GetRootPointer() const { + return root_pointer; + } + protected: unique_ptr LoadSegment() override; @@ -30,6 +34,7 @@ class RowGroupSegmentTree : public SegmentTree { idx_t current_row_group; idx_t max_row_group; unique_ptr reader; + MetaBlockPointer root_pointer; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp index 647ba06ea..16e297bae 100644 --- a/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/segment_tree.hpp @@ -38,7 +38,7 @@ class SegmentTree { //! Locks the segment tree. All methods to the segment tree either lock the segment tree, or take an already //! obtained lock. - SegmentLock Lock() { + SegmentLock Lock() const { return SegmentLock(node_lock); } @@ -76,12 +76,15 @@ class SegmentTree { auto l = Lock(); return ReferenceSegments(l); } + const vector> &ReferenceLoadedSegments(SegmentLock &l) const { + return nodes; + } idx_t GetSegmentCount() { auto l = Lock(); return GetSegmentCount(l); } - idx_t GetSegmentCount(SegmentLock &l) { + idx_t GetSegmentCount(SegmentLock &l) const { return nodes.size(); } //! Gets a pointer to the nth segment. Negative numbers start from the back. @@ -252,6 +255,10 @@ class SegmentTree { return SegmentIterationHelper(*this); } + SegmentIterationHelper Segments(SegmentLock &l) { + return SegmentIterationHelper(*this, l); + } + void Reinitialize() { if (nodes.empty()) { return; @@ -274,37 +281,42 @@ class SegmentTree { return nullptr; } + T *GetRootSegmentInternal() const { + return nodes.empty() ? nullptr : nodes[0].node.get(); + } + private: //! The nodes in the tree, can be binary searched vector> nodes; //! Lock to access or modify the nodes - mutex node_lock; + mutable mutex node_lock; private: - T *GetRootSegmentInternal() { - return nodes.empty() ? nullptr : nodes[0].node.get(); - } - class SegmentIterationHelper { public: explicit SegmentIterationHelper(SegmentTree &tree) : tree(tree) { } + SegmentIterationHelper(SegmentTree &tree, SegmentLock &l) : tree(tree), lock(l) { + } private: SegmentTree &tree; + optional_ptr lock; private: class SegmentIterator { public: - SegmentIterator(SegmentTree &tree_p, T *current_p) : tree(tree_p), current(current_p) { + SegmentIterator(SegmentTree &tree_p, T *current_p, optional_ptr lock) + : tree(tree_p), current(current_p), lock(lock) { } SegmentTree &tree; T *current; + optional_ptr lock; public: void Next() { - current = tree.GetNextSegment(current); + current = lock ? tree.GetNextSegment(*lock, current) : tree.GetNextSegment(current); } SegmentIterator &operator++() { @@ -322,10 +334,11 @@ class SegmentTree { public: SegmentIterator begin() { // NOLINT: match stl API - return SegmentIterator(tree, tree.GetRootSegment()); + auto root = lock ? tree.GetRootSegment(*lock) : tree.GetRootSegment(); + return SegmentIterator(tree, root, lock); } SegmentIterator end() { // NOLINT: match stl API - return SegmentIterator(tree, nullptr); + return SegmentIterator(tree, nullptr, lock); } }; diff --git a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp index 9bad48ffa..48ac6ccb7 100644 --- a/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/standard_column_data.hpp @@ -65,6 +65,7 @@ class StandardColumnData : public ColumnData { vector &result) override; bool IsPersistent() override; + bool HasAnyChanges() const override; PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp index b3fd8f4d0..d05436bfc 100644 --- a/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/struct_column_data.hpp @@ -59,6 +59,7 @@ class StructColumnData : public ColumnData { unique_ptr Checkpoint(RowGroup &row_group, ColumnCheckpointInfo &info) override; bool IsPersistent() override; + bool HasAnyChanges() const override; PersistentColumnData Serialize() override; void InitializeColumn(PersistentColumnData &column_data, BaseStatistics &target_stats) override; diff --git a/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp index c062032cc..05715cec4 100644 --- a/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp +++ b/src/duckdb/src/include/duckdb/storage/table/table_index_list.hpp @@ -20,47 +20,32 @@ class LocalTableStorage; struct IndexStorageInfo; struct DataTableInfo; +//! IndexBindState to transition index binding phases preventing lock order inversion. +enum class IndexBindState : uint8_t { UNBOUND, BINDING, BOUND }; + +//! IndexEntry contains an atomic in addition to the index to ensure correct binding. +struct IndexEntry { + explicit IndexEntry(unique_ptr index); + atomic bind_state; + unique_ptr index; +}; + class TableIndexList { public: - //! Scan the indexes, invoking the callback method for every entry. + //! Scan the index entries, invoking the callback method for every entry. template void Scan(T &&callback) { - lock_guard lock(indexes_lock); - for (auto &index : indexes) { - if (callback(*index)) { + lock_guard lock(index_entries_lock); + for (auto &entry : index_entries) { + if (callback(*entry->index)) { break; } } } - //! Scan the indexes, invoking the callback method for every bound entry of type T. - template - void ScanBound(FUNC &&callback) { - lock_guard lock(indexes_lock); - for (auto &index : indexes) { - if (index->IsBound() && T::TYPE_NAME == index->GetIndexType()) { - if (callback(index->Cast())) { - break; - } - } - } - } - - // Bind any unbound indexes of type T and invoke the callback method. - template - void BindAndScan(ClientContext &context, DataTableInfo &table_info, FUNC &&callback) { - // FIXME: optimize this by only looping through the indexes once without re-acquiring the lock. - InitializeIndexes(context, table_info, T::TYPE_NAME); - ScanBound(callback); - } - - //! Returns a reference to the indexes. - const vector> &Indexes() const { - return indexes; - } - //! Adds an index to the list of indexes. + //! Adds an index entry to the list of index entries. void AddIndex(unique_ptr index); - //! Removes an index from the list of indexes. + //! Removes an index entry from the list of index entries. void RemoveIndex(const string &name); //! Removes all remaining memory of an index after dropping the catalog entry. void CommitDrop(const string &name); @@ -68,12 +53,11 @@ class TableIndexList { bool NameIsUnique(const string &name); //! Returns an optional pointer to the index matching the name. optional_ptr Find(const string &name); - //! Initializes unknown indexes that are possibly present after an extension load, optionally throwing an exception - //! on failure. - void InitializeIndexes(ClientContext &context, DataTableInfo &table_info, const char *index_type = nullptr); - //! Returns true, if there are no indexes in this list. + //! Binds unbound indexes possibly present after loading an extension. + void Bind(ClientContext &context, DataTableInfo &table_info, const char *index_type = nullptr); + //! Returns true, if there are no index entries. bool Empty(); - //! Returns the number of indexes in this list. + //! Returns the number of index entries. idx_t Count(); //! Overwrite this list with the other list. void Move(TableIndexList &other); @@ -82,16 +66,16 @@ class TableIndexList { //! Verify a foreign key constraint. void VerifyForeignKey(optional_ptr storage, const vector &fk_keys, DataChunk &chunk, ConflictManager &conflict_manager); - //! Get the combined column ids of the indexes in this list. + //! Get the combined column ids of the indexes. unordered_set GetRequiredColumns(); - //! Serialize all indexes of this table. + //! Serialize all indexes of the table. vector SerializeToDisk(QueryContext context, const case_insensitive_map_t &options); private: - //! A lock to prevent any concurrent changes to the indexes. - mutex indexes_lock; - //! Indexes associated with the table. - vector> indexes; + //! A lock to prevent any concurrent changes to the index entries. + mutex index_entries_lock; + //! The index entries of the table. + vector> index_entries; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp b/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp index b14baa0c3..679655f4d 100644 --- a/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp +++ b/src/duckdb/src/include/duckdb/storage/temporary_file_manager.hpp @@ -140,7 +140,7 @@ class TemporaryFileHandle { void EraseBlockIndex(block_id_t block_index); //! Read/Write temporary buffers at given positions in this file (potentially compressed) - unique_ptr ReadTemporaryBuffer(const TemporaryFileIndex &index_in_file, + unique_ptr ReadTemporaryBuffer(QueryContext context, const TemporaryFileIndex &index_in_file, unique_ptr reusable_buffer) const; void WriteTemporaryBuffer(FileBuffer &buffer, idx_t block_index, AllocatedData &compressed_buffer) const; @@ -280,10 +280,11 @@ class TemporaryFileManager { }; //! Create/Read/Update/Delete operations for temporary buffers - void WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer); + idx_t WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer); bool HasTemporaryBuffer(block_id_t block_id); - unique_ptr ReadTemporaryBuffer(block_id_t id, unique_ptr reusable_buffer); - void DeleteTemporaryBuffer(block_id_t id); + unique_ptr ReadTemporaryBuffer(QueryContext context, block_id_t id, + unique_ptr reusable_buffer); + idx_t DeleteTemporaryBuffer(block_id_t id); bool IsEncrypted() const; //! Get the list of temporary files and their sizes diff --git a/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp index ca6b2c920..583f3f57b 100644 --- a/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp +++ b/src/duckdb/src/include/duckdb/transaction/meta_transaction.hpp @@ -22,6 +22,17 @@ class AttachedDatabase; class ClientContext; class Transaction; +enum class TransactionState { UNCOMMITTED, COMMITTED, ROLLED_BACK }; + +struct TransactionReference { + explicit TransactionReference(Transaction &transaction_p) + : state(TransactionState::UNCOMMITTED), transaction(transaction_p) { + } + + TransactionState state; + Transaction &transaction; +}; + //! The MetaTransaction manages multiple transactions for different attached databases class MetaTransaction { public: @@ -68,7 +79,7 @@ class MetaTransaction { //! Lock to prevent all_transactions and transactions from getting out of sync mutex lock; //! The set of active transactions for each database - reference_map_t> transactions; + reference_map_t transactions; //! The set of transactions in order of when they were started vector> all_transactions; //! The database we are modifying - we can only modify one database per transaction diff --git a/src/duckdb/src/include/duckdb/verification/explain_statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/explain_statement_verifier.hpp new file mode 100644 index 000000000..cbf52de17 --- /dev/null +++ b/src/duckdb/src/include/duckdb/verification/explain_statement_verifier.hpp @@ -0,0 +1,23 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/verification/explain_statement_verifier.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/verification/statement_verifier.hpp" + +namespace duckdb { + +class ExplainStatementVerifier : public StatementVerifier { +public: + explicit ExplainStatementVerifier(unique_ptr statement_p, + optional_ptr> parameters); + static unique_ptr Create(const SQLStatement &statement_p, + optional_ptr> parameters); +}; + +} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp index 63d52393e..a60abf187 100644 --- a/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp +++ b/src/duckdb/src/include/duckdb/verification/statement_verifier.hpp @@ -23,6 +23,7 @@ enum class VerificationType : uint8_t { NO_OPERATOR_CACHING, PREPARED, EXTERNAL, + EXPLAIN, FETCH_ROW_AS_SCAN, INVALID @@ -38,6 +39,7 @@ class StatementVerifier { optional_ptr> values); virtual ~StatementVerifier() noexcept; +public: //! Check whether expressions in this verifier and the other verifier match void CheckExpressions(const StatementVerifier &other) const; //! Check whether expressions within this verifier match @@ -52,13 +54,6 @@ class StatementVerifier { string CompareResults(const StatementVerifier &other); public: - const VerificationType type; - const string name; - unique_ptr statement; - optional_ptr> parameters; - const vector> &select_list; - unique_ptr materialized_result; - virtual bool RequireEquality() const { return true; } @@ -78,6 +73,18 @@ class StatementVerifier { virtual bool ForceFetchRow() const { return false; } + +public: + const VerificationType type; + const string name; + unique_ptr statement; + optional_ptr select_statement; + optional_ptr> parameters; + const vector> &select_list; + unique_ptr materialized_result; + +private: + const vector> empty_select_list = {}; }; } // namespace duckdb diff --git a/src/duckdb/src/include/duckdb_extension.h b/src/duckdb/src/include/duckdb_extension.h index 9462e1247..ca760e552 100644 --- a/src/duckdb/src/include/duckdb_extension.h +++ b/src/duckdb/src/include/duckdb_extension.h @@ -199,7 +199,7 @@ typedef struct { duckdb_value (*duckdb_create_timestamp)(duckdb_timestamp input); duckdb_value (*duckdb_create_interval)(duckdb_interval input); duckdb_value (*duckdb_create_blob)(const uint8_t *data, idx_t length); - duckdb_value (*duckdb_create_varint)(duckdb_varint input); + duckdb_value (*duckdb_create_bignum)(duckdb_bignum input); duckdb_value (*duckdb_create_decimal)(duckdb_decimal input); duckdb_value (*duckdb_create_bit)(duckdb_bit input); duckdb_value (*duckdb_create_uuid)(duckdb_uhugeint input); @@ -223,7 +223,7 @@ typedef struct { duckdb_interval (*duckdb_get_interval)(duckdb_value val); duckdb_logical_type (*duckdb_get_value_type)(duckdb_value val); duckdb_blob (*duckdb_get_blob)(duckdb_value val); - duckdb_varint (*duckdb_get_varint)(duckdb_value val); + duckdb_bignum (*duckdb_get_bignum)(duckdb_value val); duckdb_decimal (*duckdb_get_decimal)(duckdb_value val); duckdb_bit (*duckdb_get_bit)(duckdb_value val); duckdb_uhugeint (*duckdb_get_uuid)(duckdb_value val); @@ -543,6 +543,20 @@ typedef struct { duckdb_error_data (*duckdb_appender_error_data)(duckdb_appender appender); #endif +// New arrow interface functions +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_error_data (*duckdb_to_arrow_schema)(duckdb_arrow_options arrow_options, duckdb_logical_type *types, + const char **names, idx_t column_count, struct ArrowSchema *out_schema); + duckdb_error_data (*duckdb_data_chunk_to_arrow)(duckdb_arrow_options arrow_options, duckdb_data_chunk chunk, + struct ArrowArray *out_arrow_array); + duckdb_error_data (*duckdb_schema_from_arrow)(duckdb_connection connection, struct ArrowSchema *schema, + duckdb_arrow_converted_schema *out_types); + duckdb_error_data (*duckdb_data_chunk_from_arrow)(duckdb_connection connection, struct ArrowArray *arrow_array, + duckdb_arrow_converted_schema converted_schema, + duckdb_data_chunk *out_chunk); + void (*duckdb_destroy_arrow_converted_schema)(duckdb_arrow_converted_schema *arrow_converted_schema); +#endif + // New functions for duckdb error data #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE duckdb_error_data (*duckdb_create_error_data)(duckdb_error_type type, const char *message); @@ -552,12 +566,28 @@ typedef struct { bool (*duckdb_error_data_has_error)(duckdb_error_data error_data); #endif +// API to create and manipulate expressions +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + void (*duckdb_destroy_expression)(duckdb_expression *expr); + duckdb_logical_type (*duckdb_expression_return_type)(duckdb_expression expr); + bool (*duckdb_expression_is_foldable)(duckdb_expression expr); + duckdb_error_data (*duckdb_expression_fold)(duckdb_client_context context, duckdb_expression expr, + duckdb_value *out_value); +#endif + // New functions around the client context #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE idx_t (*duckdb_client_context_get_connection_id)(duckdb_client_context context); void (*duckdb_destroy_client_context)(duckdb_client_context *context); void (*duckdb_connection_get_client_context)(duckdb_connection connection, duckdb_client_context *out_context); duckdb_value (*duckdb_get_table_names)(duckdb_connection connection, const char *query, bool qualified); + void (*duckdb_connection_get_arrow_options)(duckdb_connection connection, duckdb_arrow_options *out_arrow_options); + void (*duckdb_destroy_arrow_options)(duckdb_arrow_options *arrow_options); +#endif + +// New query execution functions +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + duckdb_arrow_options (*duckdb_result_get_arrow_options)(duckdb_result *result); #endif // New functions around scalar function binding @@ -569,6 +599,8 @@ typedef struct { duckdb_delete_callback_t destroy); void *(*duckdb_scalar_function_get_bind_data)(duckdb_function_info info); void *(*duckdb_scalar_function_bind_get_extra_info)(duckdb_bind_info info); + idx_t (*duckdb_scalar_function_bind_get_argument_count)(duckdb_bind_info info); + duckdb_expression (*duckdb_scalar_function_bind_get_argument)(duckdb_bind_info info, idx_t index); #endif // New string functions that are added @@ -576,11 +608,18 @@ typedef struct { char *(*duckdb_value_to_string)(duckdb_value value); #endif +// New functions around table function binding +#ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE + void (*duckdb_table_function_get_client_context)(duckdb_bind_info info, duckdb_client_context *out_context); +#endif + // New value functions that are added #ifdef DUCKDB_EXTENSION_API_VERSION_UNSTABLE duckdb_value (*duckdb_create_map_value)(duckdb_logical_type map_type, duckdb_value *keys, duckdb_value *values, idx_t entry_count); duckdb_value (*duckdb_create_union_value)(duckdb_logical_type union_type, idx_t tag_index, duckdb_value value); + duckdb_value (*duckdb_create_time_ns)(duckdb_time_ns input); + duckdb_time_ns (*duckdb_get_time_ns)(duckdb_value val); #endif // API to create and manipulate vector types @@ -712,7 +751,7 @@ typedef struct { #define duckdb_create_int64 duckdb_ext_api.duckdb_create_int64 #define duckdb_create_hugeint duckdb_ext_api.duckdb_create_hugeint #define duckdb_create_uhugeint duckdb_ext_api.duckdb_create_uhugeint -#define duckdb_create_varint duckdb_ext_api.duckdb_create_varint +#define duckdb_create_bignum duckdb_ext_api.duckdb_create_bignum #define duckdb_create_decimal duckdb_ext_api.duckdb_create_decimal #define duckdb_create_float duckdb_ext_api.duckdb_create_float #define duckdb_create_double duckdb_ext_api.duckdb_create_double @@ -739,7 +778,7 @@ typedef struct { #define duckdb_get_uint64 duckdb_ext_api.duckdb_get_uint64 #define duckdb_get_hugeint duckdb_ext_api.duckdb_get_hugeint #define duckdb_get_uhugeint duckdb_ext_api.duckdb_get_uhugeint -#define duckdb_get_varint duckdb_ext_api.duckdb_get_varint +#define duckdb_get_bignum duckdb_ext_api.duckdb_get_bignum #define duckdb_get_decimal duckdb_ext_api.duckdb_get_decimal #define duckdb_get_float duckdb_ext_api.duckdb_get_float #define duckdb_get_double duckdb_ext_api.duckdb_get_double @@ -1020,6 +1059,13 @@ typedef struct { #define duckdb_appender_error_data duckdb_ext_api.duckdb_appender_error_data #define duckdb_append_default_to_chunk duckdb_ext_api.duckdb_append_default_to_chunk +// Version unstable_new_arrow_functions +#define duckdb_to_arrow_schema duckdb_ext_api.duckdb_to_arrow_schema +#define duckdb_data_chunk_to_arrow duckdb_ext_api.duckdb_data_chunk_to_arrow +#define duckdb_schema_from_arrow duckdb_ext_api.duckdb_schema_from_arrow +#define duckdb_data_chunk_from_arrow duckdb_ext_api.duckdb_data_chunk_from_arrow +#define duckdb_destroy_arrow_converted_schema duckdb_ext_api.duckdb_destroy_arrow_converted_schema + // Version unstable_new_error_data_functions #define duckdb_create_error_data duckdb_ext_api.duckdb_create_error_data #define duckdb_destroy_error_data duckdb_ext_api.duckdb_destroy_error_data @@ -1027,24 +1073,42 @@ typedef struct { #define duckdb_error_data_message duckdb_ext_api.duckdb_error_data_message #define duckdb_error_data_has_error duckdb_ext_api.duckdb_error_data_has_error +// Version unstable_new_expression_functions +#define duckdb_destroy_expression duckdb_ext_api.duckdb_destroy_expression +#define duckdb_expression_return_type duckdb_ext_api.duckdb_expression_return_type +#define duckdb_expression_is_foldable duckdb_ext_api.duckdb_expression_is_foldable +#define duckdb_expression_fold duckdb_ext_api.duckdb_expression_fold + // Version unstable_new_open_connect_functions #define duckdb_connection_get_client_context duckdb_ext_api.duckdb_connection_get_client_context +#define duckdb_connection_get_arrow_options duckdb_ext_api.duckdb_connection_get_arrow_options #define duckdb_client_context_get_connection_id duckdb_ext_api.duckdb_client_context_get_connection_id #define duckdb_destroy_client_context duckdb_ext_api.duckdb_destroy_client_context +#define duckdb_destroy_arrow_options duckdb_ext_api.duckdb_destroy_arrow_options #define duckdb_get_table_names duckdb_ext_api.duckdb_get_table_names +// Version unstable_new_query_execution_functions +#define duckdb_result_get_arrow_options duckdb_ext_api.duckdb_result_get_arrow_options + // Version unstable_new_scalar_function_functions -#define duckdb_scalar_function_set_bind duckdb_ext_api.duckdb_scalar_function_set_bind -#define duckdb_scalar_function_set_bind_data duckdb_ext_api.duckdb_scalar_function_set_bind_data -#define duckdb_scalar_function_bind_set_error duckdb_ext_api.duckdb_scalar_function_bind_set_error -#define duckdb_scalar_function_bind_get_extra_info duckdb_ext_api.duckdb_scalar_function_bind_get_extra_info -#define duckdb_scalar_function_get_bind_data duckdb_ext_api.duckdb_scalar_function_get_bind_data -#define duckdb_scalar_function_get_client_context duckdb_ext_api.duckdb_scalar_function_get_client_context +#define duckdb_scalar_function_set_bind duckdb_ext_api.duckdb_scalar_function_set_bind +#define duckdb_scalar_function_set_bind_data duckdb_ext_api.duckdb_scalar_function_set_bind_data +#define duckdb_scalar_function_bind_set_error duckdb_ext_api.duckdb_scalar_function_bind_set_error +#define duckdb_scalar_function_bind_get_extra_info duckdb_ext_api.duckdb_scalar_function_bind_get_extra_info +#define duckdb_scalar_function_get_bind_data duckdb_ext_api.duckdb_scalar_function_get_bind_data +#define duckdb_scalar_function_get_client_context duckdb_ext_api.duckdb_scalar_function_get_client_context +#define duckdb_scalar_function_bind_get_argument_count duckdb_ext_api.duckdb_scalar_function_bind_get_argument_count +#define duckdb_scalar_function_bind_get_argument duckdb_ext_api.duckdb_scalar_function_bind_get_argument // Version unstable_new_string_functions #define duckdb_value_to_string duckdb_ext_api.duckdb_value_to_string +// Version unstable_new_table_function_functions +#define duckdb_table_function_get_client_context duckdb_ext_api.duckdb_table_function_get_client_context + // Version unstable_new_value_functions +#define duckdb_create_time_ns duckdb_ext_api.duckdb_create_time_ns +#define duckdb_get_time_ns duckdb_ext_api.duckdb_get_time_ns #define duckdb_create_map_value duckdb_ext_api.duckdb_create_map_value #define duckdb_create_union_value duckdb_ext_api.duckdb_create_union_value diff --git a/src/duckdb/src/main/attached_database.cpp b/src/duckdb/src/main/attached_database.cpp index b5ee4a3a0..cbd5e04d3 100644 --- a/src/duckdb/src/main/attached_database.cpp +++ b/src/duckdb/src/main/attached_database.cpp @@ -21,10 +21,10 @@ AttachOptions::AttachOptions(const DBConfigOptions &options) : access_mode(options.access_mode), db_type(options.database_type) { } -AttachOptions::AttachOptions(const unique_ptr &info, const AccessMode default_access_mode) +AttachOptions::AttachOptions(const unordered_map &attach_options, const AccessMode default_access_mode) : access_mode(default_access_mode) { - for (auto &entry : info->options) { + for (auto &entry : attach_options) { if (entry.first == "readonly" || entry.first == "read_only") { // Extract the read access mode. @@ -58,8 +58,7 @@ AttachOptions::AttachOptions(const unique_ptr &info, const AccessMod default_table = QualifiedName::Parse(StringValue::Get(entry.second.DefaultCastAs(LogicalType::VARCHAR))); continue; } - - options[entry.first] = entry.second; + options.emplace(entry.first, entry.second); } } @@ -75,7 +74,9 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, AttachedDatabaseType ty // This database does not have storage, or uses temporary_objects for in-memory storage. D_ASSERT(type == AttachedDatabaseType::TEMP_DATABASE || type == AttachedDatabaseType::SYSTEM_DATABASE); if (type == AttachedDatabaseType::TEMP_DATABASE) { - storage = make_uniq(*this, string(IN_MEMORY_PATH), false); + unordered_map options; + AttachOptions attach_options(options, AccessMode::READ_WRITE); + storage = make_uniq(*this, string(IN_MEMORY_PATH), attach_options); } catalog = make_uniq(*this); @@ -92,28 +93,9 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, str } else { type = AttachedDatabaseType::READ_WRITE_DATABASE; } - for (auto &entry : options.options) { - if (StringUtil::CIEquals(entry.first, "block_size")) { - continue; - } - if (StringUtil::CIEquals(entry.first, "encryption_key")) { - continue; - } - if (StringUtil::CIEquals(entry.first, "encryption_cipher")) { - continue; - } - if (StringUtil::CIEquals(entry.first, "row_group_size")) { - continue; - } - if (StringUtil::CIEquals(entry.first, "storage_version")) { - continue; - } - throw BinderException("Unrecognized option for attach \"%s\"", entry.first); - } // We create the storage after the catalog to guarantee we allow extensions to instantiate the DuckCatalog. catalog = make_uniq(*this); - auto read_only = options.access_mode == AccessMode::READ_ONLY; - storage = make_uniq(*this, std::move(file_path_p), read_only); + storage = make_uniq(*this, std::move(file_path_p), options); transaction_manager = make_uniq(*this); internal = true; } @@ -128,15 +110,14 @@ AttachedDatabase::AttachedDatabase(DatabaseInstance &db, Catalog &catalog_p, Sto type = AttachedDatabaseType::READ_WRITE_DATABASE; } - StorageExtensionInfo *storage_info = storage_extension->storage_info.get(); - catalog = storage_extension->attach(storage_info, context, *this, name, info, options.access_mode); + optional_ptr storage_info = storage_extension->storage_info.get(); + catalog = storage_extension->attach(storage_info, context, *this, name, info, options); if (!catalog) { throw InternalException("AttachedDatabase - attach function did not return a catalog"); } if (catalog->IsDuckCatalog()) { // The attached database uses the DuckCatalog. - auto read_only = options.access_mode == AccessMode::READ_ONLY; - storage = make_uniq(*this, info.path, read_only); + storage = make_uniq(*this, info.path, options); } transaction_manager = storage_extension->create_transaction_manager(storage_info, *this, *catalog); if (!transaction_manager) { @@ -166,25 +147,31 @@ bool AttachedDatabase::NameIsReserved(const string &name) { return name == DEFAULT_SCHEMA || name == TEMP_CATALOG || name == SYSTEM_CATALOG; } +static string RemoveQueryParams(const string &name) { + auto vec = StringUtil::Split(name, "?"); + D_ASSERT(!vec.empty()); + return vec[0]; +} + string AttachedDatabase::ExtractDatabaseName(const string &dbpath, FileSystem &fs) { if (dbpath.empty() || dbpath == IN_MEMORY_PATH) { return "memory"; } - auto name = fs.ExtractBaseName(dbpath); + auto name = RemoveQueryParams(fs.ExtractBaseName(dbpath)); if (NameIsReserved(name)) { name += "_db"; } return name; } -void AttachedDatabase::Initialize(optional_ptr context, StorageOptions options) { +void AttachedDatabase::Initialize(optional_ptr context) { if (IsSystem()) { catalog->Initialize(context, true); } else { catalog->Initialize(context, false); } if (storage) { - storage->Initialize(QueryContext(context), options); + storage->Initialize(QueryContext(context)); } } diff --git a/src/duckdb/src/main/capi/arrow-c.cpp b/src/duckdb/src/main/capi/arrow-c.cpp index ab567393e..204f092ef 100644 --- a/src/duckdb/src/main/capi/arrow-c.cpp +++ b/src/duckdb/src/main/capi/arrow-c.cpp @@ -3,9 +3,11 @@ #include "duckdb/function/table/arrow.hpp" #include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/main/prepared_statement_data.hpp" +#include "fmt/format.h" using duckdb::ArrowConverter; using duckdb::ArrowResultWrapper; +using duckdb::CClientArrowOptionsWrapper; using duckdb::Connection; using duckdb::DataChunk; using duckdb::LogicalType; @@ -14,6 +16,146 @@ using duckdb::PreparedStatementWrapper; using duckdb::QueryResult; using duckdb::QueryResultType; +duckdb_error_data duckdb_to_arrow_schema(duckdb_arrow_options arrow_options, duckdb_logical_type *types, + const char **names, idx_t column_count, struct ArrowSchema *out_schema) { + + if (!types || !names || !arrow_options || !out_schema) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Invalid argument(s) to duckdb_to_arrow_schema"); + } + duckdb::vector schema_types; + duckdb::vector schema_names; + for (idx_t i = 0; i < column_count; i++) { + schema_names.emplace_back(names[i]); + schema_types.emplace_back(*reinterpret_cast(types[i])); + } + const auto arrow_options_wrapper = reinterpret_cast(arrow_options); + try { + ArrowConverter::ToArrowSchema(out_schema, schema_types, schema_names, arrow_options_wrapper->properties); + } catch (const duckdb::Exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (const std::exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (...) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Unknown error occurred during conversion"); + } + return nullptr; +} + +duckdb_error_data duckdb_data_chunk_to_arrow(duckdb_arrow_options arrow_options, duckdb_data_chunk chunk, + struct ArrowArray *out_arrow_array) { + if (!arrow_options || !chunk || !out_arrow_array) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, + "Invalid argument(s) to duckdb_data_chunk_to_arrow"); + } + auto dchunk = reinterpret_cast(chunk); + auto arrow_options_wrapper = reinterpret_cast(arrow_options); + auto extension_type_cast = duckdb::ArrowTypeExtensionData::GetExtensionTypes( + *arrow_options_wrapper->properties.client_context, dchunk->GetTypes()); + + try { + ArrowConverter::ToArrowArray(*dchunk, out_arrow_array, arrow_options_wrapper->properties, extension_type_cast); + } catch (const duckdb::Exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (const std::exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (...) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Unknown error occurred during conversion"); + } + return nullptr; +} + +duckdb_error_data duckdb_schema_from_arrow(duckdb_connection connection, struct ArrowSchema *schema, + duckdb_arrow_converted_schema *out_types) { + if (!connection || !out_types || !schema) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, + "Invalid argument(s) to duckdb_data_chunk_to_arrow"); + } + duckdb::vector names; + const auto conn = reinterpret_cast(connection); + auto arrow_table = duckdb::make_uniq(); + try { + duckdb::vector return_types; + duckdb::ArrowTableFunction::PopulateArrowTableSchema(duckdb::DBConfig::GetConfig(*conn->context), *arrow_table, + *schema); + } catch (const duckdb::Exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (const std::exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (...) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Unknown error occurred during conversion"); + } + *out_types = reinterpret_cast(arrow_table.release()); + return nullptr; +} + +duckdb_error_data duckdb_data_chunk_from_arrow(duckdb_connection connection, struct ArrowArray *arrow_array, + duckdb_arrow_converted_schema converted_schema, + duckdb_data_chunk *out_chunk) { + if (!connection || !converted_schema || !out_chunk || !arrow_array) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, + "Invalid argument(s) to duckdb_data_chunk_to_arrow"); + } + auto arrow_table = reinterpret_cast(converted_schema); + auto conn = reinterpret_cast(connection); + auto &types = arrow_table->GetTypes(); + + auto dchunk = duckdb::make_uniq(); + dchunk->Initialize(duckdb::Allocator::DefaultAllocator(), types, duckdb::NumericCast(arrow_array->length)); + + auto &arrow_types = arrow_table->GetColumns(); + dchunk->SetCardinality(duckdb::NumericCast(arrow_array->length)); + for (idx_t i = 0; i < dchunk->ColumnCount(); i++) { + auto &parent_array = *arrow_array; + auto &array = parent_array.children[i]; + auto arrow_type = arrow_types.at(i); + auto array_physical_type = arrow_type->GetPhysicalType(); + auto array_state = duckdb::make_uniq(*conn->context); + // We need to make sure that our chunk will hold the ownership + array_state->owned_data = duckdb::make_shared_ptr(); + array_state->owned_data->arrow_array = *arrow_array; + // We set it to nullptr to effectively transfer the ownership + arrow_array->release = nullptr; + try { + switch (array_physical_type) { + case duckdb::ArrowArrayPhysicalType::DICTIONARY_ENCODED: + duckdb::ArrowToDuckDBConversion::ColumnArrowToDuckDBDictionary(dchunk->data[i], *array, 0, *array_state, + dchunk->size(), *arrow_type); + break; + case duckdb::ArrowArrayPhysicalType::RUN_END_ENCODED: + duckdb::ArrowToDuckDBConversion::ColumnArrowToDuckDBRunEndEncoded( + dchunk->data[i], *array, 0, *array_state, dchunk->size(), *arrow_type); + break; + case duckdb::ArrowArrayPhysicalType::DEFAULT: + duckdb::ArrowToDuckDBConversion::SetValidityMask(dchunk->data[i], *array, 0, dchunk->size(), + parent_array.offset, -1); + + duckdb::ArrowToDuckDBConversion::ColumnArrowToDuckDB(dchunk->data[i], *array, 0, *array_state, + dchunk->size(), *arrow_type); + break; + default: + return duckdb_create_error_data(DUCKDB_ERROR_NOT_IMPLEMENTED, + "Only Default Physical Types are currently supported"); + } + } catch (const duckdb::Exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (const std::exception &ex) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (...) { + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "Unknown error occurred during conversion"); + } + } + *out_chunk = reinterpret_cast(dchunk.release()); + return nullptr; +} + +void duckdb_destroy_arrow_converted_schema(duckdb_arrow_converted_schema *arrow_converted_schema) { + if (arrow_converted_schema && *arrow_converted_schema) { + auto converted_schema = reinterpret_cast(*arrow_converted_schema); + delete converted_schema; + *arrow_converted_schema = nullptr; + } +} + duckdb_state duckdb_query_arrow(duckdb_connection connection, const char *query, duckdb_arrow *out_result) { Connection *conn = (Connection *)connection; auto wrapper = new ArrowResultWrapper(); diff --git a/src/duckdb/src/main/capi/config-c.cpp b/src/duckdb/src/main/capi/config-c.cpp index 68605056d..fbb734dcf 100644 --- a/src/duckdb/src/main/capi/config-c.cpp +++ b/src/duckdb/src/main/capi/config-c.cpp @@ -23,7 +23,8 @@ duckdb_state duckdb_create_config(duckdb_config *out_config) { } size_t duckdb_config_count() { - return DBConfig::GetOptionCount() + duckdb::ExtensionHelper::ArraySize(duckdb::EXTENSION_SETTINGS); + return DBConfig::GetOptionCount() + DBConfig::GetAliasCount() + + duckdb::ExtensionHelper::ArraySize(duckdb::EXTENSION_SETTINGS); } duckdb_state duckdb_get_config_flag(size_t index, const char **out_name, const char **out_description) { @@ -37,9 +38,23 @@ duckdb_state duckdb_get_config_flag(size_t index, const char **out_name, const c } return DuckDBSuccess; } + // alias + index -= DBConfig::GetOptionCount(); + auto alias = DBConfig::GetAliasByIndex(index); + if (alias) { + if (out_name) { + *out_name = alias->alias; + } + option = DBConfig::GetOptionByIndex(alias->option_index); + if (out_description) { + *out_description = option->description; + } + return DuckDBSuccess; + } + index -= DBConfig::GetAliasCount(); - // extension index? - auto entry = duckdb::ExtensionHelper::GetArrayEntry(duckdb::EXTENSION_SETTINGS, index - DBConfig::GetOptionCount()); + // extension index + auto entry = duckdb::ExtensionHelper::GetArrayEntry(duckdb::EXTENSION_SETTINGS, index); if (!entry) { return DuckDBError; } diff --git a/src/duckdb/src/main/capi/duckdb-c.cpp b/src/duckdb/src/main/capi/duckdb-c.cpp index a1d7657a4..43205cdf0 100644 --- a/src/duckdb/src/main/capi/duckdb-c.cpp +++ b/src/duckdb/src/main/capi/duckdb-c.cpp @@ -1,5 +1,6 @@ #include "duckdb/main/capi/capi_internal.hpp" +using duckdb::CClientArrowOptionsWrapper; using duckdb::CClientContextWrapper; using duckdb::Connection; using duckdb::DatabaseWrapper; @@ -151,6 +152,16 @@ void duckdb_connection_get_client_context(duckdb_connection connection, duckdb_c *out_context = reinterpret_cast(wrapper); } +void duckdb_connection_get_arrow_options(duckdb_connection connection, duckdb_arrow_options *out_arrow_options) { + if (!connection || !out_arrow_options) { + return; + } + Connection *conn = reinterpret_cast(connection); + auto client_properties = conn->context->GetClientProperties(); + auto wrapper = new CClientArrowOptionsWrapper(client_properties); + *out_arrow_options = reinterpret_cast(wrapper); +} + idx_t duckdb_client_context_get_connection_id(duckdb_client_context context) { auto wrapper = reinterpret_cast(context); return wrapper->context.GetConnectionId(); @@ -164,6 +175,14 @@ void duckdb_destroy_client_context(duckdb_client_context *context) { } } +void duckdb_destroy_arrow_options(duckdb_arrow_options *arrow_options) { + if (arrow_options && *arrow_options) { + auto wrapper = reinterpret_cast(*arrow_options); + delete wrapper; + *arrow_options = nullptr; + } +} + duckdb_state duckdb_query(duckdb_connection connection, const char *query, duckdb_result *out) { Connection *conn = reinterpret_cast(connection); auto result = conn->Query(query); diff --git a/src/duckdb/src/main/capi/duckdb_value-c.cpp b/src/duckdb/src/main/capi/duckdb_value-c.cpp index a74bae4f6..9c91df242 100644 --- a/src/duckdb/src/main/capi/duckdb_value-c.cpp +++ b/src/duckdb/src/main/capi/duckdb_value-c.cpp @@ -5,7 +5,7 @@ #include "duckdb/common/types/string_type.hpp" #include "duckdb/common/types/uuid.hpp" #include "duckdb/common/types/value.hpp" -#include "duckdb/common/types/varint.hpp" +#include "duckdb/common/types/bignum.hpp" #include "duckdb/main/capi/capi_internal.hpp" using duckdb::LogicalTypeId; @@ -119,16 +119,16 @@ duckdb_uhugeint duckdb_get_uhugeint(duckdb_value val) { auto res = CAPIGetValue(val); return {res.lower, res.upper}; } -duckdb_value duckdb_create_varint(duckdb_varint input) { +duckdb_value duckdb_create_bignum(duckdb_bignum input) { return WrapValue(new duckdb::Value( - duckdb::Value::VARINT(duckdb::Varint::FromByteArray(input.data, input.size, input.is_negative)))); + duckdb::Value::BIGNUM(duckdb::Bignum::FromByteArray(input.data, input.size, input.is_negative)))); } -duckdb_varint duckdb_get_varint(duckdb_value val) { - auto v = UnwrapValue(val).DefaultCastAs(duckdb::LogicalType::VARINT); +duckdb_bignum duckdb_get_bignum(duckdb_value val) { + auto v = UnwrapValue(val).DefaultCastAs(duckdb::LogicalType::BIGNUM); auto &str = duckdb::StringValue::Get(v); duckdb::vector byte_array; bool is_negative; - duckdb::Varint::GetByteArray(byte_array, is_negative, duckdb::string_t(str)); + duckdb::Bignum::GetByteArray(byte_array, is_negative, duckdb::string_t(str)); auto size = byte_array.size(); auto data = reinterpret_cast(malloc(size)); memcpy(data, byte_array.data(), size); @@ -186,6 +186,12 @@ duckdb_value duckdb_create_time_tz_value(duckdb_time_tz input) { duckdb_time_tz duckdb_get_time_tz(duckdb_value val) { return {CAPIGetValue(val).bits}; } +duckdb_value duckdb_create_time_ns(duckdb_time_ns input) { + return CAPICreateValue(duckdb::dtime_ns_t(input.nanos)); +} +duckdb_time_ns duckdb_get_time_ns(duckdb_value val) { + return {CAPIGetValue(val).micros}; +} duckdb_value duckdb_create_timestamp(duckdb_timestamp input) { duckdb::timestamp_t ts(input.micros); @@ -352,7 +358,7 @@ duckdb_value duckdb_create_list_value(duckdb_logical_type type, duckdb_value *va } unwrapped_values.push_back(UnwrapValue(value)); } - duckdb::Value *list_value = new duckdb::Value; + auto list_value = new duckdb::Value; try { *list_value = duckdb::Value::LIST(logical_type, std::move(unwrapped_values)); } catch (...) { diff --git a/src/duckdb/src/main/capi/expression-c.cpp b/src/duckdb/src/main/capi/expression-c.cpp new file mode 100644 index 000000000..c416e03e5 --- /dev/null +++ b/src/duckdb/src/main/capi/expression-c.cpp @@ -0,0 +1,57 @@ +#include "duckdb/main/capi/capi_internal.hpp" + +#include "duckdb/execution/expression_executor.hpp" + +using duckdb::CClientContextWrapper; +using duckdb::ExpressionWrapper; + +void duckdb_destroy_expression(duckdb_expression *expr) { + if (!expr || !*expr) { + return; + } + auto wrapper = reinterpret_cast(*expr); + delete wrapper; + *expr = nullptr; +} + +duckdb_logical_type duckdb_expression_return_type(duckdb_expression expr) { + if (!expr) { + return nullptr; + } + auto wrapper = reinterpret_cast(expr); + auto logical_type = new duckdb::LogicalType(wrapper->expr->return_type); + return reinterpret_cast(logical_type); +} + +bool duckdb_expression_is_foldable(duckdb_expression expr) { + if (!expr) { + return false; + } + auto wrapper = reinterpret_cast(expr); + return wrapper->expr->IsFoldable(); +} + +duckdb_error_data duckdb_expression_fold(duckdb_client_context context, duckdb_expression expr, + duckdb_value *out_value) { + if (!expr || !duckdb_expression_is_foldable(expr)) { + return nullptr; + } + + auto value = new duckdb::Value; + try { + auto context_wrapper = reinterpret_cast(context); + auto expr_wrapper = reinterpret_cast(expr); + *value = duckdb::ExpressionExecutor::EvaluateScalar(context_wrapper->context, *expr_wrapper->expr); + *out_value = reinterpret_cast(value); + } catch (const duckdb::Exception &ex) { + delete value; + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (const std::exception &ex) { + delete value; + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, ex.what()); + } catch (...) { + delete value; + return duckdb_create_error_data(DUCKDB_ERROR_INVALID_INPUT, "unknown error occurred during folding"); + } + return nullptr; +} diff --git a/src/duckdb/src/main/capi/helper-c.cpp b/src/duckdb/src/main/capi/helper-c.cpp index 55aecc19f..c936bab59 100644 --- a/src/duckdb/src/main/capi/helper-c.cpp +++ b/src/duckdb/src/main/capi/helper-c.cpp @@ -74,8 +74,8 @@ LogicalTypeId LogicalTypeIdFromC(const duckdb_type type) { return LogicalTypeId::TIMESTAMP_TZ; case DUCKDB_TYPE_ANY: return LogicalTypeId::ANY; - case DUCKDB_TYPE_VARINT: - return LogicalTypeId::VARINT; + case DUCKDB_TYPE_BIGNUM: + return LogicalTypeId::BIGNUM; case DUCKDB_TYPE_SQLNULL: return LogicalTypeId::SQLNULL; case DUCKDB_TYPE_STRING_LITERAL: @@ -140,8 +140,8 @@ duckdb_type LogicalTypeIdToC(const LogicalTypeId type) { return DUCKDB_TYPE_BLOB; case LogicalTypeId::BIT: return DUCKDB_TYPE_BIT; - case LogicalTypeId::VARINT: - return DUCKDB_TYPE_VARINT; + case LogicalTypeId::BIGNUM: + return DUCKDB_TYPE_BIGNUM; case LogicalTypeId::INTERVAL: return DUCKDB_TYPE_INTERVAL; case LogicalTypeId::DECIMAL: diff --git a/src/duckdb/src/main/capi/result-c.cpp b/src/duckdb/src/main/capi/result-c.cpp index f2ab54e37..247b9f67f 100644 --- a/src/duckdb/src/main/capi/result-c.cpp +++ b/src/duckdb/src/main/capi/result-c.cpp @@ -443,6 +443,18 @@ duckdb_logical_type duckdb_column_logical_type(duckdb_result *result, idx_t col) return reinterpret_cast(new duckdb::LogicalType(result_data.result->types[col])); } +duckdb_arrow_options duckdb_result_get_arrow_options(duckdb_result *result) { + if (!result) { + return nullptr; + } + auto &result_data = *(reinterpret_cast(result->internal_data)); + if (!result_data.result) { + return nullptr; + } + auto arrow_options_wrapper = new duckdb::CClientArrowOptionsWrapper(result_data.result->client_properties); + return reinterpret_cast(arrow_options_wrapper); +} + idx_t duckdb_column_count(duckdb_result *result) { if (!result) { return 0; diff --git a/src/duckdb/src/main/capi/scalar_function-c.cpp b/src/duckdb/src/main/capi/scalar_function-c.cpp index f1189532f..eba21dbfc 100644 --- a/src/duckdb/src/main/capi/scalar_function-c.cpp +++ b/src/duckdb/src/main/capi/scalar_function-c.cpp @@ -148,6 +148,7 @@ void CAPIScalarFunction(DataChunk &input, ExpressionState &state, Vector &result } // namespace duckdb +using duckdb::ExpressionWrapper; using duckdb::GetCScalarFunction; using duckdb::GetCScalarFunctionBindInfo; using duckdb::GetCScalarFunctionInfo; @@ -270,6 +271,24 @@ void duckdb_scalar_function_bind_set_error(duckdb_bind_info info, const char *er bind_info.success = false; } +idx_t duckdb_scalar_function_bind_get_argument_count(duckdb_bind_info info) { + if (!info) { + return 0; + } + auto &bind_info = GetCScalarFunctionBindInfo(info); + return bind_info.arguments.size(); +} + +duckdb_expression duckdb_scalar_function_bind_get_argument(duckdb_bind_info info, idx_t index) { + if (!info || index >= duckdb_scalar_function_bind_get_argument_count(info)) { + return nullptr; + } + auto &bind_info = GetCScalarFunctionBindInfo(info); + auto wrapper = new ExpressionWrapper(); + wrapper->expr = bind_info.arguments[index]->Copy(); + return reinterpret_cast(wrapper); +} + void duckdb_scalar_function_set_extra_info(duckdb_scalar_function function, void *extra_info, duckdb_delete_callback_t destroy) { if (!function || !extra_info) { diff --git a/src/duckdb/src/main/capi/table_function-c.cpp b/src/duckdb/src/main/capi/table_function-c.cpp index b0e284d3a..268d9a04a 100644 --- a/src/duckdb/src/main/capi/table_function-c.cpp +++ b/src/duckdb/src/main/capi/table_function-c.cpp @@ -368,6 +368,15 @@ void *duckdb_bind_get_extra_info(duckdb_bind_info info) { return bind_info.function_info.extra_info; } +void duckdb_table_function_get_client_context(duckdb_bind_info info, duckdb_client_context *out_context) { + if (!info || !out_context) { + return; + } + auto &bind_info = GetCTableFunctionBindInfo(info); + auto wrapper = new duckdb::CClientContextWrapper(bind_info.context); + *out_context = reinterpret_cast(wrapper); +} + void duckdb_bind_add_result_column(duckdb_bind_info info, const char *name, duckdb_logical_type type) { if (!info || !name || !type) { return; diff --git a/src/duckdb/src/main/client_config.cpp b/src/duckdb/src/main/client_config.cpp index c3350a6be..97df5c6a9 100644 --- a/src/duckdb/src/main/client_config.cpp +++ b/src/duckdb/src/main/client_config.cpp @@ -4,6 +4,27 @@ namespace duckdb { +bool ClientConfig::AnyVerification() const { + return query_verification_enabled || verify_external || verify_serializer || verify_fetch_row; +} + +void ClientConfig::SetUserVariable(const string &name, Value value) { + user_variables[name] = std::move(value); +} + +bool ClientConfig::GetUserVariable(const string &name, Value &result) { + auto entry = user_variables.find(name); + if (entry == user_variables.end()) { + return false; + } + result = entry->second; + return true; +} + +void ClientConfig::ResetUserVariable(const string &name) { + user_variables.erase(name); +} + void ClientConfig::SetDefaultStreamingBufferSize() { auto memory = FileSystem::GetAvailableMemory(); auto default_size = ClientConfig().streaming_buffer_size; diff --git a/src/duckdb/src/main/client_context.cpp b/src/duckdb/src/main/client_context.cpp index ab489c4df..c4d98192c 100644 --- a/src/duckdb/src/main/client_context.cpp +++ b/src/duckdb/src/main/client_context.cpp @@ -44,6 +44,8 @@ #include "duckdb/transaction/transaction_context.hpp" #include "duckdb/transaction/transaction_manager.hpp" #include "duckdb/logging/log_type.hpp" +#include "duckdb/logging/log_manager.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -822,8 +824,7 @@ unique_ptr ClientContext::PendingStatementOrPreparedStatemen statement = statement->Copy(); } #endif - // check if we are on AutoCommit. In this case we should start a transaction. - if (statement && config.AnyVerification()) { + if (statement && config.query_verification_enabled) { // query verification is enabled // create a copy of the statement, and use the copy // this way we verify that the copy correctly copies all properties @@ -985,6 +986,12 @@ unique_ptr ClientContext::Query(const string &query, bool allow_str } else { current_result = ExecutePendingQueryInternal(*lock, *pending_query); } + if (current_result->HasError()) { + // Reset the interrupted flag, this was set by the task that found the error + // Next statements should not be bothered by that interruption + interrupted = false; + return current_result; + } // now append the result to the list of results if (!last_result || !last_had_result) { // first result of the query @@ -1001,12 +1008,6 @@ unique_ptr ClientContext::Query(const string &query, bool allow_str last_result = last_result->next.get(); } D_ASSERT(last_result); - if (last_result->HasError()) { - // Reset the interrupted flag, this was set by the task that found the error - // Next statements should not be bothered by that interruption - interrupted = false; - break; - } } return result; } @@ -1349,7 +1350,7 @@ SettingLookupResult ClientContext::TryGetCurrentSetting(const std::string &key, // first check the built-in settings auto &db_config = DBConfig::GetConfig(*this); auto option = db_config.GetOptionByName(key); - if (option) { + if (option && option->get_setting) { result = option->get_setting(*this); return SettingLookupResult(SettingScope::LOCAL); } @@ -1370,8 +1371,8 @@ SettingLookupResult ClientContext::TryGetCurrentSetting(const std::string &key, ParserOptions ClientContext::GetParserOptions() const { auto &client_config = ClientConfig::GetConfig(*this); ParserOptions options; - options.preserve_identifier_case = client_config.preserve_identifier_case; - options.integer_division = client_config.integer_division; + options.preserve_identifier_case = DBConfig::GetSetting(*this); + options.integer_division = DBConfig::GetSetting(*this); options.max_expression_depth = client_config.max_expression_depth; options.extensions = &DBConfig::GetConfig(*this).parser_extensions; return options; @@ -1384,12 +1385,20 @@ ClientProperties ClientContext::GetClientProperties() { if (TryGetCurrentSetting("TimeZone", result)) { timezone = result.ToString(); } + ArrowOffsetSize arrow_offset_size = ArrowOffsetSize::REGULAR; + if (DBConfig::GetSetting(*this)) { + arrow_offset_size = ArrowOffsetSize::LARGE; + } + bool arrow_use_list_view = DBConfig::GetSetting(*this); + bool arrow_lossless_conversion = DBConfig::GetSetting(*this); + bool arrow_use_string_view = DBConfig::GetSetting(*this); + auto arrow_format_version = DBConfig::GetSetting(*this); return {timezone, - db->config.options.arrow_offset_size, - db->config.options.arrow_use_list_view, - db->config.options.produce_arrow_string_views, - db->config.options.arrow_lossless_conversion, - db->config.options.arrow_output_version, + arrow_offset_size, + arrow_use_list_view, + arrow_use_string_view, + arrow_lossless_conversion, + arrow_format_version, this}; } diff --git a/src/duckdb/src/main/client_context_file_opener.cpp b/src/duckdb/src/main/client_context_file_opener.cpp index 24ed2aa66..44397f079 100644 --- a/src/duckdb/src/main/client_context_file_opener.cpp +++ b/src/duckdb/src/main/client_context_file_opener.cpp @@ -4,6 +4,7 @@ #include "duckdb/main/database.hpp" #include "duckdb/common/file_opener.hpp" #include "duckdb/main/client_context.hpp" +#include "duckdb/logging/log_manager.hpp" namespace duckdb { diff --git a/src/duckdb/src/main/client_data.cpp b/src/duckdb/src/main/client_data.cpp index 3348435b9..1348c0b09 100644 --- a/src/duckdb/src/main/client_data.cpp +++ b/src/duckdb/src/main/client_data.cpp @@ -155,9 +155,9 @@ class ClientBufferManager : public BufferManager { void WriteTemporaryBuffer(MemoryTag tag, block_id_t block_id, FileBuffer &buffer) override { return buffer_manager.WriteTemporaryBuffer(tag, block_id, buffer); } - unique_ptr ReadTemporaryBuffer(MemoryTag tag, BlockHandle &block, + unique_ptr ReadTemporaryBuffer(QueryContext context, MemoryTag tag, BlockHandle &block, unique_ptr buffer) override { - return buffer_manager.ReadTemporaryBuffer(tag, block, std::move(buffer)); + return buffer_manager.ReadTemporaryBuffer(context, tag, block, std::move(buffer)); } void DeleteTemporaryFile(BlockHandle &block) override { return buffer_manager.DeleteTemporaryFile(block); diff --git a/src/duckdb/src/main/client_verify.cpp b/src/duckdb/src/main/client_verify.cpp index d3fb4f380..05b190b07 100644 --- a/src/duckdb/src/main/client_verify.cpp +++ b/src/duckdb/src/main/client_verify.cpp @@ -148,12 +148,20 @@ ErrorData ClientContext::VerifyQuery(ClientContextLock &lock, const string &quer auto original_named_param_map = statement_copy_for_explain->named_param_map; auto explain_stmt = make_uniq(std::move(statement_copy_for_explain)); explain_stmt->named_param_map = original_named_param_map; - try { - RunStatementInternal(lock, explain_q, std::move(explain_stmt), false, parameters, false); - } catch (std::exception &ex) { // LCOV_EXCL_START - ErrorData error(ex); - interrupted = false; - return ErrorData("EXPLAIN failed but query did not (" + error.RawMessage() + ")"); + + auto explain_statement_verifier = + StatementVerifier::Create(VerificationType::EXPLAIN, *explain_stmt, parameters); + const auto explain_failed = explain_statement_verifier->Run( + *this, explain_q, + [&](const string &q, unique_ptr s, + optional_ptr> params) { + return RunStatementInternal(lock, q, std::move(s), false, params, false); + }); + + if (explain_failed) { // LCOV_EXCL_START + const auto &explain_error = explain_statement_verifier->materialized_result->error; + return ErrorData(explain_error.Type(), StringUtil::Format("Query succeeded but EXPLAIN failed with: %s", + explain_error.RawMessage())); } // LCOV_EXCL_STOP #ifdef DUCKDB_VERIFY_BOX_RENDERER diff --git a/src/duckdb/src/main/config.cpp b/src/duckdb/src/main/config.cpp index 85493d06a..1c9c1dae1 100644 --- a/src/duckdb/src/main/config.cpp +++ b/src/duckdb/src/main/config.cpp @@ -22,100 +22,97 @@ namespace duckdb { bool DBConfigOptions::debug_print_bindings = false; #endif -#define DUCKDB_GLOBAL(_PARAM) \ +#define DUCKDB_SETTING(_PARAM) \ { \ - _PARAM::Name, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, nullptr, _PARAM::ResetGlobal, \ - nullptr, _PARAM::GetSetting \ + _PARAM::Name, _PARAM::Description, _PARAM::InputType, nullptr, nullptr, nullptr, nullptr, nullptr, \ + _PARAM::DefaultScope, _PARAM::DefaultValue, nullptr \ } -#define DUCKDB_GLOBAL_ALIAS(_ALIAS, _PARAM) \ +#define DUCKDB_SETTING_CALLBACK(_PARAM) \ { \ - _ALIAS, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, nullptr, _PARAM::ResetGlobal, nullptr, \ - _PARAM::GetSetting \ + _PARAM::Name, _PARAM::Description, _PARAM::InputType, nullptr, nullptr, nullptr, nullptr, nullptr, \ + _PARAM::DefaultScope, _PARAM::DefaultValue, _PARAM::OnSet \ } - -#define DUCKDB_LOCAL(_PARAM) \ +#define DUCKDB_GLOBAL(_PARAM) \ { \ - _PARAM::Name, _PARAM::Description, _PARAM::InputType, nullptr, _PARAM::SetLocal, nullptr, _PARAM::ResetLocal, \ - _PARAM::GetSetting \ + _PARAM::Name, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, nullptr, _PARAM::ResetGlobal, \ + nullptr, _PARAM::GetSetting, SetScope::AUTOMATIC, nullptr, nullptr \ } -#define DUCKDB_LOCAL_ALIAS(_ALIAS, _PARAM) \ +#define DUCKDB_LOCAL(_PARAM) \ { \ - _ALIAS, _PARAM::Description, _PARAM::InputType, nullptr, _PARAM::SetLocal, nullptr, _PARAM::ResetLocal, \ - _PARAM::GetSetting \ + _PARAM::Name, _PARAM::Description, _PARAM::InputType, nullptr, _PARAM::SetLocal, nullptr, _PARAM::ResetLocal, \ + _PARAM::GetSetting, SetScope::AUTOMATIC, nullptr, nullptr \ } - #define DUCKDB_GLOBAL_LOCAL(_PARAM) \ { \ _PARAM::Name, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, _PARAM::SetLocal, \ - _PARAM::ResetGlobal, _PARAM::ResetLocal, _PARAM::GetSetting \ - } -#define DUCKDB_GLOBAL_LOCAL_ALIAS(_ALIAS, _PARAM) \ - { \ - _ALIAS, _PARAM::Description, _PARAM::InputType, _PARAM::SetGlobal, _PARAM::SetLocal, _PARAM::ResetGlobal, \ - _PARAM::ResetLocal, _PARAM::GetSetting \ + _PARAM::ResetGlobal, _PARAM::ResetLocal, _PARAM::GetSetting, SetScope::AUTOMATIC, nullptr, nullptr \ } #define FINAL_SETTING \ - { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr } + { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, SetScope::AUTOMATIC, nullptr, nullptr } + +#define DUCKDB_SETTING_ALIAS(_ALIAS, _SETTING_INDEX) \ + { _ALIAS, _SETTING_INDEX } +#define FINAL_ALIAS \ + { nullptr, 0 } static const ConfigurationOption internal_options[] = { + DUCKDB_GLOBAL(AccessModeSetting), DUCKDB_GLOBAL(AllocatorBackgroundThreadsSetting), DUCKDB_GLOBAL(AllocatorBulkDeallocationFlushThresholdSetting), DUCKDB_GLOBAL(AllocatorFlushThresholdSetting), DUCKDB_GLOBAL(AllowCommunityExtensionsSetting), - DUCKDB_GLOBAL(AllowExtensionsMetadataMismatchSetting), + DUCKDB_SETTING(AllowExtensionsMetadataMismatchSetting), DUCKDB_GLOBAL(AllowPersistentSecretsSetting), DUCKDB_GLOBAL(AllowUnredactedSecretsSetting), DUCKDB_GLOBAL(AllowUnsignedExtensionsSetting), DUCKDB_GLOBAL(AllowedDirectoriesSetting), DUCKDB_GLOBAL(AllowedPathsSetting), - DUCKDB_GLOBAL(ArrowLargeBufferSizeSetting), - DUCKDB_GLOBAL(ArrowLosslessConversionSetting), - DUCKDB_GLOBAL(ArrowOutputListViewSetting), - DUCKDB_GLOBAL(ArrowOutputVersionSetting), - DUCKDB_LOCAL(AsofLoopJoinThresholdSetting), + DUCKDB_SETTING(ArrowLargeBufferSizeSetting), + DUCKDB_SETTING(ArrowLosslessConversionSetting), + DUCKDB_SETTING(ArrowOutputListViewSetting), + DUCKDB_SETTING_CALLBACK(ArrowOutputVersionSetting), + DUCKDB_SETTING(AsofLoopJoinThresholdSetting), DUCKDB_GLOBAL(AutoinstallExtensionRepositorySetting), DUCKDB_GLOBAL(AutoinstallKnownExtensionsSetting), DUCKDB_GLOBAL(AutoloadKnownExtensionsSetting), - DUCKDB_GLOBAL(CatalogErrorMaxSchemasSetting), + DUCKDB_SETTING(CatalogErrorMaxSchemasSetting), DUCKDB_GLOBAL(CheckpointThresholdSetting), - DUCKDB_GLOBAL_ALIAS("wal_autocheckpoint", CheckpointThresholdSetting), DUCKDB_GLOBAL(CustomExtensionRepositorySetting), DUCKDB_LOCAL(CustomProfilingSettingsSetting), DUCKDB_GLOBAL(CustomUserAgentSetting), - DUCKDB_LOCAL(DebugAsofIejoinSetting), - DUCKDB_GLOBAL(DebugCheckpointAbortSetting), + DUCKDB_SETTING(DebugAsofIejoinSetting), + DUCKDB_SETTING_CALLBACK(DebugCheckpointAbortSetting), DUCKDB_LOCAL(DebugForceExternalSetting), - DUCKDB_LOCAL(DebugForceNoCrossProductSetting), - DUCKDB_GLOBAL(DebugSkipCheckpointOnCommitSetting), - DUCKDB_GLOBAL(DebugVerifyVectorSetting), - DUCKDB_GLOBAL(DebugWindowModeSetting), + DUCKDB_SETTING(DebugForceNoCrossProductSetting), + DUCKDB_SETTING(DebugSkipCheckpointOnCommitSetting), + DUCKDB_SETTING_CALLBACK(DebugVerifyVectorSetting), + DUCKDB_SETTING_CALLBACK(DebugWindowModeSetting), DUCKDB_GLOBAL(DefaultBlockSizeSetting), - DUCKDB_GLOBAL_LOCAL(DefaultCollationSetting), - DUCKDB_GLOBAL(DefaultNullOrderSetting), - DUCKDB_GLOBAL_ALIAS("null_order", DefaultNullOrderSetting), - DUCKDB_GLOBAL(DefaultOrderSetting), + DUCKDB_SETTING_CALLBACK(DefaultCollationSetting), + DUCKDB_SETTING_CALLBACK(DefaultNullOrderSetting), + DUCKDB_SETTING_CALLBACK(DefaultOrderSetting), DUCKDB_GLOBAL(DefaultSecretStorageSetting), DUCKDB_GLOBAL(DisableDatabaseInvalidationSetting), - DUCKDB_LOCAL(DisableTimestamptzCastsSetting), + DUCKDB_SETTING(DisableTimestamptzCastsSetting), DUCKDB_GLOBAL(DisabledCompressionMethodsSetting), DUCKDB_GLOBAL(DisabledFilesystemsSetting), DUCKDB_GLOBAL(DisabledLogTypes), DUCKDB_GLOBAL(DisabledOptimizersSetting), DUCKDB_GLOBAL(DuckDBAPISetting), - DUCKDB_LOCAL(DynamicOrFilterThresholdSetting), + DUCKDB_SETTING(DynamicOrFilterThresholdSetting), DUCKDB_GLOBAL(EnableExternalAccessSetting), DUCKDB_GLOBAL(EnableExternalFileCacheSetting), - DUCKDB_GLOBAL(EnableFSSTVectorsSetting), + DUCKDB_SETTING(EnableFSSTVectorsSetting), DUCKDB_LOCAL(EnableHTTPLoggingSetting), DUCKDB_GLOBAL(EnableHTTPMetadataCacheSetting), DUCKDB_GLOBAL(EnableLogging), - DUCKDB_GLOBAL(EnableMacroDependenciesSetting), - DUCKDB_GLOBAL(EnableObjectCacheSetting), + DUCKDB_SETTING(EnableMacroDependenciesSetting), + DUCKDB_SETTING(EnableObjectCacheSetting), DUCKDB_LOCAL(EnableProfilingSetting), DUCKDB_LOCAL(EnableProgressBarSetting), DUCKDB_LOCAL(EnableProgressBarPrintSetting), - DUCKDB_GLOBAL(EnableViewDependenciesSetting), + DUCKDB_SETTING(EnableViewDependenciesSetting), DUCKDB_GLOBAL(EnabledLogTypes), DUCKDB_LOCAL(ErrorsAsJSONSetting), DUCKDB_LOCAL(ExplainOutputSetting), @@ -130,12 +127,12 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(HTTPProxyPasswordSetting), DUCKDB_GLOBAL(HTTPProxyUsernameSetting), DUCKDB_LOCAL(IEEEFloatingPointOpsSetting), - DUCKDB_GLOBAL(ImmediateTransactionModeSetting), - DUCKDB_GLOBAL(IndexScanMaxCountSetting), - DUCKDB_GLOBAL(IndexScanPercentageSetting), - DUCKDB_LOCAL(IntegerDivisionSetting), + DUCKDB_SETTING(ImmediateTransactionModeSetting), + DUCKDB_SETTING(IndexScanMaxCountSetting), + DUCKDB_SETTING_CALLBACK(IndexScanPercentageSetting), + DUCKDB_SETTING(IntegerDivisionSetting), DUCKDB_LOCAL(LambdaSyntaxSetting), - DUCKDB_LOCAL(LateMaterializationMaxRowsSetting), + DUCKDB_SETTING(LateMaterializationMaxRowsSetting), DUCKDB_GLOBAL(LockConfigurationSetting), DUCKDB_LOCAL(LogQueryPathSetting), DUCKDB_GLOBAL(LoggingLevel), @@ -143,31 +140,29 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(LoggingStorage), DUCKDB_LOCAL(MaxExpressionDepthSetting), DUCKDB_GLOBAL(MaxMemorySetting), - DUCKDB_GLOBAL_ALIAS("memory_limit", MaxMemorySetting), DUCKDB_GLOBAL(MaxTempDirectorySizeSetting), - DUCKDB_GLOBAL(MaxVacuumTasksSetting), - DUCKDB_LOCAL(MergeJoinThresholdSetting), - DUCKDB_LOCAL(NestedLoopJoinThresholdSetting), - DUCKDB_GLOBAL(OldImplicitCastingSetting), - DUCKDB_LOCAL(OrderByNonIntegerLiteralSetting), + DUCKDB_SETTING(MaxVacuumTasksSetting), + DUCKDB_SETTING(MergeJoinThresholdSetting), + DUCKDB_SETTING(NestedLoopJoinThresholdSetting), + DUCKDB_SETTING(OldImplicitCastingSetting), + DUCKDB_SETTING(OrderByNonIntegerLiteralSetting), DUCKDB_LOCAL(OrderedAggregateThresholdSetting), - DUCKDB_LOCAL(PartitionedWriteFlushThresholdSetting), - DUCKDB_LOCAL(PartitionedWriteMaxOpenFilesSetting), + DUCKDB_SETTING(PartitionedWriteFlushThresholdSetting), + DUCKDB_SETTING(PartitionedWriteMaxOpenFilesSetting), DUCKDB_GLOBAL(PasswordSetting), DUCKDB_LOCAL(PerfectHtThresholdSetting), DUCKDB_GLOBAL(PinThreadsSetting), - DUCKDB_LOCAL(PivotFilterThresholdSetting), - DUCKDB_LOCAL(PivotLimitSetting), - DUCKDB_LOCAL(PreferRangeJoinsSetting), - DUCKDB_LOCAL(PreserveIdentifierCaseSetting), - DUCKDB_GLOBAL(PreserveInsertionOrderSetting), - DUCKDB_GLOBAL(ProduceArrowStringViewSetting), + DUCKDB_SETTING(PivotFilterThresholdSetting), + DUCKDB_SETTING(PivotLimitSetting), + DUCKDB_SETTING(PreferRangeJoinsSetting), + DUCKDB_SETTING(PreserveIdentifierCaseSetting), + DUCKDB_SETTING(PreserveInsertionOrderSetting), + DUCKDB_SETTING(ProduceArrowStringViewSetting), DUCKDB_LOCAL(ProfileOutputSetting), - DUCKDB_LOCAL_ALIAS("profiling_output", ProfileOutputSetting), DUCKDB_LOCAL(ProfilingCoverageSetting), DUCKDB_LOCAL(ProfilingModeSetting), DUCKDB_LOCAL(ProgressBarTimeSetting), - DUCKDB_LOCAL(ScalarSubqueryErrorOnMultipleRowsSetting), + DUCKDB_SETTING(ScalarSubqueryErrorOnMultipleRowsSetting), DUCKDB_GLOBAL(SchedulerProcessPartialSetting), DUCKDB_LOCAL(SchemaSetting), DUCKDB_LOCAL(SearchPathSetting), @@ -177,14 +172,19 @@ static const ConfigurationOption internal_options[] = { DUCKDB_GLOBAL(TempDirectorySetting), DUCKDB_GLOBAL(TempFileEncryptionSetting), DUCKDB_GLOBAL(ThreadsSetting), - DUCKDB_GLOBAL_ALIAS("worker_threads", ThreadsSetting), DUCKDB_GLOBAL(UsernameSetting), - DUCKDB_GLOBAL_ALIAS("user", UsernameSetting), - DUCKDB_GLOBAL(VariantLegacyEncodingSetting), DUCKDB_GLOBAL(WalEncryptionSetting), DUCKDB_GLOBAL(ZstdMinStringLengthSetting), FINAL_SETTING}; +static const ConfigurationAlias setting_aliases[] = {DUCKDB_SETTING_ALIAS("memory_limit", 82), + DUCKDB_SETTING_ALIAS("null_order", 33), + DUCKDB_SETTING_ALIAS("profiling_output", 101), + DUCKDB_SETTING_ALIAS("user", 115), + DUCKDB_SETTING_ALIAS("wal_autocheckpoint", 20), + DUCKDB_SETTING_ALIAS("worker_threads", 114), + FINAL_ALIAS}; + vector DBConfig::GetOptions() { vector options; for (idx_t index = 0; internal_options[index].name; index++) { @@ -193,29 +193,45 @@ vector DBConfig::GetOptions() { return options; } +SettingCallbackInfo::SettingCallbackInfo(ClientContext &context_p, SetScope scope) + : config(DBConfig::GetConfig(context_p)), db(context_p.db.get()), context(context_p), scope(scope) { +} + +SettingCallbackInfo::SettingCallbackInfo(DBConfig &config, optional_ptr db) + : config(config), db(db), context(nullptr), scope(SetScope::GLOBAL) { +} + idx_t DBConfig::GetOptionCount() { - idx_t count = 0; - for (idx_t index = 0; internal_options[index].name; index++) { - count++; - } - return count; + return sizeof(internal_options) / sizeof(ConfigurationOption) - 1; +} + +idx_t DBConfig::GetAliasCount() { + return sizeof(setting_aliases) / sizeof(ConfigurationAlias) - 1; } -vector DBConfig::GetOptionNames() { +vector DBConfig::GetOptionNames() { vector names; - for (idx_t i = 0, option_count = DBConfig::GetOptionCount(); i < option_count; i++) { - names.emplace_back(DBConfig::GetOptionByIndex(i)->name); + for (idx_t index = 0; internal_options[index].name; index++) { + names.emplace_back(internal_options[index].name); + } + for (idx_t index = 0; setting_aliases[index].alias; index++) { + names.emplace_back(setting_aliases[index].alias); } return names; } optional_ptr DBConfig::GetOptionByIndex(idx_t target_index) { - for (idx_t index = 0; internal_options[index].name; index++) { - if (index == target_index) { - return internal_options + index; - } + if (target_index >= GetOptionCount()) { + return nullptr; } - return nullptr; + return internal_options + target_index; +} + +optional_ptr DBConfig::GetAliasByIndex(idx_t target_index) { + if (target_index >= GetAliasCount()) { + return nullptr; + } + return setting_aliases + target_index; } optional_ptr DBConfig::GetOptionByName(const string &name) { @@ -226,6 +242,12 @@ optional_ptr DBConfig::GetOptionByName(const string & return internal_options + index; } } + for (idx_t index = 0; setting_aliases[index].alias; index++) { + D_ASSERT(StringUtil::Lower(internal_options[index].name) == string(internal_options[index].name)); + if (setting_aliases[index].alias == lname) { + return GetOptionByIndex(setting_aliases[index].option_index); + } + } return nullptr; } @@ -261,23 +283,37 @@ void DBConfig::SetOptionsByName(const case_insensitive_map_t &values) { } } -void DBConfig::SetOption(DatabaseInstance *db, const ConfigurationOption &option, const Value &value) { +void DBConfig::SetOption(optional_ptr db, const ConfigurationOption &option, const Value &value) { lock_guard l(config_lock); + Value input = value.DefaultCastAs(ParseLogicalType(option.parameter_type)); + if (option.default_value) { + // generic option + if (option.set_callback) { + SettingCallbackInfo info(*this, db); + option.set_callback(info, input); + } + options.set_variables.emplace(option.name, std::move(input)); + return; + } if (!option.set_global) { throw InvalidInputException("Could not set option \"%s\" as a global option", option.name); } D_ASSERT(option.reset_global); - Value input = value.DefaultCastAs(ParseLogicalType(option.parameter_type)); - option.set_global(db, *this, input); + option.set_global(db.get(), *this, input); } -void DBConfig::ResetOption(DatabaseInstance *db, const ConfigurationOption &option) { +void DBConfig::ResetOption(optional_ptr db, const ConfigurationOption &option) { lock_guard l(config_lock); + if (option.default_value) { + // generic option + options.set_variables.erase(option.name); + return; + } if (!option.reset_global) { throw InternalException("Could not reset option \"%s\" as a global option", option.name); } D_ASSERT(option.set_global); - option.reset_global(db, *this); + option.reset_global(db.get(), *this); } void DBConfig::SetOption(const string &name, Value value) { @@ -299,6 +335,11 @@ void DBConfig::ResetOption(const string &name) { } } +void DBConfig::ResetGenericOption(const string &name) { + lock_guard l(config_lock); + options.set_variables.erase(name); +} + LogicalType DBConfig::ParseLogicalType(const string &type) { if (StringUtil::EndsWith(type, "[]")) { // list - recurse @@ -390,9 +431,9 @@ LogicalType DBConfig::ParseLogicalType(const string &type) { } void DBConfig::AddExtensionOption(const string &name, string description, LogicalType parameter, - const Value &default_value, set_option_callback_t function) { - extension_parameters.insert( - make_pair(name, ExtensionOption(std::move(description), std::move(parameter), function, default_value))); + const Value &default_value, set_option_callback_t function, SetScope default_scope) { + extension_parameters.insert(make_pair( + name, ExtensionOption(std::move(description), std::move(parameter), function, default_value, default_scope))); // copy over unrecognized options, if they match the new extension option auto iter = options.unrecognized_options.find(name); if (iter != options.unrecognized_options.end()) { @@ -635,18 +676,57 @@ bool DBConfig::operator!=(const DBConfig &other) { return !(other.options == options); } -OrderType DBConfig::ResolveOrder(OrderType order_type) const { +OrderType DBConfig::ResolveOrder(ClientContext &context, OrderType order_type) const { if (order_type != OrderType::ORDER_DEFAULT) { return order_type; } - return options.default_order_type; + return GetSetting(context); +} + +Value DBConfig::GetSettingInternal(const ClientContext &context, const char *setting, const char *default_value) { + Value result_val; + if (context.TryGetCurrentSetting(setting, result_val)) { + return result_val; + } + return Value(default_value); +} + +Value DBConfig::GetSettingInternal(const DBConfig &config, const char *setting, const char *default_value) { + Value result_val; + if (config.TryGetCurrentSetting(setting, result_val)) { + return result_val; + } + return Value(default_value); +} + +Value DBConfig::GetSettingInternal(const DatabaseInstance &db, const char *setting, const char *default_value) { + return GetSettingInternal(DBConfig::GetConfig(db), setting, default_value); +} + +SettingLookupResult DBConfig::TryGetCurrentSetting(const string &key, Value &result) const { + const auto &global_config_map = options.set_variables; + + auto global_value = global_config_map.find(key); + if (global_value != global_config_map.end()) { + result = global_value->second; + return SettingLookupResult(SettingScope::GLOBAL); + } + auto option = GetOptionByName(key); + if (option && option->default_value) { + auto input_type = ParseLogicalType(option->parameter_type); + result = Value(option->default_value).DefaultCastAs(input_type); + return SettingLookupResult(SettingScope::GLOBAL); + } + return SettingLookupResult(); } -OrderByNullType DBConfig::ResolveNullOrder(OrderType order_type, OrderByNullType null_type) const { +OrderByNullType DBConfig::ResolveNullOrder(ClientContext &context, OrderType order_type, + OrderByNullType null_type) const { if (null_type != OrderByNullType::ORDER_DEFAULT) { return null_type; } - switch (options.default_null_order) { + auto null_order = GetSetting(context); + switch (null_order) { case DefaultOrderByNullType::NULLS_FIRST: return OrderByNullType::NULLS_FIRST; case DefaultOrderByNullType::NULLS_LAST: diff --git a/src/duckdb/src/main/database.cpp b/src/duckdb/src/main/database.cpp index 493e56516..75b08e22d 100644 --- a/src/duckdb/src/main/database.cpp +++ b/src/duckdb/src/main/database.cpp @@ -503,18 +503,10 @@ bool DuckDB::ExtensionIsLoaded(const std::string &name) { return instance->ExtensionIsLoaded(name); } -SettingLookupResult DatabaseInstance::TryGetCurrentSetting(const std::string &key, Value &result) const { +SettingLookupResult DatabaseInstance::TryGetCurrentSetting(const string &key, Value &result) const { // check the session values auto &db_config = DBConfig::GetConfig(*this); - const auto &global_config_map = db_config.options.set_variables; - - auto global_value = global_config_map.find(key); - bool found_global_value = global_value != global_config_map.end(); - if (!found_global_value) { - return SettingLookupResult(); - } - result = global_value->second; - return SettingLookupResult(SettingScope::GLOBAL); + return db_config.TryGetCurrentSetting(key, result); } shared_ptr DatabaseInstance::GetEncryptionUtil() const { diff --git a/src/duckdb/src/main/database_manager.cpp b/src/duckdb/src/main/database_manager.cpp index 88e2c6dc6..f3b1599a2 100644 --- a/src/duckdb/src/main/database_manager.cpp +++ b/src/duckdb/src/main/database_manager.cpp @@ -30,8 +30,8 @@ void DatabaseManager::InitializeSystemCatalog() { } void DatabaseManager::FinalizeStartup() { - auto databases = GetDatabases(); - for (auto &db : databases) { + auto dbs = GetDatabases(); + for (auto &db : dbs) { db.get().FinalizeLoad(nullptr); } } @@ -113,40 +113,34 @@ void DatabaseManager::DetachDatabase(ClientContext &context, const string &name, } } -optional_ptr DatabaseManager::GetDatabaseFromPath(ClientContext &context, const string &path) { - auto database_list = GetDatabases(context); - for (auto &db_ref : database_list) { - auto &db = db_ref.get(); - if (db.IsSystem()) { - continue; - } - auto &catalog = Catalog::GetCatalog(db); - if (catalog.InMemory()) { - continue; - } - auto db_path = catalog.GetDBPath(); - if (StringUtil::CIEquals(path, db_path)) { - return &db; - } - } - return nullptr; -} - void DatabaseManager::CheckPathConflict(ClientContext &context, const string &path) { - // ensure that we did not already attach a database with the same path - bool path_exists; + // Ensure that we did not already attach a database with the same path. + string db_name = ""; { lock_guard path_lock(db_paths_lock); - path_exists = db_paths.find(path) != db_paths.end(); - } - if (path_exists) { - // check that the database is actually still attached - auto entry = GetDatabaseFromPath(context, path); - if (entry) { - throw BinderException("Unique file handle conflict: Database \"%s\" is already attached with path \"%s\", ", - entry->name, path); + auto it = db_paths_to_name.find(path); + if (it != db_paths_to_name.end()) { + db_name = it->second; } } + if (db_name.empty()) { + return; + } + + // Check against the catalog set. + auto entry = GetDatabase(context, db_name); + if (!entry) { + return; + } + if (entry->IsSystem()) { + return; + } + auto &catalog = Catalog::GetCatalog(*entry); + if (catalog.InMemory()) { + return; + } + throw BinderException("Unique file handle conflict: Database \"%s\" is already attached with path \"%s\", ", + db_name, path); } void DatabaseManager::InsertDatabasePath(ClientContext &context, const string &path, const string &name) { @@ -156,7 +150,7 @@ void DatabaseManager::InsertDatabasePath(ClientContext &context, const string &p CheckPathConflict(context, path); lock_guard path_lock(db_paths_lock); - db_paths.insert(path); + db_paths_to_name[path] = name; } void DatabaseManager::EraseDatabasePath(const string &path) { @@ -164,17 +158,17 @@ void DatabaseManager::EraseDatabasePath(const string &path) { return; } lock_guard path_lock(db_paths_lock); - auto path_it = db_paths.find(path); - if (path_it != db_paths.end()) { - db_paths.erase(path_it); + auto path_it = db_paths_to_name.find(path); + if (path_it != db_paths_to_name.end()) { + db_paths_to_name.erase(path_it); } } vector DatabaseManager::GetAttachedDatabasePaths() { lock_guard path_lock(db_paths_lock); vector paths; - for (auto &path : db_paths) { - paths.push_back(path); + for (auto &entry : db_paths_to_name) { + paths.push_back(entry.first); } return paths; } diff --git a/src/duckdb/src/main/extension/extension_helper.cpp b/src/duckdb/src/main/extension/extension_helper.cpp index a852b00da..74add5379 100644 --- a/src/duckdb/src/main/extension/extension_helper.cpp +++ b/src/duckdb/src/main/extension/extension_helper.cpp @@ -10,6 +10,7 @@ #include "duckdb/main/database.hpp" #include "duckdb/main/extension.hpp" #include "duckdb/main/extension_install_info.hpp" +#include "duckdb/main/settings.hpp" // Note that c++ preprocessor doesn't have a nice way to clean this up so we need to set the defines we use to false // explicitly when they are undefined @@ -212,9 +213,8 @@ bool ExtensionHelper::TryAutoLoadExtension(ClientContext &context, const string auto &dbconfig = DBConfig::GetConfig(context); try { if (dbconfig.options.autoinstall_known_extensions) { - auto &config = DBConfig::GetConfig(context); auto autoinstall_repo = ExtensionRepository::GetRepositoryByUrl( - StringValue::Get(config.GetSetting(context))); + StringValue::Get(DBConfig::GetConfig(context).options.autoinstall_extension_repo)); ExtensionInstallOptions options; options.repository = autoinstall_repo; ExtensionHelper::InstallExtension(context, extension_name, options); @@ -260,8 +260,6 @@ static ExtensionUpdateResult UpdateExtensionInternal(ClientContext &context, Dat ExtensionUpdateResult result; result.extension_name = extension_name; - auto &config = DBConfig::GetConfig(db); - if (!fs.FileExists(full_extension_path)) { result.tag = ExtensionUpdateResultTag::NOT_INSTALLED; return result; @@ -277,7 +275,7 @@ static ExtensionUpdateResult UpdateExtensionInternal(ClientContext &context, Dat // Parse the version of the extension before updating auto ext_binary_handle = fs.OpenFile(full_extension_path, FileOpenFlags::FILE_FLAGS_READ); auto parsed_metadata = ExtensionHelper::ParseExtensionMetaData(*ext_binary_handle); - if (!parsed_metadata.AppearsValid() && !config.options.allow_extensions_metadata_mismatch) { + if (!parsed_metadata.AppearsValid() && !DBConfig::GetSetting(context)) { throw IOException( "Failed to update extension: '%s', the metadata of the extension appears invalid! To resolve this, either " "reinstall the extension using 'FORCE INSTALL %s', manually remove the file '%s', or enable '" @@ -433,23 +431,6 @@ ExtensionLoadResult ExtensionHelper::LoadExtension(DuckDB &db, const std::string ExtensionLoadResult ExtensionHelper::LoadExtensionInternal(DuckDB &db, const std::string &extension, bool initial_load) { -#ifdef DUCKDB_TEST_REMOTE_INSTALL - if (!initial_load && StringUtil::Contains(DUCKDB_TEST_REMOTE_INSTALL, extension)) { - Connection con(db); - auto result = con.Query("INSTALL " + extension); - if (result->HasError()) { - result->Print(); - return ExtensionLoadResult::EXTENSION_UNKNOWN; - } - result = con.Query("LOAD " + extension); - if (result->HasError()) { - result->Print(); - return ExtensionLoadResult::EXTENSION_UNKNOWN; - } - return ExtensionLoadResult::LOADED_EXTENSION; - } -#endif - #ifdef DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE // Note: weird comma's are on purpose to do easy string contains on a list of extension names if (!initial_load && StringUtil::Contains(DUCKDB_EXTENSIONS_TEST_WITH_LOADABLE, "," + extension + ",")) { diff --git a/src/duckdb/src/main/extension/extension_install.cpp b/src/duckdb/src/main/extension/extension_install.cpp index 7c7d87472..ca02a979c 100644 --- a/src/duckdb/src/main/extension/extension_install.cpp +++ b/src/duckdb/src/main/extension/extension_install.cpp @@ -10,7 +10,7 @@ #include "duckdb/main/extension_install_info.hpp" #include "duckdb/main/secret/secret.hpp" #include "duckdb/main/secret/secret_manager.hpp" - +#include "duckdb/main/settings.hpp" #include "duckdb/common/windows_undefs.hpp" #include @@ -224,7 +224,7 @@ static void CheckExtensionMetadataOnInstall(DatabaseInstance &db, void *in_buffe auto metadata_mismatch_error = parsed_metadata.GetInvalidMetadataError(); - if (!metadata_mismatch_error.empty() && !db.config.options.allow_extensions_metadata_mismatch) { + if (!metadata_mismatch_error.empty() && !DBConfig::GetSetting(db)) { throw IOException("Failed to install '%s'\n%s", extension_name, metadata_mismatch_error); } @@ -488,7 +488,7 @@ unique_ptr ExtensionHelper::InstallExtensionInternal(Datab if (fs.FileExists(local_extension_path) && !options.force_install) { // File exists: throw error if origin mismatches - if (options.throw_on_origin_mismatch && !db.config.options.allow_extensions_metadata_mismatch && + if (options.throw_on_origin_mismatch && !DBConfig::GetSetting(db) && fs.FileExists(local_extension_path + ".info")) { ThrowErrorOnMismatchingExtensionOrigin(fs, local_extension_path, extension_name, extension, options.repository); diff --git a/src/duckdb/src/main/extension/extension_load.cpp b/src/duckdb/src/main/extension/extension_load.cpp index 671bf1611..88eeeda37 100644 --- a/src/duckdb/src/main/extension/extension_load.cpp +++ b/src/duckdb/src/main/extension/extension_load.cpp @@ -8,6 +8,7 @@ #include "duckdb/main/error_manager.hpp" #include "duckdb/main/extension_helper.hpp" #include "duckdb/main/extension_manager.hpp" +#include "duckdb/main/settings.hpp" #include "mbedtls_wrapper.hpp" #ifndef DUCKDB_NO_THREADS @@ -404,7 +405,7 @@ bool ExtensionHelper::TryInitialLoad(DatabaseInstance &db, FileSystem &fs, const if (!signature_valid) { throw IOException(db.config.error_manager->FormatException(ErrorType::UNSIGNED_EXTENSION, filename)); } - } else if (!db.config.options.allow_extensions_metadata_mismatch) { + } else if (!DBConfig::GetSetting(db)) { if (!metadata_mismatch_error.empty()) { // Unsigned extensions AND configuration allowing n, loading allowed, mainly for // debugging purposes diff --git a/src/duckdb/src/main/extension_manager.cpp b/src/duckdb/src/main/extension_manager.cpp index 2d911feff..033f32ea1 100644 --- a/src/duckdb/src/main/extension_manager.cpp +++ b/src/duckdb/src/main/extension_manager.cpp @@ -2,6 +2,7 @@ #include "duckdb/main/database.hpp" #include "duckdb/planner/extension_callback.hpp" #include "duckdb/main/extension_helper.hpp" +#include "duckdb/logging/log_manager.hpp" namespace duckdb { diff --git a/src/duckdb/src/main/query_result.cpp b/src/duckdb/src/main/query_result.cpp index 70a1b6a85..a20f9a87a 100644 --- a/src/duckdb/src/main/query_result.cpp +++ b/src/duckdb/src/main/query_result.cpp @@ -62,7 +62,7 @@ QueryResult::QueryResult(QueryResultType type, StatementType statement_type, Sta QueryResult::QueryResult(QueryResultType type, ErrorData error) : BaseQueryResult(type, std::move(error)), - client_properties("UTC", ArrowOffsetSize::REGULAR, false, false, false, V1_0, nullptr) { + client_properties("UTC", ArrowOffsetSize::REGULAR, false, false, false, ArrowFormatVersion::V1_0, nullptr) { } QueryResult::~QueryResult() { diff --git a/src/duckdb/src/main/settings/autogenerated_settings.cpp b/src/duckdb/src/main/settings/autogenerated_settings.cpp index 3c867895e..1139ddb4d 100644 --- a/src/duckdb/src/main/settings/autogenerated_settings.cpp +++ b/src/duckdb/src/main/settings/autogenerated_settings.cpp @@ -78,22 +78,6 @@ Value AllowCommunityExtensionsSetting::GetSetting(const ClientContext &context) return Value::BOOLEAN(config.options.allow_community_extensions); } -//===----------------------------------------------------------------------===// -// Allow Extensions Metadata Mismatch -//===----------------------------------------------------------------------===// -void AllowExtensionsMetadataMismatchSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.allow_extensions_metadata_mismatch = input.GetValue(); -} - -void AllowExtensionsMetadataMismatchSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.allow_extensions_metadata_mismatch = DBConfig().options.allow_extensions_metadata_mismatch; -} - -Value AllowExtensionsMetadataMismatchSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.allow_extensions_metadata_mismatch); -} - //===----------------------------------------------------------------------===// // Allow Unredacted Secrets //===----------------------------------------------------------------------===// @@ -138,67 +122,11 @@ Value AllowUnsignedExtensionsSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.options.allow_unsigned_extensions); } -//===----------------------------------------------------------------------===// -// Arrow Large Buffer Size -//===----------------------------------------------------------------------===// -void ArrowLargeBufferSizeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.arrow_offset_size = DBConfig().options.arrow_offset_size; -} - -//===----------------------------------------------------------------------===// -// Arrow Lossless Conversion -//===----------------------------------------------------------------------===// -void ArrowLosslessConversionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.arrow_lossless_conversion = input.GetValue(); -} - -void ArrowLosslessConversionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.arrow_lossless_conversion = DBConfig().options.arrow_lossless_conversion; -} - -Value ArrowLosslessConversionSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.arrow_lossless_conversion); -} - -//===----------------------------------------------------------------------===// -// Arrow Output List View -//===----------------------------------------------------------------------===// -void ArrowOutputListViewSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.arrow_use_list_view = input.GetValue(); -} - -void ArrowOutputListViewSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.arrow_use_list_view = DBConfig().options.arrow_use_list_view; -} - -Value ArrowOutputListViewSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.arrow_use_list_view); -} - //===----------------------------------------------------------------------===// // Arrow Output Version //===----------------------------------------------------------------------===// -void ArrowOutputVersionSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.arrow_output_version = DBConfig().options.arrow_output_version; -} - -//===----------------------------------------------------------------------===// -// Asof Loop Join Threshold -//===----------------------------------------------------------------------===// -void AsofLoopJoinThresholdSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.asof_loop_join_threshold = input.GetValue(); -} - -void AsofLoopJoinThresholdSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).asof_loop_join_threshold = ClientConfig().asof_loop_join_threshold; -} - -Value AsofLoopJoinThresholdSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.asof_loop_join_threshold); +void ArrowOutputVersionSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); } //===----------------------------------------------------------------------===// @@ -249,22 +177,6 @@ Value AutoloadKnownExtensionsSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.options.autoload_known_extensions); } -//===----------------------------------------------------------------------===// -// Catalog Error Max Schemas -//===----------------------------------------------------------------------===// -void CatalogErrorMaxSchemasSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.catalog_error_max_schemas = input.GetValue(); -} - -void CatalogErrorMaxSchemasSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.catalog_error_max_schemas = DBConfig().options.catalog_error_max_schemas; -} - -Value CatalogErrorMaxSchemasSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::UBIGINT(config.options.catalog_error_max_schemas); -} - //===----------------------------------------------------------------------===// // Checkpoint Threshold //===----------------------------------------------------------------------===// @@ -296,38 +208,11 @@ Value CustomUserAgentSetting::GetSetting(const ClientContext &context) { return Value(config.options.custom_user_agent); } -//===----------------------------------------------------------------------===// -// Debug Asof Iejoin -//===----------------------------------------------------------------------===// -void DebugAsofIejoinSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.force_asof_iejoin = input.GetValue(); -} - -void DebugAsofIejoinSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).force_asof_iejoin = ClientConfig().force_asof_iejoin; -} - -Value DebugAsofIejoinSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.force_asof_iejoin); -} - //===----------------------------------------------------------------------===// // Debug Checkpoint Abort //===----------------------------------------------------------------------===// -void DebugCheckpointAbortSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto str_input = StringUtil::Upper(input.GetValue()); - config.options.checkpoint_abort = EnumUtil::FromString(str_input); -} - -void DebugCheckpointAbortSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.checkpoint_abort = DBConfig().options.checkpoint_abort; -} - -Value DebugCheckpointAbortSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::Lower(EnumUtil::ToString(config.options.checkpoint_abort))); +void DebugCheckpointAbortSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); } //===----------------------------------------------------------------------===// @@ -347,90 +232,18 @@ Value DebugForceExternalSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.force_external); } -//===----------------------------------------------------------------------===// -// Debug Force No Cross Product -//===----------------------------------------------------------------------===// -void DebugForceNoCrossProductSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.force_no_cross_product = input.GetValue(); -} - -void DebugForceNoCrossProductSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).force_no_cross_product = ClientConfig().force_no_cross_product; -} - -Value DebugForceNoCrossProductSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.force_no_cross_product); -} - -//===----------------------------------------------------------------------===// -// Debug Skip Checkpoint On Commit -//===----------------------------------------------------------------------===// -void DebugSkipCheckpointOnCommitSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.debug_skip_checkpoint_on_commit = input.GetValue(); -} - -void DebugSkipCheckpointOnCommitSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.debug_skip_checkpoint_on_commit = DBConfig().options.debug_skip_checkpoint_on_commit; -} - -Value DebugSkipCheckpointOnCommitSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.debug_skip_checkpoint_on_commit); -} - //===----------------------------------------------------------------------===// // Debug Verify Vector //===----------------------------------------------------------------------===// -void DebugVerifyVectorSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto str_input = StringUtil::Upper(input.GetValue()); - config.options.debug_verify_vector = EnumUtil::FromString(str_input); -} - -void DebugVerifyVectorSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.debug_verify_vector = DBConfig().options.debug_verify_vector; -} - -Value DebugVerifyVectorSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::Lower(EnumUtil::ToString(config.options.debug_verify_vector))); +void DebugVerifyVectorSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); } //===----------------------------------------------------------------------===// // Debug Window Mode //===----------------------------------------------------------------------===// -void DebugWindowModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto str_input = StringUtil::Upper(input.GetValue()); - config.options.window_mode = EnumUtil::FromString(str_input); -} - -void DebugWindowModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.window_mode = DBConfig().options.window_mode; -} - -Value DebugWindowModeSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::Lower(EnumUtil::ToString(config.options.window_mode))); -} - -//===----------------------------------------------------------------------===// -// Default Null Order -//===----------------------------------------------------------------------===// -void DefaultNullOrderSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.default_null_order = DBConfig().options.default_null_order; -} - -Value DefaultNullOrderSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(StringUtil::Lower(EnumUtil::ToString(config.options.default_null_order))); -} - -//===----------------------------------------------------------------------===// -// Default Order -//===----------------------------------------------------------------------===// -void DefaultOrderSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.default_order_type = DBConfig().options.default_order_type; +void DebugWindowModeSetting::OnSet(SettingCallbackInfo &info, Value ¶meter) { + EnumUtil::FromString(StringValue::Get(parameter)); } //===----------------------------------------------------------------------===// @@ -455,40 +268,6 @@ Value DisableDatabaseInvalidationSetting::GetSetting(const ClientContext &contex return Value::BOOLEAN(config.options.disable_database_invalidation); } -//===----------------------------------------------------------------------===// -// Disable Timestamptz Casts -//===----------------------------------------------------------------------===// -void DisableTimestamptzCastsSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.disable_timestamptz_casts = input.GetValue(); -} - -void DisableTimestamptzCastsSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).disable_timestamptz_casts = ClientConfig().disable_timestamptz_casts; -} - -Value DisableTimestamptzCastsSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.disable_timestamptz_casts); -} - -//===----------------------------------------------------------------------===// -// Dynamic Or Filter Threshold -//===----------------------------------------------------------------------===// -void DynamicOrFilterThresholdSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.dynamic_or_filter_threshold = input.GetValue(); -} - -void DynamicOrFilterThresholdSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).dynamic_or_filter_threshold = ClientConfig().dynamic_or_filter_threshold; -} - -Value DynamicOrFilterThresholdSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.dynamic_or_filter_threshold); -} - //===----------------------------------------------------------------------===// // Enable External Access //===----------------------------------------------------------------------===// @@ -511,22 +290,6 @@ Value EnableExternalAccessSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.options.enable_external_access); } -//===----------------------------------------------------------------------===// -// Enable F S S T Vectors -//===----------------------------------------------------------------------===// -void EnableFSSTVectorsSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.enable_fsst_vectors = input.GetValue(); -} - -void EnableFSSTVectorsSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.enable_fsst_vectors = DBConfig().options.enable_fsst_vectors; -} - -Value EnableFSSTVectorsSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.enable_fsst_vectors); -} - //===----------------------------------------------------------------------===// // Enable H T T P Metadata Cache //===----------------------------------------------------------------------===// @@ -543,22 +306,6 @@ Value EnableHTTPMetadataCacheSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.options.http_metadata_cache_enable); } -//===----------------------------------------------------------------------===// -// Enable Macro Dependencies -//===----------------------------------------------------------------------===// -void EnableMacroDependenciesSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.enable_macro_dependencies = input.GetValue(); -} - -void EnableMacroDependenciesSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.enable_macro_dependencies = DBConfig().options.enable_macro_dependencies; -} - -Value EnableMacroDependenciesSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.enable_macro_dependencies); -} - //===----------------------------------------------------------------------===// // Enable Progress Bar //===----------------------------------------------------------------------===// @@ -582,22 +329,6 @@ Value EnableProgressBarSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.enable_progress_bar); } -//===----------------------------------------------------------------------===// -// Enable View Dependencies -//===----------------------------------------------------------------------===// -void EnableViewDependenciesSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.enable_view_dependencies = input.GetValue(); -} - -void EnableViewDependenciesSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.enable_view_dependencies = DBConfig().options.enable_view_dependencies; -} - -Value EnableViewDependenciesSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.enable_view_dependencies); -} - //===----------------------------------------------------------------------===// // Errors As J S O N //===----------------------------------------------------------------------===// @@ -748,91 +479,6 @@ Value IEEEFloatingPointOpsSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.ieee_floating_point_ops); } -//===----------------------------------------------------------------------===// -// Immediate Transaction Mode -//===----------------------------------------------------------------------===// -void ImmediateTransactionModeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.immediate_transaction_mode = input.GetValue(); -} - -void ImmediateTransactionModeSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.immediate_transaction_mode = DBConfig().options.immediate_transaction_mode; -} - -Value ImmediateTransactionModeSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.immediate_transaction_mode); -} - -//===----------------------------------------------------------------------===// -// Index Scan Max Count -//===----------------------------------------------------------------------===// -void IndexScanMaxCountSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.index_scan_max_count = input.GetValue(); -} - -void IndexScanMaxCountSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.index_scan_max_count = DBConfig().options.index_scan_max_count; -} - -Value IndexScanMaxCountSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::UBIGINT(config.options.index_scan_max_count); -} - -//===----------------------------------------------------------------------===// -// Index Scan Percentage -//===----------------------------------------------------------------------===// -void IndexScanPercentageSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - if (!OnGlobalSet(db, config, input)) { - return; - } - config.options.index_scan_percentage = input.GetValue(); -} - -void IndexScanPercentageSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.index_scan_percentage = DBConfig().options.index_scan_percentage; -} - -Value IndexScanPercentageSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::DOUBLE(config.options.index_scan_percentage); -} - -//===----------------------------------------------------------------------===// -// Integer Division -//===----------------------------------------------------------------------===// -void IntegerDivisionSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.integer_division = input.GetValue(); -} - -void IntegerDivisionSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).integer_division = ClientConfig().integer_division; -} - -Value IntegerDivisionSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.integer_division); -} - -//===----------------------------------------------------------------------===// -// Late Materialization Max Rows -//===----------------------------------------------------------------------===// -void LateMaterializationMaxRowsSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.late_materialization_max_rows = input.GetValue(); -} - -void LateMaterializationMaxRowsSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).late_materialization_max_rows = ClientConfig().late_materialization_max_rows; -} - -Value LateMaterializationMaxRowsSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.late_materialization_max_rows); -} - //===----------------------------------------------------------------------===// // Lock Configuration //===----------------------------------------------------------------------===// @@ -866,89 +512,6 @@ Value MaxExpressionDepthSetting::GetSetting(const ClientContext &context) { return Value::UBIGINT(config.max_expression_depth); } -//===----------------------------------------------------------------------===// -// Max Vacuum Tasks -//===----------------------------------------------------------------------===// -void MaxVacuumTasksSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.max_vacuum_tasks = input.GetValue(); -} - -void MaxVacuumTasksSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.max_vacuum_tasks = DBConfig().options.max_vacuum_tasks; -} - -Value MaxVacuumTasksSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::UBIGINT(config.options.max_vacuum_tasks); -} - -//===----------------------------------------------------------------------===// -// Merge Join Threshold -//===----------------------------------------------------------------------===// -void MergeJoinThresholdSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.merge_join_threshold = input.GetValue(); -} - -void MergeJoinThresholdSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).merge_join_threshold = ClientConfig().merge_join_threshold; -} - -Value MergeJoinThresholdSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.merge_join_threshold); -} - -//===----------------------------------------------------------------------===// -// Nested Loop Join Threshold -//===----------------------------------------------------------------------===// -void NestedLoopJoinThresholdSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.nested_loop_join_threshold = input.GetValue(); -} - -void NestedLoopJoinThresholdSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).nested_loop_join_threshold = ClientConfig().nested_loop_join_threshold; -} - -Value NestedLoopJoinThresholdSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.nested_loop_join_threshold); -} - -//===----------------------------------------------------------------------===// -// Old Implicit Casting -//===----------------------------------------------------------------------===// -void OldImplicitCastingSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.old_implicit_casting = input.GetValue(); -} - -void OldImplicitCastingSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.old_implicit_casting = DBConfig().options.old_implicit_casting; -} - -Value OldImplicitCastingSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.old_implicit_casting); -} - -//===----------------------------------------------------------------------===// -// Order By Non Integer Literal -//===----------------------------------------------------------------------===// -void OrderByNonIntegerLiteralSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.order_by_non_integer_literal = input.GetValue(); -} - -void OrderByNonIntegerLiteralSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).order_by_non_integer_literal = ClientConfig().order_by_non_integer_literal; -} - -Value OrderByNonIntegerLiteralSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.order_by_non_integer_literal); -} - //===----------------------------------------------------------------------===// // Ordered Aggregate Threshold //===----------------------------------------------------------------------===// @@ -969,41 +532,6 @@ Value OrderedAggregateThresholdSetting::GetSetting(const ClientContext &context) return Value::UBIGINT(config.ordered_aggregate_threshold); } -//===----------------------------------------------------------------------===// -// Partitioned Write Flush Threshold -//===----------------------------------------------------------------------===// -void PartitionedWriteFlushThresholdSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.partitioned_write_flush_threshold = input.GetValue(); -} - -void PartitionedWriteFlushThresholdSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).partitioned_write_flush_threshold = - ClientConfig().partitioned_write_flush_threshold; -} - -Value PartitionedWriteFlushThresholdSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.partitioned_write_flush_threshold); -} - -//===----------------------------------------------------------------------===// -// Partitioned Write Max Open Files -//===----------------------------------------------------------------------===// -void PartitionedWriteMaxOpenFilesSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.partitioned_write_max_open_files = input.GetValue(); -} - -void PartitionedWriteMaxOpenFilesSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).partitioned_write_max_open_files = ClientConfig().partitioned_write_max_open_files; -} - -Value PartitionedWriteMaxOpenFilesSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.partitioned_write_max_open_files); -} - //===----------------------------------------------------------------------===// // Perfect Ht Threshold //===----------------------------------------------------------------------===// @@ -1028,124 +556,6 @@ Value PinThreadsSetting::GetSetting(const ClientContext &context) { return Value(StringUtil::Lower(EnumUtil::ToString(config.options.pin_threads))); } -//===----------------------------------------------------------------------===// -// Pivot Filter Threshold -//===----------------------------------------------------------------------===// -void PivotFilterThresholdSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.pivot_filter_threshold = input.GetValue(); -} - -void PivotFilterThresholdSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).pivot_filter_threshold = ClientConfig().pivot_filter_threshold; -} - -Value PivotFilterThresholdSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.pivot_filter_threshold); -} - -//===----------------------------------------------------------------------===// -// Pivot Limit -//===----------------------------------------------------------------------===// -void PivotLimitSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.pivot_limit = input.GetValue(); -} - -void PivotLimitSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).pivot_limit = ClientConfig().pivot_limit; -} - -Value PivotLimitSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::UBIGINT(config.pivot_limit); -} - -//===----------------------------------------------------------------------===// -// Prefer Range Joins -//===----------------------------------------------------------------------===// -void PreferRangeJoinsSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.prefer_range_joins = input.GetValue(); -} - -void PreferRangeJoinsSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).prefer_range_joins = ClientConfig().prefer_range_joins; -} - -Value PreferRangeJoinsSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.prefer_range_joins); -} - -//===----------------------------------------------------------------------===// -// Preserve Identifier Case -//===----------------------------------------------------------------------===// -void PreserveIdentifierCaseSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.preserve_identifier_case = input.GetValue(); -} - -void PreserveIdentifierCaseSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).preserve_identifier_case = ClientConfig().preserve_identifier_case; -} - -Value PreserveIdentifierCaseSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.preserve_identifier_case); -} - -//===----------------------------------------------------------------------===// -// Preserve Insertion Order -//===----------------------------------------------------------------------===// -void PreserveInsertionOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.preserve_insertion_order = input.GetValue(); -} - -void PreserveInsertionOrderSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.preserve_insertion_order = DBConfig().options.preserve_insertion_order; -} - -Value PreserveInsertionOrderSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.preserve_insertion_order); -} - -//===----------------------------------------------------------------------===// -// Produce Arrow String View -//===----------------------------------------------------------------------===// -void ProduceArrowStringViewSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.produce_arrow_string_views = input.GetValue(); -} - -void ProduceArrowStringViewSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.produce_arrow_string_views = DBConfig().options.produce_arrow_string_views; -} - -Value ProduceArrowStringViewSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.produce_arrow_string_views); -} - -//===----------------------------------------------------------------------===// -// Scalar Subquery Error On Multiple Rows -//===----------------------------------------------------------------------===// -void ScalarSubqueryErrorOnMultipleRowsSetting::SetLocal(ClientContext &context, const Value &input) { - auto &config = ClientConfig::GetConfig(context); - config.scalar_subquery_error_on_multiple_rows = input.GetValue(); -} - -void ScalarSubqueryErrorOnMultipleRowsSetting::ResetLocal(ClientContext &context) { - ClientConfig::GetConfig(context).scalar_subquery_error_on_multiple_rows = - ClientConfig().scalar_subquery_error_on_multiple_rows; -} - -Value ScalarSubqueryErrorOnMultipleRowsSetting::GetSetting(const ClientContext &context) { - auto &config = ClientConfig::GetConfig(context); - return Value::BOOLEAN(config.scalar_subquery_error_on_multiple_rows); -} - //===----------------------------------------------------------------------===// // Scheduler Process Partial //===----------------------------------------------------------------------===// @@ -1162,22 +572,6 @@ Value SchedulerProcessPartialSetting::GetSetting(const ClientContext &context) { return Value::BOOLEAN(config.options.scheduler_process_partial); } -//===----------------------------------------------------------------------===// -// Variant Legacy Encoding -//===----------------------------------------------------------------------===// -void VariantLegacyEncodingSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - config.options.variant_legacy_encoding = input.GetValue(); -} - -void VariantLegacyEncodingSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.variant_legacy_encoding = DBConfig().options.variant_legacy_encoding; -} - -Value VariantLegacyEncodingSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value::BOOLEAN(config.options.variant_legacy_encoding); -} - //===----------------------------------------------------------------------===// // Wal Encryption //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/main/settings/custom_settings.cpp b/src/duckdb/src/main/settings/custom_settings.cpp index aa166f702..2a78df2d3 100644 --- a/src/duckdb/src/main/settings/custom_settings.cpp +++ b/src/duckdb/src/main/settings/custom_settings.cpp @@ -285,71 +285,6 @@ Value AllowedPathsSetting::GetSetting(const ClientContext &context) { return Value::LIST(LogicalType::VARCHAR, std::move(allowed_paths)); } -//===----------------------------------------------------------------------===// -// Arrow Large Buffer Size -//===----------------------------------------------------------------------===// -void ArrowLargeBufferSizeSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto export_large_buffers_arrow = input.GetValue(); - config.options.arrow_offset_size = export_large_buffers_arrow ? ArrowOffsetSize::LARGE : ArrowOffsetSize::REGULAR; -} - -Value ArrowLargeBufferSizeSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - bool export_large_buffers_arrow = config.options.arrow_offset_size == ArrowOffsetSize::LARGE; - return Value::BOOLEAN(export_large_buffers_arrow); -} - -//===----------------------------------------------------------------------===// -// Arrow Output Format Version -//===----------------------------------------------------------------------===// -void ArrowOutputVersionSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto arrow_version = input.ToString(); - if (arrow_version == "1.0") { - config.options.arrow_output_version = V1_0; - } else if (arrow_version == "1.1") { - config.options.arrow_output_version = V1_1; - } else if (arrow_version == "1.2") { - config.options.arrow_output_version = V1_2; - } else if (arrow_version == "1.3") { - config.options.arrow_output_version = V1_3; - } else if (arrow_version == "1.4") { - config.options.arrow_output_version = V1_4; - } else if (arrow_version == "1.5") { - config.options.arrow_output_version = V1_5; - } else { - throw NotImplementedException("Unrecognized parameter for option arrow_output_version, expected either " - "\'1.0\', \'1.1\', \'1.2\', \'1.3\', \'1.4\', \'1.5\'"); - } -} - -Value ArrowOutputVersionSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - string arrow_version; - switch (config.options.arrow_output_version) { - case V1_0: - arrow_version = "1.0"; - break; - case V1_1: - arrow_version = "1.1"; - break; - case V1_2: - arrow_version = "1.2"; - break; - case V1_3: - arrow_version = "1.3"; - break; - case V1_4: - arrow_version = "1.4"; - break; - case V1_5: - arrow_version = "1.5"; - break; - default: - throw InternalException("Unrecognized arrow output version"); - } - return Value(arrow_version); -} - //===----------------------------------------------------------------------===// // Checkpoint Threshold //===----------------------------------------------------------------------===// @@ -495,48 +430,27 @@ Value DefaultBlockSizeSetting::GetSetting(const ClientContext &context) { //===----------------------------------------------------------------------===// // Default Collation //===----------------------------------------------------------------------===// -void DefaultCollationSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { - auto parameter = StringUtil::Lower(input.ToString()); - config.options.collation = parameter; -} - -void DefaultCollationSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { - config.options.collation = DBConfig().options.collation; -} - -void DefaultCollationSetting::SetLocal(ClientContext &context, const Value &input) { - auto parameter = input.ToString(); - // bind the collation to verify that it exists - ExpressionBinder::TestCollation(context, parameter); - auto &config = DBConfig::GetConfig(context); - config.options.collation = parameter; -} - -void DefaultCollationSetting::ResetLocal(ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - config.options.collation = DBConfig().options.collation; -} - -Value DefaultCollationSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - return Value(config.options.collation); +void DefaultCollationSetting::OnSet(SettingCallbackInfo &info, Value &input) { + if (info.context) { + ExpressionBinder::TestCollation(*info.context, input.ToString()); + } } //===----------------------------------------------------------------------===// // Default Null Order //===----------------------------------------------------------------------===// -void DefaultNullOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { +void DefaultNullOrderSetting::OnSet(SettingCallbackInfo &, Value &input) { auto parameter = StringUtil::Lower(input.ToString()); if (parameter == "nulls_first" || parameter == "nulls first" || parameter == "null first" || parameter == "first") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_FIRST; + input = Value("NULLS_FIRST"); } else if (parameter == "nulls_last" || parameter == "nulls last" || parameter == "null last" || parameter == "last") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_LAST; + input = Value("NULLS_LAST"); } else if (parameter == "nulls_first_on_asc_last_on_desc" || parameter == "sqlite" || parameter == "mysql") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC; + input = Value("NULLS_FIRST_ON_ASC_LAST_ON_DESC"); } else if (parameter == "nulls_last_on_asc_first_on_desc" || parameter == "postgres") { - config.options.default_null_order = DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC; + input = Value("NULLS_LAST_ON_ASC_FIRST_ON_DESC"); } else { throw ParserException("Unrecognized parameter for option NULL_ORDER \"%s\", expected either NULLS FIRST, NULLS " "LAST, SQLite, MySQL or Postgres", @@ -547,30 +461,18 @@ void DefaultNullOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, //===----------------------------------------------------------------------===// // Default Order //===----------------------------------------------------------------------===// -void DefaultOrderSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { +void DefaultOrderSetting::OnSet(SettingCallbackInfo &, Value &input) { auto parameter = StringUtil::Lower(input.ToString()); if (parameter == "ascending" || parameter == "asc") { - config.options.default_order_type = OrderType::ASCENDING; + input = Value("ASC"); } else if (parameter == "descending" || parameter == "desc") { - config.options.default_order_type = OrderType::DESCENDING; + input = Value("DESC"); } else { throw InvalidInputException("Unrecognized parameter for option DEFAULT_ORDER \"%s\". Expected ASC or DESC.", parameter); } } -Value DefaultOrderSetting::GetSetting(const ClientContext &context) { - auto &config = DBConfig::GetConfig(context); - switch (config.options.default_order_type) { - case OrderType::ASCENDING: - return "asc"; - case OrderType::DESCENDING: - return "desc"; - default: - throw InternalException("Unknown order type setting"); - } -} - //===----------------------------------------------------------------------===// // Default Secret Storage //===----------------------------------------------------------------------===// @@ -879,19 +781,6 @@ void DisabledLogTypes::ResetGlobal(DatabaseInstance *db_p, DBConfig &config) { db.GetLogManager().SetDisabledLogTypes(set); } -//===----------------------------------------------------------------------===// -// Enable Object Cache -//===----------------------------------------------------------------------===// -void EnableObjectCacheSetting::SetGlobal(DatabaseInstance *db, DBConfig &config, const Value &input) { -} - -void EnableObjectCacheSetting::ResetGlobal(DatabaseInstance *db, DBConfig &config) { -} - -Value EnableObjectCacheSetting::GetSetting(const ClientContext &context) { - return Value(); -} - //===----------------------------------------------------------------------===// // Enable Profiling //===----------------------------------------------------------------------===// @@ -1157,12 +1046,11 @@ Value HTTPLoggingOutputSetting::GetSetting(const ClientContext &context) { //===----------------------------------------------------------------------===// // Index Scan Percentage //===----------------------------------------------------------------------===// -bool IndexScanPercentageSetting::OnGlobalSet(DatabaseInstance *db, DBConfig &config, const Value &input) { +void IndexScanPercentageSetting::OnSet(SettingCallbackInfo &, Value &input) { auto index_scan_percentage = input.GetValue(); if (index_scan_percentage < 0 || index_scan_percentage > 1.0) { throw InvalidInputException("the index scan percentage must be within [0, 1]"); } - return true; } //===----------------------------------------------------------------------===// diff --git a/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp b/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp index bada196f5..2eea10358 100644 --- a/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp +++ b/src/duckdb/src/optimizer/build_probe_side_optimizer.cpp @@ -13,6 +13,7 @@ #include "duckdb/optimizer/optimizer.hpp" #include "duckdb/planner/operator/logical_cross_product.hpp" #include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -229,8 +230,9 @@ void BuildProbeSideOptimizer::VisitOperator(LogicalOperator &op) { // if the conditions have no equality, do not flip the children. // There is no physical join operator (yet) that can do an inequality right_semi/anti join. idx_t has_range = 0; + bool prefer_range_joins = DBConfig::GetSetting(context); if (op.type == LogicalOperatorType::LOGICAL_ANY_JOIN || - (op.Cast().HasEquality(has_range) && !context.config.prefer_range_joins)) { + (op.Cast().HasEquality(has_range) && !prefer_range_joins)) { TryFlipJoinChildren(join); } break; diff --git a/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp index 799622d2a..6e00355e0 100644 --- a/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp +++ b/src/duckdb/src/optimizer/column_lifetime_analyzer.cpp @@ -11,6 +11,7 @@ #include "duckdb/planner/operator/logical_filter.hpp" #include "duckdb/planner/operator/logical_order.hpp" #include "duckdb/planner/operator/logical_projection.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -80,7 +81,8 @@ void ColumnLifetimeAnalyzer::VisitOperator(LogicalOperator &op) { // FIXME: for now, we only push into the projection map for equality (hash) joins idx_t has_range = 0; - if (!comp_join.HasEquality(has_range) || optimizer.context.config.prefer_range_joins) { + bool prefer_range_joins = DBConfig::GetSetting(optimizer.context); + if (!comp_join.HasEquality(has_range) || prefer_range_joins) { return; } diff --git a/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp index 2539bf37a..6dd086ffc 100644 --- a/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp +++ b/src/duckdb/src/optimizer/join_order/cardinality_estimator.cpp @@ -355,7 +355,7 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { auto denom_multiplier = 1.0 + static_cast(unused_edge_tdoms.size()); // It's possible cross-products were added and are not present in the filters in the relation_2_tdom - // structures. When that's the case, merge all remaining subgraphs. + // structures. When that's the case, merge all remaining subgraphs as if they are connected by a cross product if (subgraphs.size() > 1) { auto final_subgraph = subgraphs.at(0); for (auto merge_with = subgraphs.begin() + 1; merge_with != subgraphs.end(); merge_with++) { @@ -367,6 +367,23 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { final_subgraph.denom *= merge_with->denom; } } + if (!subgraphs.empty()) { + // Some relations are connected by cross products and will not end up in a subgraph + // Check and make sure all relations were considered, if not, they are connected to the graph by cross products + auto &returning_subgraph = subgraphs.at(0); + if (returning_subgraph.relations->count != set.count) { + for (idx_t rel_index = 0; rel_index < set.count; rel_index++) { + auto relation_id = set.relations[rel_index]; + auto &rel = set_manager.GetJoinRelation(relation_id); + if (!JoinRelationSet::IsSubset(*returning_subgraph.relations, rel)) { + returning_subgraph.numerator_relations = + &set_manager.Union(*returning_subgraph.numerator_relations, rel); + returning_subgraph.relations = &set_manager.Union(*returning_subgraph.relations, rel); + } + } + } + } + // can happen if a table has cardinality 0, a tdom is set to 0, or if a cross product is used. if (subgraphs.empty() || subgraphs.at(0).denom == 0) { // denominator is 1 and numerators are a cross product of cardinalities. @@ -377,7 +394,6 @@ DenomInfo CardinalityEstimator::GetDenominator(JoinRelationSet &set) { template <> double CardinalityEstimator::EstimateCardinalityWithSet(JoinRelationSet &new_set) { - if (relation_set_2_cardinality.find(new_set.ToString()) != relation_set_2_cardinality.end()) { return relation_set_2_cardinality[new_set.ToString()].cardinality_before_filters; } diff --git a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp index 54f846615..fc282aba9 100644 --- a/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp +++ b/src/duckdb/src/optimizer/join_order/plan_enumerator.cpp @@ -3,6 +3,7 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/optimizer/join_order/join_node.hpp" #include "duckdb/optimizer/join_order/query_graph_manager.hpp" +#include "duckdb/main/settings.hpp" #include @@ -68,7 +69,7 @@ static vector> GetAllNeighborSets(vector neighbors) count += 1; } } - D_ASSERT(count == static_cast(std::pow(2, neighbors.size()) - 1)); + D_ASSERT(count == static_cast(std::pow(2, neighbors.size() - 1))); } #endif return ret; @@ -470,7 +471,7 @@ void PlanEnumerator::InitLeafPlans() { // Moerkotte and Thomas Neumannn, see that paper for additional info/documentation bonus slides: // https://db.in.tum.de/teaching/ws1415/queryopt/chapter3.pdf?lang=de void PlanEnumerator::SolveJoinOrder() { - bool force_no_cross_product = query_graph_manager.context.config.force_no_cross_product; + bool force_no_cross_product = DBConfig::GetSetting(query_graph_manager.context); // first try to solve the join order exactly if (query_graph_manager.relation_manager.NumRelations() >= THRESHOLD_TO_SWAP_TO_APPROXIMATE) { SolveJoinOrderApproximately(); diff --git a/src/duckdb/src/optimizer/late_materialization.cpp b/src/duckdb/src/optimizer/late_materialization.cpp index 19c1c01f2..a144df188 100644 --- a/src/duckdb/src/optimizer/late_materialization.cpp +++ b/src/duckdb/src/optimizer/late_materialization.cpp @@ -14,11 +14,12 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/main/client_config.hpp" #include "duckdb/main/config.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { LateMaterialization::LateMaterialization(Optimizer &optimizer) : optimizer(optimizer) { - max_row_count = ClientConfig::GetConfig(optimizer.context).late_materialization_max_rows; + max_row_count = DBConfig::GetSetting(optimizer.context); } vector LateMaterialization::GetOrInsertRowIds(LogicalGet &get) { @@ -411,8 +412,7 @@ bool LateMaterialization::TryLateMaterialization(unique_ptr &op } bool LateMaterialization::OptimizeLargeLimit(LogicalLimit &limit, idx_t limit_val, bool has_offset) { - auto &config = DBConfig::GetConfig(optimizer.context); - if (!has_offset && !config.options.preserve_insertion_order) { + if (!has_offset && !DBConfig::GetSetting(optimizer.context)) { // we avoid optimizing large limits if preserve insertion order is false // since the limit is executed in parallel anyway return false; diff --git a/src/duckdb/src/optimizer/optimizer.cpp b/src/duckdb/src/optimizer/optimizer.cpp index 93ec543ee..34d2f44bb 100644 --- a/src/duckdb/src/optimizer/optimizer.cpp +++ b/src/duckdb/src/optimizer/optimizer.cpp @@ -104,6 +104,7 @@ void Optimizer::RunBuiltInOptimizers() { case LogicalOperatorType::LOGICAL_TRANSACTION: case LogicalOperatorType::LOGICAL_PRAGMA: case LogicalOperatorType::LOGICAL_SET: + case LogicalOperatorType::LOGICAL_ATTACH: case LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS: case LogicalOperatorType::LOGICAL_CREATE_SECRET: case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: diff --git a/src/duckdb/src/parallel/pipeline.cpp b/src/duckdb/src/parallel/pipeline.cpp index 0072d291e..bc511539a 100644 --- a/src/duckdb/src/parallel/pipeline.cpp +++ b/src/duckdb/src/parallel/pipeline.cpp @@ -12,6 +12,7 @@ #include "duckdb/parallel/pipeline_event.hpp" #include "duckdb/parallel/pipeline_executor.hpp" #include "duckdb/parallel/task_scheduler.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -134,7 +135,6 @@ bool Pipeline::ScheduleParallel(shared_ptr &event) { } bool Pipeline::IsOrderDependent() const { - auto &config = DBConfig::GetConfig(executor.context); if (source) { auto source_order = source->SourceOrder(); if (source_order == OrderPreservationType::FIXED_ORDER) { @@ -153,7 +153,7 @@ bool Pipeline::IsOrderDependent() const { return true; } } - if (!config.options.preserve_insertion_order) { + if (!DBConfig::GetSetting(executor.context)) { return false; } if (sink && sink->SinkOrderDependent()) { diff --git a/src/duckdb/src/parallel/task_scheduler.cpp b/src/duckdb/src/parallel/task_scheduler.cpp index 9fda080f4..af2951947 100644 --- a/src/duckdb/src/parallel/task_scheduler.cpp +++ b/src/duckdb/src/parallel/task_scheduler.cpp @@ -419,7 +419,7 @@ idx_t TaskScheduler::GetEstimatedCPUId() { /* Other oses most likely use tpidr_el0 instead */ uintptr_t c; asm volatile("mrs %x0, tpidrro_el0" : "=r"(c)::"memory"); - return (idx_t)(c & (1 << 3) - 1); + return (idx_t)(c & ((1 << 3) - 1)); #else #ifndef DUCKDB_NO_THREADS // fallback to thread id diff --git a/src/duckdb/src/parallel/thread_context.cpp b/src/duckdb/src/parallel/thread_context.cpp index 1e90cc516..eb3d619ea 100644 --- a/src/duckdb/src/parallel/thread_context.cpp +++ b/src/duckdb/src/parallel/thread_context.cpp @@ -2,6 +2,7 @@ #include "duckdb/main/client_context.hpp" #include "duckdb/logging/logger.hpp" #include "duckdb/main/database.hpp" +#include "duckdb/logging/log_manager.hpp" namespace duckdb { @@ -28,7 +29,7 @@ ThreadContext::ThreadContext(ClientContext &context) : profiler(context) { log_context.transaction_id = query_id; } } - logger = context.db->GetLogManager().CreateLogger(log_context, true); + logger = LogManager::Get(context).CreateLogger(log_context, true); } ThreadContext::~ThreadContext() { diff --git a/src/duckdb/src/parser/parsed_data/attach_info.cpp b/src/duckdb/src/parser/parsed_data/attach_info.cpp index c4e1548c4..333c27c30 100644 --- a/src/duckdb/src/parser/parsed_data/attach_info.cpp +++ b/src/duckdb/src/parser/parsed_data/attach_info.cpp @@ -7,56 +7,14 @@ namespace duckdb { -StorageOptions AttachInfo::GetStorageOptions() const { - StorageOptions storage_options; - string storage_version_user_provided = ""; - for (auto &entry : options) { - if (entry.first == "block_size") { - // Extract the block allocation size. This is NOT the actual memory available on a block (block_size), - // even though the corresponding option we expose to the user is called "block_size". - storage_options.block_alloc_size = entry.second.GetValue(); - } else if (entry.first == "encryption_key") { - // check the type of the key - auto type = entry.second.type(); - if (type.id() != LogicalTypeId::VARCHAR) { - throw BinderException("\"%s\" is not a valid key. A key must be of type VARCHAR", - entry.second.ToString()); - } else if (entry.second.GetValue().empty()) { - throw BinderException("Not a valid key. A key cannot be empty"); - } - storage_options.user_key = - make_shared_ptr(StringValue::Get(entry.second.DefaultCastAs(LogicalType::BLOB))); - storage_options.block_header_size = DEFAULT_ENCRYPTION_BLOCK_HEADER_SIZE; - storage_options.encryption = true; - } else if (entry.first == "encryption_cipher") { - throw BinderException("\"%s\" is not a valid cipher. Only AES GCM is supported.", entry.second.ToString()); - } else if (entry.first == "row_group_size") { - storage_options.row_group_size = entry.second.GetValue(); - } else if (entry.first == "storage_version") { - storage_version_user_provided = entry.second.ToString(); - storage_options.storage_version = - SerializationCompatibility::FromString(entry.second.ToString()).serialization_version; - } - } - if (storage_options.encryption && (!storage_options.storage_version.IsValid() || - storage_options.storage_version.GetIndex() < - SerializationCompatibility::FromString("v1.4.0").serialization_version)) { - if (!storage_version_user_provided.empty()) { - throw InvalidInputException( - "Explicit provided STORAGE_VERSION (\"%s\") and ENCRYPTION_KEY (storage >= v1.4.0) are not compatible", - storage_version_user_provided); - } - // set storage version to v1.4.0 - storage_options.storage_version = SerializationCompatibility::FromString("v1.4.0").serialization_version; - } - return storage_options; -} - unique_ptr AttachInfo::Copy() const { auto result = make_uniq(); result->name = name; result->path = path; result->options = options; + for (auto &entry : parsed_options) { + result->parsed_options[entry.first] = entry.second->Copy(); + } result->on_conflict = on_conflict; return result; } @@ -69,13 +27,16 @@ string AttachInfo::ToString() const { } else if (on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT) { result += " OR REPLACE"; } - result += " DATABASE"; + result += " DATABASE "; result += KeywordHelper::WriteQuoted(path, '\''); if (!name.empty()) { result += " AS " + KeywordHelper::WriteOptionallyQuoted(name); } - if (!options.empty()) { + if (!parsed_options.empty() || !options.empty()) { vector stringified; + for (auto &opt : parsed_options) { + stringified.push_back(StringUtil::Format("%s %s", opt.first, opt.second->ToString())); + } for (auto &opt : options) { stringified.push_back(StringUtil::Format("%s %s", opt.first, opt.second.ToSQLString())); } diff --git a/src/duckdb/src/parser/parsed_data/copy_info.cpp b/src/duckdb/src/parser/parsed_data/copy_info.cpp index 9ba206579..98ebf8aa5 100644 --- a/src/duckdb/src/parser/parsed_data/copy_info.cpp +++ b/src/duckdb/src/parser/parsed_data/copy_info.cpp @@ -3,6 +3,10 @@ namespace duckdb { +CopyInfo::CopyInfo() + : ParseInfo(TYPE), catalog(INVALID_CATALOG), schema(DEFAULT_SCHEMA), is_from(false), is_format_auto_detected(true) { +} + unique_ptr CopyInfo::Copy() const { auto result = make_uniq(); result->catalog = catalog; @@ -13,6 +17,9 @@ unique_ptr CopyInfo::Copy() const { result->is_from = is_from; result->format = format; result->is_format_auto_detected = is_format_auto_detected; + for (auto &entry : parsed_options) { + result->parsed_options[entry.first] = entry.second ? entry.second->Copy() : nullptr; + } result->options = options; if (select_statement) { result->select_statement = select_statement->Copy(); @@ -20,11 +27,10 @@ unique_ptr CopyInfo::Copy() const { return result; } -string CopyInfo::CopyOptionsToString(const string &format, bool is_format_auto_detected, - const case_insensitive_map_t> &options) { +string CopyInfo::CopyOptionsToString() const { // We only output the format if there is a format, and it was manually set. const bool output_format = !format.empty() && !is_format_auto_detected; - if (!output_format && options.empty()) { + if (!output_format && options.empty() && parsed_options.empty()) { return string(); } string result; @@ -34,6 +40,15 @@ string CopyInfo::CopyOptionsToString(const string &format, bool is_format_auto_d if (!format.empty() && !is_format_auto_detected) { stringified.push_back(StringUtil::Format(" FORMAT %s", format)); } + for (auto &opt : parsed_options) { + auto &name = opt.first; + auto &expr = opt.second; + string option_string = name; + if (expr) { + option_string += " " + expr->ToString(); + } + stringified.push_back(option_string); + } for (auto &opt : options) { auto &name = opt.first; auto &values = opt.second; @@ -84,8 +99,6 @@ string CopyInfo::ToString() const { D_ASSERT(!select_statement); result += TablePartToString(); result += " FROM"; - result += StringUtil::Format(" %s", SQLString(file_path)); - result += CopyOptionsToString(format, is_format_auto_detected, options); } else { if (select_statement) { // COPY (select-node) TO ... @@ -94,9 +107,9 @@ string CopyInfo::ToString() const { result += TablePartToString(); } result += " TO "; - result += StringUtil::Format("%s", SQLString(file_path)); - result += CopyOptionsToString(format, is_format_auto_detected, options); } + result += StringUtil::Format(" %s", SQLString(file_path)); + result += CopyOptionsToString(); result += ";"; return result; } diff --git a/src/duckdb/src/parser/query_node.cpp b/src/duckdb/src/parser/query_node.cpp index d68e1169d..490d4d061 100644 --- a/src/duckdb/src/parser/query_node.cpp +++ b/src/duckdb/src/parser/query_node.cpp @@ -7,6 +7,7 @@ #include "duckdb/common/limits.hpp" #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/parser/statement/select_statement.hpp" namespace duckdb { diff --git a/src/duckdb/src/parser/query_node/select_node.cpp b/src/duckdb/src/parser/query_node/select_node.cpp index 3c3e4537c..71c228c40 100644 --- a/src/duckdb/src/parser/query_node/select_node.cpp +++ b/src/duckdb/src/parser/query_node/select_node.cpp @@ -106,6 +106,7 @@ string SelectNode::ToString() const { } result += ")"; } + return result + ResultModifiersToString(); } diff --git a/src/duckdb/src/parser/statement/export_statement.cpp b/src/duckdb/src/parser/statement/export_statement.cpp index 560c117ac..ad28634a0 100644 --- a/src/duckdb/src/parser/statement/export_statement.cpp +++ b/src/duckdb/src/parser/statement/export_statement.cpp @@ -24,10 +24,8 @@ string ExportStatement::ToString() const { } auto &path = info->file_path; D_ASSERT(info->is_from == false); - auto &options = info->options; - auto &format = info->format; result += StringUtil::Format(" '%s'", path); - result += CopyInfo::CopyOptionsToString(format, info->is_format_auto_detected, options); + result += info->CopyOptionsToString(); result += ";"; return result; } diff --git a/src/duckdb/src/parser/statement/merge_into_statement.cpp b/src/duckdb/src/parser/statement/merge_into_statement.cpp index 0b6ebcd95..0943f000b 100644 --- a/src/duckdb/src/parser/statement/merge_into_statement.cpp +++ b/src/duckdb/src/parser/statement/merge_into_statement.cpp @@ -16,6 +16,9 @@ MergeIntoStatement::MergeIntoStatement(const MergeIntoStatement &other) : SQLSta action_list.push_back(action->Copy()); } } + for (auto &entry : other.returning_list) { + returning_list.push_back(entry->Copy()); + } cte_map = other.cte_map.Copy(); } @@ -60,6 +63,20 @@ string MergeIntoStatement::ToString() const { result += action->ToString(); } } + if (!returning_list.empty()) { + result += " RETURNING "; + for (idx_t i = 0; i < returning_list.size(); i++) { + if (i > 0) { + result += ", "; + } + auto column = returning_list[i]->ToString(); + if (!returning_list[i]->GetAlias().empty()) { + column += + StringUtil::Format(" AS %s", KeywordHelper::WriteOptionallyQuoted(returning_list[i]->GetAlias())); + } + result += column; + } + } return result; } diff --git a/src/duckdb/src/parser/tableref/joinref.cpp b/src/duckdb/src/parser/tableref/joinref.cpp index 9e9a70dda..a36af4e46 100644 --- a/src/duckdb/src/parser/tableref/joinref.cpp +++ b/src/duckdb/src/parser/tableref/joinref.cpp @@ -22,7 +22,7 @@ string JoinRef::ToString() const { result += EnumUtil::ToString(type) + " JOIN "; break; case JoinRefType::CROSS: - result += ", "; + result += is_implicit ? ", " : "CROSS JOIN "; break; case JoinRefType::POSITIONAL: result += "POSITIONAL JOIN "; @@ -82,6 +82,7 @@ unique_ptr JoinRef::Copy() { for (auto &col : duplicate_eliminated_columns) { copy->duplicate_eliminated_columns.emplace_back(col->Copy()); } + copy->is_implicit = is_implicit; return std::move(copy); } diff --git a/src/duckdb/src/parser/tableref/showref.cpp b/src/duckdb/src/parser/tableref/showref.cpp index 950defb31..6a40719d3 100644 --- a/src/duckdb/src/parser/tableref/showref.cpp +++ b/src/duckdb/src/parser/tableref/showref.cpp @@ -1,4 +1,5 @@ #include "duckdb/parser/tableref/showref.hpp" +#include "duckdb/parser/keyword_helper.hpp" namespace duckdb { @@ -9,6 +10,17 @@ string ShowRef::ToString() const { string result; if (show_type == ShowType::SUMMARY) { result += "SUMMARIZE "; + } else if (show_type == ShowType::SHOW_FROM) { + result += "SHOW TABLES FROM "; + string name = ""; + if (!catalog_name.empty()) { + name += KeywordHelper::WriteOptionallyQuoted(catalog_name, '"'); + if (!schema_name.empty()) { + name += "."; + } + } + name += KeywordHelper::WriteOptionallyQuoted(schema_name, '"'); + result += name; } else { result += "DESCRIBE "; } @@ -38,6 +50,8 @@ bool ShowRef::Equals(const TableRef &other_p) const { unique_ptr ShowRef::Copy() { auto copy = make_uniq(); + copy->catalog_name = catalog_name; + copy->schema_name = schema_name; copy->table_name = table_name; copy->query = query ? query->Copy() : nullptr; copy->show_type = show_type; diff --git a/src/duckdb/src/parser/tableref/table_function.cpp b/src/duckdb/src/parser/tableref/table_function.cpp index 29a6da9bc..547eacc2d 100644 --- a/src/duckdb/src/parser/tableref/table_function.cpp +++ b/src/duckdb/src/parser/tableref/table_function.cpp @@ -9,7 +9,11 @@ TableFunctionRef::TableFunctionRef() : TableRef(TableReferenceType::TABLE_FUNCTI } string TableFunctionRef::ToString() const { - return BaseToString(function->ToString(), column_name_alias); + auto result = function->ToString(); + if (with_ordinality == OrdinalityType::WITH_ORDINALITY) { + result += " WITH ORDINALITY"; + } + return BaseToString(result, column_name_alias); } bool TableFunctionRef::Equals(const TableRef &other_p) const { @@ -25,6 +29,7 @@ unique_ptr TableFunctionRef::Copy() { copy->function = function->Copy(); copy->column_name_alias = column_name_alias; + copy->with_ordinality = with_ordinality; CopyProperties(*copy); return std::move(copy); diff --git a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp index 69f6286b4..d505804ab 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_cte.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_cte.cpp @@ -20,6 +20,9 @@ unique_ptr CommonTableExpressionInfo::Copy() { return result; } +CommonTableExpressionInfo::~CommonTableExpressionInfo() { +} + void Transformer::ExtractCTEsRecursive(CommonTableExpressionMap &cte_map) { for (auto &cte_entry : stored_cte_map) { for (auto &entry : cte_entry->map) { diff --git a/src/duckdb/src/parser/transform/helpers/transform_typename.cpp b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp index a7beb3ba2..e99716be6 100644 --- a/src/duckdb/src/parser/transform/helpers/transform_typename.cpp +++ b/src/duckdb/src/parser/transform/helpers/transform_typename.cpp @@ -56,11 +56,6 @@ vector Transformer::TransformTypeModifiers(duckdb_libpgquery::PGTypeName vector type_mods; if (type_name.typmods) { for (auto node = type_name.typmods->head; node; node = node->next) { - if (type_mods.size() > 9) { - const auto &name = - *PGPointerCast(type_name.names->tail->data.ptr_value)->val.str; - throw ParserException("'%s': a maximum of 9 type modifiers is allowed", name); - } const auto &const_val = *PGPointerCast(node->data.ptr_value); if (const_val.type != duckdb_libpgquery::T_PGAConst) { throw ParserException("Expected a constant as type modifier"); @@ -68,6 +63,10 @@ vector Transformer::TransformTypeModifiers(duckdb_libpgquery::PGTypeName const auto const_expr = TransformValue(const_val.val); type_mods.push_back(std::move(const_expr->value)); } + if (type_mods.size() > 9) { + const auto name = PGPointerCast(type_name.names->tail->data.ptr_value)->val.str; + throw ParserException("'%s': a maximum of 9 type modifiers is allowed", name); + } } return type_mods; } diff --git a/src/duckdb/src/parser/transform/statement/transform_attach.cpp b/src/duckdb/src/parser/transform/statement/transform_attach.cpp index 8c7632f74..1827795b2 100644 --- a/src/duckdb/src/parser/transform/statement/transform_attach.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_attach.cpp @@ -16,13 +16,13 @@ unique_ptr Transformer::TransformAttach(duckdb_libpgquery::PGAt duckdb_libpgquery::PGListCell *cell; for_each_cell(cell, stmt.options->head) { auto def_elem = PGPointerCast(cell->data.ptr_value); - Value val; + unique_ptr expr; if (def_elem->arg) { - val = TransformValue(*PGPointerCast(def_elem->arg))->value; + expr = TransformExpression(def_elem->arg); } else { - val = Value::BOOLEAN(true); + expr = make_uniq(Value::BOOLEAN(true)); } - info->options[StringUtil::Lower(def_elem->defname)] = std::move(val); + info->parsed_options[StringUtil::Lower(def_elem->defname)] = std::move(expr); } } result->info = std::move(info); diff --git a/src/duckdb/src/parser/transform/statement/transform_copy.cpp b/src/duckdb/src/parser/transform/statement/transform_copy.cpp index 7dd2a164f..d43c8e8cd 100644 --- a/src/duckdb/src/parser/transform/statement/transform_copy.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_copy.cpp @@ -11,47 +11,6 @@ namespace duckdb { -void Transformer::ParseGenericOptionListEntry(case_insensitive_map_t> &result_options, string &name, - duckdb_libpgquery::PGNode *arg) { - // otherwise - if (result_options.find(name) != result_options.end()) { - throw ParserException("Unexpected duplicate option \"%s\"", name); - } - if (!arg) { - result_options[name] = vector(); - return; - } - switch (arg->type) { - case duckdb_libpgquery::T_PGList: { - auto column_list = PGPointerCast(arg); - for (auto c = column_list->head; c != nullptr; c = lnext(c)) { - auto target = PGPointerCast(c->data.ptr_value); - result_options[name].push_back(Value(target->name)); - } - break; - } - case duckdb_libpgquery::T_PGAStar: - result_options[name].push_back(Value("*")); - break; - case duckdb_libpgquery::T_PGFuncCall: { - auto func_call = PGPointerCast(arg); - auto func_expr = TransformFuncCall(*func_call); - - Value value; - if (!Transformer::ConstructConstantFromExpression(*func_expr, value)) { - throw ParserException("Unsupported expression in option list: %s", func_expr->ToString()); - } - result_options[name].push_back(std::move(value)); - break; - } - default: { - auto val = PGPointerCast(arg); - result_options[name].push_back(TransformValue(*val)->value); - break; - } - } -} - void Transformer::TransformCopyOptions(CopyInfo &info, optional_ptr options) { if (!options) { return; @@ -61,20 +20,15 @@ void Transformer::TransformCopyOptions(CopyInfo &info, optional_ptrhead) { auto def_elem = PGPointerCast(cell->data.ptr_value); - if (StringUtil::Lower(def_elem->defname) == "format") { - // format specifier: interpret this option - auto format_val = PGPointerCast(def_elem->arg); - if (!format_val || format_val->type != duckdb_libpgquery::T_PGString) { - throw ParserException("Unsupported parameter type for FORMAT: expected e.g. FORMAT 'csv', 'parquet'"); - } - info.format = StringUtil::Lower(format_val->val.str); - info.is_format_auto_detected = false; - continue; - } - - // The rest ends up in the options string name = def_elem->defname; - ParseGenericOptionListEntry(info.options, name, def_elem->arg); + if (info.parsed_options.find(name) != info.parsed_options.end()) { + throw ParserException("Unexpected duplicate option \"%s\"", name); + } + unique_ptr expr; + if (def_elem->arg) { + expr = TransformExpression(*def_elem->arg); + } + info.parsed_options[name] = std::move(expr); } } diff --git a/src/duckdb/src/parser/transform/statement/transform_create_type.cpp b/src/duckdb/src/parser/transform/statement/transform_create_type.cpp index 084dd43bc..125d0f0dc 100644 --- a/src/duckdb/src/parser/transform/statement/transform_create_type.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_create_type.cpp @@ -28,7 +28,6 @@ Vector Transformer::PGListToVector(optional_ptr colum } auto entry_value = string(entry_value_node.val.str); - D_ASSERT(!entry_value.empty()); result_ptr[size++] = StringVector::AddStringOrBlob(result, entry_value); } return result; diff --git a/src/duckdb/src/parser/transform/statement/transform_merge_into.cpp b/src/duckdb/src/parser/transform/statement/transform_merge_into.cpp index 464c1bb51..634b58b92 100644 --- a/src/duckdb/src/parser/transform/statement/transform_merge_into.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_merge_into.cpp @@ -102,6 +102,9 @@ unique_ptr Transformer::TransformMergeInto(duckdb_libpgquery::PGMe for (auto &entry : unconditional_actions) { result->actions[entry.first].push_back(std::move(entry.second)); } + if (stmt.returningList) { + TransformExpressionList(*stmt.returningList, result->returning_list); + } return std::move(result); } diff --git a/src/duckdb/src/parser/transform/statement/transform_pragma.cpp b/src/duckdb/src/parser/transform/statement/transform_pragma.cpp index c0ff7b8cb..327233157 100644 --- a/src/duckdb/src/parser/transform/statement/transform_pragma.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_pragma.cpp @@ -52,7 +52,7 @@ unique_ptr Transformer::TransformPragma(duckdb_libpgquery::PGPragm throw ParserException("PRAGMA statement with assignment should contain exactly one parameter"); } if (!info.named_parameters.empty()) { - throw InternalException("PRAGMA statement with assignment cannot have named parameters"); + throw ParserException("PRAGMA statement with assignment cannot have named parameters"); } // SQLite does not distinguish between: // "PRAGMA table_info='integers'" diff --git a/src/duckdb/src/parser/transform/statement/transform_show.cpp b/src/duckdb/src/parser/transform/statement/transform_show.cpp index 648dbb9b7..ebd0ebd8d 100644 --- a/src/duckdb/src/parser/transform/statement/transform_show.cpp +++ b/src/duckdb/src/parser/transform/statement/transform_show.cpp @@ -15,8 +15,23 @@ unique_ptr Transformer::TransformShow(duckdb_libpgquery::PGVariableSh select_node->select_list.push_back(make_uniq()); auto showref = make_uniq(); if (stmt.set) { - // describing a set (e.g. SHOW ALL TABLES) - push it in the table name - showref->table_name = stmt.set; + if (std::string(stmt.set) == "__show_tables_from_database") { + showref->show_type = ShowType::SHOW_FROM; + auto qualified_name = TransformQualifiedName(*stmt.relation); + if (!IsInvalidCatalog(qualified_name.catalog)) { + throw ParserException("Expected \"SHOW TABLES FROM database\", \"SHOW TABLES FROM schema\", or " + "\"SHOW TABLES FROM database.schema\""); + } + if (qualified_name.schema.empty()) { + showref->schema_name = qualified_name.name; + } else { + showref->catalog_name = qualified_name.schema; + showref->schema_name = qualified_name.name; + } + } else { + // describing a set (e.g. SHOW ALL TABLES) - push it in the table name + showref->table_name = stmt.set; + } } else if (!stmt.relation->schemaname) { // describing an unqualified relation - check if this is a "special" relation string table_name = StringUtil::Lower(stmt.relation->relname); @@ -24,7 +39,7 @@ unique_ptr Transformer::TransformShow(duckdb_libpgquery::PGVariableSh showref->table_name = "\"" + std::move(table_name) + "\""; } } - if (showref->table_name.empty()) { + if (showref->table_name.empty() && showref->show_type != ShowType::SHOW_FROM) { // describing a single relation // wrap the relation in a "SELECT * FROM [table_name]" query auto show_select_node = make_uniq(); @@ -34,7 +49,10 @@ unique_ptr Transformer::TransformShow(duckdb_libpgquery::PGVariableSh showref->query = std::move(show_select_node); } - showref->show_type = stmt.is_summary ? ShowType::SUMMARY : ShowType::DESCRIBE; + // If the show type is set to default, check if summary + if (showref->show_type == ShowType::DESCRIBE) { + showref->show_type = stmt.is_summary ? ShowType::SUMMARY : ShowType::DESCRIBE; + } select_node->from_table = std::move(showref); return std::move(select_node); } diff --git a/src/duckdb/src/parser/transform/tableref/transform_from.cpp b/src/duckdb/src/parser/transform/tableref/transform_from.cpp index 84ec8ea51..db515e7c1 100644 --- a/src/duckdb/src/parser/transform/tableref/transform_from.cpp +++ b/src/duckdb/src/parser/transform/tableref/transform_from.cpp @@ -10,8 +10,9 @@ unique_ptr Transformer::TransformFrom(optional_ptrlength > 1) { - // Cross Product + // Implicit Cross Product auto result = make_uniq(JoinRefType::CROSS); + result->is_implicit = true; JoinRef *cur_root = result.get(); idx_t list_size = 0; for (auto node = root->head; node != nullptr; node = node->next) { @@ -24,6 +25,7 @@ unique_ptr Transformer::TransformFrom(optional_ptr(JoinRefType::CROSS); + result->is_implicit = true; result->left = std::move(old_res); result->right = std::move(next); cur_root = result.get(); diff --git a/src/duckdb/src/parser/transform/tableref/transform_table_function.cpp b/src/duckdb/src/parser/transform/tableref/transform_table_function.cpp index 43dccc743..686c7dbe3 100644 --- a/src/duckdb/src/parser/transform/tableref/transform_table_function.cpp +++ b/src/duckdb/src/parser/transform/tableref/transform_table_function.cpp @@ -5,9 +5,6 @@ namespace duckdb { unique_ptr Transformer::TransformRangeFunction(duckdb_libpgquery::PGRangeFunction &root) { - if (root.ordinality) { - throw NotImplementedException("WITH ORDINALITY not implemented"); - } if (root.is_rowsfrom) { throw NotImplementedException("ROWS FROM() not implemented"); } @@ -24,6 +21,9 @@ unique_ptr Transformer::TransformRangeFunction(duckdb_libpgquery::PGRa } // transform the function call auto result = make_uniq(); + if (root.ordinality) { + result->with_ordinality = OrdinalityType::WITH_ORDINALITY; + } switch (call_tree->type) { case duckdb_libpgquery::T_PGFuncCall: { auto func_call = PGPointerCast(call_tree.get()); diff --git a/src/duckdb/src/planner/bind_context.cpp b/src/duckdb/src/planner/bind_context.cpp index 7f1382dfc..b6e5df81f 100644 --- a/src/duckdb/src/planner/bind_context.cpp +++ b/src/duckdb/src/planner/bind_context.cpp @@ -625,6 +625,13 @@ void BindContext::AddBinding(unique_ptr binding) { bindings_list.push_back(std::move(binding)); } +void BindContext::AddBaseTable(idx_t index, const string &alias, const vector &names, + const vector &types, vector &bound_column_ids, + TableCatalogEntry &entry, virtual_column_map_t virtual_columns) { + AddBinding( + make_uniq(alias, types, names, bound_column_ids, &entry, index, std::move(virtual_columns))); +} + void BindContext::AddBaseTable(idx_t index, const string &alias, const vector &names, const vector &types, vector &bound_column_ids, TableCatalogEntry &entry, bool add_virtual_columns) { @@ -632,8 +639,7 @@ void BindContext::AddBaseTable(idx_t index, const string &alias, const vector(alias, types, names, bound_column_ids, &entry, index, std::move(virtual_columns))); + AddBaseTable(index, alias, names, types, bound_column_ids, entry, std::move(virtual_columns)); } void BindContext::AddBaseTable(idx_t index, const string &alias, const vector &names, diff --git a/src/duckdb/src/planner/binder.cpp b/src/duckdb/src/planner/binder.cpp index 5006b5c2b..1246f5a25 100644 --- a/src/duckdb/src/planner/binder.cpp +++ b/src/duckdb/src/planner/binder.cpp @@ -29,36 +29,31 @@ namespace duckdb { Binder &Binder::GetRootBinder() { - reference root = *this; - while (root.get().parent) { - root = *root.get().parent; - } - return root.get(); + return root_binder; } idx_t Binder::GetBinderDepth() const { - const_reference root = *this; - idx_t depth = 1; - while (root.get().parent) { - depth++; - root = *root.get().parent; - } return depth; } -shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent, BinderType binder_type) { - auto depth = parent ? parent->GetBinderDepth() : 0; +void Binder::IncreaseDepth() { + depth++; if (depth > context.config.max_expression_depth) { throw BinderException("Max expression depth limit of %lld exceeded. Use \"SET max_expression_depth TO x\" to " "increase the maximum expression depth.", context.config.max_expression_depth); } +} + +shared_ptr Binder::CreateBinder(ClientContext &context, optional_ptr parent, BinderType binder_type) { return shared_ptr(new Binder(context, parent ? parent->shared_from_this() : nullptr, binder_type)); } Binder::Binder(ClientContext &context, shared_ptr parent_p, BinderType binder_type) : context(context), bind_context(*this), parent(std::move(parent_p)), bound_tables(0), binder_type(binder_type), - entry_retriever(context) { + entry_retriever(context), root_binder(parent ? parent->GetRootBinder() : *this), + depth(parent ? parent->GetBinderDepth() : 1) { + IncreaseDepth(); if (parent) { entry_retriever.Inherit(parent->entry_retriever); @@ -547,10 +542,10 @@ void VerifyNotExcluded(const ParsedExpression &root_expr) { BoundStatement Binder::BindReturning(vector> returning_list, TableCatalogEntry &table, const string &alias, idx_t update_table_index, - unique_ptr child_operator, BoundStatement result) { + unique_ptr child_operator, virtual_column_map_t virtual_columns) { vector types; - vector names; + vector names; auto binder = Binder::CreateBinder(context); @@ -565,12 +560,14 @@ BoundStatement Binder::BindReturning(vector> return column_count++; } - binder->bind_context.AddBaseTable(update_table_index, alias, names, types, bound_columns, table, false); + binder->bind_context.AddBaseTable(update_table_index, alias, names, types, bound_columns, table, + std::move(virtual_columns)); ReturningBinder returning_binder(*binder, context); vector> projection_expressions; LogicalType result_type; vector> new_returning_list; + BoundStatement result; binder->ExpandStarExpressions(returning_list, new_returning_list); for (auto &returning_expr : new_returning_list) { VerifyNotExcluded(*returning_expr); diff --git a/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp index beba67f83..6496fcc68 100644 --- a/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_aggregate_expression.cpp @@ -18,6 +18,7 @@ #include "duckdb/planner/expression_binder/base_select_binder.hpp" #include "duckdb/planner/expression_iterator.hpp" #include "duckdb/planner/query_node/bound_select_node.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -139,8 +140,7 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu auto &config = DBConfig::GetConfig(context); const auto &order = aggr.order_bys->orders[0]; - const auto sense = - (order.type == OrderType::ORDER_DEFAULT) ? config.options.default_order_type : order.type; + const auto sense = config.ResolveOrder(context, order.type); negate_fractions = (sense == OrderType::DESCENDING); } } @@ -160,8 +160,8 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu if (order.expression->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { auto &const_expr = order.expression->Cast(); if (!const_expr.value.type().IsIntegral()) { - auto &config = ClientConfig::GetConfig(context); - if (!config.order_by_non_integer_literal) { + auto order_by_non_integer_literal = DBConfig::GetSetting(context); + if (!order_by_non_integer_literal) { throw BinderException( *order.expression, "ORDER BY non-integer literal has no effect.\n* SET order_by_non_integer_literal=true to " @@ -272,8 +272,8 @@ BindResult BaseSelectBinder::BindAggregate(FunctionExpression &aggr, AggregateFu for (auto &order : aggr.order_bys->orders) { auto &order_expr = BoundExpression::GetExpression(*order.expression); PushCollation(context, order_expr, order_expr->return_type); - const auto sense = config.ResolveOrder(order.type); - const auto null_order = config.ResolveNullOrder(sense, order.null_order); + const auto sense = config.ResolveOrder(context, order.type); + const auto null_order = config.ResolveNullOrder(context, sense, order.null_order); order_bys->orders.emplace_back(sense, null_order, std::move(order_expr)); } } diff --git a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp index 1f69fc565..46f92b3b5 100644 --- a/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp +++ b/src/duckdb/src/planner/binder/expression/bind_window_expression.cpp @@ -321,13 +321,13 @@ BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) { LogicalType start_type = LogicalType::BIGINT; if (window.start == WindowBoundary::EXPR_PRECEDING_RANGE) { D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); + range_sense = config.ResolveOrder(context, window.orders[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; start_type = BindRangeExpression(context, range_name, window.start_expr, window.orders[0].expression); } else if (window.start == WindowBoundary::EXPR_FOLLOWING_RANGE) { D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); + range_sense = config.ResolveOrder(context, window.orders[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; start_type = BindRangeExpression(context, range_name, window.start_expr, window.orders[0].expression); } @@ -335,13 +335,13 @@ BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) { LogicalType end_type = LogicalType::BIGINT; if (window.end == WindowBoundary::EXPR_PRECEDING_RANGE) { D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); + range_sense = config.ResolveOrder(context, window.orders[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "-" : "+"; end_type = BindRangeExpression(context, range_name, window.end_expr, window.orders[0].expression); } else if (window.end == WindowBoundary::EXPR_FOLLOWING_RANGE) { D_ASSERT(window.orders.size() == 1); - range_sense = config.ResolveOrder(window.orders[0].type); + range_sense = config.ResolveOrder(context, window.orders[0].type); const auto range_name = (range_sense == OrderType::ASCENDING) ? "+" : "-"; end_type = BindRangeExpression(context, range_name, window.end_expr, window.orders[0].expression); } @@ -368,16 +368,16 @@ BindResult BaseSelectBinder::BindWindow(WindowExpression &window, idx_t depth) { } for (auto &order : window.orders) { - auto type = config.ResolveOrder(order.type); - auto null_order = config.ResolveNullOrder(type, order.null_order); + auto type = config.ResolveOrder(context, order.type); + auto null_order = config.ResolveNullOrder(context, type, order.null_order); auto expression = GetExpression(order.expression); result->orders.emplace_back(type, null_order, std::move(expression)); } // Argument orders are just like arguments, not frames for (auto &order : window.arg_orders) { - auto type = config.ResolveOrder(order.type); - auto null_order = config.ResolveNullOrder(type, order.null_order); + auto type = config.ResolveOrder(context, order.type); + auto null_order = config.ResolveNullOrder(context, type, order.null_order); auto expression = GetExpression(order.expression); result->arg_orders.emplace_back(type, null_order, std::move(expression)); } diff --git a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp index 126ae001f..4f52dfc4a 100644 --- a/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp +++ b/src/duckdb/src/planner/binder/query_node/bind_select_node.cpp @@ -165,8 +165,8 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B if (star.exclude_list.empty() && star.replace_list.empty() && !star.expr) { // ORDER BY ALL // replace the order list with the all elements in the SELECT list - auto order_type = config.ResolveOrder(order.orders[0].type); - auto null_order = config.ResolveNullOrder(order_type, order.orders[0].null_order); + auto order_type = config.ResolveOrder(context, order.orders[0].type); + auto null_order = config.ResolveNullOrder(context, order_type, order.orders[0].null_order); auto constant_expr = make_uniq(Value("ALL")); bound_order->orders.emplace_back(order_type, null_order, std::move(constant_expr)); bound_modifier = std::move(bound_order); @@ -193,8 +193,8 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B vector> sort_key_parameters; for (auto &order_node : order.orders) { sort_key_parameters.push_back(std::move(order_node.expression)); - auto type = config.ResolveOrder(order_node.type); - auto null_order = config.ResolveNullOrder(type, order_node.null_order); + auto type = config.ResolveOrder(context, order_node.type); + auto null_order = config.ResolveNullOrder(context, type, order_node.null_order); string sort_param = EnumUtil::ToString(type) + " " + EnumUtil::ToString(null_order); sort_key_parameters.push_back(make_uniq(Value(sort_param))); } @@ -207,8 +207,8 @@ void Binder::PrepareModifiers(OrderBinder &order_binder, QueryNode &statement, B vector> order_list; order_binders[0].get().ExpandStarExpression(std::move(order_node.expression), order_list); - auto type = config.ResolveOrder(order_node.type); - auto null_order = config.ResolveNullOrder(type, order_node.null_order); + auto type = config.ResolveOrder(context, order_node.type); + auto null_order = config.ResolveNullOrder(context, type, order_node.null_order); for (auto &order_expr : order_list) { auto bound_expr = BindOrderExpression(order_binder, std::move(order_expr)); if (!bound_expr) { diff --git a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp index 7627996ba..16f0c8fdd 100644 --- a/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp +++ b/src/duckdb/src/planner/binder/query_node/plan_subquery.cpp @@ -19,6 +19,7 @@ #include "duckdb/planner/operator/logical_dependent_join.hpp" #include "duckdb/planner/subquery/recursive_dependent_join_planner.hpp" #include "duckdb/function/scalar/generic_functions.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -78,8 +79,7 @@ static unique_ptr PlanUncorrelatedSubquery(Binder &binder, BoundSubq D_ASSERT(bindings.size() == 1); idx_t table_idx = bindings[0].table_index; - auto &config = ClientConfig::GetConfig(binder.context); - bool error_on_multiple_rows = config.scalar_subquery_error_on_multiple_rows; + bool error_on_multiple_rows = DBConfig::GetSetting(binder.context); // we push an aggregate that returns the FIRST element vector> expressions; @@ -372,6 +372,7 @@ unique_ptr Binder::PlanSubquery(BoundSubqueryExpression &expr, uniqu } else { result_expression = PlanCorrelatedSubquery(*this, expr, root, std::move(plan)); } + IncreaseDepth(); // finally, we recursively plan the nested subqueries (if there are any) if (sub_binder->has_unplanned_dependent_joins) { RecursiveDependentJoinPlanner plan(*this); diff --git a/src/duckdb/src/planner/binder/statement/bind_attach.cpp b/src/duckdb/src/planner/binder/statement/bind_attach.cpp index 480324161..0e8655d2f 100644 --- a/src/duckdb/src/planner/binder/statement/bind_attach.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_attach.cpp @@ -3,6 +3,8 @@ #include "duckdb/parser/tableref/table_function_ref.hpp" #include "duckdb/planner/tableref/bound_table_function.hpp" #include "duckdb/planner/operator/logical_simple.hpp" +#include "duckdb/planner/expression_binder/table_function_binder.hpp" +#include "duckdb/execution/expression_executor.hpp" namespace duckdb { @@ -11,6 +13,19 @@ BoundStatement Binder::Bind(AttachStatement &stmt) { result.types = {LogicalType::BOOLEAN}; result.names = {"Success"}; + // bind the options + TableFunctionBinder option_binder(*this, context, "Attach", "Attach parameter"); + unordered_map kv_options; + for (auto &entry : stmt.info->parsed_options) { + auto bound_expr = option_binder.Bind(entry.second); + auto val = ExpressionExecutor::EvaluateScalar(context, *bound_expr); + if (val.IsNull()) { + throw BinderException("NULL is not supported as a valid option for ATTACH option \"" + entry.first + "\""); + } + stmt.info->options[entry.first] = std::move(val); + } + stmt.info->parsed_options.clear(); + result.plan = make_uniq(LogicalOperatorType::LOGICAL_ATTACH, std::move(stmt.info)); auto &properties = GetStatementProperties(); diff --git a/src/duckdb/src/planner/binder/statement/bind_copy.cpp b/src/duckdb/src/planner/binder/statement/bind_copy.cpp index 21d20d8ff..0f2a77f2e 100644 --- a/src/duckdb/src/planner/binder/statement/bind_copy.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_copy.cpp @@ -19,8 +19,8 @@ #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_insert.hpp" #include "duckdb/planner/operator/logical_projection.hpp" - -#include +#include "duckdb/planner/expression_binder/table_function_binder.hpp" +#include "duckdb/common/algorithm.hpp" #include "duckdb/main/extension_entries.hpp" @@ -383,9 +383,9 @@ BoundStatement Binder::BindCopyFrom(CopyStatement &stmt) { expected_names.push_back(col.Name()); } } - + CopyFromFunctionBindInput input(*stmt.info, copy_function.function.copy_from_function); auto function_data = - copy_function.function.copy_from_bind(context, *stmt.info, expected_names, bound_insert.expected_types); + copy_function.function.copy_from_bind(context, input, expected_names, bound_insert.expected_types); auto get = make_uniq(GenerateTableIndex(), copy_function.function.copy_from_function, std::move(function_data), bound_insert.expected_types, expected_names); for (idx_t i = 0; i < bound_insert.expected_types.size(); i++) { @@ -396,7 +396,56 @@ BoundStatement Binder::BindCopyFrom(CopyStatement &stmt) { return result; } +vector BindCopyOption(ClientContext &context, TableFunctionBinder &option_binder, const string &name, + unique_ptr &expr) { + vector result; + if (!expr) { + return result; + } + if (expr->type == ExpressionType::STAR) { + auto &star = expr->Cast(); + // for compatibility with previous copy implementation - turn a raw * into a * string literal + if (star.relation_name.empty() && star.exclude_list.empty() && star.replace_list.empty() && + star.rename_list.empty() && !star.expr && !star.columns) { + result.push_back("*"); + return result; + } + } + auto bound_expr = option_binder.Bind(expr); + auto val = ExpressionExecutor::EvaluateScalar(context, *bound_expr); + if (val.IsNull()) { + throw BinderException("NULL is not supported as a valid option for COPY option \"" + name + "\""); + } + if (val.type().id() == LogicalTypeId::STRUCT && StructType::IsUnnamed(val.type())) { + // unpack unnamed structs into a list of options + return StructValue::GetChildren(val); + } + result.push_back(std::move(val)); + return result; +} + +void Binder::BindCopyOptions(CopyInfo &info) { + TableFunctionBinder option_binder(*this, context, "Copy", "Copy options"); + for (auto &entry : info.parsed_options) { + auto inputs = BindCopyOption(context, option_binder, entry.first, entry.second); + if (StringUtil::Lower(entry.first) == "format") { + // format specifier: interpret this option + if (inputs.size() != 1 || inputs[0].type().id() != LogicalTypeId::VARCHAR) { + throw ParserException("Unsupported parameter type for FORMAT: expected e.g. FORMAT 'csv', 'parquet'"); + } + info.format = StringUtil::Lower(inputs[0].ToString()); + info.is_format_auto_detected = false; + continue; + } + info.options[entry.first] = std::move(inputs); + } + info.parsed_options.clear(); +} + BoundStatement Binder::Bind(CopyStatement &stmt, CopyToType copy_to_type) { + // bind the copy options + BindCopyOptions(*stmt.info); + if (!stmt.info->is_from && !stmt.info->select_statement) { // copy table into file without a query // generate SELECT * FROM table; diff --git a/src/duckdb/src/planner/binder/statement/bind_create.cpp b/src/duckdb/src/planner/binder/statement/bind_create.cpp index c44eb0992..c450a47df 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create.cpp @@ -43,6 +43,7 @@ #include "duckdb/storage/storage_extension.hpp" #include "duckdb/common/extension_type_info.hpp" #include "duckdb/common/type_visitor.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -155,8 +156,7 @@ void Binder::BindCreateViewInfo(CreateViewInfo &base) { auto &dependencies = base.dependencies; auto &catalog = Catalog::GetCatalog(context, base.catalog); - auto &db_config = DBConfig::GetConfig(context); - bool should_create_dependencies = db_config.GetSetting(context); + bool should_create_dependencies = DBConfig::GetSetting(context); if (should_create_dependencies) { view_binder->SetCatalogLookupCallback([&dependencies, &catalog](CatalogEntry &entry) { if (&catalog != &entry.ParentCatalog()) { @@ -186,7 +186,6 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { auto &dependencies = base.dependencies; auto &catalog = Catalog::GetCatalog(context, info.catalog); - auto &db_config = DBConfig::GetConfig(context); // try to bind each of the included functions unordered_set positional_parameters; for (auto &function : base.macros) { @@ -230,7 +229,7 @@ SchemaCatalogEntry &Binder::BindCreateFunctionInfo(CreateInfo &info) { BoundSelectNode sel_node; BoundGroupInformation group_info; SelectBinder binder(*this, context, sel_node, group_info); - bool should_create_dependencies = db_config.GetSetting(context); + bool should_create_dependencies = DBConfig::GetSetting(context); if (should_create_dependencies) { binder.SetCatalogLookupCallback([&dependencies, &catalog](CatalogEntry &entry) { diff --git a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp index 686b99676..7973de7a3 100644 --- a/src/duckdb/src/planner/binder/statement/bind_create_table.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_create_table.cpp @@ -35,8 +35,8 @@ static void CreateColumnDependencyManager(BoundCreateTableInfo &info) { } } -static void VerifyCompressionType(optional_ptr storage_manager, DBConfig &config, - BoundCreateTableInfo &info) { +static void VerifyCompressionType(ClientContext &context, optional_ptr storage_manager, + DBConfig &config, BoundCreateTableInfo &info) { auto &base = info.base->Cast(); for (auto &col : base.columns.Logical()) { auto compression_type = col.CompressionType(); @@ -45,7 +45,15 @@ static void VerifyCompressionType(optional_ptr storage_manager, "and only has decompress support", CompressionTypeToString(compression_type)); } - const auto &logical_type = col.GetType(); + auto logical_type = col.GetType(); + if (logical_type.id() == LogicalTypeId::USER && logical_type.HasAlias()) { + // Resolve user type if possible + const auto type_entry = Catalog::GetEntry( + context, INVALID_CATALOG, INVALID_SCHEMA, logical_type.GetAlias(), OnEntryNotFound::RETURN_NULL); + if (type_entry) { + logical_type = type_entry->user_type; + } + } auto physical_type = logical_type.InternalType(); if (compression_type == CompressionType::COMPRESSION_AUTO) { continue; @@ -630,7 +638,7 @@ unique_ptr Binder::BindCreateTableInfo(unique_ptrtype != TableReferenceType::BASE_TABLE) { @@ -84,8 +82,9 @@ BoundStatement Binder::Bind(DeleteStatement &stmt) { unique_ptr del_as_logicaloperator = std::move(del); return BindReturning(std::move(stmt.returning_list), table, stmt.table->alias, update_table_index, - std::move(del_as_logicaloperator), std::move(result)); + std::move(del_as_logicaloperator)); } + BoundStatement result; result.plan = std::move(del); result.names = {"Count"}; result.types = {LogicalType::BIGINT}; diff --git a/src/duckdb/src/planner/binder/statement/bind_export.cpp b/src/duckdb/src/planner/binder/statement/bind_export.cpp index 68b5b80d2..7573f0c05 100644 --- a/src/duckdb/src/planner/binder/statement/bind_export.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_export.cpp @@ -163,6 +163,9 @@ BoundStatement Binder::Bind(ExportStatement &stmt) { result.types = {LogicalType::BOOLEAN}; result.names = {"Success"}; + // bind copy options + BindCopyOptions(*stmt.info); + // lookup the format in the catalog auto ©_function = Catalog::GetEntry(context, INVALID_CATALOG, DEFAULT_SCHEMA, stmt.info->format); diff --git a/src/duckdb/src/planner/binder/statement/bind_insert.cpp b/src/duckdb/src/planner/binder/statement/bind_insert.cpp index a3cc9369d..dd04baa5f 100644 --- a/src/duckdb/src/planner/binder/statement/bind_insert.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_insert.cpp @@ -1,7 +1,10 @@ #include "duckdb/catalog/catalog.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" +#include "duckdb/parser/expression/conjunction_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" #include "duckdb/parser/statement/insert_statement.hpp" +#include "duckdb/parser/statement/merge_into_statement.hpp" #include "duckdb/parser/query_node/select_node.hpp" #include "duckdb/parser/tableref/expressionlistref.hpp" #include "duckdb/planner/binder.hpp" @@ -25,6 +28,8 @@ #include "duckdb/parser/parsed_expression_iterator.hpp" #include "duckdb/storage/table_storage_info.hpp" #include "duckdb/parser/tableref/basetableref.hpp" +#include "duckdb/parser/tableref/emptytableref.hpp" +#include "duckdb/parser/tableref/subqueryref.hpp" namespace duckdb { @@ -54,8 +59,47 @@ void Binder::TryReplaceDefaultExpression(unique_ptr &expr, con expr = ExpandDefaultExpression(column); } -void ExpressionBinder::DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &table_name, - vector> &lambda_params) { +void Binder::ExpandDefaultInValuesList(InsertStatement &stmt, TableCatalogEntry &table, + optional_ptr values_list, + const vector &named_column_map) { + if (!values_list) { + return; + } + idx_t expected_columns = stmt.columns.empty() ? table.GetColumns().PhysicalColumnCount() : stmt.columns.size(); + + // special case: check if we are inserting from a VALUES statement + if (values_list) { + auto &expr_list = values_list->Cast(); + expr_list.expected_types.resize(expected_columns); + expr_list.expected_names.resize(expected_columns); + + D_ASSERT(!expr_list.values.empty()); + CheckInsertColumnCountMismatch(expected_columns, expr_list.values[0].size(), !stmt.columns.empty(), + table.name.c_str()); + + // VALUES list! + for (idx_t col_idx = 0; col_idx < expected_columns; col_idx++) { + D_ASSERT(named_column_map.size() >= col_idx); + auto &table_col_idx = named_column_map[col_idx]; + + // set the expected types as the types for the INSERT statement + auto &column = table.GetColumn(table_col_idx); + expr_list.expected_types[col_idx] = column.Type(); + expr_list.expected_names[col_idx] = column.Name(); + + // now replace any DEFAULT values with the corresponding default expression + for (idx_t list_idx = 0; list_idx < expr_list.values.size(); list_idx++) { + TryReplaceDefaultExpression(expr_list.values[list_idx][col_idx], column); + } + } + } +} + +void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, + vector> &lambda_params); + +void DoUpdateSetQualifyInLambda(FunctionExpression &function, const string &table_name, + vector> &lambda_params) { for (auto &child : function.children) { if (child->GetExpressionClass() != ExpressionClass::LAMBDA) { @@ -96,8 +140,8 @@ void ExpressionBinder::DoUpdateSetQualifyInLambda(FunctionExpression &function, } } -void ExpressionBinder::DoUpdateSetQualify(unique_ptr &expr, const string &table_name, - vector> &lambda_params) { +void DoUpdateSetQualify(unique_ptr &expr, const string &table_name, + vector> &lambda_params) { // We avoid ambiguity with EXCLUDED columns by qualifying all column references. switch (expr->GetExpressionClass()) { @@ -135,85 +179,6 @@ void ExpressionBinder::DoUpdateSetQualify(unique_ptr &expr, co *expr, [&](unique_ptr &child) { DoUpdateSetQualify(child, table_name, lambda_params); }); } -// Replace binding.table_index with 'dest' if it's 'source' -void ReplaceColumnBindings(Expression &expr, idx_t source, idx_t dest) { - if (expr.GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_columnref = expr.Cast(); - if (bound_columnref.binding.table_index == source) { - bound_columnref.binding.table_index = dest; - } - } - ExpressionIterator::EnumerateChildren( - expr, [&](unique_ptr &child) { ReplaceColumnBindings(*child, source, dest); }); -} - -void Binder::BindDoUpdateSetExpressions(const string &table_alias, LogicalInsert &insert, UpdateSetInfo &set_info, - TableCatalogEntry &table, TableStorageInfo &storage_info) { - D_ASSERT(insert.children.size() == 1); - - vector logical_column_ids; - vector column_names; - D_ASSERT(set_info.columns.size() == set_info.expressions.size()); - - for (idx_t i = 0; i < set_info.columns.size(); i++) { - auto &colname = set_info.columns[i]; - auto &expr = set_info.expressions[i]; - if (!table.ColumnExists(colname)) { - throw BinderException("Referenced update column %s not found in table!", colname); - } - auto &column = table.GetColumn(colname); - if (column.Generated()) { - throw BinderException("Cant update column \"%s\" because it is a generated column!", column.Name()); - } - if (std::find(insert.set_columns.begin(), insert.set_columns.end(), column.Physical()) != - insert.set_columns.end()) { - throw BinderException("Multiple assignments to same column \"%s\"", colname); - } - - if (!column.Type().SupportsRegularUpdate()) { - insert.update_is_del_and_insert = true; - } - - insert.set_columns.push_back(column.Physical()); - logical_column_ids.push_back(column.Oid()); - insert.set_types.push_back(column.Type()); - column_names.push_back(colname); - if (expr->GetExpressionType() == ExpressionType::VALUE_DEFAULT) { - expr = ExpandDefaultExpression(column); - } - - // Qualify and bind the ON CONFLICT DO UPDATE SET expression. - UpdateBinder update_binder(*this, context); - update_binder.target_type = column.Type(); - - // Avoid ambiguity between existing table columns and EXCLUDED columns. - vector> lambda_params; - update_binder.DoUpdateSetQualify(expr, table_alias, lambda_params); - - auto bound_expr = update_binder.Bind(expr); - D_ASSERT(bound_expr); - insert.expressions.push_back(std::move(bound_expr)); - } - - // Figure out which columns are indexed on - unordered_set indexed_columns; - for (auto &index : storage_info.index_info) { - for (auto &column_id : index.column_set) { - indexed_columns.insert(column_id); - } - } - - // If any column targeted by a SET expression has an index, then - // we need to rewrite this to an DELETE + INSERT. - for (idx_t i = 0; i < logical_column_ids.size(); i++) { - auto &column = logical_column_ids[i]; - if (indexed_columns.count(column)) { - insert.update_is_del_and_insert = true; - break; - } - } -} - unique_ptr CreateSetInfoForReplace(TableCatalogEntry &table, InsertStatement &insert, TableStorageInfo &storage_info) { auto set_info = make_uniq(); @@ -257,47 +222,129 @@ unique_ptr CreateSetInfoForReplace(TableCatalogEntry &table, Inse return set_info; } -vector GetColumnsToFetch(const TableBinding &binding) { - auto &bound_columns = binding.GetBoundColumnIds(); - vector result; - for (auto &col : bound_columns) { - result.push_back(col.GetPrimaryIndex()); - } - return result; -} - -void Binder::BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &table, InsertStatement &stmt) { - if (!stmt.on_conflict_info) { - insert.action_type = OnConflictAction::THROW; - return; - } - D_ASSERT(stmt.table_ref->type == TableReferenceType::BASE_TABLE); +void Binder::BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, + vector &named_column_map, vector &expected_types, + IndexVector &column_index_map) { + if (!columns.empty() || default_values) { + // insertion statement specifies column list - // visit the table reference - auto bound_table = Bind(*stmt.table_ref); - if (bound_table->type != TableReferenceType::BASE_TABLE) { - throw BinderException("Can only update base table!"); + // create a mapping of (list index) -> (column index) + case_insensitive_map_t column_name_map; + for (idx_t i = 0; i < columns.size(); i++) { + auto entry = column_name_map.insert(make_pair(columns[i], i)); + if (!entry.second) { + throw BinderException("Duplicate column name \"%s\" in INSERT", columns[i]); + } + auto column_index = table.GetColumnIndex(columns[i]); + if (column_index.index == COLUMN_IDENTIFIER_ROW_ID) { + throw BinderException("Cannot explicitly insert values into rowid column"); + } + auto &col = table.GetColumn(column_index); + if (col.Generated()) { + throw BinderException("Cannot insert into a generated column"); + } + expected_types.push_back(col.Type()); + named_column_map.push_back(column_index); + } + for (auto &col : table.GetColumns().Physical()) { + auto entry = column_name_map.find(col.Name()); + if (entry == column_name_map.end()) { + // column not specified, set index to DConstants::INVALID_INDEX + column_index_map.push_back(DConstants::INVALID_INDEX); + } else { + // column was specified, set to the index + column_index_map.push_back(entry->second); + } + } + } else { + // insert by position and no columns specified - insertion into all columns of the table + // intentionally don't populate 'column_index_map' as an indication of this + for (auto &col : table.GetColumns().Physical()) { + named_column_map.push_back(col.Logical()); + expected_types.push_back(col.Type()); + } } +} - auto &table_ref = stmt.table_ref->Cast(); - const string &table_alias = !table_ref.alias.empty() ? table_ref.alias : table_ref.table_name; +unique_ptr Binder::GenerateMergeInto(InsertStatement &stmt, TableCatalogEntry &table) { + D_ASSERT(stmt.on_conflict_info); - auto &on_conflict = *stmt.on_conflict_info; - D_ASSERT(on_conflict.action_type != OnConflictAction::THROW); - insert.action_type = on_conflict.action_type; + auto &on_conflict_info = *stmt.on_conflict_info; + auto merge_into = make_uniq(); + // set up the target table + string table_name = !stmt.table_ref->alias.empty() ? stmt.table_ref->alias : stmt.table; + merge_into->target = std::move(stmt.table_ref); - // obtain the table storage info auto storage_info = table.GetStorageInfo(context); - auto &columns = table.GetColumns(); - if (!on_conflict.indexed_columns.empty()) { - // Bind the ON CONFLICT () + // set up the columns on which to join + vector distinct_on_columns; + if (on_conflict_info.indexed_columns.empty()) { + // When omitting the conflict target, we derive the join columns from the primary key/unique constraints + // traverse the primary key/unique constraints - // create a mapping of (list index) -> (column index) + vector> join_conditions; + // We check if there are any constraints on the table, if there aren't we throw an error. + idx_t found_matching_indexes = 0; + for (auto &index : storage_info.index_info) { + if (!index.is_unique) { + continue; + } + + vector> and_children; + auto &indexed_columns = index.column_set; + for (auto &column : columns.Physical()) { + if (!indexed_columns.count(column.Physical().index)) { + continue; + } + auto lhs = make_uniq(column.Name(), table_name); + auto rhs = make_uniq(column.Name(), "excluded"); + auto new_condition = + make_uniq(ExpressionType::COMPARE_EQUAL, std::move(lhs), std::move(rhs)); + and_children.push_back(std::move(new_condition)); + distinct_on_columns.push_back(column.Name()); + } + if (and_children.empty()) { + continue; + } + unique_ptr condition; + if (and_children.size() == 1) { + condition = std::move(and_children[0]); + } else { + // AND together + condition = make_uniq(ExpressionType::CONJUNCTION_AND, std::move(and_children)); + } + join_conditions.push_back(std::move(condition)); + found_matching_indexes++; + } + unique_ptr join_condition; + if (join_conditions.size() == 1) { + join_condition = std::move(join_conditions[0]); + } else { + // OR together + join_condition = + make_uniq(ExpressionType::CONJUNCTION_OR, std::move(join_conditions)); + } + merge_into->join_condition = std::move(join_condition); + + if (!found_matching_indexes) { + throw BinderException("There are no UNIQUE/PRIMARY KEY constraints that refer to this table, specify ON " + "CONFLICT columns manually"); + } else if (found_matching_indexes != 1) { + if (on_conflict_info.action_type != OnConflictAction::NOTHING) { + // When no conflict target is provided, and the action type is UPDATE, + // we only allow the operation when only a single Index exists + throw BinderException("Conflict target has to be provided for a DO UPDATE operation when the table has " + "multiple UNIQUE/PRIMARY KEY constraints"); + } + } + } else { + // when on conflict columns are explicitly provided - use them directly + // first figure out if there is an index on the columns or not case_insensitive_map_t specified_columns; - for (idx_t i = 0; i < on_conflict.indexed_columns.size(); i++) { - specified_columns[on_conflict.indexed_columns[i]] = i; - auto column_index = table.GetColumnIndex(on_conflict.indexed_columns[i]); + for (idx_t i = 0; i < on_conflict_info.indexed_columns.size(); i++) { + specified_columns[on_conflict_info.indexed_columns[i]] = i; + auto column_index = table.GetColumnIndex(on_conflict_info.indexed_columns[i]); if (column_index.index == COLUMN_IDENTIFIER_ROW_ID) { throw BinderException("Cannot specify ROWID as ON CONFLICT target"); } @@ -306,11 +353,12 @@ void Binder::BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &tabl throw BinderException("Cannot specify a generated column as ON CONFLICT target"); } } + unordered_set on_conflict_filter; for (auto &col : columns.Physical()) { auto entry = specified_columns.find(col.Name()); if (entry != specified_columns.end()) { // column was specified, set to the index - insert.on_conflict_filter.insert(col.Physical().index); + on_conflict_filter.insert(col.Physical().index); } } bool index_references_columns = false; @@ -318,7 +366,7 @@ void Binder::BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &tabl if (!index.is_unique) { continue; } - bool index_matches = insert.on_conflict_filter == index.column_set; + bool index_matches = on_conflict_filter == index.column_set; if (index_matches) { index_references_columns = true; break; @@ -330,194 +378,127 @@ void Binder::BindOnConflictClause(LogicalInsert &insert, TableCatalogEntry &tabl throw BinderException("The specified columns as conflict target are not referenced by a UNIQUE/PRIMARY KEY " "CONSTRAINT or INDEX"); } - } else { - // When omitting the conflict target, the ON CONFLICT applies to every UNIQUE/PRIMARY KEY on the table + distinct_on_columns = on_conflict_info.indexed_columns; + merge_into->using_columns = std::move(on_conflict_info.indexed_columns); + } - // We check if there are any constraints on the table, if there aren't we throw an error. - idx_t found_matching_indexes = 0; - for (auto &index : storage_info.index_info) { - if (!index.is_unique) { - continue; - } - auto &indexed_columns = index.column_set; - bool matches = false; - for (auto &column : table.GetColumns().Physical()) { - if (indexed_columns.count(column.Physical().index)) { - matches = true; - break; - } + // expand any default values + auto values_list = stmt.GetValuesList(); + if (values_list) { + vector named_column_map; + if (stmt.columns.empty()) { + for (auto &col : table.GetColumns().Physical()) { + named_column_map.push_back(col.Logical()); } - found_matching_indexes += matches; - } - - if (!found_matching_indexes) { - throw BinderException( - "There are no UNIQUE/PRIMARY KEY Indexes that refer to this table, ON CONFLICT is a no-op"); - } else if (found_matching_indexes != 1) { - if (insert.action_type != OnConflictAction::NOTHING) { - // When no conflict target is provided, and the action type is UPDATE, - // we only allow the operation when only a single Index exists - throw BinderException("Conflict target has to be provided for a DO UPDATE operation when the table has " - "multiple UNIQUE/PRIMARY KEY constraints"); + } else { + for (auto &col_name : stmt.columns) { + auto &col = table.GetColumn(col_name); + named_column_map.push_back(col.Logical()); } } + ExpandDefaultInValuesList(stmt, table, values_list, named_column_map); } + // set up the data source + unique_ptr source; + if (stmt.select_statement) { + source = make_uniq(std::move(stmt.select_statement), "excluded"); + } else { + source = make_uniq(); + } + if (stmt.column_order == InsertColumnOrder::INSERT_BY_POSITION) { + // if we are inserting by position add the columns of the target table as an alias to the source + if (!stmt.columns.empty() || stmt.default_values) { + // we are not emitting all columns - set the column set as the set of aliases + source->column_name_alias = stmt.columns; + + // now push another subquery that adds the default columns + auto select_stmt = make_uniq(); + auto select_node = make_uniq(); + unordered_set set_columns; + for (auto &set_col : stmt.columns) { + set_columns.insert(set_col); + } - // add the 'excluded' dummy table binding - AddTableName("excluded"); - // add a bind context entry for it - auto excluded_index = GenerateTableIndex(); - insert.excluded_table_index = excluded_index; - vector table_column_names; - vector table_column_types; - for (auto &col : columns.Physical()) { - table_column_names.push_back(col.Name()); - table_column_types.push_back(col.Type()); - } - bind_context.AddGenericBinding(excluded_index, "excluded", table_column_names, table_column_types); - - if (on_conflict.condition) { - WhereBinder where_binder(*this, context); - - // Avoid ambiguity between existing table columns and EXCLUDED columns. - vector> lambda_params; - where_binder.DoUpdateSetQualify(on_conflict.condition, table_alias, lambda_params); - - // Bind the ON CONFLICT ... WHERE clause. - auto condition = where_binder.Bind(on_conflict.condition); - insert.on_conflict_condition = std::move(condition); - } + for (auto &column : columns.Physical()) { + auto &name = column.Name(); + unique_ptr expr; + if (set_columns.find(name) == set_columns.end()) { + // column is not specified - at the default value + if (column.HasDefaultValue()) { + expr = column.DefaultValue().Copy(); + } else { + expr = make_uniq(Value(column.Type())); + } + } else { + // column is specified - add a reference to it + expr = make_uniq(name); + } + select_node->select_list.push_back(std::move(expr)); + } + select_node->from_table = std::move(source); + select_stmt->node = std::move(select_node); - optional_idx projection_index; - reference>> insert_child_operators = insert.children; - while (!projection_index.IsValid()) { - if (insert_child_operators.get().empty()) { - // No further children to visit - break; + source = make_uniq(std::move(select_stmt), "excluded"); } - auto ¤t_child = insert_child_operators.get()[0]; - auto table_indices = current_child->GetTableIndex(); - if (table_indices.empty()) { - // This operator does not have a table index to refer to, we have to visit its children - insert_child_operators = current_child->children; - continue; + // push all columns of the table as an alias + for (auto &column : columns.Physical()) { + source->column_name_alias.push_back(column.Name()); } - projection_index = table_indices[0]; - } - if (!projection_index.IsValid()) { - throw InternalException("Could not locate a table_index from the children of the insert"); } - - ErrorData unused; - auto original_binding = bind_context.GetBinding(table_alias, unused); - D_ASSERT(original_binding && !unused.HasError()); - - auto table_index = original_binding->index; - - // Replace any column bindings to refer to the projection table_index, rather than the source table - if (insert.on_conflict_condition) { - ReplaceColumnBindings(*insert.on_conflict_condition, table_index, projection_index.GetIndex()); - } - - if (insert.action_type == OnConflictAction::REPLACE) { - D_ASSERT(on_conflict.set_info == nullptr); - on_conflict.set_info = CreateSetInfoForReplace(table, stmt, storage_info); - insert.action_type = OnConflictAction::UPDATE; - } - if (on_conflict.set_info && on_conflict.set_info->columns.empty()) { - // if we are doing INSERT OR REPLACE on a table with no columns outside of the primary key column - // convert to INSERT OR IGNORE - insert.action_type = OnConflictAction::NOTHING; - } - if (insert.action_type == OnConflictAction::NOTHING) { - if (!insert.on_conflict_condition) { - return; - } - // Get the column_ids we need to fetch later on from the conflicting tuples - // of the original table, to execute the expressions - D_ASSERT(original_binding->binding_type == BindingType::TABLE); - auto &table_binding = original_binding->Cast(); - insert.columns_to_fetch = GetColumnsToFetch(table_binding); - return; + // push DISTINCT ON(unique_columns) + auto distinct_stmt = make_uniq(); + auto select_node = make_uniq(); + auto distinct = make_uniq(); + for (auto &col : distinct_on_columns) { + distinct->distinct_on_targets.push_back(make_uniq(col)); } + select_node->modifiers.push_back(std::move(distinct)); + select_node->select_list.push_back(make_uniq()); + select_node->from_table = std::move(source); + distinct_stmt->node = std::move(select_node); + source = make_uniq(std::move(distinct_stmt), "excluded"); - D_ASSERT(on_conflict.set_info); - auto &set_info = *on_conflict.set_info; - D_ASSERT(set_info.columns.size() == set_info.expressions.size()); - - if (set_info.condition) { - WhereBinder where_binder(*this, context); - - // Avoid ambiguity between existing table columns and EXCLUDED columns. - vector> lambda_params; - where_binder.DoUpdateSetQualify(set_info.condition, table_alias, lambda_params); + merge_into->source = std::move(source); - // Bind the SET ... WHERE clause. - auto condition = where_binder.Bind(set_info.condition); - insert.do_update_condition = std::move(condition); + if (on_conflict_info.action_type == OnConflictAction::REPLACE) { + D_ASSERT(!on_conflict_info.set_info); + on_conflict_info.set_info = CreateSetInfoForReplace(table, stmt, storage_info); + on_conflict_info.action_type = OnConflictAction::UPDATE; } + // now set up the merge actions + // first set up the base (insert) action when not matched + auto insert_action = make_uniq(); + insert_action->action_type = MergeActionType::MERGE_INSERT; + insert_action->column_order = stmt.column_order; - BindDoUpdateSetExpressions(table_alias, insert, set_info, table, storage_info); + merge_into->actions[MergeActionCondition::WHEN_NOT_MATCHED_BY_TARGET].push_back(std::move(insert_action)); - // Get the column_ids we need to fetch later on from the conflicting tuples - // of the original table, to execute the expressions - D_ASSERT(original_binding->binding_type == BindingType::TABLE); - auto &table_binding = original_binding->Cast(); - insert.columns_to_fetch = GetColumnsToFetch(table_binding); - - // Replace the column bindings to refer to the child operator - for (auto &expr : insert.expressions) { - // Change the non-excluded column references to refer to the projection index - ReplaceColumnBindings(*expr, table_index, projection_index.GetIndex()); - } - // Do the same for the (optional) DO UPDATE condition - if (insert.do_update_condition) { - ReplaceColumnBindings(*insert.do_update_condition, table_index, projection_index.GetIndex()); + if (on_conflict_info.condition) { + throw BinderException("ON CONFLICT WHERE clause is only supported in DO UPDATE SET ... WHERE ...\nThe WHERE " + "clause after the conflict columns is used for partial indexes which are not supported."); } -} - -void Binder::BindInsertColumnList(TableCatalogEntry &table, vector &columns, bool default_values, - vector &named_column_map, vector &expected_types, - IndexVector &column_index_map) { - if (!columns.empty() || default_values) { - // insertion statement specifies column list - - // create a mapping of (list index) -> (column index) - case_insensitive_map_t column_name_map; - for (idx_t i = 0; i < columns.size(); i++) { - auto entry = column_name_map.insert(make_pair(columns[i], i)); - if (!entry.second) { - throw BinderException("Duplicate column name \"%s\" in INSERT", columns[i]); - } - auto column_index = table.GetColumnIndex(columns[i]); - if (column_index.index == COLUMN_IDENTIFIER_ROW_ID) { - throw BinderException("Cannot explicitly insert values into rowid column"); - } - auto &col = table.GetColumn(column_index); - if (col.Generated()) { - throw BinderException("Cannot insert into a generated column"); - } - expected_types.push_back(col.Type()); - named_column_map.push_back(column_index); - } - for (auto &col : table.GetColumns().Physical()) { - auto entry = column_name_map.find(col.Name()); - if (entry == column_name_map.end()) { - // column not specified, set index to DConstants::INVALID_INDEX - column_index_map.push_back(DConstants::INVALID_INDEX); - } else { - // column was specified, set to the index - column_index_map.push_back(entry->second); - } + if (on_conflict_info.action_type == OnConflictAction::UPDATE) { + // when doing UPDATE set up the when matched action + auto update_action = make_uniq(); + update_action->action_type = MergeActionType::MERGE_UPDATE; + for (auto &col : on_conflict_info.set_info->expressions) { + vector> lambda_params; + DoUpdateSetQualify(col, table_name, lambda_params); } - } else { - // insert by position and no columns specified - insertion into all columns of the table - // intentionally don't populate 'column_index_map' as an indication of this - for (auto &col : table.GetColumns().Physical()) { - named_column_map.push_back(col.Logical()); - expected_types.push_back(col.Type()); + if (on_conflict_info.set_info->condition) { + vector> lambda_params; + DoUpdateSetQualify(on_conflict_info.set_info->condition, table_name, lambda_params); + update_action->condition = std::move(on_conflict_info.set_info->condition); } + update_action->update_info = std::move(on_conflict_info.set_info); + + merge_into->actions[MergeActionCondition::WHEN_MATCHED].push_back(std::move(update_action)); } + + // move over extra properties + merge_into->cte_map = std::move(stmt.cte_map); + merge_into->returning_list = std::move(stmt.returning_list); + return merge_into; } BoundStatement Binder::Bind(InsertStatement &stmt) { @@ -527,6 +508,11 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { BindSchemaOrCatalog(stmt.catalog, stmt.schema); auto &table = Catalog::GetEntry(context, stmt.catalog, stmt.schema, stmt.table); + if (stmt.on_conflict_info) { + // generate a MERGE INTO statement and bind it instead + auto merge_into = GenerateMergeInto(stmt, table); + return Bind(*merge_into); + } if (!table.temporary) { // inserting into a non-temporary table: alters underlying database auto &properties = GetStatementProperties(); @@ -575,33 +561,7 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { } // Exclude the generated columns from this amount idx_t expected_columns = stmt.columns.empty() ? table.GetColumns().PhysicalColumnCount() : stmt.columns.size(); - - // special case: check if we are inserting from a VALUES statement - if (values_list) { - auto &expr_list = values_list->Cast(); - expr_list.expected_types.resize(expected_columns); - expr_list.expected_names.resize(expected_columns); - - D_ASSERT(!expr_list.values.empty()); - CheckInsertColumnCountMismatch(expected_columns, expr_list.values[0].size(), !stmt.columns.empty(), - table.name.c_str()); - - // VALUES list! - for (idx_t col_idx = 0; col_idx < expected_columns; col_idx++) { - D_ASSERT(named_column_map.size() >= col_idx); - auto &table_col_idx = named_column_map[col_idx]; - - // set the expected types as the types for the INSERT statement - auto &column = table.GetColumn(table_col_idx); - expr_list.expected_types[col_idx] = column.Type(); - expr_list.expected_names[col_idx] = column.Name(); - - // now replace any DEFAULT values with the corresponding default expression - for (idx_t list_idx = 0; list_idx < expr_list.values.size(); list_idx++) { - TryReplaceDefaultExpression(expr_list.values[list_idx][col_idx], column); - } - } - } + ExpandDefaultInValuesList(stmt, table, values_list, named_column_map); // parse select statement and add to logical plan unique_ptr root; @@ -619,20 +579,16 @@ BoundStatement Binder::Bind(InsertStatement &stmt) { } else { root = make_uniq(GenerateTableIndex()); } - insert->AddChild(std::move(root)); - - BindOnConflictClause(*insert, table, stmt); + insert->AddChild(std::move(root)); if (!stmt.returning_list.empty()) { insert->return_chunk = true; - result.types.clear(); - result.names.clear(); auto insert_table_index = GenerateTableIndex(); insert->table_index = insert_table_index; unique_ptr index_as_logicaloperator = std::move(insert); return BindReturning(std::move(stmt.returning_list), table, stmt.table_ref ? stmt.table_ref->alias : string(), - insert_table_index, std::move(index_as_logicaloperator), std::move(result)); + insert_table_index, std::move(index_as_logicaloperator)); } D_ASSERT(result.types.size() == result.names.size()); diff --git a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp index b92b00fd4..6dc983d0f 100644 --- a/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_merge_into.cpp @@ -41,7 +41,7 @@ unique_ptr Binder::BindMergeAction(LogicalMergeInto &merge auto result = make_uniq(); result->action_type = action.action_type; if (action.condition) { - ProjectionBinder proj_binder(*this, context, proj_index, expressions); + ProjectionBinder proj_binder(*this, context, proj_index, expressions, "WHERE clause"); proj_binder.target_type = LogicalType::BOOLEAN; auto cond = proj_binder.Bind(action.condition); result->condition = std::move(cond); @@ -74,7 +74,7 @@ unique_ptr Binder::BindMergeAction(LogicalMergeInto &merge // construct a dummy projection and update LogicalProjection proj(proj_index, std::move(expressions)); LogicalUpdate update(table); - update.return_chunk = false; + update.return_chunk = merge_into.return_chunk; update.columns = std::move(result->columns); update.expressions = std::move(result->expressions); update.bound_defaults = std::move(merge_into.bound_defaults); @@ -135,7 +135,7 @@ unique_ptr Binder::BindMergeAction(LogicalMergeInto &merge case MergeActionType::MERGE_ERROR: { // bind the error message (if any) for (auto &expr : action.expressions) { - ProjectionBinder proj_binder(*this, context, proj_index, expressions); + ProjectionBinder proj_binder(*this, context, proj_index, expressions, "Error Message"); proj_binder.target_type = LogicalType::VARCHAR; auto error_msg = proj_binder.Bind(expr); result->expressions.push_back(std::move(error_msg)); @@ -171,9 +171,9 @@ void RewriteMergeBindings(LogicalOperator &op, const vector &sour } BoundStatement Binder::Bind(MergeIntoStatement &stmt) { - // bind the target table auto target_binder = Binder::CreateBinder(context, this); + string table_alias = stmt.target->alias; auto bound_table = target_binder->Bind(*stmt.target); if (bound_table->type != TableReferenceType::BASE_TABLE) { throw BinderException("Can only merge into base tables!"); @@ -235,8 +235,10 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { auto &get = root->children[inverted ? 0 : 1]->Cast(); auto merge_into = make_uniq(table); - merge_into->table_index = GenerateTableIndex(); + if (!stmt.returning_list.empty()) { + merge_into->return_chunk = true; + } // bind table constraints/default values in case these are referenced auto &catalog_name = table.ParentCatalog().GetName(); @@ -305,10 +307,21 @@ BoundStatement Binder::Bind(MergeIntoStatement &stmt) { merge_into->AddChild(std::move(proj)); + if (!stmt.returning_list.empty()) { + auto merge_table_index = merge_into->table_index; + unique_ptr index_as_logicaloperator = std::move(merge_into); + + // add the merge_action virtual column + virtual_column_map_t virtual_columns; + virtual_columns.insert(make_pair(VIRTUAL_COLUMN_START, TableColumn("merge_action", LogicalType::VARCHAR))); + return BindReturning(std::move(stmt.returning_list), table, table_alias, merge_table_index, + std::move(index_as_logicaloperator), std::move(virtual_columns)); + } + BoundStatement result; + result.plan = std::move(merge_into); result.names = {"Count"}; result.types = {LogicalType::BIGINT}; - result.plan = std::move(merge_into); auto &properties = GetStatementProperties(); properties.allow_stream_result = false; diff --git a/src/duckdb/src/planner/binder/statement/bind_update.cpp b/src/duckdb/src/planner/binder/statement/bind_update.cpp index 03d33b5be..650b23b89 100644 --- a/src/duckdb/src/planner/binder/statement/bind_update.cpp +++ b/src/duckdb/src/planner/binder/statement/bind_update.cpp @@ -106,7 +106,6 @@ void Binder::BindRowIdColumns(TableCatalogEntry &table, LogicalGet &get, vector< } BoundStatement Binder::Bind(UpdateStatement &stmt) { - BoundStatement result; unique_ptr root; // visit the table reference @@ -184,9 +183,10 @@ BoundStatement Binder::Bind(UpdateStatement &stmt) { unique_ptr update_as_logicaloperator = std::move(update); return BindReturning(std::move(stmt.returning_list), table, stmt.table->alias, update_table_index, - std::move(update_as_logicaloperator), std::move(result)); + std::move(update_as_logicaloperator)); } + BoundStatement result; result.names = {"Count"}; result.types = {LogicalType::BIGINT}; result.plan = std::move(update); diff --git a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp index e1352d6b3..2eb211530 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_pivot.cpp @@ -21,6 +21,7 @@ #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" #include "duckdb/main/query_result.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -542,11 +543,10 @@ unique_ptr Binder::BindPivot(PivotRef &ref, vector(context); if (total_pivots >= pivot_limit) { throw BinderException(ref, "Pivot column limit of %llu exceeded. Use SET pivot_limit=X to increase the limit.", - client_config.pivot_limit); + pivot_limit); } // construct the required pivot values recursively @@ -565,7 +565,8 @@ unique_ptr Binder::BindPivot(PivotRef &ref, vector filtered aggregates are faster when there are FEW pivot values // -> LIST is faster when there are MANY pivot values // we switch dynamically based on the number of pivots to compute - if (pivot_values.size() <= client_config.pivot_filter_threshold) { + auto pivot_filter_threshold = DBConfig::GetSetting(context); + if (pivot_values.size() <= pivot_filter_threshold) { // use a set of filtered aggregates pivot_node = PivotFilteredAggregate(context, ref, std::move(all_columns), handled_columns, std::move(pivot_values)); diff --git a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp index 78803e395..b23456cab 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_showref.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_showref.cpp @@ -9,6 +9,10 @@ #include "duckdb/planner/operator/logical_get.hpp" #include "duckdb/planner/operator/logical_projection.hpp" #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" +#include "duckdb/catalog/catalog.hpp" +#include "duckdb/catalog/catalog_search_path.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/client_context.hpp" namespace duckdb { @@ -151,6 +155,32 @@ unique_ptr Binder::BindShowTable(ShowRef &ref) { sql = PragmaShowDatabases(); } else if (lname == "\"tables\"") { sql = PragmaShowTables(); + } else if (ref.show_type == ShowType::SHOW_FROM) { + auto catalog_name = ref.catalog_name; + auto schema_name = ref.schema_name; + + // Check for unqualified name, promote schema to catalog if unambiguous, and set schema_name to empty if so + Binder::BindSchemaOrCatalog(catalog_name, schema_name); + + // If fully qualified, check if the schema exists + if (!catalog_name.empty() && !schema_name.empty()) { + auto schema_entry = Catalog::GetSchema(context, catalog_name, schema_name, OnEntryNotFound::RETURN_NULL); + if (!schema_entry) { + throw CatalogException("SHOW TABLES FROM: No catalog + schema named \"%s.%s\" found.", catalog_name, + schema_name); + } + } else if (catalog_name.empty() && !schema_name.empty()) { + // We have a schema name, use default catalog + auto &client_data = ClientData::Get(context); + auto &default_entry = client_data.catalog_search_path->GetDefault(); + catalog_name = default_entry.catalog; + auto schema_entry = Catalog::GetSchema(context, catalog_name, schema_name, OnEntryNotFound::RETURN_NULL); + if (!schema_entry) { + throw CatalogException("SHOW TABLES FROM: No catalog + schema named \"%s.%s\" found.", catalog_name, + schema_name); + } + } + sql = PragmaShowTables(catalog_name, schema_name); } else if (lname == "\"variables\"") { sql = PragmaShowVariables(); } else if (lname == "__show_tables_expanded") { diff --git a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp index 905998854..0c6e1e0aa 100644 --- a/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp +++ b/src/duckdb/src/planner/binder/tableref/bind_table_function.cpp @@ -21,6 +21,10 @@ #include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" #include "duckdb/function/table/read_csv.hpp" +#include "duckdb/planner/expression/bound_window_expression.hpp" +#include "duckdb/planner/operator/logical_window.hpp" +#include "duckdb/planner/operator/logical_projection.hpp" + namespace duckdb { enum class TableFunctionBindType { STANDARD_TABLE_FUNCTION, TABLE_IN_OUT_FUNCTION, TABLE_PARAMETER_FUNCTION }; @@ -191,12 +195,14 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab vector input_table_names) { auto function_name = GetAlias(ref); auto &column_name_alias = ref.column_name_alias; - auto bind_index = GenerateTableIndex(); // perform the binding unique_ptr bind_data; vector return_types; vector return_names; + auto constexpr ordinality_name = "ordinality"; + string ordinality_column_name = ordinality_name; + idx_t ordinality_column_id; if (table_function.bind || table_function.bind_replace || table_function.bind_operator) { TableFunctionBindInput bind_input(parameters, named_parameters, input_table_types, input_table_names, table_function.function_info.get(), this, table_function, ref); @@ -236,6 +242,27 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab table_function.name); } bind_data = table_function.bind(context, bind_input, return_types, return_names); + if (ref.with_ordinality == OrdinalityType::WITH_ORDINALITY) { + // check if column name 'ordinality' already exists and if so, replace it iteratively until free name is + // found + case_insensitive_set_t ci_return_names; + idx_t ordinality_name_suffix = 0; + for (auto &n : return_names) { + ci_return_names.insert(n); + } + for (auto &n : column_name_alias) { + ci_return_names.insert(n); + } + while (ci_return_names.find(ordinality_column_name) != ci_return_names.end()) { + ordinality_column_name = ordinality_name + to_string(ordinality_name_suffix++); + } + if (!correlated_columns.empty()) { + return_types.emplace_back(LogicalType::BIGINT); + return_names.emplace_back(ordinality_column_name); + D_ASSERT(return_names.size() == return_types.size()); + ordinality_column_id = return_types.size() - 1; + } + } } else { throw InvalidInputException("Cannot call function \"%s\" directly - it has no bind function", table_function.name); @@ -270,11 +297,55 @@ unique_ptr Binder::BindTableFunctionInternal(TableFunction &tab get->named_parameters = named_parameters; get->input_table_types = input_table_types; get->input_table_names = input_table_names; + if (ref.with_ordinality == OrdinalityType::WITH_ORDINALITY && !correlated_columns.empty()) { + get->ordinality_idx = ordinality_column_id; + } if (table_function.in_out_function) { for (idx_t i = 0; i < return_types.size(); i++) { get->AddColumnId(i); } } + + if (ref.with_ordinality == OrdinalityType::WITH_ORDINALITY && correlated_columns.empty()) { + auto window_index = GenerateTableIndex(); + auto window = make_uniq(window_index); + auto row_number = + make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); + row_number->start = WindowBoundary::UNBOUNDED_PRECEDING; + row_number->end = WindowBoundary::CURRENT_ROW_ROWS; + if (return_names.size() < column_name_alias.size()) { + row_number->alias = column_name_alias[return_names.size()]; + } else { + row_number->alias = ordinality_column_name; + } + window->expressions.push_back(std::move(row_number)); + for (idx_t i = 0; i < return_types.size(); i++) { + get->AddColumnId(i); + } + window->children.push_back(std::move(get)); + + vector> select_list; + for (idx_t i = 0; i < return_types.size(); i++) { + auto expression = make_uniq(return_types[i], ColumnBinding(bind_index, i)); + select_list.push_back(std::move(expression)); + } + select_list.push_back(make_uniq(LogicalType::BIGINT, ColumnBinding(window_index, 0))); + + auto projection_index = GenerateTableIndex(); + auto projection = make_uniq(projection_index, std::move(select_list)); + + projection->children.push_back(std::move(window)); + if (return_names.size() < column_name_alias.size()) { + return_names.push_back(column_name_alias[return_names.size()]); + } else { + return_names.push_back(ordinality_column_name); + } + + return_types.push_back(LogicalType::BIGINT); + bind_context.AddGenericBinding(projection_index, function_name, return_names, return_types); + return std::move(projection); + } + // now add the table function to the bind context so its columns can be bound bind_context.AddTableFunction(bind_index, function_name, return_names, return_types, get->GetMutableColumnIds(), get->GetTable().get(), std::move(virtual_columns)); diff --git a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp b/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp index 3be48aa6c..6c2f9957a 100644 --- a/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp +++ b/src/duckdb/src/planner/binder/tableref/plan_table_function.cpp @@ -6,7 +6,21 @@ namespace duckdb { unique_ptr Binder::CreatePlan(BoundTableFunction &ref) { if (ref.subquery) { auto child_node = CreatePlan(*ref.subquery); - ref.get->children.push_back(std::move(child_node)); + + reference node = *ref.get; + + while (!node.get().children.empty()) { + D_ASSERT(node.get().children.size() == 1); + if (node.get().children.size() != 1) { + throw InternalException( + "Binder::CreatePlan: linear path expected, but found node with %d children", + node.get().children.size()); + } + node = *node.get().children[0]; + } + + D_ASSERT(node.get().type == LogicalOperatorType::LOGICAL_GET); + node.get().children.push_back(std::move(child_node)); } return std::move(ref.get); } diff --git a/src/duckdb/src/planner/collation_binding.cpp b/src/duckdb/src/planner/collation_binding.cpp index aabde7f55..1ddefb9a8 100644 --- a/src/duckdb/src/planner/collation_binding.cpp +++ b/src/duckdb/src/planner/collation_binding.cpp @@ -3,6 +3,7 @@ #include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" #include "duckdb/main/config.hpp" +#include "duckdb/main/settings.hpp" #include "duckdb/catalog/catalog.hpp" #include "duckdb/function/function_binder.hpp" @@ -18,7 +19,7 @@ bool PushVarcharCollation(ClientContext &context, unique_ptr &source auto str_collation = StringType::GetCollation(sql_type); string collation; if (str_collation.empty()) { - collation = DBConfig::GetConfig(context).options.collation; + collation = DBConfig::GetSetting(context); } else { collation = str_collation; } diff --git a/src/duckdb/src/planner/expression_binder/order_binder.cpp b/src/duckdb/src/planner/expression_binder/order_binder.cpp index af78102b8..6b093eb8f 100644 --- a/src/duckdb/src/planner/expression_binder/order_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/order_binder.cpp @@ -14,6 +14,7 @@ #include "duckdb/planner/expression_binder/select_bind_state.hpp" #include "duckdb/main/client_config.hpp" #include "duckdb/common/pair.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -59,8 +60,9 @@ optional_idx OrderBinder::TryGetProjectionReference(ParsedExpression &expr) cons // non-integral expression // ORDER BY has no effect // this is disabled by default (matching Postgres) - but we can control this with a setting - auto &config = ClientConfig::GetConfig(binders[0].get().context); - if (!config.order_by_non_integer_literal) { + auto order_by_non_integer_literal = + DBConfig::GetSetting(binders[0].get().context); + if (!order_by_non_integer_literal) { throw BinderException(expr, "%s non-integer literal has no effect.\n* SET " "order_by_non_integer_literal=true to allow this behavior.", diff --git a/src/duckdb/src/planner/expression_binder/projection_binder.cpp b/src/duckdb/src/planner/expression_binder/projection_binder.cpp index f5c0a6c68..331ed9315 100644 --- a/src/duckdb/src/planner/expression_binder/projection_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/projection_binder.cpp @@ -4,8 +4,9 @@ namespace duckdb { ProjectionBinder::ProjectionBinder(Binder &binder, ClientContext &context, idx_t proj_index_p, - vector> &proj_expressions_p) - : ExpressionBinder(binder, context), proj_index(proj_index_p), proj_expressions(proj_expressions_p) { + vector> &proj_expressions_p, string clause_p) + : ExpressionBinder(binder, context), proj_index(proj_index_p), proj_expressions(proj_expressions_p), + clause(std::move(clause_p)) { } BindResult ProjectionBinder::BindColumnRef(unique_ptr &expr_ptr, idx_t depth, bool root_expression) { @@ -13,6 +14,9 @@ BindResult ProjectionBinder::BindColumnRef(unique_ptr &expr_pt if (result.HasError()) { return result; } + if (result.expression->GetExpressionClass() == ExpressionClass::BOUND_LAMBDA_REF) { + return result; + } // we have successfully bound a column - push it into the projection and emit a reference auto proj_ref = make_uniq(result.expression->return_type, ColumnBinding(proj_index, proj_expressions.size())); @@ -25,9 +29,9 @@ BindResult ProjectionBinder::BindExpression(unique_ptr &expr_p auto &expr = *expr_ptr; switch (expr.GetExpressionClass()) { case ExpressionClass::DEFAULT: - return BindUnsupportedExpression(expr, depth, "Clause cannot contain DEFAULT clause"); + return BindUnsupportedExpression(expr, depth, clause + " cannot contain DEFAULT clause"); case ExpressionClass::WINDOW: - return BindUnsupportedExpression(expr, depth, "Clause cannot contain window functions!"); + return BindUnsupportedExpression(expr, depth, clause + " cannot contain window functions!"); case ExpressionClass::COLUMN_REF: return BindColumnRef(expr_ptr, depth, root_expression); default: @@ -36,7 +40,7 @@ BindResult ProjectionBinder::BindExpression(unique_ptr &expr_p } string ProjectionBinder::UnsupportedAggregateMessage() { - return "Clause cannot contain aggregate functions"; + return clause + " cannot contain aggregate functions"; } } // namespace duckdb diff --git a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp index 3c3c1491c..198bd072b 100644 --- a/src/duckdb/src/planner/expression_binder/table_function_binder.cpp +++ b/src/duckdb/src/planner/expression_binder/table_function_binder.cpp @@ -6,8 +6,10 @@ namespace duckdb { -TableFunctionBinder::TableFunctionBinder(Binder &binder, ClientContext &context, string table_function_name_p) - : ExpressionBinder(binder, context), table_function_name(std::move(table_function_name_p)) { +TableFunctionBinder::TableFunctionBinder(Binder &binder, ClientContext &context, string table_function_name_p, + string clause_p) + : ExpressionBinder(binder, context), table_function_name(std::move(table_function_name_p)), + clause(std::move(clause_p)) { } BindResult TableFunctionBinder::BindLambdaReference(LambdaRefExpression &expr, idx_t depth) { @@ -67,18 +69,18 @@ BindResult TableFunctionBinder::BindExpression(unique_ptr &exp case ExpressionClass::COLUMN_REF: return BindColumnReference(expr_ptr, depth, root_expression); case ExpressionClass::SUBQUERY: - throw BinderException("Table function cannot contain subqueries"); + throw BinderException(clause + " cannot contain subqueries"); case ExpressionClass::DEFAULT: - return BindResult("Table function cannot contain DEFAULT clause"); + return BindResult(clause + " cannot contain DEFAULT clause"); case ExpressionClass::WINDOW: - return BindResult("Table function cannot contain window functions!"); + return BindResult(clause + " cannot contain window functions!"); default: return ExpressionBinder::BindExpression(expr_ptr, depth); } } string TableFunctionBinder::UnsupportedAggregateMessage() { - return "Table function cannot contain aggregates!"; + return clause + " cannot contain aggregates!"; } } // namespace duckdb diff --git a/src/duckdb/src/planner/logical_operator_visitor.cpp b/src/duckdb/src/planner/logical_operator_visitor.cpp index 716b719fb..5e96a5bbb 100644 --- a/src/duckdb/src/planner/logical_operator_visitor.cpp +++ b/src/duckdb/src/planner/logical_operator_visitor.cpp @@ -132,11 +132,11 @@ void LogicalOperatorVisitor::EnumerateExpressions(LogicalOperator &op, } case LogicalOperatorType::LOGICAL_INSERT: { auto &insert = op.Cast(); - if (insert.on_conflict_condition) { - callback(&insert.on_conflict_condition); + if (insert.on_conflict_info.on_conflict_condition) { + callback(&insert.on_conflict_info.on_conflict_condition); } - if (insert.do_update_condition) { - callback(&insert.do_update_condition); + if (insert.on_conflict_info.do_update_condition) { + callback(&insert.on_conflict_info.do_update_condition); } break; } diff --git a/src/duckdb/src/planner/operator/logical_get.cpp b/src/duckdb/src/planner/operator/logical_get.cpp index f3c98224d..79d029f08 100644 --- a/src/duckdb/src/planner/operator/logical_get.cpp +++ b/src/duckdb/src/planner/operator/logical_get.cpp @@ -220,6 +220,7 @@ void LogicalGet::Serialize(Serializer &serializer) const { serializer.WriteProperty(210, "projected_input", projected_input); serializer.WritePropertyWithDefault(211, "column_indexes", column_ids); serializer.WritePropertyWithDefault(212, "extra_info", extra_info, ExtraOperatorInfo {}); + serializer.WritePropertyWithDefault(213, "ordinality_idx", ordinality_idx); } unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) { @@ -248,6 +249,8 @@ unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) } deserializer.ReadProperty(210, "projected_input", result->projected_input); deserializer.ReadPropertyWithDefault(211, "column_indexes", result->column_ids); + result->extra_info = deserializer.ReadPropertyWithExplicitDefault(212, "extra_info", {}); + deserializer.ReadPropertyWithDefault(213, "ordinality_idx", result->ordinality_idx); if (!legacy_column_ids.empty()) { if (!result->column_ids.empty()) { throw SerializationException( @@ -257,7 +260,6 @@ unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) result->column_ids.emplace_back(col_id); } } - result->extra_info = deserializer.ReadPropertyWithExplicitDefault(212, "extra_info", {}); auto &context = deserializer.Get(); virtual_column_map_t virtual_columns; if (!has_serialize) { @@ -272,10 +274,13 @@ unique_ptr LogicalGet::Deserialize(Deserializer &deserializer) throw InternalException("Table function \"%s\" has neither bind nor (de)serialize", function.name); } bind_data = function.bind(context, input, bind_return_types, bind_names); + if (result->ordinality_idx.IsValid()) { + auto ordinality_pos = bind_return_types.begin() + NumericCast(result->ordinality_idx.GetIndex()); + bind_return_types.emplace(ordinality_pos, LogicalType::BIGINT); + } if (function.get_virtual_columns) { virtual_columns = function.get_virtual_columns(context, bind_data.get()); } - for (auto &col_id : result->column_ids) { if (col_id.IsVirtualColumn()) { auto idx = col_id.GetPrimaryIndex(); diff --git a/src/duckdb/src/planner/operator/logical_insert.cpp b/src/duckdb/src/planner/operator/logical_insert.cpp index 10f115697..bab6820fe 100644 --- a/src/duckdb/src/planner/operator/logical_insert.cpp +++ b/src/duckdb/src/planner/operator/logical_insert.cpp @@ -6,9 +6,12 @@ namespace duckdb { +BoundOnConflictInfo::BoundOnConflictInfo() : action_type(OnConflictAction::THROW), update_is_del_and_insert(false) { +} + LogicalInsert::LogicalInsert(TableCatalogEntry &table, idx_t table_index) - : LogicalOperator(LogicalOperatorType::LOGICAL_INSERT), table(table), table_index(table_index), return_chunk(false), - action_type(OnConflictAction::THROW), update_is_del_and_insert(false) { + : LogicalOperator(LogicalOperatorType::LOGICAL_INSERT), table(table), table_index(table_index), + return_chunk(false) { } LogicalInsert::LogicalInsert(ClientContext &context, const unique_ptr table_info) diff --git a/src/duckdb/src/planner/operator/logical_merge_into.cpp b/src/duckdb/src/planner/operator/logical_merge_into.cpp index f08d6cfb9..fa99f69ab 100644 --- a/src/duckdb/src/planner/operator/logical_merge_into.cpp +++ b/src/duckdb/src/planner/operator/logical_merge_into.cpp @@ -17,15 +17,27 @@ LogicalMergeInto::LogicalMergeInto(ClientContext &context, const unique_ptr LogicalMergeInto::GetTableIndex() const { + return vector {table_index}; } vector LogicalMergeInto::GetColumnBindings() { + if (return_chunk) { + return GenerateColumnBindings(table_index, table.GetTypes().size() + 1); + } return {ColumnBinding(0, 0)}; } void LogicalMergeInto::ResolveTypes() { - types.emplace_back(LogicalType::BIGINT); + if (return_chunk) { + types = table.GetTypes(); + types.push_back(LogicalType::VARCHAR); + } else { + types.emplace_back(LogicalType::BIGINT); + } } } // namespace duckdb diff --git a/src/duckdb/src/planner/planner.cpp b/src/duckdb/src/planner/planner.cpp index e55296e6f..9f1cafa4d 100644 --- a/src/duckdb/src/planner/planner.cpp +++ b/src/duckdb/src/planner/planner.cpp @@ -47,10 +47,11 @@ void Planner::CreatePlan(SQLStatement &statement) { this->names = bound_statement.names; this->types = bound_statement.types; - this->plan = FlattenDependentJoins::DecorrelateIndependent(*binder, std::move(bound_statement.plan)); - + this->plan = std::move(bound_statement.plan); auto max_tree_depth = ClientConfig::GetConfig(context).max_expression_depth; CheckTreeDepth(*plan, max_tree_depth); + + this->plan = FlattenDependentJoins::DecorrelateIndependent(*binder, std::move(this->plan)); } catch (const std::exception &ex) { ErrorData error(ex); this->plan = nullptr; diff --git a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp index dafd519dc..34ed61deb 100644 --- a/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp +++ b/src/duckdb/src/planner/subquery/flatten_dependent_join.cpp @@ -610,6 +610,8 @@ unique_ptr FlattenDependentJoins::PushDownDependentJoinInternal // recurse into left children plan->children[0] = DecorrelateIndependent(binder, std::move(plan->children[0])); + // Similar to the LOGICAL_COMPARISON_JOIN + delim_offset += plan->children[0]->GetColumnBindings().size(); return plan; } // both sides have correlation diff --git a/src/duckdb/src/storage/buffer/block_handle.cpp b/src/duckdb/src/storage/buffer/block_handle.cpp index fa103784a..2494cace9 100644 --- a/src/duckdb/src/storage/buffer/block_handle.cpp +++ b/src/duckdb/src/storage/buffer/block_handle.cpp @@ -1,6 +1,7 @@ #include "duckdb/storage/buffer/block_handle.hpp" #include "duckdb/common/file_buffer.hpp" +#include "duckdb/main/client_context.hpp" #include "duckdb/storage/block.hpp" #include "duckdb/storage/block_manager.hpp" #include "duckdb/storage/buffer/buffer_handle.hpp" @@ -150,7 +151,8 @@ BufferHandle BlockHandle::Load(unique_ptr reusable_buffer) { buffer = std::move(block); } else { if (MustWriteToTemporaryFile()) { - buffer = block_manager.buffer_manager.ReadTemporaryBuffer(tag, *this, std::move(reusable_buffer)); + buffer = block_manager.buffer_manager.ReadTemporaryBuffer(QueryContext(), tag, *this, + std::move(reusable_buffer)); } else { return BufferHandle(); // Destroyed upon unpin/evict, so there is no temp buffer to read } diff --git a/src/duckdb/src/storage/buffer_manager.cpp b/src/duckdb/src/storage/buffer_manager.cpp index e855fa58a..cbafc4e4a 100644 --- a/src/duckdb/src/storage/buffer_manager.cpp +++ b/src/duckdb/src/storage/buffer_manager.cpp @@ -114,7 +114,7 @@ void BufferManager::WriteTemporaryBuffer(MemoryTag tag, block_id_t block_id, Fil throw NotImplementedException("This type of BufferManager does not support 'WriteTemporaryBuffer"); } -unique_ptr BufferManager::ReadTemporaryBuffer(MemoryTag tag, BlockHandle &block, +unique_ptr BufferManager::ReadTemporaryBuffer(QueryContext context, MemoryTag tag, BlockHandle &block, unique_ptr buffer) { throw NotImplementedException("This type of BufferManager does not support 'ReadTemporaryBuffer"); } diff --git a/src/duckdb/src/storage/caching_file_system.cpp b/src/duckdb/src/storage/caching_file_system.cpp index 019a45365..19a032708 100644 --- a/src/duckdb/src/storage/caching_file_system.cpp +++ b/src/duckdb/src/storage/caching_file_system.cpp @@ -22,13 +22,21 @@ CachingFileSystem CachingFileSystem::Get(ClientContext &context) { } unique_ptr CachingFileSystem::OpenFile(const OpenFileInfo &path, FileOpenFlags flags) { - return make_uniq(*this, path, flags, external_file_cache.GetOrCreateCachedFile(path.path)); + return make_uniq(QueryContext(), *this, path, flags, + external_file_cache.GetOrCreateCachedFile(path.path)); } -CachingFileHandle::CachingFileHandle(CachingFileSystem &caching_file_system_p, const OpenFileInfo &path_p, - FileOpenFlags flags_p, CachedFile &cached_file_p) - : caching_file_system(caching_file_system_p), external_file_cache(caching_file_system.external_file_cache), - path(path_p), flags(flags_p), validate(true), cached_file(cached_file_p), position(0) { +unique_ptr CachingFileSystem::OpenFile(QueryContext context, const OpenFileInfo &path, + FileOpenFlags flags) { + return make_uniq(context, *this, path, flags, + external_file_cache.GetOrCreateCachedFile(path.path)); +} + +CachingFileHandle::CachingFileHandle(QueryContext context, CachingFileSystem &caching_file_system_p, + const OpenFileInfo &path_p, FileOpenFlags flags_p, CachedFile &cached_file_p) + : context(context), caching_file_system(caching_file_system_p), + external_file_cache(caching_file_system.external_file_cache), path(path_p), flags(flags_p), validate(true), + cached_file(cached_file_p), position(0) { if (path.extended_info) { const auto &open_options = path.extended_info->options; const auto validate_entry = open_options.find("validate_external_file_cache"); @@ -76,7 +84,7 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, c if (!external_file_cache.IsEnabled()) { result = external_file_cache.GetBufferManager().Allocate(MemoryTag::EXTERNAL_FILE_CACHE, nr_bytes); buffer = result.Ptr(); - GetFileHandle().Read(buffer, nr_bytes, location); + GetFileHandle().Read(context, buffer, nr_bytes, location); return result; } @@ -101,7 +109,7 @@ BufferHandle CachingFileHandle::Read(data_ptr_t &buffer, const idx_t nr_bytes, c if (ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, false) <= 1) { ReadAndCopyInterleaved(overlapping_ranges, new_file_range, buffer, nr_bytes, location, true); } else { - GetFileHandle().Read(buffer, nr_bytes, location); + GetFileHandle().Read(context, buffer, nr_bytes, location); } } @@ -351,7 +359,7 @@ idx_t CachingFileHandle::ReadAndCopyInterleaved(const vectorlocation - current_location; D_ASSERT(bytes_to_read < remaining_bytes); if (actually_read) { - GetFileHandle().Read(buffer + buffer_offset, bytes_to_read, current_location); + GetFileHandle().Read(context, buffer + buffer_offset, bytes_to_read, current_location); } current_location += bytes_to_read; remaining_bytes -= bytes_to_read; @@ -385,7 +393,7 @@ idx_t CachingFileHandle::ReadAndCopyInterleaved(const vector SingleFileRowGroupWriter::GetMetadataManager() { + return table_data_writer.GetManager(); +} + +void SingleFileRowGroupWriter::StartWritingColumns(vector &column_metadata) { + table_data_writer.SetWrittenPointers(column_metadata); +} + +void SingleFileRowGroupWriter::FinishWritingColumns() { + table_data_writer.SetWrittenPointers(nullptr); +} + } // namespace duckdb diff --git a/src/duckdb/src/storage/checkpoint/table_data_reader.cpp b/src/duckdb/src/storage/checkpoint/table_data_reader.cpp index 5f76227f3..0bb28cb00 100644 --- a/src/duckdb/src/storage/checkpoint/table_data_reader.cpp +++ b/src/duckdb/src/storage/checkpoint/table_data_reader.cpp @@ -10,8 +10,10 @@ namespace duckdb { -TableDataReader::TableDataReader(MetadataReader &reader, BoundCreateTableInfo &info) : reader(reader), info(info) { +TableDataReader::TableDataReader(MetadataReader &reader, BoundCreateTableInfo &info, MetaBlockPointer table_pointer) + : reader(reader), info(info) { info.data = make_uniq(info.Base().columns.LogicalColumnCount()); + info.data->base_table_pointer = table_pointer; } void TableDataReader::ReadTableData() { diff --git a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp index a937250e4..0bde4acb3 100644 --- a/src/duckdb/src/storage/checkpoint/table_data_writer.cpp +++ b/src/duckdb/src/storage/checkpoint/table_data_writer.cpp @@ -7,6 +7,7 @@ #include "duckdb/parallel/task_scheduler.hpp" #include "duckdb/storage/table/column_checkpoint_state.hpp" #include "duckdb/storage/table/table_statistics.hpp" +#include "duckdb/storage/metadata/metadata_reader.hpp" namespace duckdb { @@ -53,33 +54,56 @@ CheckpointType SingleFileTableDataWriter::GetCheckpointType() const { return checkpoint_manager.GetCheckpointType(); } +MetadataManager &SingleFileTableDataWriter::GetMetadataManager() { + return checkpoint_manager.GetMetadataManager(); +} + +void SingleFileTableDataWriter::WriteUnchangedTable(MetaBlockPointer pointer, idx_t total_rows) { + existing_pointer = pointer; + existing_rows = total_rows; +} + void SingleFileTableDataWriter::FinalizeTable(const TableStatistics &global_stats, DataTableInfo *info, Serializer &serializer) { + MetaBlockPointer pointer; + idx_t total_rows; + if (!existing_pointer.IsValid()) { + // write the metadata + // store the current position in the metadata writer + // this is where the row groups for this table start + pointer = table_data_writer.GetMetaBlockPointer(); + + // Serialize statistics as a single unit + BinarySerializer stats_serializer(table_data_writer, serializer.GetOptions()); + stats_serializer.Begin(); + global_stats.Serialize(stats_serializer); + stats_serializer.End(); + + // now start writing the row group pointers to disk + table_data_writer.Write(row_group_pointers.size()); + total_rows = 0; + for (auto &row_group_pointer : row_group_pointers) { + auto row_group_count = row_group_pointer.row_start + row_group_pointer.tuple_count; + if (row_group_count > total_rows) { + total_rows = row_group_count; + } - // store the current position in the metadata writer - // this is where the row groups for this table start - auto pointer = table_data_writer.GetMetaBlockPointer(); - - // Serialize statistics as a single unit - BinarySerializer stats_serializer(table_data_writer, serializer.GetOptions()); - stats_serializer.Begin(); - global_stats.Serialize(stats_serializer); - stats_serializer.End(); - - // now start writing the row group pointers to disk - table_data_writer.Write(row_group_pointers.size()); - idx_t total_rows = 0; - for (auto &row_group_pointer : row_group_pointers) { - auto row_group_count = row_group_pointer.row_start + row_group_pointer.tuple_count; - if (row_group_count > total_rows) { - total_rows = row_group_count; + // Each RowGroup is its own unit + BinarySerializer row_group_serializer(table_data_writer, serializer.GetOptions()); + row_group_serializer.Begin(); + RowGroup::Serialize(row_group_pointer, row_group_serializer); + row_group_serializer.End(); } - - // Each RowGroup is its own unit - BinarySerializer row_group_serializer(table_data_writer, serializer.GetOptions()); - row_group_serializer.Begin(); - RowGroup::Serialize(row_group_pointer, row_group_serializer); - row_group_serializer.End(); + } else { + // we have existing metadata and the table is unchanged - write a pointer to the existing metadata + pointer = existing_pointer; + total_rows = existing_rows.GetIndex(); + + // label the blocks as used again to prevent them from being freed + auto &metadata_manager = checkpoint_manager.GetMetadataManager(); + MetadataReader reader(metadata_manager, pointer); + auto blocks = reader.GetRemainingBlocks(); + metadata_manager.ClearModifiedBlocks(blocks); } // Now begin the metadata as a unit diff --git a/src/duckdb/src/storage/checkpoint_manager.cpp b/src/duckdb/src/storage/checkpoint_manager.cpp index 9dd675773..8618b38eb 100644 --- a/src/duckdb/src/storage/checkpoint_manager.cpp +++ b/src/duckdb/src/storage/checkpoint_manager.cpp @@ -31,6 +31,8 @@ #include "duckdb/transaction/meta_transaction.hpp" #include "duckdb/transaction/transaction_manager.hpp" #include "duckdb/catalog/dependency_manager.hpp" +#include "duckdb/common/serializer/memory_stream.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -129,7 +131,6 @@ static catalog_entry_vector_t GetCatalogEntries(vector(); if (storage_manager.InMemory()) { return; @@ -201,8 +202,8 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { wal->WriteCheckpoint(meta_block); wal->Flush(); } - - if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER) { + auto debug_checkpoint_abort = DBConfig::GetSetting(db.GetDatabase()); + if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER) { throw FatalException("Checkpoint aborted before header write because of PRAGMA checkpoint_abort flag"); } @@ -241,7 +242,7 @@ void SingleFileCheckpointWriter::CreateCheckpoint() { block_manager.VerifyBlocks(verify_block_usage_count); #endif - if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE) { + if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE) { throw FatalException("Checkpoint aborted before truncate because of PRAGMA checkpoint_abort flag"); } @@ -594,7 +595,7 @@ void CheckpointReader::ReadTableData(CatalogTransaction transaction, Deserialize auto &reader = dynamic_cast(binary_deserializer.GetStream()); MetadataReader table_data_reader(reader.GetMetadataManager(), table_pointer); - TableDataReader data_reader(table_data_reader, bound_info); + TableDataReader data_reader(table_data_reader, bound_info, table_pointer); data_reader.ReadTableData(); bound_info.data->total_rows = total_rows; diff --git a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp index dff33da74..580a5cfc5 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/compression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/compression.cpp @@ -305,12 +305,14 @@ static inline bool RequiresHigherBitWidth(bitpacking_width_t bitwidth, uint32_t } template -static inline bool AddLookup(DictFSSTCompressionState &state, idx_t lookup, const bool recalculate_indices_space) { +static inline bool AddLookup(DictFSSTCompressionState &state, idx_t lookup, const bool recalculate_indices_space, + bool fail_on_no_space) { D_ASSERT(lookup != DConstants::INVALID_INDEX); //! This string exists in the dictionary idx_t new_dictionary_indices_space = state.dictionary_indices_space; - if (APPEND_STATE != DictionaryAppendState::ENCODED_ALL_UNIQUE && recalculate_indices_space) { + auto get_bitpacking_size = APPEND_STATE != DictionaryAppendState::ENCODED_ALL_UNIQUE && recalculate_indices_space; + if (get_bitpacking_size) { new_dictionary_indices_space = BitpackingPrimitives::GetRequiredSize(state.tuple_count + 1, state.dictionary_indices_width); } @@ -336,6 +338,12 @@ static inline bool AddLookup(DictFSSTCompressionState &state, idx_t lookup, cons available_space -= FSST_SYMBOL_TABLE_SIZE; } if (required_space > available_space) { + if (fail_on_no_space) { + throw FatalException("AddLookup in DictFSST failed: required: %d, available: %d, indices: %d, bitpacking: " + "%b, dict offset: %d, str length: %d", + required_space, available_space, new_dictionary_indices_space, get_bitpacking_size, + state.dictionary_offset, state.string_lengths_space); + } return false; } @@ -349,7 +357,7 @@ static inline bool AddLookup(DictFSSTCompressionState &state, idx_t lookup, cons template static inline bool AddToDictionary(DictFSSTCompressionState &state, const string_t &str, - const bool recalculate_indices_space) { + const bool recalculate_indices_space, bool fail_on_no_space) { uint32_t str_len = UnsafeNumericCast(str.GetSize()); if (APPEND_STATE == DictionaryAppendState::ENCODED) { //! We delay encoding of new entries. @@ -413,6 +421,12 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string available_space -= FSST_SYMBOL_TABLE_SIZE; } if (required_space > available_space) { + if (fail_on_no_space) { + throw FatalException("AddToDictionary in DictFSST failed: required: %d, available: %d, dict offset + " + "str_len: %d, new str length: %d, new dict indices: %d", + required_space, available_space, state.dictionary_offset + str_len, + new_string_lengths_space, new_dictionary_indices_space); + } return false; } @@ -422,7 +436,7 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string if (str.IsInlined()) { state.dictionary_encoding_buffer.push_back(str); } else { - state.dictionary_encoding_buffer.push_back(state.uncompressed_dictionary_copy.AddString(str)); + state.dictionary_encoding_buffer.push_back(state.uncompressed_dictionary_copy.AddBlob(str)); } if (!state.to_encode_string_sum) { //! As specified in fsst.h @@ -461,7 +475,8 @@ static inline bool AddToDictionary(DictFSSTCompressionState &state, const string } bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_format, const string_t &str, bool is_null, - EncodedInput &encoded_input, const idx_t i, idx_t count) { + EncodedInput &encoded_input, const idx_t i, idx_t count, + bool fail_on_no_space) { auto strings = UnifiedVectorFormat::GetData(vector_format); idx_t lookup = DConstants::INVALID_INDEX; @@ -484,17 +499,21 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form case DictionaryAppendState::REGULAR: { if (append_state == DictionaryAppendState::REGULAR) { if (lookup != DConstants::INVALID_INDEX) { - return AddLookup(*this, lookup, recalculate_indices_space); + return AddLookup(*this, lookup, recalculate_indices_space, + fail_on_no_space); } else { //! This string does not exist in the dictionary, add it - return AddToDictionary(*this, str, recalculate_indices_space); + return AddToDictionary(*this, str, recalculate_indices_space, + fail_on_no_space); } } else { if (lookup != DConstants::INVALID_INDEX) { - return AddLookup(*this, lookup, recalculate_indices_space); + return AddLookup(*this, lookup, recalculate_indices_space, + fail_on_no_space); } else { //! This string does not exist in the dictionary, add it - return AddToDictionary(*this, str, recalculate_indices_space); + return AddToDictionary(*this, str, recalculate_indices_space, + fail_on_no_space); } } } @@ -505,10 +524,12 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form bool fits; if (lookup != DConstants::INVALID_INDEX) { - fits = AddLookup(*this, lookup, recalculate_indices_space); + fits = + AddLookup(*this, lookup, recalculate_indices_space, fail_on_no_space); } else { //! Not in the dictionary, add it - fits = AddToDictionary(*this, str, recalculate_indices_space); + fits = AddToDictionary(*this, str, recalculate_indices_space, + fail_on_no_space); } if (fits) { return fits; @@ -523,10 +544,12 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form // we flush these and try again to see if the size went down enough FlushEncodingBuffer(); if (lookup != DConstants::INVALID_INDEX) { - return AddLookup(*this, lookup, recalculate_indices_space); + return AddLookup(*this, lookup, recalculate_indices_space, + fail_on_no_space); } else { //! Not in the dictionary, add it - return AddToDictionary(*this, str, recalculate_indices_space); + return AddToDictionary(*this, str, recalculate_indices_space, + fail_on_no_space); } } case DictionaryAppendState::ENCODED_ALL_UNIQUE: { @@ -535,8 +558,7 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form #ifdef DEBUG auto temp_decoder = alloca(sizeof(duckdb_fsst_decoder_t)); - duckdb_fsst_import((duckdb_fsst_decoder_t *)temp_decoder, fsst_serialized_symbol_table.get()); - + duckdb_fsst_import(reinterpret_cast(temp_decoder), fsst_serialized_symbol_table.get()); vector decompress_buffer; #endif @@ -589,12 +611,12 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form //! Verify that we can decompress the string auto &uncompressed_str = strings[encoded_input.offset + j]; decompress_buffer.resize(uncompressed_str.GetSize() + 1 + 100); - auto decoded_std_string = - FSSTPrimitives::DecompressValue((void *)temp_decoder, (const char *)compressed_ptrs[j], - (idx_t)compressed_sizes[j], decompress_buffer); + auto decoded_std_string = FSSTPrimitives::DecompressValue( + (void *)temp_decoder, reinterpret_cast(compressed_ptrs[j]), + (idx_t)compressed_sizes[j], decompress_buffer); D_ASSERT(decoded_std_string.size() == uncompressed_str.GetSize()); - string_t decompressed_string((const char *)decompress_buffer.data(), + string_t decompressed_string(reinterpret_cast(decompress_buffer.data()), UnsafeNumericCast(uncompressed_str.GetSize())); D_ASSERT(decompressed_string == uncompressed_str); #endif @@ -615,14 +637,15 @@ bool DictFSSTCompressionState::CompressInternal(UnifiedVectorFormat &vector_form compressed_string.GetSize(), decompress_buffer); D_ASSERT(decoded_std_string.size() == uncompressed_string.GetSize()); - string_t decompressed_string((const char *)decompress_buffer.data(), + string_t decompressed_string(reinterpret_cast(decompress_buffer.data()), UnsafeNumericCast(uncompressed_string.GetSize())); D_ASSERT(decompressed_string == uncompressed_string); } #endif auto &string = encoded_input.data[i - encoded_input.offset]; - return AddToDictionary(*this, string, recalculate_indices_space); + return AddToDictionary(*this, string, recalculate_indices_space, + fail_on_no_space); } }; throw InternalException("Unreachable"); @@ -820,22 +843,22 @@ void DictFSSTCompressionState::Compress(Vector &scan_vector, idx_t count) { auto &str = strings[idx]; auto is_null = !vector_format.validity.RowIsValid(idx); do { - if (CompressInternal(vector_format, str, is_null, encoded_input, i, count)) { + if (CompressInternal(vector_format, str, is_null, encoded_input, i, count, false)) { break; } if (append_state == DictionaryAppendState::REGULAR) { append_state = TryEncode(); D_ASSERT(append_state != DictionaryAppendState::REGULAR); - if (CompressInternal(vector_format, str, is_null, encoded_input, i, count)) { + if (CompressInternal(vector_format, str, is_null, encoded_input, i, count, false)) { break; } } Flush(false); encoded_input.data.clear(); encoded_input.offset = 0; - if (!CompressInternal(vector_format, str, is_null, encoded_input, i, count)) { - throw FatalException("Compressing directly after Flush doesn't fit"); + if (!CompressInternal(vector_format, str, is_null, encoded_input, i, count, true)) { + throw FatalException("Compressing directly after Flush doesn't fit - expected to throw earlier!"); } } while (false); if (!is_null) { diff --git a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp index fd6640e68..0546096bb 100644 --- a/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp +++ b/src/duckdb/src/storage/compression/dict_fsst/decompression.cpp @@ -219,7 +219,6 @@ bool CompressedStringScanState::AllowDictionaryScan(idx_t scan_count) { void CompressedStringScanState::ScanToDictionaryVector(ColumnSegment &segment, Vector &result, idx_t result_offset, idx_t start, idx_t scan_count) { - D_ASSERT(start % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0); D_ASSERT(scan_count == STANDARD_VECTOR_SIZE); D_ASSERT(result_offset == 0); diff --git a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp index 24919388d..afd335dab 100644 --- a/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp +++ b/src/duckdb/src/storage/compression/fixed_size_uncompressed.cpp @@ -72,9 +72,12 @@ void UncompressedCompressState::CreateEmptySegment(idx_t row_start) { info.GetBlockManager()); if (type.InternalType() == PhysicalType::VARCHAR) { auto &state = compressed_segment->GetSegmentState()->Cast(); - auto &partial_block_manager = checkpoint_data.GetCheckpointState().GetPartialBlockManager(); - state.block_manager = partial_block_manager.GetBlockManager(); - state.overflow_writer = make_uniq(partial_block_manager); + auto &storage_manager = checkpoint_data.GetStorageManager(); + if (!storage_manager.InMemory()) { + auto &partial_block_manager = checkpoint_data.GetCheckpointState().GetPartialBlockManager(); + state.block_manager = partial_block_manager.GetBlockManager(); + state.overflow_writer = make_uniq(partial_block_manager); + } } current_segment = std::move(compressed_segment); current_segment->InitializeAppend(append_state); @@ -84,8 +87,10 @@ void UncompressedCompressState::FlushSegment(idx_t segment_size) { auto &state = checkpoint_data.GetCheckpointState(); if (current_segment->type.InternalType() == PhysicalType::VARCHAR) { auto &segment_state = current_segment->GetSegmentState()->Cast(); - segment_state.overflow_writer->Flush(); - segment_state.overflow_writer.reset(); + if (segment_state.overflow_writer) { + segment_state.overflow_writer->Flush(); + segment_state.overflow_writer.reset(); + } } append_state.child_appends.clear(); append_state.append_state.reset(); diff --git a/src/duckdb/src/storage/compression/fsst.cpp b/src/duckdb/src/storage/compression/fsst.cpp index 800079a80..cbb3b3ac7 100644 --- a/src/duckdb/src/storage/compression/fsst.cpp +++ b/src/duckdb/src/storage/compression/fsst.cpp @@ -8,8 +8,10 @@ #include "duckdb/storage/checkpoint/write_overflow_strings_to_disk.hpp" #include "duckdb/storage/string_uncompressed.hpp" #include "duckdb/storage/table/column_data_checkpointer.hpp" +#include "duckdb/main/settings.hpp" #include "fsst.h" + #include "miniz_wrapper.hpp" namespace duckdb { @@ -644,8 +646,7 @@ void FSSTStorage::StringScanPartial(ColumnSegment &segment, ColumnScanState &sta bool enable_fsst_vectors; if (ALLOW_FSST_VECTORS) { - auto &config = DBConfig::GetConfig(segment.db); - enable_fsst_vectors = config.options.enable_fsst_vectors; + enable_fsst_vectors = DBConfig::GetSetting(segment.db); } else { enable_fsst_vectors = false; } diff --git a/src/duckdb/src/storage/compression/rle.cpp b/src/duckdb/src/storage/compression/rle.cpp index efac9d52a..57ebaf1fa 100644 --- a/src/duckdb/src/storage/compression/rle.cpp +++ b/src/duckdb/src/storage/compression/rle.cpp @@ -134,7 +134,7 @@ struct RLECompressState : public CompressionState { idx_t MaxRLECount() { auto entry_size = sizeof(T) + sizeof(rle_count_t); - return (info.GetBlockSize() - RLEConstants::RLE_HEADER_SIZE) / entry_size; + return AlignValueFloor((info.GetBlockSize() - RLEConstants::RLE_HEADER_SIZE) / entry_size); } RLECompressState(ColumnDataCheckpointData &checkpoint_data_p, const CompressionInfo &info) diff --git a/src/duckdb/src/storage/data_table.cpp b/src/duckdb/src/storage/data_table.cpp index cdd98cebb..9bd57d9ba 100644 --- a/src/duckdb/src/storage/data_table.cpp +++ b/src/duckdb/src/storage/data_table.cpp @@ -9,6 +9,7 @@ #include "duckdb/common/types/constraint_conflict_info.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/main/attached_database.hpp" #include "duckdb/main/client_context.hpp" #include "duckdb/parser/constraints/list.hpp" @@ -37,8 +38,8 @@ DataTableInfo::DataTableInfo(AttachedDatabase &db, shared_ptr ta : db(db), table_io_manager(std::move(table_io_manager_p)), schema(std::move(schema)), table(std::move(table)) { } -void DataTableInfo::InitializeIndexes(ClientContext &context, const char *index_type) { - indexes.InitializeIndexes(context, *this, index_type); +void DataTableInfo::BindIndexes(ClientContext &context, const char *index_type) { + indexes.Bind(context, *this, index_type); } bool DataTableInfo::IsTemporary() const { @@ -98,7 +99,8 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t removed_co column_definitions.emplace_back(column_def.Copy()); } - info->InitializeIndexes(context); + // Bind all indexes. + info->BindIndexes(context); // first check if there are any indexes that exist that point to the removed column info->indexes.Scan([&](Index &index) { @@ -145,7 +147,9 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, BoundConstraint for (const auto &index_info : parent.info->index_storage_infos) { info->index_storage_infos.push_back(IndexStorageInfo(index_info.name)); } - info->InitializeIndexes(context); + + // Bind all indexes. + info->BindIndexes(context); auto &local_storage = LocalStorage::Get(context, db); lock_guard parent_lock(parent.append_lock); @@ -163,6 +167,7 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, BoundConstraint DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t changed_idx, const LogicalType &target_type, const vector &bound_columns, Expression &cast_expr) : db(parent.db), info(parent.info), version(DataTableVersion::MAIN_TABLE) { + auto &local_storage = LocalStorage::Get(context, db); // prevent any tuples from being added to the parent lock_guard lock(append_lock); @@ -170,7 +175,8 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t changed_id column_definitions.emplace_back(column_def.Copy()); } - info->InitializeIndexes(context); + // Bind all indexes. + info->BindIndexes(context); // first check if there are any indexes that exist that point to the changed column info->indexes.Scan([&](Index &index) { @@ -187,7 +193,7 @@ DataTable::DataTable(ClientContext &context, DataTable &parent, idx_t changed_id // set up the statistics for the table // the column that had its type changed will have the new statistics computed during conversion - this->row_groups = parent.row_groups->AlterType(context, changed_idx, target_type, bound_columns, cast_expr); + row_groups = parent.row_groups->AlterType(context, changed_idx, target_type, bound_columns, cast_expr); // scan the original table, and fill the new column with the transformed value local_storage.ChangeType(parent, *this, changed_idx, target_type, bound_columns, cast_expr); @@ -311,8 +317,8 @@ shared_ptr &DataTable::GetDataTableInfo() { return info; } -void DataTable::InitializeIndexes(ClientContext &context) { - info->InitializeIndexes(context); +void DataTable::BindIndexes(ClientContext &context) { + info->BindIndexes(context); } bool DataTable::HasIndexes() const { @@ -677,13 +683,15 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, optional_ptr manager) { // Verify the constraint without a conflict manager. if (!manager) { - return indexes.ScanBound([&](ART &art) { - if (!art.IsUnique()) { + return indexes.Scan([&](Index &index) { + if (!index.IsUnique() || index.GetIndexType() != ART::TYPE_NAME) { return false; } - + D_ASSERT(index.IsBound()); + auto &art = index.Cast(); if (storage) { auto delete_index = storage->delete_indexes.Find(art.GetIndexName()); + D_ASSERT(!delete_index || delete_index->IsBound()); IndexAppendInfo index_append_info(IndexAppendMode::DEFAULT, delete_index); art.VerifyAppend(chunk, index_append_info, nullptr); } else { @@ -698,16 +706,18 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, optional_ptrGetConflictInfo(); // Find all indexes matching the conflict target. - indexes.ScanBound([&](ART &art) { - if (!art.IsUnique()) { + indexes.Scan([&](Index &index) { + if (!index.IsUnique() || index.GetIndexType() != ART::TYPE_NAME) { return false; } - if (!conflict_info.ConflictTargetMatches(art)) { + if (!conflict_info.ConflictTargetMatches(index)) { return false; } - + D_ASSERT(index.IsBound()); + auto &art = index.Cast(); if (storage) { auto delete_index = storage->delete_indexes.Find(art.GetIndexName()); + D_ASSERT(!delete_index || delete_index->IsBound()); manager->AddIndex(art, delete_index); } else { manager->AddIndex(art, nullptr); @@ -727,16 +737,18 @@ void DataTable::VerifyUniqueIndexes(TableIndexList &indexes, optional_ptrSetMode(ConflictManagerMode::THROW); - indexes.ScanBound([&](ART &art) { - if (!art.IsUnique()) { + indexes.Scan([&](Index &index) { + if (!index.IsUnique() || index.GetIndexType() != ART::TYPE_NAME) { return false; } - if (manager->IndexMatches(art)) { + if (manager->IndexMatches(index.Cast())) { return false; } - + D_ASSERT(index.IsBound()); + auto &art = index.Cast(); if (storage) { auto delete_index = storage->delete_indexes.Find(art.GetIndexName()); + D_ASSERT(!delete_index || delete_index->IsBound()); IndexAppendInfo index_append_info(IndexAppendMode::DEFAULT, delete_index); art.VerifyAppend(chunk, index_append_info, *manager); } else { @@ -1132,7 +1144,7 @@ void DataTable::RevertAppend(DuckTransaction &transaction, idx_t start_row, idx_ row_data[i] = NumericCast(current_row_base + i); } info->indexes.Scan([&](Index &index) { - // We cant add to unbound indexes anyways, so there is no need to revert them + // We cannot add to unbound indexes, so there is no need to revert them. if (index.IsBound()) { index.Cast().Delete(chunk, row_identifiers); } @@ -1161,35 +1173,39 @@ void DataTable::RevertAppend(DuckTransaction &transaction, idx_t start_row, idx_ //===--------------------------------------------------------------------===// ErrorData DataTable::AppendToIndexes(TableIndexList &indexes, optional_ptr delete_indexes, DataChunk &chunk, row_t row_start, const IndexAppendMode index_append_mode) { - ErrorData error; if (indexes.Empty()) { - return error; + return ErrorData(); } - // first generate the vector of row identifiers + // Generate the vector of row identifiers. Vector row_ids(LogicalType::ROW_TYPE); VectorOperations::GenerateSequence(row_ids, chunk.size(), row_start, 1); - vector already_appended; + vector> already_appended; bool append_failed = false; - // now append the entries to the indices - indexes.Scan([&](Index &index_to_append) { - if (!index_to_append.IsBound()) { - throw InternalException("unbound index in DataTable::AppendToIndexes"); + + // Append the entries to the indexes. + ErrorData error; + indexes.Scan([&](Index &index) { + if (!index.IsBound()) { + auto &unbound_index = index.Cast(); + unbound_index.BufferChunk(chunk, row_ids); + return false; } - auto &index = index_to_append.Cast(); + + auto &bound_index = index.Cast(); // Find the matching delete index. optional_ptr delete_index; - if (index.IsUnique()) { + if (bound_index.IsUnique()) { if (delete_indexes) { - delete_index = delete_indexes->Find(index.name); + delete_index = delete_indexes->Find(bound_index.name); } } try { IndexAppendInfo index_append_info(index_append_mode, delete_index); - error = index.Append(chunk, row_ids, index_append_info); + error = bound_index.Append(chunk, row_ids, index_append_info); } catch (std::exception &ex) { error = ErrorData(ex); } @@ -1198,15 +1214,15 @@ ErrorData DataTable::AppendToIndexes(TableIndexList &indexes, optional_ptrDelete(chunk, row_ids); + // Constraint violation: remove any appended entries from previous indexes (if any). + for (auto index : already_appended) { + index.get().Delete(chunk, row_ids); } } return error; @@ -1295,8 +1311,8 @@ void DataTable::VerifyDeleteConstraints(optional_ptr storage, unique_ptr DataTable::InitializeDelete(TableCatalogEntry &table, ClientContext &context, const vector> &bound_constraints) { - // initialize indexes (if any) - info->InitializeIndexes(context); + // Bind all indexes. + info->BindIndexes(context); auto binder = Binder::CreateBinder(context); vector types; @@ -1460,9 +1476,8 @@ void DataTable::VerifyUpdateConstraints(ConstraintState &state, ClientContext &c unique_ptr DataTable::InitializeUpdate(TableCatalogEntry &table, ClientContext &context, const vector> &bound_constraints) { - // check that there are no unknown indexes - info->InitializeIndexes(context); - + // Bind all indexes. + info->BindIndexes(context); auto result = make_uniq(); result->constraint_state = InitializeConstraintState(table, bound_constraints); return result; diff --git a/src/duckdb/src/storage/local_storage.cpp b/src/duckdb/src/storage/local_storage.cpp index bc1197424..f4adf9419 100644 --- a/src/duckdb/src/storage/local_storage.cpp +++ b/src/duckdb/src/storage/local_storage.cpp @@ -1,16 +1,17 @@ #include "duckdb/transaction/local_storage.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/storage/table/append_state.hpp" -#include "duckdb/storage/write_ahead_log.hpp" + #include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/storage/table/row_group.hpp" -#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/execution/index/art/art.hpp" +#include "duckdb/execution/index/unbound_index.hpp" #include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/partial_block_manager.hpp" - +#include "duckdb/storage/table/append_state.hpp" #include "duckdb/storage/table/column_checkpoint_state.hpp" -#include "duckdb/storage/table_io_manager.hpp" +#include "duckdb/storage/table/row_group.hpp" #include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table_io_manager.hpp" +#include "duckdb/storage/write_ahead_log.hpp" +#include "duckdb/transaction/duck_transaction.hpp" namespace duckdb { @@ -24,11 +25,18 @@ LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) row_groups = make_shared_ptr(data_table_info, io_manager, types, MAX_ROW_ID, 0); row_groups->InitializeEmpty(); - data_table_info->GetIndexes().BindAndScan(context, *data_table_info, [&](ART &art) { - auto constraint_type = art.GetConstraintType(); - if (constraint_type == IndexConstraintType::NONE) { + data_table_info->GetIndexes().Scan([&](Index &index) { + auto constraint = index.GetConstraintType(); + if (constraint == IndexConstraintType::NONE) { return false; } + if (index.GetIndexType() != ART::TYPE_NAME) { + return false; + } + if (!index.IsBound()) { + return false; + } + auto &art = index.Cast(); // UNIQUE constraint. vector> expressions; @@ -39,13 +47,15 @@ LocalTableStorage::LocalTableStorage(ClientContext &context, DataTable &table) } // Create a delete index and a local index. - auto delete_index = make_uniq(art.GetIndexName(), constraint_type, art.GetColumnIds(), - art.table_io_manager, std::move(delete_expressions), art.db); + auto &name = art.GetIndexName(); + auto &io_manager = art.table_io_manager; + auto delete_index = + make_uniq(name, constraint, art.GetColumnIds(), io_manager, std::move(delete_expressions), art.db); delete_indexes.AddIndex(std::move(delete_index)); - auto index = make_uniq(art.GetIndexName(), constraint_type, art.GetColumnIds(), art.table_io_manager, - std::move(expressions), art.db); - append_indexes.AddIndex(std::move(index)); + auto append_index = + make_uniq(name, constraint, art.GetColumnIds(), io_manager, std::move(expressions), art.db); + append_indexes.AddIndex(std::move(append_index)); return false; }); } @@ -114,7 +124,9 @@ idx_t LocalTableStorage::EstimatedSize() { // get the index size idx_t index_sizes = 0; append_indexes.Scan([&](Index &index) { - D_ASSERT(index.IsBound()); + if (!index.IsBound()) { + return false; + } index_sizes += index.Cast().GetInMemorySize(); return false; }); @@ -381,13 +393,11 @@ bool LocalStorage::NextParallelScan(ClientContext &context, DataTable &table, Pa } void LocalStorage::InitializeAppend(LocalAppendState &state, DataTable &table) { - table.InitializeIndexes(context); state.storage = &table_manager.GetOrCreateStorage(context, table); state.storage->row_groups->InitializeAppend(TransactionData(transaction), state.append_state); } void LocalStorage::InitializeStorage(LocalAppendState &state, DataTable &table) { - table.InitializeIndexes(context); state.storage = &table_manager.GetOrCreateStorage(context, table); } @@ -396,12 +406,13 @@ void LocalTableStorage::AppendToDeleteIndexes(Vector &row_ids, DataChunk &delete return; } - delete_indexes.ScanBound([&](ART &art) { - if (!art.IsUnique()) { + delete_indexes.Scan([&](Index &index) { + D_ASSERT(index.IsBound()); + if (!index.IsUnique()) { return false; } IndexAppendInfo index_append_info(IndexAppendMode::IGNORE_DUPLICATES, nullptr); - auto result = art.Cast().Append(delete_chunk, row_ids, index_append_info); + auto result = index.Cast().Append(delete_chunk, row_ids, index_append_info); if (result.HasError()) { throw InternalException("unexpected constraint violation on delete ART: ", result.Message()); } @@ -516,10 +527,9 @@ void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage, optional_ storage.Rollback(); return; } - idx_t append_count = storage.row_groups->GetTotalRows() - storage.deleted_rows; - table.InitializeIndexes(context); - const idx_t row_group_size = storage.row_groups->GetRowGroupSize(); + auto append_count = storage.row_groups->GetTotalRows() - storage.deleted_rows; + const auto row_group_size = storage.row_groups->GetRowGroupSize(); TableAppendState append_state; table.AppendLock(append_state); @@ -529,9 +539,7 @@ void LocalStorage::Flush(DataTable &table, LocalTableStorage &storage, optional_ // table is currently empty OR we are bulk appending: move over the storage directly // first flush any outstanding blocks storage.FlushBlocks(); - // now append to the indexes (if there are any) - // FIXME: we should be able to merge the transaction-local index directly into the main table index - // as long we just rewrite some row-ids + // Append to the indexes. if (table.HasIndexes()) { storage.AppendToIndexes(transaction, append_state, false); } diff --git a/src/duckdb/src/storage/metadata/metadata_manager.cpp b/src/duckdb/src/storage/metadata/metadata_manager.cpp index 824c5bb40..e988ab99a 100644 --- a/src/duckdb/src/storage/metadata/metadata_manager.cpp +++ b/src/duckdb/src/storage/metadata/metadata_manager.cpp @@ -9,6 +9,28 @@ namespace duckdb { +MetadataBlock::MetadataBlock() : block_id(INVALID_BLOCK), dirty(false) { +} + +MetadataBlock::MetadataBlock(MetadataBlock &&other) noexcept : dirty(false) { + std::swap(block, other.block); + std::swap(block_id, other.block_id); + std::swap(free_blocks, other.free_blocks); + auto dirty_val = dirty.load(); + dirty = other.dirty.load(); + other.dirty = dirty_val; +} + +MetadataBlock &MetadataBlock::operator=(MetadataBlock &&other) noexcept { + std::swap(block, other.block); + std::swap(block_id, other.block_id); + std::swap(free_blocks, other.free_blocks); + auto dirty_val = dirty.load(); + dirty = other.dirty.load(); + other.dirty = dirty_val; + return *this; +} + MetadataManager::MetadataManager(BlockManager &block_manager, BufferManager &buffer_manager) : block_manager(block_manager), buffer_manager(buffer_manager) { } @@ -37,6 +59,8 @@ MetadataHandle MetadataManager::AllocateHandle() { MetadataPointer pointer; pointer.block_index = UnsafeNumericCast(free_block); auto &block = blocks[free_block]; + // the block is now dirty + block.dirty = true; if (block.block->BlockId() < MAXIMUM_BLOCK) { // this block is a disk-backed block, yet we are planning to write to it // we need to convert it into a transient block before we can write to it @@ -55,6 +79,13 @@ MetadataHandle MetadataManager::AllocateHandle() { MetadataHandle MetadataManager::Pin(const MetadataPointer &pointer) { D_ASSERT(pointer.index < METADATA_BLOCK_COUNT); auto &block = blocks[UnsafeNumericCast(pointer.block_index)]; +#ifdef DEBUG + for (auto &free_block : block.free_blocks) { + if (free_block == pointer.index) { + throw InternalException("Pinning block %d.%d but it is marked as a free block", block.block_id, free_block); + } + } +#endif MetadataHandle handle; handle.pointer.block_index = pointer.block_index; @@ -74,6 +105,7 @@ void MetadataManager::ConvertToTransient(MetadataBlock &metadata_block) { // copy the data to the transient block memcpy(new_buffer.Ptr(), old_buffer.Ptr(), block_manager.GetBlockSize()); metadata_block.block = std::move(new_block); + metadata_block.dirty = true; // unregister the old block block_manager.UnregisterBlock(metadata_block.block_id); @@ -89,6 +121,7 @@ block_id_t MetadataManager::AllocateNewBlock() { for (idx_t i = 0; i < METADATA_BLOCK_COUNT; i++) { new_block.free_blocks.push_back(NumericCast(METADATA_BLOCK_COUNT - i - 1)); } + new_block.dirty = true; // zero-initialize the handle memset(handle.Ptr(), 0, block_manager.GetBlockSize()); AddBlock(std::move(new_block)); @@ -109,6 +142,9 @@ void MetadataManager::AddAndRegisterBlock(MetadataBlock block) { if (block.block) { throw InternalException("Calling AddAndRegisterBlock on block that already exists"); } + if (block.block_id >= MAXIMUM_BLOCK) { + throw InternalException("AddAndRegisterBlock called with a transient block id"); + } block.block = block_manager.RegisterBlock(block.block_id); AddBlock(std::move(block), true); } @@ -145,7 +181,7 @@ MetadataPointer MetadataManager::RegisterDiskPointer(MetaBlockPointer pointer) { auto block_id = pointer.GetBlockId(); MetadataBlock block; block.block_id = block_id; - AddAndRegisterBlock(block); + AddAndRegisterBlock(std::move(block)); return FromDiskPointer(pointer); } @@ -181,6 +217,12 @@ void MetadataManager::Flush() { for (auto &kv : blocks) { auto &block = kv.second; + if (!block.dirty) { + if (block.block->BlockId() >= MAXIMUM_BLOCK) { + throw InternalException("Transient blocks must always be marked as dirty"); + } + continue; + } auto handle = buffer_manager.Pin(block.block); // zero-initialize the few leftover bytes memset(handle.Ptr() + total_metadata_size, 0, block_manager.GetBlockSize() - total_metadata_size); @@ -189,11 +231,13 @@ void MetadataManager::Flush() { // Convert the temporary block to a persistent block. block.block = block_manager.ConvertToPersistent(QueryContext(), kv.first, std::move(block.block), std::move(handle)); - continue; + } else { + // Already a persistent block, so we only need to write it. + D_ASSERT(block.block->BlockId() == block.block_id); + block_manager.Write(QueryContext(), handle.GetFileBuffer(), block.block_id); } - // Already a persistent block, so we only need to write it. - D_ASSERT(block.block->BlockId() == block.block_id); - block_manager.Write(QueryContext(), handle.GetFileBuffer(), block.block_id); + // the block is no longer dirty + block.dirty = false; } } @@ -299,8 +343,6 @@ void MetadataManager::ClearModifiedBlocks(const vector &pointe throw InternalException("ClearModifiedBlocks - Block id %llu not found in modified_blocks", block_id); } auto &modified_list = entry->second; - // verify the block has been modified - D_ASSERT(modified_list && (1ULL << block_index)); // unset the bit modified_list &= ~(1ULL << block_index); } diff --git a/src/duckdb/src/storage/metadata/metadata_reader.cpp b/src/duckdb/src/storage/metadata/metadata_reader.cpp index 833a6e656..060ec24ee 100644 --- a/src/duckdb/src/storage/metadata/metadata_reader.cpp +++ b/src/duckdb/src/storage/metadata/metadata_reader.cpp @@ -7,7 +7,6 @@ MetadataReader::MetadataReader(MetadataManager &manager, MetaBlockPointer pointe : manager(manager), type(type), next_pointer(FromDiskPointer(pointer)), has_next_block(true), read_pointers(read_pointers_p), index(0), offset(0), next_offset(pointer.offset), capacity(0) { if (read_pointers) { - D_ASSERT(read_pointers->empty()); read_pointers->push_back(pointer); } } @@ -47,9 +46,25 @@ void MetadataReader::ReadData(data_ptr_t buffer, idx_t read_size) { } MetaBlockPointer MetadataReader::GetMetaBlockPointer() { + if (capacity == 0) { + throw InternalException("GetMetaBlockPointer called but there is no active pointer"); + } return manager.GetDiskPointer(block.pointer, UnsafeNumericCast(offset)); } +vector MetadataReader::GetRemainingBlocks(MetaBlockPointer last_block) { + vector result; + while (has_next_block) { + auto next_block_pointer = manager.GetDiskPointer(next_pointer, UnsafeNumericCast(next_offset)); + if (last_block.IsValid() && next_block_pointer.block_pointer == last_block.block_pointer) { + break; + } + result.push_back(next_block_pointer); + ReadNextBlock(); + } + return result; +} + void MetadataReader::ReadNextBlock() { if (!has_next_block) { throw IOException("No more data remaining in MetadataReader"); diff --git a/src/duckdb/src/storage/metadata/metadata_writer.cpp b/src/duckdb/src/storage/metadata/metadata_writer.cpp index a059d7cca..69d8ea87e 100644 --- a/src/duckdb/src/storage/metadata/metadata_writer.cpp +++ b/src/duckdb/src/storage/metadata/metadata_writer.cpp @@ -30,6 +30,13 @@ MetaBlockPointer MetadataWriter::GetMetaBlockPointer() { return manager.GetDiskPointer(block.pointer, UnsafeNumericCast(offset)); } +void MetadataWriter::SetWrittenPointers(optional_ptr> written_pointers_p) { + written_pointers = written_pointers_p; + if (written_pointers && capacity > 0) { + written_pointers->push_back(manager.GetDiskPointer(current_pointer)); + } +} + MetadataHandle MetadataWriter::NextHandle() { return manager.AllocateHandle(); } diff --git a/src/duckdb/src/storage/open_file_storage_extension.cpp b/src/duckdb/src/storage/open_file_storage_extension.cpp index 125f65b26..1daf5b5e6 100644 --- a/src/duckdb/src/storage/open_file_storage_extension.cpp +++ b/src/duckdb/src/storage/open_file_storage_extension.cpp @@ -42,9 +42,9 @@ class OpenFileDefaultGenerator : public DefaultGenerator { string file; }; -unique_ptr OpenFileStorageAttach(StorageExtensionInfo *storage_info, ClientContext &context, +unique_ptr OpenFileStorageAttach(optional_ptr storage_info, ClientContext &context, AttachedDatabase &db, const string &name, AttachInfo &info, - AccessMode access_mode) { + AttachOptions &attach_options) { auto file = info.path; // open an in-memory database info.path = ":memory:"; @@ -66,7 +66,7 @@ unique_ptr OpenFileStorageAttach(StorageExtensionInfo *storage_info, Cl return std::move(catalog); } -unique_ptr OpenFileStorageTransactionManager(StorageExtensionInfo *storage_info, +unique_ptr OpenFileStorageTransactionManager(optional_ptr storage_info, AttachedDatabase &db, Catalog &catalog) { return make_uniq(db); } diff --git a/src/duckdb/src/storage/partial_block_manager.cpp b/src/duckdb/src/storage/partial_block_manager.cpp index f0f330be1..27fe86cd3 100644 --- a/src/duckdb/src/storage/partial_block_manager.cpp +++ b/src/duckdb/src/storage/partial_block_manager.cpp @@ -1,4 +1,6 @@ #include "duckdb/storage/partial_block_manager.hpp" +#include "duckdb/storage/table/in_memory_checkpoint.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" namespace duckdb { @@ -15,6 +17,10 @@ void PartialBlock::AddUninitializedRegion(idx_t start, idx_t end) { uninitialized_regions.push_back({start, end}); } +void PartialBlock::AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block) { + throw InternalException("PartialBlock::AddSegmentToTail not supported for this block type"); +} + void PartialBlock::FlushInternal(const idx_t free_space_left) { // ensure that we do not leak any data @@ -105,6 +111,14 @@ bool PartialBlockManager::GetPartialBlock(idx_t segment_size, unique_ptr PartialBlockManager::CreatePartialBlock(ColumnData &column_data, ColumnSegment &segment, + PartialBlockState state, BlockManager &block_manager) { + if (partial_block_type == PartialBlockType::IN_MEMORY_CHECKPOINT) { + return make_uniq(column_data, segment, state, block_manager); + } + return make_uniq(column_data, segment, state, block_manager); +} + void PartialBlockManager::RegisterPartialBlock(PartialBlockAllocation allocation) { auto &state = allocation.partial_block->state; D_ASSERT(partial_block_type != PartialBlockType::FULL_CHECKPOINT || state.block_id >= 0); @@ -120,6 +134,7 @@ void PartialBlockManager::RegisterPartialBlock(PartialBlockAllocation allocation // check if the block is STILL partially filled after adding the segment_size if (new_space_left >= block_manager.GetBlockSize() - max_partial_block_size) { // the block is still partially filled: add it to the partially_filled_blocks list + D_ASSERT(allocation.partial_block->state.offset > 0); partially_filled_blocks.insert(make_pair(new_space_left, std::move(allocation.partial_block))); } } diff --git a/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp index 06d526706..83c37ec4f 100644 --- a/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp +++ b/src/duckdb/src/storage/serialization/serialize_logical_operator.cpp @@ -540,18 +540,18 @@ void LogicalInsert::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(204, "table_index", table_index); serializer.WritePropertyWithDefault(205, "return_chunk", return_chunk); serializer.WritePropertyWithDefault>>(206, "bound_defaults", bound_defaults); - serializer.WriteProperty(207, "action_type", action_type); - serializer.WritePropertyWithDefault>(208, "expected_set_types", expected_set_types); - serializer.WritePropertyWithDefault>(209, "on_conflict_filter", on_conflict_filter); - serializer.WritePropertyWithDefault>(210, "on_conflict_condition", on_conflict_condition); - serializer.WritePropertyWithDefault>(211, "do_update_condition", do_update_condition); - serializer.WritePropertyWithDefault>(212, "set_columns", set_columns); - serializer.WritePropertyWithDefault>(213, "set_types", set_types); - serializer.WritePropertyWithDefault(214, "excluded_table_index", excluded_table_index); - serializer.WritePropertyWithDefault>(215, "columns_to_fetch", columns_to_fetch); - serializer.WritePropertyWithDefault>(216, "source_columns", source_columns); + serializer.WriteProperty(207, "action_type", on_conflict_info.action_type); + serializer.WritePropertyWithDefault>(208, "expected_set_types", on_conflict_info.expected_set_types); + serializer.WritePropertyWithDefault>(209, "on_conflict_filter", on_conflict_info.on_conflict_filter); + serializer.WritePropertyWithDefault>(210, "on_conflict_condition", on_conflict_info.on_conflict_condition); + serializer.WritePropertyWithDefault>(211, "do_update_condition", on_conflict_info.do_update_condition); + serializer.WritePropertyWithDefault>(212, "set_columns", on_conflict_info.set_columns); + serializer.WritePropertyWithDefault>(213, "set_types", on_conflict_info.set_types); + serializer.WritePropertyWithDefault(214, "excluded_table_index", on_conflict_info.excluded_table_index); + serializer.WritePropertyWithDefault>(215, "columns_to_fetch", on_conflict_info.columns_to_fetch); + serializer.WritePropertyWithDefault>(216, "source_columns", on_conflict_info.source_columns); serializer.WritePropertyWithDefault>>(217, "expressions", expressions); - serializer.WritePropertyWithDefault(218, "update_is_del_and_insert", update_is_del_and_insert, false); + serializer.WritePropertyWithDefault(218, "update_is_del_and_insert", on_conflict_info.update_is_del_and_insert, false); } unique_ptr LogicalInsert::Deserialize(Deserializer &deserializer) { @@ -563,18 +563,18 @@ unique_ptr LogicalInsert::Deserialize(Deserializer &deserialize deserializer.ReadPropertyWithDefault(204, "table_index", result->table_index); deserializer.ReadPropertyWithDefault(205, "return_chunk", result->return_chunk); deserializer.ReadPropertyWithDefault>>(206, "bound_defaults", result->bound_defaults); - deserializer.ReadProperty(207, "action_type", result->action_type); - deserializer.ReadPropertyWithDefault>(208, "expected_set_types", result->expected_set_types); - deserializer.ReadPropertyWithDefault>(209, "on_conflict_filter", result->on_conflict_filter); - deserializer.ReadPropertyWithDefault>(210, "on_conflict_condition", result->on_conflict_condition); - deserializer.ReadPropertyWithDefault>(211, "do_update_condition", result->do_update_condition); - deserializer.ReadPropertyWithDefault>(212, "set_columns", result->set_columns); - deserializer.ReadPropertyWithDefault>(213, "set_types", result->set_types); - deserializer.ReadPropertyWithDefault(214, "excluded_table_index", result->excluded_table_index); - deserializer.ReadPropertyWithDefault>(215, "columns_to_fetch", result->columns_to_fetch); - deserializer.ReadPropertyWithDefault>(216, "source_columns", result->source_columns); + deserializer.ReadProperty(207, "action_type", result->on_conflict_info.action_type); + deserializer.ReadPropertyWithDefault>(208, "expected_set_types", result->on_conflict_info.expected_set_types); + deserializer.ReadPropertyWithDefault>(209, "on_conflict_filter", result->on_conflict_info.on_conflict_filter); + deserializer.ReadPropertyWithDefault>(210, "on_conflict_condition", result->on_conflict_info.on_conflict_condition); + deserializer.ReadPropertyWithDefault>(211, "do_update_condition", result->on_conflict_info.do_update_condition); + deserializer.ReadPropertyWithDefault>(212, "set_columns", result->on_conflict_info.set_columns); + deserializer.ReadPropertyWithDefault>(213, "set_types", result->on_conflict_info.set_types); + deserializer.ReadPropertyWithDefault(214, "excluded_table_index", result->on_conflict_info.excluded_table_index); + deserializer.ReadPropertyWithDefault>(215, "columns_to_fetch", result->on_conflict_info.columns_to_fetch); + deserializer.ReadPropertyWithDefault>(216, "source_columns", result->on_conflict_info.source_columns); deserializer.ReadPropertyWithDefault>>(217, "expressions", result->expressions); - deserializer.ReadPropertyWithExplicitDefault(218, "update_is_del_and_insert", result->update_is_del_and_insert, false); + deserializer.ReadPropertyWithExplicitDefault(218, "update_is_del_and_insert", result->on_conflict_info.update_is_del_and_insert, false); return std::move(result); } @@ -616,6 +616,7 @@ void LogicalMergeInto::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(203, "row_id_start", row_id_start); serializer.WriteProperty(204, "source_marker", source_marker); serializer.WritePropertyWithDefault>>>(205, "actions", actions); + serializer.WritePropertyWithDefault(206, "return_chunk", return_chunk); } unique_ptr LogicalMergeInto::Deserialize(Deserializer &deserializer) { @@ -626,6 +627,7 @@ unique_ptr LogicalMergeInto::Deserialize(Deserializer &deserial deserializer.ReadPropertyWithDefault(203, "row_id_start", result->row_id_start); deserializer.ReadProperty(204, "source_marker", result->source_marker); deserializer.ReadPropertyWithDefault>>>(205, "actions", result->actions); + deserializer.ReadPropertyWithDefault(206, "return_chunk", result->return_chunk); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_tableref.cpp b/src/duckdb/src/storage/serialization/serialize_tableref.cpp index 97e5c3fda..69af270d3 100644 --- a/src/duckdb/src/storage/serialization/serialize_tableref.cpp +++ b/src/duckdb/src/storage/serialization/serialize_tableref.cpp @@ -138,6 +138,9 @@ void JoinRef::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault>(205, "using_columns", using_columns); serializer.WritePropertyWithDefault(206, "delim_flipped", delim_flipped); serializer.WritePropertyWithDefault>>(207, "duplicate_eliminated_columns", duplicate_eliminated_columns); + if (serializer.ShouldSerialize(6)) { + serializer.WritePropertyWithDefault(208, "is_implicit", is_implicit, true); + } } unique_ptr JoinRef::Deserialize(Deserializer &deserializer) { @@ -150,6 +153,7 @@ unique_ptr JoinRef::Deserialize(Deserializer &deserializer) { deserializer.ReadPropertyWithDefault>(205, "using_columns", result->using_columns); deserializer.ReadPropertyWithDefault(206, "delim_flipped", result->delim_flipped); deserializer.ReadPropertyWithDefault>>(207, "duplicate_eliminated_columns", result->duplicate_eliminated_columns); + deserializer.ReadPropertyWithExplicitDefault(208, "is_implicit", result->is_implicit, true); return std::move(result); } @@ -181,6 +185,8 @@ void ShowRef::Serialize(Serializer &serializer) const { serializer.WritePropertyWithDefault(200, "table_name", table_name); serializer.WritePropertyWithDefault>(201, "query", query); serializer.WriteProperty(202, "show_type", show_type); + serializer.WritePropertyWithDefault(203, "catalog_name", catalog_name); + serializer.WritePropertyWithDefault(204, "schema_name", schema_name); } unique_ptr ShowRef::Deserialize(Deserializer &deserializer) { @@ -188,6 +194,8 @@ unique_ptr ShowRef::Deserialize(Deserializer &deserializer) { deserializer.ReadPropertyWithDefault(200, "table_name", result->table_name); deserializer.ReadPropertyWithDefault>(201, "query", result->query); deserializer.ReadProperty(202, "show_type", result->show_type); + deserializer.ReadPropertyWithDefault(203, "catalog_name", result->catalog_name); + deserializer.ReadPropertyWithDefault(204, "schema_name", result->schema_name); return std::move(result); } @@ -208,12 +216,14 @@ void TableFunctionRef::Serialize(Serializer &serializer) const { TableRef::Serialize(serializer); serializer.WritePropertyWithDefault>(200, "function", function); serializer.WritePropertyWithDefault>(201, "column_name_alias", column_name_alias); + serializer.WritePropertyWithDefault(202, "with_ordinality", with_ordinality, OrdinalityType::WITHOUT_ORDINALITY); } unique_ptr TableFunctionRef::Deserialize(Deserializer &deserializer) { auto result = duckdb::unique_ptr(new TableFunctionRef()); deserializer.ReadPropertyWithDefault>(200, "function", result->function); deserializer.ReadPropertyWithDefault>(201, "column_name_alias", result->column_name_alias); + deserializer.ReadPropertyWithExplicitDefault(202, "with_ordinality", result->with_ordinality, OrdinalityType::WITHOUT_ORDINALITY); return std::move(result); } diff --git a/src/duckdb/src/storage/serialization/serialize_types.cpp b/src/duckdb/src/storage/serialization/serialize_types.cpp index 6f5e811ef..453961009 100644 --- a/src/duckdb/src/storage/serialization/serialize_types.cpp +++ b/src/duckdb/src/storage/serialization/serialize_types.cpp @@ -56,6 +56,9 @@ shared_ptr ExtraTypeInfo::Deserialize(Deserializer &deserializer) case ExtraTypeInfoType::STRUCT_TYPE_INFO: result = StructTypeInfo::Deserialize(deserializer); break; + case ExtraTypeInfoType::TEMPLATE_TYPE_INFO: + result = TemplateTypeInfo::Deserialize(deserializer); + break; case ExtraTypeInfoType::USER_TYPE_INFO: result = UserTypeInfo::Deserialize(deserializer); break; @@ -189,6 +192,17 @@ shared_ptr StructTypeInfo::Deserialize(Deserializer &deserializer return std::move(result); } +void TemplateTypeInfo::Serialize(Serializer &serializer) const { + ExtraTypeInfo::Serialize(serializer); + serializer.WritePropertyWithDefault(200, "name", name); +} + +shared_ptr TemplateTypeInfo::Deserialize(Deserializer &deserializer) { + auto result = duckdb::shared_ptr(new TemplateTypeInfo()); + deserializer.ReadPropertyWithDefault(200, "name", result->name); + return std::move(result); +} + void UserTypeInfo::Serialize(Serializer &serializer) const { ExtraTypeInfo::Serialize(serializer); serializer.WritePropertyWithDefault(200, "user_type_name", user_type_name); diff --git a/src/duckdb/src/storage/single_file_block_manager.cpp b/src/duckdb/src/storage/single_file_block_manager.cpp index a9c8873b7..cfcbd43a7 100644 --- a/src/duckdb/src/storage/single_file_block_manager.cpp +++ b/src/duckdb/src/storage/single_file_block_manager.cpp @@ -15,6 +15,7 @@ #include "duckdb/storage/metadata/metadata_reader.hpp" #include "duckdb/storage/metadata/metadata_writer.hpp" #include "duckdb/storage/storage_manager.hpp" +#include "duckdb/main/settings.hpp" #include "mbedtls_wrapper.hpp" #include @@ -120,12 +121,12 @@ void MainHeader::Write(WriteStream &ser) { } } -void MainHeader::CheckMagicBytes(FileHandle &handle) { +void MainHeader::CheckMagicBytes(QueryContext context, FileHandle &handle) { data_t magic_bytes[MAGIC_BYTE_SIZE]; if (handle.GetFileSize() < MainHeader::MAGIC_BYTE_SIZE + MainHeader::MAGIC_BYTE_OFFSET) { throw IOException("The file \"%s\" exists, but it is not a valid DuckDB database file!", handle.path); } - handle.Read(magic_bytes, MainHeader::MAGIC_BYTE_SIZE, MainHeader::MAGIC_BYTE_OFFSET); + handle.Read(context, magic_bytes, MainHeader::MAGIC_BYTE_SIZE, MainHeader::MAGIC_BYTE_OFFSET); if (memcmp(magic_bytes, MainHeader::MAGIC_BYTES, MainHeader::MAGIC_BYTE_SIZE) != 0) { throw IOException("The file \"%s\" exists, but it is not a valid DuckDB database file!", handle.path); } @@ -427,7 +428,7 @@ void SingleFileBlockManager::CreateNewDatabase(QueryContext context) { max_block = 0; } -void SingleFileBlockManager::LoadExistingDatabase() { +void SingleFileBlockManager::LoadExistingDatabase(QueryContext context) { auto flags = GetFileFlags(false); // open the RDBMS handle @@ -438,9 +439,9 @@ void SingleFileBlockManager::LoadExistingDatabase() { throw IOException("Cannot open database \"%s\" in read-only mode: database does not exist", path); } - MainHeader::CheckMagicBytes(*handle); + MainHeader::CheckMagicBytes(context, *handle); // otherwise, we check the metadata of the file - ReadAndChecksum(header_buffer, 0, true); + ReadAndChecksum(context, header_buffer, 0, true); uint64_t delta = 0; if (GetBlockHeaderSize() > DEFAULT_BLOCK_HEADER_STORAGE_SIZE) { @@ -472,11 +473,11 @@ void SingleFileBlockManager::LoadExistingDatabase() { // read the database headers from disk DatabaseHeader h1; - ReadAndChecksum(header_buffer, Storage::FILE_HEADER_SIZE); + ReadAndChecksum(context, header_buffer, Storage::FILE_HEADER_SIZE); h1 = DeserializeDatabaseHeader(main_header, header_buffer.buffer); DatabaseHeader h2; - ReadAndChecksum(header_buffer, Storage::FILE_HEADER_SIZE * 2ULL); + ReadAndChecksum(context, header_buffer, Storage::FILE_HEADER_SIZE * 2ULL); h2 = DeserializeDatabaseHeader(main_header, header_buffer.buffer); // check the header with the highest iteration count @@ -538,9 +539,10 @@ void SingleFileBlockManager::CheckChecksum(FileBuffer &block, uint64_t location, } } -void SingleFileBlockManager::ReadAndChecksum(FileBuffer &block, uint64_t location, bool skip_block_header) const { +void SingleFileBlockManager::ReadAndChecksum(QueryContext context, FileBuffer &block, uint64_t location, + bool skip_block_header) const { // read the buffer from disk - block.Read(*handle, location); + block.Read(context, *handle, location); //! calculate delta header bytes (if any) uint64_t delta = GetBlockHeaderSize() - Storage::DEFAULT_BLOCK_HEADER_SIZE; @@ -860,7 +862,7 @@ void SingleFileBlockManager::ReadBlock(data_ptr_t internal_buffer, uint64_t bloc void SingleFileBlockManager::ReadBlock(Block &block, bool skip_block_header) const { // read the buffer from disk auto location = GetBlockLocation(block.id); - block.Read(*handle, location); + block.Read(QueryContext(), *handle, location); //! calculate delta header bytes (if any) uint64_t delta = GetBlockHeaderSize() - Storage::DEFAULT_BLOCK_HEADER_SIZE; @@ -876,7 +878,7 @@ void SingleFileBlockManager::ReadBlock(Block &block, bool skip_block_header) con void SingleFileBlockManager::Read(Block &block) { D_ASSERT(block.id >= 0); D_ASSERT(std::find(free_list.begin(), free_list.end(), block.id) == free_list.end()); - ReadAndChecksum(block, GetBlockLocation(block.id)); + ReadAndChecksum(QueryContext(), block, GetBlockLocation(block.id)); } void SingleFileBlockManager::ReadBlocks(FileBuffer &buffer, block_id_t start_block, idx_t block_count) { @@ -885,7 +887,7 @@ void SingleFileBlockManager::ReadBlocks(FileBuffer &buffer, block_id_t start_blo // read the buffer from disk auto location = GetBlockLocation(start_block); - buffer.Read(*handle, location); + buffer.Read(QueryContext(), *handle, location); // for each of the blocks - verify the checksum auto ptr = buffer.InternalBuffer(); @@ -1018,8 +1020,8 @@ void SingleFileBlockManager::WriteHeader(QueryContext context, DatabaseHeader he header.block_count = NumericCast(max_block); header.serialization_compatibility = options.storage_version.GetIndex(); - auto &config = DBConfig::Get(db); - if (config.options.checkpoint_abort == CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE) { + auto debug_checkpoint_abort = DBConfig::GetSetting(db.GetDatabase()); + if (debug_checkpoint_abort == CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE) { throw FatalException("Checkpoint aborted after free list write because of PRAGMA checkpoint_abort flag"); } diff --git a/src/duckdb/src/storage/standard_buffer_manager.cpp b/src/duckdb/src/storage/standard_buffer_manager.cpp index f9a8b3baf..d2b1b54ba 100644 --- a/src/duckdb/src/storage/standard_buffer_manager.cpp +++ b/src/duckdb/src/storage/standard_buffer_manager.cpp @@ -496,8 +496,8 @@ void StandardBufferManager::WriteTemporaryBuffer(MemoryTag tag, block_id_t block // Append to a few grouped files. if (buffer.AllocSize() == GetBlockAllocSize()) { - evicted_data_per_tag[uint8_t(tag)] += GetBlockAllocSize(); - temporary_directory.handle->GetTempFile().WriteTemporaryBuffer(block_id, buffer); + idx_t eviction_size = temporary_directory.handle->GetTempFile().WriteTemporaryBuffer(block_id, buffer); + evicted_data_per_tag[uint8_t(tag)] += eviction_size; return; } @@ -534,14 +534,15 @@ void StandardBufferManager::WriteTemporaryBuffer(MemoryTag tag, block_id_t block buffer.Write(QueryContext(), *handle, offset); } -unique_ptr StandardBufferManager::ReadTemporaryBuffer(MemoryTag tag, BlockHandle &block, +unique_ptr StandardBufferManager::ReadTemporaryBuffer(QueryContext context, MemoryTag tag, + BlockHandle &block, unique_ptr reusable_buffer) { D_ASSERT(!temporary_directory.path.empty()); D_ASSERT(temporary_directory.handle.get()); auto id = block.BlockId(); if (temporary_directory.handle->GetTempFile().HasTemporaryBuffer(id)) { // This is a block that was offloaded to a regular .tmp file, the file contains blocks of a fixed size - return temporary_directory.handle->GetTempFile().ReadTemporaryBuffer(id, std::move(reusable_buffer)); + return temporary_directory.handle->GetTempFile().ReadTemporaryBuffer(context, id, std::move(reusable_buffer)); } // This block contains data of variable size so we need to open it and read it to get its size. @@ -550,8 +551,8 @@ unique_ptr StandardBufferManager::ReadTemporaryBuffer(MemoryTag tag, auto path = GetTemporaryPath(id); auto &fs = FileSystem::GetFileSystem(db); auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ); - handle->Read(&block_size, sizeof(idx_t), 0); - handle->Read(&block_header_size, sizeof(idx_t), sizeof(idx_t)); + handle->Read(context, &block_size, sizeof(idx_t), 0); + handle->Read(context, &block_header_size, sizeof(idx_t), sizeof(idx_t)); idx_t offset = sizeof(idx_t) * 2; @@ -562,15 +563,15 @@ unique_ptr StandardBufferManager::ReadTemporaryBuffer(MemoryTag tag, // encrypted //! Read nonce and tag from file. uint8_t encryption_metadata[DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE]; - handle->Read(encryption_metadata, DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE, offset); + handle->Read(context, encryption_metadata, DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE, offset); //! Read and decrypt the buffer. - buffer->Read(*handle, offset + DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE); + buffer->Read(context, *handle, offset + DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE); EncryptionEngine::DecryptTemporaryBuffer(GetDatabase(), buffer->InternalBuffer(), buffer->AllocSize(), encryption_metadata); } else { // unencrypted: read the data directly - buffer->Read(*handle, offset); + buffer->Read(context, *handle, offset); } handle.reset(); @@ -596,8 +597,8 @@ void StandardBufferManager::DeleteTemporaryFile(BlockHandle &block) { // check if we should delete the file from the shared pool of files, or from the general file system if (temporary_directory.handle->GetTempFile().HasTemporaryBuffer(id)) { - evicted_data_per_tag[uint8_t(block.GetMemoryTag())] -= GetBlockAllocSize(); - temporary_directory.handle->GetTempFile().DeleteTemporaryBuffer(id); + idx_t eviction_size = temporary_directory.handle->GetTempFile().DeleteTemporaryBuffer(id); + evicted_data_per_tag[uint8_t(block.GetMemoryTag())] -= eviction_size; return; } diff --git a/src/duckdb/src/storage/statistics/string_stats.cpp b/src/duckdb/src/storage/statistics/string_stats.cpp index 4a8e629d3..e7d232692 100644 --- a/src/duckdb/src/storage/statistics/string_stats.cpp +++ b/src/duckdb/src/storage/statistics/string_stats.cpp @@ -161,7 +161,7 @@ void StringStats::Update(BaseStatistics &stats, const string_t &value) { } if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { auto unicode = Utf8Proc::Analyze(const_char_ptr_cast(data), size); - if (unicode == UnicodeType::UNICODE) { + if (unicode == UnicodeType::UTF8) { string_data.has_unicode = true; } else if (unicode == UnicodeType::INVALID) { throw ErrorManager::InvalidUnicodeError(string(const_char_ptr_cast(data), size), @@ -291,7 +291,7 @@ void StringStats::Verify(const BaseStatistics &stats, Vector &vector, const Sele } if (stats.GetType().id() == LogicalTypeId::VARCHAR && !string_data.has_unicode) { auto unicode = Utf8Proc::Analyze(data, len); - if (unicode == UnicodeType::UNICODE) { + if (unicode == UnicodeType::UTF8) { throw InternalException("Statistics mismatch: string value contains unicode, but statistics says it " "shouldn't.\nStatistics: %s\nVector: %s", stats.ToString(), vector.ToString(count)); diff --git a/src/duckdb/src/storage/storage_manager.cpp b/src/duckdb/src/storage/storage_manager.cpp index 62f25d7ff..89e8fbd1f 100644 --- a/src/duckdb/src/storage/storage_manager.cpp +++ b/src/duckdb/src/storage/storage_manager.cpp @@ -12,14 +12,65 @@ #include "duckdb/storage/single_file_block_manager.hpp" #include "duckdb/storage/storage_extension.hpp" #include "duckdb/storage/table/column_data.hpp" +#include "duckdb/storage/table/in_memory_checkpoint.hpp" #include "mbedtls_wrapper.hpp" namespace duckdb { using SHA256State = duckdb_mbedtls::MbedTlsWrapper::SHA256State; -StorageManager::StorageManager(AttachedDatabase &db, string path_p, bool read_only) - : db(db), path(std::move(path_p)), read_only(read_only) { +void StorageOptions::Initialize(const unordered_map &options) { + string storage_version_user_provided = ""; + for (auto &entry : options) { + if (entry.first == "block_size") { + // Extract the block allocation size. This is NOT the actual memory available on a block (block_size), + // even though the corresponding option we expose to the user is called "block_size". + block_alloc_size = entry.second.GetValue(); + } else if (entry.first == "encryption_key") { + // check the type of the key + auto type = entry.second.type(); + if (type.id() != LogicalTypeId::VARCHAR) { + throw BinderException("\"%s\" is not a valid key. A key must be of type VARCHAR", + entry.second.ToString()); + } else if (entry.second.GetValue().empty()) { + throw BinderException("Not a valid key. A key cannot be empty"); + } + user_key = make_shared_ptr(StringValue::Get(entry.second.DefaultCastAs(LogicalType::BLOB))); + block_header_size = DEFAULT_ENCRYPTION_BLOCK_HEADER_SIZE; + encryption = true; + } else if (entry.first == "encryption_cipher") { + throw BinderException("\"%s\" is not a valid cipher. Only AES GCM is supported.", entry.second.ToString()); + } else if (entry.first == "row_group_size") { + row_group_size = entry.second.GetValue(); + } else if (entry.first == "storage_version") { + storage_version_user_provided = entry.second.ToString(); + storage_version = SerializationCompatibility::FromString(entry.second.ToString()).serialization_version; + } else if (entry.first == "compress") { + if (entry.second.DefaultCastAs(LogicalType::BOOLEAN).GetValue()) { + compress_in_memory = CompressInMemory::COMPRESS; + } else { + compress_in_memory = CompressInMemory::DO_NOT_COMPRESS; + } + } else { + throw BinderException("Unrecognized option for attach \"%s\"", entry.first); + } + } + if (encryption && + (!storage_version.IsValid() || + storage_version.GetIndex() < SerializationCompatibility::FromString("v1.4.0").serialization_version)) { + if (!storage_version_user_provided.empty()) { + throw InvalidInputException( + "Explicit provided STORAGE_VERSION (\"%s\") and ENCRYPTION_KEY (storage >= v1.4.0) are not compatible", + storage_version_user_provided); + } + // set storage version to v1.4.0 + storage_version = SerializationCompatibility::FromString("v1.4.0").serialization_version; + } +} + +StorageManager::StorageManager(AttachedDatabase &db, string path_p, const AttachOptions &options) + : db(db), path(std::move(path_p)), read_only(options.access_mode == AccessMode::READ_ONLY), + in_memory_change_size(0) { if (path.empty()) { path = IN_MEMORY_PATH; @@ -27,6 +78,8 @@ StorageManager::StorageManager(AttachedDatabase &db, string path_p, bool read_on } auto &fs = FileSystem::Get(db); path = fs.ExpandPath(path); + + storage_options.Initialize(options.options); } StorageManager::~StorageManager() { @@ -48,7 +101,7 @@ ObjectCache &ObjectCache::GetObjectCache(ClientContext &context) { } idx_t StorageManager::GetWALSize() { - return wal->GetWALSize(); + return InMemory() ? in_memory_change_size.load() : wal->GetWALSize(); } optional_ptr StorageManager::GetWAL() { @@ -62,7 +115,7 @@ void StorageManager::ResetWAL() { wal->Delete(); } -string StorageManager::GetWALPath() { +string StorageManager::GetWALPath() const { // we append the ".wal" **before** a question mark in case of GET parameters // but only if we are not in a windows long path (which starts with \\?\) std::size_t question_mark_pos = std::string::npos; @@ -78,22 +131,22 @@ string StorageManager::GetWALPath() { return wal_path; } -bool StorageManager::InMemory() { +bool StorageManager::InMemory() const { D_ASSERT(!path.empty()); return path == IN_MEMORY_PATH; } -void StorageManager::Initialize(QueryContext context, StorageOptions &options) { +void StorageManager::Initialize(QueryContext context) { bool in_memory = InMemory(); if (in_memory && read_only) { throw CatalogException("Cannot launch in-memory database in read-only mode!"); } // Create or load the database from disk, if not in-memory mode. - LoadDatabase(context, options); + LoadDatabase(context); - if (options.encryption) { - ClearUserKey(options.user_key); + if (storage_options.encryption) { + ClearUserKey(storage_options.user_key); } } @@ -121,18 +174,23 @@ class SingleFileTableIOManager : public TableIOManager { } }; -SingleFileStorageManager::SingleFileStorageManager(AttachedDatabase &db, string path, bool read_only) - : StorageManager(db, std::move(path), read_only) { +SingleFileStorageManager::SingleFileStorageManager(AttachedDatabase &db, string path, const AttachOptions &options) + : StorageManager(db, std::move(path), options) { } -void SingleFileStorageManager::LoadDatabase(QueryContext context, StorageOptions &storage_options) { - +void SingleFileStorageManager::LoadDatabase(QueryContext context) { if (InMemory()) { block_manager = make_uniq(BufferManager::GetBufferManager(db), DEFAULT_BLOCK_ALLOC_SIZE, DEFAULT_BLOCK_HEADER_STORAGE_SIZE); table_io_manager = make_uniq(*block_manager, DEFAULT_ROW_GROUP_SIZE); + // in-memory databases can always use the latest storage version + storage_version = GetSerializationVersion("latest"); + load_complete = true; return; } + if (storage_options.compress_in_memory != CompressInMemory::AUTOMATIC) { + throw InvalidInputException("COMPRESS can only be set for in-memory databases"); + } auto &fs = FileSystem::Get(db); auto &config = DBConfig::Get(db); @@ -229,7 +287,7 @@ void SingleFileStorageManager::LoadDatabase(QueryContext context, StorageOptions // We'll construct the SingleFileBlockManager with the default block allocation size, // and later adjust it when reading the file header. auto sf_block_manager = make_uniq(db, path, options); - sf_block_manager->LoadExistingDatabase(); + sf_block_manager->LoadExistingDatabase(context); block_manager = std::move(sf_block_manager); table_io_manager = make_uniq(*block_manager, row_group_size); @@ -392,8 +450,16 @@ bool SingleFileStorageManager::IsCheckpointClean(MetaBlockPointer checkpoint_id) return block_manager->IsRootBlock(checkpoint_id); } +unique_ptr SingleFileStorageManager::CreateCheckpointWriter(QueryContext context, + CheckpointOptions options) { + if (InMemory()) { + return make_uniq(context, db, *block_manager, *this, options.type); + } + return make_uniq(context, db, *block_manager, options.type); +} + void SingleFileStorageManager::CreateCheckpoint(QueryContext context, CheckpointOptions options) { - if (InMemory() || read_only || !load_complete) { + if (read_only || !load_complete) { return; } if (db.GetStorageExtension()) { @@ -403,14 +469,14 @@ void SingleFileStorageManager::CreateCheckpoint(QueryContext context, Checkpoint if (GetWALSize() > 0 || config.options.force_checkpoint || options.action == CheckpointAction::ALWAYS_CHECKPOINT) { // we only need to checkpoint if there is anything in the WAL try { - SingleFileCheckpointWriter checkpointer(context, db, *block_manager, options.type); - checkpointer.CreateCheckpoint(); + auto checkpointer = CreateCheckpointWriter(context, options); + checkpointer->CreateCheckpoint(); } catch (std::exception &ex) { ErrorData error(ex); throw FatalException("Failed to create checkpoint because of error: %s", error.RawMessage()); } } - if (options.wal_action == CheckpointWALAction::DELETE_WAL) { + if (!InMemory() && options.wal_action == CheckpointWALAction::DELETE_WAL) { ResetWAL(); } diff --git a/src/duckdb/src/storage/table/array_column_data.cpp b/src/duckdb/src/storage/table/array_column_data.cpp index c4c556984..05964339a 100644 --- a/src/duckdb/src/storage/table/array_column_data.cpp +++ b/src/duckdb/src/storage/table/array_column_data.cpp @@ -313,6 +313,10 @@ bool ArrayColumnData::IsPersistent() { return validity.IsPersistent() && child_column->IsPersistent(); } +bool ArrayColumnData::HasAnyChanges() const { + return child_column->HasAnyChanges() || validity.HasAnyChanges(); +} + PersistentColumnData ArrayColumnData::Serialize() { PersistentColumnData persistent_data(PhysicalType::ARRAY); persistent_data.child_columns.push_back(validity.Serialize()); diff --git a/src/duckdb/src/storage/table/column_checkpoint_state.cpp b/src/duckdb/src/storage/table/column_checkpoint_state.cpp index 1bd809e50..213338d97 100644 --- a/src/duckdb/src/storage/table/column_checkpoint_state.cpp +++ b/src/duckdb/src/storage/table/column_checkpoint_state.cpp @@ -25,7 +25,7 @@ unique_ptr ColumnCheckpointState::GetStatistics() { PartialBlockForCheckpoint::PartialBlockForCheckpoint(ColumnData &data, ColumnSegment &segment, PartialBlockState state, BlockManager &block_manager) : PartialBlock(state, block_manager, segment.block) { - AddSegmentToTail(data, segment, 0); + PartialBlockForCheckpoint::AddSegmentToTail(data, segment, 0); } PartialBlockForCheckpoint::~PartialBlockForCheckpoint() { @@ -126,28 +126,32 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme return; } // LCOV_EXCL_STOP - // merge the segment stats into the global stats + // Merge the segment statistics into the global statistics. global_stats->Merge(segment->stats.statistics); - // get the buffer of the segment and pin it - auto &db = column_data.GetDatabase(); - auto &buffer_manager = BufferManager::GetBufferManager(db); block_id_t block_id = INVALID_BLOCK; uint32_t offset_in_block = 0; unique_lock partial_block_lock; - if (!segment->stats.statistics.IsConstant()) { + if (segment->stats.statistics.IsConstant()) { + // Constant block. + segment->ConvertToPersistent(partial_block_manager.GetClientContext(), nullptr, INVALID_BLOCK); + + } else if (segment_size != 0) { + // Non-constant block with data that has to go to disk. + auto &db = column_data.GetDatabase(); + auto &buffer_manager = BufferManager::GetBufferManager(db); partial_block_lock = partial_block_manager.GetLock(); - // non-constant block - auto allocation = partial_block_manager.GetBlockAllocation(NumericCast(segment_size)); + auto cast_segment_size = NumericCast(segment_size); + auto allocation = partial_block_manager.GetBlockAllocation(cast_segment_size); block_id = allocation.state.block_id; offset_in_block = allocation.state.offset; if (allocation.partial_block) { // Use an existing block. D_ASSERT(offset_in_block > 0); - auto &pstate = allocation.partial_block->Cast(); + auto &pstate = *allocation.partial_block; // pin the source block auto old_handle = buffer_manager.Pin(segment->block); // pin the target block @@ -164,13 +168,17 @@ void ColumnCheckpointState::FlushSegmentInternal(unique_ptr segme segment->Resize(block_size); } D_ASSERT(offset_in_block == 0); - allocation.partial_block = make_uniq(column_data, *segment, allocation.state, - *allocation.block_manager); + allocation.partial_block = partial_block_manager.CreatePartialBlock(column_data, *segment, allocation.state, + *allocation.block_manager); } // Writer will decide whether to reuse this block. partial_block_manager.RegisterPartialBlock(std::move(allocation)); + } else { - segment->ConvertToPersistent(partial_block_manager.GetClientContext(), nullptr, INVALID_BLOCK); + // Empty segment, which does not have to go to disk. + // We still need to change its type to persistent, because we need to write its metadata. + segment->segment_type = ColumnSegmentType::PERSISTENT; + segment->block.reset(); } // construct the data pointer diff --git a/src/duckdb/src/storage/table/column_data.cpp b/src/duckdb/src/storage/table/column_data.cpp index 1366aadb9..082f42f61 100644 --- a/src/duckdb/src/storage/table/column_data.cpp +++ b/src/duckdb/src/storage/table/column_data.cpp @@ -77,6 +77,29 @@ bool ColumnData::HasChanges(idx_t start_row, idx_t end_row) const { return false; } +bool ColumnData::HasChanges() const { + auto l = data.Lock(); + auto &nodes = data.ReferenceLoadedSegments(l); + for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { + auto segment = nodes[segment_idx].node.get(); + if (segment->segment_type == ColumnSegmentType::TRANSIENT) { + // transient segment: always need to write to disk + return true; + } + // persistent segment; check if there were any updates or deletions in this segment + idx_t start_row_idx = segment->start - start; + idx_t end_row_idx = start_row_idx + segment->count; + if (HasChanges(start_row_idx, end_row_idx)) { + return true; + } + } + return false; +} + +bool ColumnData::HasAnyChanges() const { + return HasChanges(); +} + void ColumnData::ClearUpdates() { lock_guard update_guard(update_lock); updates.reset(); @@ -542,13 +565,18 @@ void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, FetchUpdateRow(transaction, row_id, result, result_idx); } +idx_t ColumnData::FetchUpdateData(row_t *row_ids, Vector &base_vector) { + ColumnScanState state; + auto fetch_count = ColumnData::Fetch(state, row_ids[0], base_vector); + base_vector.Flatten(fetch_count); + return fetch_count; +} + void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count) { Vector base_vector(type); - ColumnScanState state; - auto fetch_count = Fetch(state, row_ids[0], base_vector); + FetchUpdateData(row_ids, base_vector); - base_vector.Flatten(fetch_count); UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); } diff --git a/src/duckdb/src/storage/table/column_data_checkpointer.cpp b/src/duckdb/src/storage/table/column_data_checkpointer.cpp index eb20b13d8..08362da3f 100644 --- a/src/duckdb/src/storage/table/column_data_checkpointer.cpp +++ b/src/duckdb/src/storage/table/column_data_checkpointer.cpp @@ -5,6 +5,7 @@ #include "duckdb/storage/data_table.hpp" #include "duckdb/parser/column_definition.hpp" #include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/logging/log_manager.hpp" #include "duckdb/main/database.hpp" namespace duckdb { @@ -38,6 +39,10 @@ ColumnCheckpointState &ColumnDataCheckpointData::GetCheckpointState() { return *checkpoint_state; } +StorageManager &ColumnDataCheckpointData::GetStorageManager() { + return *storage_manager; +} + //! ColumnDataCheckpointer static Vector CreateIntermediateVector(vector> &states) { @@ -328,8 +333,8 @@ void ColumnDataCheckpointer::WriteToDisk() { auto &checkpoint_state = checkpoint_states[i]; auto &col_data = checkpoint_state.get().column_data; - checkpoint_data[i] = - ColumnDataCheckpointData(checkpoint_state, col_data, col_data.GetDatabase(), row_group, checkpoint_info); + checkpoint_data[i] = ColumnDataCheckpointData(checkpoint_state, col_data, col_data.GetDatabase(), row_group, + checkpoint_info, storage_manager); compression_states[i] = function->init_compression(checkpoint_data[i], std::move(analyze_state)); } @@ -357,21 +362,7 @@ void ColumnDataCheckpointer::WriteToDisk() { } bool ColumnDataCheckpointer::HasChanges(ColumnData &col_data) { - auto &nodes = col_data.data.ReferenceSegments(); - for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { - auto segment = nodes[segment_idx].node.get(); - if (segment->segment_type == ColumnSegmentType::TRANSIENT) { - // transient segment: always need to write to disk - return true; - } - // persistent segment; check if there were any updates or deletions in this segment - idx_t start_row_idx = segment->start - row_group.start; - idx_t end_row_idx = start_row_idx + segment->count; - if (col_data.HasChanges(start_row_idx, end_row_idx)) { - return true; - } - } - return false; + return col_data.HasChanges(); } void ColumnDataCheckpointer::WritePersistentSegments(ColumnCheckpointState &state) { diff --git a/src/duckdb/src/storage/table/column_segment.cpp b/src/duckdb/src/storage/table/column_segment.cpp index cb1227854..347463fbe 100644 --- a/src/duckdb/src/storage/table/column_segment.cpp +++ b/src/duckdb/src/storage/table/column_segment.cpp @@ -35,10 +35,8 @@ unique_ptr ColumnSegment::CreatePersistentSegment(DatabaseInstanc optional_ptr function; shared_ptr block; - if (block_id == INVALID_BLOCK) { - function = config.GetCompressionFunction(CompressionType::COMPRESSION_CONSTANT, type.InternalType()); - } else { - function = config.GetCompressionFunction(compression_type, type.InternalType()); + function = config.GetCompressionFunction(compression_type, type.InternalType()); + if (block_id != INVALID_BLOCK) { block = block_manager.RegisterBlock(block_id); } @@ -224,36 +222,38 @@ void ColumnSegment::RevertAppend(idx_t start_row) { // Convert To Persistent //===--------------------------------------------------------------------===// void ColumnSegment::ConvertToPersistent(QueryContext context, optional_ptr block_manager, - block_id_t block_id_p) { + const block_id_t block_id_p) { D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); segment_type = ColumnSegmentType::PERSISTENT; - block_id = block_id_p; offset = 0; - if (block_id == INVALID_BLOCK) { - // Constant block: no need to write anything to disk besides the stats. - // Set the compression function to constant. - D_ASSERT(stats.statistics.IsConstant()); - auto &config = DBConfig::GetConfig(db); - function = *config.GetCompressionFunction(CompressionType::COMPRESSION_CONSTANT, type.InternalType()); - // Reset the block buffer. - block.reset(); + if (block_id != INVALID_BLOCK) { + D_ASSERT(!stats.statistics.IsConstant()); + // Non-constant block: write the block to disk. + // The block data already exists in memory, so we alter the metadata, + // which ensures that the buffer points to an on-disk block. + block = block_manager->ConvertToPersistent(context, block_id, std::move(block)); return; } - // Non-constant block: write the block to disk. - // The data for the block already exists in-memory of our block. - // Instead of copying the data, we alter the metadata so that the buffer points to an on-disk block. - D_ASSERT(!stats.statistics.IsConstant()); - block = block_manager->ConvertToPersistent(context, block_id, std::move(block)); + // Constant block: no need to write anything to disk besides the stats (metadata). + // I.e., we do not need to write an actual block. + // Thus, we set the compression function to constant and reset the block buffer. + D_ASSERT(stats.statistics.IsConstant()); + auto &config = DBConfig::GetConfig(db); + function = *config.GetCompressionFunction(CompressionType::COMPRESSION_CONSTANT, type.InternalType()); + block.reset(); } void ColumnSegment::MarkAsPersistent(shared_ptr block_p, uint32_t offset_p) { D_ASSERT(segment_type == ColumnSegmentType::TRANSIENT); - segment_type = ColumnSegmentType::PERSISTENT; - block_id = block_p->BlockId(); + SetBlock(std::move(block_p), offset_p); +} + +void ColumnSegment::SetBlock(shared_ptr block_p, uint32_t offset_p) { + segment_type = ColumnSegmentType::PERSISTENT; offset = offset_p; block = std::move(block_p); } diff --git a/src/duckdb/src/storage/table/in_memory_checkpoint.cpp b/src/duckdb/src/storage/table/in_memory_checkpoint.cpp new file mode 100644 index 000000000..911c86747 --- /dev/null +++ b/src/duckdb/src/storage/table/in_memory_checkpoint.cpp @@ -0,0 +1,136 @@ +#include "duckdb/storage/table/in_memory_checkpoint.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" +#include "duckdb/catalog/duck_catalog.hpp" + +namespace duckdb { + +//===--------------------------------------------------------------------===// +// In-Memory Checkpoint Writer +//===--------------------------------------------------------------------===// +InMemoryCheckpointer::InMemoryCheckpointer(QueryContext context, AttachedDatabase &db, BlockManager &block_manager, + StorageManager &storage_manager, CheckpointType checkpoint_type) + : CheckpointWriter(db), context(context.GetClientContext()), + partial_block_manager(context, block_manager, PartialBlockType::IN_MEMORY_CHECKPOINT), + storage_manager(storage_manager), checkpoint_type(checkpoint_type) { +} + +void InMemoryCheckpointer::CreateCheckpoint() { + vector> schemas; + // we scan the set of committed schemas + auto &catalog = Catalog::GetCatalog(db).Cast(); + catalog.ScanSchemas([&](SchemaCatalogEntry &entry) { schemas.push_back(entry); }); + + vector> tables; + for (const auto &schema_ref : schemas) { + auto &schema = schema_ref.get(); + schema.Scan(CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { + if (entry.type == CatalogType::TABLE_ENTRY) { + tables.push_back(entry.Cast()); + } + }); + } + + for (auto &table : tables) { + MemoryStream write_stream; + BinarySerializer serializer(write_stream); + + WriteTable(table, serializer); + } + storage_manager.ResetInMemoryChange(); +} + +MetadataWriter &InMemoryCheckpointer::GetMetadataWriter() { + throw InternalException("Unsupported method GetMetadataWriter for InMemoryCheckpointer"); +} +MetadataManager &InMemoryCheckpointer::GetMetadataManager() { + throw InternalException("Unsupported method GetMetadataManager for InMemoryCheckpointer"); +} +unique_ptr InMemoryCheckpointer::GetTableDataWriter(TableCatalogEntry &table) { + throw InternalException("Unsupported method GetTableDataWriter for InMemoryCheckpointer"); +} + +void InMemoryCheckpointer::WriteTable(TableCatalogEntry &table, Serializer &serializer) { + InMemoryTableDataWriter data_writer(*this, table); + + // Write the table data + auto table_lock = table.GetStorage().GetCheckpointLock(); + table.GetStorage().Checkpoint(data_writer, serializer); + // flush any partial blocks BEFORE releasing the table lock + // flushing partial blocks updates where data lives and is not thread-safe + partial_block_manager.FlushPartialBlocks(); +} + +InMemoryRowGroupWriter::InMemoryRowGroupWriter(TableCatalogEntry &table, PartialBlockManager &partial_block_manager, + InMemoryCheckpointer &checkpoint_manager) + : RowGroupWriter(table, partial_block_manager), checkpoint_manager(checkpoint_manager) { +} + +CheckpointType InMemoryRowGroupWriter::GetCheckpointType() const { + return checkpoint_manager.GetCheckpointType(); +} + +WriteStream &InMemoryRowGroupWriter::GetPayloadWriter() { + return metadata_writer; +} + +MetaBlockPointer InMemoryRowGroupWriter::GetMetaBlockPointer() { + return MetaBlockPointer(); +} + +optional_ptr InMemoryRowGroupWriter::GetMetadataManager() { + return nullptr; +} + +InMemoryTableDataWriter::InMemoryTableDataWriter(InMemoryCheckpointer &checkpoint_manager, TableCatalogEntry &table) + : TableDataWriter(table, checkpoint_manager.GetClientContext()), checkpoint_manager(checkpoint_manager) { +} + +void InMemoryTableDataWriter::WriteUnchangedTable(MetaBlockPointer pointer, idx_t total_rows) { +} + +void InMemoryTableDataWriter::FinalizeTable(const TableStatistics &global_stats, DataTableInfo *info, + Serializer &serializer) { + // nop: no need to write anything +} + +unique_ptr InMemoryTableDataWriter::GetRowGroupWriter(RowGroup &row_group) { + return make_uniq(table, checkpoint_manager.GetPartialBlockManager(), checkpoint_manager); +} + +CheckpointType InMemoryTableDataWriter::GetCheckpointType() const { + return checkpoint_manager.GetCheckpointType(); +} + +MetadataManager &InMemoryTableDataWriter::GetMetadataManager() { + return checkpoint_manager.GetMetadataManager(); +} + +InMemoryPartialBlock::InMemoryPartialBlock(ColumnData &data, ColumnSegment &segment, PartialBlockState state, + BlockManager &block_manager) + : PartialBlock(state, block_manager, segment.block) { + AddSegmentToTail(data, segment, 0); +} + +InMemoryPartialBlock::~InMemoryPartialBlock() { +} + +void InMemoryPartialBlock::Flush(QueryContext context, const idx_t free_space_left) { + Clear(); +} + +void InMemoryPartialBlock::Merge(PartialBlock &other_p, idx_t offset, idx_t other_size) { + auto &other = other_p.Cast(); + other.Clear(); +} + +void InMemoryPartialBlock::AddSegmentToTail(ColumnData &data, ColumnSegment &segment, uint32_t offset_in_block) { + segment.SetBlock(block_handle, offset_in_block); +} + +void InMemoryPartialBlock::Clear() { + uninitialized_regions.clear(); + block_handle.reset(); +} + +} // namespace duckdb diff --git a/src/duckdb/src/storage/table/list_column_data.cpp b/src/duckdb/src/storage/table/list_column_data.cpp index 2a898fcdb..7685d16ca 100644 --- a/src/duckdb/src/storage/table/list_column_data.cpp +++ b/src/duckdb/src/storage/table/list_column_data.cpp @@ -373,6 +373,10 @@ bool ListColumnData::IsPersistent() { return ColumnData::IsPersistent() && validity.IsPersistent() && child_column->IsPersistent(); } +bool ListColumnData::HasAnyChanges() const { + return ColumnData::HasAnyChanges() || validity.HasAnyChanges() || child_column->HasAnyChanges(); +} + PersistentColumnData ListColumnData::Serialize() { auto persistent_data = ColumnData::Serialize(); persistent_data.child_columns.push_back(validity.Serialize()); diff --git a/src/duckdb/src/storage/table/row_group.cpp b/src/duckdb/src/storage/table/row_group.cpp index eef8461fe..a56d60c21 100644 --- a/src/duckdb/src/storage/table/row_group.cpp +++ b/src/duckdb/src/storage/table/row_group.cpp @@ -1,28 +1,25 @@ #include "duckdb/storage/table/row_group.hpp" -#include "duckdb/common/types/vector.hpp" + #include "duckdb/common/exception.hpp" -#include "duckdb/storage/table/column_data.hpp" -#include "duckdb/storage/table/column_checkpoint_state.hpp" -#include "duckdb/storage/table/update_segment.hpp" -#include "duckdb/storage/table_storage_info.hpp" -#include "duckdb/planner/table_filter.hpp" +#include "duckdb/common/serializer/binary_serializer.hpp" +#include "duckdb/common/serializer/deserializer.hpp" +#include "duckdb/common/serializer/serializer.hpp" +#include "duckdb/common/types/vector.hpp" +#include "duckdb/execution/adaptive_filter.hpp" #include "duckdb/execution/expression_executor.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/planner/table_filter.hpp" #include "duckdb/storage/checkpoint/table_data_writer.hpp" #include "duckdb/storage/metadata/metadata_reader.hpp" -#include "duckdb/transaction/duck_transaction_manager.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/transaction/duck_transaction.hpp" #include "duckdb/storage/table/append_state.hpp" -#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/column_checkpoint_state.hpp" +#include "duckdb/storage/table/column_data.hpp" #include "duckdb/storage/table/row_version_manager.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/binary_serializer.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/execution/adaptive_filter.hpp" +#include "duckdb/storage/table/scan_state.hpp" +#include "duckdb/storage/table/update_segment.hpp" +#include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/transaction/duck_transaction.hpp" +#include "duckdb/transaction/duck_transaction_manager.hpp" namespace duckdb { @@ -46,6 +43,8 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, RowGroupPointer pointer) } this->deletes_pointers = std::move(pointer.deletes_pointers); this->deletes_is_loaded = false; + this->has_metadata_blocks = pointer.has_metadata_blocks; + this->extra_metadata_blocks = std::move(pointer.extra_metadata_blocks); Verify(); } @@ -69,8 +68,13 @@ RowGroup::RowGroup(RowGroupCollection &collection_p, PersistentRowGroupData &dat void RowGroup::MoveToCollection(RowGroupCollection &collection_p, idx_t new_start) { this->collection = collection_p; this->start = new_start; - for (auto &column : GetColumns()) { - column->SetStart(new_start); + for (idx_t c = 0; c < columns.size(); c++) { + if (is_loaded && !is_loaded[c]) { + // we only need to set the column start position if it is already loaded + // if it is not loaded - we will set the correct start position upon loading + continue; + } + columns[c]->SetStart(new_start); } if (!HasUnloadedDeletes()) { auto vinfo = GetVersionInfo(); @@ -1009,10 +1013,52 @@ bool RowGroup::HasUnloadedDeletes() const { return !deletes_is_loaded; } -RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { - vector compression_types; - compression_types.reserve(columns.size()); +vector RowGroup::GetColumnPointers() { + if (has_metadata_blocks) { + // we have the column metadata from the file itself - no need to deserialize metadata to fetch it + // read if from "column_pointers" and "extra_metadata_blocks" + auto result = column_pointers; + for (auto &block_pointer : extra_metadata_blocks) { + result.emplace_back(block_pointer, 0); + } + return result; + } + vector result; + if (column_pointers.empty()) { + // no pointers + return result; + } + // column_pointers stores the beginning of each column + // if columns are big - they may span multiple metadata blocks + // we need to figure out all blocks that this row group points to + // we need to follow the linked list in the metadata blocks to allow for this + auto &metadata_manager = GetCollection().GetMetadataManager(); + idx_t last_idx = column_pointers.size() - 1; + if (column_pointers.size() > 1) { + // for all but the last column pointer - we can just follow the linked list until we reach the last column + MetadataReader reader(metadata_manager, column_pointers[0]); + auto last_pointer = column_pointers[last_idx]; + result = reader.GetRemainingBlocks(last_pointer); + } + // for the last column we need to deserialize the column - because we don't know where it stops + auto &types = GetCollection().GetTypes(); + MetadataReader reader(metadata_manager, column_pointers[last_idx], &result); + ColumnData::Deserialize(GetBlockManager(), GetTableInfo(), last_idx, start, reader, types[last_idx]); + return result; +} +RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { + if (!column_pointers.empty() && !HasChanges()) { + // we have existing metadata and the row group has not been changed + // re-use previous metadata + RowGroupWriteData result; + result.existing_pointers = GetColumnPointers(); + return result; + } + auto &compression_types = writer.GetCompressionTypes(); + if (columns.size() != compression_types.size()) { + throw InternalException("RowGroup::WriteToDisk - mismatch in column count vs compression types"); + } for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { auto &column = GetColumn(column_idx); if (column.count != this->count) { @@ -1020,8 +1066,6 @@ RowGroupWriteData RowGroup::WriteToDisk(RowGroupWriter &writer) { "group has %llu rows, column has %llu)", column_idx, this->count.load(), column.count.load()); } - auto compression_type = writer.GetColumnCompressionType(column_idx); - compression_types.push_back(compression_type); } RowGroupWriteInfo info(writer.GetPartialBlockManager(), compression_types, writer.GetCheckpointType()); @@ -1032,22 +1076,38 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite TableStatistics &global_stats) { RowGroupPointer row_group_pointer; - auto lock = global_stats.GetLock(); - for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { - global_stats.GetStats(*lock, column_idx).Statistics().Merge(write_data.statistics[column_idx]); - } - + auto metadata_manager = writer.GetMetadataManager(); // construct the row group pointer and write the column meta data to disk - D_ASSERT(write_data.states.size() == columns.size()); row_group_pointer.row_start = start; row_group_pointer.tuple_count = count; + if (!write_data.existing_pointers.empty()) { + // we are re-using the previous metadata + row_group_pointer.data_pointers = column_pointers; + row_group_pointer.has_metadata_blocks = has_metadata_blocks; + row_group_pointer.extra_metadata_blocks = extra_metadata_blocks; + row_group_pointer.deletes_pointers = deletes_pointers; + metadata_manager->ClearModifiedBlocks(write_data.existing_pointers); + metadata_manager->ClearModifiedBlocks(deletes_pointers); + return row_group_pointer; + } + D_ASSERT(write_data.states.size() == columns.size()); + { + auto lock = global_stats.GetLock(); + for (idx_t column_idx = 0; column_idx < GetColumnCount(); column_idx++) { + global_stats.GetStats(*lock, column_idx).Statistics().Merge(write_data.statistics[column_idx]); + } + } + vector column_metadata; + unordered_set metadata_blocks; + writer.StartWritingColumns(column_metadata); for (auto &state : write_data.states) { // get the current position of the table data writer auto &data_writer = writer.GetPayloadWriter(); - auto pointer = data_writer.GetMetaBlockPointer(); + auto pointer = writer.GetMetaBlockPointer(); // store the stats and the data pointers in the row group pointers row_group_pointer.data_pointers.push_back(pointer); + metadata_blocks.insert(pointer.block_pointer); // Write pointers to the column segments. // @@ -1059,11 +1119,43 @@ RowGroupPointer RowGroup::Checkpoint(RowGroupWriteData write_data, RowGroupWrite persistent_data.Serialize(serializer); serializer.End(); } - row_group_pointer.deletes_pointers = CheckpointDeletes(writer.GetPayloadWriter().GetManager()); + writer.FinishWritingColumns(); + + row_group_pointer.has_metadata_blocks = true; + for (auto &column_pointer : column_metadata) { + auto entry = metadata_blocks.find(column_pointer.block_pointer); + if (entry != metadata_blocks.end()) { + // this metadata block is already stored in "data_pointers" - no need to duplicate it + continue; + } + // this metadata block is not stored - add it to the extra metadata blocks + row_group_pointer.extra_metadata_blocks.push_back(column_pointer.block_pointer); + } + if (metadata_manager) { + row_group_pointer.deletes_pointers = CheckpointDeletes(*metadata_manager); + } Verify(); return row_group_pointer; } +bool RowGroup::HasChanges() const { + if (version_info.load()) { + // we have deletes + return true; + } + // check if any of the columns have changes + // avoid loading unloaded columns - unloaded columns can never have changes + for (idx_t c = 0; c < columns.size(); c++) { + if (is_loaded && !is_loaded[c]) { + continue; + } + if (columns[c]->HasAnyChanges()) { + return true; + } + } + return false; +} + bool RowGroup::IsPersistent() const { for (auto &column : columns) { if (!column->IsPersistent()) { @@ -1105,6 +1197,10 @@ void RowGroup::Serialize(RowGroupPointer &pointer, Serializer &serializer) { serializer.WriteProperty(101, "tuple_count", pointer.tuple_count); serializer.WriteProperty(102, "data_pointers", pointer.data_pointers); serializer.WriteProperty(103, "delete_pointers", pointer.deletes_pointers); + if (serializer.ShouldSerialize(6)) { + serializer.WriteProperty(104, "has_metadata_blocks", pointer.has_metadata_blocks); + serializer.WritePropertyWithDefault(105, "extra_metadata_blocks", pointer.extra_metadata_blocks); + } } RowGroupPointer RowGroup::Deserialize(Deserializer &deserializer) { @@ -1113,6 +1209,8 @@ RowGroupPointer RowGroup::Deserialize(Deserializer &deserializer) { result.tuple_count = deserializer.ReadProperty(101, "tuple_count"); result.data_pointers = deserializer.ReadProperty>(102, "data_pointers"); result.deletes_pointers = deserializer.ReadProperty>(103, "delete_pointers"); + result.has_metadata_blocks = deserializer.ReadPropertyWithExplicitDefault(104, "has_metadata_blocks", false); + result.extra_metadata_blocks = deserializer.ReadPropertyWithDefault>(105, "extra_metadata_blocks"); return result; } diff --git a/src/duckdb/src/storage/table/row_group_collection.cpp b/src/duckdb/src/storage/table/row_group_collection.cpp index 57287f3b7..ae3531efc 100644 --- a/src/duckdb/src/storage/table/row_group_collection.cpp +++ b/src/duckdb/src/storage/table/row_group_collection.cpp @@ -1,4 +1,5 @@ #include "duckdb/storage/table/row_group_collection.hpp" + #include "duckdb/common/serializer/binary_deserializer.hpp" #include "duckdb/execution/expression_executor.hpp" #include "duckdb/execution/index/bound_index.hpp" @@ -15,6 +16,7 @@ #include "duckdb/storage/table/row_group_segment_tree.hpp" #include "duckdb/storage/table/scan_state.hpp" #include "duckdb/storage/table_storage_info.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -33,6 +35,7 @@ void RowGroupSegmentTree::Initialize(PersistentTableData &data) { max_row_group = data.row_group_count; finished_loading = false; reader = make_uniq(collection.GetMetadataManager(), data.block_pointer); + root_pointer = data.block_pointer; } unique_ptr RowGroupSegmentTree::LoadSegment() { @@ -95,6 +98,7 @@ void RowGroupCollection::Initialize(PersistentTableData &data) { this->total_rows = data.total_rows; row_groups->Initialize(data); stats.Initialize(types, data); + metadata_pointer = data.base_table_pointer; } void RowGroupCollection::Initialize(PersistentCollectionData &data) { @@ -940,9 +944,10 @@ class VacuumTask : public BaseCheckpointTask { void RowGroupCollection::InitializeVacuumState(CollectionCheckpointState &checkpoint_state, VacuumState &state, vector> &segments) { - bool is_full_checkpoint = checkpoint_state.writer.GetCheckpointType() == CheckpointType::FULL_CHECKPOINT; + auto checkpoint_type = checkpoint_state.writer.GetCheckpointType(); + bool vacuum_is_allowed = checkpoint_type != CheckpointType::CONCURRENT_CHECKPOINT; // currently we can only vacuum deletes if we are doing a full checkpoint and there are no indexes - state.can_vacuum_deletes = info->GetIndexes().Empty() && is_full_checkpoint; + state.can_vacuum_deletes = info->GetIndexes().Empty() && vacuum_is_allowed; if (!state.can_vacuum_deletes) { return; } @@ -1048,12 +1053,11 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl try { // schedule tasks idx_t total_vacuum_tasks = 0; - auto &config = DBConfig::GetConfig(writer.GetDatabase()); - + auto max_vacuum_tasks = DBConfig::GetSetting(writer.GetDatabase()); for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { auto &entry = segments[segment_idx]; - auto vacuum_tasks = ScheduleVacuumTasks(checkpoint_state, vacuum_state, segment_idx, - total_vacuum_tasks < config.options.max_vacuum_tasks); + auto vacuum_tasks = + ScheduleVacuumTasks(checkpoint_state, vacuum_state, segment_idx, total_vacuum_tasks < max_vacuum_tasks); if (vacuum_tasks) { // vacuum tasks were scheduled - don't schedule a checkpoint task yet total_vacuum_tasks++; @@ -1065,8 +1069,10 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl } // schedule a checkpoint task for this row group entry.node->MoveToCollection(*this, vacuum_state.row_start); - auto checkpoint_task = GetCheckpointTask(checkpoint_state, segment_idx); - checkpoint_state.executor->ScheduleTask(std::move(checkpoint_task)); + if (writer.GetCheckpointType() != CheckpointType::VACUUM_ONLY) { + auto checkpoint_task = GetCheckpointTask(checkpoint_state, segment_idx); + checkpoint_state.executor->ScheduleTask(std::move(checkpoint_task)); + } vacuum_state.row_start += entry.node->count; } } catch (const std::exception &e) { @@ -1079,6 +1085,39 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl checkpoint_state.executor->WorkOnTasks(); // no errors - finalize the row groups + // if the table already exists on disk - check if all row groups have stayed the same + if (metadata_pointer.IsValid()) { + bool table_has_changes = false; + for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { + auto &entry = segments[segment_idx]; + if (!entry.node) { + table_has_changes = true; + break; + } + auto &write_state = checkpoint_state.write_data[segment_idx]; + if (write_state.existing_pointers.empty()) { + table_has_changes = true; + break; + } + } + if (!table_has_changes) { + // table is unmodified and already exists on disk + // we can directly re-use the metadata pointer + // mark all blocks associated with row groups as still being in-use + auto &metadata_manager = writer.GetMetadataManager(); + for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { + auto &entry = segments[segment_idx]; + auto &row_group = *entry.node; + auto &write_state = checkpoint_state.write_data[segment_idx]; + metadata_manager.ClearModifiedBlocks(write_state.existing_pointers); + metadata_manager.ClearModifiedBlocks(row_group.GetDeletesPointers()); + row_groups->AppendSegment(l, std::move(entry.node)); + } + writer.WriteUnchangedTable(metadata_pointer, total_rows.load()); + return; + } + } + idx_t new_total_rows = 0; for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { auto &entry = segments[segment_idx]; @@ -1087,6 +1126,13 @@ void RowGroupCollection::Checkpoint(TableDataWriter &writer, TableStatistics &gl continue; } auto &row_group = *entry.node; + if (!checkpoint_state.writers[segment_idx]) { + // row group was not checkpointed - this can happen if compressing is disabled for in-memory tables + D_ASSERT(writer.GetCheckpointType() == CheckpointType::VACUUM_ONLY); + row_groups->AppendSegment(l, std::move(entry.node)); + new_total_rows += row_group.count; + continue; + } auto row_group_writer = std::move(checkpoint_state.writers[segment_idx]); if (!row_group_writer) { throw InternalException("Missing row group writer for index %llu", segment_idx); diff --git a/src/duckdb/src/storage/table/row_version_manager.cpp b/src/duckdb/src/storage/table/row_version_manager.cpp index a3b47a56e..df4e463da 100644 --- a/src/duckdb/src/storage/table/row_version_manager.cpp +++ b/src/duckdb/src/storage/table/row_version_manager.cpp @@ -221,7 +221,7 @@ vector RowVersionManager::Checkpoint(MetadataManager &manager) // we can write the current pointer as-is // ensure the blocks we are pointing to are not marked as free manager.ClearModifiedBlocks(storage_pointers); - // return the root pointer + // return the current set of pointers return storage_pointers; } // first count how many ChunkInfo's we need to deserialize diff --git a/src/duckdb/src/storage/table/standard_column_data.cpp b/src/duckdb/src/storage/table/standard_column_data.cpp index 50062a45c..266161a77 100644 --- a/src/duckdb/src/storage/table/standard_column_data.cpp +++ b/src/duckdb/src/storage/table/standard_column_data.cpp @@ -154,8 +154,15 @@ idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &re void StandardColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count) { - ColumnData::Update(transaction, column_index, update_vector, row_ids, update_count); - validity.Update(transaction, column_index, update_vector, row_ids, update_count); + Vector base_vector(type); + auto standard_fetch = FetchUpdateData(row_ids, base_vector); + auto validity_fetch = validity.FetchUpdateData(row_ids, base_vector); + if (standard_fetch != validity_fetch) { + throw InternalException("Unaligned fetch in validity and main column data for update"); + } + + UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); + validity.UpdateInternal(transaction, column_index, update_vector, row_ids, update_count, base_vector); } void StandardColumnData::UpdateColumn(TransactionData transaction, const vector &column_path, @@ -271,6 +278,10 @@ bool StandardColumnData::IsPersistent() { return ColumnData::IsPersistent() && validity.IsPersistent(); } +bool StandardColumnData::HasAnyChanges() const { + return ColumnData::HasAnyChanges() || validity.HasAnyChanges(); +} + PersistentColumnData StandardColumnData::Serialize() { auto persistent_data = ColumnData::Serialize(); persistent_data.child_columns.push_back(validity.Serialize()); diff --git a/src/duckdb/src/storage/table/struct_column_data.cpp b/src/duckdb/src/storage/table/struct_column_data.cpp index 2e8f94954..a19f94707 100644 --- a/src/duckdb/src/storage/table/struct_column_data.cpp +++ b/src/duckdb/src/storage/table/struct_column_data.cpp @@ -328,6 +328,18 @@ bool StructColumnData::IsPersistent() { return true; } +bool StructColumnData::HasAnyChanges() const { + if (validity.HasAnyChanges()) { + return true; + } + for (auto &child_col : sub_columns) { + if (child_col->HasAnyChanges()) { + return true; + } + } + return false; +} + PersistentColumnData StructColumnData::Serialize() { PersistentColumnData persistent_data(PhysicalType::STRUCT); persistent_data.child_columns.push_back(validity.Serialize()); diff --git a/src/duckdb/src/storage/table/update_segment.cpp b/src/duckdb/src/storage/table/update_segment.cpp index 39463ba89..8056907bc 100644 --- a/src/duckdb/src/storage/table/update_segment.cpp +++ b/src/duckdb/src/storage/table/update_segment.cpp @@ -575,10 +575,8 @@ void UpdateSegment::RollbackUpdate(UpdateInfo &info) { // Cleanup Update //===--------------------------------------------------------------------===// void UpdateSegment::CleanupUpdateInternal(const StorageLockKey &lock, UpdateInfo &info) { - D_ASSERT(info.HasPrev()); - auto prev = info.prev; - { - auto pin = prev.Pin(); + if (info.HasPrev()) { + auto pin = info.prev.Pin(); auto &prev_info = UpdateInfo::Get(pin); prev_info.next = info.next; } @@ -586,7 +584,7 @@ void UpdateSegment::CleanupUpdateInternal(const StorageLockKey &lock, UpdateInfo auto next = info.next; auto next_pin = next.Pin(); auto &next_info = UpdateInfo::Get(next_pin); - next_info.prev = prev; + next_info.prev = info.prev; } } diff --git a/src/duckdb/src/storage/table_index_list.cpp b/src/duckdb/src/storage/table_index_list.cpp index e75cdea24..9292073db 100644 --- a/src/duckdb/src/storage/table_index_list.cpp +++ b/src/duckdb/src/storage/table_index_list.cpp @@ -12,41 +12,50 @@ namespace duckdb { +IndexEntry::IndexEntry(unique_ptr index_p) : index(std::move(index_p)) { + if (index->IsBound()) { + bind_state = IndexBindState::BOUND; + } else { + bind_state = IndexBindState::UNBOUND; + } +} + void TableIndexList::AddIndex(unique_ptr index) { D_ASSERT(index); - lock_guard lock(indexes_lock); - indexes.push_back(std::move(index)); + lock_guard lock(index_entries_lock); + auto index_entry = make_uniq(std::move(index)); + index_entries.push_back(std::move(index_entry)); } void TableIndexList::RemoveIndex(const string &name) { - lock_guard lock(indexes_lock); - - for (idx_t i = 0; i < indexes.size(); i++) { - auto &index = indexes[i]; - if (index->GetIndexName() == name) { - indexes.erase_at(i); - break; + lock_guard lock(index_entries_lock); + for (idx_t i = 0; i < index_entries.size(); i++) { + auto &index = *index_entries[i]->index; + if (index.GetIndexName() == name) { + index_entries.erase_at(i); + return; } } } void TableIndexList::CommitDrop(const string &name) { - lock_guard lock(indexes_lock); - - for (auto &index : indexes) { - if (index->GetIndexName() == name) { - index->CommitDrop(); + lock_guard lock(index_entries_lock); + for (auto &entry : index_entries) { + auto &index = *entry->index; + if (index.GetIndexName() == name) { + index.CommitDrop(); + return; } } } bool TableIndexList::NameIsUnique(const string &name) { - lock_guard lock(indexes_lock); - // Only covers PK, FK, and UNIQUE indexes. - for (const auto &index : indexes) { - if (index->IsPrimary() || index->IsForeign() || index->IsUnique()) { - if (index->GetIndexName() == name) { + lock_guard lock(index_entries_lock); + for (auto &entry : index_entries) { + auto &index = *entry->index; + if (index.IsPrimary() || index.IsForeign() || index.IsUnique()) { + if (index.GetIndexName() == name) { return false; } } @@ -55,21 +64,26 @@ bool TableIndexList::NameIsUnique(const string &name) { } optional_ptr TableIndexList::Find(const string &name) { - for (auto &index : indexes) { - if (index->GetIndexName() == name) { - return index->Cast(); + for (auto &entry : index_entries) { + auto &index = *entry->index; + if (index.GetIndexName() == name) { + if (!index.IsBound()) { + throw InternalException("cannot return an unbound index in TableIndexList::Find"); + } + return index.Cast(); } } return nullptr; } -void TableIndexList::InitializeIndexes(ClientContext &context, DataTableInfo &table_info, const char *index_type) { - // Fast path: do we have any unbound indexes? +void TableIndexList::Bind(ClientContext &context, DataTableInfo &table_info, const char *index_type) { + // Early-out, if we have no unbound indexes. bool needs_binding = false; { - lock_guard lock(indexes_lock); - for (auto &index : indexes) { - if (!index->IsBound() && (index_type == nullptr || index->GetIndexType() == index_type)) { + lock_guard lock(index_entries_lock); + for (auto &entry : index_entries) { + auto &index = *entry->index; + if (!index.IsBound() && (index_type == nullptr || index.GetIndexType() == index_type)) { needs_binding = true; break; } @@ -93,39 +107,74 @@ void TableIndexList::InitializeIndexes(ClientContext &context, DataTableInfo &ta column_names.push_back(col.Name()); } - lock_guard lock(indexes_lock); - for (auto &index : indexes) { - if (!index->IsBound() && (index_type == nullptr || index->GetIndexType() == index_type)) { - // Create a binder to bind this index. - auto binder = Binder::CreateBinder(context); + unique_lock lock(index_entries_lock); + // Busy-spin trying to bind all indexes. + while (true) { + optional_ptr index_entry; + for (auto &entry : index_entries) { + auto &index = *entry->index; + if (!index.IsBound() && (index_type == nullptr || index.GetIndexType() == index_type)) { + index_entry = entry.get(); + break; + } + } + if (!index_entry) { + // We bound all indexes. + break; + } + if (index_entry->bind_state == IndexBindState::BINDING) { + // Another thread is binding the index. + // Lock and unlock the index entries so that the other thread can commit its changes. + lock.unlock(); + lock.lock(); + continue; + + } else if (index_entry->bind_state == IndexBindState::UNBOUND) { + // We are the thread that'll bind the index. + index_entry->bind_state = IndexBindState::BINDING; + lock.unlock(); + + } else { + throw InternalException("index entry bind state cannot be BOUND here"); + } - // Add the table to the binder. - vector dummy_column_ids; - binder->bind_context.AddBaseTable(0, string(), column_names, column_types, dummy_column_ids, table); + // Create a binder to bind this index. + auto binder = Binder::CreateBinder(context); - // Create an IndexBinder to bind the index - IndexBinder idx_binder(*binder, context); + // Add the table to the binder. + vector dummy_column_ids; + binder->bind_context.AddBaseTable(0, string(), column_names, column_types, dummy_column_ids, table); - // Replace the unbound index with a bound index. - auto bound_idx = idx_binder.BindIndex(index->Cast()); - index = std::move(bound_idx); + // Create an IndexBinder to bind the index + IndexBinder idx_binder(*binder, context); + + // Apply any outstanding appends and replace the unbound index with a bound index. + auto &unbound_index = index_entry->index->Cast(); + auto bound_idx = idx_binder.BindIndex(unbound_index); + if (unbound_index.HasBufferedAppends()) { + bound_idx->ApplyBufferedAppends(unbound_index.GetBufferedAppends()); } + + // Commit the bound index to the index entry. + lock.lock(); + index_entry->bind_state = IndexBindState::BOUND; + index_entry->index = std::move(bound_idx); } } bool TableIndexList::Empty() { - lock_guard lock(indexes_lock); - return indexes.empty(); + lock_guard lock(index_entries_lock); + return index_entries.empty(); } idx_t TableIndexList::Count() { - lock_guard lock(indexes_lock); - return indexes.size(); + lock_guard lock(index_entries_lock); + return index_entries.size(); } void TableIndexList::Move(TableIndexList &other) { - D_ASSERT(indexes.empty()); - indexes = std::move(other.indexes); + D_ASSERT(index_entries.empty()); + index_entries = std::move(other.index_entries); } bool IsForeignKeyIndex(const vector &fk_keys, Index &index, ForeignKeyType fk_type) { @@ -154,9 +203,10 @@ bool IsForeignKeyIndex(const vector &fk_keys, Index &index, Forei optional_ptr TableIndexList::FindForeignKeyIndex(const vector &fk_keys, const ForeignKeyType fk_type) { - for (auto &index_elem : indexes) { - if (IsForeignKeyIndex(fk_keys, *index_elem, fk_type)) { - return index_elem; + for (auto &entry : index_entries) { + auto &index = *entry->index; + if (IsForeignKeyIndex(fk_keys, index, fk_type)) { + return index; } } return nullptr; @@ -182,10 +232,11 @@ void TableIndexList::VerifyForeignKey(optional_ptr storage, c } unordered_set TableIndexList::GetRequiredColumns() { - lock_guard lock(indexes_lock); + lock_guard lock(index_entries_lock); unordered_set column_ids; - for (auto &index : indexes) { - for (auto col_id : index->GetColumnIds()) { + for (auto &entry : index_entries) { + auto &index = *entry->index; + for (auto col_id : index.GetColumnIds()) { column_ids.insert(col_id); } } @@ -195,15 +246,16 @@ unordered_set TableIndexList::GetRequiredColumns() { vector TableIndexList::SerializeToDisk(QueryContext context, const case_insensitive_map_t &options) { vector infos; - for (auto &index : indexes) { - if (index->IsBound()) { - auto info = index->Cast().SerializeToDisk(context, options); + for (auto &entry : index_entries) { + auto &index = *entry->index; + if (index.IsBound()) { + auto info = index.Cast().SerializeToDisk(context, options); D_ASSERT(info.IsValid() && !info.name.empty()); infos.push_back(info); continue; } - auto info = index->Cast().GetStorageInfo(); + auto info = index.Cast().GetStorageInfo(); D_ASSERT(!info.name.empty()); infos.push_back(info); } diff --git a/src/duckdb/src/storage/temporary_file_manager.cpp b/src/duckdb/src/storage/temporary_file_manager.cpp index 925e87e8e..b8ab5a7b0 100644 --- a/src/duckdb/src/storage/temporary_file_manager.cpp +++ b/src/duckdb/src/storage/temporary_file_manager.cpp @@ -210,7 +210,8 @@ TemporaryFileIndex TemporaryFileHandle::TryGetBlockIndex(idx_t block_header_size return TemporaryFileIndex(identifier, block_index, block_header_size); } -unique_ptr TemporaryFileHandle::ReadTemporaryBuffer(const TemporaryFileIndex &index_in_file, +unique_ptr TemporaryFileHandle::ReadTemporaryBuffer(QueryContext context, + const TemporaryFileIndex &index_in_file, unique_ptr reusable_buffer) const { auto &buffer_manager = BufferManager::GetBufferManager(db); auto block_index = index_in_file.block_index.GetIndex(); @@ -237,13 +238,13 @@ unique_ptr TemporaryFileHandle::ReadTemporaryBuffer(const TemporaryF if (IsEncrypted()) { uint8_t encryption_metadata[DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE]; //! Read nonce and tag. - handle->Read(encryption_metadata, DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE, read_position); + handle->Read(context, encryption_metadata, DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE, read_position); //! Read the encrypted compressed buffer. - handle->Read(read_buffer, read_size, read_position + DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE); + handle->Read(context, read_buffer, read_size, read_position + DEFAULT_ENCRYPTED_BUFFER_HEADER_SIZE); //! Decrypt the compressed buffer. EncryptionEngine::DecryptTemporaryBuffer(db, read_buffer, read_size, encryption_metadata); } else { - handle->Read(read_buffer, read_size, read_position); + handle->Read(context, read_buffer, read_size, read_position); } if (is_uncompressed) { @@ -496,7 +497,7 @@ TemporaryFileManager::~TemporaryFileManager() { TemporaryFileManager::TemporaryFileManagerLock::TemporaryFileManagerLock(mutex &mutex) : lock(mutex) { } -void TemporaryFileManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { +idx_t TemporaryFileManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer &buffer) { // We group DEFAULT_BLOCK_ALLOC_SIZE blocks into the same file. D_ASSERT(buffer.AllocSize() == BufferManager::GetBufferManager(db).GetBlockAllocSize()); @@ -539,6 +540,7 @@ void TemporaryFileManager::WriteTemporaryBuffer(block_id_t block_id, FileBuffer handle->WriteTemporaryBuffer(buffer, index.block_index.GetIndex(), compressed_buffer); compression_adaptivity.Update(compression_result.level, time_before_ns); + return static_cast(compression_result.size); } TemporaryFileManager::CompressionResult @@ -643,7 +645,7 @@ bool TemporaryFileManager::IsEncrypted() const { return db.config.options.temp_file_encryption; } -unique_ptr TemporaryFileManager::ReadTemporaryBuffer(block_id_t id, +unique_ptr TemporaryFileManager::ReadTemporaryBuffer(QueryContext context, block_id_t id, unique_ptr reusable_buffer) { TemporaryFileIndex index; optional_ptr handle; @@ -654,7 +656,7 @@ unique_ptr TemporaryFileManager::ReadTemporaryBuffer(block_id_t id, } // before the reusable buffer is given, - auto buffer = handle->ReadTemporaryBuffer(index, std::move(reusable_buffer)); + auto buffer = handle->ReadTemporaryBuffer(context, index, std::move(reusable_buffer)); { // remove the block (and potentially erase the temp file) TemporaryFileManagerLock lock(manager_lock); @@ -663,11 +665,12 @@ unique_ptr TemporaryFileManager::ReadTemporaryBuffer(block_id_t id, return buffer; } -void TemporaryFileManager::DeleteTemporaryBuffer(block_id_t id) { +idx_t TemporaryFileManager::DeleteTemporaryBuffer(block_id_t id) { TemporaryFileManagerLock lock(manager_lock); auto index = GetTempBlockIndex(lock, id); auto handle = GetFileHandle(lock, index.identifier); EraseUsedBlock(lock, id, *handle, index); + return static_cast(index.identifier.size); } vector TemporaryFileManager::GetTemporaryFiles() { diff --git a/src/duckdb/src/transaction/duck_transaction.cpp b/src/duckdb/src/transaction/duck_transaction.cpp index cb1b46b32..f27a43c83 100644 --- a/src/duckdb/src/transaction/duck_transaction.cpp +++ b/src/duckdb/src/transaction/duck_transaction.cpp @@ -161,7 +161,9 @@ bool DuckTransaction::ChangesMade() { } UndoBufferProperties DuckTransaction::GetUndoProperties() { - return undo_buffer.GetProperties(); + auto properties = undo_buffer.GetProperties(); + properties.estimated_size += storage->EstimatedSize(); + return properties; } bool DuckTransaction::AutomaticCheckpoint(AttachedDatabase &db, const UndoBufferProperties &properties) { @@ -176,7 +178,7 @@ bool DuckTransaction::AutomaticCheckpoint(AttachedDatabase &db, const UndoBuffer return false; } auto &storage_manager = db.GetStorageManager(); - return storage_manager.AutomaticCheckpoint(storage->EstimatedSize() + properties.estimated_size); + return storage_manager.AutomaticCheckpoint(properties.estimated_size); } bool DuckTransaction::ShouldWriteToWAL(AttachedDatabase &db) { diff --git a/src/duckdb/src/transaction/duck_transaction_manager.cpp b/src/duckdb/src/transaction/duck_transaction_manager.cpp index c25bc903f..50c4db6b4 100644 --- a/src/duckdb/src/transaction/duck_transaction_manager.cpp +++ b/src/duckdb/src/transaction/duck_transaction_manager.cpp @@ -14,6 +14,7 @@ #include "duckdb/main/attached_database.hpp" #include "duckdb/main/database_manager.hpp" #include "duckdb/transaction/meta_transaction.hpp" +#include "duckdb/main/settings.hpp" namespace duckdb { @@ -102,18 +103,17 @@ DuckTransactionManager::CanCheckpoint(DuckTransaction &transaction, unique_ptr(db.GetDatabase())) { return CheckpointDecision("checkpointing on commit disabled through configuration"); } // try to lock the checkpoint lock @@ -156,15 +156,17 @@ DuckTransactionManager::CanCheckpoint(DuckTransaction &transaction, unique_ptr MetaTransaction::TryGetTransaction(AttachedDatabase &d if (entry == transactions.end()) { return nullptr; } else { - return &entry->second.get(); + return &entry->second.transaction; } } @@ -63,12 +63,12 @@ Transaction &MetaTransaction::GetTransaction(AttachedDatabase &db) { VerifyAllTransactionsUnique(db, all_transactions); #endif all_transactions.push_back(db); - transactions.insert(make_pair(reference(db), reference(new_transaction))); + transactions.insert(make_pair(reference(db), TransactionReference(new_transaction))); return new_transaction; } else { - D_ASSERT(entry->second.get().active_query == active_query); - return entry->second; + D_ASSERT(entry->second.transaction.active_query == active_query); + return entry->second.transaction; } } @@ -115,34 +115,60 @@ ErrorData MetaTransaction::Commit() { if (entry == transactions.end()) { throw InternalException("Could not find transaction corresponding to database in MetaTransaction"); } + #ifdef DEBUG auto already_committed = committed_tx.insert(db).second == false; if (already_committed) { throw InternalException("All databases inside all_transactions should be unique, invariant broken!"); } #endif + auto &transaction_manager = db.GetTransactionManager(); - auto &transaction = entry->second.get(); - if (!error.HasError()) { - // commit - error = transaction_manager.CommitTransaction(context, transaction); - } else { - // we have encountered an error previously - roll back subsequent entries - transaction_manager.RollbackTransaction(transaction); + auto &transaction_ref = entry->second; + if (transaction_ref.state != TransactionState::UNCOMMITTED) { + continue; + } + auto &transaction = transaction_ref.transaction; + try { + if (!error.HasError()) { + // Commit the transaction. + error = transaction_manager.CommitTransaction(context, transaction); + transaction_ref.state = error.HasError() ? TransactionState::ROLLED_BACK : TransactionState::COMMITTED; + } else { + // Rollback due to previous error. + transaction_manager.RollbackTransaction(transaction); + transaction_ref.state = TransactionState::ROLLED_BACK; + } + } catch (std::exception &ex) { + error.Merge(ErrorData(ex)); + transaction_ref.state = TransactionState::ROLLED_BACK; } } return error; } void MetaTransaction::Rollback() { - // rollback transactions in reverse order + // Rollback all transactions in reverse order. + ErrorData error; for (idx_t i = all_transactions.size(); i > 0; i--) { auto &db = all_transactions[i - 1].get(); auto &transaction_manager = db.GetTransactionManager(); auto entry = transactions.find(db); D_ASSERT(entry != transactions.end()); - auto &transaction = entry->second.get(); - transaction_manager.RollbackTransaction(transaction); + auto &transaction_ref = entry->second; + if (transaction_ref.state != TransactionState::UNCOMMITTED) { + continue; + } + try { + auto &transaction = transaction_ref.transaction; + transaction_manager.RollbackTransaction(transaction); + } catch (std::exception &ex) { + error.Merge(ErrorData(ex)); + } + transaction_ref.state = TransactionState::ROLLED_BACK; + } + if (error.HasError()) { + error.Throw(); } } @@ -153,24 +179,25 @@ idx_t MetaTransaction::GetActiveQuery() { void MetaTransaction::SetActiveQuery(transaction_t query_number) { active_query = query_number; for (auto &entry : transactions) { - entry.second.get().active_query = query_number; + entry.second.transaction.active_query = query_number; } } void MetaTransaction::ModifyDatabase(AttachedDatabase &db) { - if (db.IsSystem() || db.IsTemporary()) { - // we can always modify the system and temp databases - return; - } if (IsReadOnly()) { throw TransactionException("Cannot write to database \"%s\" - transaction is launched in read-only mode", db.GetName()); } + auto &transaction = GetTransaction(db); + if (transaction.IsReadOnly()) { + transaction.SetReadWrite(); + } + if (db.IsSystem() || db.IsTemporary()) { + // we can always modify the system and temp databases + return; + } if (!modified_database) { modified_database = &db; - - auto &transaction = GetTransaction(db); - transaction.SetReadWrite(); return; } if (&db != modified_database.get()) { diff --git a/src/duckdb/src/verification/explain_statement_verifier.cpp b/src/duckdb/src/verification/explain_statement_verifier.cpp new file mode 100644 index 000000000..36201f2c7 --- /dev/null +++ b/src/duckdb/src/verification/explain_statement_verifier.cpp @@ -0,0 +1,16 @@ +#include "duckdb/verification/explain_statement_verifier.hpp" + +namespace duckdb { + +ExplainStatementVerifier::ExplainStatementVerifier(unique_ptr statement_p, + optional_ptr> parameters) + : StatementVerifier(VerificationType::EXPLAIN, "Explain", std::move(statement_p), parameters) { +} + +unique_ptr +ExplainStatementVerifier::Create(const SQLStatement &statement, + optional_ptr> parameters) { + return make_uniq(statement.Copy(), parameters); +} + +} // namespace duckdb diff --git a/src/duckdb/src/verification/prepared_statement_verifier.cpp b/src/duckdb/src/verification/prepared_statement_verifier.cpp index 2458d4714..319288cdb 100644 --- a/src/duckdb/src/verification/prepared_statement_verifier.cpp +++ b/src/duckdb/src/verification/prepared_statement_verifier.cpp @@ -21,7 +21,7 @@ PreparedStatementVerifier::Create(const SQLStatement &statement, } void PreparedStatementVerifier::Extract() { - auto &select = *statement; + auto &select = *select_statement; // replace all the constants from the select statement and replace them with parameter expressions ParsedExpressionIterator::EnumerateQueryNodeChildren( *select.node, [&](unique_ptr &child) { ConvertConstants(child); }); diff --git a/src/duckdb/src/verification/statement_verifier.cpp b/src/duckdb/src/verification/statement_verifier.cpp index fb8fc71ac..81f4c4aba 100644 --- a/src/duckdb/src/verification/statement_verifier.cpp +++ b/src/duckdb/src/verification/statement_verifier.cpp @@ -11,14 +11,17 @@ #include "duckdb/verification/unoptimized_statement_verifier.hpp" #include "duckdb/verification/no_operator_caching_verifier.hpp" #include "duckdb/verification/fetch_row_verifier.hpp" +#include "duckdb/verification/explain_statement_verifier.hpp" namespace duckdb { StatementVerifier::StatementVerifier(VerificationType type, string name, unique_ptr statement_p, optional_ptr> parameters_p) - : type(type), name(std::move(name)), - statement(unique_ptr_cast(std::move(statement_p))), parameters(parameters_p), - select_list(statement->node->GetSelectList()) { + : type(type), name(std::move(name)), statement(std::move(statement_p)), + select_statement(statement->type == StatementType::SELECT_STATEMENT ? &statement->Cast() + : nullptr), + parameters(parameters_p), + select_list(select_statement ? select_statement->node->GetSelectList() : empty_select_list) { } StatementVerifier::StatementVerifier(unique_ptr statement_p, @@ -47,6 +50,8 @@ StatementVerifier::Create(VerificationType type, const SQLStatement &statement_p return PreparedStatementVerifier::Create(statement_p, parameters); case VerificationType::EXTERNAL: return ExternalStatementVerifier::Create(statement_p, parameters); + case VerificationType::EXPLAIN: + return ExplainStatementVerifier::Create(statement_p, parameters); case VerificationType::FETCH_ROW_AS_SCAN: return FetchRowVerifier::Create(statement_p, parameters); case VerificationType::INVALID: @@ -60,8 +65,8 @@ void StatementVerifier::CheckExpressions(const StatementVerifier &other) const { D_ASSERT(type == VerificationType::ORIGINAL); // Check equality - if (other.RequireEquality()) { - D_ASSERT(statement->Equals(*other.statement)); + if (other.RequireEquality() && select_statement) { + D_ASSERT(select_statement->Equals(*other.select_statement)); } #ifdef DEBUG diff --git a/src/duckdb/third_party/brotli/common/shared_dictionary.cpp b/src/duckdb/third_party/brotli/common/shared_dictionary.cpp index 3b951c989..f74774586 100644 --- a/src/duckdb/third_party/brotli/common/shared_dictionary.cpp +++ b/src/duckdb/third_party/brotli/common/shared_dictionary.cpp @@ -60,9 +60,9 @@ static BROTLI_BOOL ReadUint16(const uint8_t* encoded, size_t size, size_t* pos, return BROTLI_TRUE; } -/* Reads a varint into a uint32_t, and returns error if it's too large */ +/* Reads a bignum into a uint32_t, and returns error if it's too large */ /* Returns BROTLI_TRUE on success, BROTLI_FALSE on failure. */ -static BROTLI_BOOL ReadVarint32(const uint8_t* encoded, size_t size, +static BROTLI_BOOL ReadBignum32(const uint8_t* encoded, size_t size, size_t* pos, uint32_t* result) { int num = 0; uint8_t byte; @@ -248,7 +248,7 @@ static BROTLI_BOOL DryParseDictionary(const uint8_t* encoded, pos += 2; /* LZ77_DICTIONARY_LENGTH */ - if (!ReadVarint32(encoded, size, &pos, &chunk_size)) return BROTLI_FALSE; + if (!ReadBignum32(encoded, size, &pos, &chunk_size)) return BROTLI_FALSE; if (chunk_size != 0) { /* This limitation is not specified but the 32-bit Brotli decoder for now */ if (chunk_size > 1073741823) return BROTLI_FALSE; @@ -284,7 +284,7 @@ static BROTLI_BOOL ParseDictionary(const uint8_t* encoded, size_t size, pos += 2; /* LZ77_DICTIONARY_LENGTH */ - if (!ReadVarint32(encoded, size, &pos, &chunk_size)) return BROTLI_FALSE; + if (!ReadBignum32(encoded, size, &pos, &chunk_size)) return BROTLI_FALSE; if (chunk_size != 0) { if (pos + chunk_size > size) return BROTLI_FALSE; dict->prefix_size[dict->num_prefix] = chunk_size; diff --git a/src/duckdb/third_party/fmt/include/fmt/core.h b/src/duckdb/third_party/fmt/include/fmt/core.h index fe95196b3..ed555c6c7 100644 --- a/src/duckdb/third_party/fmt/include/fmt/core.h +++ b/src/duckdb/third_party/fmt/include/fmt/core.h @@ -8,6 +8,9 @@ #ifndef FMT_CORE_H_ #define FMT_CORE_H_ +#include "duckdb/common/hugeint.hpp" +#include "duckdb/common/uhugeint.hpp" + #include // std::FILE #include #include @@ -231,19 +234,16 @@ using std_string_view = std::experimental::basic_string_view; template struct std_string_view {}; #endif +using int128_t = duckdb::hugeint_t; +using uint128_t = duckdb::uhugeint_t; + #ifdef FMT_USE_INT128 // Do nothing. #elif defined(__SIZEOF_INT128__) # define FMT_USE_INT128 1 -using int128_t = __int128_t; -using uint128_t = __uint128_t; #else # define FMT_USE_INT128 0 #endif -#if !FMT_USE_INT128 -struct int128_t {}; -struct uint128_t {}; -#endif // Casts a nonnegative integer to unsigned. template @@ -1004,16 +1004,10 @@ FMT_CONSTEXPR auto visit_format_arg(Visitor&& vis, return vis(arg.value_.long_long_value); case internal::ulong_long_type: return vis(arg.value_.ulong_long_value); -#if FMT_USE_INT128 case internal::int128_type: return vis(arg.value_.int128_value); case internal::uint128_type: return vis(arg.value_.uint128_value); -#else - case internal::int128_type: - case internal::uint128_type: - break; -#endif case internal::bool_type: return vis(arg.value_.bool_value); case internal::char_type: diff --git a/src/duckdb/third_party/fmt/include/fmt/format-inl.h b/src/duckdb/third_party/fmt/include/fmt/format-inl.h index 1755025ae..b92c914b5 100644 --- a/src/duckdb/third_party/fmt/include/fmt/format-inl.h +++ b/src/duckdb/third_party/fmt/include/fmt/format-inl.h @@ -629,7 +629,11 @@ class bigint { int num_bigits = static_cast(bigits_.size()); int num_result_bigits = 2 * num_bigits; bigits_.resize(num_result_bigits); - using accumulator_t = conditional_t; +#if FMT_USE_INT128 + using accumulator_t = __uint128_t; +#else + using accumulator_t = accumulator; +#endif auto sum = accumulator_t(); for (int bigit_index = 0; bigit_index < num_bigits; ++bigit_index) { // Compute bigit at position bigit_index of the result by adding diff --git a/src/duckdb/third_party/fmt/include/fmt/format.h b/src/duckdb/third_party/fmt/include/fmt/format.h index 67190586d..4c5163010 100644 --- a/src/duckdb/third_party/fmt/include/fmt/format.h +++ b/src/duckdb/third_party/fmt/include/fmt/format.h @@ -34,6 +34,7 @@ #define FMT_FORMAT_H_ #include "duckdb/common/exception.hpp" +#include "duckdb/common/limits.hpp" #include "duckdb/original/std/memory.hpp" #include "fmt/core.h" @@ -254,9 +255,21 @@ inline fallback_uintptr to_uintptr(const void* p) { template constexpr T max_value() { return (std::numeric_limits::max)(); } +template <> constexpr int128_t max_value() { + return duckdb::NumericLimits::Maximum(); +} +template <> constexpr uint128_t max_value() { + return duckdb::NumericLimits::Maximum(); +} template constexpr int num_bits() { return std::numeric_limits::digits; } +template <> constexpr int num_bits() { + return 127; +} +template <> constexpr int num_bits() { + return 128; +} template <> constexpr int num_bits() { return static_cast(sizeof(void*) * std::numeric_limits::digits); @@ -691,11 +704,11 @@ namespace internal { // Returns true if value is negative, false otherwise. // Same as `value < 0` but doesn't produce warnings if T is an unsigned type. -template ::is_signed)> +template ::is_signed || std::is_same::value)> FMT_CONSTEXPR bool is_negative(T value) { return value < 0; } -template ::is_signed)> +template ::is_signed && !std::is_same::value)> FMT_CONSTEXPR bool is_negative(T) { return false; } @@ -703,9 +716,10 @@ FMT_CONSTEXPR bool is_negative(T) { // Smallest of uint32_t, uint64_t, uint128_t that is large enough to // represent all values of T. template -using uint32_or_64_or_128_t = conditional_t< - std::numeric_limits::digits <= 32, uint32_t, - conditional_t::digits <= 64, uint64_t, uint128_t>>; +using uint32_or_64_or_128_t = + conditional_t::value || std::is_same::value, uint128_t, + conditional_t::digits <= 32, uint32_t, + conditional_t::digits <= 64, uint64_t, uint128_t>>>; // Static data is placed in this class template for the header-only config. template struct FMT_EXTERN_TEMPLATE_API basic_data { @@ -755,7 +769,6 @@ inline int count_digits(uint64_t n) { } #endif -#if FMT_USE_INT128 inline int count_digits(uint128_t n) { int count = 1; for (;;) { @@ -770,7 +783,6 @@ inline int count_digits(uint128_t n) { count += 4; } } -#endif // Counts the number of digits in n. BITS = log2(radix). template inline int count_digits(UInt n) { @@ -842,7 +854,7 @@ inline Char* format_decimal(Char* buffer, UInt value, int num_digits, add_thousands_sep(buffer); } if (value < 10) { - *--buffer = static_cast('0' + value); + *--buffer = static_cast('0' + static_cast(value)); return end; } auto index = static_cast(value * 2); @@ -881,7 +893,7 @@ inline Char* format_uint(Char* buffer, UInt value, int num_digits, Char* end = buffer; do { const char* digits = upper ? "0123456789ABCDEF" : data::hex_digits; - unsigned digit = (value & ((1 << BASE_BITS) - 1)); + unsigned digit = (static_cast(value) & ((1 << BASE_BITS) - 1)); *--buffer = static_cast(BASE_BITS < 4 ? static_cast('0' + digit) : digits[digit]); } while ((value >>= BASE_BITS) != 0); @@ -1455,7 +1467,7 @@ template class basic_writer { if (is_negative(value)) { prefix[0] = '-'; ++prefix_size; - abs_value = 0 - abs_value; + abs_value = -abs_value; } else if (specs.sign != sign::none && specs.sign != sign::minus) { prefix[0] = specs.sign == sign::plus ? '+' : ' '; ++prefix_size; @@ -1651,10 +1663,8 @@ template class basic_writer { void write(unsigned long value) { write_decimal(value); } void write(unsigned long long value) { write_decimal(value); } -#if FMT_USE_INT128 void write(int128_t value) { write_decimal(value); } void write(uint128_t value) { write_decimal(value); } -#endif template void write_int(T value, const Spec& spec) { @@ -1977,13 +1987,13 @@ template class width_checker { explicit FMT_CONSTEXPR width_checker(ErrorHandler& eh) : handler_(eh) {} template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T value) { + FMT_CONSTEXPR uint64_t operator()(T value) { if (is_negative(value)) handler_.on_error("negative width"); - return static_cast(value); + return static_cast(value); } template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T) { + FMT_CONSTEXPR uint64_t operator()(T) { handler_.on_error("width is not integer"); return 0; } @@ -1997,13 +2007,13 @@ template class precision_checker { explicit FMT_CONSTEXPR precision_checker(ErrorHandler& eh) : handler_(eh) {} template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T value) { + FMT_CONSTEXPR uint64_t operator()(T value) { if (is_negative(value)) handler_.on_error("negative precision"); - return static_cast(value); + return static_cast(value); } template ::value)> - FMT_CONSTEXPR unsigned long long operator()(T) { + FMT_CONSTEXPR uint64_t operator()(T) { handler_.on_error("precision is not integer"); return 0; } @@ -2128,7 +2138,7 @@ template class specs_checker : public Handler { template