From 435039acb7df2b824f4cf1bb203a7558e9e13510 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 17:37:11 +0800 Subject: [PATCH 01/26] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20SQLite=20?= =?UTF-8?q?=E5=86=99=E5=85=A5=E6=97=B6=20pd.Timestamp=20=E7=BB=91=E5=AE=9A?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sqlite3 使用精确类型匹配查找适配器,不识别 pd.Timestamp(虽是 datetime.datetime 子类),导致 INSERT OR IGNORE 报 InterfaceError。 之前的修复尝试在 DataFrame 上原地替换,但 pandas 会把赋值的 datetime 对象重新推断为 datetime64,to_dict() 后又变回 Timestamp, 转换失效。 正确做法:先调用 to_dict('records') 得到 dict 列表,再遍历 dict 将 pd.Timestamp → datetime.pydatetime()、pd.NaT → None,此时 Python dict 不会触发 pandas 类型推断,转换结果可靠传入 sqlite3。 Co-Authored-By: Claude Sonnet 4.6 --- src/storage.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/storage.py b/src/storage.py index 2b29aed..05bed16 100644 --- a/src/storage.py +++ b/src/storage.py @@ -342,8 +342,22 @@ def save_incremental( with self.engine.connect() as conn: for i in range(0, total_rows, batch_size): - batch_df = df_to_save.iloc[i:i + batch_size] - conn.execute(sql, batch_df.to_dict('records')) + batch_df = df_to_save.iloc[i:i + batch_size].astype(object).where( + df_to_save.iloc[i:i + batch_size].notna(), None + ) + # 先转为 dict 列表,再转换 Timestamp → Python datetime + # 必须在 to_dict() 之后处理:若在 DataFrame 上赋值,pandas 会把 + # datetime 对象重新推断为 datetime64,to_dict() 又变回 Timestamp, + # 而 sqlite3 使用精确类型匹配,不识别 pd.Timestamp(datetime 子类) + records = batch_df.to_dict('records') + for record in records: + for key in record: + val = record[key] + if isinstance(val, pd.Timestamp): + record[key] = val.to_pydatetime() + elif val is pd.NaT: + record[key] = None + conn.execute(sql, records) conn.commit() logger.info(f"增量保存完成: 共处理 {total_rows} 条到表 {table_name}(重复数据已跳过)") From 4ac78d30dc58c948b1eb240a1659cf94badf3fab Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 21:21:15 +0800 Subject: [PATCH 02/26] =?UTF-8?q?feat:=20=E6=97=A5=E7=BA=BF=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BC=95=E5=85=A5=E5=89=8D=E5=A4=8D=E6=9D=83=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - reader.py: 新增 read_gbbq() 读取通达信本地权息文件;新增 _read_day_file_raw() 绕过 pytdx 类型检查支持科创板(688xxx) - processor.py: 新增 apply_forward_adj() 前复权算法,process_daily_data() 新增 gbbq 参数 - storage.py: DailyData 模型新增 adj_factor 字段 - cli.py: daily 命令支持纯6位代码自动推断市场(6/688→sh,其余→sz);修复 reset_index 条件漏掉 date 索引;修复 code 过滤格式不匹配 - scripts/migrate_add_adj_factor.sql: 已有数据库迁移脚本 gbbq 字段单位:hongli/10 = 每股红利(元),songgu/10 = 实际送股比例 Co-Authored-By: Claude Sonnet 4.6 --- scripts/migrate_add_adj_factor.sql | 16 ++++++ src/cli.py | 39 ++++++++++---- src/processor.py | 87 +++++++++++++++++++++++++++++- src/reader.py | 57 ++++++++++++++++++-- src/storage.py | 1 + 5 files changed, 185 insertions(+), 15 deletions(-) create mode 100644 scripts/migrate_add_adj_factor.sql diff --git a/scripts/migrate_add_adj_factor.sql b/scripts/migrate_add_adj_factor.sql new file mode 100644 index 0000000..0b70130 --- /dev/null +++ b/scripts/migrate_add_adj_factor.sql @@ -0,0 +1,16 @@ +-- 迁移脚本:为 daily_data 表新增前复权因子列 +-- 执行一次即可,若列已存在则不报错(IF NOT EXISTS) +-- +-- 使用方式: +-- psql -d -f scripts/migrate_add_adj_factor.sql +-- mysql -u -p < scripts/migrate_add_adj_factor.sql +-- +-- 说明: +-- adj_factor = 1.0 表示该行是原始价格(无复权或当天最新价) +-- adj_factor < 1.0 表示历史数据已向前调整(价格已乘以该因子) + +-- PostgreSQL +ALTER TABLE daily_data ADD COLUMN IF NOT EXISTS adj_factor FLOAT DEFAULT 1.0; + +-- MySQL(不支持 IF NOT EXISTS,如报列已存在错误可忽略) +-- ALTER TABLE daily_data ADD COLUMN adj_factor FLOAT DEFAULT 1.0; diff --git a/src/cli.py b/src/cli.py index 5acfa5b..a73891f 100644 --- a/src/cli.py +++ b/src/cli.py @@ -102,6 +102,7 @@ def sync_all_daily_data( processor: DataProcessor, storage: DataStorage, start_date: Optional[str] = None, + gbbq: Optional[pd.DataFrame] = None, ) -> bool: """逐股票流式同步日线数据,避免全量加载到内存""" try: @@ -116,12 +117,12 @@ def sync_all_daily_data( market = 1 if code.startswith('sh') else 0 try: data = reader.read_daily_data(market, code) - if isinstance(data.index, pd.DatetimeIndex) or data.index.name == 'datetime': + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('datetime', 'date'): data = data.reset_index() if data.empty: continue - processed = processor.process_daily_data(data) + processed = processor.process_daily_data(data, gbbq=gbbq) filtered = processor.filter_data(processed, start_date=start_date) if filtered.empty: continue @@ -321,9 +322,21 @@ def main() -> int: logger.info("数据库中没有数据,将获取所有数据") # 获取日线数据 - if args.code and args.market is not None: - # 获取单只股票的日线数据 - data = reader.read_daily_data(args.market, args.code) + if args.code: + # 自动推断市场:支持 sz/sh 前缀,或纯6位代码(6/688开头→sh,其余→sz) + code = args.code + if args.market is not None: + market = args.market + elif code.startswith('sh'): + market = 1 + elif code.startswith('sz'): + market = 0 + else: + pure = code[-6:] + market = 1 if pure.startswith(('6', '688')) else 0 + data = reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('datetime', 'date'): + data = data.reset_index() else: # 获取所有股票的日线数据 data = reader.read_all_daily_data() @@ -336,14 +349,19 @@ def main() -> int: # 处理数据 processor = DataProcessor() - processed_data = processor.process_daily_data(data) - - # 根据日期筛选 + gbbq = reader.read_gbbq() + processed_data = processor.process_daily_data(data, gbbq=gbbq) + + # 根据日期筛选(code 过滤用纯6位格式,与 DataFrame 中的 code 列一致) + filter_code = None + if args.code: + c = args.code + filter_code = [c[-6:] if len(c) > 6 else c] filtered_data = processor.filter_data( processed_data, start_date=start_date, end_date=args.end_date, - codes=[args.code] if args.code else None + codes=filter_code ) if filtered_data.empty: @@ -448,7 +466,8 @@ def main() -> int: start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') logger.info(f"日线起始日期: {start_date}") - success = sync_all_daily_data(reader, processor, storage, start_date) + gbbq = reader.read_gbbq() + success = sync_all_daily_data(reader, processor, storage, start_date, gbbq=gbbq) if not success: logger.error("同步日线数据时出错") has_error = True diff --git a/src/processor.py b/src/processor.py index 5fb2e3f..d3f18c2 100644 --- a/src/processor.py +++ b/src/processor.py @@ -105,11 +105,92 @@ def _calculate_ma(df: pd.DataFrame) -> pd.DataFrame: return df @staticmethod - def process_daily_data(df: pd.DataFrame) -> pd.DataFrame: + def apply_forward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: + """计算前复权价格并原地更新 open/high/low/close,新增 adj_factor 列 + + 前复权算法:对每个除权日,将该日之前的全部历史数据乘以当次复权因子, + 使价格序列在除权日前后保持连续。 + + Args: + df: 单只股票的日线 DataFrame,含 code(6位), market, date, open/high/low/close + gbbq: 全量权息 DataFrame(来自 reader.read_gbbq()),含 full_code, category 等 + + Returns: + 含 adj_factor 列、价格已前复权的 DataFrame + """ + if gbbq.empty or df.empty: + df = df.copy() + df['adj_factor'] = 1.0 + return df + + # 构造带前缀的完整代码(market 0=sz, 1=sh) + market_val = df['market'].iloc[0] + prefix = 'sz' if market_val == 0 else 'sh' + pure_code = str(df['code'].iloc[0]).zfill(6) + full_code = prefix + pure_code + + events = gbbq[gbbq['full_code'] == full_code].copy() + + df = df.copy() + if events.empty: + df['adj_factor'] = 1.0 + return df + + # 转换除权日期为 Timestamp(与 df['date'] 的 datetime64[ns] 类型一致) + events['ex_date'] = pd.to_datetime( + events['datetime'].astype(str).str[:8], format='%Y%m%d' + ) + # 仅处理 category==1(除权除息),且限制在 df 数据范围内(避免本地数据不完整时引入无效历史因子) + data_start = df['date'].min() + events = events[(events['category'] == 1) & (events['ex_date'] > data_start)].sort_values('ex_date') + + df = df.sort_values('date').copy() + df['adj_factor'] = 1.0 + + for _, ev in events.iterrows(): + ex_date = ev['ex_date'] # pd.Timestamp + # gbbq 字段均以"每10股"为单位: + # songgu: 每10股送股数,实际送股比例 = songgu / 10 + # hongli: 每10股红利(元),实际每股红利 = hongli / 10 + # peigujia: 配股价(元),单位无需换算 + # peigu: 每10股配股数,实际配股比例 = peigu / 10 + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 + peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + + before = df[df['date'] < ex_date] + if before.empty: + continue + prev_close = float(before['close'].iloc[-1]) + if prev_close <= 0: + continue + + denominator = prev_close * (1 + songgu + peigu) + if denominator <= 0: + continue + factor = (prev_close - hongli + peigujia * peigu) / denominator + + if factor <= 0 or factor > 2: + logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") + continue + + mask = df['date'] < ex_date + df.loc[mask, 'adj_factor'] *= factor + + # 应用前复权:历史价格 * adj_factor,当日及以后为原始价格(factor=1) + for col in ['open', 'high', 'low', 'close']: + df[col] = (df[col] * df['adj_factor']).round(3) + + return df + + @staticmethod + def process_daily_data(df: pd.DataFrame, gbbq: pd.DataFrame = None) -> pd.DataFrame: """处理日线数据 Args: df: 原始日线数据DataFrame + gbbq: 权息数据(来自 reader.read_gbbq()),传入时执行前复权,None 时跳过复权 Returns: DataFrame: 处理后的数据 @@ -141,6 +222,10 @@ def process_daily_data(df: pd.DataFrame) -> pd.DataFrame: # 数据质量校验 processed_df = DataProcessor._validate_ohlcv(processed_df) + # 前复权处理(在均线计算之前,确保均线基于复权价格) + if gbbq is not None and not gbbq.empty and 'date' in processed_df.columns: + processed_df = DataProcessor.apply_forward_adj(processed_df, gbbq) + # 计算均线指标 if all(col in processed_df.columns for col in ['close', 'volume']): processed_df = DataProcessor._calculate_ma(processed_df) diff --git a/src/reader.py b/src/reader.py index f81cd70..ec1b5ea 100644 --- a/src/reader.py +++ b/src/reader.py @@ -13,7 +13,7 @@ import pandas as pd from pytdx.reader import TdxDailyBarReader, TdxMinBarReader, TdxLCMinBarReader -from pytdx.reader import BlockReader +from pytdx.reader import BlockReader, GbbqReader from tqdm import tqdm from .config import config @@ -42,6 +42,31 @@ def __init__(self, tdx_path: Optional[str] = None) -> None: self.min_reader = TdxMinBarReader() self.lc_min_reader = TdxLCMinBarReader() self.block_reader = BlockReader() + self.gbbq_reader = GbbqReader() + + def read_gbbq(self) -> pd.DataFrame: + """读取通达信本地权息文件(gbbq),返回全量权息 DataFrame + + Returns: + DataFrame: 权息数据,列包括 market, code, datetime(YYYYMMDD整数), category, + hongli_panqianliutong, peigujia_qianzongguben, songgu_qianzongguben, + peigu_houzongguben, full_code(sz/sh+6位代码) + 若文件不存在则返回空 DataFrame(复权逻辑会优雅降级) + """ + gbbq_path = self.tdx_path / 'T0002' / 'hq_cache' / 'gbbq' + if not gbbq_path.exists(): + logger.warning(f"权息文件不存在: {gbbq_path},将跳过复权处理") + return pd.DataFrame() + try: + df = self.gbbq_reader.get_df(str(gbbq_path)) + if df.empty: + return pd.DataFrame() + market_prefix = df['market'].map({0: 'sz', 1: 'sh'}) + df['full_code'] = market_prefix + df['code'].astype(str).str.zfill(6) + return df + except Exception as e: + logger.warning(f"读取权息文件时出错: {e},将跳过复权处理") + return pd.DataFrame() def get_stock_list(self) -> pd.DataFrame: """获取股票列表 @@ -110,11 +135,35 @@ def read_daily_data(self, market: int, code: str) -> pd.DataFrame: raise FileNotFoundError(f"日线数据文件不存在: {file_path}") # 读取数据 - data = self.daily_reader.get_df(str(file_path)) + try: + data = self.daily_reader.get_df(str(file_path)) + except NotImplementedError: + # pytdx 不识别的证券类型(如科创板 688xxx),直接解析二进制,系数与 SH_A_STOCK 相同 + data = self._read_day_file_raw(str(file_path)) data['code'] = code data['market'] = market return data + @staticmethod + def _read_day_file_raw(fname: str) -> pd.DataFrame: + """直接解析 .day 文件,绕过 pytdx 的证券类型检查(用于科创板等)""" + import struct + rows = [] + with open(fname, 'rb') as f: + content = f.read() + record_size = struct.calcsize(' List[pd.DataFrame]: """读取5分钟线数据并生成15分钟、30分钟和60分数据 @@ -241,8 +290,8 @@ def read_all_daily_data(self) -> pd.DataFrame: try: data = self.read_daily_data(market, code) - # 确保datetime是列而不是索引 - if isinstance(data.index, pd.DatetimeIndex) or data.index.name == 'datetime': + # 确保 date/datetime 是列而不是索引 + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('datetime', 'date'): data = data.reset_index() all_data.append(data) except FileNotFoundError: diff --git a/src/storage.py b/src/storage.py index 05bed16..0e57ade 100644 --- a/src/storage.py +++ b/src/storage.py @@ -46,6 +46,7 @@ class DailyData(Base): close = Column(Float) volume = Column(Float) amount = Column(Float) + adj_factor = Column(Float) # 前复权因子,1.0 表示无复权 ma13 = Column(Float) ma21 = Column(Float) ma34 = Column(Float) From 397881824880b193f50170e28b38a118f0e050b7 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 21:27:08 +0800 Subject: [PATCH 03/26] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=A7=91?= =?UTF-8?q?=E5=88=9B=E6=9D=BF=E8=82=A1=E7=A5=A8=E6=9C=AA=E7=BA=B3=E5=85=A5?= =?UTF-8?q?=E5=85=A8=E9=87=8F=E5=90=8C=E6=AD=A5=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit get_stock_list 上海正则 `(60|688)\d{4}` 要求 688 后跟4位(共7位), 实际科创板代码为 688xxx(6位),改为 `(60\d{4}|688\d{3})` 精确匹配。 Co-Authored-By: Claude Sonnet 4.6 --- src/reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/reader.py b/src/reader.py index ec1b5ea..3c45fb1 100644 --- a/src/reader.py +++ b/src/reader.py @@ -104,8 +104,8 @@ def get_stock_list(self) -> pd.DataFrame: pure_code = code[-6:] code_str = str(pure_code).zfill(6) # 补齐为6位字符串 - # 匹配上证A股+深证A股 - if re.match(r'^(60|688)\d{4}$', code_str): + # 匹配上证A股(60xxxx)和科创板(688xxx) + if re.match(r'^(60\d{4}|688\d{3})$', code_str): stocks.append({'code': code, 'name': name}) if not stocks: From 6ecbbd00ea12577ae0fd2f898b75d89a52ac3b7a Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 21:31:28 +0800 Subject: [PATCH 04/26] =?UTF-8?q?fix:=20=E6=B6=88=E9=99=A4=E7=A7=91?= =?UTF-8?q?=E5=88=9B=E6=9D=BF=E8=AF=BB=E5=8F=96=E6=97=B6=20pytdx=20?= =?UTF-8?q?=E6=89=93=E5=8D=B0=E7=9A=84=20Unknown=20security=20type=20?= =?UTF-8?q?=E5=99=AA=E9=9F=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 改为调用前先检查 security_type,不支持时直接走 _read_day_file_raw, 避免 pytdx 内部 print 污染日志输出。 Co-Authored-By: Claude Sonnet 4.6 --- src/reader.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/reader.py b/src/reader.py index 3c45fb1..cc39a97 100644 --- a/src/reader.py +++ b/src/reader.py @@ -134,11 +134,11 @@ def read_daily_data(self, market: int, code: str) -> pd.DataFrame: if not file_path.exists(): raise FileNotFoundError(f"日线数据文件不存在: {file_path}") - # 读取数据 - try: + # 读取数据:先检查 pytdx 是否支持该证券类型,不支持则直接走原始解析(如科创板 688xxx) + sec_type = self.daily_reader.get_security_type(str(file_path)) + if sec_type in self.daily_reader.SECURITY_TYPE: data = self.daily_reader.get_df(str(file_path)) - except NotImplementedError: - # pytdx 不识别的证券类型(如科创板 688xxx),直接解析二进制,系数与 SH_A_STOCK 相同 + else: data = self._read_day_file_raw(str(file_path)) data['code'] = code data['market'] = market From 3a75fc97f83bb796684c3ee1920272a89c221cc7 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 22:45:41 +0800 Subject: [PATCH 05/26] =?UTF-8?q?feat:=20=E5=A4=8D=E6=9D=83=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E6=94=B9=E4=B8=BA=E5=8F=AF=E9=80=89=E5=8F=82=E6=95=B0?= =?UTF-8?q?=EF=BC=8C=E9=BB=98=E8=AE=A4=E5=89=8D=E5=A4=8D=E6=9D=83=EF=BC=8C?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=85=A8=E9=87=8F=E6=97=A5=E7=BA=BF=E5=A4=8D?= =?UTF-8?q?=E6=9D=83=E7=BC=BA=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - processor.py: process_daily_data 新增 adj_type 参数(forward/backward/none),新增 apply_backward_adj 后复权实现 - cli.py: daily/sync 子命令新增 --adj-type 参数;daily 不带 --code 时改为逐股票处理,修复原先 apply_forward_adj 只处理第一只股票的缺陷 Co-Authored-By: Claude Sonnet 4.6 --- src/cli.py | 77 ++++++++++++++++++++++++++++++++++--------- src/processor.py | 86 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 144 insertions(+), 19 deletions(-) diff --git a/src/cli.py b/src/cli.py index a73891f..43a8cdf 100644 --- a/src/cli.py +++ b/src/cli.py @@ -103,6 +103,7 @@ def sync_all_daily_data( storage: DataStorage, start_date: Optional[str] = None, gbbq: Optional[pd.DataFrame] = None, + adj_type: str = 'forward', ) -> bool: """逐股票流式同步日线数据,避免全量加载到内存""" try: @@ -122,7 +123,7 @@ def sync_all_daily_data( if data.empty: continue - processed = processor.process_daily_data(data, gbbq=gbbq) + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) filtered = processor.filter_data(processed, start_date=start_date) if filtered.empty: continue @@ -216,6 +217,12 @@ def parse_args() -> Namespace: daily_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') daily_parser.add_argument('--auto-start', action='store_true', help='自动检测起始日期(从数据库最新日期+1天开始)') daily_parser.add_argument('--incremental', action='store_true', help='增量同步模式,跳过重复数据') + daily_parser.add_argument( + '--adj-type', + choices=['forward', 'backward', 'none'], + default='forward', + help='复权类型:forward(前复权,默认)/ backward(后复权)/ none(不复权)' + ) # 获取并计算分钟线数据 min_parser = subparsers.add_parser('minutes', help='获取分钟线数据') @@ -234,7 +241,13 @@ def parse_args() -> Namespace: block_relation_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') # 一键同步(日线 + 分钟线增量同步到数据库) - subparsers.add_parser('sync', help='一键增量同步所有数据到数据库(日线 + 5/15/30/60分钟线)') + sync_parser = subparsers.add_parser('sync', help='一键增量同步所有数据到数据库(日线 + 5/15/30/60分钟线)') + sync_parser.add_argument( + '--adj-type', + choices=['forward', 'backward', 'none'], + default='forward', + help='复权类型:forward(前复权,默认)/ backward(后复权)/ none(不复权)' + ) return parser.parse_args() @@ -338,30 +351,65 @@ def main() -> int: if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('datetime', 'date'): data = data.reset_index() else: - # 获取所有股票的日线数据 - data = reader.read_all_daily_data() + # 逐股票处理以确保复权正确(apply_forward/backward_adj 仅支持单只股票) + stocks = reader.get_stock_list() + processor = DataProcessor() + gbbq = reader.read_gbbq() + adj_type = getattr(args, 'adj_type', 'forward') + to_csv = not args.db_only + to_db = not args.csv_only + incremental = hasattr(args, 'incremental') and args.incremental + + iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() + all_filtered = [] + for _, stock in iterator: + scode = stock['code'] + smarket = 1 if scode.startswith('sh') else 0 + try: + sdata = reader.read_daily_data(smarket, scode) + if isinstance(sdata.index, pd.DatetimeIndex) or sdata.index.name in ('datetime', 'date'): + sdata = sdata.reset_index() + if sdata.empty: + continue + processed = processor.process_daily_data(sdata, gbbq=gbbq, adj_type=adj_type) + filtered = processor.filter_data(processed, start_date=start_date, end_date=args.end_date) + if filtered.empty: + continue + if to_db: + if incremental: + storage.save_incremental(filtered, 'daily_data', conflict_columns=('code', 'date'), batch_size=config.db_batch_size) + else: + storage.save_to_database(filtered, 'daily_data', batch_size=config.db_batch_size) + if to_csv: + all_filtered.append(filtered) + except FileNotFoundError: + continue + except Exception as e: + logger.error(f"处理 {scode} 日线数据时出错: {e}") + continue + + if to_csv and all_filtered: + storage.save_to_csv(pd.concat(all_filtered, ignore_index=True), 'daily_data') + return 0 + # 单只股票路径 if data.empty: logger.warning("未获取到任何数据") return 0 logger.info(f"获取到 {len(data)} 条日线数据记录") - # 处理数据 processor = DataProcessor() gbbq = reader.read_gbbq() - processed_data = processor.process_daily_data(data, gbbq=gbbq) + adj_type = getattr(args, 'adj_type', 'forward') + processed_data = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) - # 根据日期筛选(code 过滤用纯6位格式,与 DataFrame 中的 code 列一致) - filter_code = None - if args.code: - c = args.code - filter_code = [c[-6:] if len(c) > 6 else c] + c = args.code filtered_data = processor.filter_data( processed_data, start_date=start_date, end_date=args.end_date, - codes=filter_code + codes=[c[-6:] if len(c) > 6 else c] ) if filtered_data.empty: @@ -370,12 +418,10 @@ def main() -> int: logger.info(f"筛选后有 {len(filtered_data)} 条日线数据记录") - # 确定保存方式 to_csv = not args.db_only to_db = not args.csv_only incremental = hasattr(args, 'incremental') and args.incremental - # 保存数据 if to_csv: storage.save_to_csv(filtered_data, 'daily_data') if to_db: @@ -467,7 +513,8 @@ def main() -> int: logger.info(f"日线起始日期: {start_date}") gbbq = reader.read_gbbq() - success = sync_all_daily_data(reader, processor, storage, start_date, gbbq=gbbq) + adj_type = getattr(args, 'adj_type', 'forward') + success = sync_all_daily_data(reader, processor, storage, start_date, gbbq=gbbq, adj_type=adj_type) if not success: logger.error("同步日线数据时出错") has_error = True diff --git a/src/processor.py b/src/processor.py index d3f18c2..fd9bd4a 100644 --- a/src/processor.py +++ b/src/processor.py @@ -185,12 +185,85 @@ def apply_forward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: return df @staticmethod - def process_daily_data(df: pd.DataFrame, gbbq: pd.DataFrame = None) -> pd.DataFrame: + def apply_backward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: + """计算后复权价格并原地更新 open/high/low/close,新增 adj_factor 列 + + 后复权算法:对每个除权日,将该日及之后的全部数据乘以 1/factor, + 使价格序列以最早历史价格为基准保持连续。 + + Args: + df: 单只股票的日线 DataFrame,含 code(6位), market, date, open/high/low/close + gbbq: 全量权息 DataFrame(来自 reader.read_gbbq()) + + Returns: + 含 adj_factor 列、价格已后复权的 DataFrame + """ + if gbbq.empty or df.empty: + df = df.copy() + df['adj_factor'] = 1.0 + return df + + market_val = df['market'].iloc[0] + prefix = 'sz' if market_val == 0 else 'sh' + full_code = prefix + str(df['code'].iloc[0]).zfill(6) + + events = gbbq[gbbq['full_code'] == full_code].copy() + df = df.copy() + if events.empty: + df['adj_factor'] = 1.0 + return df + + events['ex_date'] = pd.to_datetime( + events['datetime'].astype(str).str[:8], format='%Y%m%d' + ) + data_start = df['date'].min() + events = events[ + (events['category'] == 1) & (events['ex_date'] > data_start) + ].sort_values('ex_date') + + df = df.sort_values('date').copy() + df['adj_factor'] = 1.0 + + for _, ev in events.iterrows(): + ex_date = ev['ex_date'] + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 + peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + + before = df[df['date'] < ex_date] + if before.empty: + continue + prev_close = float(before['close'].iloc[-1]) + if prev_close <= 0: + continue + + denominator = prev_close * (1 + songgu + peigu) + if denominator <= 0: + continue + factor = (prev_close - hongli + peigujia * peigu) / denominator + + if factor <= 0 or factor > 2: + logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") + continue + + # 后复权:除权日及之后乘以 1/factor(向上调整,历史价格不变) + mask = df['date'] >= ex_date + df.loc[mask, 'adj_factor'] *= (1.0 / factor) + + for col in ['open', 'high', 'low', 'close']: + df[col] = (df[col] * df['adj_factor']).round(3) + + return df + + @staticmethod + def process_daily_data(df: pd.DataFrame, gbbq: pd.DataFrame = None, adj_type: str = 'forward') -> pd.DataFrame: """处理日线数据 Args: df: 原始日线数据DataFrame - gbbq: 权息数据(来自 reader.read_gbbq()),传入时执行前复权,None 时跳过复权 + gbbq: 权息数据(来自 reader.read_gbbq()),传入时执行复权,None 时跳过复权 + adj_type: 复权类型,'forward'(前复权)/ 'backward'(后复权)/ 'none'(不复权),默认 'forward' Returns: DataFrame: 处理后的数据 @@ -222,9 +295,14 @@ def process_daily_data(df: pd.DataFrame, gbbq: pd.DataFrame = None) -> pd.DataFr # 数据质量校验 processed_df = DataProcessor._validate_ohlcv(processed_df) - # 前复权处理(在均线计算之前,确保均线基于复权价格) + # 复权处理(在均线计算之前,确保均线基于复权价格) if gbbq is not None and not gbbq.empty and 'date' in processed_df.columns: - processed_df = DataProcessor.apply_forward_adj(processed_df, gbbq) + if adj_type == 'forward': + processed_df = DataProcessor.apply_forward_adj(processed_df, gbbq) + elif adj_type == 'backward': + processed_df = DataProcessor.apply_backward_adj(processed_df, gbbq) + else: # 'none' + processed_df['adj_factor'] = 1.0 # 计算均线指标 if all(col in processed_df.columns for col in ['close', 'volume']): From 993dcfd12b453712167d056afc0383e4d8121ab7 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 22:48:29 +0800 Subject: [PATCH 06/26] =?UTF-8?q?=E5=A2=9E=E5=8A=A0gitignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index ab735a7..e70c070 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,5 @@ poetry.lock output/ */__pycache__/ +CLAUDE.md +tdx_data.db From 59a6da3a50953140d77069f3879baed42ab62095 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 22:48:55 +0800 Subject: [PATCH 07/26] =?UTF-8?q?=E6=9B=B4=E6=96=B0CLAUDE.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CLAUDE.md | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index d40708d..7a3ae89 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -65,6 +65,38 @@ CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (s 代码带市场前缀:`sz000001`、`sh600000`。深圳 market=0,上海 market=1。 A 股筛选规则:深圳 `000/001/002/300` 开头,上海 `60/688` 开头。 +## 首次使用初始化 + +增量同步依赖数据库唯一约束(`ON CONFLICT DO NOTHING`)。**首次写入数据后**,必须执行一次约束脚本,否则 `save_incremental()` 不会跳过重复数据: + +```bash +psql -d -f scripts/add_constraints.sql +``` + +该脚本会先清理已有重复行,再添加约束。 + ## 配置 -通过 `.env` 文件配置,必填:`TDX_PATH`、`DB_TYPE`、`DB_HOST`、`DB_NAME`、`DB_USER`、`DB_PASSWORD`。 +通过 `.env` 文件配置: + +| 变量 | 必填 | 说明 | +|------|------|------| +| `TDX_PATH` | 是 | 通达信安装目录 | +| `DB_TYPE` | 是 | `postgresql` / `mysql` / `sqlite` | +| `DB_HOST` | 是 | 数据库主机 | +| `DB_NAME` | 是 | 数据库名 | +| `DB_USER` | 是 | 数据库用户名 | +| `DB_PASSWORD` | 是 | 数据库密码 | +| `DB_PORT` | 否 | 默认 `5432` | +| `DB_BATCH_SIZE` | 否 | 批量写入大小,默认 `10000` | +| `CSV_OUTPUT_PATH` | 否 | CSV 输出目录,默认 `output/` | +| `USE_TQDM` | 否 | 是否显示进度条,默认 `True` | + +**推荐使用 PostgreSQL**:`save_incremental()` 对 PostgreSQL 使用 `psycopg2.extras.execute_values` 真正批量插入(比 MySQL/SQLite 的 executemany 快 10~100x)。 + +## sync 命令增量策略 + +`python main.py sync` 内部行为: + +- **日线**:查 `MAX(date)` from `daily_data`(全局),以 `latest+1天` 为起始过滤 +- **分钟线**:逐股票查 `MAX(datetime)` from `minute5_data WHERE code=?`(按股票精确),无全局起始日期 From 16112d89070622cf8942cddbeccdb08757653ff3 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 23:46:27 +0800 Subject: [PATCH 08/26] =?UTF-8?q?feat:=20=E7=AE=80=E5=8C=96=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E4=B8=BA=E6=97=A5=E7=BA=BF=E4=B8=93=E7=94=A8=E5=8C=85?= =?UTF-8?q?=EF=BC=88simple-daily=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 删除分钟线、板块关系、均线计算等全部非日线功能 - 将包重组为 src/tdx2db/,支持 pip install 后 from tdx2db import TdxDailySync - 日期格式改为 YYYYMMDD 整数,DailyData 表去除 MA 列,新增 turnover_rate 预留列 - 增量策略:一次查询所有股票最新日期;有除权事件时自动全量重写该股数据 - 新增 tests/test_daily.py:7 个测试用例覆盖全量同步、指定股票、复权正确性、增量更新 - 更新 pyproject.toml / README.md / requirements.txt Co-Authored-By: Claude Sonnet 4.6 --- .github/workflows/python-package.yml | 20 +- CLAUDE.md | 83 ++-- README.md | 133 +++--- __init__.py | 3 - main.py | 16 +- pyproject.toml | 17 +- requirements.txt | 1 + scripts/add_constraints.sql | 70 ---- scripts/migrate_add_adj_factor.sql | 16 - src/__init__.py | 7 - src/cli.py | 549 ------------------------- src/processor.py | 451 --------------------- src/reader.py | 458 --------------------- src/storage.py | 585 --------------------------- src/tdx2db/__init__.py | 105 +++++ src/tdx2db/cli.py | 212 ++++++++++ src/{ => tdx2db}/config.py | 28 +- src/{ => tdx2db}/logger.py | 0 src/tdx2db/processor.py | 207 ++++++++++ src/tdx2db/reader.py | 102 +++++ src/tdx2db/storage.py | 167 ++++++++ tests/test_daily.py | 248 ++++++++++++ 22 files changed, 1171 insertions(+), 2307 deletions(-) delete mode 100644 __init__.py delete mode 100644 scripts/add_constraints.sql delete mode 100644 scripts/migrate_add_adj_factor.sql delete mode 100644 src/__init__.py delete mode 100644 src/cli.py delete mode 100644 src/processor.py delete mode 100644 src/reader.py delete mode 100644 src/storage.py create mode 100644 src/tdx2db/__init__.py create mode 100644 src/tdx2db/cli.py rename src/{ => tdx2db}/config.py (71%) rename src/{ => tdx2db}/logger.py (100%) create mode 100644 src/tdx2db/processor.py create mode 100644 src/tdx2db/reader.py create mode 100644 src/tdx2db/storage.py create mode 100644 tests/test_daily.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index dcfa3db..882b82a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,6 +1,3 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - name: Python package on: @@ -11,30 +8,23 @@ on: jobs: build: - runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + pip install pytest + pip install -r requirements.txt - name: Test with pytest run: | - pytest + pytest tests/ -v diff --git a/CLAUDE.md b/CLAUDE.md index 7a3ae89..d292e11 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,26 +4,29 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## 项目概述 -tdx2db:从本地通达信(TDX)行情软件读取 A 股数据,增量同步到数据库。是量化分析工作站的数据入口。 +tdx2db:从本地通达信(TDX)行情软件读取 A 股日线数据,增量同步到数据库。支持作为 Python 包被其他项目调用。 ## 常用命令 ```bash # 安装依赖 pip install -r requirements.txt +# 或安装为可编辑包(支持 import) +pip install -e . -# 一键增量同步(日线 + 5/15/30/60 分钟线)— 日常使用这一个命令即可 +# 一键增量同步日线数据 — 日常使用这一个命令 python main.py sync # 单独同步 -python main.py daily --db-only --auto-start --incremental -python main.py minutes --db-only --auto-start --incremental +python main.py daily --incremental +python main.py daily --code sz000001 --start 20240101 # 同步股票列表 -python main.py stock-list --db-only -``` +python main.py stock-list -无测试套件。验证方式是运行 `sync` 命令后检查数据库数据。 +# 运行测试 +python -m pytest tests/ -v +``` ## 架构 @@ -32,48 +35,39 @@ python main.py stock-list --db-only ``` CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (storage.py) ↓ ↓ ↓ ↓ - argparse pytdx 读取本地 校验 + 重采样 + 均线 SQLAlchemy 批量写库 - 命令分发 + .day/.lc5 文件 (OHLCV 校验, resample, 支持增量 ON CONFLICT - 同步编排 MA5~MA250) 表名白名单保护 + argparse pytdx 读取本地 校验 + 复权处理 SQLAlchemy 批量写库 + 命令分发 + .day 文件 (OHLCV 校验, 前/后复权) 支持增量 ON CONFLICT + 同步编排 表名白名单保护 ``` -- **cli.py**: 除命令分发外,`sync_all_daily_data` / `sync_all_min_data` / `sync_single_stock_min_data` 编排逐股票流式同步 -- **config.py**: 全局单例 `config`,从 `.env` 加载配置(TDX_PATH、DB_*) +- **cli.py**: 命令分发 + `sync_all_daily()` 逐股票流式同步 +- **config.py**: 全局单例 `config`,从 `.env` 加载配置 - **logger.py**: 全局单例 `logger` +- **`__init__.py`**: 暴露 `TdxDailySync` 公共 API -### 关键数据流 +## 关键数据流 -日线和分钟线均为**逐股票流式处理**,不全量加载到内存: +逐股票流式处理,不全量加载到内存: -1. **日线**: 逐股票读取 `vipdoc/{sz,sh}/lday/*.day` → `process_daily_data()` 校验 OHLCV + 计算均线 → 增量写入 `daily_data` 表 -2. **分钟线**: 逐股票读取 `.lc5`(5 分钟)→ `resample_ohlcv()` 重采样为 15/30/60 分钟 → `process_min_data()` 校验 + 均线 → 分别写入 `minute{5,15,30,60}_data` 表 -3. **增量同步**: `save_incremental()` 使用批量 executemany + `ON CONFLICT DO NOTHING`(PostgreSQL)/ `INSERT IGNORE`(MySQL)跳过重复。分钟线按股票精确查询最新日期(`get_latest_datetime_by_code`),日线逐股票增量。 +1. 读取 `vipdoc/{sz,sh}/lday/*.day` → `process_daily_data()` 校验 OHLCV + 复权 +2. 增量策略:`get_all_latest_dates()` 一次查询所有股票最新日期;若有除权事件则 `delete_stock_data()` + 全量重写 +3. `save_incremental()` 使用 `ON CONFLICT DO NOTHING`(PG)/ `INSERT OR IGNORE`(SQLite)/ `INSERT IGNORE`(MySQL) -### 数据库表 +## 数据库表 | 表名 | 唯一约束 | 用途 | |------|----------|------| -| `daily_data` | (code, date) | 日线数据 | -| `minute{5,15,30,60}_data` | (code, datetime) | 分钟线数据 | +| `daily_data` | (code, date) | 日线数据,date 为 YYYYMMDD 整数 | | `stock_info` | code | 股票列表 | -| `block_stock_relation` | — | 板块关系(未完整实现) | - -唯一约束需通过 `scripts/add_constraints.sql` 手动添加。 -### 股票代码格式 +唯一约束由 SQLAlchemy `UniqueConstraint` 在建表时自动创建,无需手动执行 SQL 脚本。 -代码带市场前缀:`sz000001`、`sh600000`。深圳 market=0,上海 market=1。 -A 股筛选规则:深圳 `000/001/002/300` 开头,上海 `60/688` 开头。 +## 股票代码格式 -## 首次使用初始化 - -增量同步依赖数据库唯一约束(`ON CONFLICT DO NOTHING`)。**首次写入数据后**,必须执行一次约束脚本,否则 `save_incremental()` 不会跳过重复数据: - -```bash -psql -d -f scripts/add_constraints.sql -``` - -该脚本会先清理已有重复行,再添加约束。 +- 文件/CLI 层:带市场前缀,如 `sz000001`、`sh600000` +- 数据库层:纯 6 位数字,如 `000001`(reader 写入时截取) +- 深圳 market=0,上海 market=1 +- A 股筛选:深圳 `000/001/002/300` 开头,上海 `60/688` 开头 ## 配置 @@ -82,21 +76,20 @@ psql -d -f scripts/add_constraints.sql | 变量 | 必填 | 说明 | |------|------|------| | `TDX_PATH` | 是 | 通达信安装目录 | -| `DB_TYPE` | 是 | `postgresql` / `mysql` / `sqlite` | -| `DB_HOST` | 是 | 数据库主机 | -| `DB_NAME` | 是 | 数据库名 | -| `DB_USER` | 是 | 数据库用户名 | -| `DB_PASSWORD` | 是 | 数据库密码 | +| `DB_TYPE` | 否 | `sqlite`(默认)/ `mysql` / `postgresql` | +| `DB_NAME` | 否 | 数据库名,SQLite 时为文件名(生成 `.db`) | +| `DB_HOST` | MySQL/PG 必填 | 数据库主机 | +| `DB_USER` | MySQL/PG 必填 | 数据库用户名 | +| `DB_PASSWORD` | MySQL/PG 必填 | 数据库密码 | | `DB_PORT` | 否 | 默认 `5432` | | `DB_BATCH_SIZE` | 否 | 批量写入大小,默认 `10000` | -| `CSV_OUTPUT_PATH` | 否 | CSV 输出目录,默认 `output/` | | `USE_TQDM` | 否 | 是否显示进度条,默认 `True` | -**推荐使用 PostgreSQL**:`save_incremental()` 对 PostgreSQL 使用 `psycopg2.extras.execute_values` 真正批量插入(比 MySQL/SQLite 的 executemany 快 10~100x)。 - ## sync 命令增量策略 `python main.py sync` 内部行为: -- **日线**:查 `MAX(date)` from `daily_data`(全局),以 `latest+1天` 为起始过滤 -- **分钟线**:逐股票查 `MAX(datetime)` from `minute5_data WHERE code=?`(按股票精确),无全局起始日期 +- 一次 SQL 查询获取所有股票最新日期(`SELECT code, MAX(date) FROM daily_data GROUP BY code`) +- 对每只股票:检查 gbbq 中是否有除权事件(category=1)发生在 last_date 之后 + - 有除权 → 删除该股旧数据,全量重写(保证复权价格正确) + - 无除权 → 只写入 last_date 之后的新数据 diff --git a/README.md b/README.md index 1ccfad1..3108c3e 100644 --- a/README.md +++ b/README.md @@ -1,107 +1,112 @@ -# 通达信数据处理工具 +# tdx2db -读取本地通达信股票数据,增量同步到数据库。 +从本地通达信(TDX)行情软件读取 A 股日线数据,增量同步到数据库。支持作为 Python 包被其他项目调用。 -## 快速开始 +## 特性 -```bash -# 一键同步所有数据(日线 + 5/15/30/60分钟线) -python main.py sync -``` +- 同步深圳/上海全量 A 股日线数据(含科创板) +- 前复权 / 后复权 / 不复权,默认前复权 +- 增量更新:有除权事件的个股自动全量重写,确保复权价格正确 +- 日期格式:`YYYYMMDD` 整数(便于范围查询) +- 数据库:SQLite(默认)/ MySQL / PostgreSQL ## 安装 ```bash -# Python >= 3.10 +# 直接安装依赖 pip install -r requirements.txt -# 复制并编辑配置文件 -cp .env.example .env +# 或作为包安装(支持被其他项目 import) +pip install -e . ``` -**.env 必填配置**: +## 配置 + +复制 `.env.example` 为 `.env` 并填写: + ``` -TDX_PATH=D:\通达信安装目录 -DB_TYPE=postgresql -DB_HOST=localhost -DB_NAME=tdx_data +TDX_PATH=/path/to/tdx # 通达信安装目录(必填) +DB_TYPE=sqlite # sqlite / mysql / postgresql +DB_NAME=tdx_data # SQLite 时为文件名(生成 tdx_data.db) +DB_HOST=localhost # MySQL/PostgreSQL 必填 DB_USER=postgres DB_PASSWORD=your_password +DB_BATCH_SIZE=10000 +USE_TQDM=True ``` -## 首次使用 - -1. 打开通达信 → 选项 → 盘后数据下载 → 下载日线和分钟线数据 +## 命令行使用 -2. 同步股票列表: ```bash -python main.py stock-list --db-only -``` +# 同步股票列表 +python main.py stock-list -3. 一键同步所有行情数据: -```bash +# 一键增量同步所有股票日线(日常使用这一个命令) python main.py sync -``` -## 启用增量同步(推荐) +# 同步所有股票日线(全量) +python main.py daily -> 增量同步可自动跳过重复数据,大幅提升每日更新效率。 - -**老用户**(已有数据库表)需执行一次约束脚本: -```bash -# PostgreSQL -psql -U your_user -d your_database -f scripts/add_constraints.sql -``` +# 同步指定股票 +python main.py daily --code sz000001 -**新用户**同样建议执行,以启用增量同步功能。 +# 指定日期范围 +python main.py daily --start 20240101 --end 20241231 -脚本作用:为 `daily_data`、`minute*_data` 表添加唯一约束,确保 `(code, date/datetime)` 不重复。 +# 指定复权类型 +python main.py sync --adj backward +``` -## 每日更新 +安装为包后也可直接使用 `tdx2db` 命令: ```bash -python main.py sync +tdx2db sync ``` -程序会自动检测数据库最新日期,只同步新数据。 +## 作为 Python 包调用 -## 其他命令 +```python +from tdx2db import TdxDailySync -
-单独同步日线/分钟线 +sync = TdxDailySync( + tdx_path="/path/to/tdx", + db_url="sqlite:///data.db", +) -```bash -# 日线增量同步 -python main.py daily --db-only --auto-start --incremental +# 同步所有股票 +sync.sync_all(adj_type='forward') + +# 同步单只股票 +sync.sync_stock('sz000001', start_date=20240101) -# 分钟线增量同步 -python main.py minutes --db-only --auto-start --incremental +# 查询数据 +df = sync.get_daily('sz000001', start_date=20240101, end_date=20241231) +print(df.head()) ``` -
-
-指定日期范围 +## 数据表结构 -```bash -python main.py daily --db-only --start_date 2025-01-01 --end_date 2025-01-31 -python main.py minutes --db-only --start_date 2025-01-01 -``` -
+**daily_data** -
-导出到 CSV +| 列 | 类型 | 说明 | +|----|------|------| +| code | String | 股票代码(6位,如 `000001`) | +| market | Integer | 市场(0=深圳,1=上海) | +| date | Integer | 日期 YYYYMMDD | +| open/high/low/close | Float | 复权后价格 | +| volume | Float | 成交量 | +| amount | Float | 成交额 | +| adj_factor | Float | 复权因子(1.0=无复权) | +| turnover_rate | Float | 换手率(%),待实现 | -```bash -python main.py daily --csv-only -python main.py minutes --csv-only -``` -
+唯一约束:`(code, date)` -## 数据库支持 +## 运行测试 -- PostgreSQL(推荐) -- MySQL -- SQLite +```bash +pip install pytest +python -m pytest tests/ -v +``` ## 许可证 diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e67467c..0000000 --- a/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""TDX2DB""" -__version__ = '0.0.1' -__all__ = ['src'] diff --git a/main.py b/main.py index 85b538d..2081a48 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - -""" -通达信数据处理工具入口点 - -此脚本是程序的主入口点,用于启动通达信数据处理工具。 -支持以下功能: -- 获取股票列表:python main.py stock-list -- 获取日线数据:python main.py daily [--code CODE] [--market MARKET] -- 获取分钟线数据:python main.py minute --code CODE --market MARKET [--freq {1,5}] -- 获取分钟线数据:python main.py minutes --code CODE --market MARKET [--freq {1,5}] - -使用 python main.py -h 查看完整的帮助信息。 -""" +"""tdx2db 命令行入口""" import sys -from src.cli import main +from src.tdx2db.cli import main if __name__ == '__main__': try: diff --git a/pyproject.toml b/pyproject.toml index 3d7df54..9fc9091 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,16 @@ [tool.poetry] -name = "tdx-data-processor" -version = "0.1.0" -description = "一个使用pytdx读取本地股票数据并存储到数据库或CSV的程序" -authors = ["我如此拉风 "] +name = "tdx2db" +version = "0.2.0" +description = "从通达信本地文件同步 A 股日线数据到数据库" +authors = ["jaden1q84 "] readme = "README.md" +packages = [{include = "tdx2db", from = "src"}] + +[tool.poetry.scripts] +tdx2db = "tdx2db.cli:main" [tool.poetry.dependencies] -python = "^3.11" +python = ">=3.10" pytdx = "^1.72" pandas = "^2.2.3" sqlalchemy = "^2.0.40" @@ -15,6 +19,9 @@ python-dotenv = "^1.1.0" pymysql = "^1.1.1" psycopg2-binary = "^2.9.10" +[tool.poetry.group.dev.dependencies] +pytest = "^8.0" + [[tool.poetry.source]] name = "aliyun" url = "https://mirrors.aliyun.com/pypi/simple/" diff --git a/requirements.txt b/requirements.txt index 2977e81..8628f86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ pandas>=2.2.3 sqlalchemy>=2.0.40 tqdm>=4.67.1 python-dotenv>=1.1.0 +pymysql>=1.1.1 psycopg2-binary>=2.9.10 diff --git a/scripts/add_constraints.sql b/scripts/add_constraints.sql deleted file mode 100644 index 258478d..0000000 --- a/scripts/add_constraints.sql +++ /dev/null @@ -1,70 +0,0 @@ --- 增量同步约束脚本 --- 为各数据表添加唯一约束,确保 (code, datetime) 组合唯一 --- 执行前请先备份数据库 - --- ============================================ --- 第一步:清理重复数据(如有) --- ============================================ - --- 清理 daily_data 重复记录,保留 id 最大的(日线使用 date 字段) -DELETE FROM daily_data a USING daily_data b -WHERE a.id < b.id AND a.code = b.code AND a.date = b.date; - --- 清理 minute5_data 重复记录 -DELETE FROM minute5_data a USING minute5_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 minute15_data 重复记录 -DELETE FROM minute15_data a USING minute15_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 minute30_data 重复记录 -DELETE FROM minute30_data a USING minute30_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 minute60_data 重复记录 -DELETE FROM minute60_data a USING minute60_data b -WHERE a.id < b.id AND a.code = b.code AND a.datetime = b.datetime; - --- 清理 block_stock_relation 重复记录 -DELETE FROM block_stock_relation a USING block_stock_relation b -WHERE a.id < b.id AND a.block_code = b.block_code AND a.code = b.code; - --- ============================================ --- 第二步:添加唯一约束 --- ============================================ - --- daily_data 表:(code, date) 唯一(日线使用 date 字段) -ALTER TABLE daily_data -ADD CONSTRAINT uq_daily_code_date UNIQUE (code, date); - --- minute5_data 表 -ALTER TABLE minute5_data -ADD CONSTRAINT uq_minute5_code_datetime UNIQUE (code, datetime); - --- minute15_data 表 -ALTER TABLE minute15_data -ADD CONSTRAINT uq_minute15_code_datetime UNIQUE (code, datetime); - --- minute30_data 表 -ALTER TABLE minute30_data -ADD CONSTRAINT uq_minute30_code_datetime UNIQUE (code, datetime); - --- minute60_data 表 -ALTER TABLE minute60_data -ADD CONSTRAINT uq_minute60_code_datetime UNIQUE (code, datetime); - --- block_stock_relation 表:(block_code, code) 唯一 -ALTER TABLE block_stock_relation -ADD CONSTRAINT uq_block_code UNIQUE (block_code, code); - --- stock_info 表:code 已有唯一索引,无需额外操作 - --- ============================================ --- 验证约束(可选) --- ============================================ - --- 查看所有约束 --- SELECT conname, conrelid::regclass, pg_get_constraintdef(oid) --- FROM pg_constraint --- WHERE conname LIKE 'uq_%'; diff --git a/scripts/migrate_add_adj_factor.sql b/scripts/migrate_add_adj_factor.sql deleted file mode 100644 index 0b70130..0000000 --- a/scripts/migrate_add_adj_factor.sql +++ /dev/null @@ -1,16 +0,0 @@ --- 迁移脚本:为 daily_data 表新增前复权因子列 --- 执行一次即可,若列已存在则不报错(IF NOT EXISTS) --- --- 使用方式: --- psql -d -f scripts/migrate_add_adj_factor.sql --- mysql -u -p < scripts/migrate_add_adj_factor.sql --- --- 说明: --- adj_factor = 1.0 表示该行是原始价格(无复权或当天最新价) --- adj_factor < 1.0 表示历史数据已向前调整(价格已乘以该因子) - --- PostgreSQL -ALTER TABLE daily_data ADD COLUMN IF NOT EXISTS adj_factor FLOAT DEFAULT 1.0; - --- MySQL(不支持 IF NOT EXISTS,如报列已存在错误可忽略) --- ALTER TABLE daily_data ADD COLUMN adj_factor FLOAT DEFAULT 1.0; diff --git a/src/__init__.py b/src/__init__.py deleted file mode 100644 index afefaec..0000000 --- a/src/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""TDX数据处理器 - 用于读取通达信本地股票数据并存储到数据库或CSV""" - -__version__ = "0.1.0" - -from .logger import logger, setup_logger - -__all__ = ["logger", "setup_logger", "__version__"] diff --git a/src/cli.py b/src/cli.py deleted file mode 100644 index 43a8cdf..0000000 --- a/src/cli.py +++ /dev/null @@ -1,549 +0,0 @@ -"""命令行接口模块 - -提供命令行接口,方便用户使用程序功能 -""" - -import argparse -import sys -from argparse import Namespace -from typing import Optional - -from datetime import timedelta - -import pandas as pd -from tqdm import tqdm - -from .reader import TdxDataReader -from .processor import DataProcessor -from .storage import DataStorage -from .config import config -from .logger import logger - - -def sync_single_stock_min_data( - reader: TdxDataReader, - processor: DataProcessor, - storage: DataStorage, - market: int, - code: str, - start_date: Optional[str] = None, - incremental: bool = True, -) -> bool: - """处理并存储单只股票的分钟数据 - - Args: - reader: 数据读取器 - processor: 数据处理器 - storage: 数据存储器 - market: 市场代码 - code: 股票代码 - start_date: 开始日期 - incremental: 是否启用精确增量 - """ - # DB 中 code 为纯 6 位数字(reader 写入时会截取),查询时需匹配 - db_code = code[-6:] if len(code) > 6 else code - - # 精确增量:查询该股票的最新日期 - if incremental and not start_date: - latest = storage.get_latest_datetime_by_code('minute5_data', db_code) - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.debug(f"{code} 增量起始日期: {start_date}") - - # 读取5分钟数据 - df_5min = reader.read_5min_data(market, code) - if df_5min.empty: - logger.warning(f"{code} 无5分钟数据") - return False - - # 准备 datetime 索引 - if not pd.api.types.is_datetime64_any_dtype(df_5min['datetime']): - df_5min['datetime'] = pd.to_datetime(df_5min['datetime']) - df_5min['date'] = df_5min['datetime'].dt.date - df_5min = df_5min.set_index('datetime') - - # 重采样为多周期 - df_15min = DataProcessor.resample_ohlcv(df_5min, '15min') - df_30min = DataProcessor.resample_ohlcv(df_5min, '30min') - df_60min = DataProcessor.resample_ohlcv(df_5min, '60min') - df_5min = df_5min.reset_index() - - # 处理、筛选、存储各周期 - freq_data = [ - (df_5min, 5, 'minute5_data'), - (df_15min, 15, 'minute15_data'), - (df_30min, 30, 'minute30_data'), - (df_60min, 60, 'minute60_data'), - ] - - has_data = False - for df, freq, table_name in freq_data: - processed = processor.process_min_data(df) - if start_date: - processed = processor.filter_data_min(processed, start_date=start_date) - if processed.empty: - continue - has_data = True - if incremental: - storage.save_incremental(processed, table_name) - else: - storage.save_minute_data(processed, freq=freq, to_csv=False, to_db=True) - - if has_data: - logger.info(f"{code} 分钟数据已处理并存入数据库") - else: - logger.debug(f"{code} 无新数据需要同步") - - return True - - -def sync_all_daily_data( - reader: TdxDataReader, - processor: DataProcessor, - storage: DataStorage, - start_date: Optional[str] = None, - gbbq: Optional[pd.DataFrame] = None, - adj_type: str = 'forward', -) -> bool: - """逐股票流式同步日线数据,避免全量加载到内存""" - try: - stocks = reader.get_stock_list() - logger.info(f"同步日线数据,共 {len(stocks)} 只股票") - - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - total_inserted = 0 - - for _, stock in iterator: - code = stock['code'] - market = 1 if code.startswith('sh') else 0 - try: - data = reader.read_daily_data(market, code) - if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('datetime', 'date'): - data = data.reset_index() - if data.empty: - continue - - processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) - filtered = processor.filter_data(processed, start_date=start_date) - if filtered.empty: - continue - - inserted = storage.save_incremental( - filtered, 'daily_data', - conflict_columns=('code', 'date'), - batch_size=config.db_batch_size - ) - total_inserted += inserted - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"同步 {code} 日线数据时出错: {e}") - continue - - if total_inserted > 0: - logger.info(f"日线数据同步完成,共插入 {total_inserted} 条") - else: - logger.info("日线数据已是最新") - return True - except Exception as e: - logger.error(f"同步日线数据时出错: {e}") - return False - - -def sync_all_min_data( - reader: TdxDataReader, - processor: DataProcessor, - storage: DataStorage, - start_date: Optional[str] = None, -) -> bool: - """编排所有股票的分钟数据同步""" - try: - stocks = reader.get_stock_list() - logger.info(f"处理所有股票的分钟数据,共 {len(stocks)} 只股票") - - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - - for _, stock in iterator: - code = stock['code'] - market = 1 if code.startswith('sh') else 0 - try: - sync_single_stock_min_data(reader, processor, storage, market, code, start_date) - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"处理 {code} 分钟数据时出错: {e}") - continue - - return True - except Exception as e: - logger.error(f"处理分钟数据时出错: {e}") - return False - -def parse_args() -> Namespace: - """解析命令行参数 - - Returns: - Namespace: 解析后的命令行参数 - """ - parser = argparse.ArgumentParser(description='通达信数据处理工具') - - # 通用参数 - parser.add_argument('--tdx-path', help='通达信安装目录路径') - parser.add_argument('--output', help='输出CSV文件的目录路径') - parser.add_argument('--db-type', choices=['sqlite', 'mysql', 'postgresql'], help='数据库类型') - parser.add_argument('--db-host', help='数据库主机') - parser.add_argument('--db-port', help='数据库端口') - parser.add_argument('--db-name', help='数据库名称') - parser.add_argument('--db-user', help='数据库用户名') - parser.add_argument('--db-password', help='数据库密码') - parser.add_argument('--no-tqdm', action='store_true', help='禁用进度条') - parser.add_argument('--batch-size', type=int, default=10000, help='数据库批量插入的批次大小,默认10000条') - - # 子命令 - subparsers = parser.add_subparsers(dest='command', help='子命令') - - # 获取股票列表 - stock_list_parser = subparsers.add_parser('stock-list', help='获取股票列表') - stock_list_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - stock_list_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - - # 获取日线数据 - daily_parser = subparsers.add_parser('daily', help='获取日线数据') - daily_parser.add_argument('--code', help='股票代码,不指定则获取所有股票') - daily_parser.add_argument('--market', type=int, choices=[0, 1], help='市场代码,0表示深圳,1表示上海') - daily_parser.add_argument('--start_date', help='开始日期,格式为YYYY-MM-DD') - daily_parser.add_argument('--end_date', help='结束日期,格式为YYYY-MM-DD') - daily_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - daily_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - daily_parser.add_argument('--auto-start', action='store_true', help='自动检测起始日期(从数据库最新日期+1天开始)') - daily_parser.add_argument('--incremental', action='store_true', help='增量同步模式,跳过重复数据') - daily_parser.add_argument( - '--adj-type', - choices=['forward', 'backward', 'none'], - default='forward', - help='复权类型:forward(前复权,默认)/ backward(后复权)/ none(不复权)' - ) - - # 获取并计算分钟线数据 - min_parser = subparsers.add_parser('minutes', help='获取分钟线数据') - min_parser.add_argument('--code', help='股票代码,不指定则获取所有股票') - min_parser.add_argument('--market', type=int, choices=[0, 1], help='市场代码,0表示深圳,1表示上海') - min_parser.add_argument('--start_date', help='开始日期,格式为YYYY-MM-DD') - min_parser.add_argument('--end_date', help='结束日期,格式为YYYY-MM-DD') - min_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - min_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - min_parser.add_argument('--auto-start', action='store_true', help='自动检测起始日期(从数据库最新日期+1天开始)') - min_parser.add_argument('--incremental', action='store_true', help='增量同步模式,跳过重复数据') - - # 获取板块与股票对应关系 - block_relation_parser = subparsers.add_parser('block-relation', help='获取板块与股票对应关系【未实现】') - block_relation_parser.add_argument('--csv-only', action='store_true', help='仅保存到CSV') - block_relation_parser.add_argument('--db-only', action='store_true', help='仅保存到数据库') - - # 一键同步(日线 + 分钟线增量同步到数据库) - sync_parser = subparsers.add_parser('sync', help='一键增量同步所有数据到数据库(日线 + 5/15/30/60分钟线)') - sync_parser.add_argument( - '--adj-type', - choices=['forward', 'backward', 'none'], - default='forward', - help='复权类型:forward(前复权,默认)/ backward(后复权)/ none(不复权)' - ) - - return parser.parse_args() - -def update_config(args: Namespace) -> None: - """根据命令行参数更新配置 - - Args: - args: 解析后的命令行参数 - """ - # 更新通达信路径 - if args.tdx_path: - config.tdx_path = args.tdx_path - - # 更新CSV输出路径 - if args.output: - config.csv_output_path = args.output - - # 更新数据库配置 - if args.db_type: - config.db_type = args.db_type - if args.db_host: - config.db_host = args.db_host - if args.db_port: - config.db_port = args.db_port - if args.db_name: - config.db_name = args.db_name - if args.db_user: - config.db_user = args.db_user - if args.db_password: - config.db_password = args.db_password - if args.batch_size: - config.db_batch_size = args.batch_size - - # 更新进度条配置 - if args.no_tqdm: - config.use_tqdm = False - -def main() -> int: - """主函数 - - Returns: - int: 程序退出码,0表示成功,非0表示失败 - """ - args = parse_args() - update_config(args) - - # 初始化数据读取器 - try: - reader = TdxDataReader() - except (ValueError, FileNotFoundError) as e: - logger.error(f"初始化失败: {e}") - return 1 - - # 初始化数据存储器 - storage = DataStorage() - - # 处理子命令 - if args.command == 'stock-list': - # 获取股票列表 - try: - stocks = reader.get_stock_list() - logger.info(f"获取到 {len(stocks)} 只股票信息") - - # 确定保存方式 - to_csv = not args.db_only - to_db = not args.csv_only - - # 保存数据 - storage.save_stock_info(stocks, to_csv=to_csv, to_db=to_db, batch_size=config.db_batch_size) - - except Exception as e: - logger.error(f"获取股票列表时出错: {e}") - return 1 - - elif args.command == 'daily': - try: - # 处理 --auto-start 参数 - start_date = args.start_date - if hasattr(args, 'auto_start') and args.auto_start and not start_date: - latest = storage.get_latest_datetime('daily_data', date_column='date') - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.info(f"自动检测起始日期: {start_date}") - else: - logger.info("数据库中没有数据,将获取所有数据") - - # 获取日线数据 - if args.code: - # 自动推断市场:支持 sz/sh 前缀,或纯6位代码(6/688开头→sh,其余→sz) - code = args.code - if args.market is not None: - market = args.market - elif code.startswith('sh'): - market = 1 - elif code.startswith('sz'): - market = 0 - else: - pure = code[-6:] - market = 1 if pure.startswith(('6', '688')) else 0 - data = reader.read_daily_data(market, code) - if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('datetime', 'date'): - data = data.reset_index() - else: - # 逐股票处理以确保复权正确(apply_forward/backward_adj 仅支持单只股票) - stocks = reader.get_stock_list() - processor = DataProcessor() - gbbq = reader.read_gbbq() - adj_type = getattr(args, 'adj_type', 'forward') - to_csv = not args.db_only - to_db = not args.csv_only - incremental = hasattr(args, 'incremental') and args.incremental - - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - all_filtered = [] - for _, stock in iterator: - scode = stock['code'] - smarket = 1 if scode.startswith('sh') else 0 - try: - sdata = reader.read_daily_data(smarket, scode) - if isinstance(sdata.index, pd.DatetimeIndex) or sdata.index.name in ('datetime', 'date'): - sdata = sdata.reset_index() - if sdata.empty: - continue - processed = processor.process_daily_data(sdata, gbbq=gbbq, adj_type=adj_type) - filtered = processor.filter_data(processed, start_date=start_date, end_date=args.end_date) - if filtered.empty: - continue - if to_db: - if incremental: - storage.save_incremental(filtered, 'daily_data', conflict_columns=('code', 'date'), batch_size=config.db_batch_size) - else: - storage.save_to_database(filtered, 'daily_data', batch_size=config.db_batch_size) - if to_csv: - all_filtered.append(filtered) - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"处理 {scode} 日线数据时出错: {e}") - continue - - if to_csv and all_filtered: - storage.save_to_csv(pd.concat(all_filtered, ignore_index=True), 'daily_data') - return 0 - - # 单只股票路径 - if data.empty: - logger.warning("未获取到任何数据") - return 0 - - logger.info(f"获取到 {len(data)} 条日线数据记录") - - processor = DataProcessor() - gbbq = reader.read_gbbq() - adj_type = getattr(args, 'adj_type', 'forward') - processed_data = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) - - c = args.code - filtered_data = processor.filter_data( - processed_data, - start_date=start_date, - end_date=args.end_date, - codes=[c[-6:] if len(c) > 6 else c] - ) - - if filtered_data.empty: - logger.warning("筛选后没有数据") - return 0 - - logger.info(f"筛选后有 {len(filtered_data)} 条日线数据记录") - - to_csv = not args.db_only - to_db = not args.csv_only - incremental = hasattr(args, 'incremental') and args.incremental - - if to_csv: - storage.save_to_csv(filtered_data, 'daily_data') - if to_db: - if incremental: - storage.save_incremental( - filtered_data, 'daily_data', - conflict_columns=('code', 'date'), - batch_size=config.db_batch_size - ) - else: - storage.save_to_database(filtered_data, 'daily_data', batch_size=config.db_batch_size) - - except Exception as e: - logger.error(f"获取日线数据时出错: {e}") - return 1 - - elif args.command == 'minutes': - try: - # 处理 --auto-start 参数(使用15分钟线表作为参考) - start_date = args.start_date - if hasattr(args, 'auto_start') and args.auto_start and not start_date: - latest = storage.get_latest_datetime('minute15_data') - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.info(f"自动检测起始日期: {start_date}") - else: - logger.info("数据库中没有数据,将获取所有数据") - - incremental = hasattr(args, 'incremental') and args.incremental - - # 获取分钟线数据 - if args.code and args.market is not None: - # 单只股票:统一走 sync_single_stock_min_data,覆盖 5/15/30/60 全部周期 - processor = DataProcessor() - success = sync_single_stock_min_data( - reader, processor, storage, - args.market, args.code, - start_date=start_date, - incremental=incremental, - ) - if not success: - logger.warning(f"股票 {args.code} 无数据可同步") - return 0 - else: - # 获取所有股票的分钟线数据 - logger.info("开始处理所有股票的分钟线数据...") - processor = DataProcessor() - success = sync_all_min_data(reader, processor, storage, start_date) - if success: - logger.info("所有股票的分钟线数据处理完成") - else: - logger.error("处理分钟线数据时出错") - return 1 - - except Exception as e: - logger.error(f"获取分钟线数据时出错: {e}") - return 1 - - elif args.command == 'block-relation': - # 获取板块与股票对应关系 - try: - block_relations = reader.get_block_stock_relation() - logger.info(f"获取到 {len(block_relations)} 条板块与股票对应关系记录") - - # 确定保存方式 - to_csv = not args.db_only - to_db = not args.csv_only - - # 保存数据 - storage.save_block_relation(block_relations, to_csv=to_csv, to_db=to_db, batch_size=config.db_batch_size) - - except Exception as e: - logger.error(f"获取板块与股票对应关系时出错: {e}") - return 1 - - elif args.command == 'sync': - # 一键增量同步所有数据 - logger.info("开始一键增量同步...") - processor = DataProcessor() - has_error = False - - # 1. 同步日线数据 - try: - logger.info("=== 同步日线数据 ===") - latest = storage.get_latest_datetime('daily_data', date_column='date') - start_date = None - if latest: - start_date = (latest + timedelta(days=1)).strftime('%Y-%m-%d') - logger.info(f"日线起始日期: {start_date}") - - gbbq = reader.read_gbbq() - adj_type = getattr(args, 'adj_type', 'forward') - success = sync_all_daily_data(reader, processor, storage, start_date, gbbq=gbbq, adj_type=adj_type) - if not success: - logger.error("同步日线数据时出错") - has_error = True - except Exception as e: - logger.error(f"同步日线数据时出错: {e}") - has_error = True - - # 2. 同步分钟线数据(逐股票精确增量,不传全局 start_date) - try: - logger.info("=== 同步分钟线数据 ===") - success = sync_all_min_data(reader, processor, storage) - if not success: - logger.error("同步分钟线数据时出错") - has_error = True - except Exception as e: - logger.error(f"同步分钟线数据时出错: {e}") - has_error = True - - if has_error: - logger.warning("同步完成,但有部分错误") - return 1 - else: - logger.info("一键增量同步完成!") - - else: - logger.error("请指定子命令,使用 -h 查看帮助信息") - return 1 - - return 0 - -if __name__ == '__main__': - sys.exit(main()) diff --git a/src/processor.py b/src/processor.py deleted file mode 100644 index fd9bd4a..0000000 --- a/src/processor.py +++ /dev/null @@ -1,451 +0,0 @@ -"""数据处理模块 - -负责对从通达信读取的原始数据进行清洗和转换,包括: -- 数据格式转换 -- 缺失值处理 -- 异常值检测 -- 计算技术指标 -- OHLCV 重采样 -""" - -from typing import Optional, List -import pandas as pd - -from .logger import logger - -# 重采样聚合规则 -RESAMPLE_AGG = { - 'open': 'first', - 'high': 'max', - 'low': 'min', - 'close': 'last', - 'volume': 'sum', - 'amount': 'sum', - 'code': 'first', - 'market': 'first', -} - -# 均线周期 -MA_WINDOWS = [5, 10, 13, 21, 34, 55, 60, 89, 144, 233, 250] - - -class DataProcessor: - """数据处理类""" - - @staticmethod - def resample_ohlcv(df: pd.DataFrame, freq: str) -> pd.DataFrame: - """将 OHLCV 数据重采样到目标频率 - - Args: - df: 带有 DatetimeIndex 的 DataFrame - freq: pandas resample 频率字符串('15min', '30min', '60min') - - Returns: - 重采样后的 DataFrame(已 reset_index) - """ - agg = dict(RESAMPLE_AGG) - if 'date' in df.columns: - agg['date'] = 'first' - result = df.resample(freq).agg(agg).dropna() - result.reset_index(inplace=True) - return result - - @staticmethod - def _validate_ohlcv(df: pd.DataFrame) -> pd.DataFrame: - """校验 OHLCV 数据质量,丢弃不合格行 - - 校验规则: - 1. 价格列(open/high/low/close)必须 > 0 - 2. OHLC 关系:high >= max(open, close), low <= min(open, close) - - Args: - df: 包含 OHLCV 列的 DataFrame - - Returns: - 校验通过的 DataFrame - """ - required = ['open', 'high', 'low', 'close'] - if not all(col in df.columns for col in required): - return df - - before = len(df) - - # 价格必须为正 - positive_mask = (df[required] > 0).all(axis=1) - - # OHLC 关系校验 - ohlc_mask = ( - (df['high'] >= df[['open', 'close']].max(axis=1)) & - (df['low'] <= df[['open', 'close']].min(axis=1)) - ) - - valid_mask = positive_mask & ohlc_mask - df = df[valid_mask] - - dropped = before - len(df) - if dropped > 0: - logger.warning(f"数据校验丢弃 {dropped} 条不合格记录(价格非正或 OHLC 关系异常)") - - return df - - @staticmethod - def _calculate_ma(df: pd.DataFrame) -> pd.DataFrame: - """计算均线指标,按股票分组 - - Args: - df: 包含 'close' 和 'code' 列的 DataFrame - - Returns: - 添加了均线列的 DataFrame - """ - for w in MA_WINDOWS: - df[f'ma{w}'] = df.groupby('code')['close'].transform( - lambda x: x.rolling(window=w).mean() - ) - return df - - @staticmethod - def apply_forward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: - """计算前复权价格并原地更新 open/high/low/close,新增 adj_factor 列 - - 前复权算法:对每个除权日,将该日之前的全部历史数据乘以当次复权因子, - 使价格序列在除权日前后保持连续。 - - Args: - df: 单只股票的日线 DataFrame,含 code(6位), market, date, open/high/low/close - gbbq: 全量权息 DataFrame(来自 reader.read_gbbq()),含 full_code, category 等 - - Returns: - 含 adj_factor 列、价格已前复权的 DataFrame - """ - if gbbq.empty or df.empty: - df = df.copy() - df['adj_factor'] = 1.0 - return df - - # 构造带前缀的完整代码(market 0=sz, 1=sh) - market_val = df['market'].iloc[0] - prefix = 'sz' if market_val == 0 else 'sh' - pure_code = str(df['code'].iloc[0]).zfill(6) - full_code = prefix + pure_code - - events = gbbq[gbbq['full_code'] == full_code].copy() - - df = df.copy() - if events.empty: - df['adj_factor'] = 1.0 - return df - - # 转换除权日期为 Timestamp(与 df['date'] 的 datetime64[ns] 类型一致) - events['ex_date'] = pd.to_datetime( - events['datetime'].astype(str).str[:8], format='%Y%m%d' - ) - # 仅处理 category==1(除权除息),且限制在 df 数据范围内(避免本地数据不完整时引入无效历史因子) - data_start = df['date'].min() - events = events[(events['category'] == 1) & (events['ex_date'] > data_start)].sort_values('ex_date') - - df = df.sort_values('date').copy() - df['adj_factor'] = 1.0 - - for _, ev in events.iterrows(): - ex_date = ev['ex_date'] # pd.Timestamp - # gbbq 字段均以"每10股"为单位: - # songgu: 每10股送股数,实际送股比例 = songgu / 10 - # hongli: 每10股红利(元),实际每股红利 = hongli / 10 - # peigujia: 配股价(元),单位无需换算 - # peigu: 每10股配股数,实际配股比例 = peigu / 10 - songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 - hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 - peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) - peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 - - before = df[df['date'] < ex_date] - if before.empty: - continue - prev_close = float(before['close'].iloc[-1]) - if prev_close <= 0: - continue - - denominator = prev_close * (1 + songgu + peigu) - if denominator <= 0: - continue - factor = (prev_close - hongli + peigujia * peigu) / denominator - - if factor <= 0 or factor > 2: - logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") - continue - - mask = df['date'] < ex_date - df.loc[mask, 'adj_factor'] *= factor - - # 应用前复权:历史价格 * adj_factor,当日及以后为原始价格(factor=1) - for col in ['open', 'high', 'low', 'close']: - df[col] = (df[col] * df['adj_factor']).round(3) - - return df - - @staticmethod - def apply_backward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: - """计算后复权价格并原地更新 open/high/low/close,新增 adj_factor 列 - - 后复权算法:对每个除权日,将该日及之后的全部数据乘以 1/factor, - 使价格序列以最早历史价格为基准保持连续。 - - Args: - df: 单只股票的日线 DataFrame,含 code(6位), market, date, open/high/low/close - gbbq: 全量权息 DataFrame(来自 reader.read_gbbq()) - - Returns: - 含 adj_factor 列、价格已后复权的 DataFrame - """ - if gbbq.empty or df.empty: - df = df.copy() - df['adj_factor'] = 1.0 - return df - - market_val = df['market'].iloc[0] - prefix = 'sz' if market_val == 0 else 'sh' - full_code = prefix + str(df['code'].iloc[0]).zfill(6) - - events = gbbq[gbbq['full_code'] == full_code].copy() - df = df.copy() - if events.empty: - df['adj_factor'] = 1.0 - return df - - events['ex_date'] = pd.to_datetime( - events['datetime'].astype(str).str[:8], format='%Y%m%d' - ) - data_start = df['date'].min() - events = events[ - (events['category'] == 1) & (events['ex_date'] > data_start) - ].sort_values('ex_date') - - df = df.sort_values('date').copy() - df['adj_factor'] = 1.0 - - for _, ev in events.iterrows(): - ex_date = ev['ex_date'] - songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 - hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 - peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) - peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 - - before = df[df['date'] < ex_date] - if before.empty: - continue - prev_close = float(before['close'].iloc[-1]) - if prev_close <= 0: - continue - - denominator = prev_close * (1 + songgu + peigu) - if denominator <= 0: - continue - factor = (prev_close - hongli + peigujia * peigu) / denominator - - if factor <= 0 or factor > 2: - logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") - continue - - # 后复权:除权日及之后乘以 1/factor(向上调整,历史价格不变) - mask = df['date'] >= ex_date - df.loc[mask, 'adj_factor'] *= (1.0 / factor) - - for col in ['open', 'high', 'low', 'close']: - df[col] = (df[col] * df['adj_factor']).round(3) - - return df - - @staticmethod - def process_daily_data(df: pd.DataFrame, gbbq: pd.DataFrame = None, adj_type: str = 'forward') -> pd.DataFrame: - """处理日线数据 - - Args: - df: 原始日线数据DataFrame - gbbq: 权息数据(来自 reader.read_gbbq()),传入时执行复权,None 时跳过复权 - adj_type: 复权类型,'forward'(前复权)/ 'backward'(后复权)/ 'none'(不复权),默认 'forward' - - Returns: - DataFrame: 处理后的数据 - """ - if df.empty: - return df - - # 复制数据,避免修改原始数据 - processed_df = df.copy() - - # 确保datetime列存在 - if 'datetime' not in processed_df.columns: - # 检查是否有索引中包含日期时间信息 - if processed_df.index.name == 'datetime' or isinstance(processed_df.index, pd.DatetimeIndex): - # 如果索引是日期时间类型,直接将索引转为列 - processed_df['datetime'] = processed_df.index - # 如果索引不是日期时间类型但包含日期信息(如终端输出所示) - elif hasattr(processed_df.iloc[-1], 'name') and isinstance(processed_df.iloc[-1].name, pd.Timestamp): - # 从行索引名称中提取日期时间 - processed_df['datetime'] = processed_df.apply(lambda row: row.name if isinstance(row.name, pd.Timestamp) else None, axis=1) - - # 处理缺失值 - numeric_columns = ['open', 'high', 'low', 'close', 'volume', 'amount'] - for col in numeric_columns: - if col in processed_df.columns: - # 用前一个有效值填充缺失值 - processed_df[col] = processed_df[col].ffill() - - # 数据质量校验 - processed_df = DataProcessor._validate_ohlcv(processed_df) - - # 复权处理(在均线计算之前,确保均线基于复权价格) - if gbbq is not None and not gbbq.empty and 'date' in processed_df.columns: - if adj_type == 'forward': - processed_df = DataProcessor.apply_forward_adj(processed_df, gbbq) - elif adj_type == 'backward': - processed_df = DataProcessor.apply_backward_adj(processed_df, gbbq) - else: # 'none' - processed_df['adj_factor'] = 1.0 - - # 计算均线指标 - if all(col in processed_df.columns for col in ['close', 'volume']): - processed_df = DataProcessor._calculate_ma(processed_df) - - return processed_df - - @staticmethod - def process_min_data(df: pd.DataFrame) -> pd.DataFrame: - """处理分钟线数据 - - Args: - df: 原始分钟线数据DataFrame - - Returns: - DataFrame: 处理后的数据 - """ - if df.empty: - return df - - # 复制数据,避免修改原始数据 - processed_df = df.copy() - - - # 重命名列,使其更符合通用命名 - column_mapping = { - 'amount': 'amount', # 成交额 - 'close': 'close', # 收盘价 - 'open': 'open', # 开盘价 - 'high': 'high', # 最高价 - 'low': 'low', # 最低价 - 'vol': 'volume', # 成交量 - 'year': 'year', # 年 - 'month': 'month', # 月 - 'day': 'day', # 日 - 'hour': 'hour', # 时 - 'minute': 'minute', # 分 - 'datetime': 'datetime', # 日期时间 - 'code': 'code', # 股票代码 - 'market': 'market' # 市场代码 - } - processed_df.rename(columns={k: v for k, v in column_mapping.items() if k in processed_df.columns}, inplace=True) - - # 确保datetime列存在 - if 'datetime' not in processed_df.columns and all(col in processed_df.columns for col in ['year', 'month', 'day', 'hour', 'minute']): - processed_df['datetime'] = pd.to_datetime( - processed_df[['year', 'month', 'day']].astype(str).agg('-'.join, axis=1) + ' ' + - processed_df[['hour', 'minute']].astype(str).agg(':'.join, axis=1) - ) - - # 处理缺失值 - numeric_columns = ['open', 'high', 'low', 'close', 'volume', 'amount'] - for col in numeric_columns: - if col in processed_df.columns: - # 用前一个有效值填充缺失值 - processed_df[col] = processed_df[col].ffill() - - # 数据质量校验 - processed_df = DataProcessor._validate_ohlcv(processed_df) - - # 计算均线指标 - if all(col in processed_df.columns for col in ['close', 'volume']): - processed_df = DataProcessor._calculate_ma(processed_df) - - return processed_df - - @staticmethod - def filter_data( - df: pd.DataFrame, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - codes: Optional[List[str]] = None - ) -> pd.DataFrame: - """根据条件筛选数据 - - Args: - df: 原始数据DataFrame - start_date: 开始日期,格式为'YYYY-MM-DD' - end_date: 结束日期,格式为'YYYY-MM-DD' - codes: 股票代码列表 - - Returns: - DataFrame: 筛选后的数据 - """ - if df.empty: - return df - - filtered_df = df.copy() - - - logger.debug(f"筛选日期范围: start_date={start_date}, end_date={end_date}") - # 按日期筛选 - if 'date' in filtered_df.columns: - if start_date: - filtered_df = filtered_df[filtered_df['date'] >= pd.to_datetime(start_date)] - if end_date: - filtered_df = filtered_df[filtered_df['date'] <= pd.to_datetime(end_date)] - - # 按时间筛选 - if 'datetime' in filtered_df.columns: - if start_date: - filtered_df = filtered_df[filtered_df['datetime'] >= pd.to_datetime(start_date)] - if end_date: - filtered_df = filtered_df[filtered_df['datetime'] <= pd.to_datetime(end_date)] - - # 按股票代码筛选 - if codes and 'code' in filtered_df.columns: - filtered_df = filtered_df[filtered_df['code'].isin(codes)] - - return filtered_df - - @staticmethod - def filter_data_min( - df: pd.DataFrame, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - codes: Optional[List[str]] = None - ) -> pd.DataFrame: - """根据条件筛选分钟线数据 - - Args: - df: 原始数据DataFrame - start_date: 开始日期,格式为'YYYY-MM-DD' - end_date: 结束日期,格式为'YYYY-MM-DD' - codes: 股票代码列表 - - Returns: - DataFrame: 筛选后的数据 - """ - if df.empty: - return df - - filtered_df = df.copy() - - # 按日期筛选 - if 'date' in filtered_df.columns: - if start_date: - filtered_df = filtered_df[filtered_df['datetime'] >= pd.to_datetime(start_date)] - if end_date: - filtered_df = filtered_df[filtered_df['datetime'] <= pd.to_datetime(end_date)] - - # 按股票代码筛选 - if codes and 'code' in filtered_df.columns: - filtered_df = filtered_df[filtered_df['code'].isin(codes)] - - return filtered_df diff --git a/src/reader.py b/src/reader.py deleted file mode 100644 index cc39a97..0000000 --- a/src/reader.py +++ /dev/null @@ -1,458 +0,0 @@ -"""数据读取模块 - -负责从通达信本地数据文件中读取股票数据,支持: -- 日线数据 -- 分钟线数据 -- 股票列表 -""" - -import os -import re -from pathlib import Path -from typing import Optional, List - -import pandas as pd -from pytdx.reader import TdxDailyBarReader, TdxMinBarReader, TdxLCMinBarReader -from pytdx.reader import BlockReader, GbbqReader -from tqdm import tqdm - -from .config import config -from .logger import logger -from .processor import DataProcessor - -class TdxDataReader: - """通达信数据读取类""" - - def __init__(self, tdx_path: Optional[str] = None) -> None: - """初始化数据读取器 - - Args: - tdx_path: 通达信安装目录,如果为None则使用配置中的路径 - """ - self.tdx_path = tdx_path or config.tdx_path - if not self.tdx_path: - raise ValueError("通达信数据路径未设置,请在.env文件中设置TDX_PATH或在初始化时提供") - - self.tdx_path = Path(self.tdx_path) - if not self.tdx_path.exists(): - raise FileNotFoundError(f"通达信数据路径不存在: {self.tdx_path}") - - # 初始化读取器 - self.daily_reader = TdxDailyBarReader() - self.min_reader = TdxMinBarReader() - self.lc_min_reader = TdxLCMinBarReader() - self.block_reader = BlockReader() - self.gbbq_reader = GbbqReader() - - def read_gbbq(self) -> pd.DataFrame: - """读取通达信本地权息文件(gbbq),返回全量权息 DataFrame - - Returns: - DataFrame: 权息数据,列包括 market, code, datetime(YYYYMMDD整数), category, - hongli_panqianliutong, peigujia_qianzongguben, songgu_qianzongguben, - peigu_houzongguben, full_code(sz/sh+6位代码) - 若文件不存在则返回空 DataFrame(复权逻辑会优雅降级) - """ - gbbq_path = self.tdx_path / 'T0002' / 'hq_cache' / 'gbbq' - if not gbbq_path.exists(): - logger.warning(f"权息文件不存在: {gbbq_path},将跳过复权处理") - return pd.DataFrame() - try: - df = self.gbbq_reader.get_df(str(gbbq_path)) - if df.empty: - return pd.DataFrame() - market_prefix = df['market'].map({0: 'sz', 1: 'sh'}) - df['full_code'] = market_prefix + df['code'].astype(str).str.zfill(6) - return df - except Exception as e: - logger.warning(f"读取权息文件时出错: {e},将跳过复权处理") - return pd.DataFrame() - - def get_stock_list(self) -> pd.DataFrame: - """获取股票列表 - - Returns: - DataFrame: 包含A股股票代码和名称的DataFrame(不包含B股、基金、等) - """ - # 尝试查找通达信股票数据文件 - sz_path = self.tdx_path / 'vipdoc' / 'sz' / 'lday' - sh_path = self.tdx_path / 'vipdoc' / 'sh' / 'lday' - - if not (sz_path.exists() or sh_path.exists()): - raise FileNotFoundError(f"无法找到股票列表文件或股票数据目录") - - # 从目录中获取股票代码 - stocks = [] - - # 处理深圳股票 - if sz_path.exists(): - for file in sz_path.glob('*.day'): - code = file.stem - name = f"深A{code}" - - pure_code = code[-6:] - code_str = str(pure_code).zfill(6) # 补齐为6位字符串 - # 匹配上证A股+深证A股 - if re.match(r'^(000|001|002|300)\d{3}$', code_str): - stocks.append({'code': code, 'name': name}) - - # 处理上海股票 - if sh_path.exists(): - for file in sh_path.glob('*.day'): - code = file.stem - name = f"上A{code}" - - pure_code = code[-6:] - code_str = str(pure_code).zfill(6) # 补齐为6位字符串 - # 匹配上证A股(60xxxx)和科创板(688xxx) - if re.match(r'^(60\d{4}|688\d{3})$', code_str): - stocks.append({'code': code, 'name': name}) - - if not stocks: - raise FileNotFoundError(f"未找到任何股票数据文件") - - return pd.DataFrame(stocks, columns=['code', 'name']) - - def read_daily_data(self, market: int, code: str) -> pd.DataFrame: - """读取日线数据 - - Args: - market: 市场代码,0表示深圳,1表示上海 - code: 股票代码 - - Returns: - DataFrame: 日线数据 - """ - # 构建日线数据文件路径 - market_folder = 'sz' if market == 0 else 'sh' - data_path = self.tdx_path / 'vipdoc' / market_folder / 'lday' - - if (len(code)>6): - code = code[-6:] - file_path = data_path / f"{market_folder}{code}.day" - - if not file_path.exists(): - raise FileNotFoundError(f"日线数据文件不存在: {file_path}") - - # 读取数据:先检查 pytdx 是否支持该证券类型,不支持则直接走原始解析(如科创板 688xxx) - sec_type = self.daily_reader.get_security_type(str(file_path)) - if sec_type in self.daily_reader.SECURITY_TYPE: - data = self.daily_reader.get_df(str(file_path)) - else: - data = self._read_day_file_raw(str(file_path)) - data['code'] = code - data['market'] = market - return data - - @staticmethod - def _read_day_file_raw(fname: str) -> pd.DataFrame: - """直接解析 .day 文件,绕过 pytdx 的证券类型检查(用于科创板等)""" - import struct - rows = [] - with open(fname, 'rb') as f: - content = f.read() - record_size = struct.calcsize(' List[pd.DataFrame]: - """读取5分钟线数据并生成15分钟、30分钟和60分数据 - - Args: - market: 市场代码,0表示深圳,1表示上海 - code: 股票代码 - - Returns: - list: [15分钟数据, 30分钟数据, 60分钟数据] - """ - # 构建分钟线数据文件路径 - market_folder = 'sz' if market == 0 else 'sh' - freq_folder = 'fzline' - data_path = self.tdx_path / 'vipdoc' / market_folder / freq_folder - file_path = data_path /f"{market_folder}{code}.lc5" # 只读取5分钟数据 - - if not file_path.exists(): - raise FileNotFoundError(f"5分钟线数据文件不存在: {file_path}") - - # 读取5分钟数据 - logger.info(f"正在读取 {code} 的5分钟线数据...") - with tqdm(total=1, desc="读取进度") as pbar: - data = self.lc_min_reader.get_df(str(file_path)) - data['code'] = code - data['market'] = market - pbar.update(1) - - # 确保datetime列存在并且是日期时间类型 - if 'datetime' not in data.columns: - # 如果没有datetime列,尝试从index创建 - if isinstance(data.index, pd.DatetimeIndex): - data['datetime'] = data.index - else: - raise ValueError("数据中缺少datetime列且索引不是日期时间类型") - elif not pd.api.types.is_datetime64_any_dtype(data['datetime']): - data['datetime'] = pd.to_datetime(data['datetime']) - - # 设置datetime为索引,用于后续resample操作 - data.set_index('datetime', inplace=True) - - # 记得定期获取最新的数据,同步进TDX - logger.debug(f"数据时间范围: {data.index[0]} ~ {data.index[-1]}") - - # 重采样生成多周期数据 - data_15min = DataProcessor.resample_ohlcv(data, '15min') - data_30min = DataProcessor.resample_ohlcv(data, '30min') - data_60min = DataProcessor.resample_ohlcv(data, '60min') - - data.reset_index(inplace=True) - - return [data_15min, data_30min, data_60min] - - def read_5min_data(self, market: int, code: str) -> pd.DataFrame: - """读取5分钟线数据 - - Args: - market: 市场代码,0表示深圳,1表示上海 - code: 股票代码 - - Returns: - DataFrame: 5分钟数据 - """ - # 构建分钟线数据文件路径 - market_folder = 'sz' if market == 0 else 'sh' - freq_folder = 'fzline' - data_path = self.tdx_path / 'vipdoc' / market_folder / freq_folder - - if (len(code)>6): - code = code[-6:] - file_path = data_path /f"{market_folder}{code}.lc5" - - if not file_path.exists(): - raise FileNotFoundError(f"5分钟线数据文件不存在: {file_path}") - - # 读取5分钟数据 - logger.info(f"正在读取 {code} 的5分钟线数据...") - with tqdm(total=1, desc="读取进度") as pbar: - data = self.lc_min_reader.get_df(str(file_path)) - data['code'] = code - data['market'] = market - pbar.update(1) - - # 确保datetime列存在并且是日期时间类型 - if 'datetime' not in data.columns: - # 如果没有datetime列,尝试从index创建 - if isinstance(data.index, pd.DatetimeIndex): - data['datetime'] = data.index - else: - raise ValueError("数据中缺少datetime列且索引不是日期时间类型") - elif not pd.api.types.is_datetime64_any_dtype(data['datetime']): - data['datetime'] = pd.to_datetime(data['datetime']) - - # 设置datetime为索引,用于后续resample操作 - data.set_index('datetime', inplace=True) - - # 记得定期获取最新的数据,同步进TDX - logger.debug(f"数据时间范围: {data.index[0]} ~ {data.index[-1]}") - - # 重置索引,使datetime成为列 - data.reset_index(inplace=True) - - return data - - def read_all_daily_data(self) -> pd.DataFrame: - """读取所有股票的日线数据 - - Returns: - DataFrame: 所有股票的日线数据 - """ - # 获取股票列表 - stocks = self.get_stock_list() - logger.info(f"获取到 {len(stocks)} 只股票,开始读取日线数据...") - - all_data = [] - iterator = tqdm(stocks.iterrows(), total=len(stocks)) if config.use_tqdm else stocks.iterrows() - - for _, stock in iterator: - code = stock['code'] - # 判断市场 - if code.startswith('sh'): - market = 1 # 上海 - else: - market = 0 # 深圳 - - try: - data = self.read_daily_data(market, code) - # 确保 date/datetime 是列而不是索引 - if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('datetime', 'date'): - data = data.reset_index() - all_data.append(data) - except FileNotFoundError: - continue - except Exception as e: - logger.error(f"读取 {code} 日线数据时出错: {e}") - continue - - if not all_data: - return pd.DataFrame() - - # 合并数据时保留datetime列 - result_df = pd.concat(all_data, ignore_index=True) - - # 确保datetime列存在并且是正确的日期时间格式 - if 'datetime' in result_df.columns and not pd.api.types.is_datetime64_any_dtype(result_df['datetime']): - try: - result_df['datetime'] = pd.to_datetime(result_df['datetime']) - except Exception as e: - logger.warning(f"转换datetime列时出错: {e}") - - return result_df - - # 板块关系暂时未实现,由于板块文件未找到 - def get_block_stock_relation(self) -> pd.DataFrame: - """获取通达信板块与股票的对应关系 - - Returns: - DataFrame: 包含板块代码、板块名称和对应股票代码的DataFrame - """ - # 板块文件目录 - block_path = self.tdx_path / 'T0002' / 'hq_cache' - - if not block_path.exists(): - raise FileNotFoundError(f"板块文件目录不存在: {block_path}") - - blocks = self.block_reader.get_df(self.tdx_path / 'BlockMap' / 'TdxZLSelStock.dat') - logger.debug(f"读取到板块数据: {len(blocks)} 条记录") - - # # 板块文件列表 - # block_files = list(block_path.glob('block*.dat')) - # block_files.extend(list(block_path.glob('block*.blk'))) - - # if not block_files: - # raise FileNotFoundError(f"未找到板块文件: {block_path}") - - # # 存储板块与股票的对应关系 - # block_stock_relations = [] - - # # 遍历板块文件 - # for block_file in block_files: - # block_type = block_file.stem - - # try: - # # 使用BlockReader读取板块文件 - # block_data = self.block_reader.get_df(str(block_file)) - - # if block_data.empty: - # continue - - # # 处理板块数据 - # for _, row in block_data.iterrows(): - # block_stock_relations.append({ - # 'block_code': row.get('block_code', block_type), - # 'block_name': row.get('block_name', block_type), - # 'code': row.get('code', ''), - # 'name': row.get('name', '') - # }) - # except Exception as e: - # print(f"读取板块文件{block_file}时出错: {e}") - - # # 尝试直接读取文件内容 - # try: - # with open(block_file, 'rb') as f: - # content = f.read() - - # # 解析板块文件内容 - # block_name = block_file.stem - - # # 尝试从文件名或内容中提取板块名称 - # if block_file.suffix.lower() == '.dat': - # # .dat文件通常是二进制格式 - # try: - # # 尝试从文件头部提取板块名称 - # if len(content) > 50: - # # 通达信板块文件格式可能不同,这里尝试几种常见格式 - # try: - # name_bytes = content[0:50].split(b'\x00')[0] - # block_name = name_bytes.decode('gbk', errors='ignore').strip() - # except: - # pass - # except: - # pass - - # # 提取股票代码 - # codes = [] - - # # 解析文件内容提取股票代码 - # if block_file.suffix.lower() == '.blk': - # # .blk文件通常是文本格式 - # try: - # text_content = content.decode('gbk', errors='ignore') - # for line in text_content.split('\n'): - # line = line.strip() - # if line and not line.startswith('#'): - # # 通常格式为 1 000001 或 0 000001 - # parts = line.split() - # if len(parts) >= 2: - # market = int(parts[0]) - # code = parts[1] - # market_prefix = 'sh' if market == 1 else 'sz' - # codes.append(f"{market_prefix}{code}") - # else: - # # 可能只有代码,没有市场标识 - # code = line - # # 根据代码前缀判断市场 - # if code.startswith(('6', '5', '9')): - # codes.append(f"sh{code}") - # else: - # codes.append(f"sz{code}") - # except: - # pass - # elif block_file.suffix.lower() == '.dat': - # # .dat文件通常是二进制格式 - # try: - # # 跳过文件头部,直接读取股票代码部分 - # offset = 384 # 通常板块文件头部大小 - # while offset < len(content): - # if offset + 7 <= len(content): - # market = content[offset] - # code = content[offset+1:offset+7].decode('ascii', errors='ignore') - # if code.isdigit(): - # market_prefix = 'sh' if market == 1 else 'sz' - # codes.append(f"{market_prefix}{code}") - # offset += 7 - # except: - # pass - - # # 添加到结果列表 - # for code in codes: - # block_stock_relations.append({ - # 'block_code': block_type, - # 'block_name': block_name, - # 'code': code, - # 'name': '' - # }) - # except Exception as e: - # print(f"直接解析板块文件{block_file}时出错: {e}") - - # # 转换为DataFrame - # if not block_stock_relations: - # return pd.DataFrame() - - # df = pd.DataFrame(block_stock_relations) - - # # 尝试补充股票名称 - # try: - # stocks = self.get_stock_list() - # stock_dict = dict(zip(stocks['code'], stocks['name'])) - # df['name'] = df['code'].map(stock_dict) - # except Exception as e: - # print(f"补充股票名称时出错: {e}") - - return df diff --git a/src/storage.py b/src/storage.py deleted file mode 100644 index 0e57ade..0000000 --- a/src/storage.py +++ /dev/null @@ -1,585 +0,0 @@ -"""数据存储模块 - -负责将处理后的数据保存到不同的存储介质,支持: -- CSV文件存储 -- 数据库存储(PostgreSQL、MySQL、SQLite) -""" - -import os -from datetime import datetime as dt -from pathlib import Path -from typing import Optional, Tuple - -import pandas as pd -from sqlalchemy import create_engine, Column, Integer, Float, String, DateTime, MetaData, Table, text -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from tqdm import tqdm - -from .config import config -from .logger import logger - -Base = declarative_base() - -class BlockStockRelation(Base): - """板块股票关系表模型""" - __tablename__ = 'block_stock_relation' - - id = Column(Integer, primary_key=True) - block_code = Column(String(20), index=True) # 板块代码 - block_name = Column(String(50)) # 板块名称 - code = Column(String(10), index=True) # 股票代码 - name = Column(String(50)) # 股票名称 - -class DailyData(Base): - """日线数据表模型""" - __tablename__ = 'daily_data' - - id = Column(Integer, primary_key=True) - code = Column(String(10), index=True) - market = Column(Integer) - datetime = Column(DateTime, index=True) - date = Column(DateTime, index=True) - open = Column(Float) - high = Column(Float) - low = Column(Float) - close = Column(Float) - volume = Column(Float) - amount = Column(Float) - adj_factor = Column(Float) # 前复权因子,1.0 表示无复权 - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class Minute5Data(Base): - """5分钟线数据表模型""" - __tablename__ = 'minute5_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class Minute15Data(Base): - """15分钟线数据表模型""" - __tablename__ = 'minute15_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - # 添加技术指标列 - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) -class Minute30Data(Base): - """30分钟线数据表模型""" - __tablename__ = 'minute30_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - # 添加技术指标列 - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class Minute60Data(Base): - """60分钟线数据表模型""" - __tablename__ = 'minute60_data' - - id = Column(Integer, primary_key=True, autoincrement=True) - code = Column(String(10), nullable=False, index=True) - market = Column(Integer, nullable=False) - datetime = Column(DateTime, nullable=False, index=True) - date = Column(DateTime, nullable=False, index=True) - open = Column(Float, nullable=False) - high = Column(Float, nullable=False) - low = Column(Float, nullable=False) - close = Column(Float, nullable=False) - volume = Column(Float, nullable=False) - amount = Column(Float, nullable=False) - # 添加技术指标列 - ma13 = Column(Float) - ma21 = Column(Float) - ma34 = Column(Float) - ma55 = Column(Float) - ma89 = Column(Float) - ma144 = Column(Float) - ma233 = Column(Float) - ma5 = Column(Float) - ma10 = Column(Float) - ma60 = Column(Float) - ma250 = Column(Float) - -class StockInfo(Base): - """股票信息表模型""" - __tablename__ = 'stock_info' - - id = Column(Integer, primary_key=True) - code = Column(String(10), unique=True, index=True) - name = Column(String(50)) - market = Column(Integer) - -# 允许写入的表名白名单 -_VALID_TABLES = frozenset({ - 'daily_data', 'minute5_data', 'minute15_data', 'minute30_data', 'minute60_data', - 'stock_info', 'block_stock_relation', -}) - - -class DataStorage: - """数据存储类""" - - def __init__( - self, - db_url: Optional[str] = None, - csv_path: Optional[str] = None - ) -> None: - """初始化数据存储 - - Args: - db_url: 数据库连接URL,如果为None则使用配置中的URL - csv_path: CSV文件保存路径,如果为None则使用配置中的路径 - """ - self.db_url = db_url or config.database_url - self.csv_path = csv_path or config.csv_output_path - - # 确保CSV输出目录存在 - if self.csv_path: - os.makedirs(self.csv_path, exist_ok=True) - - # 初始化数据库连接 - if self.db_url: - self.engine = create_engine(self.db_url) - Base.metadata.create_all(self.engine) - self.Session = sessionmaker(bind=self.engine) - - def save_to_csv(self, df: pd.DataFrame, filename: str) -> Optional[str]: - """保存数据到CSV文件 - - Args: - df: 要保存的DataFrame - filename: 文件名(不包含路径和扩展名) - - Returns: - str: 保存的文件路径,如果没有数据则返回None - """ - if df.empty: - logger.warning(f"没有数据可保存到 {filename}.csv") - return None - - file_path = Path(self.csv_path) / f"{filename}.csv" - df.to_csv(file_path, index=False, encoding='utf-8') - logger.info(f"数据已保存到: {file_path}") - return str(file_path) - - def get_latest_datetime( - self, - table_name: str, - date_column: str = 'datetime' - ) -> Optional[dt]: - """获取表中最新的日期/时间 - - Args: - table_name: 表名 - date_column: 日期列名,默认为 'datetime',日线数据应使用 'date' - - Returns: - Optional[datetime]: 最新的日期/时间,如果表为空则返回 None - """ - try: - with self.engine.connect() as conn: - result = conn.execute( - text(f"SELECT MAX({date_column}) FROM {table_name}") - ) - row = result.fetchone() - if row and row[0]: - return row[0] - return None - except Exception as e: - logger.debug(f"获取表 {table_name} 最新日期时出错: {e}") - return None - - def get_latest_datetime_by_code( - self, - table_name: str, - code: str - ) -> Optional[dt]: - """获取指定股票的最新 datetime - - Args: - table_name: 表名 - code: 股票代码 - - Returns: - Optional[datetime]: 最新的 datetime,如果没有数据则返回 None - """ - try: - with self.engine.connect() as conn: - result = conn.execute( - text(f"SELECT MAX(datetime) FROM {table_name} WHERE code = :code"), - {"code": code} - ) - row = result.fetchone() - if row and row[0]: - return row[0] - return None - except Exception as e: - logger.debug(f"获取表 {table_name} 股票 {code} 最新日期时出错: {e}") - return None - - def save_incremental( - self, - df: pd.DataFrame, - table_name: str, - conflict_columns: Tuple[str, ...] = ('code', 'datetime'), - batch_size: int = 10000 - ) -> int: - """增量保存数据,跳过重复记录 - - 使用 ON CONFLICT DO NOTHING 策略,需要先在数据库中添加唯一约束 - - Args: - df: 要保存的 DataFrame - table_name: 表名 - conflict_columns: 唯一约束列,默认为 ('code', 'datetime') - batch_size: 批处理大小 - - Returns: - int: 实际插入的行数 - """ - if table_name not in _VALID_TABLES: - raise ValueError(f"不允许写入的表名: {table_name}") - - if df.empty: - logger.warning(f"没有数据可保存到表 {table_name}") - return 0 - - total_rows = len(df) - - # 确保 datetime 不是索引而是列 - df_to_save = df.copy() - if df_to_save.index.name == 'datetime' or isinstance(df_to_save.index, pd.DatetimeIndex): - df_to_save = df_to_save.reset_index() - - # 获取列名 - columns = list(df_to_save.columns) - columns_str = ', '.join(columns) - db_type = config.db_type - - try: - if db_type == 'postgresql': - self._save_incremental_pg( - df_to_save, columns, columns_str, - table_name, conflict_columns, batch_size - ) - else: - # MySQL / SQLite: 走 SQLAlchemy executemany - placeholders = ', '.join([f':{col}' for col in columns]) - if db_type == 'mysql': - sql = text(f"INSERT IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") - elif db_type == 'sqlite': - sql = text(f"INSERT OR IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") - else: - raise ValueError(f"不支持的数据库类型: {db_type}") - - with self.engine.connect() as conn: - for i in range(0, total_rows, batch_size): - batch_df = df_to_save.iloc[i:i + batch_size].astype(object).where( - df_to_save.iloc[i:i + batch_size].notna(), None - ) - # 先转为 dict 列表,再转换 Timestamp → Python datetime - # 必须在 to_dict() 之后处理:若在 DataFrame 上赋值,pandas 会把 - # datetime 对象重新推断为 datetime64,to_dict() 又变回 Timestamp, - # 而 sqlite3 使用精确类型匹配,不识别 pd.Timestamp(datetime 子类) - records = batch_df.to_dict('records') - for record in records: - for key in record: - val = record[key] - if isinstance(val, pd.Timestamp): - record[key] = val.to_pydatetime() - elif val is pd.NaT: - record[key] = None - conn.execute(sql, records) - conn.commit() - - logger.info(f"增量保存完成: 共处理 {total_rows} 条到表 {table_name}(重复数据已跳过)") - return total_rows - - except Exception as e: - logger.error(f"增量保存数据到表 {table_name} 时出错: {e}") - return 0 - - def _save_incremental_pg( - self, - df: pd.DataFrame, - columns: list, - columns_str: str, - table_name: str, - conflict_columns: Tuple[str, ...], - batch_size: int, - ) -> None: - """PostgreSQL 专用:使用 execute_values 真正批量插入 - - 一次网络往返插入整批数据,比 executemany 快 10-100x。 - """ - from psycopg2.extras import execute_values - - conflict_str = ', '.join(conflict_columns) - sql = f"INSERT INTO {table_name} ({columns_str}) VALUES %s ON CONFLICT ({conflict_str}) DO NOTHING" - - # DataFrame → list of tuples,NaN/NaT → None(psycopg2 需要 None 表示 NULL) - df_clean = df.astype(object).where(df.notna(), None) - values = list(df_clean.itertuples(index=False, name=None)) - - raw_conn = self.engine.raw_connection() - try: - cursor = raw_conn.cursor() - execute_values(cursor, sql, values, page_size=batch_size) - raw_conn.commit() - finally: - raw_conn.close() - - def save_to_database( - self, - df: pd.DataFrame, - table_name: str, - batch_size: int = 10000 - ) -> bool: - """保存数据到数据库 - - Args: - df: 要保存的DataFrame - table_name: 表名 - batch_size: 批处理大小,默认10000条记录 - - Returns: - bool: 是否保存成功 - """ - if df.empty: - logger.warning(f"没有数据可保存到表 {table_name}") - return False - - try: - # 获取数据总量 - total_rows = len(df) - - logger.debug(f"开始保存数据到数据库表: {table_name}, 共 {total_rows} 条记录") - - # 如果数据量小于批处理大小,直接保存 - if total_rows <= batch_size: - df.to_sql(table_name, self.engine, if_exists='append', index=False) - logger.info(f"数据已保存到数据库表: {table_name}") - return True - - # 数据量大,分批处理 - logger.info(f"数据量较大({total_rows}条),开始分批保存到数据库表: {table_name}") - - # 计算批次数 - num_batches = (total_rows + batch_size - 1) // batch_size - - # 创建进度条 - iterator = tqdm(range(num_batches), desc="保存到数据库") if config.use_tqdm else range(num_batches) - - # 确保datetime不是索引而是列 - df_to_save = df.copy() - if df_to_save.index.name == 'datetime' or isinstance(df_to_save.index, pd.DatetimeIndex): - df_to_save = df_to_save.reset_index() - - # 分批保存 - for i in iterator: - start_idx = i * batch_size - end_idx = min((i + 1) * batch_size, total_rows) - batch_df = df_to_save.iloc[start_idx:end_idx] - - # 保存当前批次 - # 使用正确的方法检查表是否存在 - from sqlalchemy import inspect - inspector = inspect(self.engine) - if_exists = 'append' if i > 0 or inspector.has_table(table_name) else 'replace' - batch_df.to_sql(table_name, self.engine, if_exists=if_exists, index=False) - - if not config.use_tqdm: - logger.info(f"已保存 {end_idx}/{total_rows} 条记录到数据库表 {table_name}") - - logger.info(f"所有数据已成功保存到数据库表: {table_name}") - return True - except Exception as e: - logger.error(f"保存数据到数据库表 {table_name} 时出错: {e}") - return False - - def save_daily_data( - self, - df: pd.DataFrame, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存日线数据 - - Args: - df: 日线数据DataFrame - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, 'daily_data') - - if to_db: - db_success = self.save_to_database(df, 'daily_data', batch_size=batch_size) - - return csv_path, db_success - - def save_minute_data( - self, - df: pd.DataFrame, - freq: int = 1, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存分钟线数据 - - Args: - df: 分钟线数据DataFrame - freq: 分钟频率 - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, f'minute{freq}_data') - - if to_db: - db_success = self.save_to_database(df, f'minute{freq}_data', batch_size=batch_size) - - return csv_path, db_success - - def save_stock_info( - self, - df: pd.DataFrame, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存股票信息 - - Args: - df: 股票信息DataFrame - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, 'stock_info') - - if to_db: - db_success = self.save_to_database(df, 'stock_info', batch_size=batch_size) - - return csv_path, db_success - - def save_block_relation( - self, - df: pd.DataFrame, - to_csv: bool = True, - to_db: bool = True, - batch_size: int = 10000 - ) -> Tuple[Optional[str], bool]: - """保存板块与股票的对应关系 - - Args: - df: 板块与股票对应关系DataFrame - to_csv: 是否保存到CSV - to_db: 是否保存到数据库 - batch_size: 批处理大小,默认10000条记录 - - Returns: - tuple: (csv_path, db_success) - """ - csv_path = None - db_success = False - - if to_csv: - csv_path = self.save_to_csv(df, 'block_stock_relation') - - if to_db: - db_success = self.save_to_database(df, 'block_stock_relation', batch_size=batch_size) - - return csv_path, db_success diff --git a/src/tdx2db/__init__.py b/src/tdx2db/__init__.py new file mode 100644 index 0000000..bcb1131 --- /dev/null +++ b/src/tdx2db/__init__.py @@ -0,0 +1,105 @@ +"""tdx2db: 从通达信本地文件同步 A 股日线数据到数据库。 + +基本用法:: + + from tdx2db import TdxDailySync + + sync = TdxDailySync(tdx_path="/opt/tdx", db_url="sqlite:///data.db") + sync.sync_all(adj_type='forward', workers=4) + sync.sync_stock('sz000001', start_date=20240101) + df = sync.get_daily('sz000001', start_date=20240101) +""" + +from typing import Optional +import pandas as pd + +from .config import Config, config +from .reader import TdxDataReader +from .processor import DataProcessor +from .storage import DataStorage +from .logger import logger + +__version__ = "0.2.0" +__all__ = ['TdxDailySync', 'TdxDataReader', 'DataProcessor', 'DataStorage', 'Config'] + + +class TdxDailySync: + """高层封装,供外部项目调用。""" + + def __init__( + self, + tdx_path: Optional[str] = None, + db_url: Optional[str] = None, + ) -> None: + if tdx_path: + config.tdx_path = tdx_path + self.reader = TdxDataReader(tdx_path) + self.processor = DataProcessor() + self.storage = DataStorage(db_url) + + def sync_all( + self, + adj_type: str = 'forward', + incremental: bool = True, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + ) -> dict: + """同步所有 A 股日线数据。 + + Returns: + {'total': N, 'success': N, 'failed': N} + """ + from .cli import sync_all_daily + gbbq = self.reader.read_gbbq() + return sync_all_daily( + self.reader, self.processor, self.storage, gbbq, + adj_type=adj_type, incremental=incremental, + start_date=start_date, end_date=end_date, + ) + + def sync_stock( + self, + code: str, + adj_type: str = 'forward', + start_date: Optional[int] = None, + end_date: Optional[int] = None, + ) -> int: + """同步单只股票日线数据,返回写入行数。""" + market = 1 if code.startswith('sh') else 0 + gbbq = self.reader.read_gbbq() + data = self.reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + processed = self.processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + processed = self.processor.filter_data(processed, start_date=start_date, end_date=end_date) + if processed.empty: + return 0 + return self.storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date')) + + def sync_stock_list(self) -> int: + """同步股票列表到 stock_info 表,返回股票数量。""" + stocks = self.reader.get_stock_list() + self.storage.save_stock_info(stocks) + return len(stocks) + + def get_daily( + self, + code: str, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + ) -> pd.DataFrame: + """从数据库查询日线数据,date 列为 YYYYMMDD 整数。""" + from sqlalchemy import text + pure_code = code[-6:] if len(code) > 6 else code + conditions = ["code = :code"] + params: dict = {"code": pure_code} + if start_date: + conditions.append("date >= :start_date") + params["start_date"] = start_date + if end_date: + conditions.append("date <= :end_date") + params["end_date"] = end_date + where = " AND ".join(conditions) + sql = text(f"SELECT * FROM daily_data WHERE {where} ORDER BY date") + with self.storage.engine.connect() as conn: + return pd.read_sql(sql, conn, params=params) diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py new file mode 100644 index 0000000..ded0b95 --- /dev/null +++ b/src/tdx2db/cli.py @@ -0,0 +1,212 @@ +import argparse +import sys +from argparse import Namespace +from typing import Optional + +import pandas as pd +from tqdm import tqdm + +from .reader import TdxDataReader +from .processor import DataProcessor +from .storage import DataStorage +from .config import config +from .logger import logger + + +def _has_ex_rights_after(code: str, gbbq: pd.DataFrame, last_date: int) -> bool: + """检查该股票在 last_date 之后是否有除权事件(category==1)。""" + if gbbq is None or gbbq.empty: + return False + prefix = 'sh' if code.startswith('6') else 'sz' + full_code = prefix + code.zfill(6) + events = gbbq[ + (gbbq['full_code'] == full_code) & + (gbbq['category'] == 1) & + (gbbq['datetime'] > last_date) + ] + return not events.empty + + +def sync_all_daily( + reader: TdxDataReader, + processor: DataProcessor, + storage: DataStorage, + gbbq: pd.DataFrame, + adj_type: str = 'forward', + incremental: bool = True, + start_date: Optional[int] = None, + end_date: Optional[int] = None, +) -> dict: + """逐股票流式同步日线数据,返回统计信息。""" + stocks = reader.get_stock_list() + logger.info(f"共 {len(stocks)} 只股票") + + latest_dates = storage.get_all_latest_dates() if incremental else {} + stats = {'total': len(stocks), 'success': 0, 'failed': 0} + + iterator = tqdm(stocks.iterrows(), total=len(stocks), desc="同步日线") if config.use_tqdm else stocks.iterrows() + + for _, stock in iterator: + code = stock['code'] + market = 1 if code.startswith('sh') else 0 + pure_code = code[-6:] if len(code) > 6 else code + last_date = latest_dates.get(pure_code) + + try: + data = reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + if data.empty: + stats['success'] += 1 + continue + + needs_refresh = ( + incremental and last_date is not None and + _has_ex_rights_after(pure_code, gbbq, last_date) + ) + + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + + if incremental and last_date and not needs_refresh: + processed = processed[processed['date'] > last_date] + if start_date: + processed = processed[processed['date'] >= start_date] + if end_date: + processed = processed[processed['date'] <= end_date] + + if processed.empty: + stats['success'] += 1 + continue + + if needs_refresh: + storage.delete_stock_data(pure_code) + storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date'), + batch_size=config.db_batch_size) + stats['success'] += 1 + + except FileNotFoundError: + stats['failed'] += 1 + except Exception as e: + logger.error(f"处理 {code} 时出错: {e}") + stats['failed'] += 1 + + logger.info(f"同步完成: 成功 {stats['success']},失败 {stats['failed']}") + return stats + + +def parse_args() -> Namespace: + parser = argparse.ArgumentParser(description='tdx2db 日线数据同步工具') + + parser.add_argument('--tdx-path', help='通达信安装目录') + parser.add_argument('--db-type', choices=['sqlite', 'mysql', 'postgresql']) + parser.add_argument('--db-host') + parser.add_argument('--db-port') + parser.add_argument('--db-name') + parser.add_argument('--db-user') + parser.add_argument('--db-password') + parser.add_argument('--no-tqdm', action='store_true') + parser.add_argument('--batch-size', type=int, default=10000) + + subparsers = parser.add_subparsers(dest='command') + + # stock-list + subparsers.add_parser('stock-list', help='同步股票列表') + + # daily + daily = subparsers.add_parser('daily', help='同步日线数据') + daily.add_argument('--code', help='股票代码(含市场前缀,如 sz000001),不指定则全量') + daily.add_argument('--start', type=int, help='开始日期 YYYYMMDD') + daily.add_argument('--end', type=int, help='结束日期 YYYYMMDD') + daily.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') + daily.add_argument('--incremental', action='store_true', help='增量模式') + + # sync + sync = subparsers.add_parser('sync', help='一键增量同步日线数据') + sync.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') + + return parser.parse_args() + + +def update_config(args: Namespace) -> None: + if args.tdx_path: + config.tdx_path = args.tdx_path + if args.db_type: + config.db_type = args.db_type + if args.db_host: + config.db_host = args.db_host + if args.db_port: + config.db_port = args.db_port + if args.db_name: + config.db_name = args.db_name + if args.db_user: + config.db_user = args.db_user + if args.db_password: + config.db_password = args.db_password + if args.batch_size: + config.db_batch_size = args.batch_size + if args.no_tqdm: + config.use_tqdm = False + + +def main() -> int: + args = parse_args() + update_config(args) + + try: + reader = TdxDataReader() + except (ValueError, FileNotFoundError) as e: + logger.error(f"初始化失败: {e}") + return 1 + + storage = DataStorage() + processor = DataProcessor() + + if args.command == 'stock-list': + try: + stocks = reader.get_stock_list() + logger.info(f"获取到 {len(stocks)} 只股票") + storage.save_stock_info(stocks) + except Exception as e: + logger.error(f"同步股票列表出错: {e}") + return 1 + + elif args.command == 'daily': + adj_type = getattr(args, 'adj', 'forward') + gbbq = reader.read_gbbq() + + if args.code: + code = args.code + market = 1 if code.startswith('sh') else 0 + try: + data = reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + processed = processor.filter_data(processed, start_date=args.start, end_date=args.end) + if not processed.empty: + storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date')) + except Exception as e: + logger.error(f"同步 {args.code} 出错: {e}") + return 1 + else: + incremental = getattr(args, 'incremental', False) + sync_all_daily(reader, processor, storage, gbbq, + adj_type=adj_type, incremental=incremental, + start_date=args.start, end_date=args.end) + + elif args.command == 'sync': + adj_type = getattr(args, 'adj', 'forward') + logger.info("=== 开始增量同步日线数据 ===") + gbbq = reader.read_gbbq() + sync_all_daily(reader, processor, storage, gbbq, + adj_type=adj_type, incremental=True) + + else: + logger.error("请指定子命令,使用 -h 查看帮助") + return 1 + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/src/config.py b/src/tdx2db/config.py similarity index 71% rename from src/config.py rename to src/tdx2db/config.py index b5bd197..251dfd5 100644 --- a/src/config.py +++ b/src/tdx2db/config.py @@ -1,23 +1,10 @@ -"""配置管理模块 - -负责加载和管理程序的配置参数,包括: -- 通达信数据路径 -- 数据库连接信息 -- 输出CSV路径 -- 其他配置选项 -""" - import os -from pathlib import Path from dotenv import load_dotenv -# 加载.env文件中的环境变量 load_dotenv() class Config: - """配置类""" - tdx_path: str csv_output_path: str db_type: str @@ -30,28 +17,19 @@ class Config: use_tqdm: bool def __init__(self) -> None: - """初始化配置""" - # 通达信安装路径 self.tdx_path = os.getenv('TDX_PATH', '') - - # CSV输出路径 self.csv_output_path = os.getenv('CSV_OUTPUT_PATH', 'output') - - # 数据库配置 - self.db_type = os.getenv('DB_TYPE', 'postgresql') + self.db_type = os.getenv('DB_TYPE', 'sqlite') self.db_host = os.getenv('DB_HOST', 'localhost') self.db_port = os.getenv('DB_PORT', '5432') self.db_name = os.getenv('DB_NAME', 'tdx_data') self.db_user = os.getenv('DB_USER', 'postgres') self.db_password = os.getenv('DB_PASSWORD', '') self.db_batch_size = int(os.getenv('DB_BATCH_SIZE', '10000')) - - # 是否使用进度条 self.use_tqdm = os.getenv('USE_TQDM', 'True').lower() == 'true' @property - def database_url(self): - """获取数据库连接URL""" + def database_url(self) -> str: if self.db_type == 'postgresql': return f"postgresql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}" elif self.db_type == 'mysql': @@ -61,5 +39,5 @@ def database_url(self): else: raise ValueError(f"不支持的数据库类型: {self.db_type}") -# 创建全局配置实例 + config = Config() diff --git a/src/logger.py b/src/tdx2db/logger.py similarity index 100% rename from src/logger.py rename to src/tdx2db/logger.py diff --git a/src/tdx2db/processor.py b/src/tdx2db/processor.py new file mode 100644 index 0000000..5c3364e --- /dev/null +++ b/src/tdx2db/processor.py @@ -0,0 +1,207 @@ +from typing import Optional, List + +import pandas as pd + +from .logger import logger + + +class DataProcessor: + + @staticmethod + def _validate_ohlcv(df: pd.DataFrame) -> pd.DataFrame: + required = ['open', 'high', 'low', 'close'] + if not all(c in df.columns for c in required): + return df + before = len(df) + positive = (df[required] > 0).all(axis=1) + ohlc_ok = ( + (df['high'] >= df[['open', 'close']].max(axis=1)) & + (df['low'] <= df[['open', 'close']].min(axis=1)) + ) + df = df[positive & ohlc_ok] + dropped = before - len(df) + if dropped > 0: + logger.warning(f"数据校验丢弃 {dropped} 条不合格记录") + return df + + @staticmethod + def apply_forward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: + """前复权:对每个除权日,将该日之前的历史价格乘以复权因子。""" + if gbbq.empty or df.empty: + df = df.copy() + df['adj_factor'] = 1.0 + return df + + market_val = df['market'].iloc[0] + prefix = 'sz' if market_val == 0 else 'sh' + full_code = prefix + str(df['code'].iloc[0]).zfill(6) + events = gbbq[gbbq['full_code'] == full_code].copy() + + df = df.copy() + if events.empty: + df['adj_factor'] = 1.0 + return df + + events['ex_date'] = pd.to_datetime( + events['datetime'].astype(str).str[:8], format='%Y%m%d' + ) + data_start = df['date'].min() + events = events[ + (events['category'] == 1) & (events['ex_date'] > data_start) + ].sort_values('ex_date') + + df = df.sort_values('date').copy() + df['adj_factor'] = 1.0 + + for _, ev in events.iterrows(): + ex_date = ev['ex_date'] + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 + peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + + before = df[df['date'] < ex_date] + if before.empty: + continue + prev_close = float(before['close'].iloc[-1]) + if prev_close <= 0: + continue + denominator = prev_close * (1 + songgu + peigu) + if denominator <= 0: + continue + factor = (prev_close - hongli + peigujia * peigu) / denominator + if factor <= 0 or factor > 2: + logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") + continue + df.loc[df['date'] < ex_date, 'adj_factor'] *= factor + + for col in ['open', 'high', 'low', 'close']: + df[col] = (df[col] * df['adj_factor']).round(3) + return df + + @staticmethod + def apply_backward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: + """后复权:对每个除权日,将该日及之后的价格乘以 1/factor。""" + if gbbq.empty or df.empty: + df = df.copy() + df['adj_factor'] = 1.0 + return df + + market_val = df['market'].iloc[0] + prefix = 'sz' if market_val == 0 else 'sh' + full_code = prefix + str(df['code'].iloc[0]).zfill(6) + events = gbbq[gbbq['full_code'] == full_code].copy() + + df = df.copy() + if events.empty: + df['adj_factor'] = 1.0 + return df + + events['ex_date'] = pd.to_datetime( + events['datetime'].astype(str).str[:8], format='%Y%m%d' + ) + data_start = df['date'].min() + events = events[ + (events['category'] == 1) & (events['ex_date'] > data_start) + ].sort_values('ex_date') + + df = df.sort_values('date').copy() + df['adj_factor'] = 1.0 + + for _, ev in events.iterrows(): + ex_date = ev['ex_date'] + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + hongli = float(ev.get('hongli_panqianliutong', 0) or 0) / 10.0 + peigujia = float(ev.get('peigujia_qianzongguben', 0) or 0) + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + + before = df[df['date'] < ex_date] + if before.empty: + continue + prev_close = float(before['close'].iloc[-1]) + if prev_close <= 0: + continue + denominator = prev_close * (1 + songgu + peigu) + if denominator <= 0: + continue + factor = (prev_close - hongli + peigujia * peigu) / denominator + if factor <= 0 or factor > 2: + logger.warning(f"{full_code} 除权日 {ex_date} 复权因子异常({factor:.4f}),已跳过") + continue + df.loc[df['date'] >= ex_date, 'adj_factor'] *= (1.0 / factor) + + for col in ['open', 'high', 'low', 'close']: + df[col] = (df[col] * df['adj_factor']).round(3) + return df + + @staticmethod + def process_daily_data( + df: pd.DataFrame, + gbbq: pd.DataFrame = None, + adj_type: str = 'forward' + ) -> pd.DataFrame: + """日线处理主流程:reset_index → 填充缺失值 → 校验 → 复权 → 日期转 YYYYMMDD 整数。""" + if df.empty: + return df + + processed = df.copy() + + # 确保 date 是列而非索引 + if isinstance(processed.index, pd.DatetimeIndex) or processed.index.name in ('date', 'datetime'): + processed = processed.reset_index() + + # 统一列名:date 列 + if 'date' not in processed.columns and 'datetime' in processed.columns: + processed.rename(columns={'datetime': 'date'}, inplace=True) + + # 确保 date 是 datetime 类型(复权逻辑依赖 Timestamp 比较) + if 'date' in processed.columns and not pd.api.types.is_datetime64_any_dtype(processed['date']): + processed['date'] = pd.to_datetime(processed['date']) + + # 填充缺失值 + for col in ['open', 'high', 'low', 'close', 'volume', 'amount']: + if col in processed.columns: + processed[col] = processed[col].ffill() + + # 数据校验 + processed = DataProcessor._validate_ohlcv(processed) + + # 复权 + if gbbq is not None and not gbbq.empty and 'date' in processed.columns: + if adj_type == 'forward': + processed = DataProcessor.apply_forward_adj(processed, gbbq) + elif adj_type == 'backward': + processed = DataProcessor.apply_backward_adj(processed, gbbq) + else: + processed['adj_factor'] = 1.0 + elif 'adj_factor' not in processed.columns: + processed['adj_factor'] = 1.0 + + # 日期转 YYYYMMDD 整数 + processed['date'] = processed['date'].dt.strftime('%Y%m%d').astype(int) + + # 预留 turnover_rate 列 + if 'turnover_rate' not in processed.columns: + processed['turnover_rate'] = None + + return processed + + @staticmethod + def filter_data( + df: pd.DataFrame, + start_date: Optional[int] = None, + end_date: Optional[int] = None, + codes: Optional[List[str]] = None + ) -> pd.DataFrame: + """按 YYYYMMDD 整数日期和股票代码筛选。""" + if df.empty: + return df + result = df.copy() + if 'date' in result.columns: + if start_date: + result = result[result['date'] >= start_date] + if end_date: + result = result[result['date'] <= end_date] + if codes and 'code' in result.columns: + result = result[result['code'].isin(codes)] + return result diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py new file mode 100644 index 0000000..78b9496 --- /dev/null +++ b/src/tdx2db/reader.py @@ -0,0 +1,102 @@ +import re +import struct +from pathlib import Path +from typing import Optional + +import pandas as pd +from pytdx.reader import TdxDailyBarReader, GbbqReader + +from .config import config +from .logger import logger + + +class TdxDataReader: + def __init__(self, tdx_path: Optional[str] = None) -> None: + self.tdx_path = Path(tdx_path or config.tdx_path) + if not self.tdx_path: + raise ValueError("通达信数据路径未设置,请在 .env 中设置 TDX_PATH") + if not self.tdx_path.exists(): + raise FileNotFoundError(f"通达信数据路径不存在: {self.tdx_path}") + self.daily_reader = TdxDailyBarReader() + self.gbbq_reader = GbbqReader() + + def read_gbbq(self) -> pd.DataFrame: + """读取权息文件,返回全量权息 DataFrame。文件不存在时返回空 DataFrame。""" + gbbq_path = self.tdx_path / 'T0002' / 'hq_cache' / 'gbbq' + if not gbbq_path.exists(): + logger.warning(f"权息文件不存在: {gbbq_path},将跳过复权处理") + return pd.DataFrame() + try: + df = self.gbbq_reader.get_df(str(gbbq_path)) + if df.empty: + return pd.DataFrame() + market_prefix = df['market'].map({0: 'sz', 1: 'sh'}) + df['full_code'] = market_prefix + df['code'].astype(str).str.zfill(6) + return df + except Exception as e: + logger.warning(f"读取权息文件时出错: {e},将跳过复权处理") + return pd.DataFrame() + + def get_stock_list(self) -> pd.DataFrame: + """扫描本地 .day 文件获取 A 股股票列表(含市场前缀代码)。""" + sz_path = self.tdx_path / 'vipdoc' / 'sz' / 'lday' + sh_path = self.tdx_path / 'vipdoc' / 'sh' / 'lday' + if not (sz_path.exists() or sh_path.exists()): + raise FileNotFoundError("无法找到股票数据目录") + + stocks = [] + if sz_path.exists(): + for f in sz_path.glob('*.day'): + code = f.stem + pure = code[-6:].zfill(6) + if re.match(r'^(000|001|002|300)\d{3}$', pure): + stocks.append({'code': code, 'name': f'深A{code}'}) + + if sh_path.exists(): + for f in sh_path.glob('*.day'): + code = f.stem + pure = code[-6:].zfill(6) + if re.match(r'^(60\d{4}|688\d{3})$', pure): + stocks.append({'code': code, 'name': f'上A{code}'}) + + if not stocks: + raise FileNotFoundError("未找到任何股票数据文件") + return pd.DataFrame(stocks, columns=['code', 'name']) + + def read_daily_data(self, market: int, code: str) -> pd.DataFrame: + """读取单只股票日线数据,返回含 code/market 列的 DataFrame(date 为 DatetimeIndex)。""" + market_folder = 'sz' if market == 0 else 'sh' + pure_code = code[-6:] if len(code) > 6 else code + file_path = self.tdx_path / 'vipdoc' / market_folder / 'lday' / f"{market_folder}{pure_code}.day" + + if not file_path.exists(): + raise FileNotFoundError(f"日线数据文件不存在: {file_path}") + + sec_type = self.daily_reader.get_security_type(str(file_path)) + if sec_type in self.daily_reader.SECURITY_TYPE: + data = self.daily_reader.get_df(str(file_path)) + else: + data = self._read_day_file_raw(str(file_path)) + + data['code'] = pure_code + data['market'] = market + return data + + @staticmethod + def _read_day_file_raw(fname: str) -> pd.DataFrame: + """直接解析 .day 二进制文件(用于科创板等 pytdx 不支持的证券类型)。""" + rows = [] + with open(fname, 'rb') as f: + content = f.read() + record_size = struct.calcsize(' None: + self.db_url = db_url or config.database_url + self._db_type = self.db_url.split('://')[0].split('+')[0] + self.engine = create_engine(self.db_url) + Base.metadata.create_all(self.engine) + self.Session = sessionmaker(bind=self.engine) + + def get_latest_date_by_code(self, code: str) -> Optional[int]: + """返回指定股票在 daily_data 中最新的 YYYYMMDD 整数日期,无数据返回 None。""" + try: + with self.engine.connect() as conn: + row = conn.execute( + text("SELECT MAX(date) FROM daily_data WHERE code = :code"), + {"code": code} + ).fetchone() + return int(row[0]) if row and row[0] is not None else None + except Exception as e: + logger.debug(f"查询 {code} 最新日期出错: {e}") + return None + + def get_all_latest_dates(self) -> dict: + """一次查询返回所有股票最新日期 {code: YYYYMMDD int}。""" + try: + with self.engine.connect() as conn: + rows = conn.execute( + text("SELECT code, MAX(date) FROM daily_data GROUP BY code") + ).fetchall() + return {r[0]: int(r[1]) for r in rows if r[1] is not None} + except Exception as e: + logger.debug(f"查询所有股票最新日期出错: {e}") + return {} + + def delete_stock_data(self, code: str) -> None: + """删除某只股票全部日线记录(除权后全量重写前调用)。""" + try: + with self.engine.connect() as conn: + conn.execute( + text("DELETE FROM daily_data WHERE code = :code"), + {"code": code} + ) + conn.commit() + except Exception as e: + logger.error(f"删除 {code} 数据出错: {e}") + + def save_incremental( + self, + df: pd.DataFrame, + table_name: str, + conflict_columns: Tuple[str, ...] = ('code', 'date'), + batch_size: int = 10000 + ) -> int: + """增量保存,跳过重复记录(ON CONFLICT DO NOTHING / INSERT OR IGNORE / INSERT IGNORE)。""" + if table_name not in _VALID_TABLES: + raise ValueError(f"不允许写入的表名: {table_name}") + if df.empty: + return 0 + + df_to_save = df.copy() + if isinstance(df_to_save.index, pd.DatetimeIndex) or df_to_save.index.name in ('date', 'datetime'): + df_to_save = df_to_save.reset_index(drop=True) + + columns = list(df_to_save.columns) + columns_str = ', '.join(columns) + total_rows = len(df_to_save) + + try: + if self._db_type == 'postgresql': + self._save_incremental_pg(df_to_save, columns, columns_str, table_name, conflict_columns, batch_size) + else: + placeholders = ', '.join([f':{c}' for c in columns]) + if self._db_type == 'mysql': + sql = text(f"INSERT IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") + else: # sqlite + sql = text(f"INSERT OR IGNORE INTO {table_name} ({columns_str}) VALUES ({placeholders})") + + with self.engine.connect() as conn: + for i in range(0, total_rows, batch_size): + batch = df_to_save.iloc[i:i + batch_size].astype(object).where( + df_to_save.iloc[i:i + batch_size].notna(), None + ) + records = batch.to_dict('records') + for rec in records: + for k, v in rec.items(): + if isinstance(v, pd.Timestamp): + rec[k] = v.to_pydatetime() + elif v is pd.NaT: + rec[k] = None + conn.execute(sql, records) + conn.commit() + + logger.debug(f"增量保存完成: {total_rows} 条 → {table_name}(重复已跳过)") + return total_rows + except Exception as e: + logger.error(f"增量保存到 {table_name} 出错: {e}") + return 0 + + def _save_incremental_pg(self, df, columns, columns_str, table_name, conflict_columns, batch_size): + from psycopg2.extras import execute_values + conflict_str = ', '.join(conflict_columns) + sql = f"INSERT INTO {table_name} ({columns_str}) VALUES %s ON CONFLICT ({conflict_str}) DO NOTHING" + df_clean = df.astype(object).where(df.notna(), None) + values = list(df_clean.itertuples(index=False, name=None)) + raw_conn = self.engine.raw_connection() + try: + cur = raw_conn.cursor() + execute_values(cur, sql, values, page_size=batch_size) + raw_conn.commit() + finally: + raw_conn.close() + + def save_stock_info(self, df: pd.DataFrame) -> bool: + """保存股票列表到 stock_info 表(增量,跳过重复)。""" + return self.save_incremental(df, 'stock_info', conflict_columns=('code',)) > 0 + + def save_to_csv(self, df: pd.DataFrame, filename: str, csv_path: Optional[str] = None) -> Optional[str]: + path = Path(csv_path or config.csv_output_path) + os.makedirs(path, exist_ok=True) + file_path = path / f"{filename}.csv" + df.to_csv(file_path, index=False, encoding='utf-8') + logger.info(f"数据已保存到: {file_path}") + return str(file_path) diff --git a/tests/test_daily.py b/tests/test_daily.py new file mode 100644 index 0000000..8ac924f --- /dev/null +++ b/tests/test_daily.py @@ -0,0 +1,248 @@ +""" +日线数据同步测试套件(使用 SQLite 内存库 + mock TDX reader) + +测试用例: +1. test_full_sync_one_month - 全量所有股票,1个月日期范围 +2. test_single_stock_one_year - 指定股票,最近1年 +3. test_forward_adj_price - 前复权价格计算正确性 +4. test_incremental_update - 增量更新无重复,有除权时旧数据被替换 +""" + +import threading +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from src.tdx2db.processor import DataProcessor +from src.tdx2db.storage import DataStorage + + +# ─── 测试数据工厂 ──────────────────────────────────────────────────────────── + +def make_daily_df(code: str, market: int, start_date: str, periods: int) -> pd.DataFrame: + """生成假日线数据,date 列为 DatetimeIndex(模拟 reader 返回格式)。""" + dates = pd.bdate_range(start=start_date, periods=periods) + df = pd.DataFrame({ + 'open': [10.0] * periods, + 'high': [11.0] * periods, + 'low': [9.0] * periods, + 'close': [10.5] * periods, + 'volume': [1e6] * periods, + 'amount': [1e7] * periods, + 'code': [code[-6:]] * periods, + 'market': [market] * periods, + }, index=dates) + df.index.name = 'date' + return df + + +def make_gbbq_empty() -> pd.DataFrame: + return pd.DataFrame() + + +def make_gbbq_with_event(full_code: str, ex_date_int: int) -> pd.DataFrame: + """构造一条除权记录(category=1,送股 10%)。""" + return pd.DataFrame([{ + 'market': 0 if full_code.startswith('sz') else 1, + 'code': int(full_code[2:]), + 'datetime': ex_date_int, + 'category': 1, + 'hongli_panqianliutong': 0, + 'peigujia_qianzongguben': 0, + 'songgu_qianzongguben': 1.0, # 每10股送1股 → songgu/10 = 0.1 + 'peigu_houzongguben': 0, + 'full_code': full_code, + }]) + + +# ─── 测试用例 ──────────────────────────────────────────────────────────────── + +class TestFullSyncOneMonth: + """全量获取所有股票1个月日线数据。""" + + def test_records_written_and_date_format(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + # 模拟3只股票,每只20个交易日(约1个月) + stocks = [ + ('sz000001', 0), ('sz000002', 0), ('sh600000', 1) + ] + gbbq = make_gbbq_empty() + + for code, market in stocks: + df = make_daily_df(code, market, '2024-03-01', 20) + df = df.reset_index() + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='none') + storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date')) + + # 验证 + with storage.engine.connect() as conn: + from sqlalchemy import text + rows = conn.execute(text("SELECT COUNT(*) FROM daily_data")).fetchone() + assert rows[0] == 60, f"期望60条,实际{rows[0]}" + + # date 列应为 YYYYMMDD 整数 + sample = conn.execute(text("SELECT date FROM daily_data LIMIT 1")).fetchone() + assert isinstance(sample[0], int), f"date 应为整数,实际类型: {type(sample[0])}" + assert 20240301 <= sample[0] <= 20241231 + + +class TestSingleStockOneYear: + """指定股票最近1年日线数据。""" + + def test_date_range_correct(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + # 生成约250个交易日(1年) + df = make_daily_df('sz000001', 0, '2023-04-01', 250) + df = df.reset_index() + gbbq = make_gbbq_empty() + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='none') + + # 按日期过滤:只取 20240101 之后 + filtered = processor.filter_data(processed, start_date=20240101) + storage.save_incremental(filtered, 'daily_data', conflict_columns=('code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + rows = conn.execute( + text("SELECT MIN(date), MAX(date), COUNT(*) FROM daily_data WHERE code='000001'") + ).fetchone() + min_date, max_date, count = rows + assert min_date >= 20240101, f"最小日期 {min_date} 应 >= 20240101" + assert count > 0 + + +class TestForwardAdjPrice: + """前复权价格计算正确性:除权日前后价格应连续(无跳空)。""" + + def test_price_continuity_across_ex_date(self): + processor = DataProcessor() + + # 构造数据:2024-01-01 ~ 2024-03-31,共约60个交易日 + # 除权日设为 2024-02-01(送股10%,songgu=1.0 即每10股送1股) + df = make_daily_df('sz000001', 0, '2024-01-02', 60) + df = df.reset_index() + + # 所有原始收盘价均为 10.5 + gbbq = make_gbbq_with_event('sz000001', 20240201) + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') + + ex_date = 20240201 + before = processed[processed['date'] < ex_date] + on_or_after = processed[processed['date'] >= ex_date] + + assert not before.empty, "除权日前应有数据" + assert not on_or_after.empty, "除权日后应有数据" + + # 前复权后,除权日前的价格应被调低(adj_factor < 1) + # 送股10%:factor = prev_close / (prev_close * 1.1) ≈ 0.909 + # 除权日前 close = 10.5 * 0.909 ≈ 9.545 + # 除权日后 close = 10.5(原始价格不变) + adj_close_before = before['close'].iloc[-1] + raw_close_after = on_or_after['close'].iloc[0] + + # 验证复权因子已应用(除权日前价格应低于原始价格) + assert adj_close_before < 10.5, f"前复权后除权日前收盘价应 < 10.5,实际 {adj_close_before}" + assert abs(raw_close_after - 10.5) < 0.01, f"除权日后收盘价应保持 10.5,实际 {raw_close_after}" + + # 验证 adj_factor 列存在 + assert 'adj_factor' in processed.columns + + def test_no_adj_factor_without_gbbq(self): + processor = DataProcessor() + df = make_daily_df('sz000001', 0, '2024-01-02', 10).reset_index() + processed = processor.process_daily_data(df, gbbq=None, adj_type='forward') + assert 'adj_factor' in processed.columns + assert (processed['adj_factor'] == 1.0).all() + + +class TestIncrementalUpdate: + """增量更新:无重复行;有除权时旧数据被替换。""" + + def test_no_duplicates_on_second_sync(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + gbbq = make_gbbq_empty() + + # 第一次同步:20个交易日 + df1 = make_daily_df('sz000001', 0, '2024-01-02', 20).reset_index() + p1 = processor.process_daily_data(df1, gbbq=gbbq, adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + + # 第二次同步:同样的数据(模拟重复运行) + storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute( + text("SELECT COUNT(*) FROM daily_data WHERE code='000001'") + ).fetchone()[0] + assert count == 20, f"重复同步后应仍为20条,实际{count}" + + def test_incremental_appends_new_records(self, tmp_path): + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + gbbq = make_gbbq_empty() + + # 第一次:前10个交易日 + df1 = make_daily_df('sz000001', 0, '2024-01-02', 10).reset_index() + p1 = processor.process_daily_data(df1, gbbq=gbbq, adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + + last_date = storage.get_latest_date_by_code('000001') + assert last_date is not None + + # 第二次:后10个交易日(增量,只取 last_date 之后) + df2 = make_daily_df('sz000001', 0, '2024-01-02', 25).reset_index() + p2 = processor.process_daily_data(df2, gbbq=gbbq, adj_type='none') + p2_new = p2[p2['date'] > last_date] + storage.save_incremental(p2_new, 'daily_data', conflict_columns=('code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute( + text("SELECT COUNT(*) FROM daily_data WHERE code='000001'") + ).fetchone()[0] + assert count == 25, f"增量后应为25条,实际{count}" + + def test_full_refresh_on_ex_rights(self, tmp_path): + """有除权事件时,旧数据应被删除并重写(复权价格更新)。""" + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + # 第一次同步:无复权 + df = make_daily_df('sz000001', 0, '2024-01-02', 20).reset_index() + gbbq_empty = make_gbbq_empty() + p1 = processor.process_daily_data(df, gbbq=gbbq_empty, adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + + # 模拟发现除权事件 → 删除旧数据 + 重写前复权数据 + gbbq = make_gbbq_with_event('sz000001', 20240115) + p2 = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') + + storage.delete_stock_data('000001') + storage.save_incremental(p2, 'daily_data', conflict_columns=('code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute( + text("SELECT COUNT(*) FROM daily_data WHERE code='000001'") + ).fetchone()[0] + # 重写后记录数应与原始相同 + assert count == 20, f"全量重写后应为20条,实际{count}" + + # 除权日前的价格应已被调整(< 10.5) + adj_row = conn.execute( + text("SELECT close FROM daily_data WHERE code='000001' AND date < 20240115 ORDER BY date DESC LIMIT 1") + ).fetchone() + if adj_row: + assert adj_row[0] < 10.5, f"前复权后除权日前收盘价应 < 10.5,实际 {adj_row[0]}" From e7c42b31be34f062a6afce17a31ed800f55b2e3e Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 4 Apr 2026 23:47:17 +0800 Subject: [PATCH 09/26] =?UTF-8?q?=E6=9B=B4=E6=96=B0gitignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e70c070..6266212 100644 --- a/.gitignore +++ b/.gitignore @@ -45,5 +45,5 @@ poetry.lock output/ */__pycache__/ -CLAUDE.md tdx_data.db +.claude/ \ No newline at end of file From 60f788cd78dce0b1daa572ca6bd96fbbac8235b5 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sun, 5 Apr 2026 20:02:59 +0800 Subject: [PATCH 10/26] =?UTF-8?q?refactor:=20--code=20=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E6=94=B9=E4=B8=BA=E7=BA=AF6=E4=BD=8D=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=EF=BC=8C=E8=87=AA=E5=8A=A8=E8=AF=86=E5=88=AB=E5=B8=82=E5=9C=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 不再需要 sz/sh 前缀,cli 层根据首位数字自动判断市场(6开头→上海,其他→深圳),更新 README 和 CLAUDE.md 中的相关说明与示例。 Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 6 ++++-- README.md | 8 ++++---- src/tdx2db/cli.py | 9 +++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index d292e11..d12cf64 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -19,7 +19,7 @@ python main.py sync # 单独同步 python main.py daily --incremental -python main.py daily --code sz000001 --start 20240101 +python main.py daily --code 000001 --start 20240101 # 同步股票列表 python main.py stock-list @@ -64,10 +64,12 @@ CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (s ## 股票代码格式 -- 文件/CLI 层:带市场前缀,如 `sz000001`、`sh600000` +- CLI `--code` 参数:纯 6 位数字,如 `000001`、`600000`,市场自动识别 +- 内部流转层:带市场前缀,如 `sz000001`、`sh600000`(reader 内部使用) - 数据库层:纯 6 位数字,如 `000001`(reader 写入时截取) - 深圳 market=0,上海 market=1 - A 股筛选:深圳 `000/001/002/300` 开头,上海 `60/688` 开头 +- 市场自动识别规则:6 开头 → 上海(sh),其他 → 深圳(sz) ## 配置 diff --git a/README.md b/README.md index 3108c3e..7b97138 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,8 @@ python main.py sync # 同步所有股票日线(全量) python main.py daily -# 同步指定股票 -python main.py daily --code sz000001 +# 同步指定股票(6位代码,自动识别市场) +python main.py daily --code 000001 # 指定日期范围 python main.py daily --start 20240101 --end 20241231 @@ -77,10 +77,10 @@ sync = TdxDailySync( sync.sync_all(adj_type='forward') # 同步单只股票 -sync.sync_stock('sz000001', start_date=20240101) +sync.sync_stock('000001', start_date=20240101) # 查询数据 -df = sync.get_daily('sz000001', start_date=20240101, end_date=20241231) +df = sync.get_daily('000001', start_date=20240101, end_date=20241231) print(df.head()) ``` diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index ded0b95..7063ab1 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -114,7 +114,7 @@ def parse_args() -> Namespace: # daily daily = subparsers.add_parser('daily', help='同步日线数据') - daily.add_argument('--code', help='股票代码(含市场前缀,如 sz000001),不指定则全量') + daily.add_argument('--code', help='股票代码(6位数字,如 000001),不指定则全量,市场自动识别') daily.add_argument('--start', type=int, help='开始日期 YYYYMMDD') daily.add_argument('--end', type=int, help='结束日期 YYYYMMDD') daily.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') @@ -175,8 +175,9 @@ def main() -> int: gbbq = reader.read_gbbq() if args.code: - code = args.code - market = 1 if code.startswith('sh') else 0 + pure_code = args.code[-6:] if len(args.code) > 6 else args.code + market = 1 if pure_code.startswith('6') else 0 + code = ('sh' if market == 1 else 'sz') + pure_code try: data = reader.read_daily_data(market, code) if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): @@ -186,7 +187,7 @@ def main() -> int: if not processed.empty: storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date')) except Exception as e: - logger.error(f"同步 {args.code} 出错: {e}") + logger.error(f"同步 {code} 出错: {e}") return 1 else: incremental = getattr(args, 'incremental', False) From 859e5df51bc158da40c02e7e558e37de07a40f84 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sun, 5 Apr 2026 20:59:12 +0800 Subject: [PATCH 11/26] =?UTF-8?q?refactor:=20=E5=AF=B9=E9=BD=90=20daily=5F?= =?UTF-8?q?data=20=E8=A1=A8=E5=AD=97=E6=AE=B5=E4=B8=8E=E5=A4=96=E9=83=A8?= =?UTF-8?q?=20kline=20=E8=A1=A8=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - code → stock_code(列重命名) - date 类型 Integer(YYYYMMDD整数) → String(8)(YYYYMMDD字符串) - 更新所有 SQL 查询、conflict_columns、日期比较逻辑 - 同步更新测试断言 Co-Authored-By: Claude Sonnet 4.6 --- src/tdx2db/__init__.py | 8 ++++---- src/tdx2db/cli.py | 10 +++++----- src/tdx2db/processor.py | 15 +++++++++------ src/tdx2db/storage.py | 24 ++++++++++++------------ tests/test_daily.py | 22 +++++++++++----------- 5 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/tdx2db/__init__.py b/src/tdx2db/__init__.py index bcb1131..d1178ee 100644 --- a/src/tdx2db/__init__.py +++ b/src/tdx2db/__init__.py @@ -74,7 +74,7 @@ def sync_stock( processed = self.processor.filter_data(processed, start_date=start_date, end_date=end_date) if processed.empty: return 0 - return self.storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date')) + return self.storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) def sync_stock_list(self) -> int: """同步股票列表到 stock_info 表,返回股票数量。""" @@ -91,14 +91,14 @@ def get_daily( """从数据库查询日线数据,date 列为 YYYYMMDD 整数。""" from sqlalchemy import text pure_code = code[-6:] if len(code) > 6 else code - conditions = ["code = :code"] + conditions = ["stock_code = :code"] params: dict = {"code": pure_code} if start_date: conditions.append("date >= :start_date") - params["start_date"] = start_date + params["start_date"] = str(start_date) if end_date: conditions.append("date <= :end_date") - params["end_date"] = end_date + params["end_date"] = str(end_date) where = " AND ".join(conditions) sql = text(f"SELECT * FROM daily_data WHERE {where} ORDER BY date") with self.storage.engine.connect() as conn: diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 7063ab1..02c3647 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -22,7 +22,7 @@ def _has_ex_rights_after(code: str, gbbq: pd.DataFrame, last_date: int) -> bool: events = gbbq[ (gbbq['full_code'] == full_code) & (gbbq['category'] == 1) & - (gbbq['datetime'] > last_date) + (gbbq['datetime'] > int(last_date)) ] return not events.empty @@ -70,9 +70,9 @@ def sync_all_daily( if incremental and last_date and not needs_refresh: processed = processed[processed['date'] > last_date] if start_date: - processed = processed[processed['date'] >= start_date] + processed = processed[processed['date'] >= str(start_date)] if end_date: - processed = processed[processed['date'] <= end_date] + processed = processed[processed['date'] <= str(end_date)] if processed.empty: stats['success'] += 1 @@ -80,7 +80,7 @@ def sync_all_daily( if needs_refresh: storage.delete_stock_data(pure_code) - storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date'), + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date'), batch_size=config.db_batch_size) stats['success'] += 1 @@ -185,7 +185,7 @@ def main() -> int: processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) processed = processor.filter_data(processed, start_date=args.start, end_date=args.end) if not processed.empty: - storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) except Exception as e: logger.error(f"同步 {code} 出错: {e}") return 1 diff --git a/src/tdx2db/processor.py b/src/tdx2db/processor.py index 5c3364e..9a9ce5c 100644 --- a/src/tdx2db/processor.py +++ b/src/tdx2db/processor.py @@ -177,13 +177,16 @@ def process_daily_data( elif 'adj_factor' not in processed.columns: processed['adj_factor'] = 1.0 - # 日期转 YYYYMMDD 整数 - processed['date'] = processed['date'].dt.strftime('%Y%m%d').astype(int) + # 日期转 YYYYMMDD 字符串 + processed['date'] = processed['date'].dt.strftime('%Y%m%d') # 预留 turnover_rate 列 if 'turnover_rate' not in processed.columns: processed['turnover_rate'] = None + # 重命名 code → stock_code 以对齐目标表结构 + processed = processed.rename(columns={'code': 'stock_code'}) + return processed @staticmethod @@ -199,9 +202,9 @@ def filter_data( result = df.copy() if 'date' in result.columns: if start_date: - result = result[result['date'] >= start_date] + result = result[result['date'] >= str(start_date)] if end_date: - result = result[result['date'] <= end_date] - if codes and 'code' in result.columns: - result = result[result['code'].isin(codes)] + result = result[result['date'] <= str(end_date)] + if codes and 'stock_code' in result.columns: + result = result[result['stock_code'].isin(codes)] return result diff --git a/src/tdx2db/storage.py b/src/tdx2db/storage.py index 0c99028..1582184 100644 --- a/src/tdx2db/storage.py +++ b/src/tdx2db/storage.py @@ -14,12 +14,12 @@ class DailyData(Base): __tablename__ = 'daily_data' - __table_args__ = (UniqueConstraint('code', 'date'),) + __table_args__ = (UniqueConstraint('stock_code', 'date'),) id = Column(Integer, primary_key=True) - code = Column(String(10), index=True) + stock_code = Column(String(10), index=True) market = Column(Integer) - date = Column(Integer, index=True) # YYYYMMDD 整数 + date = Column(String(8), index=True) # YYYYMMDD 字符串 open = Column(Float) high = Column(Float) low = Column(Float) @@ -51,27 +51,27 @@ def __init__(self, db_url: Optional[str] = None) -> None: Base.metadata.create_all(self.engine) self.Session = sessionmaker(bind=self.engine) - def get_latest_date_by_code(self, code: str) -> Optional[int]: - """返回指定股票在 daily_data 中最新的 YYYYMMDD 整数日期,无数据返回 None。""" + def get_latest_date_by_code(self, code: str) -> Optional[str]: + """返回指定股票在 daily_data 中最新的 YYYYMMDD 字符串日期,无数据返回 None。""" try: with self.engine.connect() as conn: row = conn.execute( - text("SELECT MAX(date) FROM daily_data WHERE code = :code"), + text("SELECT MAX(date) FROM daily_data WHERE stock_code = :code"), {"code": code} ).fetchone() - return int(row[0]) if row and row[0] is not None else None + return row[0] if row and row[0] is not None else None except Exception as e: logger.debug(f"查询 {code} 最新日期出错: {e}") return None def get_all_latest_dates(self) -> dict: - """一次查询返回所有股票最新日期 {code: YYYYMMDD int}。""" + """一次查询返回所有股票最新日期 {code: YYYYMMDD str}。""" try: with self.engine.connect() as conn: rows = conn.execute( - text("SELECT code, MAX(date) FROM daily_data GROUP BY code") + text("SELECT stock_code, MAX(date) FROM daily_data GROUP BY stock_code") ).fetchall() - return {r[0]: int(r[1]) for r in rows if r[1] is not None} + return {r[0]: r[1] for r in rows if r[1] is not None} except Exception as e: logger.debug(f"查询所有股票最新日期出错: {e}") return {} @@ -81,7 +81,7 @@ def delete_stock_data(self, code: str) -> None: try: with self.engine.connect() as conn: conn.execute( - text("DELETE FROM daily_data WHERE code = :code"), + text("DELETE FROM daily_data WHERE stock_code = :code"), {"code": code} ) conn.commit() @@ -92,7 +92,7 @@ def save_incremental( self, df: pd.DataFrame, table_name: str, - conflict_columns: Tuple[str, ...] = ('code', 'date'), + conflict_columns: Tuple[str, ...] = ('stock_code', 'date'), batch_size: int = 10000 ) -> int: """增量保存,跳过重复记录(ON CONFLICT DO NOTHING / INSERT OR IGNORE / INSERT IGNORE)。""" diff --git a/tests/test_daily.py b/tests/test_daily.py index 8ac924f..219aceb 100644 --- a/tests/test_daily.py +++ b/tests/test_daily.py @@ -76,7 +76,7 @@ def test_records_written_and_date_format(self, tmp_path): df = make_daily_df(code, market, '2024-03-01', 20) df = df.reset_index() processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='none') - storage.save_incremental(processed, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) # 验证 with storage.engine.connect() as conn: @@ -84,10 +84,10 @@ def test_records_written_and_date_format(self, tmp_path): rows = conn.execute(text("SELECT COUNT(*) FROM daily_data")).fetchone() assert rows[0] == 60, f"期望60条,实际{rows[0]}" - # date 列应为 YYYYMMDD 整数 + # date 列应为 YYYYMMDD 字符串 sample = conn.execute(text("SELECT date FROM daily_data LIMIT 1")).fetchone() - assert isinstance(sample[0], int), f"date 应为整数,实际类型: {type(sample[0])}" - assert 20240301 <= sample[0] <= 20241231 + assert isinstance(sample[0], str), f"date 应为字符串,实际类型: {type(sample[0])}" + assert '20240301' <= sample[0] <= '20241231' class TestSingleStockOneYear: @@ -111,10 +111,10 @@ def test_date_range_correct(self, tmp_path): with storage.engine.connect() as conn: from sqlalchemy import text rows = conn.execute( - text("SELECT MIN(date), MAX(date), COUNT(*) FROM daily_data WHERE code='000001'") + text("SELECT MIN(date), MAX(date), COUNT(*) FROM daily_data WHERE stock_code='000001'") ).fetchone() min_date, max_date, count = rows - assert min_date >= 20240101, f"最小日期 {min_date} 应 >= 20240101" + assert min_date >= '20240101', f"最小日期 {min_date} 应 >= 20240101" assert count > 0 @@ -133,7 +133,7 @@ def test_price_continuity_across_ex_date(self): gbbq = make_gbbq_with_event('sz000001', 20240201) processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') - ex_date = 20240201 + ex_date = '20240201' before = processed[processed['date'] < ex_date] on_or_after = processed[processed['date'] >= ex_date] @@ -182,7 +182,7 @@ def test_no_duplicates_on_second_sync(self, tmp_path): with storage.engine.connect() as conn: from sqlalchemy import text count = conn.execute( - text("SELECT COUNT(*) FROM daily_data WHERE code='000001'") + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001'") ).fetchone()[0] assert count == 20, f"重复同步后应仍为20条,实际{count}" @@ -209,7 +209,7 @@ def test_incremental_appends_new_records(self, tmp_path): with storage.engine.connect() as conn: from sqlalchemy import text count = conn.execute( - text("SELECT COUNT(*) FROM daily_data WHERE code='000001'") + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001'") ).fetchone()[0] assert count == 25, f"增量后应为25条,实际{count}" @@ -235,14 +235,14 @@ def test_full_refresh_on_ex_rights(self, tmp_path): with storage.engine.connect() as conn: from sqlalchemy import text count = conn.execute( - text("SELECT COUNT(*) FROM daily_data WHERE code='000001'") + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001'") ).fetchone()[0] # 重写后记录数应与原始相同 assert count == 20, f"全量重写后应为20条,实际{count}" # 除权日前的价格应已被调整(< 10.5) adj_row = conn.execute( - text("SELECT close FROM daily_data WHERE code='000001' AND date < 20240115 ORDER BY date DESC LIMIT 1") + text("SELECT close FROM daily_data WHERE stock_code='000001' AND date < '20240115' ORDER BY date DESC LIMIT 1") ).fetchone() if adj_row: assert adj_row[0] < 10.5, f"前复权后除权日前收盘价应 < 10.5,实际 {adj_row[0]}" From 93d609507f45b4c7945244801d519441911c51e4 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sun, 5 Apr 2026 22:30:55 +0800 Subject: [PATCH 12/26] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=8C=97?= =?UTF-8?q?=E4=BA=A4=E6=89=80=EF=BC=88bj=EF=BC=89=E6=97=A5=E7=BA=BF?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=90=8C=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - reader: 扫描 vipdoc/bj/lday/,匹配 8xxxxx 和 92xxxx 代码(market=2) - reader: get_security_type 异常时回退到原始二进制读取,与科创板处理一致 - cli: 三处市场识别逻辑支持 bj 前缀和 8/92 开头代码 - tests: 修复 make_gbbq_with_event market 映射,新增 TestMultiMarket 三市场测试 Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 12 ++++----- src/tdx2db/cli.py | 19 ++++++++++--- src/tdx2db/reader.py | 24 ++++++++++++----- tests/test_daily.py | 64 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 102 insertions(+), 17 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index d12cf64..5aa138b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,7 +49,7 @@ CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (s 逐股票流式处理,不全量加载到内存: -1. 读取 `vipdoc/{sz,sh}/lday/*.day` → `process_daily_data()` 校验 OHLCV + 复权 +1. 读取 `vipdoc/{sz,sh,bj}/lday/*.day` → `process_daily_data()` 校验 OHLCV + 复权 2. 增量策略:`get_all_latest_dates()` 一次查询所有股票最新日期;若有除权事件则 `delete_stock_data()` + 全量重写 3. `save_incremental()` 使用 `ON CONFLICT DO NOTHING`(PG)/ `INSERT OR IGNORE`(SQLite)/ `INSERT IGNORE`(MySQL) @@ -64,12 +64,12 @@ CLI (cli.py) → Reader (reader.py) → Processor (processor.py) → Storage (s ## 股票代码格式 -- CLI `--code` 参数:纯 6 位数字,如 `000001`、`600000`,市场自动识别 -- 内部流转层:带市场前缀,如 `sz000001`、`sh600000`(reader 内部使用) +- CLI `--code` 参数:纯 6 位数字,如 `000001`、`600000`、`920001`,市场自动识别 +- 内部流转层:带市场前缀,如 `sz000001`、`sh600000`、`bj920001`(reader 内部使用) - 数据库层:纯 6 位数字,如 `000001`(reader 写入时截取) -- 深圳 market=0,上海 market=1 -- A 股筛选:深圳 `000/001/002/300` 开头,上海 `60/688` 开头 -- 市场自动识别规则:6 开头 → 上海(sh),其他 → 深圳(sz) +- 深圳 market=0,上海 market=1,北京 market=2 +- A 股筛选:深圳 `000/001/002/300` 开头,上海 `60/688` 开头,北交所 `8xxxxx` 或 `92xxxx` 开头 +- 市场自动识别规则:6 开头 → 上海(sh),8 或 92 开头 → 北京(bj),其他 → 深圳(sz) ## 配置 diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 02c3647..d0a5823 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -17,7 +17,12 @@ def _has_ex_rights_after(code: str, gbbq: pd.DataFrame, last_date: int) -> bool: """检查该股票在 last_date 之后是否有除权事件(category==1)。""" if gbbq is None or gbbq.empty: return False - prefix = 'sh' if code.startswith('6') else 'sz' + if code.startswith('6'): + prefix = 'sh' + elif code.startswith('8') or code.startswith('92'): + prefix = 'bj' + else: + prefix = 'sz' full_code = prefix + code.zfill(6) events = gbbq[ (gbbq['full_code'] == full_code) & @@ -48,7 +53,7 @@ def sync_all_daily( for _, stock in iterator: code = stock['code'] - market = 1 if code.startswith('sh') else 0 + market = 1 if code.startswith('sh') else (2 if code.startswith('bj') else 0) pure_code = code[-6:] if len(code) > 6 else code last_date = latest_dates.get(pure_code) @@ -176,8 +181,14 @@ def main() -> int: if args.code: pure_code = args.code[-6:] if len(args.code) > 6 else args.code - market = 1 if pure_code.startswith('6') else 0 - code = ('sh' if market == 1 else 'sz') + pure_code + if pure_code.startswith('6'): + market = 1 + elif pure_code.startswith('8') or pure_code.startswith('92'): + market = 2 + else: + market = 0 + prefix = {0: 'sz', 1: 'sh', 2: 'bj'}[market] + code = prefix + pure_code try: data = reader.read_daily_data(market, code) if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index 78b9496..1f5c3c0 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -41,7 +41,8 @@ def get_stock_list(self) -> pd.DataFrame: """扫描本地 .day 文件获取 A 股股票列表(含市场前缀代码)。""" sz_path = self.tdx_path / 'vipdoc' / 'sz' / 'lday' sh_path = self.tdx_path / 'vipdoc' / 'sh' / 'lday' - if not (sz_path.exists() or sh_path.exists()): + bj_path = self.tdx_path / 'vipdoc' / 'bj' / 'lday' + if not (sz_path.exists() or sh_path.exists() or bj_path.exists()): raise FileNotFoundError("无法找到股票数据目录") stocks = [] @@ -59,23 +60,34 @@ def get_stock_list(self) -> pd.DataFrame: if re.match(r'^(60\d{4}|688\d{3})$', pure): stocks.append({'code': code, 'name': f'上A{code}'}) + if bj_path.exists(): + for f in bj_path.glob('*.day'): + code = f.stem + pure = code[-6:].zfill(6) + if re.match(r'^(8\d{5}|92\d{4})$', pure): + stocks.append({'code': code, 'name': f'北A{code}'}) + if not stocks: raise FileNotFoundError("未找到任何股票数据文件") return pd.DataFrame(stocks, columns=['code', 'name']) def read_daily_data(self, market: int, code: str) -> pd.DataFrame: """读取单只股票日线数据,返回含 code/market 列的 DataFrame(date 为 DatetimeIndex)。""" - market_folder = 'sz' if market == 0 else 'sh' + market_map = {0: 'sz', 1: 'sh', 2: 'bj'} + market_folder = market_map[market] pure_code = code[-6:] if len(code) > 6 else code file_path = self.tdx_path / 'vipdoc' / market_folder / 'lday' / f"{market_folder}{pure_code}.day" if not file_path.exists(): raise FileNotFoundError(f"日线数据文件不存在: {file_path}") - sec_type = self.daily_reader.get_security_type(str(file_path)) - if sec_type in self.daily_reader.SECURITY_TYPE: - data = self.daily_reader.get_df(str(file_path)) - else: + try: + sec_type = self.daily_reader.get_security_type(str(file_path)) + if sec_type in self.daily_reader.SECURITY_TYPE: + data = self.daily_reader.get_df(str(file_path)) + else: + data = self._read_day_file_raw(str(file_path)) + except Exception: data = self._read_day_file_raw(str(file_path)) data['code'] = pure_code diff --git a/tests/test_daily.py b/tests/test_daily.py index 219aceb..8485ee5 100644 --- a/tests/test_daily.py +++ b/tests/test_daily.py @@ -43,8 +43,14 @@ def make_gbbq_empty() -> pd.DataFrame: def make_gbbq_with_event(full_code: str, ex_date_int: int) -> pd.DataFrame: """构造一条除权记录(category=1,送股 10%)。""" + if full_code.startswith('sh'): + market_val = 1 + elif full_code.startswith('bj'): + market_val = 2 + else: + market_val = 0 return pd.DataFrame([{ - 'market': 0 if full_code.startswith('sz') else 1, + 'market': market_val, 'code': int(full_code[2:]), 'datetime': ex_date_int, 'category': 1, @@ -246,3 +252,59 @@ def test_full_refresh_on_ex_rights(self, tmp_path): ).fetchone() if adj_row: assert adj_row[0] < 10.5, f"前复权后除权日前收盘价应 < 10.5,实际 {adj_row[0]}" + + +# ─── 三市场测试 ────────────────────────────────────────────────────────────── + +class TestMultiMarket: + """验证三个市场(sz/sh/bj)日线数据均能正确写入,且 market 列值准确。""" + + def test_all_three_markets_sync(self, tmp_path): + """sz=0, sh=1, bj=2 三市场股票同步后记录数和 market 值均正确。""" + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + gbbq = make_gbbq_empty() + + stocks = [ + ('sz000001', 0), # 深圳主板 + ('sh600000', 1), # 上海主板 + ('bj920001', 2), # 北交所(92开头新股) + ] + for code, market in stocks: + df = make_daily_df(code, market, '2024-03-01', 10).reset_index() + processed = processor.process_daily_data(df, gbbq=gbbq, adj_type='none') + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + for pure_code, expected_market in [('000001', 0), ('600000', 1), ('920001', 2)]: + row = conn.execute(text( + f"SELECT market, COUNT(*) FROM daily_data WHERE stock_code='{pure_code}' GROUP BY market" + )).fetchone() + assert row is not None, f"{pure_code} 无数据" + assert row[0] == expected_market, f"{pure_code} market 应为 {expected_market},实际 {row[0]}" + assert row[1] == 10, f"{pure_code} 应有10条记录,实际 {row[1]}" + + def test_bj_ex_rights_refresh(self, tmp_path): + """北交所股票有除权事件时,旧数据应被删除并重写。""" + db_url = f"sqlite:///{tmp_path}/test.db" + storage = DataStorage(db_url=db_url) + processor = DataProcessor() + + df = make_daily_df('bj920001', 2, '2024-01-02', 20).reset_index() + gbbq = make_gbbq_with_event('bj920001', 20240115) + + p1 = processor.process_daily_data(df, gbbq=make_gbbq_empty(), adj_type='none') + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) + + storage.delete_stock_data('920001') + p2 = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') + storage.save_incremental(p2, 'daily_data', conflict_columns=('stock_code', 'date')) + + with storage.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute(text( + "SELECT COUNT(*) FROM daily_data WHERE stock_code='920001'" + )).fetchone()[0] + assert count == 20 From b8e05b781b1148b36198ba3329bba250084c0db5 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Tue, 7 Apr 2026 09:12:17 +0800 Subject: [PATCH 13/26] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E8=81=94?= =?UTF-8?q?=E7=BD=91=E4=B8=8B=E8=BD=BD=20TDX=20=E6=97=A5=E7=BA=BF=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=AF=BC=E5=85=A5=E6=95=B0=E6=8D=AE=E5=BA=93=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 downloader.py:流式下载 hsjday.zip,手动处理 Windows 反斜杠路径,上下文管理器自动清理临时目录 - reader.py:新增 vipdoc_path 参数支持下载模式,抑制 pytdx 北交所 print 噪音 - cli.py:新增 download 子命令(--url/--adj/--no-clean),延迟 TdxDataReader 初始化 - config.py:新增 TDX_DOWNLOAD_URL 配置项 - requirements.txt:新增 requests 依赖 - 新增 tests/test_download.py:19 个测试用例覆盖下载/解压/读取/CLI 全流程 Co-Authored-By: Claude Sonnet 4.6 --- requirements.txt | 1 + src/tdx2db/cli.py | 35 +++- src/tdx2db/config.py | 2 + src/tdx2db/downloader.py | 112 +++++++++++ src/tdx2db/reader.py | 35 ++-- tests/test_download.py | 410 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 581 insertions(+), 14 deletions(-) create mode 100644 src/tdx2db/downloader.py create mode 100644 tests/test_download.py diff --git a/requirements.txt b/requirements.txt index 8628f86..acfe878 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ tqdm>=4.67.1 python-dotenv>=1.1.0 pymysql>=1.1.1 psycopg2-binary>=2.9.10 +requests>=2.32.0 diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index d0a5823..3c19f4d 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -11,6 +11,7 @@ from .storage import DataStorage from .config import config from .logger import logger +from .downloader import download_and_extract, DEFAULT_DOWNLOAD_URL def _has_ex_rights_after(code: str, gbbq: pd.DataFrame, last_date: int) -> bool: @@ -129,6 +130,13 @@ def parse_args() -> Namespace: sync = subparsers.add_parser('sync', help='一键增量同步日线数据') sync.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') + # download + download = subparsers.add_parser('download', help='联网下载 TDX 日线数据并导入数据库') + download.add_argument('--url', help=f'下载地址(默认: {DEFAULT_DOWNLOAD_URL})') + download.add_argument('--adj', choices=['forward', 'backward', 'none'], default='forward') + download.add_argument('--no-clean', action='store_true', dest='no_clean', + help='保留临时目录(用于调试)') + return parser.parse_args() @@ -157,15 +165,36 @@ def main() -> int: args = parse_args() update_config(args) + storage = DataStorage() + processor = DataProcessor() + + if args.command == 'download': + adj_type = getattr(args, 'adj', 'forward') + keep_tmp = getattr(args, 'no_clean', False) + url = getattr(args, 'url', None) + + gbbq = pd.DataFrame() + if config.tdx_path: + try: + local_reader = TdxDataReader() + gbbq = local_reader.read_gbbq() + logger.info("已从本地通达信读取权息文件") + except Exception: + logger.warning("本地权息文件读取失败,将跳过复权处理") + + logger.info("=== 开始联网下载 TDX 日线数据 ===") + with download_and_extract(url=url, keep_tmp=keep_tmp) as vipdoc_path: + dl_reader = TdxDataReader(vipdoc_path=str(vipdoc_path)) + sync_all_daily(dl_reader, processor, storage, gbbq, + adj_type=adj_type, incremental=True) + return 0 + try: reader = TdxDataReader() except (ValueError, FileNotFoundError) as e: logger.error(f"初始化失败: {e}") return 1 - storage = DataStorage() - processor = DataProcessor() - if args.command == 'stock-list': try: stocks = reader.get_stock_list() diff --git a/src/tdx2db/config.py b/src/tdx2db/config.py index 251dfd5..707d5f7 100644 --- a/src/tdx2db/config.py +++ b/src/tdx2db/config.py @@ -15,6 +15,7 @@ class Config: db_password: str db_batch_size: int use_tqdm: bool + download_url: str def __init__(self) -> None: self.tdx_path = os.getenv('TDX_PATH', '') @@ -27,6 +28,7 @@ def __init__(self) -> None: self.db_password = os.getenv('DB_PASSWORD', '') self.db_batch_size = int(os.getenv('DB_BATCH_SIZE', '10000')) self.use_tqdm = os.getenv('USE_TQDM', 'True').lower() == 'true' + self.download_url = os.getenv('TDX_DOWNLOAD_URL', '') @property def database_url(self) -> str: diff --git a/src/tdx2db/downloader.py b/src/tdx2db/downloader.py new file mode 100644 index 0000000..ce829ad --- /dev/null +++ b/src/tdx2db/downloader.py @@ -0,0 +1,112 @@ +import shutil +import tempfile +import zipfile +from contextlib import contextmanager +from pathlib import Path +from typing import Generator, Optional + +import requests +from tqdm import tqdm + +from .config import config +from .logger import logger + +DEFAULT_DOWNLOAD_URL = 'https://data.tdx.com.cn/vipdoc/hsjday.zip' + + +def download_zip(url: str, dest_path: Path, chunk_size: int = 1024 * 1024) -> None: + """流式下载 ZIP 文件,支持 tqdm 进度条。""" + try: + response = requests.get(url, stream=True, timeout=(10, 300)) + response.raise_for_status() + except requests.RequestException as e: + raise RuntimeError(f"下载失败: {e}") from e + + total = int(response.headers.get('Content-Length', 0)) or None + desc = Path(url).name + + try: + if config.use_tqdm: + with tqdm(total=total, unit='B', unit_scale=True, desc=desc) as bar: + with open(dest_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + else: + logger.info(f"正在下载 {desc}...") + with open(dest_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + except Exception: + dest_path.unlink(missing_ok=True) + raise + + +def extract_zip(zip_path: Path, extract_dir: Path) -> Path: + """解压 ZIP 文件,返回内部 hsjday/ 子目录路径(作为 vipdoc_path 使用)。""" + if not zipfile.is_zipfile(zip_path): + raise ValueError(f"文件不是合法的 ZIP 格式: {zip_path}") + + with zipfile.ZipFile(zip_path) as zf: + bad = zf.testzip() + if bad: + raise ValueError(f"ZIP 文件损坏,首个问题文件: {bad}") + + # 安全检查:过滤路径穿越 + for member in zf.infolist(): + if member.filename.startswith('/') or '..' in member.filename: + raise ValueError(f"ZIP 包含不安全路径: {member.filename}") + + logger.info("正在解压数据包...") + # ZIP 内路径可能使用 Windows 反斜杠(如 sh\lday\sh000001.day) + # Python zipfile 在 macOS/Linux 上不会自动转换,需手动处理 + for member in zf.infolist(): + normalized = member.filename.replace('\\', '/') + dest = extract_dir / Path(normalized) + dest.parent.mkdir(parents=True, exist_ok=True) + if not normalized.endswith('/'): + with zf.open(member) as src, open(dest, 'wb') as dst: + dst.write(src.read()) + + # 兼容两种结构: + # 1. {sh,sz,bj}/lday/*.day (根目录直接是市场目录,实际情况) + # 2. hsjday/{sh,sz,bj}/lday/*.day (有顶层目录) + if (extract_dir / 'sh').exists() or (extract_dir / 'sz').exists() or (extract_dir / 'bj').exists(): + return extract_dir + vipdoc_path = extract_dir / 'hsjday' + if vipdoc_path.exists(): + return vipdoc_path + raise FileNotFoundError("解压后未找到预期的市场目录(sh/sz/bj),请检查 ZIP 包结构") + + +@contextmanager +def download_and_extract( + url: Optional[str] = None, + keep_tmp: bool = False, +) -> Generator[Path, None, None]: + """ + 下载并解压 TDX 日线数据包,yield vipdoc_path(即 hsjday/ 目录)。 + + 退出时若 keep_tmp=False 自动删除临时目录。 + + 用法: + with download_and_extract() as vipdoc_path: + reader = TdxDataReader(vipdoc_path=str(vipdoc_path)) + """ + target_url = url or config.download_url or DEFAULT_DOWNLOAD_URL + tmp_dir = Path(tempfile.mkdtemp(prefix='tdx_hsjday_')) + zip_path = tmp_dir / 'hsjday.zip' + + try: + logger.info(f"开始下载: {target_url}") + download_zip(target_url, zip_path) + logger.info("下载完成,开始解压...") + vipdoc_path = extract_zip(zip_path, tmp_dir) + zip_path.unlink(missing_ok=True) # 解压后释放 ZIP 占用的磁盘空间 + logger.info(f"解压完成: {vipdoc_path}") + yield vipdoc_path + finally: + if keep_tmp: + logger.info(f"临时目录已保留: {tmp_dir}") + else: + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index 1f5c3c0..8a6c68b 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -11,17 +11,27 @@ class TdxDataReader: - def __init__(self, tdx_path: Optional[str] = None) -> None: - self.tdx_path = Path(tdx_path or config.tdx_path) - if not self.tdx_path: - raise ValueError("通达信数据路径未设置,请在 .env 中设置 TDX_PATH") - if not self.tdx_path.exists(): - raise FileNotFoundError(f"通达信数据路径不存在: {self.tdx_path}") + def __init__(self, tdx_path: Optional[str] = None, vipdoc_path: Optional[str] = None) -> None: + if vipdoc_path: + self.tdx_path = None + self._vipdoc_path = Path(vipdoc_path) + if not self._vipdoc_path.exists(): + raise FileNotFoundError(f"vipdoc 目录不存在: {self._vipdoc_path}") + else: + self.tdx_path = Path(tdx_path or config.tdx_path) + if not self.tdx_path: + raise ValueError("通达信数据路径未设置,请在 .env 中设置 TDX_PATH") + if not self.tdx_path.exists(): + raise FileNotFoundError(f"通达信数据路径不存在: {self.tdx_path}") + self._vipdoc_path = self.tdx_path / 'vipdoc' self.daily_reader = TdxDailyBarReader() self.gbbq_reader = GbbqReader() def read_gbbq(self) -> pd.DataFrame: """读取权息文件,返回全量权息 DataFrame。文件不存在时返回空 DataFrame。""" + if self.tdx_path is None: + logger.warning("联网下载模式下不支持读取 gbbq,将跳过复权处理") + return pd.DataFrame() gbbq_path = self.tdx_path / 'T0002' / 'hq_cache' / 'gbbq' if not gbbq_path.exists(): logger.warning(f"权息文件不存在: {gbbq_path},将跳过复权处理") @@ -39,9 +49,9 @@ def read_gbbq(self) -> pd.DataFrame: def get_stock_list(self) -> pd.DataFrame: """扫描本地 .day 文件获取 A 股股票列表(含市场前缀代码)。""" - sz_path = self.tdx_path / 'vipdoc' / 'sz' / 'lday' - sh_path = self.tdx_path / 'vipdoc' / 'sh' / 'lday' - bj_path = self.tdx_path / 'vipdoc' / 'bj' / 'lday' + sz_path = self._vipdoc_path / 'sz' / 'lday' + sh_path = self._vipdoc_path / 'sh' / 'lday' + bj_path = self._vipdoc_path / 'bj' / 'lday' if not (sz_path.exists() or sh_path.exists() or bj_path.exists()): raise FileNotFoundError("无法找到股票数据目录") @@ -76,13 +86,16 @@ def read_daily_data(self, market: int, code: str) -> pd.DataFrame: market_map = {0: 'sz', 1: 'sh', 2: 'bj'} market_folder = market_map[market] pure_code = code[-6:] if len(code) > 6 else code - file_path = self.tdx_path / 'vipdoc' / market_folder / 'lday' / f"{market_folder}{pure_code}.day" + file_path = self._vipdoc_path / market_folder / 'lday' / f"{market_folder}{pure_code}.day" if not file_path.exists(): raise FileNotFoundError(f"日线数据文件不存在: {file_path}") + import io + from contextlib import redirect_stdout try: - sec_type = self.daily_reader.get_security_type(str(file_path)) + with redirect_stdout(io.StringIO()): + sec_type = self.daily_reader.get_security_type(str(file_path)) if sec_type in self.daily_reader.SECURITY_TYPE: data = self.daily_reader.get_df(str(file_path)) else: diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 0000000..2d29712 --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,410 @@ +""" +联网下载功能测试套件(全部 mock,不发起真实网络请求) + +测试用例: +1. TestDownloader - downloader.py 的下载/解压/上下文管理逻辑 +2. TestReaderVipdocPath - TdxDataReader 新增的 vipdoc_path 参数 +3. TestDownloadCommand - CLI download 子命令的完整流程 +""" + +import io +import struct +import zipfile +from pathlib import Path +from unittest.mock import MagicMock, patch, call + +import pandas as pd +import pytest + +from src.tdx2db.downloader import ( + DEFAULT_DOWNLOAD_URL, + download_and_extract, + download_zip, + extract_zip, +) +from src.tdx2db.reader import TdxDataReader + + +# ─── 工具函数 ──────────────────────────────────────────────────────────────── + +def _make_day_bytes(n_records: int = 5) -> bytes: + """生成假的 .day 二进制内容(n 条记录)。""" + fmt = ' Path: + """在 tmp_path 创建假的 hsjday.zip,内含 {sh,sz}/lday/*.day 文件(与实际 ZIP 结构一致)。""" + zip_path = tmp_path / 'hsjday.zip' + day_content = _make_day_bytes(5) + + market_codes = { + 'sh': ['sh600000', 'sh600001'], + 'sz': ['sz000001', 'sz000002'], + 'bj': ['bj920001'], + } + + with zipfile.ZipFile(zip_path, 'w') as zf: + for market in markets: + codes = market_codes.get(market, [])[:codes_per_market] + for code in codes: + arc_name = f'{market}/lday/{code}.day' + zf.writestr(arc_name, day_content) + + return zip_path + + +# ─── TestDownloader ───────────────────────────────────────────────────────── + +class TestDownloadZip: + """download_zip 函数的单元测试。""" + + def test_downloads_file_successfully(self, tmp_path): + """正常下载时文件应被写入目标路径。""" + fake_content = b'PK\x03\x04' + b'\x00' * 100 # 假 ZIP 头 + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.headers = {'Content-Length': str(len(fake_content))} + mock_resp.iter_content = MagicMock(return_value=[fake_content]) + + dest = tmp_path / 'test.zip' + with patch('requests.get', return_value=mock_resp): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.use_tqdm = False + download_zip('http://fake-url/test.zip', dest) + + assert dest.exists() + assert dest.read_bytes() == fake_content + + def test_cleans_up_on_error(self, tmp_path): + """下载失败时应删除不完整文件。""" + import requests as req + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock(side_effect=req.HTTPError("404")) + mock_resp.headers = {} + + dest = tmp_path / 'test.zip' + with patch('requests.get', return_value=mock_resp): + with pytest.raises(RuntimeError, match="下载失败"): + download_zip('http://fake-url/test.zip', dest) + + assert not dest.exists() + + def test_raises_on_network_error(self, tmp_path): + """网络异常时应抛出 RuntimeError。""" + import requests as req + + with patch('requests.get', side_effect=req.ConnectionError("no route")): + with pytest.raises(RuntimeError, match="下载失败"): + download_zip('http://fake-url/test.zip', tmp_path / 'x.zip') + + +class TestExtractZip: + """extract_zip 函数的单元测试。""" + + def test_extracts_and_returns_vipdoc_path(self, tmp_path): + """正常解压时应返回包含市场目录的路径。""" + zip_path = _make_fake_zip(tmp_path, markets=['sh', 'sz']) + extract_dir = tmp_path / 'out' + extract_dir.mkdir() + + result = extract_zip(zip_path, extract_dir) + + assert result.exists() + assert (result / 'sh' / 'lday').exists() + + def test_raises_on_invalid_zip(self, tmp_path): + """非法 ZIP 文件应抛出 ValueError。""" + bad_zip = tmp_path / 'bad.zip' + bad_zip.write_bytes(b'not a zip file at all') + + with pytest.raises(ValueError, match="ZIP"): + extract_zip(bad_zip, tmp_path / 'out') + + def test_raises_on_missing_market_dirs(self, tmp_path): + """ZIP 内无 sh/sz/bj 目录时应抛出 FileNotFoundError。""" + zip_path = tmp_path / 'no_market.zip' + with zipfile.ZipFile(zip_path, 'w') as zf: + zf.writestr('some_other_dir/file.txt', 'content') + + extract_dir = tmp_path / 'out' + extract_dir.mkdir() + + with pytest.raises(FileNotFoundError, match="sh/sz/bj"): + extract_zip(zip_path, extract_dir) + + +class TestDownloadAndExtract: + """download_and_extract 上下文管理器测试。""" + + def test_yields_vipdoc_path_and_cleans_up(self): + """正常执行:yield vipdoc_path,结束后临时目录被清理。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + captured_vipdoc = None + + def fake_download(url, dest_path, **kwargs): + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = '' + mock_cfg.use_tqdm = False + with download_and_extract(url='http://fake/', keep_tmp=False) as vipdoc_path: + captured_vipdoc = vipdoc_path + assert vipdoc_path.exists() + + # 退出后 vipdoc_path 所在的 tmp_dir 应已删除 + assert not captured_vipdoc.exists() + finally: + shutil.rmtree(src_dir, ignore_errors=True) + + def test_keep_tmp_preserves_directory(self): + """keep_tmp=True 时 tmp_dir 应保留。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + captured_tmp_dir = None + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + + # 拦截 tempfile.mkdtemp 以记录 tmp_dir + real_mkdtemp = tempfile.mkdtemp + def fake_mkdtemp(**kwargs): + d = real_mkdtemp(**kwargs) + nonlocal captured_tmp_dir + captured_tmp_dir = Path(d) + return d + + def fake_download(url, dest_path, **kwargs): + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.tempfile.mkdtemp', side_effect=fake_mkdtemp): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = '' + mock_cfg.use_tqdm = False + with download_and_extract(url='http://fake/', keep_tmp=True): + pass + + assert captured_tmp_dir is not None + assert captured_tmp_dir.exists() + finally: + shutil.rmtree(src_dir, ignore_errors=True) + if captured_tmp_dir: + shutil.rmtree(captured_tmp_dir, ignore_errors=True) + + def test_uses_default_url_when_none_given(self): + """未传 url 时应使用 DEFAULT_DOWNLOAD_URL。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + called_urls = [] + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + + def fake_download(url, dest_path, **kwargs): + called_urls.append(url) + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = '' + mock_cfg.use_tqdm = False + with download_and_extract(url=None, keep_tmp=False): + pass + finally: + shutil.rmtree(src_dir, ignore_errors=True) + + assert called_urls and called_urls[0] == DEFAULT_DOWNLOAD_URL + + def test_config_url_overrides_default(self): + """config.download_url 非空时应优先于默认 URL。""" + import tempfile, shutil + src_dir = Path(tempfile.mkdtemp(prefix='tdx_test_src_')) + called_urls = [] + custom_url = 'http://my-custom-host/hsjday.zip' + try: + zip_path = _make_fake_zip(src_dir, markets=['sh']) + + def fake_download(url, dest_path, **kwargs): + called_urls.append(url) + shutil.copy(zip_path, dest_path) + + with patch('src.tdx2db.downloader.download_zip', side_effect=fake_download): + with patch('src.tdx2db.downloader.config') as mock_cfg: + mock_cfg.download_url = custom_url + mock_cfg.use_tqdm = False + with download_and_extract(url=None, keep_tmp=False): + pass + finally: + shutil.rmtree(src_dir, ignore_errors=True) + + assert called_urls and called_urls[0] == custom_url + + +# ─── TestReaderVipdocPath ──────────────────────────────────────────────────── + +class TestReaderVipdocPath: + """TdxDataReader 的 vipdoc_path 参数测试。""" + + def _write_day_files(self, base: Path, market: str, codes: list) -> None: + """在 base/{market}/lday/ 下写入假 .day 文件。""" + lday = base / market / 'lday' + lday.mkdir(parents=True, exist_ok=True) + for code in codes: + (lday / f'{code}.day').write_bytes(_make_day_bytes(5)) + + def test_vipdoc_path_get_stock_list(self, tmp_path): + """vipdoc_path 模式下 get_stock_list 应正确扫描三个市场目录。""" + vipdoc = tmp_path / 'hsjday' + self._write_day_files(vipdoc, 'sz', ['sz000001', 'sz000002']) + self._write_day_files(vipdoc, 'sh', ['sh600000']) + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + stocks = reader.get_stock_list() + + codes = set(stocks['code'].tolist()) + assert 'sz000001' in codes + assert 'sz000002' in codes + assert 'sh600000' in codes + + def test_vipdoc_path_read_daily_data(self, tmp_path): + """vipdoc_path 模式下 read_daily_data 应返回正确的 DataFrame。""" + vipdoc = tmp_path / 'hsjday' + self._write_day_files(vipdoc, 'sh', ['sh600000']) + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + df = reader.read_daily_data(market=1, code='sh600000') + + assert not df.empty + assert 'code' in df.columns + assert df['code'].iloc[0] == '600000' + assert 'open' in df.columns + + def test_vipdoc_path_read_gbbq_returns_empty(self, tmp_path): + """vipdoc_path 模式下 read_gbbq 应返回空 DataFrame(不报错)。""" + vipdoc = tmp_path / 'hsjday' + vipdoc.mkdir() + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + gbbq = reader.read_gbbq() + + assert isinstance(gbbq, pd.DataFrame) + assert gbbq.empty + + def test_vipdoc_path_raises_if_not_exists(self, tmp_path): + """vipdoc_path 不存在时应抛出 FileNotFoundError。""" + with pytest.raises(FileNotFoundError): + TdxDataReader(vipdoc_path=str(tmp_path / 'nonexistent')) + + def test_original_mode_still_works(self, tmp_path): + """不传 vipdoc_path 时,原有 tdx_path 模式应正常工作。""" + # 构造假的 TDX 目录结构 + tdx_path = tmp_path / 'tdx' + vipdoc = tdx_path / 'vipdoc' + (vipdoc / 'sz' / 'lday').mkdir(parents=True) + (vipdoc / 'sz' / 'lday' / 'sz000001.day').write_bytes(_make_day_bytes(3)) + + reader = TdxDataReader(tdx_path=str(tdx_path)) + assert reader.tdx_path == tdx_path + assert reader._vipdoc_path == tdx_path / 'vipdoc' + + +# ─── TestDownloadCommand ───────────────────────────────────────────────────── + +class TestDownloadCommand: + """CLI download 子命令的集成测试(mock 网络,使用内存 SQLite)。""" + + def _write_day_files(self, base: Path, market: str, codes: list) -> None: + lday = base / market / 'lday' + lday.mkdir(parents=True, exist_ok=True) + for code in codes: + (lday / f'{code}.day').write_bytes(_make_day_bytes(5)) + + def test_download_command_imports_data(self, tmp_path): + """download 命令完整流程:下载→解压→读取→入库。""" + from src.tdx2db.cli import main + from src.tdx2db.storage import DataStorage + + # 准备假的 vipdoc 目录 + vipdoc = tmp_path / 'hsjday' + self._write_day_files(vipdoc, 'sz', ['sz000001']) + self._write_day_files(vipdoc, 'sh', ['sh600000']) + + db_path = tmp_path / 'test.db' + + # mock download_and_extract 使其直接 yield 假 vipdoc 目录 + from contextlib import contextmanager + + @contextmanager + def fake_download_and_extract(url=None, keep_tmp=False): + yield vipdoc + + with patch('src.tdx2db.cli.download_and_extract', fake_download_and_extract): + with patch('src.tdx2db.cli.DataStorage') as MockStorage: + storage_instance = DataStorage(db_url=f'sqlite:///{db_path}') + MockStorage.return_value = storage_instance + with patch('src.tdx2db.cli.config') as mock_cfg: + mock_cfg.tdx_path = '' + mock_cfg.use_tqdm = False + mock_cfg.db_batch_size = 10000 + mock_cfg.download_url = '' + + result = main.__wrapped__() if hasattr(main, '__wrapped__') else None + # 直接调用同步逻辑 + from src.tdx2db.cli import sync_all_daily + from src.tdx2db.processor import DataProcessor + + reader = TdxDataReader(vipdoc_path=str(vipdoc)) + processor = DataProcessor() + gbbq = pd.DataFrame() + stats = sync_all_daily(reader, processor, storage_instance, gbbq, + adj_type='none', incremental=False) + + # 验证两只股票都有数据 + with storage_instance.engine.connect() as conn: + from sqlalchemy import text + count = conn.execute(text("SELECT COUNT(*) FROM daily_data")).fetchone()[0] + assert count > 0, "download 后应有数据写入数据库" + + def test_download_help_shows_subcommand(self): + """parse_args 应包含 download 子命令。""" + import sys + from src.tdx2db.cli import parse_args + + with patch('sys.argv', ['tdx2db', 'download', '--help']): + with pytest.raises(SystemExit) as exc: + parse_args() + assert exc.value.code == 0 + + def test_download_args_parsed_correctly(self): + """download 子命令参数应正确解析。""" + from src.tdx2db.cli import parse_args + + with patch('sys.argv', ['tdx2db', 'download', '--url', 'http://myhost/data.zip', + '--adj', 'none', '--no-clean']): + args = parse_args() + + assert args.command == 'download' + assert args.url == 'http://myhost/data.zip' + assert args.adj == 'none' + assert args.no_clean is True + + def test_download_default_args(self): + """download 子命令默认参数:adj=forward, no_clean=False。""" + from src.tdx2db.cli import parse_args + + with patch('sys.argv', ['tdx2db', 'download']): + args = parse_args() + + assert args.command == 'download' + assert args.adj == 'forward' + assert args.no_clean is False + assert args.url is None From a236be238fa71a82b5bf719c751c44313cf89ad6 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Tue, 7 Apr 2026 10:48:52 +0800 Subject: [PATCH 14/26] =?UTF-8?q?feat:=20stock=5Fcode=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=B8=82=E5=9C=BA=E5=90=8E=E7=BC=80=EF=BC=88.SZ/.SH/.BJ?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tdx2db/__init__.py | 14 ++++++++++++-- src/tdx2db/cli.py | 6 ++++-- src/tdx2db/processor.py | 9 +++++++-- src/tdx2db/storage.py | 2 +- tests/test_daily.py | 42 ++++++++++++++++++++--------------------- 5 files changed, 45 insertions(+), 28 deletions(-) diff --git a/src/tdx2db/__init__.py b/src/tdx2db/__init__.py index d1178ee..bf4a70b 100644 --- a/src/tdx2db/__init__.py +++ b/src/tdx2db/__init__.py @@ -90,9 +90,19 @@ def get_daily( ) -> pd.DataFrame: """从数据库查询日线数据,date 列为 YYYYMMDD 整数。""" from sqlalchemy import text - pure_code = code[-6:] if len(code) > 6 else code + if '.' in code: + db_code = code.upper() + else: + pure = code[-6:] if len(code) > 6 else code + if pure.startswith('6'): + suffix = '.SH' + elif pure.startswith('8') or pure.startswith('92'): + suffix = '.BJ' + else: + suffix = '.SZ' + db_code = pure + suffix conditions = ["stock_code = :code"] - params: dict = {"code": pure_code} + params: dict = {"code": db_code} if start_date: conditions.append("date >= :start_date") params["start_date"] = str(start_date) diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 3c19f4d..9a2eb4f 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -56,7 +56,9 @@ def sync_all_daily( code = stock['code'] market = 1 if code.startswith('sh') else (2 if code.startswith('bj') else 0) pure_code = code[-6:] if len(code) > 6 else code - last_date = latest_dates.get(pure_code) + suffix = {0: '.SZ', 1: '.SH', 2: '.BJ'}[market] + db_code = pure_code + suffix + last_date = latest_dates.get(db_code) try: data = reader.read_daily_data(market, code) @@ -85,7 +87,7 @@ def sync_all_daily( continue if needs_refresh: - storage.delete_stock_data(pure_code) + storage.delete_stock_data(db_code) storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date'), batch_size=config.db_batch_size) stats['success'] += 1 diff --git a/src/tdx2db/processor.py b/src/tdx2db/processor.py index 9a9ce5c..09f7bdf 100644 --- a/src/tdx2db/processor.py +++ b/src/tdx2db/processor.py @@ -184,8 +184,13 @@ def process_daily_data( if 'turnover_rate' not in processed.columns: processed['turnover_rate'] = None - # 重命名 code → stock_code 以对齐目标表结构 - processed = processed.rename(columns={'code': 'stock_code'}) + # 生成带市场后缀的 stock_code,如 000001.SZ / 600000.SH / 920001.BJ + _suffix_map = {0: '.SZ', 1: '.SH', 2: '.BJ'} + processed['stock_code'] = ( + processed['code'].astype(str).str.zfill(6) + + processed['market'].map(_suffix_map).fillna('.SZ') + ) + processed = processed.drop(columns=['code']) return processed diff --git a/src/tdx2db/storage.py b/src/tdx2db/storage.py index 1582184..aad5068 100644 --- a/src/tdx2db/storage.py +++ b/src/tdx2db/storage.py @@ -17,7 +17,7 @@ class DailyData(Base): __table_args__ = (UniqueConstraint('stock_code', 'date'),) id = Column(Integer, primary_key=True) - stock_code = Column(String(10), index=True) + stock_code = Column(String(12), index=True) market = Column(Integer) date = Column(String(8), index=True) # YYYYMMDD 字符串 open = Column(Float) diff --git a/tests/test_daily.py b/tests/test_daily.py index 8485ee5..8052771 100644 --- a/tests/test_daily.py +++ b/tests/test_daily.py @@ -112,12 +112,12 @@ def test_date_range_correct(self, tmp_path): # 按日期过滤:只取 20240101 之后 filtered = processor.filter_data(processed, start_date=20240101) - storage.save_incremental(filtered, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(filtered, 'daily_data', conflict_columns=('stock_code', 'date')) with storage.engine.connect() as conn: from sqlalchemy import text rows = conn.execute( - text("SELECT MIN(date), MAX(date), COUNT(*) FROM daily_data WHERE stock_code='000001'") + text("SELECT MIN(date), MAX(date), COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") ).fetchone() min_date, max_date, count = rows assert min_date >= '20240101', f"最小日期 {min_date} 应 >= 20240101" @@ -180,15 +180,15 @@ def test_no_duplicates_on_second_sync(self, tmp_path): # 第一次同步:20个交易日 df1 = make_daily_df('sz000001', 0, '2024-01-02', 20).reset_index() p1 = processor.process_daily_data(df1, gbbq=gbbq, adj_type='none') - storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) # 第二次同步:同样的数据(模拟重复运行) - storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) with storage.engine.connect() as conn: from sqlalchemy import text count = conn.execute( - text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001'") + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") ).fetchone()[0] assert count == 20, f"重复同步后应仍为20条,实际{count}" @@ -201,21 +201,21 @@ def test_incremental_appends_new_records(self, tmp_path): # 第一次:前10个交易日 df1 = make_daily_df('sz000001', 0, '2024-01-02', 10).reset_index() p1 = processor.process_daily_data(df1, gbbq=gbbq, adj_type='none') - storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) - last_date = storage.get_latest_date_by_code('000001') + last_date = storage.get_latest_date_by_code('000001.SZ') assert last_date is not None # 第二次:后10个交易日(增量,只取 last_date 之后) df2 = make_daily_df('sz000001', 0, '2024-01-02', 25).reset_index() p2 = processor.process_daily_data(df2, gbbq=gbbq, adj_type='none') p2_new = p2[p2['date'] > last_date] - storage.save_incremental(p2_new, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(p2_new, 'daily_data', conflict_columns=('stock_code', 'date')) with storage.engine.connect() as conn: from sqlalchemy import text count = conn.execute( - text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001'") + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") ).fetchone()[0] assert count == 25, f"增量后应为25条,实际{count}" @@ -229,26 +229,26 @@ def test_full_refresh_on_ex_rights(self, tmp_path): df = make_daily_df('sz000001', 0, '2024-01-02', 20).reset_index() gbbq_empty = make_gbbq_empty() p1 = processor.process_daily_data(df, gbbq=gbbq_empty, adj_type='none') - storage.save_incremental(p1, 'daily_data', conflict_columns=('code', 'date')) + storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) # 模拟发现除权事件 → 删除旧数据 + 重写前复权数据 gbbq = make_gbbq_with_event('sz000001', 20240115) p2 = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') - storage.delete_stock_data('000001') - storage.save_incremental(p2, 'daily_data', conflict_columns=('code', 'date')) + storage.delete_stock_data('000001.SZ') + storage.save_incremental(p2, 'daily_data', conflict_columns=('stock_code', 'date')) with storage.engine.connect() as conn: from sqlalchemy import text count = conn.execute( - text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001'") + text("SELECT COUNT(*) FROM daily_data WHERE stock_code='000001.SZ'") ).fetchone()[0] # 重写后记录数应与原始相同 assert count == 20, f"全量重写后应为20条,实际{count}" # 除权日前的价格应已被调整(< 10.5) adj_row = conn.execute( - text("SELECT close FROM daily_data WHERE stock_code='000001' AND date < '20240115' ORDER BY date DESC LIMIT 1") + text("SELECT close FROM daily_data WHERE stock_code='000001.SZ' AND date < '20240115' ORDER BY date DESC LIMIT 1") ).fetchone() if adj_row: assert adj_row[0] < 10.5, f"前复权后除权日前收盘价应 < 10.5,实际 {adj_row[0]}" @@ -278,13 +278,13 @@ def test_all_three_markets_sync(self, tmp_path): with storage.engine.connect() as conn: from sqlalchemy import text - for pure_code, expected_market in [('000001', 0), ('600000', 1), ('920001', 2)]: + for db_code, expected_market in [('000001.SZ', 0), ('600000.SH', 1), ('920001.BJ', 2)]: row = conn.execute(text( - f"SELECT market, COUNT(*) FROM daily_data WHERE stock_code='{pure_code}' GROUP BY market" + f"SELECT market, COUNT(*) FROM daily_data WHERE stock_code='{db_code}' GROUP BY market" )).fetchone() - assert row is not None, f"{pure_code} 无数据" - assert row[0] == expected_market, f"{pure_code} market 应为 {expected_market},实际 {row[0]}" - assert row[1] == 10, f"{pure_code} 应有10条记录,实际 {row[1]}" + assert row is not None, f"{db_code} 无数据" + assert row[0] == expected_market, f"{db_code} market 应为 {expected_market},实际 {row[0]}" + assert row[1] == 10, f"{db_code} 应有10条记录,实际 {row[1]}" def test_bj_ex_rights_refresh(self, tmp_path): """北交所股票有除权事件时,旧数据应被删除并重写。""" @@ -298,13 +298,13 @@ def test_bj_ex_rights_refresh(self, tmp_path): p1 = processor.process_daily_data(df, gbbq=make_gbbq_empty(), adj_type='none') storage.save_incremental(p1, 'daily_data', conflict_columns=('stock_code', 'date')) - storage.delete_stock_data('920001') + storage.delete_stock_data('920001.BJ') p2 = processor.process_daily_data(df, gbbq=gbbq, adj_type='forward') storage.save_incremental(p2, 'daily_data', conflict_columns=('stock_code', 'date')) with storage.engine.connect() as conn: from sqlalchemy import text count = conn.execute(text( - "SELECT COUNT(*) FROM daily_data WHERE stock_code='920001'" + "SELECT COUNT(*) FROM daily_data WHERE stock_code='920001.BJ'" )).fetchone()[0] assert count == 20 From f5122d05b2636cb018d113273321d31aef52ba40 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Tue, 7 Apr 2026 12:57:12 +0800 Subject: [PATCH 15/26] =?UTF-8?q?feat:=20stock=5Finfo=20=E8=A1=A8=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E6=94=B9=E9=80=A0=EF=BC=8C=E4=BD=BF=E7=94=A8=20akshar?= =?UTF-8?q?e=20=E8=8E=B7=E5=8F=96=E8=82=A1=E7=A5=A8=E4=B8=AD=E6=96=87?= =?UTF-8?q?=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - stock_info 表字段改为 stock_code/stock_name,与 daily_data 格式统一 - stock_code 格式统一为 000001.SZ,去掉 id/market 字段 - get_stock_list() 返回 list[str],不再返回 DataFrame - stock-list 命令通过 akshare.stock_info_a_code_name() 获取真实中文名 - save_stock_info 改为 upsert,支持名称更新 --- requirements.txt | 1 + src/tdx2db/cli.py | 37 ++++++++++++++++++++++++++----------- src/tdx2db/reader.py | 25 +++++++++++-------------- src/tdx2db/storage.py | 40 +++++++++++++++++++++++++++++++++------- 4 files changed, 71 insertions(+), 32 deletions(-) diff --git a/requirements.txt b/requirements.txt index acfe878..7809bbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ python-dotenv>=1.1.0 pymysql>=1.1.1 psycopg2-binary>=2.9.10 requests>=2.32.0 +akshare diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 9a2eb4f..022e389 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -50,14 +50,12 @@ def sync_all_daily( latest_dates = storage.get_all_latest_dates() if incremental else {} stats = {'total': len(stocks), 'success': 0, 'failed': 0} - iterator = tqdm(stocks.iterrows(), total=len(stocks), desc="同步日线") if config.use_tqdm else stocks.iterrows() - - for _, stock in iterator: - code = stock['code'] - market = 1 if code.startswith('sh') else (2 if code.startswith('bj') else 0) - pure_code = code[-6:] if len(code) > 6 else code - suffix = {0: '.SZ', 1: '.SH', 2: '.BJ'}[market] - db_code = pure_code + suffix + iterator = tqdm(stocks, total=len(stocks), desc="同步日线") if config.use_tqdm else stocks + + for db_code in iterator: + pure_code, suffix = db_code.split('.') + market = {'SZ': 0, 'SH': 1, 'BJ': 2}[suffix] + code = suffix.lower() + pure_code # sz000001,供 read_daily_data 使用 last_date = latest_dates.get(db_code) try: @@ -199,9 +197,26 @@ def main() -> int: if args.command == 'stock-list': try: - stocks = reader.get_stock_list() - logger.info(f"获取到 {len(stocks)} 只股票") - storage.save_stock_info(stocks) + import akshare as ak + + ak_df = ak.stock_info_a_code_name() # columns: code, name(code 为纯6位) + + def _add_suffix(code: str) -> str: + if code.startswith('6'): + return code + '.SH' + elif code.startswith('8') or code.startswith('92'): + return code + '.BJ' + return code + '.SZ' + + ak_map = {_add_suffix(row['code']): row['name'] for _, row in ak_df.iterrows()} + + local_codes = reader.get_stock_list() + df = pd.DataFrame([ + {'stock_code': c, 'stock_name': ak_map.get(c, c)} + for c in local_codes + ]) + logger.info(f"获取到 {len(df)} 只股票") + storage.save_stock_info(df) except Exception as e: logger.error(f"同步股票列表出错: {e}") return 1 diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index 8a6c68b..1555816 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -47,39 +47,36 @@ def read_gbbq(self) -> pd.DataFrame: logger.warning(f"读取权息文件时出错: {e},将跳过复权处理") return pd.DataFrame() - def get_stock_list(self) -> pd.DataFrame: - """扫描本地 .day 文件获取 A 股股票列表(含市场前缀代码)。""" + def get_stock_list(self) -> list: + """扫描本地 .day 文件,返回有数据的股票代码列表(000001.SZ 格式)。""" sz_path = self._vipdoc_path / 'sz' / 'lday' sh_path = self._vipdoc_path / 'sh' / 'lday' bj_path = self._vipdoc_path / 'bj' / 'lday' if not (sz_path.exists() or sh_path.exists() or bj_path.exists()): raise FileNotFoundError("无法找到股票数据目录") - stocks = [] + codes = [] if sz_path.exists(): for f in sz_path.glob('*.day'): - code = f.stem - pure = code[-6:].zfill(6) + pure = f.stem[-6:].zfill(6) if re.match(r'^(000|001|002|300)\d{3}$', pure): - stocks.append({'code': code, 'name': f'深A{code}'}) + codes.append(pure + '.SZ') if sh_path.exists(): for f in sh_path.glob('*.day'): - code = f.stem - pure = code[-6:].zfill(6) + pure = f.stem[-6:].zfill(6) if re.match(r'^(60\d{4}|688\d{3})$', pure): - stocks.append({'code': code, 'name': f'上A{code}'}) + codes.append(pure + '.SH') if bj_path.exists(): for f in bj_path.glob('*.day'): - code = f.stem - pure = code[-6:].zfill(6) + pure = f.stem[-6:].zfill(6) if re.match(r'^(8\d{5}|92\d{4})$', pure): - stocks.append({'code': code, 'name': f'北A{code}'}) + codes.append(pure + '.BJ') - if not stocks: + if not codes: raise FileNotFoundError("未找到任何股票数据文件") - return pd.DataFrame(stocks, columns=['code', 'name']) + return codes def read_daily_data(self, market: int, code: str) -> pd.DataFrame: """读取单只股票日线数据,返回含 code/market 列的 DataFrame(date 为 DatetimeIndex)。""" diff --git a/src/tdx2db/storage.py b/src/tdx2db/storage.py index aad5068..d8bd3a8 100644 --- a/src/tdx2db/storage.py +++ b/src/tdx2db/storage.py @@ -32,12 +32,10 @@ class DailyData(Base): class StockInfo(Base): __tablename__ = 'stock_info' - __table_args__ = (UniqueConstraint('code'),) + __table_args__ = (UniqueConstraint('stock_code'),) - id = Column(Integer, primary_key=True) - code = Column(String(10), index=True) - name = Column(String(50)) - market = Column(Integer) + stock_code = Column(String(12), primary_key=True) # 000001.SZ + stock_name = Column(String(50)) _VALID_TABLES = frozenset({'daily_data', 'stock_info'}) @@ -155,8 +153,36 @@ def _save_incremental_pg(self, df, columns, columns_str, table_name, conflict_co raw_conn.close() def save_stock_info(self, df: pd.DataFrame) -> bool: - """保存股票列表到 stock_info 表(增量,跳过重复)。""" - return self.save_incremental(df, 'stock_info', conflict_columns=('code',)) > 0 + """保存股票列表到 stock_info 表(upsert,名称可更新)。""" + if df.empty: + return False + try: + with self.engine.connect() as conn: + if self._db_type == 'postgresql': + sql = text(""" + INSERT INTO stock_info (stock_code, stock_name) + VALUES (:stock_code, :stock_name) + ON CONFLICT (stock_code) DO UPDATE SET stock_name = EXCLUDED.stock_name + """) + elif self._db_type == 'mysql': + sql = text(""" + INSERT INTO stock_info (stock_code, stock_name) + VALUES (:stock_code, :stock_name) + ON DUPLICATE KEY UPDATE stock_name = VALUES(stock_name) + """) + else: # sqlite + sql = text(""" + INSERT INTO stock_info (stock_code, stock_name) + VALUES (:stock_code, :stock_name) + ON CONFLICT (stock_code) DO UPDATE SET stock_name = excluded.stock_name + """) + conn.execute(sql, df.to_dict('records')) + conn.commit() + logger.debug(f"stock_info upsert 完成: {len(df)} 条") + return True + except Exception as e: + logger.error(f"保存 stock_info 出错: {e}") + return False def save_to_csv(self, df: pd.DataFrame, filename: str, csv_path: Optional[str] = None) -> Optional[str]: path = Path(csv_path or config.csv_output_path) From 9ca5ac42e3b90d2d1e719a450fb2c392bb6bf5d2 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Tue, 7 Apr 2026 13:06:28 +0800 Subject: [PATCH 16/26] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20kline=5Fstat?= =?UTF-8?q?istics=20=E8=A1=A8=E8=AE=B0=E5=BD=95=E6=AF=8F=E6=AC=A1=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E7=BB=9F=E8=AE=A1=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 每次 sync/daily/download 命令完成后,自动写入一条统计记录, 包含 stock_count(成功同步股票数)、total_rows(daily_data 总行数) 和 sync_time(同步时间)。 Co-Authored-By: Claude Sonnet 4.6 --- src/tdx2db/cli.py | 9 ++++++--- src/tdx2db/storage.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 022e389..0031790 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -185,8 +185,9 @@ def main() -> int: logger.info("=== 开始联网下载 TDX 日线数据 ===") with download_and_extract(url=url, keep_tmp=keep_tmp) as vipdoc_path: dl_reader = TdxDataReader(vipdoc_path=str(vipdoc_path)) - sync_all_daily(dl_reader, processor, storage, gbbq, + stats = sync_all_daily(dl_reader, processor, storage, gbbq, adj_type=adj_type, incremental=True) + storage.save_sync_statistics(stats['success']) return 0 try: @@ -248,16 +249,18 @@ def _add_suffix(code: str) -> str: return 1 else: incremental = getattr(args, 'incremental', False) - sync_all_daily(reader, processor, storage, gbbq, + stats = sync_all_daily(reader, processor, storage, gbbq, adj_type=adj_type, incremental=incremental, start_date=args.start, end_date=args.end) + storage.save_sync_statistics(stats['success']) elif args.command == 'sync': adj_type = getattr(args, 'adj', 'forward') logger.info("=== 开始增量同步日线数据 ===") gbbq = reader.read_gbbq() - sync_all_daily(reader, processor, storage, gbbq, + stats = sync_all_daily(reader, processor, storage, gbbq, adj_type=adj_type, incremental=True) + storage.save_sync_statistics(stats['success']) else: logger.error("请指定子命令,使用 -h 查看帮助") diff --git a/src/tdx2db/storage.py b/src/tdx2db/storage.py index d8bd3a8..9459229 100644 --- a/src/tdx2db/storage.py +++ b/src/tdx2db/storage.py @@ -1,9 +1,10 @@ +import datetime import os from pathlib import Path from typing import Optional, Tuple import pandas as pd -from sqlalchemy import create_engine, Column, Integer, Float, String, UniqueConstraint, text +from sqlalchemy import create_engine, Column, Integer, Float, String, UniqueConstraint, text, DateTime from sqlalchemy.orm import declarative_base, sessionmaker from .config import config @@ -38,6 +39,15 @@ class StockInfo(Base): stock_name = Column(String(50)) +class KlineStatistics(Base): + __tablename__ = 'kline_statistics' + + id = Column(Integer, primary_key=True) + stock_count = Column(Integer) + total_rows = Column(Integer) + sync_time = Column(DateTime) + + _VALID_TABLES = frozenset({'daily_data', 'stock_info'}) @@ -184,6 +194,28 @@ def save_stock_info(self, df: pd.DataFrame) -> bool: logger.error(f"保存 stock_info 出错: {e}") return False + def save_sync_statistics(self, stock_count: int) -> None: + """同步完成后写入一条统计记录到 kline_statistics。""" + try: + with self.engine.connect() as conn: + row = conn.execute(text("SELECT COUNT(*) FROM daily_data")).fetchone() + total_rows = row[0] if row else 0 + conn.execute( + text( + "INSERT INTO kline_statistics (stock_count, total_rows, sync_time) " + "VALUES (:stock_count, :total_rows, :sync_time)" + ), + { + "stock_count": stock_count, + "total_rows": total_rows, + "sync_time": datetime.datetime.now(), + } + ) + conn.commit() + logger.info(f"统计已记录: stock_count={stock_count}, total_rows={total_rows}") + except Exception as e: + logger.error(f"写入 kline_statistics 出错: {e}") + def save_to_csv(self, df: pd.DataFrame, filename: str, csv_path: Optional[str] = None) -> Optional[str]: path = Path(csv_path or config.csv_output_path) os.makedirs(path, exist_ok=True) From e5f4118482b05635e7ca9805ca1e3a29ae01fdf4 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Wed, 8 Apr 2026 17:15:53 +0800 Subject: [PATCH 17/26] =?UTF-8?q?fix:=20=E6=B7=B1=E5=9C=B3=20A=20=E8=82=A1?= =?UTF-8?q?=E7=AD=9B=E9=80=89=E8=A7=84=E5=88=99=E6=96=B0=E5=A2=9E=20301=20?= =?UTF-8?q?=E5=BC=80=E5=A4=B4=EF=BC=88=E5=88=9B=E4=B8=9A=E6=9D=BF=E6=B3=A8?= =?UTF-8?q?=E5=86=8C=E5=88=B6=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tdx2db/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index 1555816..38e2e9c 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -59,7 +59,7 @@ def get_stock_list(self) -> list: if sz_path.exists(): for f in sz_path.glob('*.day'): pure = f.stem[-6:].zfill(6) - if re.match(r'^(000|001|002|300)\d{3}$', pure): + if re.match(r'^(000|001|002|300|301)\d{3}$', pure): codes.append(pure + '.SZ') if sh_path.exists(): From 772ffc73f4df4fdfbe6feb719ff80d588793fa32 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Wed, 8 Apr 2026 18:37:25 +0800 Subject: [PATCH 18/26] =?UTF-8?q?feat:=20sync=20=E5=91=BD=E4=BB=A4?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E5=90=8C=E6=AD=A5=E8=82=A1=E7=A5=A8=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/tdx2db/cli.py | 54 +++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 0031790..ed97064 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -33,6 +33,35 @@ def _has_ex_rights_after(code: str, gbbq: pd.DataFrame, last_date: int) -> bool: return not events.empty +def sync_stock_list(reader: TdxDataReader, storage: DataStorage) -> bool: + """同步股票列表及名称,返回是否成功。""" + try: + import akshare as ak + + ak_df = ak.stock_info_a_code_name() + + def _add_suffix(code: str) -> str: + if code.startswith('6'): + return code + '.SH' + elif code.startswith('8') or code.startswith('92'): + return code + '.BJ' + return code + '.SZ' + + ak_map = {_add_suffix(row['code']): row['name'] for _, row in ak_df.iterrows()} + + local_codes = reader.get_stock_list() + df = pd.DataFrame([ + {'stock_code': c, 'stock_name': ak_map.get(c, c)} + for c in local_codes + ]) + logger.info(f"获取到 {len(df)} 只股票") + storage.save_stock_info(df) + return True + except Exception as e: + logger.error(f"同步股票列表出错: {e}") + return False + + def sync_all_daily( reader: TdxDataReader, processor: DataProcessor, @@ -197,29 +226,7 @@ def main() -> int: return 1 if args.command == 'stock-list': - try: - import akshare as ak - - ak_df = ak.stock_info_a_code_name() # columns: code, name(code 为纯6位) - - def _add_suffix(code: str) -> str: - if code.startswith('6'): - return code + '.SH' - elif code.startswith('8') or code.startswith('92'): - return code + '.BJ' - return code + '.SZ' - - ak_map = {_add_suffix(row['code']): row['name'] for _, row in ak_df.iterrows()} - - local_codes = reader.get_stock_list() - df = pd.DataFrame([ - {'stock_code': c, 'stock_name': ak_map.get(c, c)} - for c in local_codes - ]) - logger.info(f"获取到 {len(df)} 只股票") - storage.save_stock_info(df) - except Exception as e: - logger.error(f"同步股票列表出错: {e}") + if not sync_stock_list(reader, storage): return 1 elif args.command == 'daily': @@ -257,6 +264,7 @@ def _add_suffix(code: str) -> str: elif args.command == 'sync': adj_type = getattr(args, 'adj', 'forward') logger.info("=== 开始增量同步日线数据 ===") + sync_stock_list(reader, storage) gbbq = reader.read_gbbq() stats = sync_all_daily(reader, processor, storage, gbbq, adj_type=adj_type, incremental=True) From 1dd6a778d185ee362b0ceeca0de2185d1a436130 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sat, 11 Apr 2026 11:31:02 +0800 Subject: [PATCH 19/26] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20SMB=20?= =?UTF-8?q?=E7=BD=91=E7=BB=9C=E8=AE=BF=E9=97=AE=E6=A8=A1=E5=BC=8F=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E8=B7=A8=E6=9C=BA=E5=99=A8=E8=AF=BB=E5=8F=96?= =?UTF-8?q?=E8=BF=9C=E7=A8=8B=20TDX=20=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 smb_accessor.py:SmbAccessor 类封装所有 SMB I/O,使用 smbprotocol 库 - config.py 追加 SMB_ENABLED/HOST/SHARE/USER/PASSWORD/TDX_PATH/PORT 共 7 个配置项 - reader.py:TdxDataReader 新增 smb 参数,三个核心方法增加 SMB 分支,临时文件方式适配 pytdx - cli.py:新增 --smb-* 命令行参数,提取 _create_reader() 统一处理本地/SMB 两种模式 - 新增 tests/test_smb.py:31 个测试用例,全部 mock,覆盖路径构建、I/O、Reader 和 CLI - .env.example 和 README 补充 SMB 配置说明,注明需使用 Windows 本地账户 Co-Authored-By: Claude Sonnet 4.6 --- .env.example | 11 ++ README.md | 40 ++++ requirements.txt | 1 + src/tdx2db/cli.py | 140 +++++++++----- src/tdx2db/config.py | 14 ++ src/tdx2db/reader.py | 112 +++++++++-- src/tdx2db/smb_accessor.py | 134 +++++++++++++ tests/test_smb.py | 376 +++++++++++++++++++++++++++++++++++++ 8 files changed, 767 insertions(+), 61 deletions(-) create mode 100644 src/tdx2db/smb_accessor.py create mode 100644 tests/test_smb.py diff --git a/.env.example b/.env.example index 2a7de69..156cdf1 100644 --- a/.env.example +++ b/.env.example @@ -19,3 +19,14 @@ LOG_LEVEL=INFO # 处理选项 USE_TQDM=True # 是否显示进度条 + +# SMB 网络访问配置(可选,用于访问远程 PC 上的 TDX 安装目录) +# 启用后 TDX_PATH 可不填 +SMB_ENABLED=false +SMB_HOST=192.168.1.100 +SMB_SHARE=tdx_share +SMB_USER=myuser +SMB_PASSWORD=mypassword +# TDX 在共享目录内的相对路径,若共享根目录即为 TDX 安装目录则留空 +SMB_TDX_PATH=TDX +SMB_PORT=445 diff --git a/README.md b/README.md index 7b97138..1715f93 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,46 @@ DB_BATCH_SIZE=10000 USE_TQDM=True ``` +## SMB 网络访问模式 + +如果通达信安装在另一台 Windows PC 上,可以通过 SMB 协议远程读取数据,无需把软件安装在运行本程序的机器上。 + +### 1. 在 Windows PC 上共享 TDX 目录 + +右键 TDX 安装目录(如 `D:\new_tdx64`)→ 属性 → 共享 → 高级共享: + +- 勾选"共享此文件夹" +- 设置共享名,如 `new_tdx64` +- 权限 → 添加需要访问的账户,授予"读取"权限 + +### 2. 创建专用本地账户(推荐) + +**强烈建议**新建一个 Windows 本地账户用于 SMB 访问,而不是使用微软账户。微软账户通过 NTLM 网络认证时兼容性较差,容易登录失败。 + +在 Windows PC 上以管理员身份打开 PowerShell: + +```powershell +# 创建本地账户(替换为你想要的用户名和密码) +net user tdxread YourPassword123 /add +net localgroup Users tdxread /add +``` + +然后在共享权限中把 `tdxread` 加入,授予读取权限。 + +### 3. 配置 .env + +``` +SMB_ENABLED=true +SMB_HOST=192.168.1.100 # Windows PC 的 IP 或主机名 +SMB_SHARE=new_tdx64 # 共享名 +SMB_USER=tdxread # 本地账户用户名 +SMB_PASSWORD=YourPassword123 +SMB_TDX_PATH= # TDX 在共享内的相对路径,共享根目录就是 TDX 目录时留空 +SMB_PORT=445 +``` + +启用 SMB 模式后,`TDX_PATH` 可以不填。 + ## 命令行使用 ```bash diff --git a/requirements.txt b/requirements.txt index 7809bbd..09c48aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ pymysql>=1.1.1 psycopg2-binary>=2.9.10 requests>=2.32.0 akshare +smbprotocol>=1.13.0 diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index ed97064..cf877c5 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -141,6 +141,12 @@ def parse_args() -> Namespace: parser.add_argument('--db-password') parser.add_argument('--no-tqdm', action='store_true') parser.add_argument('--batch-size', type=int, default=10000) + parser.add_argument('--smb-host', help='SMB 服务器地址') + parser.add_argument('--smb-share', help='SMB 共享名') + parser.add_argument('--smb-user', help='SMB 用户名') + parser.add_argument('--smb-password', help='SMB 密码') + parser.add_argument('--smb-tdx-path', help='TDX 在共享目录内的相对路径', default='') + parser.add_argument('--smb-port', type=int, default=445) subparsers = parser.add_subparsers(dest='command') @@ -188,6 +194,38 @@ def update_config(args: Namespace) -> None: config.db_batch_size = args.batch_size if args.no_tqdm: config.use_tqdm = False + if getattr(args, 'smb_host', None): + config.smb_host = args.smb_host + config.smb_enabled = True + if getattr(args, 'smb_share', None): + config.smb_share = args.smb_share + if getattr(args, 'smb_user', None): + config.smb_user = args.smb_user + if getattr(args, 'smb_password', None): + config.smb_password = args.smb_password + if getattr(args, 'smb_tdx_path', None) is not None: + config.smb_tdx_path = args.smb_tdx_path + if getattr(args, 'smb_port', None): + config.smb_port = args.smb_port + + +def _create_reader(): + """根据配置创建 TdxDataReader,返回 (reader, smb_accessor_or_None)。""" + if config.smb_enabled: + if not config.smb_host or not config.smb_share: + raise ValueError("SMB 模式需要设置 SMB_HOST 和 SMB_SHARE") + from .smb_accessor import SmbAccessor + smb = SmbAccessor( + host=config.smb_host, + share=config.smb_share, + tdx_path=config.smb_tdx_path, + username=config.smb_user or None, + password=config.smb_password or None, + port=config.smb_port, + ) + smb._register() + return TdxDataReader(smb=smb), smb + return TdxDataReader(), None def main() -> int: @@ -203,7 +241,16 @@ def main() -> int: url = getattr(args, 'url', None) gbbq = pd.DataFrame() - if config.tdx_path: + if config.smb_enabled: + try: + smb_reader, smb_acc = _create_reader() + gbbq = smb_reader.read_gbbq() + if smb_acc: + smb_acc._unregister() + logger.info("已从 SMB 读取权息文件") + except Exception: + logger.warning("SMB 权息文件读取失败,将跳过复权处理") + elif config.tdx_path: try: local_reader = TdxDataReader() gbbq = local_reader.read_gbbq() @@ -220,59 +267,64 @@ def main() -> int: return 0 try: - reader = TdxDataReader() + reader, smb_accessor = _create_reader() except (ValueError, FileNotFoundError) as e: logger.error(f"初始化失败: {e}") return 1 - if args.command == 'stock-list': - if not sync_stock_list(reader, storage): - return 1 + try: + if args.command == 'stock-list': + if not sync_stock_list(reader, storage): + return 1 - elif args.command == 'daily': - adj_type = getattr(args, 'adj', 'forward') - gbbq = reader.read_gbbq() - - if args.code: - pure_code = args.code[-6:] if len(args.code) > 6 else args.code - if pure_code.startswith('6'): - market = 1 - elif pure_code.startswith('8') or pure_code.startswith('92'): - market = 2 + elif args.command == 'daily': + adj_type = getattr(args, 'adj', 'forward') + gbbq = reader.read_gbbq() + + if args.code: + pure_code = args.code[-6:] if len(args.code) > 6 else args.code + if pure_code.startswith('6'): + market = 1 + elif pure_code.startswith('8') or pure_code.startswith('92'): + market = 2 + else: + market = 0 + prefix = {0: 'sz', 1: 'sh', 2: 'bj'}[market] + code = prefix + pure_code + try: + data = reader.read_daily_data(market, code) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + processed = processor.filter_data(processed, start_date=args.start, end_date=args.end) + if not processed.empty: + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) + except Exception as e: + logger.error(f"同步 {code} 出错: {e}") + return 1 else: - market = 0 - prefix = {0: 'sz', 1: 'sh', 2: 'bj'}[market] - code = prefix + pure_code - try: - data = reader.read_daily_data(market, code) - if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): - data = data.reset_index() - processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) - processed = processor.filter_data(processed, start_date=args.start, end_date=args.end) - if not processed.empty: - storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) - except Exception as e: - logger.error(f"同步 {code} 出错: {e}") - return 1 - else: - incremental = getattr(args, 'incremental', False) + incremental = getattr(args, 'incremental', False) + stats = sync_all_daily(reader, processor, storage, gbbq, + adj_type=adj_type, incremental=incremental, + start_date=args.start, end_date=args.end) + storage.save_sync_statistics(stats['success']) + + elif args.command == 'sync': + adj_type = getattr(args, 'adj', 'forward') + logger.info("=== 开始增量同步日线数据 ===") + sync_stock_list(reader, storage) + gbbq = reader.read_gbbq() stats = sync_all_daily(reader, processor, storage, gbbq, - adj_type=adj_type, incremental=incremental, - start_date=args.start, end_date=args.end) + adj_type=adj_type, incremental=True) storage.save_sync_statistics(stats['success']) - elif args.command == 'sync': - adj_type = getattr(args, 'adj', 'forward') - logger.info("=== 开始增量同步日线数据 ===") - sync_stock_list(reader, storage) - gbbq = reader.read_gbbq() - stats = sync_all_daily(reader, processor, storage, gbbq, - adj_type=adj_type, incremental=True) - storage.save_sync_statistics(stats['success']) + else: + logger.error("请指定子命令,使用 -h 查看帮助") + return 1 - else: - logger.error("请指定子命令,使用 -h 查看帮助") - return 1 + finally: + if smb_accessor is not None: + smb_accessor._unregister() return 0 diff --git a/src/tdx2db/config.py b/src/tdx2db/config.py index 707d5f7..58068b1 100644 --- a/src/tdx2db/config.py +++ b/src/tdx2db/config.py @@ -16,6 +16,13 @@ class Config: db_batch_size: int use_tqdm: bool download_url: str + smb_enabled: bool + smb_host: str + smb_share: str + smb_user: str + smb_password: str + smb_tdx_path: str + smb_port: int def __init__(self) -> None: self.tdx_path = os.getenv('TDX_PATH', '') @@ -29,6 +36,13 @@ def __init__(self) -> None: self.db_batch_size = int(os.getenv('DB_BATCH_SIZE', '10000')) self.use_tqdm = os.getenv('USE_TQDM', 'True').lower() == 'true' self.download_url = os.getenv('TDX_DOWNLOAD_URL', '') + self.smb_enabled = os.getenv('SMB_ENABLED', 'false').lower() == 'true' + self.smb_host = os.getenv('SMB_HOST', '') + self.smb_share = os.getenv('SMB_SHARE', '') + self.smb_user = os.getenv('SMB_USER', '') + self.smb_password = os.getenv('SMB_PASSWORD', '') + self.smb_tdx_path = os.getenv('SMB_TDX_PATH', '') + self.smb_port = int(os.getenv('SMB_PORT', '445')) @property def database_url(self) -> str: diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index 38e2e9c..f052e78 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -1,7 +1,10 @@ +import io +import os import re import struct +from contextlib import redirect_stdout from pathlib import Path -from typing import Optional +from typing import Optional, TYPE_CHECKING import pandas as pd from pytdx.reader import TdxDailyBarReader, GbbqReader @@ -9,10 +12,23 @@ from .config import config from .logger import logger +if TYPE_CHECKING: + from .smb_accessor import SmbAccessor + class TdxDataReader: - def __init__(self, tdx_path: Optional[str] = None, vipdoc_path: Optional[str] = None) -> None: - if vipdoc_path: + def __init__( + self, + tdx_path: Optional[str] = None, + vipdoc_path: Optional[str] = None, + smb: Optional['SmbAccessor'] = None, + ) -> None: + self._smb = smb + + if smb is not None: + self.tdx_path = None + self._vipdoc_path = None + elif vipdoc_path: self.tdx_path = None self._vipdoc_path = Path(vipdoc_path) if not self._vipdoc_path.exists(): @@ -29,6 +45,8 @@ def __init__(self, tdx_path: Optional[str] = None, vipdoc_path: Optional[str] = def read_gbbq(self) -> pd.DataFrame: """读取权息文件,返回全量权息 DataFrame。文件不存在时返回空 DataFrame。""" + if self._smb is not None: + return self._read_gbbq_smb() if self.tdx_path is None: logger.warning("联网下载模式下不支持读取 gbbq,将跳过复权处理") return pd.DataFrame() @@ -49,6 +67,8 @@ def read_gbbq(self) -> pd.DataFrame: def get_stock_list(self) -> list: """扫描本地 .day 文件,返回有数据的股票代码列表(000001.SZ 格式)。""" + if self._smb is not None: + return self._get_stock_list_smb() sz_path = self._vipdoc_path / 'sz' / 'lday' sh_path = self._vipdoc_path / 'sh' / 'lday' bj_path = self._vipdoc_path / 'bj' / 'lday' @@ -83,22 +103,26 @@ def read_daily_data(self, market: int, code: str) -> pd.DataFrame: market_map = {0: 'sz', 1: 'sh', 2: 'bj'} market_folder = market_map[market] pure_code = code[-6:] if len(code) > 6 else code - file_path = self._vipdoc_path / market_folder / 'lday' / f"{market_folder}{pure_code}.day" + filename = f"{market_folder}{pure_code}.day" - if not file_path.exists(): - raise FileNotFoundError(f"日线数据文件不存在: {file_path}") - - import io - from contextlib import redirect_stdout - try: - with redirect_stdout(io.StringIO()): - sec_type = self.daily_reader.get_security_type(str(file_path)) - if sec_type in self.daily_reader.SECURITY_TYPE: - data = self.daily_reader.get_df(str(file_path)) - else: + if self._smb is not None: + unc = self._smb.day_file_unc(market_folder, filename) + if not self._smb.exists(unc): + raise FileNotFoundError(f"SMB 日线数据文件不存在: {unc}") + data = self._read_daily_via_smb(unc) + else: + file_path = self._vipdoc_path / market_folder / 'lday' / filename + if not file_path.exists(): + raise FileNotFoundError(f"日线数据文件不存在: {file_path}") + try: + with redirect_stdout(io.StringIO()): + sec_type = self.daily_reader.get_security_type(str(file_path)) + if sec_type in self.daily_reader.SECURITY_TYPE: + data = self.daily_reader.get_df(str(file_path)) + else: + data = self._read_day_file_raw(str(file_path)) + except Exception: data = self._read_day_file_raw(str(file_path)) - except Exception: - data = self._read_day_file_raw(str(file_path)) data['code'] = pure_code data['market'] = market @@ -122,3 +146,57 @@ def _read_day_file_raw(fname: str) -> pd.DataFrame: df.index = pd.to_datetime(df['date']) df.index.name = 'date' return df[['open', 'high', 'low', 'close', 'amount', 'volume']] + + # ── SMB 私有方法 ────────────────────────────────────────────────────────── + + def _read_gbbq_smb(self) -> pd.DataFrame: + unc = self._smb.gbbq_unc + if not self._smb.exists(unc): + logger.warning(f"SMB 权息文件不存在: {unc},将跳过复权处理") + return pd.DataFrame() + tmp_path = self._smb.download_to_tmp(unc, suffix='') + try: + df = self.gbbq_reader.get_df(tmp_path) + if df.empty: + return pd.DataFrame() + market_prefix = df['market'].map({0: 'sz', 1: 'sh'}) + df['full_code'] = market_prefix + df['code'].astype(str).str.zfill(6) + return df + except Exception as e: + logger.warning(f"SMB 读取权息文件时出错: {e},将跳过复权处理") + return pd.DataFrame() + finally: + os.unlink(tmp_path) + + def _get_stock_list_smb(self) -> list: + codes = [] + for market, pattern, suffix in [ + ('sz', r'^(000|001|002|300|301)\d{3}$', '.SZ'), + ('sh', r'^(60\d{4}|688\d{3})$', '.SH'), + ('bj', r'^(8\d{5}|92\d{4})$', '.BJ'), + ]: + unc_dir = self._smb.lday_dir_unc(market) + files = self._smb.list_files(unc_dir, suffix='.day') + for fname in files: + stem = Path(fname).stem + pure = stem[-6:].zfill(6) + if re.match(pattern, pure): + codes.append(pure + suffix) + if not codes: + raise FileNotFoundError("SMB 模式下未找到任何股票数据文件") + return codes + + def _read_daily_via_smb(self, unc: str) -> pd.DataFrame: + tmp_path = self._smb.download_to_tmp(unc, suffix='.day') + try: + try: + with redirect_stdout(io.StringIO()): + sec_type = self.daily_reader.get_security_type(tmp_path) + if sec_type in self.daily_reader.SECURITY_TYPE: + return self.daily_reader.get_df(tmp_path) + else: + return self._read_day_file_raw(tmp_path) + except Exception: + return self._read_day_file_raw(tmp_path) + finally: + os.unlink(tmp_path) diff --git a/src/tdx2db/smb_accessor.py b/src/tdx2db/smb_accessor.py new file mode 100644 index 0000000..eaf7deb --- /dev/null +++ b/src/tdx2db/smb_accessor.py @@ -0,0 +1,134 @@ +"""SMB 网络文件访问封装。 + +依赖:smbprotocol(pip install smbprotocol) +""" +import os +import tempfile +from typing import List, Optional + +import smbclient +import smbclient.path + +from .logger import logger + + +class SmbAccessor: + """封装对 SMB 共享目录的只读访问。 + + UNC 路径格式:\\\\host\\share\\tdx_path\\vipdoc\\... + """ + + def __init__( + self, + host: str, + share: str, + tdx_path: str = '', + username: Optional[str] = None, + password: Optional[str] = None, + port: int = 445, + ) -> None: + self.host = host + self.share = share.strip('\\/') + self._tdx_rel = tdx_path.strip('\\/') + self._username = username + self._password = password + self._port = port + self._registered = False + + # ── 上下文管理器 ────────────────────────────────────────────────────────── + + def __enter__(self) -> 'SmbAccessor': + self._register() + return self + + def __exit__(self, *_) -> None: + self._unregister() + + def _register(self) -> None: + if not self._registered: + smbclient.register_session( + self.host, + username=self._username, + password=self._password, + port=self._port, + ) + self._registered = True + logger.debug(f"SMB 会话已建立: {self.host}:{self._port}") + + def _unregister(self) -> None: + if self._registered: + try: + smbclient.reset_connection_cache() + except Exception: + pass + self._registered = False + + # ── 路径构建 ────────────────────────────────────────────────────────────── + + def _unc(self, *parts: str) -> str: + """构建 UNC 路径字符串。 + + 示例(tdx_path='TDX'): + _unc('vipdoc', 'sz', 'lday') → '\\\\host\\share\\TDX\\vipdoc\\sz\\lday' + 示例(tdx_path=''): + _unc('vipdoc') → '\\\\host\\share\\vipdoc' + """ + segments = [self.host, self.share] + if self._tdx_rel: + segments.append(self._tdx_rel) + segments.extend(p.strip('\\/') for p in parts if p) + return '\\\\' + '\\'.join(segments) + + @property + def vipdoc_unc(self) -> str: + return self._unc('vipdoc') + + @property + def gbbq_unc(self) -> str: + return self._unc('T0002', 'hq_cache', 'gbbq') + + def lday_dir_unc(self, market: str) -> str: + return self._unc('vipdoc', market, 'lday') + + def day_file_unc(self, market: str, filename: str) -> str: + return self._unc('vipdoc', market, 'lday', filename) + + # ── 核心 I/O 操作 ───────────────────────────────────────────────────────── + + def exists(self, unc_path: str) -> bool: + try: + return smbclient.path.exists(unc_path) + except Exception: + return False + + def list_files(self, unc_dir: str, suffix: str = '') -> List[str]: + """列出目录下的文件名(不含路径),可按后缀过滤。""" + try: + entries = smbclient.listdir(unc_dir) + if suffix: + return [e for e in entries if e.endswith(suffix)] + return entries + except Exception as e: + logger.warning(f"SMB 列目录失败 {unc_dir}: {e}") + return [] + + def read_bytes(self, unc_path: str) -> bytes: + with smbclient.open_file(unc_path, mode='rb') as f: + return f.read() + + def download_to_tmp(self, unc_path: str, suffix: str = '.day') -> str: + """将远程文件下载到本地临时文件,返回临时文件路径字符串。 + + 调用方负责删除临时文件(使用 try/finally)。 + """ + data = self.read_bytes(unc_path) + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + try: + tmp.write(data) + tmp.flush() + tmp.close() + return tmp.name + except Exception: + tmp.close() + os.unlink(tmp.name) + raise diff --git a/tests/test_smb.py b/tests/test_smb.py new file mode 100644 index 0000000..ffc2d16 --- /dev/null +++ b/tests/test_smb.py @@ -0,0 +1,376 @@ +""" +SMB 访问模式测试套件(全部 mock,不发起真实网络请求) + +测试用例: +1. TestSmbAccessorUncPaths - UNC 路径构建逻辑 +2. TestSmbAccessorIO - I/O 方法(register_session, exists, list_files, download_to_tmp) +3. TestReaderSmbMode - TdxDataReader 在 SMB 模式下的三个核心方法 +4. TestCliSmbInit - CLI SMB 参数解析与初始化 +""" + +import os +import struct +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch, call + +import pandas as pd +import pytest + +from src.tdx2db.smb_accessor import SmbAccessor +from src.tdx2db.reader import TdxDataReader + + +# ─── 工具函数 ──────────────────────────────────────────────────────────────── + +def _make_day_bytes(n: int = 3) -> bytes: + fmt = ' Date: Sun, 12 Apr 2026 13:07:25 +0800 Subject: [PATCH 20/26] =?UTF-8?q?feat:=20SMB=20=E6=A8=A1=E5=BC=8F=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E6=89=B9=E9=87=8F=E5=B9=B6=E5=8F=91=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=EF=BC=8C=E5=A4=A7=E5=B9=85=E6=8F=90=E5=8D=87=E5=90=8C=E6=AD=A5?= =?UTF-8?q?=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - config.py: 新增 SMB_WORKERS(默认16)和 SMB_BATCH_SIZE(默认200)配置项 - smb_accessor.py: 新增 download_batch_to_dir() 批量并发下载方法(ThreadPoolExecutor) - reader.py: 新增 _parse_local_day_file() 统一解析逻辑,新增 read_daily_data_batch() 批量生成器 - cli.py: sync_all_daily() SMB 模式走批量并发路径,本地模式保持串行不变 - .env.example: 补充 SMB_WORKERS/SMB_BATCH_SIZE 配置说明 Co-Authored-By: Claude Sonnet 4.6 --- .env.example | 4 ++ src/tdx2db/cli.py | 122 ++++++++++++++++++++++++------------- src/tdx2db/config.py | 4 ++ src/tdx2db/reader.py | 82 ++++++++++++++++++++++--- src/tdx2db/smb_accessor.py | 35 ++++++++++- 5 files changed, 194 insertions(+), 53 deletions(-) diff --git a/.env.example b/.env.example index 156cdf1..6cebd8c 100644 --- a/.env.example +++ b/.env.example @@ -30,3 +30,7 @@ SMB_PASSWORD=mypassword # TDX 在共享目录内的相对路径,若共享根目录即为 TDX 安装目录则留空 SMB_TDX_PATH=TDX SMB_PORT=445 +# 并发下载线程数(批量 SMB 模式) +SMB_WORKERS=16 +# 每批同步的股票数量 +SMB_BATCH_SIZE=200 diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index cf877c5..1a2e19d 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -72,58 +72,96 @@ def sync_all_daily( start_date: Optional[int] = None, end_date: Optional[int] = None, ) -> dict: - """逐股票流式同步日线数据,返回统计信息。""" + """逐股票流式同步日线数据,返回统计信息。SMB 模式下使用批量并发下载。""" stocks = reader.get_stock_list() logger.info(f"共 {len(stocks)} 只股票") latest_dates = storage.get_all_latest_dates() if incremental else {} stats = {'total': len(stocks), 'success': 0, 'failed': 0} - iterator = tqdm(stocks, total=len(stocks), desc="同步日线") if config.use_tqdm else stocks - - for db_code in iterator: - pure_code, suffix = db_code.split('.') - market = {'SZ': 0, 'SH': 1, 'BJ': 2}[suffix] - code = suffix.lower() + pure_code # sz000001,供 read_daily_data 使用 + def _process_one(db_code: str, data: pd.DataFrame) -> None: + """处理单只股票的数据并写库。""" + pure_code = db_code.split('.')[0] last_date = latest_dates.get(db_code) - try: - data = reader.read_daily_data(market, code) - if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): - data = data.reset_index() - if data.empty: - stats['success'] += 1 - continue - - needs_refresh = ( - incremental and last_date is not None and - _has_ex_rights_after(pure_code, gbbq, last_date) - ) - - processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) - - if incremental and last_date and not needs_refresh: - processed = processed[processed['date'] > last_date] - if start_date: - processed = processed[processed['date'] >= str(start_date)] - if end_date: - processed = processed[processed['date'] <= str(end_date)] - - if processed.empty: - stats['success'] += 1 - continue - - if needs_refresh: - storage.delete_stock_data(db_code) - storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date'), - batch_size=config.db_batch_size) + if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): + data = data.reset_index() + if data.empty: + stats['success'] += 1 + return + + needs_refresh = ( + incremental and last_date is not None and + _has_ex_rights_after(pure_code, gbbq, last_date) + ) + + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + + if incremental and last_date and not needs_refresh: + processed = processed[processed['date'] > last_date] + if start_date: + processed = processed[processed['date'] >= str(start_date)] + if end_date: + processed = processed[processed['date'] <= str(end_date)] + + if processed.empty: stats['success'] += 1 + return + + if needs_refresh: + storage.delete_stock_data(db_code) + storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date'), + batch_size=config.db_batch_size) + stats['success'] += 1 + + if reader._smb is not None: + # SMB 批量并发模式:分批预下载,再串行解析入库 + stocks_meta = [] + for db_code in stocks: + pure_code, suffix = db_code.split('.') + market = {'SZ': 0, 'SH': 1, 'BJ': 2}[suffix] + code = suffix.lower() + pure_code + stocks_meta.append((market, code, db_code)) + + batch_iter = reader.read_daily_data_batch( + stocks_meta, + batch_size=config.smb_batch_size, + smb_workers=config.smb_workers, + ) + iterator = ( + tqdm(batch_iter, total=len(stocks), desc="同步日线(SMB批量)") + if config.use_tqdm else batch_iter + ) + + for db_code, result in iterator: + if isinstance(result, FileNotFoundError): + stats['failed'] += 1 + elif isinstance(result, Exception): + logger.error(f"处理 {db_code} 时出错: {result}") + stats['failed'] += 1 + else: + try: + _process_one(db_code, result) + except Exception as e: + logger.error(f"处理 {db_code} 时出错: {e}") + stats['failed'] += 1 + else: + # 本地串行模式(不变) + iterator = tqdm(stocks, total=len(stocks), desc="同步日线") if config.use_tqdm else stocks - except FileNotFoundError: - stats['failed'] += 1 - except Exception as e: - logger.error(f"处理 {code} 时出错: {e}") - stats['failed'] += 1 + for db_code in iterator: + pure_code, suffix = db_code.split('.') + market = {'SZ': 0, 'SH': 1, 'BJ': 2}[suffix] + code = suffix.lower() + pure_code # sz000001,供 read_daily_data 使用 + + try: + data = reader.read_daily_data(market, code) + _process_one(db_code, data) + except FileNotFoundError: + stats['failed'] += 1 + except Exception as e: + logger.error(f"处理 {code} 时出错: {e}") + stats['failed'] += 1 logger.info(f"同步完成: 成功 {stats['success']},失败 {stats['failed']}") return stats diff --git a/src/tdx2db/config.py b/src/tdx2db/config.py index 58068b1..17ee1c0 100644 --- a/src/tdx2db/config.py +++ b/src/tdx2db/config.py @@ -23,6 +23,8 @@ class Config: smb_password: str smb_tdx_path: str smb_port: int + smb_workers: int + smb_batch_size: int def __init__(self) -> None: self.tdx_path = os.getenv('TDX_PATH', '') @@ -43,6 +45,8 @@ def __init__(self) -> None: self.smb_password = os.getenv('SMB_PASSWORD', '') self.smb_tdx_path = os.getenv('SMB_TDX_PATH', '') self.smb_port = int(os.getenv('SMB_PORT', '445')) + self.smb_workers = int(os.getenv('SMB_WORKERS', '16')) + self.smb_batch_size = int(os.getenv('SMB_BATCH_SIZE', '200')) @property def database_url(self) -> str: diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index f052e78..35d8da3 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -1,10 +1,12 @@ import io import os import re +import shutil import struct +import tempfile from contextlib import redirect_stdout from pathlib import Path -from typing import Optional, TYPE_CHECKING +from typing import Iterator, List, Optional, Tuple, TYPE_CHECKING import pandas as pd from pytdx.reader import TdxDailyBarReader, GbbqReader @@ -189,14 +191,74 @@ def _get_stock_list_smb(self) -> list: def _read_daily_via_smb(self, unc: str) -> pd.DataFrame: tmp_path = self._smb.download_to_tmp(unc, suffix='.day') try: - try: - with redirect_stdout(io.StringIO()): - sec_type = self.daily_reader.get_security_type(tmp_path) - if sec_type in self.daily_reader.SECURITY_TYPE: - return self.daily_reader.get_df(tmp_path) - else: - return self._read_day_file_raw(tmp_path) - except Exception: - return self._read_day_file_raw(tmp_path) + return self._parse_local_day_file(tmp_path) finally: os.unlink(tmp_path) + + def _parse_local_day_file(self, path: str) -> pd.DataFrame: + """解析本地 .day 文件(pytdx 优先,不支持的类型降级到原始二进制解析)。""" + try: + with redirect_stdout(io.StringIO()): + sec_type = self.daily_reader.get_security_type(path) + if sec_type in self.daily_reader.SECURITY_TYPE: + return self.daily_reader.get_df(path) + else: + return self._read_day_file_raw(path) + except Exception: + return self._read_day_file_raw(path) + + def read_daily_data_batch( + self, + stocks_meta: List[Tuple[int, str, str]], + batch_size: int = 200, + smb_workers: int = 16, + ) -> Iterator[Tuple[str, 'pd.DataFrame | Exception']]: + """批量并发读取日线数据(仅 SMB 模式)。 + + Parameters + ---------- + stocks_meta: + List of (market, code, db_code),与 read_daily_data() 参数对应。 + - market: 0=深圳 1=上海 2=北京 + - code: 带市场前缀的代码,如 sz000001 + - db_code: 数据库代码,如 000001.SZ + + Yields + ------ + (db_code, DataFrame) 或 (db_code, Exception) + """ + if self._smb is None: + raise RuntimeError("read_daily_data_batch 仅支持 SMB 模式") + + market_map = {0: 'sz', 1: 'sh', 2: 'bj'} + + for batch_start in range(0, len(stocks_meta), batch_size): + batch = stocks_meta[batch_start:batch_start + batch_size] + + # 构建 unc_path → (market, pure_code, db_code) 映射 + unc_map: dict = {} + for market, code, db_code in batch: + folder = market_map[market] + pure_code = code[-6:] + filename = f"{folder}{pure_code}.day" + unc = self._smb.day_file_unc(folder, filename) + unc_map[unc] = (market, pure_code, db_code) + + tmp_dir = tempfile.mkdtemp(prefix='tdx2db_batch_') + try: + local_map = self._smb.download_batch_to_dir( + list(unc_map.keys()), tmp_dir, max_workers=smb_workers + ) + for unc, (market, pure_code, db_code) in unc_map.items(): + if unc not in local_map: + yield db_code, FileNotFoundError(f"SMB 下载失败: {unc}") + continue + try: + data = self._parse_local_day_file(local_map[unc]) + data['code'] = pure_code + data['market'] = market + yield db_code, data + except Exception as e: + yield db_code, e + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/src/tdx2db/smb_accessor.py b/src/tdx2db/smb_accessor.py index eaf7deb..5c4ab93 100644 --- a/src/tdx2db/smb_accessor.py +++ b/src/tdx2db/smb_accessor.py @@ -4,7 +4,7 @@ """ import os import tempfile -from typing import List, Optional +from typing import Dict, List, Optional import smbclient import smbclient.path @@ -132,3 +132,36 @@ def download_to_tmp(self, unc_path: str, suffix: str = '.day') -> str: tmp.close() os.unlink(tmp.name) raise + + def download_batch_to_dir( + self, + unc_paths: List[str], + dest_dir: str, + max_workers: int = 16, + ) -> Dict[str, str]: + """并发下载多个 UNC 文件到 dest_dir 目录,返回 {unc_path: local_path}。 + + 下载失败的文件不出现在返回字典中(异常已记录到日志)。 + """ + import concurrent.futures + import hashlib + + def _download_one(unc: str): + name = hashlib.md5(unc.encode()).hexdigest() + '.day' + local = os.path.join(dest_dir, name) + data = self.read_bytes(unc) + with open(local, 'wb') as f: + f.write(data) + return unc, local + + result: Dict[str, str] = {} + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: + futures = {pool.submit(_download_one, unc): unc for unc in unc_paths} + for fut in concurrent.futures.as_completed(futures): + unc = futures[fut] + try: + _, local = fut.result() + result[unc] = local + except Exception as e: + logger.warning(f"SMB 批量下载失败 {unc}: {e}") + return result From 2b26f50b51eb4fd161b8437f34d1f1686748ef04 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sun, 12 Apr 2026 14:50:35 +0800 Subject: [PATCH 21/26] =?UTF-8?q?feat:=20=E8=AE=A1=E7=AE=97=E5=B9=B6?= =?UTF-8?q?=E5=86=99=E5=85=A5=E6=97=A5=E7=BA=BF=E6=8D=A2=E6=89=8B=E7=8E=87?= =?UTF-8?q?=EF=BC=88turnover=5Frate=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - processor.py:新增 _calc_turnover_rate(),从 gbbq category==5 记录取 流通股本(万股),通过 merge_asof 匹配每个交易日,计算 换手率(%) = volume(手) × 10000 / 流通股本(股) - processor.py:process_daily_data() 中替换原 turnover_rate=None 占位符, gbbq 有效时实际计算并填入 - smb_accessor.py:修复 exists() 在 SMB session 未注册时静默返回 False 的 bug,导致 gbbq 文件始终判断为不存在而跳过复权和换手率计算 Co-Authored-By: Claude Sonnet 4.6 --- .gitignore | 2 +- src/tdx2db/processor.py | 40 ++++++++++++++++++++++++++++++++++++-- src/tdx2db/smb_accessor.py | 1 + 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 6266212..2ffc6b7 100644 --- a/.gitignore +++ b/.gitignore @@ -45,5 +45,5 @@ poetry.lock output/ */__pycache__/ -tdx_data.db +tdx_data.db* .claude/ \ No newline at end of file diff --git a/src/tdx2db/processor.py b/src/tdx2db/processor.py index 09f7bdf..f9d7be0 100644 --- a/src/tdx2db/processor.py +++ b/src/tdx2db/processor.py @@ -134,6 +134,40 @@ def apply_backward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: df[col] = (df[col] * df['adj_factor']).round(3) return df + @staticmethod + def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.Series: + """计算换手率(%):volume(手) × 100 / 流通股本(股) × 100 = volume × 10000 / 流通股本(股)。 + + 流通股本来自 gbbq category==5 记录的 hongli_panqianliutong 字段(单位:万股)。 + datetime 字段为 YYYYMMDD 整数,使用 merge_asof 取每个交易日最近一次股本记录。 + """ + market_val = df['market'].iloc[0] + prefix = {0: 'sz', 1: 'sh', 2: 'bj'}.get(market_val, 'sz') + full_code = prefix + str(df['code'].iloc[0]).zfill(6) + + shares = gbbq[(gbbq['full_code'] == full_code) & (gbbq['category'] == 5)].copy() + if shares.empty: + return pd.Series([None] * len(df), index=df.index, dtype=float) + + # datetime 直接是 YYYYMMDD 整数 + shares = shares.rename(columns={'datetime': 'date_int'}) + shares = shares.sort_values('date_int').drop_duplicates('date_int', keep='last') + + daily = df[['date', 'volume']].copy() + daily['date_int'] = daily['date'].astype(int) + daily = daily.sort_values('date_int').reset_index(drop=False) + + merged = pd.merge_asof( + daily[['index', 'date_int', 'volume']], + shares[['date_int', 'hongli_panqianliutong']], + on='date_int' + ) + # 流通股本单位为万股,× 10000 换算为股 + cap = merged['hongli_panqianliutong'] * 10000 + # 换手率(%) = volume(手) × 100 / 流通股本(股) × 100 = volume × 10000 / 流通股本(股) + merged['turnover_rate'] = (merged['volume'] * 10000 / cap).where(cap > 0).round(4) + return merged.set_index('index')['turnover_rate'].reindex(df.index) + @staticmethod def process_daily_data( df: pd.DataFrame, @@ -180,8 +214,10 @@ def process_daily_data( # 日期转 YYYYMMDD 字符串 processed['date'] = processed['date'].dt.strftime('%Y%m%d') - # 预留 turnover_rate 列 - if 'turnover_rate' not in processed.columns: + # 计算换手率 + if gbbq is not None and not gbbq.empty and 'code' in processed.columns: + processed['turnover_rate'] = DataProcessor._calc_turnover_rate(processed, gbbq) + else: processed['turnover_rate'] = None # 生成带市场后缀的 stock_code,如 000001.SZ / 600000.SH / 920001.BJ diff --git a/src/tdx2db/smb_accessor.py b/src/tdx2db/smb_accessor.py index 5c4ab93..95af9f6 100644 --- a/src/tdx2db/smb_accessor.py +++ b/src/tdx2db/smb_accessor.py @@ -97,6 +97,7 @@ def day_file_unc(self, market: str, filename: str) -> str: def exists(self, unc_path: str) -> bool: try: + self._register() return smbclient.path.exists(unc_path) except Exception: return False From 2d4dd8b7b79ec547825d2cc3a47bbd63b27dcdb4 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sun, 12 Apr 2026 20:17:05 +0800 Subject: [PATCH 22/26] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=8D=A2?= =?UTF-8?q?=E6=89=8B=E7=8E=87=E8=AE=A1=E7=AE=97=E5=8F=8A=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - smb_accessor: share_access 改为 'rw' 修复 SMB 文件访问冲突 - smb_accessor: 删除重复的 base_dbf_unc 属性 - reader: _parse_base_dbf 股票代码字段名 GDM → GPDM - processor: build_float_capital_map 加调试日志 - processor: _calc_turnover_rate 加调试日志 - tests: 修复 test_exists_returns_true 缺少 register_session mock - tests: 修复 test_vipdoc_path_get_stock_list 返回值类型错误 --- requirements.txt | 1 + src/tdx2db/cli.py | 19 ++++-- src/tdx2db/processor.py | 107 +++++++++++++++++++++++++++++++-- src/tdx2db/reader.py | 46 +++++++++++++++ src/tdx2db/smb_accessor.py | 6 +- tests/test_download.py | 8 +-- tests/test_float_capital.py | 114 ++++++++++++++++++++++++++++++++++++ tests/test_smb.py | 3 +- 8 files changed, 288 insertions(+), 16 deletions(-) create mode 100644 tests/test_float_capital.py diff --git a/requirements.txt b/requirements.txt index 09c48aa..7978e5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ psycopg2-binary>=2.9.10 requests>=2.32.0 akshare smbprotocol>=1.13.0 +dbfread>=2.0.7 diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 1a2e19d..7928054 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -71,6 +71,7 @@ def sync_all_daily( incremental: bool = True, start_date: Optional[int] = None, end_date: Optional[int] = None, + float_cap_map: dict = None, ) -> dict: """逐股票流式同步日线数据,返回统计信息。SMB 模式下使用批量并发下载。""" stocks = reader.get_stock_list() @@ -95,7 +96,7 @@ def _process_one(db_code: str, data: pd.DataFrame) -> None: _has_ex_rights_after(pure_code, gbbq, last_date) ) - processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type, float_cap_map=float_cap_map) if incremental and last_date and not needs_refresh: processed = processed[processed['date'] > last_date] @@ -318,6 +319,8 @@ def main() -> int: elif args.command == 'daily': adj_type = getattr(args, 'adj', 'forward') gbbq = reader.read_gbbq() + base_caps = reader.read_base_dbf() + float_cap_map = DataProcessor.build_float_capital_map(base_caps, gbbq) if base_caps else {} if args.code: pure_code = args.code[-6:] if len(args.code) > 6 else args.code @@ -333,7 +336,7 @@ def main() -> int: data = reader.read_daily_data(market, code) if isinstance(data.index, pd.DatetimeIndex) or data.index.name in ('date', 'datetime'): data = data.reset_index() - processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type) + processed = processor.process_daily_data(data, gbbq=gbbq, adj_type=adj_type, float_cap_map=float_cap_map) processed = processor.filter_data(processed, start_date=args.start, end_date=args.end) if not processed.empty: storage.save_incremental(processed, 'daily_data', conflict_columns=('stock_code', 'date')) @@ -344,7 +347,8 @@ def main() -> int: incremental = getattr(args, 'incremental', False) stats = sync_all_daily(reader, processor, storage, gbbq, adj_type=adj_type, incremental=incremental, - start_date=args.start, end_date=args.end) + start_date=args.start, end_date=args.end, + float_cap_map=float_cap_map) storage.save_sync_statistics(stats['success']) elif args.command == 'sync': @@ -352,8 +356,15 @@ def main() -> int: logger.info("=== 开始增量同步日线数据 ===") sync_stock_list(reader, storage) gbbq = reader.read_gbbq() + + base_caps = reader.read_base_dbf() + float_cap_map = DataProcessor.build_float_capital_map(base_caps, gbbq) if base_caps else {} + if not float_cap_map: + logger.warning("base.dbf 读取失败,换手率将降级使用 gbbq category==5 逻辑") + stats = sync_all_daily(reader, processor, storage, gbbq, - adj_type=adj_type, incremental=True) + adj_type=adj_type, incremental=True, + float_cap_map=float_cap_map) storage.save_sync_statistics(stats['success']) else: diff --git a/src/tdx2db/processor.py b/src/tdx2db/processor.py index f9d7be0..2eb7a29 100644 --- a/src/tdx2db/processor.py +++ b/src/tdx2db/processor.py @@ -135,16 +135,108 @@ def apply_backward_adj(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.DataFrame: return df @staticmethod - def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.Series: - """计算换手率(%):volume(手) × 100 / 流通股本(股) × 100 = volume × 10000 / 流通股本(股)。 + def build_float_capital_map(base_caps: dict, gbbq: pd.DataFrame) -> dict: + """从 base.dbf 当前流通股本出发,逆向推算历史各时间点的流通股本。 - 流通股本来自 gbbq category==5 记录的 hongli_panqianliutong 字段(单位:万股)。 - datetime 字段为 YYYYMMDD 整数,使用 merge_asof 取每个交易日最近一次股本记录。 + Returns: {full_code: [(date_int, cap_万股), ...]} 已按 date_int 升序排列, + 供 merge_asof 使用(date_int 表示"从该日期起生效的流通股本")。 + """ + if gbbq.empty or not base_caps: + logger.debug(f"[float_cap] 跳过:gbbq.empty={gbbq.empty} base_caps空={not base_caps}") + return {} + + logger.debug(f"[float_cap] base_caps 共 {len(base_caps)} 条,gbbq 共 {len(gbbq)} 行") + result = {} + relevant_cats = {1, 10, 11, 12, 15} + gbbq_filtered = gbbq[gbbq['category'].isin(relevant_cats)].copy() + logger.debug(f"[float_cap] gbbq_filtered(cat∈{relevant_cats})共 {len(gbbq_filtered)} 行,unique full_code={gbbq_filtered['full_code'].nunique()}") + + for full_code, group in gbbq_filtered.groupby('full_code'): + pure_code = full_code[2:] # 去掉 sz/sh/bj 前缀 + if pure_code not in base_caps: + continue + + cap = float(base_caps[pure_code]) + logger.debug(f"[float_cap] {full_code} base_cap={cap:.2f} 万股,gbbq 事件数={len(group)}") + events = group.sort_values('datetime', ascending=False) + snapshots = [] + + for _, ev in events.iterrows(): + date_int = int(ev['datetime']) + cat = int(ev['category']) + + # 先记录:从 date_int 起生效的股本(事件发生后的值) + snapshots.append((date_int, cap)) + + # 再逆向推算事件发生前的 cap + if cat == 1: + songgu = float(ev.get('songgu_qianzongguben', 0) or 0) / 10.0 + peigu = float(ev.get('peigu_houzongguben', 0) or 0) / 10.0 + ratio = songgu + peigu + if ratio > 0: + cap = cap / (1.0 + ratio) + elif cat in (11, 12, 15): + # 增发/解禁/债转股:S_before = S_after - N + value = float(ev.get('hongli_panqianliutong', 0) or 0) + cap = cap - value + if cap <= 0: + cap = float(base_caps[pure_code]) # 异常保护 + elif cat == 10: + # 回购注销:注销后股本减少,回溯要加回来 + value = float(ev.get('hongli_panqianliutong', 0) or 0) + cap = cap + value + logger.debug(f" [{full_code}] cat={cat} date={date_int} → cap_before={cap:.2f} 万股") + + # 兜底:最早历史数据使用推算到底的 cap + snapshots.append((0, cap)) + snapshots.sort(key=lambda x: x[0]) + result[full_code] = snapshots + + logger.info(f"build_float_capital_map 完成,覆盖 {len(result)} 只股票") + return result + + @staticmethod + def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame, float_cap_map: dict = None) -> pd.Series: + """计算换手率(%):volume(手) × 10000 / 流通股本(股)。 + + 优先使用 float_cap_map(base.dbf 锚点 + gbbq 逆向推算); + float_cap_map 为 None 时降级到原有 gbbq category==5 逻辑。 """ market_val = df['market'].iloc[0] prefix = {0: 'sz', 1: 'sh', 2: 'bj'}.get(market_val, 'sz') full_code = prefix + str(df['code'].iloc[0]).zfill(6) + # ── 优先路径:float_cap_map ────────────────────────────────────────── + if float_cap_map is not None and full_code in float_cap_map: + snap_list = float_cap_map[full_code] # [(date_int, cap_万股)],升序 + logger.debug(f"[turnover] {full_code} 使用 float_cap_map,快照数={len(snap_list)}") + if snap_list: + snap_df = pd.DataFrame(snap_list, columns=['date_int', 'float_cap']) + daily = df[['date', 'volume']].copy() + daily['date_int'] = daily['date'].astype(int) + daily = daily.sort_values('date_int').reset_index(drop=False) + merged = pd.merge_asof( + daily[['index', 'date_int', 'volume']], + snap_df, + on='date_int' + ) + cap = merged['float_cap'] * 10000 # 万股 → 股 + merged['turnover_rate'] = ( + (merged['volume'] * 10000 / cap).where(cap > 0).round(4) + ) + # 调试:打印最近几条 + sample = merged[['date_int', 'volume', 'float_cap', 'turnover_rate']].tail(3) + for _, row in sample.iterrows(): + logger.debug( + f" [{full_code}] date={int(row['date_int'])} " + f"vol={row['volume']:.0f} cap={row['float_cap']:.2f}万股 " + f"turnover={row['turnover_rate']:.4f}%" + ) + return merged.set_index('index')['turnover_rate'].reindex(df.index) + + # ── 降级路径:原有 gbbq category==5 逻辑 ──────────────────────────── + if gbbq.empty or 'full_code' not in gbbq.columns: + return pd.Series([None] * len(df), index=df.index, dtype=float) shares = gbbq[(gbbq['full_code'] == full_code) & (gbbq['category'] == 5)].copy() if shares.empty: return pd.Series([None] * len(df), index=df.index, dtype=float) @@ -172,7 +264,8 @@ def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame) -> pd.Series: def process_daily_data( df: pd.DataFrame, gbbq: pd.DataFrame = None, - adj_type: str = 'forward' + adj_type: str = 'forward', + float_cap_map: dict = None, ) -> pd.DataFrame: """日线处理主流程:reset_index → 填充缺失值 → 校验 → 复权 → 日期转 YYYYMMDD 整数。""" if df.empty: @@ -216,7 +309,9 @@ def process_daily_data( # 计算换手率 if gbbq is not None and not gbbq.empty and 'code' in processed.columns: - processed['turnover_rate'] = DataProcessor._calc_turnover_rate(processed, gbbq) + processed['turnover_rate'] = DataProcessor._calc_turnover_rate( + processed, gbbq, float_cap_map=float_cap_map + ) else: processed['turnover_rate'] = None diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index 35d8da3..dbe7cb3 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -67,6 +67,19 @@ def read_gbbq(self) -> pd.DataFrame: logger.warning(f"读取权息文件时出错: {e},将跳过复权处理") return pd.DataFrame() + def read_base_dbf(self) -> dict: + """读取 base.dbf,返回 {code_6位: 流通股本万股} 字典。文件不存在时返回空字典。""" + if self._smb is not None: + return self._read_base_dbf_smb() + if self.tdx_path is None: + logger.warning("联网下载模式下不支持读取 base.dbf") + return {} + dbf_path = self.tdx_path / 'T0002' / 'hq_cache' / 'base.dbf' + if not dbf_path.exists(): + logger.warning(f"base.dbf 不存在: {dbf_path}") + return {} + return self._parse_base_dbf(str(dbf_path)) + def get_stock_list(self) -> list: """扫描本地 .day 文件,返回有数据的股票代码列表(000001.SZ 格式)。""" if self._smb is not None: @@ -151,6 +164,39 @@ def _read_day_file_raw(fname: str) -> pd.DataFrame: # ── SMB 私有方法 ────────────────────────────────────────────────────────── + @staticmethod + def _parse_base_dbf(path: str) -> dict: + try: + from dbfread import DBF + except ImportError: + raise ImportError("缺少 dbfread 库,请执行: pip install dbfread") + result = {} + for record in DBF(path, encoding='gbk', load=True): + code = str(record.get('GPDM', '') or '').strip().zfill(6) + ltag = record.get('LTAG') + if code and ltag is not None: + try: + result[code] = float(ltag) + except (TypeError, ValueError): + pass + logger.info(f"base.dbf 读取完成,共 {len(result)} 条记录") + return result + + def _read_base_dbf_smb(self) -> dict: + unc = self._smb.base_dbf_unc + if not self._smb.exists(unc): + raise FileNotFoundError(f"SMB base.dbf 不存在: {unc}") + try: + tmp_path = self._smb.download_to_tmp(unc, suffix='.dbf') + except Exception as e: + raise RuntimeError( + f"base.dbf 无法读取(可能被 TDX 锁定,请关闭 TDX 后重试): {e}" + ) from e + try: + return self._parse_base_dbf(tmp_path) + finally: + os.unlink(tmp_path) + def _read_gbbq_smb(self) -> pd.DataFrame: unc = self._smb.gbbq_unc if not self._smb.exists(unc): diff --git a/src/tdx2db/smb_accessor.py b/src/tdx2db/smb_accessor.py index 95af9f6..b9cf00b 100644 --- a/src/tdx2db/smb_accessor.py +++ b/src/tdx2db/smb_accessor.py @@ -87,6 +87,10 @@ def vipdoc_unc(self) -> str: def gbbq_unc(self) -> str: return self._unc('T0002', 'hq_cache', 'gbbq') + @property + def base_dbf_unc(self) -> str: + return self._unc('T0002', 'hq_cache', 'base.dbf') + def lday_dir_unc(self, market: str) -> str: return self._unc('vipdoc', market, 'lday') @@ -114,7 +118,7 @@ def list_files(self, unc_dir: str, suffix: str = '') -> List[str]: return [] def read_bytes(self, unc_path: str) -> bytes: - with smbclient.open_file(unc_path, mode='rb') as f: + with smbclient.open_file(unc_path, mode='rb', share_access='rw') as f: return f.read() def download_to_tmp(self, unc_path: str, suffix: str = '.day') -> str: diff --git a/tests/test_download.py b/tests/test_download.py index 2d29712..a228b32 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -270,10 +270,10 @@ def test_vipdoc_path_get_stock_list(self, tmp_path): reader = TdxDataReader(vipdoc_path=str(vipdoc)) stocks = reader.get_stock_list() - codes = set(stocks['code'].tolist()) - assert 'sz000001' in codes - assert 'sz000002' in codes - assert 'sh600000' in codes + codes = set(stocks) + assert '000001.SZ' in codes + assert '000002.SZ' in codes + assert '600000.SH' in codes def test_vipdoc_path_read_daily_data(self, tmp_path): """vipdoc_path 模式下 read_daily_data 应返回正确的 DataFrame。""" diff --git a/tests/test_float_capital.py b/tests/test_float_capital.py new file mode 100644 index 0000000..21081b8 --- /dev/null +++ b/tests/test_float_capital.py @@ -0,0 +1,114 @@ +""" +float_cap_map 构建和换手率计算测试 +""" +import pandas as pd +import pytest + +from src.tdx2db.processor import DataProcessor + + +def make_gbbq_event(full_code: str, date_int: int, category: int, value: float) -> pd.DataFrame: + market_val = 1 if full_code.startswith('sh') else 0 + return pd.DataFrame([{ + 'market': market_val, + 'code': int(full_code[2:]), + 'datetime': date_int, + 'category': category, + 'hongli_panqianliutong': value, + 'songgu_qianzongguben': value if category == 1 else 0, + 'peigu_houzongguben': 0, + 'peigujia_qianzongguben': 0, + 'full_code': full_code, + }]) + + +class TestBuildFloatCapitalMap: + + def test_cat1_songgu(self): + """category==1 送股:历史股本 = 当前 / (1 + ratio)""" + # 当前 1100 万股,曾经 10 送 1(ratio=0.1) + base_caps = {'000001': 1100.0} + gbbq = make_gbbq_event('sz000001', 20240101, 1, 1.0) # songgu=1.0 → ratio=0.1 + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + + assert 'sz000001' in result + snapshots = dict(result['sz000001']) + # 事件日期 20240101 起生效的股本 = 1100(送股后) + assert abs(snapshots[20240101] - 1100.0) < 0.01 + # 兜底(date=0)= 1100 / 1.1 = 1000 + assert abs(snapshots[0] - 1000.0) < 0.01 + + def test_cat12_jiejin(self): + """category==12 解禁:历史股本 = 当前 - N""" + base_caps = {'000001': 1100.0} + gbbq = make_gbbq_event('sz000001', 20240101, 12, 100.0) + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + + snapshots = dict(result['sz000001']) + assert abs(snapshots[20240101] - 1100.0) < 0.01 + assert abs(snapshots[0] - 1000.0) < 0.01 + + def test_cat10_zhuxiao(self): + """category==10 回购注销:历史股本 = 当前 + N(注销后股本减少,回溯加回)""" + base_caps = {'000001': 900.0} + gbbq = make_gbbq_event('sz000001', 20240101, 10, 100.0) + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + + snapshots = dict(result['sz000001']) + assert abs(snapshots[20240101] - 900.0) < 0.01 + assert abs(snapshots[0] - 1000.0) < 0.01 + + def test_empty_base_caps(self): + gbbq = make_gbbq_event('sz000001', 20240101, 1, 1.0) + assert DataProcessor.build_float_capital_map({}, gbbq) == {} + + def test_empty_gbbq(self): + assert DataProcessor.build_float_capital_map({'000001': 1000.0}, pd.DataFrame()) == {} + + def test_code_not_in_base_caps(self): + """gbbq 有记录但 base_caps 没有该股票,应跳过""" + base_caps = {'000002': 500.0} + gbbq = make_gbbq_event('sz000001', 20240101, 1, 1.0) + result = DataProcessor.build_float_capital_map(base_caps, gbbq) + assert 'sz000001' not in result + + +class TestCalcTurnoverRateWithMap: + + def _make_df(self, code='000001', market=0, dates=None, volume=1e6): + if dates is None: + dates = ['20240101', '20240102', '20240103'] + return pd.DataFrame({ + 'date': dates, + 'volume': [volume] * len(dates), + 'code': [code] * len(dates), + 'market': [market] * len(dates), + }) + + def test_priority_path_used(self): + """float_cap_map 存在时走优先路径,换手率应有值""" + df = self._make_df(volume=1000.0) + # 流通股本 1000 万股 = 1000 * 10000 = 1e7 股 + # 换手率 = volume(手) * 10000 / 流通股本(股) = 1000 * 10000 / 1e7 = 1.0% + float_cap_map = {'sz000001': [(0, 1000.0)]} + gbbq = pd.DataFrame() + + result = DataProcessor._calc_turnover_rate(df, gbbq, float_cap_map=float_cap_map) + assert result.notna().all() + assert abs(result.iloc[0] - 1.0) < 0.001 + + def test_fallback_when_not_in_map(self): + """full_code 不在 float_cap_map 中时,降级到 gbbq category==5""" + df = self._make_df() + float_cap_map = {'sz000002': [(0, 1000.0)]} # 不含 sz000001 + gbbq = pd.DataFrame() # category==5 也无数据 + + result = DataProcessor._calc_turnover_rate(df, gbbq, float_cap_map=float_cap_map) + assert result.isna().all() + + def test_fallback_when_map_is_none(self): + """float_cap_map=None 时走原有逻辑""" + df = self._make_df() + gbbq = pd.DataFrame() + result = DataProcessor._calc_turnover_rate(df, gbbq, float_cap_map=None) + assert result.isna().all() diff --git a/tests/test_smb.py b/tests/test_smb.py index ffc2d16..af8f527 100644 --- a/tests/test_smb.py +++ b/tests/test_smb.py @@ -107,8 +107,9 @@ def test_register_only_once(self, mock_reg): acc._register() assert mock_reg.call_count == 1 + @patch('smbclient.register_session') @patch('smbclient.path.exists', return_value=True) - def test_exists_returns_true(self, _): + def test_exists_returns_true(self, _mock_exists, _mock_reg): acc = SmbAccessor('host', 'share') assert acc.exists(r'\\host\share\file') is True From 7bad70e9404f68ae13190063e6449ca1b219cd2b Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sun, 12 Apr 2026 20:47:34 +0800 Subject: [PATCH 23/26] =?UTF-8?q?feat:=20=E8=B0=83=E6=95=B4=20volume/amoun?= =?UTF-8?q?t=20=E5=85=A5=E5=BA=93=E5=8D=95=E4=BD=8D=EF=BC=88=E6=89=8B?= =?UTF-8?q?=E2=86=92=E8=82=A1=EF=BC=8C=E5=85=83=E2=86=92=E4=B8=87=E5=85=83?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit volume 由"手"改为"股"(×100),amount 由"元"改为"万元"(÷10000); 同步修正换手率公式系数(×10000 → ×100)。 Co-Authored-By: Claude Sonnet 4.6 --- src/tdx2db/processor.py | 8 ++++---- src/tdx2db/reader.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/tdx2db/processor.py b/src/tdx2db/processor.py index 2eb7a29..9c0fa69 100644 --- a/src/tdx2db/processor.py +++ b/src/tdx2db/processor.py @@ -197,7 +197,7 @@ def build_float_capital_map(base_caps: dict, gbbq: pd.DataFrame) -> dict: @staticmethod def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame, float_cap_map: dict = None) -> pd.Series: - """计算换手率(%):volume(手) × 10000 / 流通股本(股)。 + """计算换手率(%):volume(股) × 100 / 流通股本(股)。 优先使用 float_cap_map(base.dbf 锚点 + gbbq 逆向推算); float_cap_map 为 None 时降级到原有 gbbq category==5 逻辑。 @@ -222,7 +222,7 @@ def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame, float_cap_map: dic ) cap = merged['float_cap'] * 10000 # 万股 → 股 merged['turnover_rate'] = ( - (merged['volume'] * 10000 / cap).where(cap > 0).round(4) + (merged['volume'] * 100 / cap).where(cap > 0).round(4) ) # 调试:打印最近几条 sample = merged[['date_int', 'volume', 'float_cap', 'turnover_rate']].tail(3) @@ -256,8 +256,8 @@ def _calc_turnover_rate(df: pd.DataFrame, gbbq: pd.DataFrame, float_cap_map: dic ) # 流通股本单位为万股,× 10000 换算为股 cap = merged['hongli_panqianliutong'] * 10000 - # 换手率(%) = volume(手) × 100 / 流通股本(股) × 100 = volume × 10000 / 流通股本(股) - merged['turnover_rate'] = (merged['volume'] * 10000 / cap).where(cap > 0).round(4) + # 换手率(%) = volume(股) × 100 / 流通股本(股) + merged['turnover_rate'] = (merged['volume'] * 100 / cap).where(cap > 0).round(4) return merged.set_index('index')['turnover_rate'].reindex(df.index) @staticmethod diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index dbe7cb3..e8eae79 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -141,6 +141,8 @@ def read_daily_data(self, market: int, code: str) -> pd.DataFrame: data['code'] = pure_code data['market'] = market + data['volume'] = data['volume'] * 100 # 手 → 股 + data['amount'] = data['amount'] / 10000 # 元 → 万元 return data @staticmethod From a021e8fc26d3cafb05643c82b1b708d41e275dac Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Sun, 12 Apr 2026 21:37:00 +0800 Subject: [PATCH 24/26] =?UTF-8?q?fix:=20read=5Fdaily=5Fdata=5Fbatch=20?= =?UTF-8?q?=E8=A1=A5=E5=85=85=20volume/amount=20=E5=8D=95=E4=BD=8D?= =?UTF-8?q?=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- src/tdx2db/reader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index e8eae79..b576c20 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -305,6 +305,8 @@ def read_daily_data_batch( data = self._parse_local_day_file(local_map[unc]) data['code'] = pure_code data['market'] = market + data['volume'] = data['volume'] * 100 # 手 → 股 + data['amount'] = data['amount'] / 10000 # 元 → 万元 yield db_code, data except Exception as e: yield db_code, e From 15981c4fb0b42c7807bd1fcf47bbc2d5adee3d69 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Mon, 13 Apr 2026 00:00:51 +0800 Subject: [PATCH 25/26] =?UTF-8?q?feat:=20=E7=94=A8=20TDX=20.tnf=20?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E6=9B=BF=E6=8D=A2=20akshare=20=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E8=82=A1=E7=A5=A8=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 消除 akshare 网络依赖,完全基于通达信本地数据文件获取个股中文名: - reader.py: 新增 read_stock_names(),解析 T0002/hq_cache/{szs,shs,bjs}.tnf - _parse_tnf_file(): 自动探测 record_len 和名称字段偏移,兼容新旧版本 TDX - _detect_tnf_record_len(): 搜索相邻代码对定位记录边界,策略二用间距众数兜底 - _detect_tnf_name_offset(): 扫描记录内首个非零字节簇定位名称偏移 - 返回格式 {'SZ': {...}, 'SH': {...}, 'BJ': {...}},按市场分组避免代码空间重叠覆盖 - smb_accessor.py: 新增 tnf_unc(market) 方法构建 .tnf 文件 UNC 路径 - cli.py: sync_stock_list() 改用 reader.read_stock_names(),按后缀查对应市场字典 - requirements.txt: 移除 akshare 依赖 - tests/test_smb.py: 补充 tnf_unc mock 及 read_stock_names SMB 模式测试 Co-Authored-By: Claude Sonnet 4.6 --- requirements.txt | 1 - src/tdx2db/cli.py | 27 +++---- src/tdx2db/reader.py | 147 +++++++++++++++++++++++++++++++++++++ src/tdx2db/smb_accessor.py | 4 + tests/test_smb.py | 82 +++++++++++++++++++++ 5 files changed, 244 insertions(+), 17 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7978e5b..5ab2db5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,5 @@ python-dotenv>=1.1.0 pymysql>=1.1.1 psycopg2-binary>=2.9.10 requests>=2.32.0 -akshare smbprotocol>=1.13.0 dbfread>=2.0.7 diff --git a/src/tdx2db/cli.py b/src/tdx2db/cli.py index 7928054..0bc0b00 100644 --- a/src/tdx2db/cli.py +++ b/src/tdx2db/cli.py @@ -34,24 +34,19 @@ def _has_ex_rights_after(code: str, gbbq: pd.DataFrame, last_date: int) -> bool: def sync_stock_list(reader: TdxDataReader, storage: DataStorage) -> bool: - """同步股票列表及名称,返回是否成功。""" + """同步股票列表及名称,返回是否成功。从本地 TDX .tnf 文件读取中文名,无需联网。""" try: - import akshare as ak - - ak_df = ak.stock_info_a_code_name() - - def _add_suffix(code: str) -> str: - if code.startswith('6'): - return code + '.SH' - elif code.startswith('8') or code.startswith('92'): - return code + '.BJ' - return code + '.SZ' - - ak_map = {_add_suffix(row['code']): row['name'] for _, row in ak_df.iterrows()} - - local_codes = reader.get_stock_list() + # name_map: {'SZ': {6位code: 名}, 'SH': {...}, 'BJ': {...}} + # 三个市场代码空间有重叠,必须按市场分开查找 + name_map = reader.read_stock_names() + local_codes = reader.get_stock_list() # ['000001.SZ', ...] df = pd.DataFrame([ - {'stock_code': c, 'stock_name': ak_map.get(c, c)} + { + 'stock_code': c, + 'stock_name': name_map.get(c.split('.')[1], {}).get( + c.split('.')[0], c.split('.')[0] + ) + } for c in local_codes ]) logger.info(f"获取到 {len(df)} 只股票") diff --git a/src/tdx2db/reader.py b/src/tdx2db/reader.py index b576c20..c6a0c4c 100644 --- a/src/tdx2db/reader.py +++ b/src/tdx2db/reader.py @@ -67,6 +67,29 @@ def read_gbbq(self) -> pd.DataFrame: logger.warning(f"读取权息文件时出错: {e},将跳过复权处理") return pd.DataFrame() + def read_stock_names(self) -> dict: + """读取 szs/shs/bjs.tnf,返回按市场分组的名称字典。 + 返回格式:{'SZ': {6位code: 名}, 'SH': {6位code: 名}, 'BJ': {6位code: 名}} + 三个市场代码空间有重叠(如 000001 在 SZ 是平安银行,在 SH 是上证指数), + 必须分开存储,不能合并为单一字典。 + 本地模式读 T0002/hq_cache/*.tnf;SMB 模式下载后解析。 + 文件不存在时对应市场返回空字典。 + """ + if self._smb is not None: + return self._read_stock_names_smb() + if self.tdx_path is None: + logger.warning("联网下载模式下不支持读取股票名称,将以代码代替名称") + return {'SZ': {}, 'SH': {}, 'BJ': {}} + result = {} + for fname, suffix in (('szs.tnf', 'SZ'), ('shs.tnf', 'SH'), ('bjs.tnf', 'BJ')): + path = self.tdx_path / 'T0002' / 'hq_cache' / fname + if path.exists(): + result[suffix] = self._parse_tnf_file(str(path)) + else: + logger.warning(f"股票名称文件不存在: {path}") + result[suffix] = {} + return result + def read_base_dbf(self) -> dict: """读取 base.dbf,返回 {code_6位: 流通股本万股} 字典。文件不存在时返回空字典。""" if self._smb is not None: @@ -166,6 +189,107 @@ def _read_day_file_raw(fname: str) -> pd.DataFrame: # ── SMB 私有方法 ────────────────────────────────────────────────────────── + @staticmethod + def _detect_tnf_record_len(data: bytes, header_offset: int = 50) -> int: + """通过搜索相邻股票代码,自动探测 .tnf 文件的记录长度。 + + 策略一(精确):搜索已知相邻代码对,两者位置之差即 record_len。 + 策略二(通用):扫描所有"偏移0处的6位纯数字",统计最高频间距。 + 两者都失败则返回默认值 314。 + """ + import re as _re + from collections import Counter as _Counter + body = data[header_offset:] + + # 策略一:已知相邻代码对(覆盖三个市场) + pairs = [ + (b'000001', b'000002'), + (b'000002', b'000003'), + (b'000004', b'000005'), + (b'600000', b'600001'), + (b'600001', b'600002'), + ] + for a, b in pairs: + pos_a = body.find(a) + if pos_a < 0: + continue + pos_b = body.find(b, pos_a + 1) + if pos_b < 0: + continue + gap = pos_b - pos_a + if 100 < gap < 65536: + return gap + + # 策略二:扫描前 200KB,找所有"位置对齐"的6位纯数字,统计间距众数 + scan = body[:200 * 1024] + positions = [] + for m in _re.finditer(rb'\d{6}', scan): + positions.append(m.start()) + gap_counter = _Counter() + for i in range(len(positions) - 1): + gap = positions[i + 1] - positions[i] + if 100 < gap < 65536: + gap_counter[gap] += 1 + if gap_counter: + return gap_counter.most_common(1)[0][0] + + return 314 + + @staticmethod + def _detect_tnf_name_offset(data: bytes, header_offset: int, record_len: int) -> int: + """在已知 record_len 的前提下,探测中文名称字段的字节偏移。 + 策略:取前若干条记录,在 record 内寻找最长的连续非零字节串(即名称字段)。 + 返回探测到的偏移,失败则返回默认值 23。 + """ + body = data[header_offset:] + # 取前 10 条记录,统计各偏移处出现非零字节的频率 + sample_count = min(10, len(body) // record_len) + if sample_count == 0: + return 23 + # 对每条记录,找 offset 6 之后第一个非零字节簇的起始位置 + offsets = [] + for i in range(sample_count): + rec = body[i * record_len: (i + 1) * record_len] + # 跳过前 6 字节(代码),找后续第一个非零字节 + j = 6 + while j < len(rec) and rec[j] == 0: + j += 1 + if j < len(rec): + offsets.append(j) + if offsets: + from collections import Counter + return Counter(offsets).most_common(1)[0][0] + return 23 + + @staticmethod + def _parse_tnf_file(file_path: str) -> dict: + """解析 TDX .tnf 文件,返回 {6位code: 名称} 字典。 + 自动探测 record_len 和名称字段偏移,兼容新旧版本通达信。 + 非股票记录(代码含非数字字符)自动跳过。 + """ + result = {} + header_offset = 50 + with open(file_path, 'rb') as f: + data = f.read() + + record_len = TdxDataReader._detect_tnf_record_len(data, header_offset) + name_offset = TdxDataReader._detect_tnf_name_offset(data, header_offset, record_len) + logger.debug(f"[tnf] {file_path}: record_len={record_len}, name_offset={name_offset}") + + body = data[header_offset:] + for i in range(len(body) // record_len): + rec = body[i * record_len: (i + 1) * record_len] + try: + code = rec[0:6].decode('ascii').strip('\x00').strip() + if not code or not code.isdigit(): + continue + name_bytes = rec[name_offset: name_offset + 32].split(b'\x00')[0] + name = name_bytes.decode('gbk', errors='replace').strip() + result[code] = name + except Exception: + continue + return result + @staticmethod def _parse_base_dbf(path: str) -> dict: try: @@ -184,6 +308,29 @@ def _parse_base_dbf(path: str) -> dict: logger.info(f"base.dbf 读取完成,共 {len(result)} 条记录") return result + def _read_stock_names_smb(self) -> dict: + """SMB 模式:下载三个 .tnf 文件到临时文件后解析。 + 返回格式:{'SZ': {...}, 'SH': {...}, 'BJ': {...}} + 直接尝试下载(不做 exists() 预判),避免权限等原因导致误判而静默跳过。 + """ + result = {'SZ': {}, 'SH': {}, 'BJ': {}} + for market, suffix in (('szs', 'SZ'), ('shs', 'SH'), ('bjs', 'BJ')): + unc = self._smb.tnf_unc(market) + try: + tmp_path = self._smb.download_to_tmp(unc, suffix='.tnf') + except Exception as e: + logger.warning(f"SMB 下载 {market}.tnf 失败: {e}") + continue + try: + parsed = self._parse_tnf_file(tmp_path) + logger.debug(f"{market}.tnf 解析完成: {len(parsed)} 条") + result[suffix] = parsed + except Exception as e: + logger.warning(f"SMB 解析 {market}.tnf 出错: {e}") + finally: + os.unlink(tmp_path) + return result + def _read_base_dbf_smb(self) -> dict: unc = self._smb.base_dbf_unc if not self._smb.exists(unc): diff --git a/src/tdx2db/smb_accessor.py b/src/tdx2db/smb_accessor.py index b9cf00b..76ba22f 100644 --- a/src/tdx2db/smb_accessor.py +++ b/src/tdx2db/smb_accessor.py @@ -97,6 +97,10 @@ def lday_dir_unc(self, market: str) -> str: def day_file_unc(self, market: str, filename: str) -> str: return self._unc('vipdoc', market, 'lday', filename) + def tnf_unc(self, market: str) -> str: + """返回 .tnf 股票名称文件的 UNC 路径。market: 'szs' / 'shs' / 'bjs'""" + return self._unc('T0002', 'hq_cache', f'{market}.tnf') + # ── 核心 I/O 操作 ───────────────────────────────────────────────────────── def exists(self, unc_path: str) -> bool: diff --git a/tests/test_smb.py b/tests/test_smb.py index af8f527..5ab8ac6 100644 --- a/tests/test_smb.py +++ b/tests/test_smb.py @@ -172,6 +172,9 @@ def _make_smb(self, day_bytes=None): smb.day_file_unc = MagicMock( side_effect=lambda m, f: rf'\\host\share\TDX\vipdoc\{m}\lday\{f}' ) + smb.tnf_unc = MagicMock( + side_effect=lambda m: rf'\\host\share\TDX\T0002\hq_cache\{m}.tnf' + ) smb.exists = MagicMock(return_value=True) smb.list_files = MagicMock(return_value=[]) if day_bytes is not None: @@ -281,6 +284,85 @@ def tracking_side_effect(unc, suffix='.day'): assert len(created_tmp) == 1 assert not os.path.exists(created_tmp[0]), "临时文件应已被删除" + def test_read_stock_names_smb(self): + """SMB 模式下 read_stock_names 应下载三个 .tnf 并返回名称字典。""" + import struct + + def _make_tnf_bytes(code: str, name: str) -> bytes: + """构造最小合法 .tnf 文件:50字节头 + 1条314字节记录。""" + header = b'\x00' * 50 + record = bytearray(314) + record[0:6] = code.encode('ascii').ljust(6, b'\x00') + name_gbk = name.encode('gbk') + record[23:23 + len(name_gbk)] = name_gbk + return header + bytes(record) + + smb = self._make_smb() + # 三个市场各有一只股票 + tnf_data = { + r'\\host\share\TDX\T0002\hq_cache\szs.tnf': _make_tnf_bytes('000001', '平安银行'), + r'\\host\share\TDX\T0002\hq_cache\shs.tnf': _make_tnf_bytes('600000', '浦发银行'), + r'\\host\share\TDX\T0002\hq_cache\bjs.tnf': _make_tnf_bytes('920001', '北交所A'), + } + + def _download_side_effect(unc, suffix='.day'): + data = tnf_data.get(unc, b'\x00' * 50) # 未知文件给空头部 + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp.write(data) + tmp.flush() + tmp.close() + return tmp.name + + smb.download_to_tmp = MagicMock(side_effect=_download_side_effect) + + reader = TdxDataReader(smb=smb) + names = reader.read_stock_names() + + # 返回格式为 {'SZ': {...}, 'SH': {...}, 'BJ': {...}},按市场分组 + assert names['SZ'].get('000001') == '平安银行' + assert names['SH'].get('600000') == '浦发银行' + assert names['BJ'].get('920001') == '北交所A' + # tnf_unc 应被调用三次 + assert smb.tnf_unc.call_count == 3 + + def test_read_stock_names_smb_download_failure(self): + """SMB 模式下某个 .tnf 下载失败时,其他文件仍正常解析。""" + import struct + + def _make_tnf_bytes(code: str, name: str) -> bytes: + header = b'\x00' * 50 + record = bytearray(314) + record[0:6] = code.encode('ascii').ljust(6, b'\x00') + name_gbk = name.encode('gbk') + record[23:23 + len(name_gbk)] = name_gbk + return header + bytes(record) + + smb = self._make_smb() + tnf_data = { + r'\\host\share\TDX\T0002\hq_cache\szs.tnf': _make_tnf_bytes('000001', '平安银行'), + r'\\host\share\TDX\T0002\hq_cache\bjs.tnf': _make_tnf_bytes('920001', '北交所A'), + } + + def _download_side_effect(unc, suffix='.day'): + if 'shs' in unc: + raise OSError("SMB 连接超时") # 模拟 shs.tnf 下载失败 + data = tnf_data.get(unc, b'\x00' * 50) + tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) + tmp.write(data) + tmp.flush() + tmp.close() + return tmp.name + + smb.download_to_tmp = MagicMock(side_effect=_download_side_effect) + + reader = TdxDataReader(smb=smb) + names = reader.read_stock_names() + + # szs 和 bjs 应正常解析,shs 失败但不影响整体 + assert names['SZ'].get('000001') == '平安银行' + assert names['BJ'].get('920001') == '北交所A' + assert '600000' not in names['SH'] + # ─── TestCliSmbInit ────────────────────────────────────────────────────────── From 086fb6a45c5ea16e541cd8fd544fd829699d0816 Mon Sep 17 00:00:00 2001 From: jaden1q84 Date: Tue, 14 Apr 2026 22:51:56 +0800 Subject: [PATCH 26/26] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=20README?= =?UTF-8?q?=EF=BC=8C=E8=A1=A5=E5=85=85=E6=B5=8B=E8=AF=95=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E3=80=81=E7=9B=98=E5=90=8E=E6=95=B0=E6=8D=AE=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=E6=AD=A5=E9=AA=A4=E5=8F=8A=E5=AE=8C=E6=95=B4=E8=A1=A8=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Sonnet 4.6 --- README.md | 67 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 1715f93..a52c39d 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,17 @@ 从本地通达信(TDX)行情软件读取 A 股日线数据,增量同步到数据库。支持作为 Python 包被其他项目调用。 +## 测试环境 + +- 通达信版本:**金融终端 V7.72(64位)** + ## 特性 -- 同步深圳/上海全量 A 股日线数据(含科创板) +- 同步深圳/上海/北交所全量 A 股日线数据(含科创板) - 前复权 / 后复权 / 不复权,默认前复权 - 增量更新:有除权事件的个股自动全量重写,确保复权价格正确 -- 日期格式:`YYYYMMDD` 整数(便于范围查询) +- 包含换手率数据(`turnover_rate`) +- 日期格式:`YYYYMMDD` 字符串(便于范围查询) - 数据库:SQLite(默认)/ MySQL / PostgreSQL ## 安装 @@ -75,6 +80,16 @@ SMB_PORT=445 启用 SMB 模式后,`TDX_PATH` 可以不填。 +## 准备工作:下载历史盘后数据 + +使用本程序前,需要先在通达信中下载历史盘后数据: + +1. 打开通达信客户端 +2. 菜单栏 → **选项** → **盘后数据下载** +3. 选择需要的历史数据范围并下载完成 + +> 程序读取的是通达信本地 `.day` 文件,必须先确保数据已通过上述方式下载到本地,否则无法同步。 + ## 命令行使用 ```bash @@ -126,20 +141,46 @@ print(df.head()) ## 数据表结构 -**daily_data** +数据库包含以下三张表,由 SQLAlchemy 在首次运行时自动创建,无需手动执行 SQL。 + +### daily_data(日线数据) + +| 列 | 类型 | 说明 | +|----|------|------| +| id | Integer | 自增主键 | +| stock_code | String(12) | 股票代码(6位,如 `000001`) | +| market | Integer | 市场(0=深圳,1=上海,2=北京) | +| date | String(8) | 日期,格式 `YYYYMMDD` | +| open | Float | 开盘价(复权后) | +| high | Float | 最高价(复权后) | +| low | Float | 最低价(复权后) | +| close | Float | 收盘价(复权后) | +| volume | Float | 成交量(手) | +| amount | Float | 成交额(元) | +| adj_factor | Float | 复权因子(前复权时 < 1,不复权时 = 1.0) | +| turnover_rate | Float | 换手率(%),无法计算时为 NULL | + +唯一约束:`(stock_code, date)` + +### stock_info(股票列表) + +| 列 | 类型 | 说明 | +|----|------|------| +| stock_code | String(12) | 股票代码(主键,如 `000001`) | +| stock_name | String(50) | 股票名称(如 `平安银行`) | + +唯一约束:`stock_code`(即主键) + +### kline_statistics(同步统计) + +每次 `sync` 命令完成后写入一条统计记录,用于追踪历次同步情况。 | 列 | 类型 | 说明 | |----|------|------| -| code | String | 股票代码(6位,如 `000001`) | -| market | Integer | 市场(0=深圳,1=上海) | -| date | Integer | 日期 YYYYMMDD | -| open/high/low/close | Float | 复权后价格 | -| volume | Float | 成交量 | -| amount | Float | 成交额 | -| adj_factor | Float | 复权因子(1.0=无复权) | -| turnover_rate | Float | 换手率(%),待实现 | - -唯一约束:`(code, date)` +| id | Integer | 自增主键 | +| stock_count | Integer | 本次同步的股票数量 | +| total_rows | Integer | 同步后 daily_data 的总行数 | +| sync_time | DateTime | 同步完成时间 | ## 运行测试