diff --git a/pkg-r/R/TblSqlSource.R b/pkg-r/R/TblSqlSource.R
index 51e2d007..91903d90 100644
--- a/pkg-r/R/TblSqlSource.R
+++ b/pkg-r/R/TblSqlSource.R
@@ -117,10 +117,18 @@ TblSqlSource <- R6::R6Class(
#' Execute a SQL query and return results
#'
#' @param query SQL query string to execute
- #' @return A data frame containing query results
- execute_query = function(query) {
+ #' @param collect If `TRUE` (default), collects the results into a local data frame
+ #' using [dplyr::collect()]. If `FALSE`, returns a lazy SQL
+ #' tibble.
+ #' @return A data frame (if `collect = TRUE`) or a lazy SQL tibble (if
+ #' `collect = FALSE`)
+ execute_query = function(query, collect = TRUE) {
sql_query <- self$prep_query(query)
- dplyr::tbl(private$conn, dplyr::sql(sql_query))
+ result <- dplyr::tbl(private$conn, dplyr::sql(sql_query))
+ if (collect) {
+ result <- dplyr::collect(result)
+ }
+ result
},
#' @description
diff --git a/pkg-r/man/TblSqlSource.Rd b/pkg-r/man/TblSqlSource.Rd
index 71e98a6a..78f54fb2 100644
--- a/pkg-r/man/TblSqlSource.Rd
+++ b/pkg-r/man/TblSqlSource.Rd
@@ -117,18 +117,23 @@ A string containing schema information formatted for LLM prompts
\subsection{Method \code{execute_query()}}{
Execute a SQL query and return results
\subsection{Usage}{
-\if{html}{\out{
}}\preformatted{TblSqlSource$execute_query(query)}\if{html}{\out{
}}
+\if{html}{\out{}}\preformatted{TblSqlSource$execute_query(query, collect = TRUE)}\if{html}{\out{
}}
}
\subsection{Arguments}{
\if{html}{\out{}}
\describe{
\item{\code{query}}{SQL query string to execute}
+
+\item{\code{collect}}{If \code{TRUE} (default), collects the results into a local data frame
+using \code{\link[dplyr:compute]{dplyr::collect()}}. If \code{FALSE}, returns a lazy SQL
+tibble.}
}
\if{html}{\out{
}}
}
\subsection{Returns}{
-A data frame containing query results
+A data frame (if \code{collect = TRUE}) or a lazy SQL tibble (if
+\code{collect = FALSE})
}
}
\if{html}{\out{
}}
diff --git a/pkg-r/tests/testthat/helper-fixtures.R b/pkg-r/tests/testthat/helper-fixtures.R
index d800b010..111ed784 100644
--- a/pkg-r/tests/testthat/helper-fixtures.R
+++ b/pkg-r/tests/testthat/helper-fixtures.R
@@ -117,7 +117,8 @@ local_tbl_sql_source <- function(
DBI::dbWriteTable(conn, table_name, data, overwrite = TRUE)
tbl <- dplyr::tbl(conn, table_name)
- tbl <- tbl_transform(tbl)
+ tbl <- tbl_transform(tbl) |>
+ dplyr::compute("test_table")
TblSqlSource$new(tbl, table_name)
}
diff --git a/pkg-r/tests/testthat/test-TblSqlSource.R b/pkg-r/tests/testthat/test-TblSqlSource.R
index d1abcedf..45cebce3 100644
--- a/pkg-r/tests/testthat/test-TblSqlSource.R
+++ b/pkg-r/tests/testthat/test-TblSqlSource.R
@@ -22,10 +22,13 @@ describe("TblSqlSource$new()", {
})
})
- it("returns lazy tibble from execute_query()", {
+ it("returns lazy tibble from execute_query() when collect = FALSE", {
source <- local_tbl_sql_source()
- result <- source$execute_query("SELECT * FROM test_table WHERE value > 25")
+ result <- source$execute_query(
+ "SELECT * FROM test_table WHERE value > 25",
+ collect = FALSE
+ )
expect_s3_class(result, "tbl_sql")
expect_s3_class(result, "tbl_lazy")
@@ -35,6 +38,31 @@ describe("TblSqlSource$new()", {
expect_equal(collected$value, c(30, 40, 50))
})
+ it("returns lazy tibble from execute_query() when collect = FALSE", {
+ source <- local_tbl_sql_source()
+
+ result <- source$execute_query(
+ "SELECT * FROM test_table WHERE value > 25",
+ collect = FALSE
+ )
+ expect_s3_class(result, "tbl_sql")
+ expect_s3_class(result, "tbl_lazy")
+ })
+
+ it("returns data frame from execute_query() when collect = TRUE", {
+ source <- local_tbl_sql_source()
+
+ result <- source$execute_query(
+ "SELECT * FROM test_table WHERE value > 25",
+ collect = TRUE
+ )
+ expect_s3_class(result, "data.frame")
+ expect_false(inherits(result, "tbl_sql"))
+ expect_false(inherits(result, "tbl_lazy"))
+ expect_equal(nrow(result), 3)
+ expect_equal(result$value, c(30, 40, 50))
+ })
+
it("returns data frame from test_query()", {
source <- local_tbl_sql_source()
@@ -59,7 +87,7 @@ describe("TblSqlSource with transformed tbl (CTE mode)", {
)
# CTE should be used since tbl is transformed
- result <- source$execute_query("SELECT * FROM test_table")
+ result <- source$execute_query("SELECT * FROM test_table", collect = FALSE)
collected <- dplyr::collect(result)
expect_equal(nrow(collected), 3)
expect_true(all(collected$value > 20))
@@ -191,7 +219,8 @@ describe("TblSqlSource edge cases - Category B: Column Naming Issues", {
# SELECT with explicit duplicate column names from JOIN
# DuckDB allows duplicate names but tibble rejects them on collect
result <- source$execute_query(
- "SELECT table_a.id, table_b.id FROM table_a JOIN table_b ON table_a.id = table_b.id"
+ "SELECT table_a.id, table_b.id FROM table_a JOIN table_b ON table_a.id = table_b.id",
+ collect = FALSE
)
expect_error(
dplyr::collect(result),
@@ -272,7 +301,8 @@ describe("TblSqlSource edge cases - Category B: Column Naming Issues", {
# SELECT * from JOIN produces duplicate 'id' columns
# tibble rejects duplicate names on collect
result <- source$execute_query(
- "SELECT * FROM table_a JOIN table_b ON table_a.id = table_b.id"
+ "SELECT * FROM table_a JOIN table_b ON table_a.id = table_b.id",
+ collect = FALSE
)
expect_error(
dplyr::collect(result),