Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/databao_cli/commands/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def initialize_agent_from_dce(project_path: Path, model: str | None, temperature

agent = create_agent(domain=_domain, llm_config=llm_config)

num_sources = len(agent.sources.dbs) + len(agent.sources.dfs)
num_sources = len(agent.sources.dbs) + len(agent.sources.dfs) + len(agent.sources.dbts)
click.echo(f"Connected to {num_sources} data source(s)")
return agent

Expand Down
6 changes: 5 additions & 1 deletion src/databao_cli/ui/components/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def render_sources_info() -> None:

dbs = agent.dbs
dfs = agent.dfs
dbts = agent.sources.dbts

if not dbs and not dfs:
if not dbs and not dfs and not dbts:
st.caption("No sources configured")
return

Expand All @@ -96,6 +97,9 @@ def render_sources_info() -> None:
for name in dfs:
st.markdown(f"📊 **{name}** (DataFrame)")

for name in dbts:
st.markdown(f"🔧 **{name}** (dbt)")


def render_executor_selector() -> None:
"""Render executor type selector."""
Expand Down
8 changes: 7 additions & 1 deletion src/databao_cli/ui/pages/context_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def _render_sources(agent: Agent) -> None:
"""Render connected data sources from the active agent."""
dbs = agent.dbs
dfs = agent.dfs
dbts = agent.sources.dbts

if not dbs and not dfs:
if not dbs and not dfs and not dbts:
st.caption("No data sources configured in this project.")
return

Expand All @@ -145,3 +146,8 @@ def _render_sources(agent: Agent) -> None:
st.markdown("**DataFrames:**")
for name in dfs:
st.markdown(f"📊 **{name}**")

if dbts:
st.markdown("**dbt Projects:**")
for name in dbts:
st.markdown(f"🔧 **{name}**")
2 changes: 1 addition & 1 deletion src/databao_cli/ui/pages/welcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def render_welcome_page() -> None:
st.metric("Active Chats", num_chats)

with stat_col2:
num_sources = len(agent.dbs) + len(agent.dfs) if agent else 0
num_sources = (len(agent.sources.dbs) + len(agent.sources.dfs) + len(agent.sources.dbts)) if agent else 0
st.metric("Data Sources", num_sources)

with stat_col3:
Expand Down
53 changes: 53 additions & 0 deletions tests/test_ask.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import MagicMock, patch

from click.testing import CliRunner

from databao_cli.__main__ import cli
Expand All @@ -10,3 +12,54 @@ def test_ask_help() -> None:

assert result.exit_code == 0
assert "Chat with the Databao agent" in result.output


def _make_mock_agent(num_dbs: int = 0, num_dfs: int = 0, num_dbts: int = 0) -> MagicMock:
"""Create a mock agent with the given number of sources."""
agent = MagicMock()
agent.sources.dbs = {f"db{i}": MagicMock() for i in range(num_dbs)}
agent.sources.dfs = {f"df{i}": MagicMock() for i in range(num_dfs)}
agent.sources.dbts = {f"dbt{i}": MagicMock() for i in range(num_dbts)}
return agent


class TestSourceCounting:
"""Tests that dbt sources are included in the source count (DBA-126)."""

@patch("databao_cli.commands.ask.create_agent")
@patch("databao_cli.commands.ask.create_domain")
@patch("databao_cli.commands.ask.databao_project_status")
def test_dbt_only_sources_counted(
self,
mock_status: MagicMock,
mock_domain: MagicMock,
mock_create_agent: MagicMock,
) -> None:
"""When only dbt sources exist, count should be > 0."""
from databao_cli.commands.ask import initialize_agent_from_dce
from databao_cli.ui.project_utils import DatabaoProjectStatus

mock_status.return_value = DatabaoProjectStatus.VALID
mock_create_agent.return_value = _make_mock_agent(num_dbts=1)

agent = initialize_agent_from_dce(project_path=MagicMock(), model=None, temperature=0.0)
assert len(agent.sources.dbs) + len(agent.sources.dfs) + len(agent.sources.dbts) == 1

@patch("databao_cli.commands.ask.create_agent")
@patch("databao_cli.commands.ask.create_domain")
@patch("databao_cli.commands.ask.databao_project_status")
def test_mixed_sources_counted(
self,
mock_status: MagicMock,
mock_domain: MagicMock,
mock_create_agent: MagicMock,
) -> None:
"""All source types should be summed in the count."""
from databao_cli.commands.ask import initialize_agent_from_dce
from databao_cli.ui.project_utils import DatabaoProjectStatus

mock_status.return_value = DatabaoProjectStatus.VALID
mock_create_agent.return_value = _make_mock_agent(num_dbs=2, num_dfs=1, num_dbts=1)

agent = initialize_agent_from_dce(project_path=MagicMock(), model=None, temperature=0.0)
assert len(agent.sources.dbs) + len(agent.sources.dfs) + len(agent.sources.dbts) == 4
Loading