diff --git a/DESCRIPTION b/DESCRIPTION index 3ee2e8d..88f68d9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -18,7 +18,8 @@ Imports: knitr, duckdbfs, dplyr, - utils + utils, + purrr Suggests: testthat (>= 3.0.0) Config/testthat/edition: 3 diff --git a/R/agent.R b/R/agent.R index 5f46977..b69f6be 100644 --- a/R/agent.R +++ b/R/agent.R @@ -44,7 +44,8 @@ create_prompt <- function(con = duckdbfs::cached_connection(), Pay attention to the schema of the table : -Also pay close attention to how data is represented in each column, as seen by the HEAD of table : +Also pay close attention to how data is represented in each column, +as seen by the HEAD (ommitting large list-type columns, if any) of table : ", .open = "<", .close = ">") @@ -62,11 +63,20 @@ Also pay close attention to how data is represented in each column, as seen by t # render table info as markdown tables: tbl_head_md <- function(table_name, con) { - dplyr::tbl(con, table_name) |> - utils::head() |> - dplyr::collect() |> - knitr::kable() |> - paste(collapse = "\n") + x <- dplyr::tbl(con, table_name) + + # drop non-list types. + # backend doesn't support tidyselect predicates + types <- dplyr::collect(utils::head(x,1)) |> + purrr::map_lgl(function(x) class(x)[[1]] != "list") + keep <- names(types)[types] + + x |> + dplyr::select(dplyr::all_of(keep)) |> + utils::head() |> + dplyr::collect() |> + knitr::kable() |> + paste(collapse = "\n") } tbl_schema_md <- function(table_name, con) { diff --git a/README.Rmd b/README.Rmd index 9bf33d0..042aed7 100644 --- a/README.Rmd +++ b/README.Rmd @@ -45,7 +45,7 @@ system_prompt = create_prompt() ```{r } -tracts_agent <- ellmer::chat_vllm( +agent <- ellmer::chat_vllm( base_url = "https://llm.cirrus.carlboettiger.info/v1/", model = "kosbu/Llama-3.3-70B-Instruct-AWQ", api_key = Sys.getenv("CIRRUS_LLM_KEY"), @@ -53,7 +53,7 @@ tracts_agent <- ellmer::chat_vllm( api_args = list(temperature = 0) ) -resp <- tracts_agent$chat("Yolo County") +resp <- agent$chat("Yolo County") agent_query(resp) ``` diff --git a/README.md b/README.md index abb0554..b28fd08 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ system_prompt = create_prompt() ``` ``` r -tracts_agent <- ellmer::chat_vllm( +agent <- ellmer::chat_vllm( base_url = "https://llm.cirrus.carlboettiger.info/v1/", model = "kosbu/Llama-3.3-70B-Instruct-AWQ", api_key = Sys.getenv("CIRRUS_LLM_KEY"), @@ -42,7 +42,7 @@ tracts_agent <- ellmer::chat_vllm( api_args = list(temperature = 0) ) -resp <- tracts_agent$chat("Yolo County") +resp <- agent$chat("Yolo County") #> { #> "query": "CREATE OR REPLACE VIEW yolo_county AS SELECT * FROM censustracts #> WHERE COUNTY = 'Yolo County'", @@ -55,15 +55,15 @@ agent_query(resp) #> # Database: DuckDB v1.1.3 [unknown@Linux 6.9.3-76060903-generic:R 4.4.2/:memory:] #> STATE COUNTY FIPS h6 #>     -#>  1 California Yolo County 06113010102 862832b0fffffff -#>  2 California Yolo County 06113010313 862832b07ffffff -#>  3 California Yolo County 06113010401 8628304afffffff -#>  4 California Yolo County 06113010401 8628304c7ffffff -#>  5 California Yolo County 06113010401 8628304d7ffffff -#>  6 California Yolo County 06113010401 8628304dfffffff -#>  7 California Yolo County 06113010401 8628304f7ffffff -#>  8 California Yolo County 06113010401 862830417ffffff -#>  9 California Yolo County 06113010401 86283041fffffff -#> 10 California Yolo County 06113010401 862832b2fffffff +#>  1 California Yolo County 06113010102 862832B0FFFFFFF +#>  2 California Yolo County 06113010313 862832B07FFFFFF +#>  3 California Yolo County 06113010401 8628304AFFFFFFF +#>  4 California Yolo County 06113010401 8628304C7FFFFFF +#>  5 California Yolo County 06113010401 8628304D7FFFFFF +#>  6 California Yolo County 06113010401 8628304DFFFFFFF +#>  7 California Yolo County 06113010401 8628304F7FFFFFF +#>  8 California Yolo County 06113010401 862830417FFFFFF +#>  9 California Yolo County 06113010401 86283041FFFFFFF +#> 10 California Yolo County 06113010401 862832B2FFFFFFF #> # ℹ more rows ```