diff --git a/pyproject.toml b/pyproject.toml index 3ae232e..737419e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ ads-setup = "ads.utils:setup_database" ads-cli = "ads.cli:main" [tool.setuptools] +package-dir = {"" = "src/python"} zip-safe = false include-package-data = true @@ -52,8 +53,8 @@ include-package-data = true where = ["src/python"] [tool.setuptools.package-data] -"*" = ["*.json", "*.tsv"] -ads = ["bin/*"] +"*" = ["data/*.json", "data/*.tsv"] +"ads" = ["bin/*"] [project.optional-dependencies] dev = [ diff --git a/src/python/ads/__init__.py b/src/python/ads/__init__.py index b48e0b0..12c560f 100644 --- a/src/python/ads/__init__.py +++ b/src/python/ads/__init__.py @@ -1,5 +1,6 @@ import logging - +import os +from pathlib import Path debug = False @@ -18,10 +19,21 @@ logger.addHandler(handler) logger.setLevel(level) """ - + else: logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -from ads.models import (Affiliation, Document, Journal, Library) -from ads.client import SearchQuery \ No newline at end of file +# setup config for first time +CONFIG = Path.home() / ".ads/config.json" +if not os.path.exists(CONFIG): + from ads.settings import ADSConfig + + os.makedirs(CONFIG.parents[0], exist_ok=True) + ADSConfig().save(CONFIG) + print("Generated config at ~/.ads/config.json") + + +# namespace discovery +from ads.client import SearchQuery +from ads.models import Affiliation, Document, Journal, Library diff --git a/data/affiliations.tsv b/src/python/ads/data/affiliations.tsv similarity index 100% rename from data/affiliations.tsv rename to src/python/ads/data/affiliations.tsv diff --git a/data/affiliations_country.tsv b/src/python/ads/data/affiliations_country.tsv similarity index 100% rename from data/affiliations_country.tsv rename to src/python/ads/data/affiliations_country.tsv diff --git a/data/journals.json b/src/python/ads/data/journals.json similarity index 100% rename from data/journals.json rename to src/python/ads/data/journals.json diff --git a/src/python/ads/settings.py b/src/python/ads/settings.py index 6a03ff5..794fbe0 100644 --- a/src/python/ads/settings.py +++ b/src/python/ads/settings.py @@ -1,11 +1,12 @@ """Configuration settings for the ADS API client using Pydantic.""" -import os import json -import yaml -from typing import List, Optional +import os from pathlib import Path -from pydantic import BaseModel, Field, field_validator, ConfigDict +from typing import List, Optional + +import yaml +from pydantic import BaseModel, ConfigDict, Field, field_validator __all__ = ["ADSConfig", "get_config"] @@ -17,35 +18,27 @@ class ADSConfig(BaseModel): # API Configuration api_url: str = Field( - default='https://api.adsabs.harvard.edu/v1', - description="Base URL for the ADS API" + default="https://api.adsabs.harvard.edu/v1", + description="Base URL for the ADS API", ) # Token Configuration tokens: Optional[List[str]] = Field( default=None, - description="List of API tokens for rotation. If None, will use single token discovery." + description="List of API tokens for rotation. If None, will use single token discovery.", ) # Async Request Configuration async_limit_per_host: int = Field( - default=10, - ge=1, - le=100, - description="Maximum concurrent connections per host" + default=10, ge=1, le=100, description="Maximum concurrent connections per host" ) async_limit: int = Field( - default=100, - ge=1, - le=1000, - description="Maximum total concurrent connections" + default=100, ge=1, le=1000, description="Maximum total concurrent connections" ) async_timeout: int = Field( - default=30, - ge=1, - description="Timeout for async requests in seconds" + default=30, ge=1, description="Timeout for async requests in seconds" ) # Retry Configuration @@ -53,13 +46,13 @@ class ADSConfig(BaseModel): default=3, ge=0, le=10, - description="Maximum number of retries for failed requests" + description="Maximum number of retries for failed requests", ) retry_base_delay: float = Field( default=1.0, gt=0, - description="Base delay in seconds for exponential backoff (delay = base * 2^attempt)" + description="Base delay in seconds for exponential backoff (delay = base * 2^attempt)", ) # Query Configuration @@ -67,14 +60,14 @@ class ADSConfig(BaseModel): default=200, ge=1, le=200, - description="Maximum number of rows per API request (ADS limit is 200)" + description="Maximum number of rows per API request (ADS limit is 200)", ) # Rate limit threshold for token rotation (conservative threshold) rate_limit_threshold: int = Field( default=4500, ge=0, - description="Remaining requests threshold to trigger token rotation" + description="Remaining requests threshold to trigger token rotation", ) # Token discovery configuration (legacy support) @@ -83,15 +76,15 @@ class ADSConfig(BaseModel): "~/.ads/token", "~/.ads/dev_key", ], - description="Paths to search for API token files" + description="Paths to search for API token files", ) token_environ_vars: List[str] = Field( default_factory=lambda: ["ADS_API_TOKEN", "ADS_DEV_KEY"], - description="Environment variables to check for API tokens" + description="Environment variables to check for API tokens", ) - @field_validator('tokens', mode='before') + @field_validator("tokens", mode="before") @classmethod def validate_tokens(cls, v): """Ensure tokens is a list if provided as a single string.""" @@ -99,10 +92,10 @@ def validate_tokens(cls, v): return None if isinstance(v, str): # Split by comma if multiple tokens provided as string - return [t.strip() for t in v.split(',') if t.strip()] + return [t.strip() for t in v.split(",") if t.strip()] return v - @field_validator('token_files', mode='before') + @field_validator("token_files", mode="before") @classmethod def expand_token_files(cls, v): """Expand user paths in token files.""" @@ -126,7 +119,7 @@ def discover_tokens(self) -> Optional[List[str]]: expanded_path = os.path.expanduser(filepath) if os.path.exists(expanded_path): try: - with open(expanded_path, 'r') as f: + with open(expanded_path, "r") as f: token = f.read().strip() if token: return [token] @@ -150,10 +143,10 @@ def from_file(cls, filepath: str) -> "ADSConfig": if not path.exists(): raise FileNotFoundError(f"Config file not found: {filepath}") - with open(path, 'r') as f: - if path.suffix == '.json': + with open(path, "r") as f: + if path.suffix == ".json": data = json.load(f) - elif path.suffix in ('.yaml', '.yml'): + elif path.suffix in (".yaml", ".yml"): data = yaml.safe_load(f) else: raise ValueError(f"Unsupported config file format: {path.suffix}") @@ -175,31 +168,41 @@ def from_env(cls) -> "ADSConfig": for key, value in os.environ.items(): if key.startswith(prefix): - config_key = key[len(prefix):].lower() + config_key = key[len(prefix) :].lower() # Handle list values (comma-separated) - if config_key in ('tokens', 'token_files', 'token_environ_vars'): - env_config[config_key] = [v.strip() for v in value.split(',')] + if config_key in ("tokens", "token_files", "token_environ_vars"): + env_config[config_key] = [v.strip() for v in value.split(",")] # Handle integer values - elif config_key in ('async_limit_per_host', 'async_limit', 'async_timeout', - 'max_retries', 'max_rows_per_request', 'rate_limit_threshold'): + elif config_key in ( + "async_limit_per_host", + "async_limit", + "async_timeout", + "max_retries", + "max_rows_per_request", + "rate_limit_threshold", + ): env_config[config_key] = int(value) # Handle float values - elif config_key == 'retry_base_delay': + elif config_key == "retry_base_delay": env_config[config_key] = float(value) else: env_config[config_key] = value return cls(**env_config) + def save(self, path: os.PathLike | Path): + with open(path, "w") as f: + import json + + json.dump(ADSConfig().model_dump(), f, indent=4) + # Global configuration instance (singleton-like) _global_config: Optional[ADSConfig] = None def get_config( - config_file: Optional[str] = "~/.ads/config.json", - use_env: bool = False, - **kwargs + config_file: Optional[str] = "~/.ads/config.json", use_env: bool = False, **kwargs ) -> ADSConfig: """ Get or create the global ADSConfig instance. diff --git a/src/python/ads/utils.py b/src/python/ads/utils.py index 5dae41e..a04cc6e 100644 --- a/src/python/ads/utils.py +++ b/src/python/ads/utils.py @@ -82,7 +82,7 @@ def to_bibcode(iterable): def _get_data_path(basename=""): from ads import __path__ - return os.path.realpath(os.path.join(__path__[0], "../data", basename)) + return os.path.realpath(os.path.join(__path__[0], "data", basename)) def setup_database(): """ Set up the local database for Journals and Affiliations. """