diff --git a/.env.example b/.env.example index 2a7de69..6cebd8c 100644 --- a/.env.example +++ b/.env.example @@ -19,3 +19,18 @@ LOG_LEVEL=INFO # 处理选项 USE_TQDM=True # 是否显示进度条 + +# SMB 网络访问配置(可选,用于访问远程 PC 上的 TDX 安装目录) +# 启用后 TDX_PATH 可不填 +SMB_ENABLED=false +SMB_HOST=192.168.1.100 +SMB_SHARE=tdx_share +SMB_USER=myuser +SMB_PASSWORD=mypassword +# TDX 在共享目录内的相对路径,若共享根目录即为 TDX 安装目录则留空 +SMB_TDX_PATH=TDX +SMB_PORT=445 +# 并发下载线程数(批量 SMB 模式) +SMB_WORKERS=16 +# 每批同步的股票数量 +SMB_BATCH_SIZE=200 diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index dcfa3db..882b82a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,6 +1,3 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - name: Python package on: @@ -11,30 +8,23 @@ on: jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + pip install pytest + pip install -r requirements.txt - name: Test with pytest run: | - pytest + pytest tests/ -v diff --git a/.gitignore b/.gitignore index ab735a7..2ffc6b7 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,5 @@ poetry.lock output/ */__pycache__/ +tdx_data.db* +.claude/ \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index d40708d..5aa138b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,26 +4,29 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## 项目概述 -tdx2db:从本地通达信(TDX)行情软件读取 A 股数据,增量同步到数据库。是量化分析工作站的数据入口。 +tdx2db:从本地通达信(TDX)行情软件读取 A 股日线数据,增量同步到数据库。支持作为 Python 包被其他项目调用。 ## 常用命令 ```bash # 安装依赖 pip install -r requirements.txt +# 或安装为可编辑包(支持 import) +pip install -e . -# 一键增量同步(日线 + 5/15/30/60 分钟线)— 日常使用这一个命令即可 +# 一键增量同步日线数据 — 日常使用这一个命令 python main.py sync # 单独同步 -python main.py daily --db-only --auto-start --incremental -python main.py minutes --db-only --auto-start --incremental +python main.py daily --incremental +python main.py daily --code 000001 --start 20240101 # 同步股票列表 -python main.py stock-list --db-only -``` +python main.py stock-list -无测试套件。验证方式是运行 `sync` 命令后检查数据库数据。 +# 运行测试 +python -m pytest tests/ -v +``` ## 架构 @@ -32,39 +35,63 @@ python main.py stock-list --db-only ``` CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (storage.py) ↓ ↓ ↓ ↓ - argparse pytdx 读取本地 校验 + 重采样 + 均线 SQLAlchemy 批量写库 - 命令分发 + .day/.lc5 文件 (OHLCV 校验, resample, 支持增量 ON CONFLICT - 同步编排 MA5~MA250) 表名白名单保护 + argparse pytdx 读取本地 校验 + 复权处理 SQLAlchemy 批量写库 + 命令分发 + .day 文件 (OHLCV 校验, 前/后复权) 支持增量 ON CONFLICT + 同步编排 表名白名单保护 ``` -- **cli.py**: 除命令分发外,`sync_all_daily_data` / `sync_all_min_data` / `sync_single_stock_min_data` 编排逐股票流式同步 -- **config.py**: 全局单例 `config`,从 `.env` 加载配置(TDX_PATH、DB_*) +- **cli.py**: 命令分发 + `sync_all_daily()` 逐股票流式同步 +- **config.py**: 全局单例 `config`,从 `.env` 加载配置 - **logger.py**: 全局单例 `logger` +- **`__init__.py`**: 暴露 `TdxDailySync` 公共 API -### 关键数据流 +## 关键数据流 -日线和分钟线均为**逐股票流式处理**,不全量加载到内存: +逐股票流式处理,不全量加载到内存: -1. **日线**: 逐股票读取 `vipdoc/{sz,sh}/lday/*.day` → `process_daily_data()` 校验 OHLCV + 计算均线 → 增量写入 `daily_data` 表 -2. **分钟线**: 逐股票读取 `.lc5`(5 分钟)→ `resample_ohlcv()` 重采样为 15/30/60 分钟 → `process_min_data()` 校验 + 均线 → 分别写入 `minute{5,15,30,60}_data` 表 -3. **增量同步**: `save_incremental()` 使用批量 executemany + `ON CONFLICT DO NOTHING`(PostgreSQL)/ `INSERT IGNORE`(MySQL)跳过重复。分钟线按股票精确查询最新日期(`get_latest_datetime_by_code`),日线逐股票增量。 +1. 读取 `vipdoc/{sz,sh,bj}/lday/*.day` → `process_daily_data()` 校验 OHLCV + 复权 +2. 增量策略:`get_all_latest_dates()` 一次查询所有股票最新日期;若有除权事件则 `delete_stock_data()` + 全量重写 +3. `save_incremental()` 使用 `ON CONFLICT DO NOTHING`(PG)/ `INSERT OR IGNORE`(SQLite)/ `INSERT IGNORE`(MySQL) -### 数据库表 +## 数据库表 | 表名 | 唯一约束 | 用途 | |------|----------|------| -| `daily_data` | (code, date) | 日线数据 | -| `minute{5,15,30,60}_data` | (code, datetime) | 分钟线数据 | +| `daily_data` | (code, date) | 日线数据,date 为 YYYYMMDD 整数 | | `stock_info` | code | 股票列表 | -| `block_stock_relation` | — | 板块关系(未完整实现) | -唯一约束需通过 `scripts/add_constraints.sql` 手动添加。 +唯一约束由 SQLAlchemy `UniqueConstraint` 在建表时自动创建,无需手动执行 SQL 脚本。 -### 股票代码格式 +## 股票代码格式 -代码带市场前缀:`sz000001`、`sh600000`。深圳 market=0,上海 market=1。 -A 股筛选规则:深圳 `000/001/002/300` 开头,上海 `60/688` 开头。 +- CLI `--code` 参数:纯 6 位数字,如 `000001`、`600000`、`920001`,市场自动识别 +- 内部流转层:带市场前缀,如 `sz000001`、`sh600000`、`bj920001`(reader 内部使用) +- 数据库层:纯 6 位数字,如 `000001`(reader 写入时截取) +- 深圳 market=0,上海 market=1,北京 market=2 +- A 股筛选:深圳 `000/001/002/300` 开头,上海 `60/688` 开头,北交所 `8xxxxx` 或 `92xxxx` 开头 +- 市场自动识别规则:6 开头 → 上海(sh),8 或 92 开头 → 北京(bj),其他 → 深圳(sz) ## 配置 -通过 `.env` 文件配置,必填:`TDX_PATH`、`DB_TYPE`、`DB_HOST`、`DB_NAME`、`DB_USER`、`DB_PASSWORD`。 +通过 `.env` 文件配置: + +| 变量 | 必填 | 说明 | +|------|------|------| +| `TDX_PATH` | 是 | 通达信安装目录 | +| `DB_TYPE` | 否 | `sqlite`(默认)/ `mysql` / `postgresql` | +| `DB_NAME` | 否 | 数据库名,SQLite 时为文件名(生成 `.db`) | +| `DB_HOST` | MySQL/PG 必填 | 数据库主机 | +| `DB_USER` | MySQL/PG 必填 | 数据库用户名 | +| `DB_PASSWORD` | MySQL/PG 必填 | 数据库密码 | +| `DB_PORT` | 否 | 默认 `5432` | +| `DB_BATCH_SIZE` | 否 | 批量写入大小,默认 `10000` | +| `USE_TQDM` | 否 | 是否显示进度条,默认 `True` | + +## sync 命令增量策略 + +`python main.py sync` 内部行为: + +- 一次 SQL 查询获取所有股票最新日期(`SELECT code, MAX(date) FROM daily_data GROUP BY code`) +- 对每只股票:检查 gbbq 中是否有除权事件(category=1)发生在 last_date 之后 + - 有除权 → 删除该股旧数据,全量重写(保证复权价格正确) + - 无除权 → 只写入 last_date 之后的新数据 diff --git a/README.md b/README.md index 1ccfad1..a52c39d 100644 --- a/README.md +++ b/README.md @@ -1,107 +1,193 @@ -# 通达信数据处理工具 +# tdx2db -读取本地通达信股票数据,增量同步到数据库。 +从本地通达信(TDX)行情软件读取 A 股日线数据,增量同步到数据库。支持作为 Python 包被其他项目调用。 -## 快速开始 +## 测试环境 -```bash -# 一键同步所有数据(日线 + 5/15/30/60分钟线) -python main.py sync -``` +- 通达信版本:**金融终端 V7.72(64位)** + +## 特性 + +- 同步深圳/上海/北交所全量 A 股日线数据(含科创板) +- 前复权 / 后复权 / 不复权,默认前复权 +- 增量更新:有除权事件的个股自动全量重写,确保复权价格正确 +- 包含换手率数据(`turnover_rate`) +- 日期格式:`YYYYMMDD` 字符串(便于范围查询) +- 数据库:SQLite(默认)/ MySQL / PostgreSQL ## 安装 ```bash -# Python >= 3.10 +# 直接安装依赖 pip install -r requirements.txt -# 复制并编辑配置文件 -cp .env.example .env +# 或作为包安装(支持被其他项目 import) +pip install -e . ``` -**.env 必填配置**: +## 配置 + +复制 `.env.example` 为 `.env` 并填写: + ``` -TDX_PATH=D:\通达信安装目录 -DB_TYPE=postgresql -DB_HOST=localhost -DB_NAME=tdx_data +TDX_PATH=/path/to/tdx # 通达信安装目录(必填) +DB_TYPE=sqlite # sqlite / mysql / postgresql +DB_NAME=tdx_data # SQLite 时为文件名(生成 tdx_data.db) +DB_HOST=localhost # MySQL/PostgreSQL 必填 DB_USER=postgres DB_PASSWORD=your_password +DB_BATCH_SIZE=10000 +USE_TQDM=True ``` -## 首次使用 +## SMB 网络访问模式 -1. 打开通达信 → 选项 → 盘后数据下载 → 下载日线和分钟线数据 +如果通达信安装在另一台 Windows PC 上,可以通过 SMB 协议远程读取数据,无需把软件安装在运行本程序的机器上。 -2. 同步股票列表: -```bash -python main.py stock-list --db-only -``` +### 1. 在 Windows PC 上共享 TDX 目录 -3. 一键同步所有行情数据: -```bash -python main.py sync +右键 TDX 安装目录(如 `D:\new_tdx64`)→ 属性 → 共享 → 高级共享: + +- 勾选"共享此文件夹" +- 设置共享名,如 `new_tdx64` +- 权限 → 添加需要访问的账户,授予"读取"权限 + +### 2. 创建专用本地账户(推荐) + +**强烈建议**新建一个 Windows 本地账户用于 SMB 访问,而不是使用微软账户。微软账户通过 NTLM 网络认证时兼容性较差,容易登录失败。 + +在 Windows PC 上以管理员身份打开 PowerShell: + +```powershell +# 创建本地账户(替换为你想要的用户名和密码) +net user tdxread YourPassword123 /add +net localgroup Users tdxread /add ``` -## 启用增量同步(推荐) +然后在共享权限中把 `tdxread` 加入,授予读取权限。 -> 增量同步可自动跳过重复数据,大幅提升每日更新效率。 +### 3. 配置 .env -**老用户**(已有数据库表)需执行一次约束脚本: -```bash -# PostgreSQL -psql -U your_user -d your_database -f scripts/add_constraints.sql ``` +SMB_ENABLED=true +SMB_HOST=192.168.1.100 # Windows PC 的 IP 或主机名 +SMB_SHARE=new_tdx64 # 共享名 +SMB_USER=tdxread # 本地账户用户名 +SMB_PASSWORD=YourPassword123 +SMB_TDX_PATH= # TDX 在共享内的相对路径,共享根目录就是 TDX 目录时留空 +SMB_PORT=445 +``` + +启用 SMB 模式后,`TDX_PATH` 可以不填。 + +## 准备工作:下载历史盘后数据 -**新用户**同样建议执行,以启用增量同步功能。 +使用本程序前,需要先在通达信中下载历史盘后数据: -脚本作用:为 `daily_data`、`minute*_data` 表添加唯一约束,确保 `(code, date/datetime)` 不重复。 +1. 打开通达信客户端 +2. 菜单栏 → **选项** → **盘后数据下载** +3. 选择需要的历史数据范围并下载完成 -## 每日更新 +> 程序读取的是通达信本地 `.day` 文件,必须先确保数据已通过上述方式下载到本地,否则无法同步。 + +## 命令行使用 ```bash -python main.py sync -``` +# 同步股票列表 +python main.py stock-list -程序会自动检测数据库最新日期,只同步新数据。 +# 一键增量同步所有股票日线(日常使用这一个命令) +python main.py sync -## 其他命令 +# 同步所有股票日线(全量) +python main.py daily -
-单独同步日线/分钟线 +# 同步指定股票(6位代码,自动识别市场) +python main.py daily --code 000001 -```bash -# 日线增量同步 -python main.py daily --db-only --auto-start --incremental +# 指定日期范围 +python main.py daily --start 20240101 --end 20241231 -# 分钟线增量同步 -python main.py minutes --db-only --auto-start --incremental +# 指定复权类型 +python main.py sync --adj backward ``` -
-
-指定日期范围 +安装为包后也可直接使用 `tdx2db` 命令: ```bash -python main.py daily --db-only --start_date 2025-01-01 --end_date 2025-01-31 -python main.py minutes --db-only --start_date 2025-01-01 +tdx2db sync ``` -
-
-导出到 CSV +## 作为 Python 包调用 -```bash -python main.py daily --csv-only -python main.py minutes --csv-only +```python +from tdx2db import TdxDailySync + +sync = TdxDailySync( + tdx_path="/path/to/tdx", + db_url="sqlite:///data.db", +) + +# 同步所有股票 +sync.sync_all(adj_type='forward') + +# 同步单只股票 +sync.sync_stock('000001', start_date=20240101) + +# 查询数据 +df = sync.get_daily('000001', start_date=20240101, end_date=20241231) +print(df.head()) ``` -
-## 数据库支持 +## 数据表结构 + +数据库包含以下三张表,由 SQLAlchemy 在首次运行时自动创建,无需手动执行 SQL。 + +### daily_data(日线数据) + +| 列 | 类型 | 说明 | +|----|------|------| +| id | Integer | 自增主键 | +| stock_code | String(12) | 股票代码(6位,如 `000001`) | +| market | Integer | 市场(0=深圳,1=上海,2=北京) | +| date | String(8) | 日期,格式 `YYYYMMDD` | +| open | Float | 开盘价(复权后) | +| high | Float | 最高价(复权后) | +| low | Float | 最低价(复权后) | +| close | Float | 收盘价(复权后) | +| volume | Float | 成交量(手) | +| amount | Float | 成交额(元) | +| adj_factor | Float | 复权因子(前复权时 < 1,不复权时 = 1.0) | +| turnover_rate | Float | 换手率(%),无法计算时为 NULL | -- PostgreSQL(推荐) -- MySQL -- SQLite +唯一约束:`(stock_code, date)` + +### stock_info(股票列表) + +| 列 | 类型 | 说明 | +|----|------|------| +| stock_code | String(12) | 股票代码(主键,如 `000001`) | +| stock_name | String(50) | 股票名称(如 `平安银行`) | + +唯一约束:`stock_code`(即主键) + +### kline_statistics(同步统计) + +每次 `sync` 命令完成后写入一条统计记录,用于追踪历次同步情况。 + +| 列 | 类型 | 说明 | +|----|------|------| +| id | Integer | 自增主键 | +| stock_count | Integer | 本次同步的股票数量 | +| total_rows | Integer | 同步后 daily_data 的总行数 | +| sync_time | DateTime | 同步完成时间 | + +## 运行测试 + +```bash +pip install pytest +python -m pytest tests/ -v +``` ## 许可证 diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e67467c..0000000 --- a/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""TDX2DB""" -__version__ = '0.0.1' -__all__ = ['src'] diff --git a/main.py b/main.py index 85b538d..2081a48 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - -""" -通达信数据处理工具入口点 - -此脚本是程序的主入口点,用于启动通达信数据处理工具。 -支持以下功能: -- 获取股票列表:python main.py stock-list -- 获取日线数据:python main.py daily [--code CODE] [--market MARKET] -- 获取分钟线数据:python main.py minute --code CODE --market MARKET [--freq {1,5}] -- 获取分钟线数据:python main.py minutes --code CODE --market MARKET [--freq {1,5}] - -使用 python main.py -h 查看完整的帮助信息。 -""" +"""tdx2db 命令行入口""" import sys -from src.cli import main +from src.tdx2db.cli import main if __name__ == '__main__': try: diff --git a/pyproject.toml b/pyproject.toml index 3d7df54..9fc9091 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,16 @@ [tool.poetry] -name = "tdx-data-processor" -version = "0.1.0" -description = "一个使用pytdx读取本地股票数据并存储到数据库或CSV的程序" -authors = ["我如此拉风 "] +name = "tdx2db" +version = "0.2.0" +description = "从通达信本地文件同步 A 股日线数据到数据库" +authors = ["jaden1q84 "] readme = "README.md" +packages = [{include = "tdx2db", from = "src"}] + +[tool.poetry.scripts] +tdx2db = "tdx2db.cli:main" [tool.poetry.dependencies] -python = "^3.11" +python = ">=3.10" pytdx = "^1.72" pandas = "^2.2.3" sqlalchemy = "^2.0.40" @@ -15,6 +19,9 @@ python-dotenv = "^1.1.0" pymysql = "^1.1.1" psycopg2-binary = "^2.9.10" +[tool.poetry.group.dev.dependencies] +pytest = "^8.0" + [[tool.poetry.source]] name = "aliyun" url = "https://mirrors.aliyun.com/pypi/simple/" diff --git a/requirements.txt b/requirements.txt index 2977e81..5ab2db5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,8 @@ pandas>=2.2.3 sqlalchemy>=2.0.40 tqdm>=4.67.1 python-dotenv>=1.1.0 +pymysql>=1.1.1 psycopg2-binary>=2.9.10 +requests>=2.32.0 +smbprotocol>=1.13.0 +dbfread>=2.0.7 diff --git a/scripts/add_constraints.sql b/scripts/add_constraints.sql deleted file mode 100644 index 258478d..0000000 --- a/scripts/add_constraints.sql +++ /dev/null @@ -1,70 +0,0 @@ --- 增量同步约束脚本 --- 为各数据表添加唯一约束,确保 (code, datetime) 组合唯一 --- 执行前请先备份数据库 - --- ============================================ --- 第一步:清理重复数据(如有) --- ============================================ - --- 清理 daily_data 重复记录,保留 id 最大的(日线使用 date 字段) -DELETE FROM daily_data a USING daily_data b -WHERE a.id < b.id AND a.code = b.code AND a.date = b.date; - --- 清理 minute5_data 重复记录 -DELETE FROM minute5_data a USING minute5_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 minute15_data 重复记录 -DELETE FROM minute15_data a USING minute15_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 minute30_data 重复记录 -DELETE FROM minute30_data a USING minute30_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 minute60_data 重复记录 -DELETE FROM minute60_data a USING minute60_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 block_stock_relation 重复记录 -DELETE FROM block_stock_relation a USING block_stock_relation b -WHERE a.id < b.id AND a.block_code = b.block_code AND a.code = b.code; - --- ============================================ --- 第二步:添加唯一约束 --- ============================================ - --- daily_data 表:(code, date) 唯一(日线使用 date 字段) -ALTER TABLE daily_data -ADD CONSTRAINT uq_daily_code_date UNIQUE (code, date); - --- minute5_data 表 -ALTER TABLE minute5_data -ADD CONSTRAINT uq_minute5_code_datetime UNIQUE (code, datetime); - --- minute15_data 表 -ALTER TABLE minute15_data -ADD CONSTRAINT uq_minute15_code_datetime UNIQUE (code, datetime); - --- minute30_data 表 -ALTER TABLE minute30_data -ADD CONSTRAINT uq_minute30_code_datetime UNIQUE (code, datetime); - --- minute60_data 表 -ALTER TABLE minute60_data -ADD CONSTRAINT uq_minute60_code_datetime UNIQUE (code, datetime); - --- block_stock_relation 表:(block_code, code) 唯一 -ALTER TABLE block_stock_relation -ADD CONSTRAINT uq_block_code UNIQUE (block_code, code); - --- stock_info 表:code 已有唯一索引,无需额外操作 - --- ============================================ --- 验证约束(可选) --- ============================================ - --- 查看所有约束 --- SELECT conname, conrelid::regclass, pg_get_constraintdef(oid) --- FROM pg_constraint --- WHERE conname LIKE 'uq_%'; diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index afefaec..0000000 --- a/src/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""TDX数据处理器 - 用于读取通达信本地股票数据并存储到数据库或CSV""" - -__version__ = "0.1.0" - -from .logger import logger, setup_logger - -__all__ = ["logger", "setup_logger", "__version__"] diff --git a/src/cli.py b/src/cli.py deleted file mode 100644 index 5acfa5b..0000000 --- a/src/cli.py +++ /dev/null @@ -1,483 +0,0 @@ -"""命令行接口模块 - -提供命令行接口,方便用户使用程序功能 -""" - -import argparse -import sys -from argparse import Namespace -from typing import Optional - -from datetime import timedelta - -import pandas as pd -from tqdm import tqdm - -from .reader import TdxDataReader -from .processor import DataProcessor -from .storage import DataStorage -from .config import config -from .logger import logger - - -def sync_single_stock_min_data( - reader: TdxDataReader, - processor: DataProcessor, - storage: DataStorage, - market: int, - code: str, - start_date: Optional[str] = None, - incremental: bool = True, -) -> bool: - """处理并存储单只股票的分钟数据 - - Args: - reader: 数据读取器 - processor: 数据处理器 - storage: 数据存储器 - market: 市场代码 - code: 股票代码 - start_date: 开始日期 - incremental: 是否启用精确增量 - """ - # DB 中 code 为纯 6 位数字(reader 写入时会截取),查询时需匹配 - db_code = code[-6:] if len(code) > 6 else code - - # 精确增量:查询该股票的最新日期 - if incremental and not start_date: - latest = storage.get_latest_datetime_by_code('minute5_data', db_code) - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.debug(f"{code} 增量起始日期: {start_date}") - - # 读取5分钟数据 - df_5min = reader.read_5min_data(market, code) - if df_5min.empty: - logger.warning(f"{code} 无5分钟数据") - return False - - # 准备 datetime 索引 - if not pd.api.types.is_datetime64_any_dtype(df_5min['datetime']): - df_5min['datetime'] = pd.to_datetime(df_5min['datetime']) - df_5min['date'] = df_5min['datetime'].dt.date - df_5min = df_5min.set_index('datetime') - - # 重采样为多周期 - df_15min = DataProcessor.resample_ohlcv(df_5min, '15min') - df_30min = DataProcessor.resample_ohlcv(df_5min, '30min') - df_60min = DataProcessor.resample_ohlcv(df_5min, '60min') - df_5min = df_5min.reset_index() - - # 处理、筛选、存储各周期 - freq_data = [ - (df_5min, 5, 'minute5_data'), - (df_15min, 15, 'minute15_data'), - (df_30min, 30, 'minute30_data'), - (df_60min, 60, 'minute60_data'), - ] - - has_data = False - for df, freq, table_name in freq_data: - processed = processor.process_min_data(df) - if start_date: - processed = processor.filter_data_min(processed, start_date=start_date) - if processed.empty: - continue - has_data = True - if incremental: - storage.save_incremental(processed, table_name) - else: - storage.save_minute_data(processed, freq=freq, to_csv=False, to_db=True) - - if has_data: - logger.info(f"{code} 分钟数据已处理并存入数据库") - else: - logger.debug(f"{code} 无新数据需要同步") - - return True - - -def sync_all_daily_data( - reader: TdxDataReader, - processor: DataProcessor, - storage: DataStorage, - start_date: Optional[str] = None, -) -> bool: - """逐股票流式同步日线数据,避免全量加载到内存""" - try: - stocks = reader.get_stock_list() - logger.info(f"同步日线数据,共 {len(stocks)} 只股票") - - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - total_inserted = 0 - - for _, stock in iterator: - code = stock['code'] - market = 1 if code.startswith('sh') else 0 - try: - data = reader.read_daily_data(market, code) - if isinstance(data.index, pd.DatetimeIndex) or data.index.name == 'datetime': - data = data.reset_index() - if data.empty: - continue - - processed = processor.process_daily_data(data) - filtered = processor.filter_data(processed, start_date=start_date) - if filtered.empty: - continue - - inserted = storage.save_incremental( - filtered, 'daily_data', - conflict_columns=('code', 'date'), - batch_size=config.db_batch_size - ) - total_inserted += inserted - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"同步 {code} 日线数据时出错: {e}") - continue - - if total_inserted > 0: - logger.info(f"日线数据同步完成,共插入 {total_inserted} 条") - else: - logger.info("日线数据已是最新") - return True - except Exception as e: - logger.error(f"同步日线数据时出错: {e}") - return False - - -def sync_all_min_data( - reader: TdxDataReader, - processor: DataProcessor, - storage: DataStorage, - start_date: Optional[str] = None, -) -> bool: - """编排所有股票的分钟数据同步""" - try: - stocks = reader.get_stock_list() - logger.info(f"处理所有股票的分钟数据,共 {len(stocks)} 只股票") - - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - - for _, stock in iterator: - code = stock['code'] - market = 1 if code.startswith('sh') else 0 - try: - sync_single_stock_min_data(reader, processor, storage, market, code, start_date) - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"处理 {code} 分钟数据时出错: {e}") - continue - - return True - except Exception as e: - logger.error(f"处理分钟数据时出错: {e}") - return False - -def parse_args() -> Namespace: - """解析命令行参数 - - Returns: - Namespace: 解析后的命令行参数 - """ - parser = argparse.ArgumentParser(description='通达信数据处理工具') - - # 通用参数 - parser.add_argument('--tdx-path', help='通达信安装目录路径') - parser.add_argument('--output', help='输出CSV文件的目录路径') - parser.add_argument('--db-type', choices=['sqlite', 'mysql', 'postgresql'], help='数据库类型') - parser.add_argument('--db-host', help='数据库主机') - parser.add_argument('--db-port', help='数据库端口') - parser.add_argument('--db-name', help='数据库名称') - parser.add_argument('--db-user', help='数据库用户名') - parser.add_argument('--db-password', help='数据库密码') - parser.add_argument('--no-tqdm', action='store_true', help='禁用进度条') - parser.add_argument('--batch-size', type=int, default=10000, help='数据库批量插入的批次大小,默认10000条') - - # 子命令 - subparsers = parser.add_subparsers(dest='command', help='子命令') - - # 获取股票列表 - stock_list_parser = subparsers.add_parser('stock-list', help='获取股票列表') - stock_list_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - stock_list_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - - # 获取日线数据 - daily_parser = subparsers.add_parser('daily', help='获取日线数据') - daily_parser.add_argument('--code', help='股票代码,不指定则获取所有股票') - daily_parser.add_argument('--market', type=int, choices=[0, 1], help='市场代码,0表示深圳,1表示上海') - daily_parser.add_argument('--start_date', help='开始日期,格式为YYYY-MM-DD') - daily_parser.add_argument('--end_date', help='结束日期,格式为YYYY-MM-DD') - daily_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - daily_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - daily_parser.add_argument('--auto-start', action='store_true', help='自动检测起始日期(从数据库最新日期+1天开始)') - daily_parser.add_argument('--incremental', action='store_true', help='增量同步模式,跳过重复数据') - - # 获取并计算分钟线数据 - min_parser = subparsers.add_parser('minutes', help='获取分钟线数据') - min_parser.add_argument('--code', help='股票代码,不指定则获取所有股票') - min_parser.add_argument('--market', type=int, choices=[0, 1], help='市场代码,0表示深圳,1表示上海') - min_parser.add_argument('--start_date', help='开始日期,格式为YYYY-MM-DD') - min_parser.add_argument('--end_date', help='结束日期,格式为YYYY-MM-DD') - min_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - min_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - min_parser.add_argument('--auto-start', action='store_true', help='自动检测起始日期(从数据库最新日期+1天开始)') - min_parser.add_argument('--incremental', action='store_true', help='增量同步模式,跳过重复数据') - - # 获取板块与股票对应关系 - block_relation_parser = subparsers.add_parser('block-relation', help='获取板块与股票对应关系【未实现】') - block_relation_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - block_relation_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - - # 一键同步(日线 + 分钟线增量同步到数据库) - subparsers.add_parser('sync', help='一键增量同步所有数据到数据库(日线 + 5/15/30/60分钟线)') - - return parser.parse_args() - -def update_config(args: Namespace) -> None: - """根据命令行参数更新配置 - - Args: - args: 解析后的命令行参数 - """ - # 更新通达信路径 - if args.tdx_path: - config.tdx_path = args.tdx_path - - # 更新CSV输出路径 - if args.output: - config.csv_output_path = args.output - - # 更新数据库配置 - if args.db_type: - config.db_type = args.db_type - if args.db_host: - config.db_host = args.db_host - if args.db_port: - config.db_port = args.db_port - if args.db_name: - config.db_name = args.db_name - if args.db_user: - config.db_user = args.db_user - if args.db_password: - config.db_password = args.db_password - if args.batch_size: - config.db_batch_size = args.batch_size - - # 更新进度条配置 - if args.no_tqdm: - config.use_tqdm = False - -def main() -> int: - """主函数 - - Returns: - int: 程序退出码,0表示成功,非0表示失败 - """ - args = parse_args() - update_config(args) - - # 初始化数据读取器 - try: - reader = TdxDataReader() - except (ValueError, FileNotFoundError) as e: - logger.error(f"初始化失败: {e}") - return 1 - - # 初始化数据存储器 - storage = DataStorage() - - # 处理子命令 - if args.command == 'stock-list': - # 获取股票列表 - try: - stocks = reader.get_stock_list() - logger.info(f"获取到 {len(stocks)} 只股票信息") - - # 确定保存方式 - to_csv = not args.db_only - to_db = not args.csv_only - - # 保存数据 - storage.save_stock_info(stocks, to_csv=to_csv, to_db=to_db, batch_size=config.db_batch_size) - - except Exception as e: - logger.error(f"获取股票列表时出错: {e}") - return 1 - - elif args.command == 'daily': - try: - # 处理 --auto-start 参数 - start_date = args.start_date - if hasattr(args, 'auto_start') and args.auto_start and not start_date: - latest = storage.get_latest_datetime('daily_data', date_column='date') - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.info(f"自动检测起始日期: {start_date}") - else: - logger.info("数据库中没有数据,将获取所有数据") - - # 获取日线数据 - if args.code and args.market is not None: - # 获取单只股票的日线数据 - data = reader.read_daily_data(args.market, args.code) - else: - # 获取所有股票的日线数据 - data = reader.read_all_daily_data() - - if data.empty: - logger.warning("未获取到任何数据") - return 0 - - logger.info(f"获取到 {len(data)} 条日线数据记录") - - # 处理数据 - processor = DataProcessor() - processed_data = processor.process_daily_data(data) - - # 根据日期筛选 - filtered_data = processor.filter_data( - processed_data, - start_date=start_date, - end_date=args.end_date, - codes=[args.code] if args.code else None - ) - - if filtered_data.empty: - logger.warning("筛选后没有数据") - return 0 - - logger.info(f"筛选后有 {len(filtered_data)} 条日线数据记录") - - # 确定保存方式 - to_csv = not args.db_only - to_db = not args.csv_only - incremental = hasattr(args, 'incremental') and args.incremental - - # 保存数据 - if to_csv: - storage.save_to_csv(filtered_data, 'daily_data') - if to_db: - if incremental: - storage.save_incremental( - filtered_data, 'daily_data', - conflict_columns=('code', 'date'), - batch_size=config.db_batch_size - ) - else: - storage.save_to_database(filtered_data, 'daily_data', batch_size=config.db_batch_size) - - except Exception as e: - logger.error(f"获取日线数据时出错: {e}") - return 1 - - elif args.command == 'minutes': - try: - # 处理 --auto-start 参数(使用15分钟线表作为参考) - start_date = args.start_date - if hasattr(args, 'auto_start') and args.auto_start and not start_date: - latest = storage.get_latest_datetime('minute15_data') - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.info(f"自动检测起始日期: {start_date}") - else: - logger.info("数据库中没有数据,将获取所有数据") - - incremental = hasattr(args, 'incremental') and args.incremental - - # 获取分钟线数据 - if args.code and args.market is not None: - # 单只股票:统一走 sync_single_stock_min_data,覆盖 5/15/30/60 全部周期 - processor = DataProcessor() - success = sync_single_stock_min_data( - reader, processor, storage, - args.market, args.code, - start_date=start_date, - incremental=incremental, - ) - if not success: - logger.warning(f"股票 {args.code} 无数据可同步") - return 0 - else: - # 获取所有股票的分钟线数据 - logger.info("开始处理所有股票的分钟线数据...") - processor = DataProcessor() - success = sync_all_min_data(reader, processor, storage, start_date) - if success: - logger.info("所有股票的分钟线数据处理完成") - else: - logger.error("处理分钟线数据时出错") - return 1 - - except Exception as e: - logger.error(f"获取分钟线数据时出错: {e}") - return 1 - - elif args.command == 'block-relation': - # 获取板块与股票对应关系 - try: - block_relations = reader.get_block_stock_relation() - logger.info(f"获取到 {len(block_relations)} 条板块与股票对应关系记录") - - # 确定保存方式 - to_csv = not args.db_only - to_db = not args.csv_only - - # 保存数据 - storage.save_block_relation(block_relations, to_csv=to_csv, to_db=to_db, batch_size=config.db_batch_size) - - except Exception as e: - logger.error(f"获取板块与股票对应关系时出错: {e}") - return 1 - - elif args.command == 'sync': - # 一键增量同步所有数据 - logger.info("开始一键增量同步...") - processor = DataProcessor() - has_error = False - - # 1. 同步日线数据 - try: - logger.info("=== 同步日线数据 ===") - latest = storage.get_latest_datetime('daily_data', date_column='date') - start_date = None - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.info(f"日线起始日期: {start_date}") - - success = sync_all_daily_data(reader, processor, storage, start_date) - if not success: - logger.error("同步日线数据时出错") - has_error = True - except Exception as e: - logger.error(f"同步日线数据时出错: {e}") - has_error = True - - # 2. 同步分钟线数据(逐股票精确增量,不传全局 start_date) - try: - logger.info("=== 同步分钟线数据 ===") - success = sync_all_min_data(reader, processor, storage) - if not success: - logger.error("同步分钟线数据时出错") - has_error = True - except Exception as e: - logger.error(f"同步分钟线数据时出错: {e}") - has_error = True - - if has_error: - logger.warning("同步完成,但有部分错误") - return 1 - else: - logger.info("一键增量同步完成!") - - else: - logger.error("请指定子命令,使用 -h 查看帮助信息") - return 1 - - return 0 - -if __name__ == '__main__': - sys.exit(main()) diff --git a/src/processor.py b/src/processor.py deleted file mode 100644 index 5fb2e3f..0000000 --- a/src/processor.py +++ /dev/null @@ -1,288 +0,0 @@ -"""数据处理模块 - -负责对从通达信读取的原始数据进行清洗和转换,包括: -- 数据格式转换 -- 缺失值处理 -- 异常值检测 -- 计算技术指标 -- OHLCV 重采样 -""" - -from typing import Optional, List -import pandas as pd - -from .logger import logger - -# 重采样聚合规则 -RESAMPLE_AGG = { - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'volume': 'sum', - 'amount': 'sum', - 'code': 'first', - 'market': 'first', -} - -# 均线周期 -MA_WINDOWS = [5, 10, 13, 21, 34, 55, 60, 89, 144, 233, 250] - - -class DataProcessor: - """数据处理类""" - - @staticmethod - def resample_ohlcv(df: pd.DataFrame, freq: str) -> pd.DataFrame: - """将 OHLCV 数据重采样到目标频率 - - Args: - df: 带有 DatetimeIndex 的 DataFrame - freq: pandas resample 频率字符串('15min', '30min', '60min') - - Returns: - 重采样后的 DataFrame(已 reset_index) - """ - agg = dict(RESAMPLE_AGG) - if 'date' in df.columns: - agg['date'] = 'first' - result = df.resample(freq).agg(agg).dropna() - result.reset_index(inplace=True) - return result - - @staticmethod - def _validate_ohlcv(df: pd.DataFrame) -> pd.DataFrame: - """校验 OHLCV 数据质量,丢弃不合格行 - - 校验规则: - 1. 价格列(open/high/low/close)必须 > 0 - 2. OHLC 关系:high >= max(open, close), low <= min(open, close) - - Args: - df: 包含 OHLCV 列的 DataFrame - - Returns: - 校验通过的 DataFrame - """ - required = ['open', 'high', 'low', 'close'] - if not all(col in df.columns for col in required): - return df - - before = len(df) - - # 价格必须为正 - positive_mask = (df[required] > 0).all(axis=1) - - # OHLC 关系校验 - ohlc_mask = ( - (df['high'] >= df[['open', 'close']].max(axis=1)) & - (df['low'] <= df[['open', 'close']].min(axis=1)) - ) - - valid_mask = positive_mask & ohlc_mask - df = df[valid_mask] - - dropped = before - len(df) - if dropped > 0: - logger.warning(f"数据校验丢弃 {dropped} 条不合格记录(价格非正或 OHLC 关系异常)") - - return df - - @staticmethod - def _calculate_ma(df: pd.DataFrame) -> pd.DataFrame: - """计算均线指标,按股票分组 - - Args: - df: 包含 'close' 和 'code' 列的 DataFrame - - Returns: - 添加了均线列的 DataFrame - """ - for w in MA_WINDOWS: - df[f'ma{w}'] = df.groupby('code')['close'].transform( - lambda x: x.rolling(window=w).mean() - ) - return df - - @staticmethod - def process_daily_data(df: pd.DataFrame) -> pd.DataFrame: - """处理日线数据 - - Args: - df: 原始日线数据DataFrame - - Returns: - DataFrame: 处理后的数据 - """ - if df.empty: - return df - - # 复制数据,避免修改原始数据 - processed_df = df.copy() - - # 确保datetime列存在 - if 'datetime' not in processed_df.columns: - # 检查是否有索引中包含日期时间信息 - if processed_df.index.name == 'datetime' or isinstance(processed_df.index, pd.DatetimeIndex): - # 如果索引是日期时间类型,直接将索引转为列 - processed_df['datetime'] = processed_df.index - # 如果索引不是日期时间类型但包含日期信息(如终端输出所示) - elif hasattr(processed_df.iloc[-1], 'name') and isinstance(processed_df.iloc[-1].name, pd.Timestamp): - # 从行索引名称中提取日期时间 - processed_df['datetime'] = processed_df.apply(lambda row: row.name if isinstance(row.name, pd.Timestamp) else None, axis=1) - - # 处理缺失值 - numeric_columns = ['open', 'high', 'low', 'close', 'volume', 'amount'] - for col in numeric_columns: - if col in processed_df.columns: - # 用前一个有效值填充缺失值 - processed_df[col] = processed_df[col].ffill() - - # 数据质量校验 - processed_df = DataProcessor._validate_ohlcv(processed_df) - - # 计算均线指标 - if all(col in processed_df.columns for col in ['close', 'volume']): - processed_df = DataProcessor._calculate_ma(processed_df) - - return processed_df - - @staticmethod - def process_min_data(df: pd.DataFrame) -> pd.DataFrame: - """处理分钟线数据 - - Args: - df: 原始分钟线数据DataFrame - - Returns: - DataFrame: 处理后的数据 - """ - if df.empty: - return df - - # 复制数据,避免修改原始数据 - processed_df = df.copy() - - - # 重命名列,使其更符合通用命名 - column_mapping = { - 'amount': 'amount', # 成交额 - 'close': 'close', # 收盘价 - 'open': 'open', # 开盘价 - 'high': 'high', # 最高价 - 'low': 'low', # 最低价 - 'vol': 'volume', # 成交量 - 'year': 'year', # 年 - 'month': 'month', # 月 - 'day': 'day', # 日 - 'hour': 'hour', # 时 - 'minute': 'minute', # 分 - 'datetime': 'datetime', # 日期时间 - 'code': 'code', # 股票代码 - 'market': 'market' # 市场代码 - } - processed_df.rename(columns={k: v for k, v in column_mapping.items() if k in processed_df.columns}, inplace=True) - - # 确保datetime列存在 - if 'datetime' not in processed_df.columns and all(col in processed_df.columns for col in ['year', 'month', 'day', 'hour', 'minute']): - processed_df['datetime'] = pd.to_datetime( - processed_df[['year', 'month', 'day']].astype(str).agg('-'.join, axis=1) + ' ' + - processed_df[['hour', 'minute']].astype(str).agg(':'.join, axis=1) - ) - - # 处理缺失值 - numeric_columns = ['open', 'high', 'low', 'close', 'volume', 'amount'] - for col in numeric_columns: - if col in processed_df.columns: - # 用前一个有效值填充缺失值 - processed_df[col] = processed_df[col].ffill() - - # 数据质量校验 - processed_df = DataProcessor._validate_ohlcv(processed_df) - - # 计算均线指标 - if all(col in processed_df.columns for col in ['close', 'volume']): - processed_df = DataProcessor._calculate_ma(processed_df) - - return processed_df - - @staticmethod - def filter_data( - df: pd.DataFrame, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - codes: Optional[List[str]] = None - ) -> pd.DataFrame: - """根据条件筛选数据 - - Args: - df: 原始数据DataFrame - start_date: 开始日期,格式为'YYYY-MM-DD' - end_date: 结束日期,格式为'YYYY-MM-DD' - codes: 股票代码列表 - - Returns: - DataFrame: 筛选后的数据 - """ - if df.empty: - return df - - filtered_df = df.copy() - - - logger.debug(f"筛选日期范围: start_date={start_date}, end_date={end_date}") - # 按日期筛选 - if 'date' in filtered_df.columns: - if start_date: - filtered_df = filtered_df[filtered_df['date'] >= pd.to_datetime(start_date)] - if end_date: - filtered_df = filtered_df[filtered_df['date'] <= pd.to_datetime(end_date)] - - # 按时间筛选 - if 'datetime' in filtered_df.columns: - if start_date: - filtered_df = filtered_df[filtered_df['datetime'] >= pd.to_datetime(start_date)] - if end_date: - filtered_df = filtered_df[filtered_df['datetime'] <= pd.to_datetime(end_date)] - - # 按股票代码筛选 - if codes and 'code' in filtered_df.columns: - filtered_df = filtered_df[filtered_df['code'].isin(codes)] - - return filtered_df - - @staticmethod - def filter_data_min( - df: pd.DataFrame, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - codes: Optional[List[str]] = None - ) -> pd.DataFrame: - """根据条件筛选分钟线数据 - - Args: - df: 原始数据DataFrame - start_date: 开始日期,格式为'YYYY-MM-DD' - end_date: 结束日期,格式为'YYYY-MM-DD' - codes: 股票代码列表 - - Returns: - DataFrame: 筛选后的数据 - """ - if df.empty: - return df - - filtered_df = df.copy() - - # 按日期筛选 - if 'date' in filtered_df.columns: - if start_date: - filtered_df = filtered_df[filtered_df['datetime'] >= pd.to_datetime(start_date)] - if end_date: - filtered_df = filtered_df[filtered_df['datetime'] <= pd.to_datetime(end_date)] - - # 按股票代码筛选 - if codes and 'code' in filtered_df.columns: - filtered_df = filtered_df[filtered_df['code'].isin(codes)] - - return filtered_df diff --git a/src/reader.py b/src/reader.py deleted file mode 100644 index f81cd70..0000000 --- a/src/reader.py +++ /dev/null @@ -1,409 +0,0 @@ -"""数据读取模块 - -负责从通达信本地数据文件中读取股票数据,支持: -- 日线数据 -- 分钟线数据 -- 股票列表 -""" - -import os -import re -from pathlib import Path -from typing import Optional, List - -import pandas as pd -from pytdx.reader import TdxDailyBarReader, TdxMinBarReader, TdxLCMinBarReader -from pytdx.reader import BlockReader -from tqdm import tqdm - -from .config import config -from .logger import logger -from .processor import DataProcessor - -class TdxDataReader: - """通达信数据读取类""" - - def __init__(self, tdx_path: Optional[str] = None) -> None: - """初始化数据读取器 - - Args: - tdx_path: 通达信安装目录,如果为None则使用配置中的路径 - """ - self.tdx_path = tdx_path or config.tdx_path - if not self.tdx_path: - raise ValueError("通达信数据路径未设置,请在.env文件中设置TDX_PATH或在初始化时提供") - - self.tdx_path = Path(self.tdx_path) - if not self.tdx_path.exists(): - raise FileNotFoundError(f"通达信数据路径不存在: {self.tdx_path}") - - # 初始化读取器 - self.daily_reader = TdxDailyBarReader() - self.min_reader = TdxMinBarReader() - self.lc_min_reader = TdxLCMinBarReader() - self.block_reader = BlockReader() - - def get_stock_list(self) -> pd.DataFrame: - """获取股票列表 - - Returns: - DataFrame: 包含A股股票代码和名称的DataFrame(不包含B股、基金、等) - """ - # 尝试查找通达信股票数据文件 - sz_path = self.tdx_path / 'vipdoc' / 'sz' / 'lday' - sh_path = self.tdx_path / 'vipdoc' / 'sh' / 'lday' - - if not (sz_path.exists() or sh_path.exists()): - raise FileNotFoundError(f"无法找到股票列表文件或股票数据目录") - - # 从目录中获取股票代码 - stocks = [] - - # 处理深圳股票 - if sz_path.exists(): - for file in sz_path.glob('*.day'): - code = file.stem - name = f"深A{code}" - - pure_code = code[-6:] - code_str = str(pure_code).zfill(6) # 补齐为6位字符串 - # 匹配上证A股+深证A股 - if re.match(r'^(000|001|002|300)\d{3}$', code_str): - stocks.append({'code': code, 'name': name}) - - # 处理上海股票 - if sh_path.exists(): - for file in sh_path.glob('*.day'): - code = file.stem - name = f"上A{code}" - - pure_code = code[-6:] - code_str = str(pure_code).zfill(6) # 补齐为6位字符串 - # 匹配上证A股+深证A股 - if re.match(r'^(60|688)\d{4}$', code_str): - stocks.append({'code': code, 'name': name}) - - if not stocks: - raise FileNotFoundError(f"未找到任何股票数据文件") - - return pd.DataFrame(stocks, columns=['code', 'name']) - - def read_daily_data(self, market: int, code: str) -> pd.DataFrame: - """读取日线数据 - - Args: - market: 市场代码,0表示深圳,1表示上海 - code: 股票代码 - - Returns: - DataFrame: 日线数据 - """ - # 构建日线数据文件路径 - market_folder = 'sz' if market == 0 else 'sh' - data_path = self.tdx_path / 'vipdoc' / market_folder / 'lday' - - if (len(code)>6): - code = code[-6:] - file_path = data_path / f"{market_folder}{code}.day" - - if not file_path.exists(): - raise FileNotFoundError(f"日线数据文件不存在: {file_path}") - - # 读取数据 - data = self.daily_reader.get_df(str(file_path)) - data['code'] = code - data['market'] = market - return data - - def read_min_data(self, market: int, code: str) -> List[pd.DataFrame]: - """读取5分钟线数据并生成15分钟、30分钟和60分数据 - - Args: - market: 市场代码,0表示深圳,1表示上海 - code: 股票代码 - - Returns: - list: [15分钟数据, 30分钟数据, 60分钟数据] - """ - # 构建分钟线数据文件路径 - market_folder = 'sz' if market == 0 else 'sh' - freq_folder = 'fzline' - data_path = self.tdx_path / 'vipdoc' / market_folder / freq_folder - file_path = data_path /f"{market_folder}{code}.lc5" # 只读取5分钟数据 - - if not file_path.exists(): - raise FileNotFoundError(f"5分钟线数据文件不存在: {file_path}") - - # 读取5分钟数据 - logger.info(f"正在读取 {code} 的5分钟线数据...") - with tqdm(total=1, desc="读取进度") as pbar: - data = self.lc_min_reader.get_df(str(file_path)) - data['code'] = code - data['market'] = market - pbar.update(1) - - # 确保datetime列存在并且是日期时间类型 - if 'datetime' not in data.columns: - # 如果没有datetime列,尝试从index创建 - if isinstance(data.index, pd.DatetimeIndex): - data['datetime'] = data.index - else: - raise ValueError("数据中缺少datetime列且索引不是日期时间类型") - elif not pd.api.types.is_datetime64_any_dtype(data['datetime']): - data['datetime'] = pd.to_datetime(data['datetime']) - - # 设置datetime为索引,用于后续resample操作 - data.set_index('datetime', inplace=True) - - # 记得定期获取最新的数据,同步进TDX - logger.debug(f"数据时间范围: {data.index[0]} ~ {data.index[-1]}") - - # 重采样生成多周期数据 - data_15min = DataProcessor.resample_ohlcv(data, '15min') - data_30min = DataProcessor.resample_ohlcv(data, '30min') - data_60min = DataProcessor.resample_ohlcv(data, '60min') - - data.reset_index(inplace=True) - - return [data_15min, data_30min, data_60min] - - def read_5min_data(self, market: int, code: str) -> pd.DataFrame: - """读取5分钟线数据 - - Args: - market: 市场代码,0表示深圳,1表示上海 - code: 股票代码 - - Returns: - DataFrame: 5分钟数据 - """ - # 构建分钟线数据文件路径 - market_folder = 'sz' if market == 0 else 'sh' - freq_folder = 'fzline' - data_path = self.tdx_path / 'vipdoc' / market_folder / freq_folder - - if (len(code)>6): - code = code[-6:] - file_path = data_path /f"{market_folder}{code}.lc5" - - if not file_path.exists(): - raise FileNotFoundError(f"5分钟线数据文件不存在: {file_path}") - - # 读取5分钟数据 - logger.info(f"正在读取 {code} 的5分钟线数据...") - with tqdm(total=1, desc="读取进度") as pbar: - data = self.lc_min_reader.get_df(str(file_path)) - data['code'] = code - data['market'] = market - pbar.update(1) - - # 确保datetime列存在并且是日期时间类型 - if 'datetime' not in data.columns: - # 如果没有datetime列,尝试从index创建 - if isinstance(data.index, pd.DatetimeIndex): - data['datetime'] = data.index - else: - raise ValueError("数据中缺少datetime列且索引不是日期时间类型") - elif not pd.api.types.is_datetime64_any_dtype(data['datetime']): - data['datetime'] = pd.to_datetime(data['datetime']) - - # 设置datetime为索引,用于后续resample操作 - data.set_index('datetime', inplace=True) - - # 记得定期获取最新的数据,同步进TDX - logger.debug(f"数据时间范围: {data.index[0]} ~ {data.index[-1]}") - - # 重置索引,使datetime成为列 - data.reset_index(inplace=True) - - return data - - def read_all_daily_data(self) -> pd.DataFrame: - """读取所有股票的日线数据 - - Returns: - DataFrame: 所有股票的日线数据 - """ - # 获取股票列表 - stocks = self.get_stock_list() - logger.info(f"获取到 {len(stocks)} 只股票,开始读取日线数据...") - - all_data = [] - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - - for _, stock in iterator: - code = stock['code'] - # 判断市场 - if code.startswith('sh'): - market = 1 # 上海 - else: - market = 0 # 深圳 - - try: - data = self.read_daily_data(market, code) - # 确保datetime是列而不是索引 - if isinstance(data.index, pd.DatetimeIndex) or data.index.name == 'datetime': - data = data.reset_index() - all_data.append(data) - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"读取 {code} 日线数据时出错: {e}") - continue - - if not all_data: - return pd.DataFrame() - - # 合并数据时保留datetime列 - result_df = pd.concat(all_data, ignore_index=True) - - # 确保datetime列存在并且是正确的日期时间格式 - if 'datetime' in result_df.columns and not pd.api.types.is_datetime64_any_dtype(result_df['datetime']): - try: - result_df['datetime'] = pd.to_datetime(result_df['datetime']) - except Exception as e: - logger.warning(f"转换datetime列时出错: {e}") - - return result_df - - # 板块关系暂时未实现,由于板块文件未找到 - def get_block_stock_relation(self) -> pd.DataFrame: - """获取通达信板块与股票的对应关系 - - Returns: - DataFrame: 包含板块代码、板块名称和对应股票代码的DataFrame - """ - # 板块文件目录 - block_path = self.tdx_path / 'T0002' / 'hq_cache' - - if not block_path.exists(): - raise FileNotFoundError(f"板块文件目录不存在: {block_path}") - - blocks = self.block_reader.get_df(self.tdx_path / 'BlockMap' / 'TdxZLSelStock.dat') - logger.debug(f"读取到板块数据: {len(blocks)} 条记录") - - # # 板块文件列表 - # block_files = list(block_path.glob('block*.dat')) - # block_files.extend(list(block_path.glob('block*.blk'))) - - # if not block_files: - # raise FileNotFoundError(f"未找到板块文件: {block_path}") - - # # 存储板块与股票的对应关系 - # block_stock_relations = [] - - # # 遍历板块文件 - # for block_file in block_files: - # block_type = block_file.stem - - # try: - # # 使用BlockReader读取板块文件 - # block_data = self.block_reader.get_df(str(block_file)) - - # if block_data.empty: - # continue - - # # 处理板块数据 - # for _, row in block_data.iterrows(): - # block_stock_relations.append({ - # 'block_code': row.get('block_code', block_type), - # 'block_name': row.get('block_name', block_type), - # 'code': row.get('code', ''), - # 'name': row.get('name', '') - # }) - # except Exception as e: - # print(f"读取板块文件{block_file}时出错: {e}") - - # # 尝试直接读取文件内容 - # try: - # with open(block_file, 'rb') as f: - # content = f.read() - - # # 解析板块文件内容 - # block_name = block_file.stem - - # # 尝试从文件名或内容中提取板块名称 - # if block_file.suffix.lower() == '.dat': - # # .dat文件通常是二进制格式 - # try: - # # 尝试从文件头部提取板块名称 - # if len(content) > 50: - # # 通达信板块文件格式可能不同,这里尝试几种常见格式 - # try: - # name_bytes = content[0:50].split(b'\x00')[0] - # block_name = name_bytes.decode('gbk', errors='ignore').strip() - # except: - # pass - # except: - # pass - - # # 提取股票代码 - # codes = [] - - # # 解析文件内容提取股票代码 - # if block_file.suffix.lower() == '.blk': - # # .blk文件通常是文本格式 - # try: - # text_content = content.decode('gbk', errors='ignore') - # for line in text_content.split('\n'): - # line = line.strip() - # if line and not line.startswith('#'): - # # 通常格式为 1 000001 或 0 000001 - # parts = line.split() - # if len(parts) >= 2: - # market = int(parts[0]) - # code = parts[1] - # market_prefix = 'sh' if market == 1 else 'sz' - # codes.append(f"{market_prefix}{code}") - # else: - # # 可能只有代码,没有市场标识 - # code = line - # # 根据代码前缀判断市场 - # if code.startswith(('6', '5', '9')): - # codes.append(f"sh{code}") - # else: - # codes.append(f"sz{code}") - # except: - # pass - # elif block_file.suffix.lower() == '.dat': - # # .dat文件通常是二进制格式 - # try: - # # 跳过文件头部,直接读取股票代码部分 - # offset = 384 # 通常板块文件头部大小 - # while offset < len(content): - # if offset + 7 <= len(content): - # market = content[offset] - # code = content[offset+1:offset+7].decode('ascii', errors='ignore') - # if code.isdigit(): - # market_prefix = 'sh' if market == 1 else 'sz' - # codes.append(f"{market_prefix}{code}") - # offset += 7 - # except: - # pass - - # # 添加到结果列表 - # for code in codes: - # block_stock_relations.append({ - # 'block_code': block_type, - # 'block_name': block_name, - # 'code': code, - # 'name': '' - # }) - # except Exception as e: - # print(f"直接解析板块文件{block_file}时出错: {e}") - - # # 转换为DataFrame - # if not block_stock_relations: - # return pd.DataFrame() - - # df = pd.DataFrame(block_stock_relations) - - # # 尝试补充股票名称 - # try: - # stocks = self.get_stock_list() - # stock_dict = dict(zip(stocks['code'], stocks['name'])) - # df['name'] = df['code'].map(stock_dict) - # except Exception as e: - # print(f"补充股票名称时出错: {e}") - - return df diff --git a/src/storage.py b/src/storage.py deleted file mode 100644 index 2b29aed..0000000 --- a/src/storage.py +++ /dev/null @@ -1,570 +0,0 @@ -"""数据存储模块 - -负责将处理后的数据保存到不同的存储介质,支持: -- CSV文件存储 -- 数据库存储(PostgreSQL、MySQL、SQLite) -""" - -import os -from datetime import datetime as dt -from pathlib import Path -from typing import Optional, Tuple - -import pandas as pd -from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, MetaData, Table, text -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from tqdm import tqdm - -from .config import config -from .logger import logger - -Base = declarative_base() - -class BlockStockRelation(Base): - """板块股票关系表模型""" - __tablename__ = 'block_stock_relation' - - id = Column(Integer, primary_key=True) - block_code = Column(String(20), index=True) # 板块代码 - block_name = Column(String(50)) # 板块名称 - code = Column(String(10), index=True) # 股票代码 - name = Column(String(50)) # 股票名称 - -class DailyData(Base): - """日线数据表模型""" - __tablename__ = 'daily_data' - - id = Column(Integer, primary_key=True) - code = Column(String(10), index=True) - market = Column(Integer) - datetime = Column(DateTime, index=True) - date = Column(DateTime, index=True) - open = Column(Float) - high = Column(Float) - low = Column(Float) - close = Column(Float) - volume = Column(Float) - amount = Column(Float) - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class Minute5Data(Base): - """5分钟线数据表模型""" - __tablename__ = 'minute5_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class Minute15Data(Base): - """15分钟线数据表模型""" - __tablename__ = 'minute15_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - # 添加技术指标列 - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) -class Minute30Data(Base): - """30分钟线数据表模型""" - __tablename__ = 'minute30_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - # 添加技术指标列 - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class Minute60Data(Base): - """60分钟线数据表模型""" - __tablename__ = 'minute60_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - # 添加技术指标列 - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class StockInfo(Base): - """股票信息表模型""" - __tablename__ = 'stock_info' - - id = Column(Integer, primary_key=True) - code = Column(String(10), unique=True, index=True) - name = Column(String(50)) - market = Column(Integer) - -# 允许写入的表名白名单 -_VALID_TABLES = frozenset({ - 'daily_data', 'minute5_data', 'minute15_data', 'minute30_data', 'minute60_data', - 'stock_info', 'block_stock_relation', -}) - - -class DataStorage: - """数据存储类""" - - def __init__( - self, - db_url: Optional[str] = None, - csv_path: Optional[str] = None - ) -> None: - """初始化数据存储 - - Args: - db_url: 数据库连接URL,如果为None则使用配置中的URL - csv_path: CSV文件保存路径,如果为None则使用配置中的路径 - """ - self.db_url = db_url or config.database_url - self.csv_path = csv_path or config.csv_output_path - - # 确保CSV输出目录存在 - if self.csv_path: - os.makedirs(self.csv_path, exist_ok=True) - - # 初始化数据库连接 - if self.db_url: - self.engine = create_engine(self.db_url) - Base.metadata.create_all(self.engine) - self.Session = sessionmaker(bind=self.engine) - - def save_to_csv(self, df: pd.DataFrame, filename: str) -> Optional[str]: - """保存数据到CSV文件 - - Args: - df: 要保存的DataFrame - filename: 文件名(不包含路径和扩展名) - - Returns: - str: 保存的文件路径,如果没有数据则返回None - """ - if df.empty: - logger.warning(f"没有数据可保存到 {filename}.csv") - return None - - file_path = Path(self.csv_path) / f"{filename}.csv" - df.to_csv(file_path, index=False, encoding='utf-8') - logger.info(f"数据已保存到: {file_path}") - return str(file_path) - - def get_latest_datetime( - self, - table_name: str, - date_column: str = 'datetime' - ) -> Optional[dt]: - """获取表中最新的日期/时间 - - Args: - table_name: 表名 - date_column: 日期列名,默认为 'datetime',日线数据应使用 'date' - - Returns: - Optional[datetime]: 最新的日期/时间,如果表为空则返回 None - """ - try: - with self.engine.connect() as conn: - result = conn.execute( - text(f"SELECT MAX({date_column}) FROM {table_name}") - ) - row = result.fetchone() - if row and row[0]: - return row[0] - return None - except Exception as e: - logger.debug(f"获取表 {table_name} 最新日期时出错: {e}") - return None - - def get_latest_datetime_by_code( - self, - table_name: str, - code: str - ) -> Optional[dt]: - """获取指定股票的最新 datetime - - Args: - table_name: 表名 - code: 股票代码 - - Returns: - Optional[datetime]: 最新的 datetime,如果没有数据则返回 None - """ - try: - with self.engine.connect() as conn: - result = conn.execute( - text(f"SELECT MAX(datetime) FROM {table_name} WHERE code = :code"), - {"code": code} - ) - row = result.fetchone() - if row and row[0]: - return row[0] - return None - except Exception as e: - logger.debug(f"获取表 {table_name} 股票 {code} 最新日期时出错: {e}") - return None - - def save_incremental( - self, - df: pd.DataFrame, - table_name: str, - conflict_columns: Tuple[str, ...] = ('code', 'datetime'), - batch_size: int = 10000 - ) -> int: - """增量保存数据,跳过重复记录 - - 使用 ON CONFLICT DO NOTHING 策略,需要先在数据库中添加唯一约束 - - Args: - df: 要保存的 DataFrame - table_name: 表名 - conflict_columns: 唯一约束列,默认为 ('code', 'datetime') - batch_size: 批处理大小 - - Returns: - int: 实际插入的行数 - """ - if table_name not in _VALID_TABLES: - raise ValueError(f"不允许写入的表名: {table_name}") - - if df.empty: - logger.warning(f"没有数据可保存到表 {table_name}") - return 0 - - total_rows = len(df) - - # 确保 datetime 不是索引而是列 - df_to_save = df.copy() - if df_to_save.index.name == 'datetime' or isinstance(df_to_save.index, pd.DatetimeIndex): - df_to_save = df_to_save.reset_index() - - # 获取列名 - columns = list(df_to_save.columns) - columns_str = ', '.join(columns) - db_type = config.db_type - - try: - if db_type == 'postgresql': - self._save_incremental_pg( - df_to_save, columns, columns_str, - table_name, conflict_columns, batch_size - ) - else: - # MySQL / SQLite: 走 SQLAlchemy executemany - placeholders = ', '.join([f':{col}' for col in columns]) - if db_type == 'mysql': - sql = text(f"INSERT IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") - elif db_type == 'sqlite': - sql = text(f"INSERT OR IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") - else: - raise ValueError(f"不支持的数据库类型: {db_type}") - - with self.engine.connect() as conn: - for i in range(0, total_rows, batch_size): - batch_df = df_to_save.iloc[i:i + batch_size] - conn.execute(sql, batch_df.to_dict('records')) - conn.commit() - - logger.info(f"增量保存完成: 共处理 {total_rows} 条到表 {table_name}(重复数据已跳过)") - return total_rows - - except Exception as e: - logger.error(f"增量保存数据到表 {table_name} 时出错: {e}") - return 0 - - def _save_incremental_pg( - self, - df: pd.DataFrame, - columns: list, - columns_str: str, - table_name: str, - conflict_columns: Tuple[str, ...], - batch_size: int, - ) -> None: - """PostgreSQL 专用:使用 execute_values 真正批量插入 - - 一次网络往返插入整批数据,比 executemany 快 10-100x。 - """ - from psycopg2.extras import execute_values - - conflict_str = ', '.join(conflict_columns) - sql = f"INSERT INTO {table_name} ({columns_str}) VALUES %s ON CONFLICT ({conflict_str}) DO NOTHING" - - # DataFrame → list of tuples,NaN/NaT → None(psycopg2 需要 None 表示 NULL) - df_clean = df.astype(object).where(df.notna(), None) - values = list(df_clean.itertuples(index=False, name=None)) - - raw_conn = self.engine.raw_connection() - try: - cursor = raw_conn.cursor() - execute_values(cursor, sql, values, page_size=batch_size) - raw_conn.commit() - finally: - raw_conn.close() - - def save_to_database( - self, - df: pd.DataFrame, - table_name: str, - batch_size: int = 10000 - ) -> bool: - """保存数据到数据库 - - Args: - df: 要保存的DataFrame - table_name: 表名 - batch_size: 批处理大小,默认10000条记录 - - Returns: - bool: 是否保存成功 - """ - if df.empty: - logger.warning(f"没有数据可保存到表 {table_name}") - return False - - try: - # 获取数据总量 - total_rows = len(df) - - logger.debug(f"开始保存数据到数据库表: {table_name}, 共 {total_rows} 条记录") - - # 如果数据量小于批处理大小,直接保存 - if total_rows <= batch_size: - df.to_sql(table_name, self.engine, if_exists='append', index=False) - logger.info(f"数据已保存到数据库表: {table_name}") - return True - - # 数据量大,分批处理 - logger.info(f"数据量较大({total_rows}条),开始分批保存到数据库表: {table_name}") - - # 计算批次数 - num_batches = (total_rows + batch_size - 1) // batch_size - - # 创建进度条 - iterator = tqdm(range(num_batches), desc="保存到数据库") if config.use_tqdm else range(num_batches) - - # 确保datetime不是索引而是列 - df_to_save = df.copy() - if df_to_save.index.name == 'datetime' or isinstance(df_to_save.index, pd.DatetimeIndex): - df_to_save = df_to_save.reset_index() - - # 分批保存 - for i in iterator: - start_idx = i * batch_size - end_idx = min((i + 1) * batch_size, total_rows) - batch_df = df_to_save.iloc[start_idx:end_idx] - - # 保存当前批次 - # 使用正确的方法检查表是否存在 - from sqlalchemy import inspect - inspector = inspect(self.engine) - if_exists = 'append' if i > 0 or inspector.has_table(table_name) else 'replace' - batch_df.to_sql(table_name, self.engine, if_exists=if_exists, index=False) - - if not config.use_tqdm: - logger.info(f"已保存 {end_idx}/{total_rows} 条记录到数据库表 {table_name}") - - logger.info(f"所有数据已成功保存到数据库表: {table_name}") - return True - except Exception as e: - logger.error(f"保存数据到数据库表 {table_name} 时出错: {e}") - return False - - def save_daily_data( - self, - df: pd.DataFrame, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存日线数据 - - Args: - df: 日线数据DataFrame - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, 'daily_data') - - if to_db: - db_success = self.save_to_database(df, 'daily_data', batch_size=batch_size) - - return csv_path, db_success - - def save_minute_data( - self, - df: pd.DataFrame, - freq: int = 1, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存分钟线数据 - - Args: - df: 分钟线数据DataFrame - freq: 分钟频率 - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, f'minute{freq}_data') - - if to_db: - db_success = self.save_to_database(df, f'minute{freq}_data', batch_size=batch_size) - - return csv_path, db_success - - def save_stock_info( - self, - df: pd.DataFrame, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存股票信息 - - Args: - df: 股票信息DataFrame - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, 'stock_info') - - if to_db: - db_success = self.save_to_database(df, 'stock_info', batch_size=batch_size) - - return csv_path, db_success - - def save_block_relation( - self, - df: pd.DataFrame, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存板块与股票的对应关系 - - Args: - df: 板块与股票对应关系DataFrame - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, 'block_stock_relation') - - if to_db: - db_success = self.save_to_database(df, 'block_stock_relation', batch_size=batch_size) - - return csv_path, db_success diff --git a/src/tdx2db/__init__.py b/src/tdx2db/__init__.py new file mode 100644 index 0000000..bf4a70b --- /dev/null +++ b/src/tdx2db/__init__.py @@ -0,0 +1,115 @@ +"""tdx2db: 从通达信本地文件同步 A 股日线数据到数据库。 + +基本用法:: + + from tdx2db import TdxDailySync + + sync = TdxDailySync(tdx_path="/opt/tdx", db_url="sqlite:///data.db") + sync.sync_all(adj_type='forward', workers=4) + sync.sync_stock('sz000001', start_date=20240101) + df = sync.get_daily('sz000001', start_date=20240101) +""" + +from typing import Optional +import pandas as pd + +from .config import Config, config +from .reader import TdxDataReader +from .processor import DataProcessor +from .storage import DataStorage +from .logger import logger + +__version__ = "0.2.0" +__all__ = ['TdxDailySync', 'TdxDataReader', 'DataProcessor', 'DataStorage', 'Config'] + + +class TdxDailySync: + """高层封装,供外部项目调用。""" + + def __init__( + self, + tdx_path: Optional[str] = None, + db_url: Optional[str] = None, + ) -> None: + if tdx_path: + config.tdx_path = tdx_path + self.reader = TdxDataReader(tdx_path) + self.processor = DataProcessor() + self.storage = DataStorage(db_url) + + def sync_all( + self, + adj_type: str = 'forward', + incremental: bool = True, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + ) -> dict: + """同步所有 A 股日线数据。 + + Returns: + {'total': N, 'success': N, 'failed': N} + """ + from .cli import sync_all_daily + gbbq = self.reader.read_gbbq() + return sync_all_daily( + self.reader, self.processor, self.storage, gbbq, + adj_type=adj_type, incremental=incremental, + start_date=start_date, end_date=end_date, + ) + + def sync_stock( + self, + code: str, + adj_type: str = 'forward', + start_date: Optional[int] = None, + end_date: Optional[int] = None, + ) -> int: + """同步单只股票日线数据,返回写入行数。""" + market = 1 if code.startswith('sh') else 0 + gbbq = self.reader.read_gbbq() + data = self.reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + processed = self.processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + processed = self.processor.filter_data(processed, start_date=start_date, end_date=end_date) + if processed.empty: + return 0 + return self.storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) + + def sync_stock_list(self) -> int: + """同步股票列表到 stock_info 表,返回股票数量。""" + stocks = self.reader.get_stock_list() + self.storage.save_stock_info(stocks) + return len(stocks) + + def get_daily( + self, + code: str, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + ) -> pd.DataFrame: + """从数据库查询日线数据,date 列为 YYYYMMDD 整数。""" + from sqlalchemy import text + if '.' in code: + db_code = code.upper() + else: + pure = code[-6:] if len(code) > 6 else code + if pure.startswith('6'): + suffix = '.SH' + elif pure.startswith('8') or pure.startswith('92'): + suffix = '.BJ' + else: + suffix = '.SZ' + db_code = pure + suffix + conditions = ["stock_code = :code"] + params: dict = {"code": db_code} + if start_date: + conditions.append("date >= :start_date") + params["start_date"] = str(start_date) + if end_date: + conditions.append("date <= :end_date") + params["end_date"] = str(end_date) + where = " AND ".join(conditions) + sql = text(f"SELECT * FROM daily_data WHERE {where} ORDER BY date") + with self.storage.engine.connect() as conn: + return pd.read_sql(sql, conn, params=params) diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py new file mode 100644 index 0000000..0bc0b00 --- /dev/null +++ b/src/tdx2db/cli.py @@ -0,0 +1,377 @@ +import argparse +import sys +from argparse import Namespace +from typing import Optional + +import pandas as pd +from tqdm import tqdm + +from .reader import TdxDataReader +from .processor import DataProcessor +from .storage import DataStorage +from .config import config +from .logger import logger +from .downloader import download_and_extract, DEFAULT_DOWNLOAD_URL + + +def _has_ex_rights_after(code: str, gbbq: pd.DataFrame, last_date: int) -> bool: + """检查该股票在 last_date 之后是否有除权事件(category==1)。""" + if gbbq is None or gbbq.empty: + return False + if code.startswith('6'): + prefix = 'sh' + elif code.startswith('8') or code.startswith('92'): + prefix = 'bj' + else: + prefix = 'sz' + full_code = prefix + code.zfill(6) + events = gbbq[ + (gbbq['full_code'] == full_code) & + (gbbq['category'] == 1) & + (gbbq['datetime'] > int(last_date)) + ] + return not events.empty + + +def sync_stock_list(reader: TdxDataReader, storage: DataStorage) -> bool: + """同步股票列表及名称,返回是否成功。从本地 TDX .tnf 文件读取中文名,无需联网。""" + try: + # name_map: {'SZ': {6位code: 名}, 'SH': {...}, 'BJ': {...}} + # 三个市场代码空间有重叠,必须按市场分开查找 + name_map = reader.read_stock_names() + local_codes = reader.get_stock_list() # ['000001.SZ', ...] + df = pd.DataFrame([ + { + 'stock_code': c, + 'stock_name': name_map.get(c.split('.')[1], {}).get( + c.split('.')[0], c.split('.')[0] + ) + } + for c in local_codes + ]) + logger.info(f"获取到 {len(df)} 只股票") + storage.save_stock_info(df) + return True + except Exception as e: + logger.error(f"同步股票列表出错: {e}") + return False + + +def sync_all_daily( + reader: TdxDataReader, + processor: DataProcessor, + storage: DataStorage, + gbbq: pd.DataFrame, + adj_type: str = 'forward', + incremental: bool = True, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + float_cap_map: dict = None, +) -> dict: + """逐股票流式同步日线数据,返回统计信息。SMB 模式下使用批量并发下载。""" + stocks = reader.get_stock_list() + logger.info(f"共 {len(stocks)} 只股票") + + latest_dates = storage.get_all_latest_dates() if incremental else {} + stats = {'total': len(stocks), 'success': 0, 'failed': 0} + + def _process_one(db_code: str, data: pd.DataFrame) -> None: + """处理单只股票的数据并写库。""" + pure_code = db_code.split('.')[0] + last_date = latest_dates.get(db_code) + + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + if data.empty: + stats['success'] += 1 + return + + needs_refresh = ( + incremental and last_date is not None and + _has_ex_rights_after(pure_code, gbbq, last_date) + ) + + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type, float_cap_map=float_cap_map) + + if incremental and last_date and not needs_refresh: + processed = processed[processed['date'] > last_date] + if start_date: + processed = processed[processed['date'] >= str(start_date)] + if end_date: + processed = processed[processed['date'] <= str(end_date)] + + if processed.empty: + stats['success'] += 1 + return + + if needs_refresh: + storage.delete_stock_data(db_code) + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date'), + batch_size=config.db_batch_size) + stats['success'] += 1 + + if reader._smb is not None: + # SMB 批量并发模式:分批预下载,再串行解析入库 + stocks_meta = [] + for db_code in stocks: + pure_code, suffix = db_code.split('.') + market = {'SZ': 0, 'SH': 1, 'BJ': 2}[suffix] + code = suffix.lower() + pure_code + stocks_meta.append((market, code, db_code)) + + batch_iter = reader.read_daily_data_batch( + stocks_meta, + batch_size=config.smb_batch_size, + smb_workers=config.smb_workers, + ) + iterator = ( + tqdm(batch_iter, total=len(stocks), desc="同步日线(SMB批量)") + if config.use_tqdm else batch_iter + ) + + for db_code, result in iterator: + if isinstance(result, FileNotFoundError): + stats['failed'] += 1 + elif isinstance(result, Exception): + logger.error(f"处理 {db_code} 时出错: {result}") + stats['failed'] += 1 + else: + try: + _process_one(db_code, result) + except Exception as e: + logger.error(f"处理 {db_code} 时出错: {e}") + stats['failed'] += 1 + else: + # 本地串行模式(不变) + iterator = tqdm(stocks, total=len(stocks), desc="同步日线") if config.use_tqdm else stocks + + for db_code in iterator: + pure_code, suffix = db_code.split('.') + market = {'SZ': 0, 'SH': 1, 'BJ': 2}[suffix] + code = suffix.lower() + pure_code # sz000001,供 read_daily_data 使用 + + try: + data = reader.read_daily_data(market, code) + _process_one(db_code, data) + except FileNotFoundError: + stats['failed'] += 1 + except Exception as e: + logger.error(f"处理 {code} 时出错: {e}") + stats['failed'] += 1 + + logger.info(f"同步完成: 成功 {stats['success']},失败 {stats['failed']}") + return stats + + +def parse_args() -> Namespace: + parser = argparse.ArgumentParser(description='tdx2db 日线数据同步工具') + + parser.add_argument('--tdx-path', help='通达信安装目录') + parser.add_argument('--db-type', choices=['sqlite', 'mysql', 'postgresql']) + parser.add_argument('--db-host') + parser.add_argument('--db-port') + parser.add_argument('--db-name') + parser.add_argument('--db-user') + parser.add_argument('--db-password') + parser.add_argument('--no-tqdm', action='store_true') + parser.add_argument('--batch-size', type=int, default=10000) + parser.add_argument('--smb-host', help='SMB 服务器地址') + parser.add_argument('--smb-share', help='SMB 共享名') + parser.add_argument('--smb-user', help='SMB 用户名') + parser.add_argument('--smb-password', help='SMB 密码') + parser.add_argument('--smb-tdx-path', help='TDX 在共享目录内的相对路径', default='') + parser.add_argument('--smb-port', type=int, default=445) + + subparsers = parser.add_subparsers(dest='command') + + # stock-list + subparsers.add_parser('stock-list', help='同步股票列表') + + # daily + daily = subparsers.add_parser('daily', help='同步日线数据') + daily.add_argument('--code', help='股票代码(6位数字,如 000001),不指定则全量,市场自动识别') + daily.add_argument('--start', type=int, help='开始日期 YYYYMMDD') + daily.add_argument('--end', type=int, help='结束日期 YYYYMMDD') + daily.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') + daily.add_argument('--incremental', action='store_true', help='增量模式') + + # sync + sync = subparsers.add_parser('sync', help='一键增量同步日线数据') + sync.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') + + # download + download = subparsers.add_parser('download', help='联网下载 TDX 日线数据并导入数据库') + download.add_argument('--url', help=f'下载地址(默认: {DEFAULT_DOWNLOAD_URL})') + download.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') + download.add_argument('--no-clean', action='store_true', dest='no_clean', + help='保留临时目录(用于调试)') + + return parser.parse_args() + + +def update_config(args: Namespace) -> None: + if args.tdx_path: + config.tdx_path = args.tdx_path + if args.db_type: + config.db_type = args.db_type + if args.db_host: + config.db_host = args.db_host + if args.db_port: + config.db_port = args.db_port + if args.db_name: + config.db_name = args.db_name + if args.db_user: + config.db_user = args.db_user + if args.db_password: + config.db_password = args.db_password + if args.batch_size: + config.db_batch_size = args.batch_size + if args.no_tqdm: + config.use_tqdm = False + if getattr(args, 'smb_host', None): + config.smb_host = args.smb_host + config.smb_enabled = True + if getattr(args, 'smb_share', None): + config.smb_share = args.smb_share + if getattr(args, 'smb_user', None): + config.smb_user = args.smb_user + if getattr(args, 'smb_password', None): + config.smb_password = args.smb_password + if getattr(args, 'smb_tdx_path', None) is not None: + config.smb_tdx_path = args.smb_tdx_path + if getattr(args, 'smb_port', None): + config.smb_port = args.smb_port + + +def _create_reader(): + """根据配置创建 TdxDataReader,返回 (reader, smb_accessor_or_None)。""" + if config.smb_enabled: + if not config.smb_host or not config.smb_share: + raise ValueError("SMB 模式需要设置 SMB_HOST 和 SMB_SHARE") + from .smb_accessor import SmbAccessor + smb = SmbAccessor( + host=config.smb_host, + share=config.smb_share, + tdx_path=config.smb_tdx_path, + username=config.smb_user or None, + password=config.smb_password or None, + port=config.smb_port, + ) + smb._register() + return TdxDataReader(smb=smb), smb + return TdxDataReader(), None + + +def main() -> int: + args = parse_args() + update_config(args) + + storage = DataStorage() + processor = DataProcessor() + + if args.command == 'download': + adj_type = getattr(args, 'adj', 'forward') + keep_tmp = getattr(args, 'no_clean', False) + url = getattr(args, 'url', None) + + gbbq = pd.DataFrame() + if config.smb_enabled: + try: + smb_reader, smb_acc = _create_reader() + gbbq = smb_reader.read_gbbq() + if smb_acc: + smb_acc._unregister() + logger.info("已从 SMB 读取权息文件") + except Exception: + logger.warning("SMB 权息文件读取失败,将跳过复权处理") + elif config.tdx_path: + try: + local_reader = TdxDataReader() + gbbq = local_reader.read_gbbq() + logger.info("已从本地通达信读取权息文件") + except Exception: + logger.warning("本地权息文件读取失败,将跳过复权处理") + + logger.info("=== 开始联网下载 TDX 日线数据 ===") + with download_and_extract(url=url, keep_tmp=keep_tmp) as vipdoc_path: + dl_reader = TdxDataReader(vipdoc_path=str(vipdoc_path)) + stats = sync_all_daily(dl_reader, processor, storage, gbbq, + adj_type=adj_type, incremental=True) + storage.save_sync_statistics(stats['success']) + return 0 + + try: + reader, smb_accessor = _create_reader() + except (ValueError, FileNotFoundError) as e: + logger.error(f"初始化失败: {e}") + return 1 + + try: + if args.command == 'stock-list': + if not sync_stock_list(reader, storage): + return 1 + + elif args.command == 'daily': + adj_type = getattr(args, 'adj', 'forward') + gbbq = reader.read_gbbq() + base_caps = reader.read_base_dbf() + float_cap_map = DataProcessor.build_float_capital_map(base_caps, gbbq) if base_caps else {} + + if args.code: + pure_code = args.code[-6:] if len(args.code) > 6 else args.code + if pure_code.startswith('6'): + market = 1 + elif pure_code.startswith('8') or pure_code.startswith('92'): + market = 2 + else: + market = 0 + prefix = {0: 'sz', 1: 'sh', 2: 'bj'}[market] + code = prefix + pure_code + try: + data = reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type, float_cap_map=float_cap_map) + processed = processor.filter_data(processed, start_date=args.start, end_date=args.end) + if not processed.empty: + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) + except Exception as e: + logger.error(f"同步 {code} 出错: {e}") + return 1 + else: + incremental = getattr(args, 'incremental', False) + stats = sync_all_daily(reader, processor, storage, gbbq, + adj_type=adj_type, incremental=incremental, + start_date=args.start, end_date=args.end, + float_cap_map=float_cap_map) + storage.save_sync_statistics(stats['success']) + + elif args.command == 'sync': + adj_type = getattr(args, 'adj', 'forward') + logger.info("=== 开始增量同步日线数据 ===") + sync_stock_list(reader, storage) + gbbq = reader.read_gbbq() + + base_caps = reader.read_base_dbf() + float_cap_map = DataProcessor.build_float_capital_map(base_caps, gbbq) if base_caps else {} + if not float_cap_map: + logger.warning("base.dbf 读取失败,换手率将降级使用 gbbq category==5 逻辑") + + stats = sync_all_daily(reader, processor, storage, gbbq, + adj_type=adj_type, incremental=True, + float_cap_map=float_cap_map) + storage.save_sync_statistics(stats['success']) + + else: + logger.error("请指定子命令,使用 -h 查看帮助") + return 1 + + finally: + if smb_accessor is not None: + smb_accessor._unregister() + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/config.py b/src/tdx2db/config.py similarity index 60% rename from src/config.py rename to src/tdx2db/config.py index b5bd197..17ee1c0 100644 --- a/src/config.py +++ b/src/tdx2db/config.py @@ -1,23 +1,10 @@ -"""配置管理模块 - -负责加载和管理程序的配置参数,包括: -- 通达信数据路径 -- 数据库连接信息 -- 输出CSV路径 -- 其他配置选项 -""" - import os -from pathlib import Path from dotenv import load_dotenv -# 加载.env文件中的环境变量 load_dotenv() class Config: - """配置类""" - tdx_path: str csv_output_path: str db_type: str @@ -28,30 +15,41 @@ class Config: db_password: str db_batch_size: int use_tqdm: bool + download_url: str + smb_enabled: bool + smb_host: str + smb_share: str + smb_user: str + smb_password: str + smb_tdx_path: str + smb_port: int + smb_workers: int + smb_batch_size: int def __init__(self) -> None: - """初始化配置""" - # 通达信安装路径 self.tdx_path = os.getenv('TDX_PATH', '') - - # CSV输出路径 self.csv_output_path = os.getenv('CSV_OUTPUT_PATH', 'output') - - # 数据库配置 - self.db_type = os.getenv('DB_TYPE', 'postgresql') + self.db_type = os.getenv('DB_TYPE', 'sqlite') self.db_host = os.getenv('DB_HOST', 'localhost') self.db_port = os.getenv('DB_PORT', '5432') self.db_name = os.getenv('DB_NAME', 'tdx_data') self.db_user = os.getenv('DB_USER', 'postgres') self.db_password = os.getenv('DB_PASSWORD', '') self.db_batch_size = int(os.getenv('DB_BATCH_SIZE', '10000')) - - # 是否使用进度条 self.use_tqdm = os.getenv('USE_TQDM', 'True').lower() == 'true' + self.download_url = os.getenv('TDX_DOWNLOAD_URL', '') + self.smb_enabled = os.getenv('SMB_ENABLED', 'false').lower() == 'true' + self.smb_host = os.getenv('SMB_HOST', '') + self.smb_share = os.getenv('SMB_SHARE', '') + self.smb_user = os.getenv('SMB_USER', '') + self.smb_password = os.getenv('SMB_PASSWORD', '') + self.smb_tdx_path = os.getenv('SMB_TDX_PATH', '') + self.smb_port = int(os.getenv('SMB_PORT', '445')) + self.smb_workers = int(os.getenv('SMB_WORKERS', '16')) + self.smb_batch_size = int(os.getenv('SMB_BATCH_SIZE', '200')) @property - def database_url(self): - """获取数据库连接URL""" + def database_url(self) -> str: if self.db_type == 'postgresql': return f"postgresql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}" elif self.db_type == 'mysql': @@ -61,5 +59,5 @@ def database_url(self): else: raise ValueError(f"不支持的数据库类型: {self.db_type}") -# 创建全局配置实例 + config = Config() diff --git a/src/tdx2db/downloader.py b/src/tdx2db/downloader.py new file mode 100644 index 0000000..ce829ad --- /dev/null +++ b/src/tdx2db/downloader.py @@ -0,0 +1,112 @@ +import shutil +import tempfile +import zipfile +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Optional + +import requests +from tqdm import tqdm + +from .config import config +from .logger import logger + +DEFAULT_DOWNLOAD_URL = 'https://data.tdx.com.cn/vipdoc/hsjday.zip' + + +def download_zip(url: str, dest_path: Path, chunk_size: int = 1024 * 1024) -> None: + """流式下载 ZIP 文件,支持 tqdm 进度条。""" + try: + response = requests.get(url, stream=True, timeout=(10, 300)) + response.raise_for_status() + except requests.RequestException as e: + raise RuntimeError(f"下载失败: {e}") from e + + total = int(response.headers.get('Content-Length', 0)) or None + desc = Path(url).name + + try: + if config.use_tqdm: + with tqdm(total=total, unit='B', unit_scale=True, desc=desc) as bar: + with open(dest_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + else: + logger.info(f"正在下载 {desc}...") + with open(dest_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + except Exception: + dest_path.unlink(missing_ok=True) + raise + + +def extract_zip(zip_path: Path, extract_dir: Path) -> Path: + """解压 ZIP 文件,返回内部 hsjday/ 子目录路径(作为 vipdoc_path 使用)。""" + if not zipfile.is_zipfile(zip_path): + raise ValueError(f"文件不是合法的 ZIP 格式: {zip_path}") + + with zipfile.ZipFile(zip_path) as zf: + bad = zf.testzip() + if bad: + raise ValueError(f"ZIP 文件损坏,首个问题文件: {bad}") + + # 安全检查:过滤路径穿越 + for member in zf.infolist(): + if member.filename.startswith('/') or '..' in member.filename: + raise ValueError(f"ZIP 包含不安全路径: {member.filename}") + + logger.info("正在解压数据包...") + # ZIP 内路径可能使用 Windows 反斜杠(如 sh\lday\sh000001.day) + # Python zipfile 在 macOS/Linux 上不会自动转换,需手动处理 + for member in zf.infolist(): + normalized = member.filename.replace('\\', '/') + dest = extract_dir / Path(normalized) + dest.parent.mkdir(parents=True, exist_ok=True) + if not normalized.endswith('/'): + with zf.open(member) as src, open(dest, 'wb') as dst: + dst.write(src.read()) + + # 兼容两种结构: + # 1. {sh,sz,bj}/lday/*.day (根目录直接是市场目录,实际情况) + # 2. hsjday/{sh,sz,bj}/lday/*.day (有顶层目录) + if (extract_dir / 'sh').exists() or (extract_dir / 'sz').exists() or (extract_dir / 'bj').exists(): + return extract_dir + vipdoc_path = extract_dir / 'hsjday' + if vipdoc_path.exists(): + return vipdoc_path + raise FileNotFoundError("解压后未找到预期的市场目录(sh/sz/bj),请检查 ZIP 包结构") + + +@contextmanager +def download_and_extract( + url: Optional[str] = None, + keep_tmp: bool = False, +) -> Generator[Path, None, None]: + """ + 下载并解压 TDX 日线数据包,yield vipdoc_path(即 hsjday/ 目录)。 + + 退出时若 keep_tmp=False 自动删除临时目录。 + + 用法: + with download_and_extract() as vipdoc_path: + reader = TdxDataReader(vipdoc_path=str(vipdoc_path)) + """ + target_url = url or config.download_url or DEFAULT_DOWNLOAD_URL + tmp_dir = Path(tempfile.mkdtemp(prefix='tdx_hsjday_')) + zip_path = tmp_dir / 'hsjday.zip' + + try: + logger.info(f"开始下载: {target_url}") + download_zip(target_url, zip_path) + logger.info("下载完成,开始解压...") + vipdoc_path = extract_zip(zip_path, tmp_dir) + zip_path.unlink(missing_ok=True) # 解压后释放 ZIP 占用的磁盘空间 + logger.info(f"解压完成: {vipdoc_path}") + yield vipdoc_path + finally: + if keep_tmp: + logger.info(f"临时目录已保留: {tmp_dir}") + else: + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/src/logger.py b/src/tdx2db/logger.py similarity index 100% rename from src/logger.py rename to src/tdx2db/logger.py diff --git a/src/tdx2db/processor.py b/src/tdx2db/processor.py new file mode 100644 index 0000000..9c0fa69 --- /dev/null +++ b/src/tdx2db/processor.py @@ -0,0 +1,346 @@ +from typing import Optional, List + +import pandas as pd + +from .logger import logger + + +class DataProcessor: + + @staticmethod + def _validate_ohlcv(df: pd.DataFrame) -> pd.DataFrame: + required = ['open', 'high', 'low', 'close'] + if not all(c in df.columns for c in required): + return df + before = len(df) + positive = (df[required] > 0).all(axis=1) + ohlc_ok = ( + (df['high'] >= df[['open', 'close']].max(axis=1)) & + (df['low'] <= df[['open', 'close']].min(axis=1)) + ) + df = df[positive & ohlc_ok] + dropped = before - len(df) + if dropped > 0: + logger.warning(f"数据校验丢弃 {dropped} 条不合格记录") + return df + + @staticmethod + def apply_forward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: + """前复权:对每个除权日,将该日之前的历史价格乘以复权因子。""" + if gbbq.empty or df.empty: + df = df.copy() + df['adj_factor'] = 1.0 + return df + + market_val = df['market'].iloc[0] + prefix = 'sz' if market_val == 0 else 'sh' + full_code = prefix + str(df['code'].iloc[0]).zfill(6) + events = gbbq[gbbq['full_code'] == full_code].copy() + + df = df.copy() + if events.empty: + df['adj_factor'] = 1.0 + return df + + events['ex_date'] = pd.to_datetime( + events['datetime'].astype(str).str[:8], format='%Y%m%d' + ) + data_start = df['date'].min() + events = events[ + (events['category'] == 1) & (events['ex_date'] > data_start) + ].sort_values('ex_date') + + df = df.sort_values('date').copy() + df['adj_factor'] = 1.0 + + for _, ev in events.iterrows(): + ex_date = ev['ex_date'] + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 + peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + + before = df[df['date'] < ex_date] + if before.empty: + continue + prev_close = float(before['close'].iloc[-1]) + if prev_close <= 0: + continue + denominator = prev_close * (1 + songgu + peigu) + if denominator <= 0: + continue + factor = (prev_close - hongli + peigujia * peigu) / denominator + if factor <= 0 or factor > 2: + logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") + continue + df.loc[df['date'] < ex_date, 'adj_factor'] *= factor + + for col in ['open', 'high', 'low', 'close']: + df[col] = (df[col] * df['adj_factor']).round(3) + return df + + @staticmethod + def apply_backward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: + """后复权:对每个除权日,将该日及之后的价格乘以 1/factor。""" + if gbbq.empty or df.empty: + df = df.copy() + df['adj_factor'] = 1.0 + return df + + market_val = df['market'].iloc[0] + prefix = 'sz' if market_val == 0 else 'sh' + full_code = prefix + str(df['code'].iloc[0]).zfill(6) + events = gbbq[gbbq['full_code'] == full_code].copy() + + df = df.copy() + if events.empty: + df['adj_factor'] = 1.0 + return df + + events['ex_date'] = pd.to_datetime( + events['datetime'].astype(str).str[:8], format='%Y%m%d' + ) + data_start = df['date'].min() + events = events[ + (events['category'] == 1) & (events['ex_date'] > data_start) + ].sort_values('ex_date') + + df = df.sort_values('date').copy() + df['adj_factor'] = 1.0 + + for _, ev in events.iterrows(): + ex_date = ev['ex_date'] + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 + peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + + before = df[df['date'] < ex_date] + if before.empty: + continue + prev_close = float(before['close'].iloc[-1]) + if prev_close <= 0: + continue + denominator = prev_close * (1 + songgu + peigu) + if denominator <= 0: + continue + factor = (prev_close - hongli + peigujia * peigu) / denominator + if factor <= 0 or factor > 2: + logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") + continue + df.loc[df['date'] >= ex_date, 'adj_factor'] *= (1.0 / factor) + + for col in ['open', 'high', 'low', 'close']: + df[col] = (df[col] * df['adj_factor']).round(3) + return df + + @staticmethod + def build_float_capital_map(base_caps: dict, gbbq: pd.DataFrame) -> dict: + """从 base.dbf 当前流通股本出发,逆向推算历史各时间点的流通股本。 + + Returns: {full_code: [(date_int, cap_万股), ...]} 已按 date_int 升序排列, + 供 merge_asof 使用(date_int 表示"从该日期起生效的流通股本")。 + """ + if gbbq.empty or not base_caps: + logger.debug(f"[float_cap] 跳过:gbbq.empty={gbbq.empty} base_caps空={not base_caps}") + return {} + + logger.debug(f"[float_cap] base_caps 共 {len(base_caps)} 条,gbbq 共 {len(gbbq)} 行") + result = {} + relevant_cats = {1, 10, 11, 12, 15} + gbbq_filtered = gbbq[gbbq['category'].isin(relevant_cats)].copy() + logger.debug(f"[float_cap] gbbq_filtered(cat∈{relevant_cats})共 {len(gbbq_filtered)} 行,unique full_code={gbbq_filtered['full_code'].nunique()}") + + for full_code, group in gbbq_filtered.groupby('full_code'): + pure_code = full_code[2:] # 去掉 sz/sh/bj 前缀 + if pure_code not in base_caps: + continue + + cap = float(base_caps[pure_code]) + logger.debug(f"[float_cap] {full_code} base_cap={cap:.2f} 万股,gbbq 事件数={len(group)}") + events = group.sort_values('datetime', ascending=False) + snapshots = [] + + for _, ev in events.iterrows(): + date_int = int(ev['datetime']) + cat = int(ev['category']) + + # 先记录:从 date_int 起生效的股本(事件发生后的值) + snapshots.append((date_int, cap)) + + # 再逆向推算事件发生前的 cap + if cat == 1: + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + ratio = songgu + peigu + if ratio > 0: + cap = cap / (1.0 + ratio) + elif cat in (11, 12, 15): + # 增发/解禁/债转股:S_before = S_after - N + value = float(ev.get('hongli_panqianliutong', 0) or 0) + cap = cap - value + if cap <= 0: + cap = float(base_caps[pure_code]) # 异常保护 + elif cat == 10: + # 回购注销:注销后股本减少,回溯要加回来 + value = float(ev.get('hongli_panqianliutong', 0) or 0) + cap = cap + value + logger.debug(f" [{full_code}] cat={cat} date={date_int} → cap_before={cap:.2f} 万股") + + # 兜底:最早历史数据使用推算到底的 cap + snapshots.append((0, cap)) + snapshots.sort(key=lambda x: x[0]) + result[full_code] = snapshots + + logger.info(f"build_float_capital_map 完成,覆盖 {len(result)} 只股票") + return result + + @staticmethod + def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame, float_cap_map: dict = None) -> pd.Series: + """计算换手率(%):volume(股) × 100 / 流通股本(股)。 + + 优先使用 float_cap_map(base.dbf 锚点 + gbbq 逆向推算); + float_cap_map 为 None 时降级到原有 gbbq category==5 逻辑。 + """ + market_val = df['market'].iloc[0] + prefix = {0: 'sz', 1: 'sh', 2: 'bj'}.get(market_val, 'sz') + full_code = prefix + str(df['code'].iloc[0]).zfill(6) + + # ── 优先路径:float_cap_map ────────────────────────────────────────── + if float_cap_map is not None and full_code in float_cap_map: + snap_list = float_cap_map[full_code] # [(date_int, cap_万股)],升序 + logger.debug(f"[turnover] {full_code} 使用 float_cap_map,快照数={len(snap_list)}") + if snap_list: + snap_df = pd.DataFrame(snap_list, columns=['date_int', 'float_cap']) + daily = df[['date', 'volume']].copy() + daily['date_int'] = daily['date'].astype(int) + daily = daily.sort_values('date_int').reset_index(drop=False) + merged = pd.merge_asof( + daily[['index', 'date_int', 'volume']], + snap_df, + on='date_int' + ) + cap = merged['float_cap'] * 10000 # 万股 → 股 + merged['turnover_rate'] = ( + (merged['volume'] * 100 / cap).where(cap > 0).round(4) + ) + # 调试:打印最近几条 + sample = merged[['date_int', 'volume', 'float_cap', 'turnover_rate']].tail(3) + for _, row in sample.iterrows(): + logger.debug( + f" [{full_code}] date={int(row['date_int'])} " + f"vol={row['volume']:.0f} cap={row['float_cap']:.2f}万股 " + f"turnover={row['turnover_rate']:.4f}%" + ) + return merged.set_index('index')['turnover_rate'].reindex(df.index) + + # ── 降级路径:原有 gbbq category==5 逻辑 ──────────────────────────── + if gbbq.empty or 'full_code' not in gbbq.columns: + return pd.Series([None] * len(df), index=df.index, dtype=float) + shares = gbbq[(gbbq['full_code'] == full_code) & (gbbq['category'] == 5)].copy() + if shares.empty: + return pd.Series([None] * len(df), index=df.index, dtype=float) + + # datetime 直接是 YYYYMMDD 整数 + shares = shares.rename(columns={'datetime': 'date_int'}) + shares = shares.sort_values('date_int').drop_duplicates('date_int', keep='last') + + daily = df[['date', 'volume']].copy() + daily['date_int'] = daily['date'].astype(int) + daily = daily.sort_values('date_int').reset_index(drop=False) + + merged = pd.merge_asof( + daily[['index', 'date_int', 'volume']], + shares[['date_int', 'hongli_panqianliutong']], + on='date_int' + ) + # 流通股本单位为万股,× 10000 换算为股 + cap = merged['hongli_panqianliutong'] * 10000 + # 换手率(%) = volume(股) × 100 / 流通股本(股) + merged['turnover_rate'] = (merged['volume'] * 100 / cap).where(cap > 0).round(4) + return merged.set_index('index')['turnover_rate'].reindex(df.index) + + @staticmethod + def process_daily_data( + df: pd.DataFrame, + gbbq: pd.DataFrame = None, + adj_type: str = 'forward', + float_cap_map: dict = None, + ) -> pd.DataFrame: + """日线处理主流程:reset_index → 填充缺失值 → 校验 → 复权 → 日期转 YYYYMMDD 整数。""" + if df.empty: + return df + + processed = df.copy() + + # 确保 date 是列而非索引 + if isinstance(processed.index, pd.DatetimeIndex) or processed.index.name in ('date', 'datetime'): + processed = processed.reset_index() + + # 统一列名:date 列 + if 'date' not in processed.columns and 'datetime' in processed.columns: + processed.rename(columns={'datetime': 'date'}, inplace=True) + + # 确保 date 是 datetime 类型(复权逻辑依赖 Timestamp 比较) + if 'date' in processed.columns and not pd.api.types.is_datetime64_any_dtype(processed['date']): + processed['date'] = pd.to_datetime(processed['date']) + + # 填充缺失值 + for col in ['open', 'high', 'low', 'close', 'volume', 'amount']: + if col in processed.columns: + processed[col] = processed[col].ffill() + + # 数据校验 + processed = DataProcessor._validate_ohlcv(processed) + + # 复权 + if gbbq is not None and not gbbq.empty and 'date' in processed.columns: + if adj_type == 'forward': + processed = DataProcessor.apply_forward_adj(processed, gbbq) + elif adj_type == 'backward': + processed = DataProcessor.apply_backward_adj(processed, gbbq) + else: + processed['adj_factor'] = 1.0 + elif 'adj_factor' not in processed.columns: + processed['adj_factor'] = 1.0 + + # 日期转 YYYYMMDD 字符串 + processed['date'] = processed['date'].dt.strftime('%Y%m%d') + + # 计算换手率 + if gbbq is not None and not gbbq.empty and 'code' in processed.columns: + processed['turnover_rate'] = DataProcessor._calc_turnover_rate( + processed, gbbq, float_cap_map=float_cap_map + ) + else: + processed['turnover_rate'] = None + + # 生成带市场后缀的 stock_code,如 000001.SZ / 600000.SH / 920001.BJ + _suffix_map = {0: '.SZ', 1: '.SH', 2: '.BJ'} + processed['stock_code'] = ( + processed['code'].astype(str).str.zfill(6) + + processed['market'].map(_suffix_map).fillna('.SZ') + ) + processed = processed.drop(columns=['code']) + + return processed + + @staticmethod + def filter_data( + df: pd.DataFrame, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + codes: Optional[List[str]] = None + ) -> pd.DataFrame: + """按 YYYYMMDD 整数日期和股票代码筛选。""" + if df.empty: + return df + result = df.copy() + if 'date' in result.columns: + if start_date: + result = result[result['date'] >= str(start_date)] + if end_date: + result = result[result['date'] <= str(end_date)] + if codes and 'stock_code' in result.columns: + result = result[result['stock_code'].isin(codes)] + return result diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py new file mode 100644 index 0000000..c6a0c4c --- /dev/null +++ b/src/tdx2db/reader.py @@ -0,0 +1,461 @@ +import io +import os +import re +import shutil +import struct +import tempfile +from contextlib import redirect_stdout +from pathlib import Path +from typing import Iterator, List, Optional, Tuple, TYPE_CHECKING + +import pandas as pd +from pytdx.reader import TdxDailyBarReader, GbbqReader + +from .config import config +from .logger import logger + +if TYPE_CHECKING: + from .smb_accessor import SmbAccessor + + +class TdxDataReader: + def __init__( + self, + tdx_path: Optional[str] = None, + vipdoc_path: Optional[str] = None, + smb: Optional['SmbAccessor'] = None, + ) -> None: + self._smb = smb + + if smb is not None: + self.tdx_path = None + self._vipdoc_path = None + elif vipdoc_path: + self.tdx_path = None + self._vipdoc_path = Path(vipdoc_path) + if not self._vipdoc_path.exists(): + raise FileNotFoundError(f"vipdoc 目录不存在: {self._vipdoc_path}") + else: + self.tdx_path = Path(tdx_path or config.tdx_path) + if not self.tdx_path: + raise ValueError("通达信数据路径未设置,请在 .env 中设置 TDX_PATH") + if not self.tdx_path.exists(): + raise FileNotFoundError(f"通达信数据路径不存在: {self.tdx_path}") + self._vipdoc_path = self.tdx_path / 'vipdoc' + self.daily_reader = TdxDailyBarReader() + self.gbbq_reader = GbbqReader() + + def read_gbbq(self) -> pd.DataFrame: + """读取权息文件,返回全量权息 DataFrame。文件不存在时返回空 DataFrame。""" + if self._smb is not None: + return self._read_gbbq_smb() + if self.tdx_path is None: + logger.warning("联网下载模式下不支持读取 gbbq,将跳过复权处理") + return pd.DataFrame() + gbbq_path = self.tdx_path / 'T0002' / 'hq_cache' / 'gbbq' + if not gbbq_path.exists(): + logger.warning(f"权息文件不存在: {gbbq_path},将跳过复权处理") + return pd.DataFrame() + try: + df = self.gbbq_reader.get_df(str(gbbq_path)) + if df.empty: + return pd.DataFrame() + market_prefix = df['market'].map({0: 'sz', 1: 'sh'}) + df['full_code'] = market_prefix + df['code'].astype(str).str.zfill(6) + return df + except Exception as e: + logger.warning(f"读取权息文件时出错: {e},将跳过复权处理") + return pd.DataFrame() + + def read_stock_names(self) -> dict: + """读取 szs/shs/bjs.tnf,返回按市场分组的名称字典。 + 返回格式:{'SZ': {6位code: 名}, 'SH': {6位code: 名}, 'BJ': {6位code: 名}} + 三个市场代码空间有重叠(如 000001 在 SZ 是平安银行,在 SH 是上证指数), + 必须分开存储,不能合并为单一字典。 + 本地模式读 T0002/hq_cache/*.tnf;SMB 模式下载后解析。 + 文件不存在时对应市场返回空字典。 + """ + if self._smb is not None: + return self._read_stock_names_smb() + if self.tdx_path is None: + logger.warning("联网下载模式下不支持读取股票名称,将以代码代替名称") + return {'SZ': {}, 'SH': {}, 'BJ': {}} + result = {} + for fname, suffix in (('szs.tnf', 'SZ'), ('shs.tnf', 'SH'), ('bjs.tnf', 'BJ')): + path = self.tdx_path / 'T0002' / 'hq_cache' / fname + if path.exists(): + result[suffix] = self._parse_tnf_file(str(path)) + else: + logger.warning(f"股票名称文件不存在: {path}") + result[suffix] = {} + return result + + def read_base_dbf(self) -> dict: + """读取 base.dbf,返回 {code_6位: 流通股本万股} 字典。文件不存在时返回空字典。""" + if self._smb is not None: + return self._read_base_dbf_smb() + if self.tdx_path is None: + logger.warning("联网下载模式下不支持读取 base.dbf") + return {} + dbf_path = self.tdx_path / 'T0002' / 'hq_cache' / 'base.dbf' + if not dbf_path.exists(): + logger.warning(f"base.dbf 不存在: {dbf_path}") + return {} + return self._parse_base_dbf(str(dbf_path)) + + def get_stock_list(self) -> list: + """扫描本地 .day 文件,返回有数据的股票代码列表(000001.SZ 格式)。""" + if self._smb is not None: + return self._get_stock_list_smb() + sz_path = self._vipdoc_path / 'sz' / 'lday' + sh_path = self._vipdoc_path / 'sh' / 'lday' + bj_path = self._vipdoc_path / 'bj' / 'lday' + if not (sz_path.exists() or sh_path.exists() or bj_path.exists()): + raise FileNotFoundError("无法找到股票数据目录") + + codes = [] + if sz_path.exists(): + for f in sz_path.glob('*.day'): + pure = f.stem[-6:].zfill(6) + if re.match(r'^(000|001|002|300|301)\d{3}$', pure): + codes.append(pure + '.SZ') + + if sh_path.exists(): + for f in sh_path.glob('*.day'): + pure = f.stem[-6:].zfill(6) + if re.match(r'^(60\d{4}|688\d{3})$', pure): + codes.append(pure + '.SH') + + if bj_path.exists(): + for f in bj_path.glob('*.day'): + pure = f.stem[-6:].zfill(6) + if re.match(r'^(8\d{5}|92\d{4})$', pure): + codes.append(pure + '.BJ') + + if not codes: + raise FileNotFoundError("未找到任何股票数据文件") + return codes + + def read_daily_data(self, market: int, code: str) -> pd.DataFrame: + """读取单只股票日线数据,返回含 code/market 列的 DataFrame(date 为 DatetimeIndex)。""" + market_map = {0: 'sz', 1: 'sh', 2: 'bj'} + market_folder = market_map[market] + pure_code = code[-6:] if len(code) > 6 else code + filename = f"{market_folder}{pure_code}.day" + + if self._smb is not None: + unc = self._smb.day_file_unc(market_folder, filename) + if not self._smb.exists(unc): + raise FileNotFoundError(f"SMB 日线数据文件不存在: {unc}") + data = self._read_daily_via_smb(unc) + else: + file_path = self._vipdoc_path / market_folder / 'lday' / filename + if not file_path.exists(): + raise FileNotFoundError(f"日线数据文件不存在: {file_path}") + try: + with redirect_stdout(io.StringIO()): + sec_type = self.daily_reader.get_security_type(str(file_path)) + if sec_type in self.daily_reader.SECURITY_TYPE: + data = self.daily_reader.get_df(str(file_path)) + else: + data = self._read_day_file_raw(str(file_path)) + except Exception: + data = self._read_day_file_raw(str(file_path)) + + data['code'] = pure_code + data['market'] = market + data['volume'] = data['volume'] * 100 # 手 → 股 + data['amount'] = data['amount'] / 10000 # 元 → 万元 + return data + + @staticmethod + def _read_day_file_raw(fname: str) -> pd.DataFrame: + """直接解析 .day 二进制文件(用于科创板等 pytdx 不支持的证券类型)。""" + rows = [] + with open(fname, 'rb') as f: + content = f.read() + record_size = struct.calcsize(' int: + """通过搜索相邻股票代码,自动探测 .tnf 文件的记录长度。 + + 策略一(精确):搜索已知相邻代码对,两者位置之差即 record_len。 + 策略二(通用):扫描所有"偏移0处的6位纯数字",统计最高频间距。 + 两者都失败则返回默认值 314。 + """ + import re as _re + from collections import Counter as _Counter + body = data[header_offset:] + + # 策略一:已知相邻代码对(覆盖三个市场) + pairs = [ + (b'000001', b'000002'), + (b'000002', b'000003'), + (b'000004', b'000005'), + (b'600000', b'600001'), + (b'600001', b'600002'), + ] + for a, b in pairs: + pos_a = body.find(a) + if pos_a < 0: + continue + pos_b = body.find(b, pos_a + 1) + if pos_b < 0: + continue + gap = pos_b - pos_a + if 100 < gap < 65536: + return gap + + # 策略二:扫描前 200KB,找所有"位置对齐"的6位纯数字,统计间距众数 + scan = body[:200 * 1024] + positions = [] + for m in _re.finditer(rb'\d{6}', scan): + positions.append(m.start()) + gap_counter = _Counter() + for i in range(len(positions) - 1): + gap = positions[i + 1] - positions[i] + if 100 < gap < 65536: + gap_counter[gap] += 1 + if gap_counter: + return gap_counter.most_common(1)[0][0] + + return 314 + + @staticmethod + def _detect_tnf_name_offset(data: bytes, header_offset: int, record_len: int) -> int: + """在已知 record_len 的前提下,探测中文名称字段的字节偏移。 + 策略:取前若干条记录,在 record 内寻找最长的连续非零字节串(即名称字段)。 + 返回探测到的偏移,失败则返回默认值 23。 + """ + body = data[header_offset:] + # 取前 10 条记录,统计各偏移处出现非零字节的频率 + sample_count = min(10, len(body) // record_len) + if sample_count == 0: + return 23 + # 对每条记录,找 offset 6 之后第一个非零字节簇的起始位置 + offsets = [] + for i in range(sample_count): + rec = body[i * record_len: (i + 1) * record_len] + # 跳过前 6 字节(代码),找后续第一个非零字节 + j = 6 + while j < len(rec) and rec[j] == 0: + j += 1 + if j < len(rec): + offsets.append(j) + if offsets: + from collections import Counter + return Counter(offsets).most_common(1)[0][0] + return 23 + + @staticmethod + def _parse_tnf_file(file_path: str) -> dict: + """解析 TDX .tnf 文件,返回 {6位code: 名称} 字典。 + 自动探测 record_len 和名称字段偏移,兼容新旧版本通达信。 + 非股票记录(代码含非数字字符)自动跳过。 + """ + result = {} + header_offset = 50 + with open(file_path, 'rb') as f: + data = f.read() + + record_len = TdxDataReader._detect_tnf_record_len(data, header_offset) + name_offset = TdxDataReader._detect_tnf_name_offset(data, header_offset, record_len) + logger.debug(f"[tnf] {file_path}: record_len={record_len}, name_offset={name_offset}") + + body = data[header_offset:] + for i in range(len(body) // record_len): + rec = body[i * record_len: (i + 1) * record_len] + try: + code = rec[0:6].decode('ascii').strip('\x00').strip() + if not code or not code.isdigit(): + continue + name_bytes = rec[name_offset: name_offset + 32].split(b'\x00')[0] + name = name_bytes.decode('gbk', errors='replace').strip() + result[code] = name + except Exception: + continue + return result + + @staticmethod + def _parse_base_dbf(path: str) -> dict: + try: + from dbfread import DBF + except ImportError: + raise ImportError("缺少 dbfread 库,请执行: pip install dbfread") + result = {} + for record in DBF(path, encoding='gbk', load=True): + code = str(record.get('GPDM', '') or '').strip().zfill(6) + ltag = record.get('LTAG') + if code and ltag is not None: + try: + result[code] = float(ltag) + except (TypeError, ValueError): + pass + logger.info(f"base.dbf 读取完成,共 {len(result)} 条记录") + return result + + def _read_stock_names_smb(self) -> dict: + """SMB 模式:下载三个 .tnf 文件到临时文件后解析。 + 返回格式:{'SZ': {...}, 'SH': {...}, 'BJ': {...}} + 直接尝试下载(不做 exists() 预判),避免权限等原因导致误判而静默跳过。 + """ + result = {'SZ': {}, 'SH': {}, 'BJ': {}} + for market, suffix in (('szs', 'SZ'), ('shs', 'SH'), ('bjs', 'BJ')): + unc = self._smb.tnf_unc(market) + try: + tmp_path = self._smb.download_to_tmp(unc, suffix='.tnf') + except Exception as e: + logger.warning(f"SMB 下载 {market}.tnf 失败: {e}") + continue + try: + parsed = self._parse_tnf_file(tmp_path) + logger.debug(f"{market}.tnf 解析完成: {len(parsed)} 条") + result[suffix] = parsed + except Exception as e: + logger.warning(f"SMB 解析 {market}.tnf 出错: {e}") + finally: + os.unlink(tmp_path) + return result + + def _read_base_dbf_smb(self) -> dict: + unc = self._smb.base_dbf_unc + if not self._smb.exists(unc): + raise FileNotFoundError(f"SMB base.dbf 不存在: {unc}") + try: + tmp_path = self._smb.download_to_tmp(unc, suffix='.dbf') + except Exception as e: + raise RuntimeError( + f"base.dbf 无法读取(可能被 TDX 锁定,请关闭 TDX 后重试): {e}" + ) from e + try: + return self._parse_base_dbf(tmp_path) + finally: + os.unlink(tmp_path) + + def _read_gbbq_smb(self) -> pd.DataFrame: + unc = self._smb.gbbq_unc + if not self._smb.exists(unc): + logger.warning(f"SMB 权息文件不存在: {unc},将跳过复权处理") + return pd.DataFrame() + tmp_path = self._smb.download_to_tmp(unc, suffix='') + try: + df = self.gbbq_reader.get_df(tmp_path) + if df.empty: + return pd.DataFrame() + market_prefix = df['market'].map({0: 'sz', 1: 'sh'}) + df['full_code'] = market_prefix + df['code'].astype(str).str.zfill(6) + return df + except Exception as e: + logger.warning(f"SMB 读取权息文件时出错: {e},将跳过复权处理") + return pd.DataFrame() + finally: + os.unlink(tmp_path) + + def _get_stock_list_smb(self) -> list: + codes = [] + for market, pattern, suffix in [ + ('sz', r'^(000|001|002|300|301)\d{3}$', '.SZ'), + ('sh', r'^(60\d{4}|688\d{3})$', '.SH'), + ('bj', r'^(8\d{5}|92\d{4})$', '.BJ'), + ]: + unc_dir = self._smb.lday_dir_unc(market) + files = self._smb.list_files(unc_dir, suffix='.day') + for fname in files: + stem = Path(fname).stem + pure = stem[-6:].zfill(6) + if re.match(pattern, pure): + codes.append(pure + suffix) + if not codes: + raise FileNotFoundError("SMB 模式下未找到任何股票数据文件") + return codes + + def _read_daily_via_smb(self, unc: str) -> pd.DataFrame: + tmp_path = self._smb.download_to_tmp(unc, suffix='.day') + try: + return self._parse_local_day_file(tmp_path) + finally: + os.unlink(tmp_path) + + def _parse_local_day_file(self, path: str) -> pd.DataFrame: + """解析本地 .day 文件(pytdx 优先,不支持的类型降级到原始二进制解析)。""" + try: + with redirect_stdout(io.StringIO()): + sec_type = self.daily_reader.get_security_type(path) + if sec_type in self.daily_reader.SECURITY_TYPE: + return self.daily_reader.get_df(path) + else: + return self._read_day_file_raw(path) + except Exception: + return self._read_day_file_raw(path) + + def read_daily_data_batch( + self, + stocks_meta: List[Tuple[int, str, str]], + batch_size: int = 200, + smb_workers: int = 16, + ) -> Iterator[Tuple[str, 'pd.DataFrame | Exception']]: + """批量并发读取日线数据(仅 SMB 模式)。 + + Parameters + ---------- + stocks_meta: + List of (market, code, db_code),与 read_daily_data() 参数对应。 + - market: 0=深圳 1=上海 2=北京 + - code: 带市场前缀的代码,如 sz000001 + - db_code: 数据库代码,如 000001.SZ + + Yields + ------ + (db_code, DataFrame) 或 (db_code, Exception) + """ + if self._smb is None: + raise RuntimeError("read_daily_data_batch 仅支持 SMB 模式") + + market_map = {0: 'sz', 1: 'sh', 2: 'bj'} + + for batch_start in range(0, len(stocks_meta), batch_size): + batch = stocks_meta[batch_start:batch_start + batch_size] + + # 构建 unc_path → (market, pure_code, db_code) 映射 + unc_map: dict = {} + for market, code, db_code in batch: + folder = market_map[market] + pure_code = code[-6:] + filename = f"{folder}{pure_code}.day" + unc = self._smb.day_file_unc(folder, filename) + unc_map[unc] = (market, pure_code, db_code) + + tmp_dir = tempfile.mkdtemp(prefix='tdx2db_batch_') + try: + local_map = self._smb.download_batch_to_dir( + list(unc_map.keys()), tmp_dir, max_workers=smb_workers + ) + for unc, (market, pure_code, db_code) in unc_map.items(): + if unc not in local_map: + yield db_code, FileNotFoundError(f"SMB 下载失败: {unc}") + continue + try: + data = self._parse_local_day_file(local_map[unc]) + data['code'] = pure_code + data['market'] = market + data['volume'] = data['volume'] * 100 # 手 → 股 + data['amount'] = data['amount'] / 10000 # 元 → 万元 + yield db_code, data + except Exception as e: + yield db_code, e + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/src/tdx2db/smb_accessor.py b/src/tdx2db/smb_accessor.py new file mode 100644 index 0000000..76ba22f --- /dev/null +++ b/src/tdx2db/smb_accessor.py @@ -0,0 +1,176 @@ +"""SMB 网络文件访问封装。 + +依赖:smbprotocol(pip install smbprotocol) +""" +import os +import tempfile +from typing import Dict, List, Optional + +import smbclient +import smbclient.path + +from .logger import logger + + +class SmbAccessor: + """封装对 SMB 共享目录的只读访问。 + + UNC 路径格式:\\\\host\\share\\tdx_path\\vipdoc\\... + """ + + def __init__( + self, + host: str, + share: str, + tdx_path: str = '', + username: Optional[str] = None, + password: Optional[str] = None, + port: int = 445, + ) -> None: + self.host = host + self.share = share.strip('\\/') + self._tdx_rel = tdx_path.strip('\\/') + self._username = username + self._password = password + self._port = port + self._registered = False + + # ── 上下文管理器 ────────────────────────────────────────────────────────── + + def __enter__(self) -> 'SmbAccessor': + self._register() + return self + + def __exit__(self, *_) -> None: + self._unregister() + + def _register(self) -> None: + if not self._registered: + smbclient.register_session( + self.host, + username=self._username, + password=self._password, + port=self._port, + ) + self._registered = True + logger.debug(f"SMB 会话已建立: {self.host}:{self._port}") + + def _unregister(self) -> None: + if self._registered: + try: + smbclient.reset_connection_cache() + except Exception: + pass + self._registered = False + + # ── 路径构建 ────────────────────────────────────────────────────────────── + + def _unc(self, *parts: str) -> str: + """构建 UNC 路径字符串。 + + 示例(tdx_path='TDX'): + _unc('vipdoc', 'sz', 'lday') → '\\\\host\\share\\TDX\\vipdoc\\sz\\lday' + 示例(tdx_path=''): + _unc('vipdoc') → '\\\\host\\share\\vipdoc' + """ + segments = [self.host, self.share] + if self._tdx_rel: + segments.append(self._tdx_rel) + segments.extend(p.strip('\\/') for p in parts if p) + return '\\\\' + '\\'.join(segments) + + @property + def vipdoc_unc(self) -> str: + return self._unc('vipdoc') + + @property + def gbbq_unc(self) -> str: + return self._unc('T0002', 'hq_cache', 'gbbq') + + @property + def base_dbf_unc(self) -> str: + return self._unc('T0002', 'hq_cache', 'base.dbf') + + def lday_dir_unc(self, market: str) -> str: + return self._unc('vipdoc', market, 'lday') + + def day_file_unc(self, market: str, filename: str) -> str: + return self._unc('vipdoc', market, 'lday', filename) + + def tnf_unc(self, market: str) -> str: + """返回 .tnf 股票名称文件的 UNC 路径。market: 'szs' / 'shs' / 'bjs'""" + return self._unc('T0002', 'hq_cache', f'{market}.tnf') + + # ── 核心 I/O 操作 ───────────────────────────────────────────────────────── + + def exists(self, unc_path: str) -> bool: + try: + self._register() + return smbclient.path.exists(unc_path) + except Exception: + return False + + def list_files(self, unc_dir: str, suffix: str = '') -> List[str]: + """列出目录下的文件名(不含路径),可按后缀过滤。""" + try: + entries = smbclient.listdir(unc_dir) + if suffix: + return [e for e in entries if e.endswith(suffix)] + return entries + except Exception as e: + logger.warning(f"SMB 列目录失败 {unc_dir}: {e}") + return [] + + def read_bytes(self, unc_path: str) -> bytes: + with smbclient.open_file(unc_path, mode='rb', share_access='rw') as f: + return f.read() + + def download_to_tmp(self, unc_path: str, suffix: str = '.day') -> str: + """将远程文件下载到本地临时文件,返回临时文件路径字符串。 + + 调用方负责删除临时文件(使用 try/finally)。 + """ + data = self.read_bytes(unc_path) + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + try: + tmp.write(data) + tmp.flush() + tmp.close() + return tmp.name + except Exception: + tmp.close() + os.unlink(tmp.name) + raise + + def download_batch_to_dir( + self, + unc_paths: List[str], + dest_dir: str, + max_workers: int = 16, + ) -> Dict[str, str]: + """并发下载多个 UNC 文件到 dest_dir 目录,返回 {unc_path: local_path}。 + + 下载失败的文件不出现在返回字典中(异常已记录到日志)。 + """ + import concurrent.futures + import hashlib + + def _download_one(unc: str): + name = hashlib.md5(unc.encode()).hexdigest() + '.day' + local = os.path.join(dest_dir, name) + data = self.read_bytes(unc) + with open(local, 'wb') as f: + f.write(data) + return unc, local + + result: Dict[str, str] = {} + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {pool.submit(_download_one, unc): unc for unc in unc_paths} + for fut in concurrent.futures.as_completed(futures): + unc = futures[fut] + try: + _, local = fut.result() + result[unc] = local + except Exception as e: + logger.warning(f"SMB 批量下载失败 {unc}: {e}") + return result diff --git a/src/tdx2db/storage.py b/src/tdx2db/storage.py new file mode 100644 index 0000000..9459229 --- /dev/null +++ b/src/tdx2db/storage.py @@ -0,0 +1,225 @@ +import datetime +import os +from pathlib import Path +from typing import Optional, Tuple + +import pandas as pd +from sqlalchemy import create_engine, Column, Integer, Float, String, UniqueConstraint, text, DateTime +from sqlalchemy.orm import declarative_base, sessionmaker + +from .config import config +from .logger import logger + +Base = declarative_base() + + +class DailyData(Base): + __tablename__ = 'daily_data' + __table_args__ = (UniqueConstraint('stock_code', 'date'),) + + id = Column(Integer, primary_key=True) + stock_code = Column(String(12), index=True) + market = Column(Integer) + date = Column(String(8), index=True) # YYYYMMDD 字符串 + open = Column(Float) + high = Column(Float) + low = Column(Float) + close = Column(Float) + volume = Column(Float) + amount = Column(Float) + adj_factor = Column(Float) + turnover_rate = Column(Float) # 换手率(%),暂时为 NULL + + +class StockInfo(Base): + __tablename__ = 'stock_info' + __table_args__ = (UniqueConstraint('stock_code'),) + + stock_code = Column(String(12), primary_key=True) # 000001.SZ + stock_name = Column(String(50)) + + +class KlineStatistics(Base): + __tablename__ = 'kline_statistics' + + id = Column(Integer, primary_key=True) + stock_count = Column(Integer) + total_rows = Column(Integer) + sync_time = Column(DateTime) + + +_VALID_TABLES = frozenset({'daily_data', 'stock_info'}) + + +class DataStorage: + def __init__(self, db_url: Optional[str] = None) -> None: + self.db_url = db_url or config.database_url + self._db_type = self.db_url.split('://')[0].split('+')[0] + self.engine = create_engine(self.db_url) + Base.metadata.create_all(self.engine) + self.Session = sessionmaker(bind=self.engine) + + def get_latest_date_by_code(self, code: str) -> Optional[str]: + """返回指定股票在 daily_data 中最新的 YYYYMMDD 字符串日期,无数据返回 None。""" + try: + with self.engine.connect() as conn: + row = conn.execute( + text("SELECT MAX(date) FROM daily_data WHERE stock_code = :code"), + {"code": code} + ).fetchone() + return row[0] if row and row[0] is not None else None + except Exception as e: + logger.debug(f"查询 {code} 最新日期出错: {e}") + return None + + def get_all_latest_dates(self) -> dict: + """一次查询返回所有股票最新日期 {code: YYYYMMDD str}。""" + try: + with self.engine.connect() as conn: + rows = conn.execute( + text("SELECT stock_code, MAX(date) FROM daily_data GROUP BY stock_code") + ).fetchall() + return {r[0]: r[1] for r in rows if r[1] is not None} + except Exception as e: + logger.debug(f"查询所有股票最新日期出错: {e}") + return {} + + def delete_stock_data(self, code: str) -> None: + """删除某只股票全部日线记录(除权后全量重写前调用)。""" + try: + with self.engine.connect() as conn: + conn.execute( + text("DELETE FROM daily_data WHERE stock_code = :code"), + {"code": code} + ) + conn.commit() + except Exception as e: + logger.error(f"删除 {code} 数据出错: {e}") + + def save_incremental( + self, + df: pd.DataFrame, + table_name: str, + conflict_columns: Tuple[str, ...] = ('stock_code', 'date'), + batch_size: int = 10000 + ) -> int: + """增量保存,跳过重复记录(ON CONFLICT DO NOTHING / INSERT OR IGNORE / INSERT IGNORE)。""" + if table_name not in _VALID_TABLES: + raise ValueError(f"不允许写入的表名: {table_name}") + if df.empty: + return 0 + + df_to_save = df.copy() + if isinstance(df_to_save.index, pd.DatetimeIndex) or df_to_save.index.name in ('date', 'datetime'): + df_to_save = df_to_save.reset_index(drop=True) + + columns = list(df_to_save.columns) + columns_str = ', '.join(columns) + total_rows = len(df_to_save) + + try: + if self._db_type == 'postgresql': + self._save_incremental_pg(df_to_save, columns, columns_str, table_name, conflict_columns, batch_size) + else: + placeholders = ', '.join([f':{c}' for c in columns]) + if self._db_type == 'mysql': + sql = text(f"INSERT IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") + else: # sqlite + sql = text(f"INSERT OR IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") + + with self.engine.connect() as conn: + for i in range(0, total_rows, batch_size): + batch = df_to_save.iloc[i:i + batch_size].astype(object).where( + df_to_save.iloc[i:i + batch_size].notna(), None + ) + records = batch.to_dict('records') + for rec in records: + for k, v in rec.items(): + if isinstance(v, pd.Timestamp): + rec[k] = v.to_pydatetime() + elif v is pd.NaT: + rec[k] = None + conn.execute(sql, records) + conn.commit() + + logger.debug(f"增量保存完成: {total_rows} 条 → {table_name}(重复已跳过)") + return total_rows + except Exception as e: + logger.error(f"增量保存到 {table_name} 出错: {e}") + return 0 + + def _save_incremental_pg(self, df, columns, columns_str, table_name, conflict_columns, batch_size): + from psycopg2.extras import execute_values + conflict_str = ', '.join(conflict_columns) + sql = f"INSERT INTO {table_name} ({columns_str}) VALUES %s ON CONFLICT ({conflict_str}) DO NOTHING" + df_clean = df.astype(object).where(df.notna(), None) + values = list(df_clean.itertuples(index=False, name=None)) + raw_conn = self.engine.raw_connection() + try: + cur = raw_conn.cursor() + execute_values(cur, sql, values, page_size=batch_size) + raw_conn.commit() + finally: + raw_conn.close() + + def save_stock_info(self, df: pd.DataFrame) -> bool: + """保存股票列表到 stock_info 表(upsert,名称可更新)。""" + if df.empty: + return False + try: + with self.engine.connect() as conn: + if self._db_type == 'postgresql': + sql = text(""" + INSERT INTO stock_info (stock_code, stock_name) + VALUES (:stock_code, :stock_name) + ON CONFLICT (stock_code) DO UPDATE SET stock_name = EXCLUDED.stock_name + """) + elif self._db_type == 'mysql': + sql = text(""" + INSERT INTO stock_info (stock_code, stock_name) + VALUES (:stock_code, :stock_name) + ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name) + """) + else: # sqlite + sql = text(""" + INSERT INTO stock_info (stock_code, stock_name) + VALUES (:stock_code, :stock_name) + ON CONFLICT (stock_code) DO UPDATE SET stock_name = excluded.stock_name + """) + conn.execute(sql, df.to_dict('records')) + conn.commit() + logger.debug(f"stock_info upsert 完成: {len(df)} 条") + return True + except Exception as e: + logger.error(f"保存 stock_info 出错: {e}") + return False + + def save_sync_statistics(self, stock_count: int) -> None: + """同步完成后写入一条统计记录到 kline_statistics。""" + try: + with self.engine.connect() as conn: + row = conn.execute(text("SELECT COUNT(*) FROM daily_data")).fetchone() + total_rows = row[0] if row else 0 + conn.execute( + text( + "INSERT INTO kline_statistics (stock_count, total_rows, sync_time) " + "VALUES (:stock_count, :total_rows, :sync_time)" + ), + { + "stock_count": stock_count, + "total_rows": total_rows, + "sync_time": datetime.datetime.now(), + } + ) + conn.commit() + logger.info(f"统计已记录: stock_count={stock_count}, total_rows={total_rows}") + except Exception as e: + logger.error(f"写入 kline_statistics 出错: {e}") + + def save_to_csv(self, df: pd.DataFrame, filename: str, csv_path: Optional[str] = None) -> Optional[str]: + path = Path(csv_path or config.csv_output_path) + os.makedirs(path, exist_ok=True) + file_path = path / f"{filename}.csv" + df.to_csv(file_path, index=False, encoding='utf-8') + logger.info(f"数据已保存到: {file_path}") + return str(file_path) diff --git a/tests/test_daily.py b/tests/test_daily.py new file mode 100644 index 0000000..8052771 --- /dev/null +++ b/tests/test_daily.py @@ -0,0 +1,310 @@ +""" +日线数据同步测试套件(使用 SQLite 内存库 + mock TDX reader) + +测试用例: +1. test_full_sync_one_month - 全量所有股票,1个月日期范围 +2. test_single_stock_one_year - 指定股票,最近1年 +3. test_forward_adj_price - 前复权价格计算正确性 +4. test_incremental_update - 增量更新无重复,有除权时旧数据被替换 +""" + +import threading +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from src.tdx2db.processor import DataProcessor +from src.tdx2db.storage import DataStorage + + +# ─── 测试数据工厂 ──────────────────────────────────────────────────────────── + +def make_daily_df(code: str, market: int, start_date: str, periods: int) -> pd.DataFrame: + """生成假日线数据,date 列为 DatetimeIndex(模拟 reader 返回格式)。""" + dates = pd.bdate_range(start=start_date, periods=periods) + df = pd.DataFrame({ + 'open': [10.0] * periods, + 'high': [11.0] * periods, + 'low': [9.0] * periods, + 'close': [10.5] * periods, + 'volume': [1e6] * periods, + 'amount': [1e7] * periods, + 'code': [code[-6:]] * periods, + 'market': [market] * periods, + }, index=dates) + df.index.name = 'date' + return df + + +def make_gbbq_empty() -> pd.DataFrame: + return pd.DataFrame() + + +def make_gbbq_with_event(full_code: str, ex_date_int: int) -> pd.DataFrame: + """构造一条除权记录(category=1,送股 10%)。""" + if full_code.startswith('sh'): + market_val = 1 + elif full_code.startswith('bj'): + market_val = 2 + else: + market_val = 0 + return pd.DataFrame([{ + 'market': market_val, + 'code': int(full_code[2:]), + 'datetime': ex_date_int, + 'category': 1, + 'hongli_panqianliutong': 0, + 'peigujia_qianzongguben': 0, + 'songgu_qianzongguben': 1.0, # 每10股送1股 → songgu/10 = 0.1 + 'peigu_houzongguben': 0, + 'full_code': full_code, + }]) + + +# ─── 测试用例 ──────────────────────────────────────────────────────────────── + +class TestFullSyncOneMonth: + """全量获取所有股票1个月日线数据。""" + + def test_records_written_and_date_format(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + # 模拟3只股票,每只20个交易日(约1个月) + stocks = [ + ('sz000001', 0), ('sz000002', 0), ('sh600000', 1) + ] + gbbq = make_gbbq_empty() + + for code, market in stocks: + df = make_daily_df(code, market, '2024-03-01', 20) + df = df.reset_index() + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='none') + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) + + # 验证 + with storage.engine.connect() as conn: + from sqlalchemy import text + rows = conn.execute(text("SELECT COUNT(*) FROM daily_data")).fetchone() + assert rows[0] == 60, f"期望60条,实际{rows[0]}" + + # date 列应为 YYYYMMDD 字符串 + sample = conn.execute(text("SELECT date FROM daily_data LIMIT 1")).fetchone() + assert isinstance(sample[0], str), f"date 应为字符串,实际类型: {type(sample[0])}" + assert '20240301' <= sample[0] <= '20241231' + + +class TestSingleStockOneYear: + """指定股票最近1年日线数据。""" + + def test_date_range_correct(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + # 生成约250个交易日(1年) + df = make_daily_df('sz000001', 0, '2023-04-01', 250) + df = df.reset_index() + gbbq = make_gbbq_empty() + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='none') + + # 按日期过滤:只取 20240101 之后 + filtered = processor.filter_data(processed, start_date=20240101) + storage.save_incremental(filtered, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + rows = conn.execute( + text("SELECT MIN(date), MAX(date), COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") + ).fetchone() + min_date, max_date, count = rows + assert min_date >= '20240101', f"最小日期 {min_date} 应 >= 20240101" + assert count > 0 + + +class TestForwardAdjPrice: + """前复权价格计算正确性:除权日前后价格应连续(无跳空)。""" + + def test_price_continuity_across_ex_date(self): + processor = DataProcessor() + + # 构造数据:2024-01-01 ~ 2024-03-31,共约60个交易日 + # 除权日设为 2024-02-01(送股10%,songgu=1.0 即每10股送1股) + df = make_daily_df('sz000001', 0, '2024-01-02', 60) + df = df.reset_index() + + # 所有原始收盘价均为 10.5 + gbbq = make_gbbq_with_event('sz000001', 20240201) + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') + + ex_date = '20240201' + before = processed[processed['date'] < ex_date] + on_or_after = processed[processed['date'] >= ex_date] + + assert not before.empty, "除权日前应有数据" + assert not on_or_after.empty, "除权日后应有数据" + + # 前复权后,除权日前的价格应被调低(adj_factor < 1) + # 送股10%:factor = prev_close / (prev_close * 1.1) ≈ 0.909 + # 除权日前 close = 10.5 * 0.909 ≈ 9.545 + # 除权日后 close = 10.5(原始价格不变) + adj_close_before = before['close'].iloc[-1] + raw_close_after = on_or_after['close'].iloc[0] + + # 验证复权因子已应用(除权日前价格应低于原始价格) + assert adj_close_before < 10.5, f"前复权后除权日前收盘价应 < 10.5,实际 {adj_close_before}" + assert abs(raw_close_after - 10.5) < 0.01, f"除权日后收盘价应保持 10.5,实际 {raw_close_after}" + + # 验证 adj_factor 列存在 + assert 'adj_factor' in processed.columns + + def test_no_adj_factor_without_gbbq(self): + processor = DataProcessor() + df = make_daily_df('sz000001', 0, '2024-01-02', 10).reset_index() + processed = processor.process_daily_data(df, gbbq=None, adj_type='forward') + assert 'adj_factor' in processed.columns + assert (processed['adj_factor'] == 1.0).all() + + +class TestIncrementalUpdate: + """增量更新:无重复行;有除权时旧数据被替换。""" + + def test_no_duplicates_on_second_sync(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + gbbq = make_gbbq_empty() + + # 第一次同步:20个交易日 + df1 = make_daily_df('sz000001', 0, '2024-01-02', 20).reset_index() + p1 = processor.process_daily_data(df1, gbbq=gbbq, adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) + + # 第二次同步:同样的数据(模拟重复运行) + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute( + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") + ).fetchone()[0] + assert count == 20, f"重复同步后应仍为20条,实际{count}" + + def test_incremental_appends_new_records(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + gbbq = make_gbbq_empty() + + # 第一次:前10个交易日 + df1 = make_daily_df('sz000001', 0, '2024-01-02', 10).reset_index() + p1 = processor.process_daily_data(df1, gbbq=gbbq, adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) + + last_date = storage.get_latest_date_by_code('000001.SZ') + assert last_date is not None + + # 第二次:后10个交易日(增量,只取 last_date 之后) + df2 = make_daily_df('sz000001', 0, '2024-01-02', 25).reset_index() + p2 = processor.process_daily_data(df2, gbbq=gbbq, adj_type='none') + p2_new = p2[p2['date'] > last_date] + storage.save_incremental(p2_new, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute( + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") + ).fetchone()[0] + assert count == 25, f"增量后应为25条,实际{count}" + + def test_full_refresh_on_ex_rights(self, tmp_path): + """有除权事件时,旧数据应被删除并重写(复权价格更新)。""" + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + # 第一次同步:无复权 + df = make_daily_df('sz000001', 0, '2024-01-02', 20).reset_index() + gbbq_empty = make_gbbq_empty() + p1 = processor.process_daily_data(df, gbbq=gbbq_empty, adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) + + # 模拟发现除权事件 → 删除旧数据 + 重写前复权数据 + gbbq = make_gbbq_with_event('sz000001', 20240115) + p2 = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') + + storage.delete_stock_data('000001.SZ') + storage.save_incremental(p2, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute( + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") + ).fetchone()[0] + # 重写后记录数应与原始相同 + assert count == 20, f"全量重写后应为20条,实际{count}" + + # 除权日前的价格应已被调整(< 10.5) + adj_row = conn.execute( + text("SELECT close FROM daily_data WHERE stock_code='000001.SZ' AND date < '20240115' ORDER BY date DESC LIMIT 1") + ).fetchone() + if adj_row: + assert adj_row[0] < 10.5, f"前复权后除权日前收盘价应 < 10.5,实际 {adj_row[0]}" + + +# ─── 三市场测试 ────────────────────────────────────────────────────────────── + +class TestMultiMarket: + """验证三个市场(sz/sh/bj)日线数据均能正确写入,且 market 列值准确。""" + + def test_all_three_markets_sync(self, tmp_path): + """sz=0, sh=1, bj=2 三市场股票同步后记录数和 market 值均正确。""" + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + gbbq = make_gbbq_empty() + + stocks = [ + ('sz000001', 0), # 深圳主板 + ('sh600000', 1), # 上海主板 + ('bj920001', 2), # 北交所(92开头新股) + ] + for code, market in stocks: + df = make_daily_df(code, market, '2024-03-01', 10).reset_index() + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='none') + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + for db_code, expected_market in [('000001.SZ', 0), ('600000.SH', 1), ('920001.BJ', 2)]: + row = conn.execute(text( + f"SELECT market, COUNT(*) FROM daily_data WHERE stock_code='{db_code}' GROUP BY market" + )).fetchone() + assert row is not None, f"{db_code} 无数据" + assert row[0] == expected_market, f"{db_code} market 应为 {expected_market},实际 {row[0]}" + assert row[1] == 10, f"{db_code} 应有10条记录,实际 {row[1]}" + + def test_bj_ex_rights_refresh(self, tmp_path): + """北交所股票有除权事件时,旧数据应被删除并重写。""" + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + df = make_daily_df('bj920001', 2, '2024-01-02', 20).reset_index() + gbbq = make_gbbq_with_event('bj920001', 20240115) + + p1 = processor.process_daily_data(df, gbbq=make_gbbq_empty(), adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) + + storage.delete_stock_data('920001.BJ') + p2 = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') + storage.save_incremental(p2, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute(text( + "SELECT COUNT(*) FROM daily_data WHERE stock_code='920001.BJ'" + )).fetchone()[0] + assert count == 20 diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 0000000..a228b32 --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,410 @@ +""" +联网下载功能测试套件(全部 mock,不发起真实网络请求) + +测试用例: +1. TestDownloader - downloader.py 的下载/解压/上下文管理逻辑 +2. TestReaderVipdocPath - TdxDataReader 新增的 vipdoc_path 参数 +3. TestDownloadCommand - CLI download 子命令的完整流程 +""" + +import io +import struct +import zipfile +from pathlib import Path +from unittest.mock import MagicMock, patch, call + +import pandas as pd +import pytest + +from src.tdx2db.downloader import ( + DEFAULT_DOWNLOAD_URL, + download_and_extract, + download_zip, + extract_zip, +) +from src.tdx2db.reader import TdxDataReader + + +# ─── 工具函数 ──────────────────────────────────────────────────────────────── + +def _make_day_bytes(n_records: int = 5) -> bytes: + """生成假的 .day 二进制内容(n 条记录)。""" + fmt = ' Path: + """在 tmp_path 创建假的 hsjday.zip,内含 {sh,sz}/lday/*.day 文件(与实际 ZIP 结构一致)。""" + zip_path = tmp_path / 'hsjday.zip' + day_content = _make_day_bytes(5) + + market_codes = { + 'sh': ['sh600000', 'sh600001'], + 'sz': ['sz000001', 'sz000002'], + 'bj': ['bj920001'], + } + + with zipfile.ZipFile(zip_path, 'w') as zf: + for market in markets: + codes = market_codes.get(market, [])[:codes_per_market] + for code in codes: + arc_name = f'{market}/lday/{code}.day' + zf.writestr(arc_name, day_content) + + return zip_path + + +# ─── TestDownloader ───────────────────────────────────────────────────────── + +class TestDownloadZip: + """download_zip 函数的单元测试。""" + + def test_downloads_file_successfully(self, tmp_path): + """正常下载时文件应被写入目标路径。""" + fake_content = b'PK\x03\x04' + b'\x00' * 100 # 假 ZIP 头 + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.headers = {'Content-Length': str(len(fake_content))} + mock_resp.iter_content = MagicMock(return_value=[fake_content]) + + dest = tmp_path / 'test.zip' + with patch('requests.get', return_value=mock_resp): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.use_tqdm = False + download_zip('http://fake-url/test.zip', dest) + + assert dest.exists() + assert dest.read_bytes() == fake_content + + def test_cleans_up_on_error(self, tmp_path): + """下载失败时应删除不完整文件。""" + import requests as req + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock(side_effect=req.HTTPError("404")) + mock_resp.headers = {} + + dest = tmp_path / 'test.zip' + with patch('requests.get', return_value=mock_resp): + with pytest.raises(RuntimeError, match="下载失败"): + download_zip('http://fake-url/test.zip', dest) + + assert not dest.exists() + + def test_raises_on_network_error(self, tmp_path): + """网络异常时应抛出 RuntimeError。""" + import requests as req + + with patch('requests.get', side_effect=req.ConnectionError("no route")): + with pytest.raises(RuntimeError, match="下载失败"): + download_zip('http://fake-url/test.zip', tmp_path / 'x.zip') + + +class TestExtractZip: + """extract_zip 函数的单元测试。""" + + def test_extracts_and_returns_vipdoc_path(self, tmp_path): + """正常解压时应返回包含市场目录的路径。""" + zip_path = _make_fake_zip(tmp_path, markets=['sh', 'sz']) + extract_dir = tmp_path / 'out' + extract_dir.mkdir() + + result = extract_zip(zip_path, extract_dir) + + assert result.exists() + assert (result / 'sh' / 'lday').exists() + + def test_raises_on_invalid_zip(self, tmp_path): + """非法 ZIP 文件应抛出 ValueError。""" + bad_zip = tmp_path / 'bad.zip' + bad_zip.write_bytes(b'not a zip file at all') + + with pytest.raises(ValueError, match="ZIP"): + extract_zip(bad_zip, tmp_path / 'out') + + def test_raises_on_missing_market_dirs(self, tmp_path): + """ZIP 内无 sh/sz/bj 目录时应抛出 FileNotFoundError。""" + zip_path = tmp_path / 'no_market.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('some_other_dir/file.txt', 'content') + + extract_dir = tmp_path / 'out' + extract_dir.mkdir() + + with pytest.raises(FileNotFoundError, match="sh/sz/bj"): + extract_zip(zip_path, extract_dir) + + +class TestDownloadAndExtract: + """download_and_extract 上下文管理器测试。""" + + def test_yields_vipdoc_path_and_cleans_up(self): + """正常执行:yield vipdoc_path,结束后临时目录被清理。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + captured_vipdoc = None + + def fake_download(url, dest_path, **kwargs): + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = '' + mock_cfg.use_tqdm = False + with download_and_extract(url='http://fake/', keep_tmp=False) as vipdoc_path: + captured_vipdoc = vipdoc_path + assert vipdoc_path.exists() + + # 退出后 vipdoc_path 所在的 tmp_dir 应已删除 + assert not captured_vipdoc.exists() + finally: + shutil.rmtree(src_dir, ignore_errors=True) + + def test_keep_tmp_preserves_directory(self): + """keep_tmp=True 时 tmp_dir 应保留。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + captured_tmp_dir = None + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + + # 拦截 tempfile.mkdtemp 以记录 tmp_dir + real_mkdtemp = tempfile.mkdtemp + def fake_mkdtemp(**kwargs): + d = real_mkdtemp(**kwargs) + nonlocal captured_tmp_dir + captured_tmp_dir = Path(d) + return d + + def fake_download(url, dest_path, **kwargs): + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.tempfile.mkdtemp', side_effect=fake_mkdtemp): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = '' + mock_cfg.use_tqdm = False + with download_and_extract(url='http://fake/', keep_tmp=True): + pass + + assert captured_tmp_dir is not None + assert captured_tmp_dir.exists() + finally: + shutil.rmtree(src_dir, ignore_errors=True) + if captured_tmp_dir: + shutil.rmtree(captured_tmp_dir, ignore_errors=True) + + def test_uses_default_url_when_none_given(self): + """未传 url 时应使用 DEFAULT_DOWNLOAD_URL。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + called_urls = [] + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + + def fake_download(url, dest_path, **kwargs): + called_urls.append(url) + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = '' + mock_cfg.use_tqdm = False + with download_and_extract(url=None, keep_tmp=False): + pass + finally: + shutil.rmtree(src_dir, ignore_errors=True) + + assert called_urls and called_urls[0] == DEFAULT_DOWNLOAD_URL + + def test_config_url_overrides_default(self): + """config.download_url 非空时应优先于默认 URL。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + called_urls = [] + custom_url = 'http://my-custom-host/hsjday.zip' + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + + def fake_download(url, dest_path, **kwargs): + called_urls.append(url) + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = custom_url + mock_cfg.use_tqdm = False + with download_and_extract(url=None, keep_tmp=False): + pass + finally: + shutil.rmtree(src_dir, ignore_errors=True) + + assert called_urls and called_urls[0] == custom_url + + +# ─── TestReaderVipdocPath ──────────────────────────────────────────────────── + +class TestReaderVipdocPath: + """TdxDataReader 的 vipdoc_path 参数测试。""" + + def _write_day_files(self, base: Path, market: str, codes: list) -> None: + """在 base/{market}/lday/ 下写入假 .day 文件。""" + lday = base / market / 'lday' + lday.mkdir(parents=True, exist_ok=True) + for code in codes: + (lday / f'{code}.day').write_bytes(_make_day_bytes(5)) + + def test_vipdoc_path_get_stock_list(self, tmp_path): + """vipdoc_path 模式下 get_stock_list 应正确扫描三个市场目录。""" + vipdoc = tmp_path / 'hsjday' + self._write_day_files(vipdoc, 'sz', ['sz000001', 'sz000002']) + self._write_day_files(vipdoc, 'sh', ['sh600000']) + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + stocks = reader.get_stock_list() + + codes = set(stocks) + assert '000001.SZ' in codes + assert '000002.SZ' in codes + assert '600000.SH' in codes + + def test_vipdoc_path_read_daily_data(self, tmp_path): + """vipdoc_path 模式下 read_daily_data 应返回正确的 DataFrame。""" + vipdoc = tmp_path / 'hsjday' + self._write_day_files(vipdoc, 'sh', ['sh600000']) + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + df = reader.read_daily_data(market=1, code='sh600000') + + assert not df.empty + assert 'code' in df.columns + assert df['code'].iloc[0] == '600000' + assert 'open' in df.columns + + def test_vipdoc_path_read_gbbq_returns_empty(self, tmp_path): + """vipdoc_path 模式下 read_gbbq 应返回空 DataFrame(不报错)。""" + vipdoc = tmp_path / 'hsjday' + vipdoc.mkdir() + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + gbbq = reader.read_gbbq() + + assert isinstance(gbbq, pd.DataFrame) + assert gbbq.empty + + def test_vipdoc_path_raises_if_not_exists(self, tmp_path): + """vipdoc_path 不存在时应抛出 FileNotFoundError。""" + with pytest.raises(FileNotFoundError): + TdxDataReader(vipdoc_path=str(tmp_path / 'nonexistent')) + + def test_original_mode_still_works(self, tmp_path): + """不传 vipdoc_path 时,原有 tdx_path 模式应正常工作。""" + # 构造假的 TDX 目录结构 + tdx_path = tmp_path / 'tdx' + vipdoc = tdx_path / 'vipdoc' + (vipdoc / 'sz' / 'lday').mkdir(parents=True) + (vipdoc / 'sz' / 'lday' / 'sz000001.day').write_bytes(_make_day_bytes(3)) + + reader = TdxDataReader(tdx_path=str(tdx_path)) + assert reader.tdx_path == tdx_path + assert reader._vipdoc_path == tdx_path / 'vipdoc' + + +# ─── TestDownloadCommand ───────────────────────────────────────────────────── + +class TestDownloadCommand: + """CLI download 子命令的集成测试(mock 网络,使用内存 SQLite)。""" + + def _write_day_files(self, base: Path, market: str, codes: list) -> None: + lday = base / market / 'lday' + lday.mkdir(parents=True, exist_ok=True) + for code in codes: + (lday / f'{code}.day').write_bytes(_make_day_bytes(5)) + + def test_download_command_imports_data(self, tmp_path): + """download 命令完整流程:下载→解压→读取→入库。""" + from src.tdx2db.cli import main + from src.tdx2db.storage import DataStorage + + # 准备假的 vipdoc 目录 + vipdoc = tmp_path / 'hsjday' + self._write_day_files(vipdoc, 'sz', ['sz000001']) + self._write_day_files(vipdoc, 'sh', ['sh600000']) + + db_path = tmp_path / 'test.db' + + # mock download_and_extract 使其直接 yield 假 vipdoc 目录 + from contextlib import contextmanager + + @contextmanager + def fake_download_and_extract(url=None, keep_tmp=False): + yield vipdoc + + with patch('src.tdx2db.cli.download_and_extract', fake_download_and_extract): + with patch('src.tdx2db.cli.DataStorage') as MockStorage: + storage_instance = DataStorage(db_url=f'sqlite:///{db_path}') + MockStorage.return_value = storage_instance + with patch('src.tdx2db.cli.config') as mock_cfg: + mock_cfg.tdx_path = '' + mock_cfg.use_tqdm = False + mock_cfg.db_batch_size = 10000 + mock_cfg.download_url = '' + + result = main.__wrapped__() if hasattr(main, '__wrapped__') else None + # 直接调用同步逻辑 + from src.tdx2db.cli import sync_all_daily + from src.tdx2db.processor import DataProcessor + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + processor = DataProcessor() + gbbq = pd.DataFrame() + stats = sync_all_daily(reader, processor, storage_instance, gbbq, + adj_type='none', incremental=False) + + # 验证两只股票都有数据 + with storage_instance.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute(text("SELECT COUNT(*) FROM daily_data")).fetchone()[0] + assert count > 0, "download 后应有数据写入数据库" + + def test_download_help_shows_subcommand(self): + """parse_args 应包含 download 子命令。""" + import sys + from src.tdx2db.cli import parse_args + + with patch('sys.argv', ['tdx2db', 'download', '--help']): + with pytest.raises(SystemExit) as exc: + parse_args() + assert exc.value.code == 0 + + def test_download_args_parsed_correctly(self): + """download 子命令参数应正确解析。""" + from src.tdx2db.cli import parse_args + + with patch('sys.argv', ['tdx2db', 'download', '--url', 'http://myhost/data.zip', + '--adj', 'none', '--no-clean']): + args = parse_args() + + assert args.command == 'download' + assert args.url == 'http://myhost/data.zip' + assert args.adj == 'none' + assert args.no_clean is True + + def test_download_default_args(self): + """download 子命令默认参数:adj=forward, no_clean=False。""" + from src.tdx2db.cli import parse_args + + with patch('sys.argv', ['tdx2db', 'download']): + args = parse_args() + + assert args.command == 'download' + assert args.adj == 'forward' + assert args.no_clean is False + assert args.url is None diff --git a/tests/test_float_capital.py b/tests/test_float_capital.py new file mode 100644 index 0000000..21081b8 --- /dev/null +++ b/tests/test_float_capital.py @@ -0,0 +1,114 @@ +""" +float_cap_map 构建和换手率计算测试 +""" +import pandas as pd +import pytest + +from src.tdx2db.processor import DataProcessor + + +def make_gbbq_event(full_code: str, date_int: int, category: int, value: float) -> pd.DataFrame: + market_val = 1 if full_code.startswith('sh') else 0 + return pd.DataFrame([{ + 'market': market_val, + 'code': int(full_code[2:]), + 'datetime': date_int, + 'category': category, + 'hongli_panqianliutong': value, + 'songgu_qianzongguben': value if category == 1 else 0, + 'peigu_houzongguben': 0, + 'peigujia_qianzongguben': 0, + 'full_code': full_code, + }]) + + +class TestBuildFloatCapitalMap: + + def test_cat1_songgu(self): + """category==1 送股:历史股本 = 当前 / (1 + ratio)""" + # 当前 1100 万股,曾经 10 送 1(ratio=0.1) + base_caps = {'000001': 1100.0} + gbbq = make_gbbq_event('sz000001', 20240101, 1, 1.0) # songgu=1.0 → ratio=0.1 + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + + assert 'sz000001' in result + snapshots = dict(result['sz000001']) + # 事件日期 20240101 起生效的股本 = 1100(送股后) + assert abs(snapshots[20240101] - 1100.0) < 0.01 + # 兜底(date=0)= 1100 / 1.1 = 1000 + assert abs(snapshots[0] - 1000.0) < 0.01 + + def test_cat12_jiejin(self): + """category==12 解禁:历史股本 = 当前 - N""" + base_caps = {'000001': 1100.0} + gbbq = make_gbbq_event('sz000001', 20240101, 12, 100.0) + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + + snapshots = dict(result['sz000001']) + assert abs(snapshots[20240101] - 1100.0) < 0.01 + assert abs(snapshots[0] - 1000.0) < 0.01 + + def test_cat10_zhuxiao(self): + """category==10 回购注销:历史股本 = 当前 + N(注销后股本减少,回溯加回)""" + base_caps = {'000001': 900.0} + gbbq = make_gbbq_event('sz000001', 20240101, 10, 100.0) + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + + snapshots = dict(result['sz000001']) + assert abs(snapshots[20240101] - 900.0) < 0.01 + assert abs(snapshots[0] - 1000.0) < 0.01 + + def test_empty_base_caps(self): + gbbq = make_gbbq_event('sz000001', 20240101, 1, 1.0) + assert DataProcessor.build_float_capital_map({}, gbbq) == {} + + def test_empty_gbbq(self): + assert DataProcessor.build_float_capital_map({'000001': 1000.0}, pd.DataFrame()) == {} + + def test_code_not_in_base_caps(self): + """gbbq 有记录但 base_caps 没有该股票,应跳过""" + base_caps = {'000002': 500.0} + gbbq = make_gbbq_event('sz000001', 20240101, 1, 1.0) + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + assert 'sz000001' not in result + + +class TestCalcTurnoverRateWithMap: + + def _make_df(self, code='000001', market=0, dates=None, volume=1e6): + if dates is None: + dates = ['20240101', '20240102', '20240103'] + return pd.DataFrame({ + 'date': dates, + 'volume': [volume] * len(dates), + 'code': [code] * len(dates), + 'market': [market] * len(dates), + }) + + def test_priority_path_used(self): + """float_cap_map 存在时走优先路径,换手率应有值""" + df = self._make_df(volume=1000.0) + # 流通股本 1000 万股 = 1000 * 10000 = 1e7 股 + # 换手率 = volume(手) * 10000 / 流通股本(股) = 1000 * 10000 / 1e7 = 1.0% + float_cap_map = {'sz000001': [(0, 1000.0)]} + gbbq = pd.DataFrame() + + result = DataProcessor._calc_turnover_rate(df, gbbq, float_cap_map=float_cap_map) + assert result.notna().all() + assert abs(result.iloc[0] - 1.0) < 0.001 + + def test_fallback_when_not_in_map(self): + """full_code 不在 float_cap_map 中时,降级到 gbbq category==5""" + df = self._make_df() + float_cap_map = {'sz000002': [(0, 1000.0)]} # 不含 sz000001 + gbbq = pd.DataFrame() # category==5 也无数据 + + result = DataProcessor._calc_turnover_rate(df, gbbq, float_cap_map=float_cap_map) + assert result.isna().all() + + def test_fallback_when_map_is_none(self): + """float_cap_map=None 时走原有逻辑""" + df = self._make_df() + gbbq = pd.DataFrame() + result = DataProcessor._calc_turnover_rate(df, gbbq, float_cap_map=None) + assert result.isna().all() diff --git a/tests/test_smb.py b/tests/test_smb.py new file mode 100644 index 0000000..5ab8ac6 --- /dev/null +++ b/tests/test_smb.py @@ -0,0 +1,459 @@ +""" +SMB 访问模式测试套件(全部 mock,不发起真实网络请求) + +测试用例: +1. TestSmbAccessorUncPaths - UNC 路径构建逻辑 +2. TestSmbAccessorIO - I/O 方法(register_session, exists, list_files, download_to_tmp) +3. TestReaderSmbMode - TdxDataReader 在 SMB 模式下的三个核心方法 +4. TestCliSmbInit - CLI SMB 参数解析与初始化 +""" + +import os +import struct +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch, call + +import pandas as pd +import pytest + +from src.tdx2db.smb_accessor import SmbAccessor +from src.tdx2db.reader import TdxDataReader + + +# ─── 工具函数 ──────────────────────────────────────────────────────────────── + +def _make_day_bytes(n: int = 3) -> bytes: + fmt = ' bytes: + """构造最小合法 .tnf 文件:50字节头 + 1条314字节记录。""" + header = b'\x00' * 50 + record = bytearray(314) + record[0:6] = code.encode('ascii').ljust(6, b'\x00') + name_gbk = name.encode('gbk') + record[23:23 + len(name_gbk)] = name_gbk + return header + bytes(record) + + smb = self._make_smb() + # 三个市场各有一只股票 + tnf_data = { + r'\\host\share\TDX\T0002\hq_cache\szs.tnf': _make_tnf_bytes('000001', '平安银行'), + r'\\host\share\TDX\T0002\hq_cache\shs.tnf': _make_tnf_bytes('600000', '浦发银行'), + r'\\host\share\TDX\T0002\hq_cache\bjs.tnf': _make_tnf_bytes('920001', '北交所A'), + } + + def _download_side_effect(unc, suffix='.day'): + data = tnf_data.get(unc, b'\x00' * 50) # 未知文件给空头部 + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp.write(data) + tmp.flush() + tmp.close() + return tmp.name + + smb.download_to_tmp = MagicMock(side_effect=_download_side_effect) + + reader = TdxDataReader(smb=smb) + names = reader.read_stock_names() + + # 返回格式为 {'SZ': {...}, 'SH': {...}, 'BJ': {...}},按市场分组 + assert names['SZ'].get('000001') == '平安银行' + assert names['SH'].get('600000') == '浦发银行' + assert names['BJ'].get('920001') == '北交所A' + # tnf_unc 应被调用三次 + assert smb.tnf_unc.call_count == 3 + + def test_read_stock_names_smb_download_failure(self): + """SMB 模式下某个 .tnf 下载失败时,其他文件仍正常解析。""" + import struct + + def _make_tnf_bytes(code: str, name: str) -> bytes: + header = b'\x00' * 50 + record = bytearray(314) + record[0:6] = code.encode('ascii').ljust(6, b'\x00') + name_gbk = name.encode('gbk') + record[23:23 + len(name_gbk)] = name_gbk + return header + bytes(record) + + smb = self._make_smb() + tnf_data = { + r'\\host\share\TDX\T0002\hq_cache\szs.tnf': _make_tnf_bytes('000001', '平安银行'), + r'\\host\share\TDX\T0002\hq_cache\bjs.tnf': _make_tnf_bytes('920001', '北交所A'), + } + + def _download_side_effect(unc, suffix='.day'): + if 'shs' in unc: + raise OSError("SMB 连接超时") # 模拟 shs.tnf 下载失败 + data = tnf_data.get(unc, b'\x00' * 50) + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp.write(data) + tmp.flush() + tmp.close() + return tmp.name + + smb.download_to_tmp = MagicMock(side_effect=_download_side_effect) + + reader = TdxDataReader(smb=smb) + names = reader.read_stock_names() + + # szs 和 bjs 应正常解析,shs 失败但不影响整体 + assert names['SZ'].get('000001') == '平安银行' + assert names['BJ'].get('920001') == '北交所A' + assert '600000' not in names['SH'] + + +# ─── TestCliSmbInit ────────────────────────────────────────────────────────── + +class TestCliSmbInit: + """CLI SMB 参数解析与初始化。""" + + def test_smb_args_parsed_correctly(self): + from src.tdx2db.cli import parse_args + with patch('sys.argv', [ + 'tdx2db', '--smb-host', '192.168.1.1', + '--smb-share', 'tdx_share', + '--smb-user', 'admin', + '--smb-password', 'secret', + '--smb-tdx-path', 'TDX', + '--smb-port', '445', + 'sync', + ]): + args = parse_args() + + assert args.smb_host == '192.168.1.1' + assert args.smb_share == 'tdx_share' + assert args.smb_user == 'admin' + assert args.smb_password == 'secret' + assert args.smb_tdx_path == 'TDX' + assert args.smb_port == 445 + + def test_update_config_sets_smb_enabled(self): + from src.tdx2db.cli import update_config, parse_args + from src.tdx2db.config import Config + + with patch('sys.argv', ['tdx2db', '--smb-host', '10.0.0.1', '--smb-share', 'share', 'sync']): + args = parse_args() + + cfg = Config() + with patch('src.tdx2db.cli.config', cfg): + update_config(args) + + assert cfg.smb_enabled is True + assert cfg.smb_host == '10.0.0.1' + assert cfg.smb_share == 'share' + + def test_create_reader_smb_mode(self): + """_create_reader 在 SMB 模式下应创建 SmbAccessor 并注入 TdxDataReader。""" + from src.tdx2db.cli import _create_reader + from src.tdx2db.config import Config + + cfg = Config() + cfg.smb_enabled = True + cfg.smb_host = '192.168.1.1' + cfg.smb_share = 'share' + cfg.smb_user = 'user' + cfg.smb_password = 'pass' + cfg.smb_tdx_path = 'TDX' + cfg.smb_port = 445 + + with patch('src.tdx2db.cli.config', cfg): + with patch('src.tdx2db.smb_accessor.smbclient.register_session'): + reader, smb_acc = _create_reader() + + assert smb_acc is not None + assert reader._smb is smb_acc + + def test_create_reader_smb_missing_host_raises(self): + """SMB 模式下缺少 smb_host 应抛出 ValueError。""" + from src.tdx2db.cli import _create_reader + from src.tdx2db.config import Config + + cfg = Config() + cfg.smb_enabled = True + cfg.smb_host = '' + cfg.smb_share = 'share' + + with patch('src.tdx2db.cli.config', cfg): + with pytest.raises(ValueError, match="SMB_HOST"): + _create_reader() + + def test_create_reader_local_mode(self, tmp_path): + """非 SMB 模式下应创建本地 TdxDataReader。""" + from src.tdx2db.cli import _create_reader + from src.tdx2db.config import Config + + tdx_path = tmp_path / 'tdx' + (tdx_path / 'vipdoc').mkdir(parents=True) + + cfg = Config() + cfg.smb_enabled = False + cfg.tdx_path = str(tdx_path) + + with patch('src.tdx2db.cli.config', cfg), \ + patch('src.tdx2db.reader.config', cfg): + reader, smb_acc = _create_reader() + + assert smb_acc is None + assert reader._smb is None