Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ packages = [
[tool.poetry.dependencies]
python = ">=3.10,<4.0"
litellm = ">=1.65.1,<2.0.0"
fastapi = ">=0.95.0"
loguru = ">=0.7.3,<0.8.0"
cachetools = ">=5.5.2,<6.0.0"
ollama = ">=0.4.7,<0.5.0"
Expand Down
23 changes: 23 additions & 0 deletions readme-web.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

#### Step 2: Start the Backend

```bash
# Install Python dependencies
pip install -r api/requirements.txt

# Start the API server
python -m web.api.main
```

#### Step 3: Start the Frontend

```bash
NODE_OPTIONS='--inspect'
npm run dev
```

#### Step 4: Debug the Frontend

```bash
NODE_OPTIONS='--inspect' npm run dev
```
20 changes: 20 additions & 0 deletions src/starfish/data_factory/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ async def list_projects(self, limit: Optional[int] = None, offset: Optional[int]
"""List available projects."""
pass

@abstractmethod
async def delete_project(self, project_id: str) -> None:
"""Delete a project."""
pass

# Master Job methods
@abstractmethod
async def log_master_job_start(self, job_data: GenerationMasterJob) -> None:
Expand Down Expand Up @@ -173,6 +178,21 @@ async def list_execution_jobs_by_master_id_and_config_hash(self, master_job_id:
"""Retrieve execution job details by master job id and config hash."""
pass

@abstractmethod
async def save_dataset(self, project_id: str, dataset_name: str, dataset_data: Dict[str, Any]) -> str:
"""Save a dataset."""
pass

@abstractmethod
async def get_dataset(self, project_id: str, dataset_name: str) -> Dict[str, Any]:
"""Get a dataset."""
pass

@abstractmethod
async def list_datasets(self, project_id: str) -> List[Dict[str, Any]]:
"""List datasets for a project."""
pass


from .registry import Registry

Expand Down
16 changes: 16 additions & 0 deletions src/starfish/data_factory/storage/in_memory/in_memory_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,19 @@ async def list_record_metadata(self, master_job_uuid: str, job_uuid: str) -> Lis
async def list_execution_jobs_by_master_id_and_config_hash(self, master_job_id: str, config_hash: str, job_status: str) -> Optional[GenerationJob]:
"""Retrieve execution job details by master job id and config hash."""
pass

async def save_dataset(self, project_id: str, dataset_name: str, dataset_data: Dict[str, Any]) -> str:
"""Save a dataset."""
pass

async def get_dataset(self, project_id: str, dataset_name: str) -> Dict[str, Any]:
"""Get a dataset."""
pass

async def list_datasets(self, project_id: str) -> List[Dict[str, Any]]:
"""List datasets for a project."""
pass

async def delete_project(self, project_id: str) -> None:
"""Delete a project."""
pass
26 changes: 26 additions & 0 deletions src/starfish/data_factory/storage/local/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CONFIGS_DIR = "configs"
DATA_DIR = "data"
ASSOCIATIONS_DIR = "associations"
DATASETS_DIR = "datasets"


class FileSystemDataHandler:
Expand All @@ -31,6 +32,7 @@ def __init__(self, data_base_path: str):
self.config_path = os.path.join(self.data_base_path, CONFIGS_DIR)
self.record_data_path = os.path.join(self.data_base_path, DATA_DIR)
self.assoc_path = os.path.join(self.data_base_path, ASSOCIATIONS_DIR)
self.datasets_path = os.path.join(self.data_base_path, DATASETS_DIR)
# TODO: Consider locks if implementing JSONL appends for associations

async def ensure_base_dirs(self):
Expand Down Expand Up @@ -85,6 +87,30 @@ def generate_request_config_path_impl(self, master_job_id: str) -> str:
path = os.path.join(self.config_path, f"{master_job_id}.request.json")
return path # Return absolute path as the reference

async def save_dataset_impl(self, project_id: str, dataset_name: str, dataset_data: Dict[str, Any]):
path = os.path.join(self.datasets_path, project_id, f"{dataset_name}.json")
await self._save_json_file(path, dataset_data)
return path

async def get_dataset_impl(self, project_id: str, dataset_name: str) -> Dict[str, Any]:
path = os.path.join(self.datasets_path, project_id, f"{dataset_name}.json")
return await self._read_json_file(path)

async def list_datasets_impl(self, project_id: str) -> list[Dict[str, Any]]:
path = os.path.join(self.datasets_path, project_id)
files = await aio_os.listdir(path)
datasets = []

for i in range(len(files)):
file = files[i]
if file.endswith(".json"):
dataset = await self._read_json_file(os.path.join(path, file))

datasets.append(
{"id": i, "name": file[:-5], "created_at": file.split("__")[1][:-5], "record_count": len(dataset), "data": dataset, "status": "completed"}
)
return datasets

async def get_request_config_impl(self, config_ref: str) -> Dict[str, Any]:
return await self._read_json_file(config_ref) # Assumes ref is absolute path

Expand Down
14 changes: 13 additions & 1 deletion src/starfish/data_factory/storage/local/local_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ async def save_project(self, project_data: Project) -> None:
async def get_project(self, project_id: str) -> Optional[Project]:
return await self._metadata_handler.get_project_impl(project_id)

async def delete_project(self, project_id: str) -> None:
await self._metadata_handler.delete_project_impl(project_id)

async def list_projects(self, limit: Optional[int] = None, offset: Optional[int] = None) -> List[Project]:
return await self._metadata_handler.list_projects_impl(limit, offset)
return await self._metadata_handler.list_projects_impl_data_template(limit, offset)

async def log_master_job_start(self, job_data: GenerationMasterJob) -> None:
await self._metadata_handler.log_master_job_start_impl(job_data)
Expand Down Expand Up @@ -154,6 +157,15 @@ async def list_execution_jobs_by_master_id_and_config_hash(self, master_job_id:
async def list_record_metadata(self, master_job_uuid: str, job_uuid: str) -> List[Record]:
return await self._metadata_handler.list_record_metadata_impl(master_job_uuid, job_uuid)

async def save_dataset(self, project_id: str, dataset_name: str, dataset_data: Dict[str, Any]) -> str:
return await self._data_handler.save_dataset_impl(project_id, dataset_name, dataset_data)

async def get_dataset(self, project_id: str, dataset_name: str) -> Dict[str, Any]:
return await self._data_handler.get_dataset_impl(project_id, dataset_name)

async def list_datasets(self, project_id: str) -> List[Dict[str, Any]]:
return await self._data_handler.list_datasets_impl(project_id)


@register_storage("local")
def create_local_storage(storage_uri: str, data_storage_uri_override: Optional[str] = None) -> LocalStorage:
Expand Down
37 changes: 34 additions & 3 deletions src/starfish/data_factory/storage/local/metadata_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,28 @@ async def batch_save_execution_jobs(self, jobs: List[GenerationJob]):

async def save_project_impl(self, project_data: Project):
sql = """
INSERT OR REPLACE INTO Projects (project_id, name, description, created_when, updated_when)
VALUES (?, ?, ?, ?, ?);
INSERT OR REPLACE INTO Projects (project_id, name, template_name, description, created_when, updated_when)
VALUES (?, ?, ?, ?, ?,?);
"""
params = (project_data.project_id, project_data.name, project_data.description, project_data.created_when, project_data.updated_when)
params = (
project_data.project_id,
project_data.name,
project_data.template_name,
project_data.description,
project_data.created_when,
project_data.updated_when,
)
await self._execute_sql(sql, params)

async def get_project_impl(self, project_id: str) -> Optional[Project]:
sql = "SELECT * FROM Projects WHERE project_id = ?"
row = await self._fetchone_sql(sql, (project_id,))
return _row_to_pydantic(Project, row)

async def delete_project_impl(self, project_id: str) -> None:
sql = "DELETE FROM Projects WHERE project_id = ?"
await self._execute_sql(sql, (project_id,))

async def list_projects_impl(self, limit: Optional[int], offset: Optional[int]) -> List[Project]:
sql = "SELECT * FROM Projects ORDER BY name"
params: List[Any] = []
Expand All @@ -256,6 +267,26 @@ async def list_projects_impl(self, limit: Optional[int], offset: Optional[int])
rows = await self._fetchall_sql(sql, tuple(params))
return [_row_to_pydantic(Project, row) for row in rows]

async def list_projects_impl_data_template(self, limit: Optional[int], offset: Optional[int]) -> List[Project]:
sql = "SELECT * FROM Projects WHERE template_name IS NOT NULL ORDER BY name"
params: List[Any] = []

# SQLite requires LIMIT when using OFFSET
if offset is not None:
if limit is not None:
sql += " LIMIT ? OFFSET ?"
params.extend([limit, offset])
else:
# If no explicit limit but with offset, use a high limit
sql += " LIMIT 1000 OFFSET ?"
params.append(offset)
elif limit is not None:
sql += " LIMIT ?"
params.append(limit)

rows = await self._fetchall_sql(sql, tuple(params))
return [_row_to_pydantic(Project, row) for row in rows]

async def log_master_job_start_impl(self, job_data: GenerationMasterJob):
data_dict = _serialize_pydantic_for_db(job_data)
cols = ", ".join(data_dict.keys())
Expand Down
1 change: 1 addition & 0 deletions src/starfish/data_factory/storage/local/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
CREATE TABLE IF NOT EXISTS Projects (
project_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
template_name TEXT,
description TEXT,
created_when TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_when TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
Expand Down
1 change: 1 addition & 0 deletions src/starfish/data_factory/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Project(BaseModel):

project_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique project identifier.")
name: str = Field(..., description="User-friendly project name.")
template_name: Optional[str] = Field(None, description="template name.")
description: Optional[str] = Field(None, description="Optional description.")
created_when: datetime.datetime = Field(default_factory=utc_now)
updated_when: datetime.datetime = Field(default_factory=utc_now)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,17 @@ class GenerateByTopicInput(BaseModel):
dependencies=[],
input_example="""{
"user_instruction": "Generate Q&A pairs about machine learning concepts",
"num_records": 100,
"records_per_topic": 5,
"num_records": 4,
"records_per_topic": 2,
"topics": [
"supervised learning",
"unsupervised learning",
{"reinforcement learning": 3}, # This means generate 3 records for this topic
"neural networks",
],
"topic_model_name": "openai/gpt-4",
"topic_model_name": "openai/gpt-4.1-mini",
"topic_model_kwargs": {"temperature": 0.7},
"generation_model_name": "openai/gpt-4",
"generation_model_name": "openai/gpt-4.1-mini",
"generation_model_kwargs": {"temperature": 0.8, "max_tokens": 200},
"output_schema": [
{"name": "question", "type": "str"},
Expand Down
41 changes: 41 additions & 0 deletions web/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.

# dependencies
/node_modules
/.pnp
.pnp.js
.yarn/install-state.gz

# testing
/coverage

# next.js
/.next/
/out/

# production
/build

# misc
.DS_Store
*.pem

# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*

# local env files
.env*.local

# vercel
.vercel

# typescript
*.tsbuildinfo
next-env.d.ts

# amplify
.amplify
amplify_outputs*
amplifyconfiguration*
4 changes: 4 additions & 0 deletions web/CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## Code of Conduct
This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
opensource-codeofconduct@amazon.com with any additional questions or comments.
1 change: 1 addition & 0 deletions web/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Empty file to make this a package
51 changes: 51 additions & 0 deletions web/api/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager

from starfish.common.logger import get_logger

logger = get_logger(__name__)

from starfish.common.env_loader import load_env_file
from web.api.storage import setup_storage, close_storage


@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
await setup_storage()
yield
# Shutdown (if needed)
await close_storage()


# Import routers
from .routers import template, dataset, project

current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.normpath(os.path.join(current_dir, "..", "..")) # Go up two levels from web/api/
env_path = os.path.join(root_dir, ".env")
load_env_file(env_path=env_path)

# Initialize FastAPI app
app = FastAPI(title="Streaming API", lifespan=lifespan, description="API for streaming chat completions")

# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)

# Include routers
app.include_router(template.router)
app.include_router(dataset.router)
app.include_router(project.router)


# Helper function to get adalflow root path
def get_adalflow_default_root_path():
return os.path.expanduser(os.path.join("~", ".adalflow"))
Loading
Loading