Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,16 @@ ads-setup = "ads.utils:setup_database"
ads-cli = "ads.cli:main"

[tool.setuptools]
package-dir = {"" = "src/python"}
zip-safe = false
include-package-data = true

[tool.setuptools.packages.find]
where = ["src/python"]

[tool.setuptools.package-data]
"*" = ["*.json", "*.tsv"]
ads = ["bin/*"]
"*" = ["data/*.json", "data/*.tsv"]
"ads" = ["bin/*"]

[project.optional-dependencies]
dev = [
Expand Down
20 changes: 16 additions & 4 deletions src/python/ads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

import os
from pathlib import Path

debug = False

Expand All @@ -18,10 +19,21 @@
logger.addHandler(handler)
logger.setLevel(level)
"""

else:
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

from ads.models import (Affiliation, Document, Journal, Library)
from ads.client import SearchQuery
# setup config for first time
CONFIG = Path.home() / ".ads/config.json"
if not os.path.exists(CONFIG):
from ads.settings import ADSConfig

os.makedirs(CONFIG.parents[0], exist_ok=True)
ADSConfig().save(CONFIG)
print("Generated config at ~/.ads/config.json")


# namespace discovery
from ads.client import SearchQuery
from ads.models import Affiliation, Document, Journal, Library
File renamed without changes.
File renamed without changes.
83 changes: 43 additions & 40 deletions src/python/ads/settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Configuration settings for the ADS API client using Pydantic."""

import os
import json
import yaml
from typing import List, Optional
import os
from pathlib import Path
from pydantic import BaseModel, Field, field_validator, ConfigDict
from typing import List, Optional

import yaml
from pydantic import BaseModel, ConfigDict, Field, field_validator

__all__ = ["ADSConfig", "get_config"]

Expand All @@ -17,64 +18,56 @@ class ADSConfig(BaseModel):

# API Configuration
api_url: str = Field(
default='https://api.adsabs.harvard.edu/v1',
description="Base URL for the ADS API"
default="https://api.adsabs.harvard.edu/v1",
description="Base URL for the ADS API",
)

# Token Configuration
tokens: Optional[List[str]] = Field(
default=None,
description="List of API tokens for rotation. If None, will use single token discovery."
description="List of API tokens for rotation. If None, will use single token discovery.",
)

# Async Request Configuration
async_limit_per_host: int = Field(
default=10,
ge=1,
le=100,
description="Maximum concurrent connections per host"
default=10, ge=1, le=100, description="Maximum concurrent connections per host"
)

async_limit: int = Field(
default=100,
ge=1,
le=1000,
description="Maximum total concurrent connections"
default=100, ge=1, le=1000, description="Maximum total concurrent connections"
)

async_timeout: int = Field(
default=30,
ge=1,
description="Timeout for async requests in seconds"
default=30, ge=1, description="Timeout for async requests in seconds"
)

# Retry Configuration
max_retries: int = Field(
default=3,
ge=0,
le=10,
description="Maximum number of retries for failed requests"
description="Maximum number of retries for failed requests",
)

retry_base_delay: float = Field(
default=1.0,
gt=0,
description="Base delay in seconds for exponential backoff (delay = base * 2^attempt)"
description="Base delay in seconds for exponential backoff (delay = base * 2^attempt)",
)

# Query Configuration
max_rows_per_request: int = Field(
default=200,
ge=1,
le=200,
description="Maximum number of rows per API request (ADS limit is 200)"
description="Maximum number of rows per API request (ADS limit is 200)",
)

# Rate limit threshold for token rotation (conservative threshold)
rate_limit_threshold: int = Field(
default=4500,
ge=0,
description="Remaining requests threshold to trigger token rotation"
description="Remaining requests threshold to trigger token rotation",
)

# Token discovery configuration (legacy support)
Expand All @@ -83,26 +76,26 @@ class ADSConfig(BaseModel):
"~/.ads/token",
"~/.ads/dev_key",
],
description="Paths to search for API token files"
description="Paths to search for API token files",
)

token_environ_vars: List[str] = Field(
default_factory=lambda: ["ADS_API_TOKEN", "ADS_DEV_KEY"],
description="Environment variables to check for API tokens"
description="Environment variables to check for API tokens",
)

@field_validator('tokens', mode='before')
@field_validator("tokens", mode="before")
@classmethod
def validate_tokens(cls, v):
"""Ensure tokens is a list if provided as a single string."""
if v is None:
return None
if isinstance(v, str):
# Split by comma if multiple tokens provided as string
return [t.strip() for t in v.split(',') if t.strip()]
return [t.strip() for t in v.split(",") if t.strip()]
return v

@field_validator('token_files', mode='before')
@field_validator("token_files", mode="before")
@classmethod
def expand_token_files(cls, v):
"""Expand user paths in token files."""
Expand All @@ -126,7 +119,7 @@ def discover_tokens(self) -> Optional[List[str]]:
expanded_path = os.path.expanduser(filepath)
if os.path.exists(expanded_path):
try:
with open(expanded_path, 'r') as f:
with open(expanded_path, "r") as f:
token = f.read().strip()
if token:
return [token]
Expand All @@ -150,10 +143,10 @@ def from_file(cls, filepath: str) -> "ADSConfig":
if not path.exists():
raise FileNotFoundError(f"Config file not found: {filepath}")

with open(path, 'r') as f:
if path.suffix == '.json':
with open(path, "r") as f:
if path.suffix == ".json":
data = json.load(f)
elif path.suffix in ('.yaml', '.yml'):
elif path.suffix in (".yaml", ".yml"):
data = yaml.safe_load(f)
else:
raise ValueError(f"Unsupported config file format: {path.suffix}")
Expand All @@ -175,31 +168,41 @@ def from_env(cls) -> "ADSConfig":

for key, value in os.environ.items():
if key.startswith(prefix):
config_key = key[len(prefix):].lower()
config_key = key[len(prefix) :].lower()
# Handle list values (comma-separated)
if config_key in ('tokens', 'token_files', 'token_environ_vars'):
env_config[config_key] = [v.strip() for v in value.split(',')]
if config_key in ("tokens", "token_files", "token_environ_vars"):
env_config[config_key] = [v.strip() for v in value.split(",")]
# Handle integer values
elif config_key in ('async_limit_per_host', 'async_limit', 'async_timeout',
'max_retries', 'max_rows_per_request', 'rate_limit_threshold'):
elif config_key in (
"async_limit_per_host",
"async_limit",
"async_timeout",
"max_retries",
"max_rows_per_request",
"rate_limit_threshold",
):
env_config[config_key] = int(value)
# Handle float values
elif config_key == 'retry_base_delay':
elif config_key == "retry_base_delay":
env_config[config_key] = float(value)
else:
env_config[config_key] = value

return cls(**env_config)

def save(self, path: os.PathLike | Path):
with open(path, "w") as f:
import json

json.dump(ADSConfig().model_dump(), f, indent=4)


# Global configuration instance (singleton-like)
_global_config: Optional[ADSConfig] = None


def get_config(
config_file: Optional[str] = "~/.ads/config.json",
use_env: bool = False,
**kwargs
config_file: Optional[str] = "~/.ads/config.json", use_env: bool = False, **kwargs
) -> ADSConfig:
"""
Get or create the global ADSConfig instance.
Expand Down
2 changes: 1 addition & 1 deletion src/python/ads/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def to_bibcode(iterable):

def _get_data_path(basename=""):
from ads import __path__
return os.path.realpath(os.path.join(__path__[0], "../data", basename))
return os.path.realpath(os.path.join(__path__[0], "data", basename))

def setup_database():
""" Set up the local database for Journals and Affiliations. """
Expand Down