diff --git a/pkg-r/R/TblSqlSource.R b/pkg-r/R/TblSqlSource.R index 51e2d007..91903d90 100644 --- a/pkg-r/R/TblSqlSource.R +++ b/pkg-r/R/TblSqlSource.R @@ -117,10 +117,18 @@ TblSqlSource <- R6::R6Class( #' Execute a SQL query and return results #' #' @param query SQL query string to execute - #' @return A data frame containing query results - execute_query = function(query) { + #' @param collect If `TRUE` (default), collects the results into a local data frame + #' using [dplyr::collect()]. If `FALSE`, returns a lazy SQL + #' tibble. + #' @return A data frame (if `collect = TRUE`) or a lazy SQL tibble (if + #' `collect = FALSE`) + execute_query = function(query, collect = TRUE) { sql_query <- self$prep_query(query) - dplyr::tbl(private$conn, dplyr::sql(sql_query)) + result <- dplyr::tbl(private$conn, dplyr::sql(sql_query)) + if (collect) { + result <- dplyr::collect(result) + } + result }, #' @description diff --git a/pkg-r/man/TblSqlSource.Rd b/pkg-r/man/TblSqlSource.Rd index 71e98a6a..78f54fb2 100644 --- a/pkg-r/man/TblSqlSource.Rd +++ b/pkg-r/man/TblSqlSource.Rd @@ -117,18 +117,23 @@ A string containing schema information formatted for LLM prompts \subsection{Method \code{execute_query()}}{ Execute a SQL query and return results \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{TblSqlSource$execute_query(query)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{TblSqlSource$execute_query(query, collect = TRUE)}\if{html}{\out{
}} } \subsection{Arguments}{ \if{html}{\out{
}} \describe{ \item{\code{query}}{SQL query string to execute} + +\item{\code{collect}}{If \code{TRUE} (default), collects the results into a local data frame +using \code{\link[dplyr:compute]{dplyr::collect()}}. If \code{FALSE}, returns a lazy SQL +tibble.} } \if{html}{\out{
}} } \subsection{Returns}{ -A data frame containing query results +A data frame (if \code{collect = TRUE}) or a lazy SQL tibble (if +\code{collect = FALSE}) } } \if{html}{\out{
}} diff --git a/pkg-r/tests/testthat/helper-fixtures.R b/pkg-r/tests/testthat/helper-fixtures.R index d800b010..111ed784 100644 --- a/pkg-r/tests/testthat/helper-fixtures.R +++ b/pkg-r/tests/testthat/helper-fixtures.R @@ -117,7 +117,8 @@ local_tbl_sql_source <- function( DBI::dbWriteTable(conn, table_name, data, overwrite = TRUE) tbl <- dplyr::tbl(conn, table_name) - tbl <- tbl_transform(tbl) + tbl <- tbl_transform(tbl) |> + dplyr::compute("test_table") TblSqlSource$new(tbl, table_name) } diff --git a/pkg-r/tests/testthat/test-TblSqlSource.R b/pkg-r/tests/testthat/test-TblSqlSource.R index d1abcedf..45cebce3 100644 --- a/pkg-r/tests/testthat/test-TblSqlSource.R +++ b/pkg-r/tests/testthat/test-TblSqlSource.R @@ -22,10 +22,13 @@ describe("TblSqlSource$new()", { }) }) - it("returns lazy tibble from execute_query()", { + it("returns lazy tibble from execute_query() when collect = FALSE", { source <- local_tbl_sql_source() - result <- source$execute_query("SELECT * FROM test_table WHERE value > 25") + result <- source$execute_query( + "SELECT * FROM test_table WHERE value > 25", + collect = FALSE + ) expect_s3_class(result, "tbl_sql") expect_s3_class(result, "tbl_lazy") @@ -35,6 +38,31 @@ describe("TblSqlSource$new()", { expect_equal(collected$value, c(30, 40, 50)) }) + it("returns lazy tibble from execute_query() when collect = FALSE", { + source <- local_tbl_sql_source() + + result <- source$execute_query( + "SELECT * FROM test_table WHERE value > 25", + collect = FALSE + ) + expect_s3_class(result, "tbl_sql") + expect_s3_class(result, "tbl_lazy") + }) + + it("returns data frame from execute_query() when collect = TRUE", { + source <- local_tbl_sql_source() + + result <- source$execute_query( + "SELECT * FROM test_table WHERE value > 25", + collect = TRUE + ) + expect_s3_class(result, "data.frame") + expect_false(inherits(result, "tbl_sql")) + expect_false(inherits(result, "tbl_lazy")) + expect_equal(nrow(result), 3) + expect_equal(result$value, c(30, 40, 50)) + }) + it("returns data frame from test_query()", { source <- local_tbl_sql_source() @@ -59,7 +87,7 @@ describe("TblSqlSource with transformed tbl (CTE mode)", { ) # CTE should be used since tbl is transformed - result <- source$execute_query("SELECT * FROM test_table") + result <- source$execute_query("SELECT * FROM test_table", collect = FALSE) collected <- dplyr::collect(result) expect_equal(nrow(collected), 3) expect_true(all(collected$value > 20)) @@ -191,7 +219,8 @@ describe("TblSqlSource edge cases - Category B: Column Naming Issues", { # SELECT with explicit duplicate column names from JOIN # DuckDB allows duplicate names but tibble rejects them on collect result <- source$execute_query( - "SELECT table_a.id, table_b.id FROM table_a JOIN table_b ON table_a.id = table_b.id" + "SELECT table_a.id, table_b.id FROM table_a JOIN table_b ON table_a.id = table_b.id", + collect = FALSE ) expect_error( dplyr::collect(result), @@ -272,7 +301,8 @@ describe("TblSqlSource edge cases - Category B: Column Naming Issues", { # SELECT * from JOIN produces duplicate 'id' columns # tibble rejects duplicate names on collect result <- source$execute_query( - "SELECT * FROM table_a JOIN table_b ON table_a.id = table_b.id" + "SELECT * FROM table_a JOIN table_b ON table_a.id = table_b.id", + collect = FALSE ) expect_error( dplyr::collect(result),