From 73b18afbce9c33a210bb4eea490f9f92fa960f39 Mon Sep 17 00:00:00 2001 From: lwk <3098293798@qq.com> Date: Wed, 11 Mar 2026 23:03:09 +0800 Subject: [PATCH] feat(kimi): DeepSORT reconstruction with mc_lambda & scipy interpolation --- .DS_Store | Bin 0 -> 6148 bytes README_V2.md | 194 ++++++++++++++++ deep_sort/__init__.py | 28 ++- deep_sort/deep_feature_extractor.py | 262 +++++++++++++++++++++ deep_sort/detection.py | 44 ++-- deep_sort/track.py | 151 +++++++----- deep_sort/track_interpolation.py | 306 ++++++++++++++++++++++++ deep_sort/tracker.py | 119 +++++++--- deep_sort/yolo_detector.py | 306 ++++++++++++++++++++++++ deep_sort_app_v2.py | 349 ++++++++++++++++++++++++++++ example_usage.py | 237 +++++++++++++++++++ requirements.txt | 3 + 12 files changed, 1892 insertions(+), 107 deletions(-) create mode 100644 .DS_Store create mode 100644 README_V2.md create mode 100644 deep_sort/deep_feature_extractor.py create mode 100644 deep_sort/track_interpolation.py create mode 100644 deep_sort/yolo_detector.py create mode 100644 deep_sort_app_v2.py create mode 100644 example_usage.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ab45feac3f3aa2c0f1621bcaa2df3ffda29ef823 GIT binary patch literal 6148 zcmeHKJx;?g6n>^bDzJ2_m@1zDsS^`YsKSVZ)B^x*P(eybMB6<(XMi(s0cJ*ELJW)? zfE(~V+buQ?1r}7H_aggq-cR217sZZ=NOdM@o2WrVbu`9k9bJX6o!gwOI1L9YenvKE zcl-IUnDZ2cFdz&pBLn8cr!5-L8I6PQS3U2b;Amw`vPj2)I1&C-F+907n4)J8Do zUkZ-V4r7OdwCAYWEO^)|kjdp~_!hWSi6N1id$7aA= c@g|xP*fbA-vBT0LG!Xd_Ff>Rf4E!kr?|!zVv;Y7A literal 0 HcmV?d00001 diff --git a/README_V2.md b/README_V2.md new file mode 100644 index 000000000..aa89eeffc --- /dev/null +++ b/README_V2.md @@ -0,0 +1,194 @@ +# Deep SORT v2 - PyTorch版本 + +基于PyTorch的Deep SORT多目标跟踪系统升级版,支持多类别跟踪、端到端检测跟踪流程、EMA特征更新和轨迹插值。 + +## 新特性 + +### 1. PyTorch深度学习框架 +- 使用PyTorch替代TensorFlow +- 实现了基于ResNet的特征提取器 +- 支持GPU加速 + +### 2. 多类别跟踪 (Multi-Class Tracking) +- 支持同时跟踪多个目标类别 +- 类别感知的轨迹关联 +- 可配置的类别匹配权重 + +### 3. YOLOv8集成 +- 集成Ultralytics YOLOv8检测器 +- 支持视频文件 (.mp4) 和摄像头输入 +- 端到端"目标检测 -> 深度特征提取 -> 轨迹关联"流程 + +### 4. EMA特征更新 +- 使用指数移动平均(EMA)平滑目标特征 +- 减少检测器误差引入的错误特征 +- 提高特征表示的稳定性 + +### 5. 轨迹插值与平滑 +- 离线轨迹插值填补遮挡期间的缺失帧 +- 支持线性插值和样条插值 +- 轨迹平滑减少抖动 + +## 安装 + +```bash +# 安装依赖 +pip install -r requirements.txt + +# 下载YOLOv8模型(首次使用会自动下载) +# 可选模型: yolov8n.pt, yolov8s.pt, yolov8m.pt, yolov8l.pt, yolov8x.pt +``` + +## 快速开始 + +### 摄像头实时跟踪 + +```bash +python deep_sort_app_v2.py --source 0 +``` + +### 视频文件跟踪 + +```bash +python deep_sort_app_v2.py \ + --source path/to/video.mp4 \ + --save-video \ + --output result.mp4 +``` + +### 多类别跟踪(人和车辆) + +```bash +python deep_sort_app_v2.py \ + --source path/to/video.mp4 \ + --classes 0 2 3 5 7 \ + --save-video +``` + +类别ID对应关系(COCO数据集): +- 0: person (人) +- 2: car (汽车) +- 3: motorcycle (摩托车) +- 5: bus (公交车) +- 7: truck (卡车) + +## 参数说明 + +### 输入输出 +- `--source`: 视频文件路径或摄像头索引 (默认: 0) +- `--output`: 输出视频路径 +- `--save-video`: 保存输出视频 + +### 模型设置 +- `--model`: YOLOv8模型路径 (默认: yolov8n.pt) +- `--feature-model`: 特征提取器模型路径(可选) + +### 检测参数 +- `--conf-threshold`: 检测置信度阈值 (默认: 0.3) +- `--iou-threshold`: NMS IOU阈值 (默认: 0.45) +- `--classes`: 要跟踪的类别ID列表 + +### 跟踪参数 +- `--max-cosine-distance`: 最大余弦距离 (默认: 0.2) +- `--nn-budget`: 外观描述符库大小 (默认: 100) +- `--max-age`: 轨迹最大存活帧数 (默认: 30) +- `--n-init`: 轨迹初始化所需检测次数 (默认: 3) +- `--ema-alpha`: EMA平滑因子 (默认: 0.9) +- `--mc-lambda`: 多类别匹配权重 (默认: 0.995) + +### 后处理参数 +- `--no-interpolation`: 禁用轨迹插值 +- `--no-smoothing`: 禁用轨迹平滑 +- `--max-interpolation-gap`: 最大插值间隔 (默认: 30) + +## API使用示例 + +```python +from deep_sort import ( + Tracker, + NearestNeighborDistanceMetric, + YOLODetector, + FeatureExtractor +) + +# 创建检测器 +detector = YOLODetector( + model_path='yolov8n.pt', + conf_threshold=0.3, + classes=[0] # 只检测人 +) + +# 创建特征提取器 +feature_extractor = FeatureExtractor(device='cuda') + +# 创建跟踪器 +metric = NearestNeighborDistanceMetric("cosine", 0.2, 100) +tracker = Tracker( + metric, + max_iou_distance=0.7, + max_age=30, + n_init=3, + ema_alpha=0.9, # EMA平滑因子 + mc_lambda=0.995 # 多类别匹配权重 +) + +# 处理视频 +import cv2 +cap = cv2.VideoCapture('video.mp4') + +while True: + ret, frame = cap.read() + if not ret: + break + + # 检测和跟踪 + results = detector.detect_and_track(frame, tracker, feature_extractor) + + # 处理结果 + for result in results: + track_id = result['track_id'] + bbox = result['bbox'] # [x1, y1, x2, y2] + class_id = result['class_id'] + class_name = result['class_name'] + + print(f"Track {track_id}: {class_name} at {bbox}") + +cap.release() +``` + +## 项目结构 + +``` +deep_sort/ +├── deep_sort/ +│ ├── __init__.py +│ ├── detection.py # 检测类 +│ ├── track.py # 轨迹类(含EMA) +│ ├── tracker.py # 跟踪器(多类别支持) +│ ├── kalman_filter.py # 卡尔曼滤波 +│ ├── nn_matching.py # 最近邻匹配 +│ ├── linear_assignment.py # 线性分配 +│ ├── iou_matching.py # IOU匹配 +│ ├── deep_feature_extractor.py # PyTorch特征提取器 +│ ├── yolo_detector.py # YOLOv8检测器 +│ └── track_interpolation.py # 轨迹插值与平滑 +├── deep_sort_app_v2.py # 主应用 +├── example_usage.py # 使用示例 +├── requirements.txt +└── README_V2.md +``` + +## 性能优化建议 + +1. **使用GPU**: 确保CUDA可用以加速特征提取 +2. **选择合适的YOLO模型**: + - yolov8n.pt: 最快,精度较低 + - yolov8s.pt: 平衡 + - yolov8m/l/x.pt: 更慢,精度更高 +3. **调整跟踪参数**: + - 降低`max_age`可减少ID切换但可能丢失目标 + - 调整`ema_alpha`平衡特征更新速度 + +## 许可证 + +与原Deep SORT项目相同。 diff --git a/deep_sort/__init__.py b/deep_sort/__init__.py index 43e08fb8a..9fad05d2e 100644 --- a/deep_sort/__init__.py +++ b/deep_sort/__init__.py @@ -1 +1,27 @@ -# vim: expandtab:ts=4:sw=4 +from .detection import Detection +from .track import Track, TrackState +from .tracker import Tracker +from .kalman_filter import KalmanFilter +from .nn_matching import NearestNeighborDistanceMetric +from .deep_feature_extractor import FeatureExtractor, DeepFeatureExtractor, create_box_encoder +from .yolo_detector import YOLODetector, VideoCapture +from .track_interpolation import TrackInterpolator, TrajectorySmoother, TrackPostProcessor + +__version__ = '2.0.0' + +__all__ = [ + 'Detection', + 'Track', + 'TrackState', + 'Tracker', + 'KalmanFilter', + 'NearestNeighborDistanceMetric', + 'FeatureExtractor', + 'DeepFeatureExtractor', + 'create_box_encoder', + 'YOLODetector', + 'VideoCapture', + 'TrackInterpolator', + 'TrajectorySmoother', + 'TrackPostProcessor', +] diff --git a/deep_sort/deep_feature_extractor.py b/deep_sort/deep_feature_extractor.py new file mode 100644 index 000000000..9f90d4173 --- /dev/null +++ b/deep_sort/deep_feature_extractor.py @@ -0,0 +1,262 @@ +import torch +import torch.nn as nn +import torchvision.transforms as T +import numpy as np +import cv2 + + +class BasicBlock(nn.Module): + """ResNet Basic Block。""" + + def __init__(self, in_channels, out_channels, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, + stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, + stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.downsample = None + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, + stride=stride, bias=False), + nn.BatchNorm2d(out_channels) + ) + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class DeepFeatureExtractor(nn.Module): + """ + 基于ResNet的行人重识别特征提取器。 + 输入: 128x64 的行人图像 + 输出: 128维特征向量 + """ + + def __init__(self, num_classes=751, reid_dim=128): + super(DeepFeatureExtractor, self).__init__() + self.reid_dim = reid_dim + + # 初始卷积层 + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # ResNet层 + self.layer1 = self._make_layer(64, 64, 2, stride=1) + self.layer2 = self._make_layer(64, 128, 2, stride=2) + self.layer3 = self._make_layer(128, 256, 2, stride=2) + self.layer4 = self._make_layer(256, 512, 2, stride=1) + + # 全局平均池化 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + # 特征投影层 + self.fc = nn.Linear(512, reid_dim) + self.bn_fc = nn.BatchNorm1d(reid_dim) + + # 分类器(仅在训练时使用) + self.classifier = nn.Linear(reid_dim, num_classes) + + self._initialize_weights() + + def _make_layer(self, in_channels, out_channels, num_blocks, stride): + layers = [] + layers.append(BasicBlock(in_channels, out_channels, stride)) + for _ in range(1, num_blocks): + layers.append(BasicBlock(out_channels, out_channels)) + return nn.Sequential(*layers) + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.fc(x) + x = self.bn_fc(x) + x = nn.functional.normalize(x, p=2, dim=1) + + return x + + def extract_features(self, x): + """提取特征(推理时使用)。""" + return self.forward(x) + + +class FeatureExtractor: + """ + 特征提取器封装类,处理预处理和后处理。 + """ + + def __init__(self, model_path=None, device='cuda', image_size=(128, 64)): + self.device = torch.device(device if torch.cuda.is_available() else 'cpu') + self.image_size = image_size + + # 创建模型 + self.model = DeepFeatureExtractor(reid_dim=128) + + # 加载预训练权重(如果提供) + if model_path is not None: + self.load_weights(model_path) + + self.model.to(self.device) + self.model.eval() + + # 图像预处理 + self.transform = T.Compose([ + T.ToPILImage(), + T.Resize(image_size), + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + def load_weights(self, model_path): + """加载模型权重。""" + checkpoint = torch.load(model_path, map_location=self.device) + if 'state_dict' in checkpoint: + self.model.load_state_dict(checkpoint['state_dict']) + else: + self.model.load_state_dict(checkpoint) + print(f"Loaded weights from {model_path}") + + def save_weights(self, model_path): + """保存模型权重。""" + torch.save(self.model.state_dict(), model_path) + print(f"Saved weights to {model_path}") + + def _extract_image_patch(self, image, bbox): + """从边界框提取图像块。""" + x, y, w, h = bbox.astype(int) + + # 调整边界框以匹配目标宽高比 + target_aspect = self.image_size[1] / self.image_size[0] + new_w = target_aspect * h + x = int(x - (new_w - w) / 2) + w = int(new_w) + + # 裁剪并调整大小 + x1 = max(0, x) + y1 = max(0, y) + x2 = min(image.shape[1], x + w) + y2 = min(image.shape[0], y + h) + + if x1 >= x2 or y1 >= y2: + return None + + patch = image[y1:y2, x1:x2] + patch = cv2.resize(patch, (self.image_size[1], self.image_size[0])) + + return patch + + def __call__(self, image, boxes): + """ + 提取图像中边界框的特征。 + + Parameters + ---------- + image : ndarray + BGR格式的图像。 + boxes : ndarray + 边界框数组,格式为 (x, y, w, h)。 + + Returns + ------- + ndarray + 特征矩阵,每行对应一个边界框的特征。 + """ + if len(boxes) == 0: + return np.array([]) + + patches = [] + for box in boxes: + patch = self._extract_image_patch(image, box) + if patch is None: + # 如果提取失败,使用随机噪声 + patch = np.random.randint(0, 256, (*self.image_size, 3), dtype=np.uint8) + patches.append(patch) + + # 转换为张量 + patches_tensor = torch.stack([ + self.transform(patch) for patch in patches + ]).to(self.device) + + # 提取特征 + with torch.no_grad(): + features = self.model.extract_features(patches_tensor) + + return features.cpu().numpy() + + +def create_box_encoder(model_path=None, device='cuda', batch_size=32): + """ + 创建边界框编码器。 + + Parameters + ---------- + model_path : str, optional + 模型权重路径。 + device : str + 计算设备。 + batch_size : int + 批处理大小。 + + Returns + ------- + callable + 编码器函数。 + """ + extractor = FeatureExtractor(model_path=model_path, device=device) + + def encoder(image, boxes): + if len(boxes) == 0: + return np.array([]) + + features = [] + for i in range(0, len(boxes), batch_size): + batch_boxes = boxes[i:i + batch_size] + batch_features = extractor(image, batch_boxes) + features.append(batch_features) + + return np.vstack(features) if features else np.array([]) + + return encoder diff --git a/deep_sort/detection.py b/deep_sort/detection.py index 97cd39d07..c8b98d498 100644 --- a/deep_sort/detection.py +++ b/deep_sort/detection.py @@ -1,47 +1,55 @@ -# vim: expandtab:ts=4:sw=4 import numpy as np -class Detection(object): +class Detection: """ - This class represents a bounding box detection in a single image. + 表示单张图像中的边界框检测。 Parameters ---------- tlwh : array_like - Bounding box in format `(x, y, w, h)`. + 边界框格式 `(x, y, w, h)`. confidence : float - Detector confidence score. + 检测器置信度分数。 feature : array_like - A feature vector that describes the object contained in this image. + 描述图像中包含目标的特征向量。 + class_id : int + 目标类别ID,用于多类别跟踪。 + class_name : str + 目标类别名称。 Attributes ---------- tlwh : ndarray - Bounding box in format `(top left x, top left y, width, height)`. - confidence : ndarray - Detector confidence score. - feature : ndarray | NoneType - A feature vector that describes the object contained in this image. - + 边界框格式 `(top left x, top left y, width, height)`. + confidence : float + 检测器置信度分数。 + feature : ndarray | None + 描述图像中包含目标的特征向量。 + class_id : int + 目标类别ID。 + class_name : str + 目标类别名称。 """ - def __init__(self, tlwh, confidence, feature): + def __init__(self, tlwh, confidence, feature, class_id=0, class_name=None): self.tlwh = np.asarray(tlwh, dtype=np.float64) self.confidence = float(confidence) - self.feature = np.asarray(feature, dtype=np.float32) + self.feature = np.asarray(feature, dtype=np.float32) if feature is not None else None + self.class_id = int(class_id) + self.class_name = class_name def to_tlbr(self): - """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., - `(top left, bottom right)`. + """将边界框转换为格式 `(min x, min y, max x, max y)`, + 即 `(top left, bottom right)`. """ ret = self.tlwh.copy() ret[2:] += ret[:2] return ret def to_xyah(self): - """Convert bounding box to format `(center x, center y, aspect ratio, - height)`, where the aspect ratio is `width / height`. + """将边界框转换为格式 `(center x, center y, aspect ratio, height)`, + 其中 aspect ratio 是 `width / height`. """ ret = self.tlwh.copy() ret[:2] += ret[2:] / 2 diff --git a/deep_sort/track.py b/deep_sort/track.py index f638e9b3e..3ffca239e 100644 --- a/deep_sort/track.py +++ b/deep_sort/track.py @@ -1,16 +1,12 @@ -# vim: expandtab:ts=4:sw=4 +import numpy as np class TrackState: """ - Enumeration type for the single target track state. Newly created tracks are - classified as `tentative` until enough evidence has been collected. Then, - the track state is changed to `confirmed`. Tracks that are no longer alive - are classified as `deleted` to mark them for removal from the set of active - tracks. - + 单目标跟踪状态的枚举类型。新创建的轨迹被分类为 `tentative` + 直到收集到足够的证据。然后,轨迹状态变为 `confirmed`。 + 不再活跃的轨迹被分类为 `deleted`,标记为从活跃轨迹集合中移除。 """ - Tentative = 1 Confirmed = 2 Deleted = 3 @@ -18,53 +14,59 @@ class TrackState: class Track: """ - A single target track with state space `(x, y, a, h)` and associated - velocities, where `(x, y)` is the center of the bounding box, `a` is the - aspect ratio and `h` is the height. + 单目标跟踪,状态空间为 `(x, y, a, h)` 和相关速度, + 其中 `(x, y)` 是边界框中心位置,`a` 是宽高比,`h` 是高度。 Parameters ---------- mean : ndarray - Mean vector of the initial state distribution. + 初始状态分布的均值向量。 covariance : ndarray - Covariance matrix of the initial state distribution. + 初始状态分布的协方差矩阵。 track_id : int - A unique track identifier. + 唯一的轨迹标识符。 n_init : int - Number of consecutive detections before the track is confirmed. The - track state is set to `Deleted` if a miss occurs within the first - `n_init` frames. + 轨迹被确认前的连续检测次数。如果在最初的 `n_init` 帧内发生丢失, + 轨迹状态被设置为 `Deleted`。 max_age : int - The maximum number of consecutive misses before the track state is - set to `Deleted`. + 轨迹状态被设置为 `Deleted` 前的最大连续丢失次数。 feature : Optional[ndarray] - Feature vector of the detection this track originates from. If not None, - this feature is added to the `features` cache. + 此轨迹来源检测的特征向量。如果不为 None,此特征被添加到 `features` 缓存。 + class_id : int + 目标类别ID。 + class_name : str + 目标类别名称。 + ema_alpha : float + EMA (指数移动平均) 的平滑因子。 Attributes ---------- mean : ndarray - Mean vector of the initial state distribution. + 状态分布的均值向量。 covariance : ndarray - Covariance matrix of the initial state distribution. + 状态分布的协方差矩阵。 track_id : int - A unique track identifier. + 唯一的轨迹标识符。 hits : int - Total number of measurement updates. + 测量更新的总次数。 age : int - Total number of frames since first occurance. + 自首次出现以来的总帧数。 time_since_update : int - Total number of frames since last measurement update. + 自上次测量更新以来的总帧数。 state : TrackState - The current track state. + 当前轨迹状态。 features : List[ndarray] - A cache of features. On each measurement update, the associated feature - vector is added to this list. - + 特征缓存。每次测量更新时,关联的特征向量被添加到此列表。 + ema_feature : ndarray + 使用EMA更新的平滑特征表示。 + class_id : int + 目标类别ID。 + class_name : str + 目标类别名称。 """ def __init__(self, mean, covariance, track_id, n_init, max_age, - feature=None): + feature=None, class_id=0, class_name=None, ema_alpha=0.9): self.mean = mean self.covariance = covariance self.track_id = track_id @@ -74,21 +76,41 @@ def __init__(self, mean, covariance, track_id, n_init, max_age, self.state = TrackState.Tentative self.features = [] + self.ema_alpha = ema_alpha + + # 使用EMA初始化特征 if feature is not None: self.features.append(feature) + self.ema_feature = feature.copy() + else: + self.ema_feature = None self._n_init = n_init self._max_age = max_age + self.class_id = class_id + self.class_name = class_name + + # 用于轨迹插值的历史记录 + self.history = [] + self._record_state() + + def _record_state(self): + """记录当前状态到历史记录,用于轨迹插值。""" + self.history.append({ + 'mean': self.mean.copy(), + 'covariance': self.covariance.copy(), + 'age': self.age, + 'time_since_update': self.time_since_update, + 'state': self.state + }) def to_tlwh(self): - """Get current position in bounding box format `(top left x, top left y, - width, height)`. + """获取当前位置,格式为边界框 `(top left x, top left y, width, height)`. Returns ------- ndarray - The bounding box. - + 边界框。 """ ret = self.mean[:4].copy() ret[2] *= ret[3] @@ -96,71 +118,88 @@ def to_tlwh(self): return ret def to_tlbr(self): - """Get current position in bounding box format `(min x, miny, max x, - max y)`. + """获取当前位置,格式为边界框 `(min x, min y, max x, max y)`. Returns ------- ndarray - The bounding box. - + 边界框。 """ ret = self.to_tlwh() ret[2:] = ret[:2] + ret[2:] return ret def predict(self, kf): - """Propagate the state distribution to the current time step using a - Kalman filter prediction step. + """使用卡尔曼滤波预测步骤将状态分布传播到当前时间步。 Parameters ---------- kf : kalman_filter.KalmanFilter - The Kalman filter. - + 卡尔曼滤波器。 """ self.mean, self.covariance = kf.predict(self.mean, self.covariance) self.age += 1 self.time_since_update += 1 + self._record_state() def update(self, kf, detection): - """Perform Kalman filter measurement update step and update the feature - cache. + """执行卡尔曼滤波测量更新步骤并更新特征缓存。 Parameters ---------- kf : kalman_filter.KalmanFilter - The Kalman filter. + 卡尔曼滤波器。 detection : Detection - The associated detection. - + 关联的检测。 """ self.mean, self.covariance = kf.update( self.mean, self.covariance, detection.to_xyah()) - self.features.append(detection.feature) + + # 使用EMA更新特征 + if detection.feature is not None: + self.features.append(detection.feature) + if self.ema_feature is None: + self.ema_feature = detection.feature.copy() + else: + self.ema_feature = ( + self.ema_alpha * self.ema_feature + + (1 - self.ema_alpha) * detection.feature + ) self.hits += 1 self.time_since_update = 0 if self.state == TrackState.Tentative and self.hits >= self._n_init: self.state = TrackState.Confirmed + self._record_state() + def mark_missed(self): - """Mark this track as missed (no association at the current time step). - """ + """将此轨迹标记为丢失(当前时间步无关联)。""" if self.state == TrackState.Tentative: self.state = TrackState.Deleted elif self.time_since_update > self._max_age: self.state = TrackState.Deleted def is_tentative(self): - """Returns True if this track is tentative (unconfirmed). - """ + """如果此轨迹是 tentative(未确认),返回 True。""" return self.state == TrackState.Tentative def is_confirmed(self): - """Returns True if this track is confirmed.""" + """如果此轨迹已确认,返回 True。""" return self.state == TrackState.Confirmed def is_deleted(self): - """Returns True if this track is dead and should be deleted.""" + """如果此轨迹已死亡并应被删除,返回 True。""" return self.state == TrackState.Deleted + + def get_smoothed_feature(self): + """获取EMA平滑后的特征表示。 + + Returns + ------- + ndarray + EMA平滑后的特征向量。 + """ + return self.ema_feature if self.ema_feature is not None else ( + self.features[-1] if self.features else None + ) diff --git a/deep_sort/track_interpolation.py b/deep_sort/track_interpolation.py new file mode 100644 index 000000000..0946a12ac --- /dev/null +++ b/deep_sort/track_interpolation.py @@ -0,0 +1,306 @@ +import numpy as np +from scipy.interpolate import interp1d + + +class TrackInterpolator: + """ + 轨迹插值器,用于填充被遮挡期间的缺失帧。 + + 当一个目标被遮挡若干帧后又被找回时,使用插值算法填补中间缺失的帧。 + 支持线性插值和样条插值。 + + Parameters + ---------- + max_gap : int + 最大可插值的帧间隔。超过此值的间隙不会被插值。 + min_track_length : int + 执行插值所需的最小轨迹长度。 + interpolation_method : str + 插值方法,可选 'linear' 或 'cubic'。 + """ + + def __init__(self, max_gap=30, min_track_length=5, interpolation_method='linear'): + self.max_gap = max_gap + self.min_track_length = min_track_length + self.interpolation_method = interpolation_method + + def interpolate_track(self, track_history): + """ + 对单个轨迹进行插值。 + + Parameters + ---------- + track_history : list + 轨迹历史记录,每个元素包含帧索引和状态信息。 + + Returns + ------- + list + 插值后的轨迹点列表。 + """ + if len(track_history) < self.min_track_length: + return track_history + + # 提取帧索引和边界框信息 + frames = [] + bboxes = [] + + for state in track_history: + # 从mean中提取边界框 (x, y, a, h) + mean = state['mean'] + x, y, a, h = mean[0], mean[1], mean[2], mean[3] + # 转换为xywh + w = a * h + x_tl = x - w / 2 + y_tl = y - h / 2 + + frames.append(state.get('frame_idx', len(frames))) + bboxes.append([x_tl, y_tl, w, h]) + + frames = np.array(frames) + bboxes = np.array(bboxes) + + # 检测间隙 + gaps = np.diff(frames) + interpolated_states = [] + + current_idx = 0 + interpolated_states.append(track_history[0]) + + for i, gap in enumerate(gaps): + if gap > 1 and gap <= self.max_gap: + # 需要插值 + start_frame = frames[i] + end_frame = frames[i + 1] + start_bbox = bboxes[i] + end_bbox = bboxes[i + 1] + + # 创建插值点 + interp_frames = np.arange(start_frame + 1, end_frame) + + # 对每个维度进行插值 + interp_bboxes = np.zeros((len(interp_frames), 4)) + for dim in range(4): + if self.interpolation_method == 'linear': + f = interp1d([start_frame, end_frame], + [start_bbox[dim], end_bbox[dim]], + kind='linear') + else: # cubic + f = interp1d([start_frame, end_frame], + [start_bbox[dim], end_bbox[dim]], + kind='cubic') + interp_bboxes[:, dim] = f(interp_frames) + + # 创建插值状态 + for j, (f_idx, bbox) in enumerate(zip(interp_frames, interp_bboxes)): + x_tl, y_tl, w, h = bbox + x = x_tl + w / 2 + y = y_tl + h / 2 + a = w / h if h > 0 else 1.0 + + # 创建新的状态 + interp_state = { + 'mean': np.array([x, y, a, h, 0, 0, 0, 0]), + 'covariance': track_history[i]['covariance'].copy(), + 'frame_idx': int(f_idx), + 'interpolated': True + } + interpolated_states.append(interp_state) + + interpolated_states.append(track_history[i + 1]) + current_idx = i + 1 + + return interpolated_states + + def interpolate_tracks(self, tracks): + """ + 对多个轨迹进行插值。 + + Parameters + ---------- + tracks : list + 轨迹列表,每个轨迹是一个历史状态列表。 + + Returns + ------- + list + 插值后的轨迹列表。 + """ + interpolated_tracks = [] + for track in tracks: + interpolated_track = self.interpolate_track(track) + interpolated_tracks.append(interpolated_track) + return interpolated_tracks + + +class TrajectorySmoother: + """ + 轨迹平滑器,使用卡尔曼平滑或移动平均来平滑轨迹。 + + Parameters + ---------- + window_size : int + 移动平均窗口大小。 + method : str + 平滑方法,可选 'moving_average' 或 'savgol'。 + """ + + def __init__(self, window_size=5, method='moving_average'): + self.window_size = window_size + self.method = method + + def smooth(self, track_history): + """ + 平滑轨迹。 + + Parameters + ---------- + track_history : list + 轨迹历史记录。 + + Returns + ------- + list + 平滑后的轨迹。 + """ + if len(track_history) < self.window_size: + return track_history + + # 提取边界框 + bboxes = [] + for state in track_history: + mean = state['mean'] + x, y, a, h = mean[0], mean[1], mean[2], mean[3] + w = a * h + x_tl = x - w / 2 + y_tl = y - h / 2 + bboxes.append([x_tl, y_tl, w, h]) + + bboxes = np.array(bboxes) + + # 应用平滑 + if self.method == 'moving_average': + smoothed_bboxes = self._moving_average(bboxes) + elif self.method == 'savgol': + smoothed_bboxes = self._savgol_filter(bboxes) + else: + smoothed_bboxes = bboxes + + # 更新状态 + smoothed_history = [] + for i, (state, bbox) in enumerate(zip(track_history, smoothed_bboxes)): + x_tl, y_tl, w, h = bbox + x = x_tl + w / 2 + y = y_tl + h / 2 + a = w / h if h > 0 else 1.0 + + smoothed_state = state.copy() + smoothed_state['mean'] = np.array([x, y, a, h, + state['mean'][4], state['mean'][5], + state['mean'][6], state['mean'][7]]) + smoothed_history.append(smoothed_state) + + return smoothed_history + + def _moving_average(self, data): + """应用移动平均。""" + smoothed = np.copy(data) + half_window = self.window_size // 2 + + for i in range(len(data)): + start = max(0, i - half_window) + end = min(len(data), i + half_window + 1) + smoothed[i] = np.mean(data[start:end], axis=0) + + return smoothed + + def _savgol_filter(self, data): + """应用Savitzky-Golay滤波器。""" + from scipy.signal import savgol_filter + + smoothed = np.copy(data) + for dim in range(data.shape[1]): + smoothed[:, dim] = savgol_filter( + data[:, dim], + window_length=min(self.window_size, len(data) // 2 * 2 + 1), + polyorder=2 + ) + return smoothed + + +class TrackPostProcessor: + """ + 轨迹后处理器,整合插值和平滑功能。 + + Parameters + ---------- + enable_interpolation : bool + 是否启用轨迹插值。 + enable_smoothing : bool + 是否启用轨迹平滑。 + max_gap : int + 最大插值间隔。 + smooth_window : int + 平滑窗口大小。 + """ + + def __init__(self, enable_interpolation=True, enable_smoothing=True, + max_gap=30, smooth_window=5): + self.enable_interpolation = enable_interpolation + self.enable_smoothing = enable_smoothing + + self.interpolator = TrackInterpolator(max_gap=max_gap) if enable_interpolation else None + self.smoother = TrajectorySmoother(window_size=smooth_window) if enable_smoothing else None + + def process(self, tracks): + """ + 处理轨迹。 + + Parameters + ---------- + tracks : list + 轨迹列表。 + + Returns + ------- + list + 处理后的轨迹列表。 + """ + processed_tracks = [] + + for track in tracks: + processed_track = track + + # 插值 + if self.interpolator is not None: + processed_track = self.interpolator.interpolate_track(processed_track) + + # 平滑 + if self.smoother is not None: + processed_track = self.smoother.smooth(processed_track) + + processed_tracks.append(processed_track) + + return processed_tracks + + def process_frame_results(self, track_history_dict): + """ + 处理帧级别的跟踪结果。 + + Parameters + ---------- + track_history_dict : dict + 轨迹历史字典,键为track_id,值为历史状态列表。 + + Returns + ------- + dict + 处理后的轨迹字典。 + """ + tracks = list(track_history_dict.values()) + processed_tracks = self.process(tracks) + + return { + track_id: processed_track + for track_id, processed_track in zip(track_history_dict.keys(), processed_tracks) + } diff --git a/deep_sort/tracker.py b/deep_sort/tracker.py index de99de44e..ee19518a7 100644 --- a/deep_sort/tracker.py +++ b/deep_sort/tracker.py @@ -1,5 +1,3 @@ -# vim: expandtab:ts=4:sw=4 -from __future__ import absolute_import import numpy as np from . import kalman_filter from . import linear_assignment @@ -9,112 +7,153 @@ class Tracker: """ - This is the multi-target tracker. + 多目标跟踪器,支持多类别跟踪。 Parameters ---------- metric : nn_matching.NearestNeighborDistanceMetric - A distance metric for measurement-to-track association. + 测量到轨迹关联的距离度量。 max_age : int - Maximum number of missed misses before a track is deleted. + 轨迹被删除前的最大丢失次数。 n_init : int - Number of consecutive detections before the track is confirmed. The - track state is set to `Deleted` if a miss occurs within the first - `n_init` frames. + 轨迹被确认前的连续检测次数。如果在最初的 `n_init` 帧内发生丢失, + 轨迹状态被设置为 `Deleted`。 + max_iou_distance : float + IOU关联的最大距离阈值。 + ema_alpha : float + EMA特征更新的平滑因子。 + mc_lambda : float + 多类别跟踪的类别匹配权重。 Attributes ---------- metric : nn_matching.NearestNeighborDistanceMetric - The distance metric used for measurement to track association. + 用于测量到轨迹关联的距离度量。 max_age : int - Maximum number of missed misses before a track is deleted. + 轨迹被删除前的最大丢失次数。 n_init : int - Number of frames that a track remains in initialization phase. + 轨迹保持在初始化阶段的帧数。 kf : kalman_filter.KalmanFilter - A Kalman filter to filter target trajectories in image space. + 用于在图像空间中过滤目标轨迹的卡尔曼滤波器。 tracks : List[Track] - The list of active tracks at the current time step. - + 当前时间步的活跃轨迹列表。 + ema_alpha : float + EMA特征更新的平滑因子。 + mc_lambda : float + 多类别跟踪的类别匹配权重。 """ - def __init__(self, metric, max_iou_distance=0.7, max_age=30, n_init=3): + def __init__(self, metric, max_iou_distance=0.7, max_age=30, n_init=3, + ema_alpha=0.9, mc_lambda=0.995): self.metric = metric self.max_iou_distance = max_iou_distance self.max_age = max_age self.n_init = n_init + self.ema_alpha = ema_alpha + self.mc_lambda = mc_lambda self.kf = kalman_filter.KalmanFilter() self.tracks = [] self._next_id = 1 + # 用于轨迹插值的已删除轨迹缓存 + self._deleted_tracks_cache = [] + self.max_cache_size = 100 + def predict(self): - """Propagate track state distributions one time step forward. + """将轨迹状态分布向前传播一个时间步。 - This function should be called once every time step, before `update`. + 此函数应在每个时间步调用一次,在 `update` 之前。 """ for track in self.tracks: track.predict(self.kf) def update(self, detections): - """Perform measurement update and track management. + """执行测量更新和轨迹管理。 Parameters ---------- detections : List[deep_sort.detection.Detection] - A list of detections at the current time step. - + 当前时间步的检测列表。 """ - # Run matching cascade. + # 运行匹配级联。 matches, unmatched_tracks, unmatched_detections = \ self._match(detections) - # Update track set. + # 更新轨迹集合。 for track_idx, detection_idx in matches: self.tracks[track_idx].update( self.kf, detections[detection_idx]) for track_idx in unmatched_tracks: self.tracks[track_idx].mark_missed() + + # 缓存已删除的轨迹用于后续插值 + deleted_tracks = [t for t in self.tracks if t.is_deleted()] + for track in deleted_tracks: + if len(track.history) > 5: # 只缓存有足够历史的轨迹 + self._deleted_tracks_cache.append(track) + # 限制缓存大小 + if len(self._deleted_tracks_cache) > self.max_cache_size: + self._deleted_tracks_cache = self._deleted_tracks_cache[-self.max_cache_size:] + + # 移除已删除的轨迹 + self.tracks = [t for t in self.tracks if not t.is_deleted()] + for detection_idx in unmatched_detections: self._initiate_track(detections[detection_idx]) - self.tracks = [t for t in self.tracks if not t.is_deleted()] - # Update distance metric. + # 更新距离度量。 active_targets = [t.track_id for t in self.tracks if t.is_confirmed()] features, targets = [], [] for track in self.tracks: if not track.is_confirmed(): continue - features += track.features - targets += [track.track_id for _ in track.features] + # 使用EMA平滑后的特征 + smoothed_feature = track.get_smoothed_feature() + if smoothed_feature is not None: + features.append(smoothed_feature) + targets.append(track.track_id) track.features = [] - self.metric.partial_fit( - np.asarray(features), np.asarray(targets), active_targets) + if len(features) > 0: + self.metric.partial_fit( + np.asarray(features), np.asarray(targets), active_targets) def _match(self, detections): + """执行轨迹和检测之间的关联。""" def gated_metric(tracks, dets, track_indices, detection_indices): features = np.array([dets[i].feature for i in detection_indices]) targets = np.array([tracks[i].track_id for i in track_indices]) cost_matrix = self.metric.distance(features, targets) + + # 多类别:添加类别匹配成本 + for i, track_idx in enumerate(track_indices): + track = tracks[track_idx] + for j, det_idx in enumerate(detection_indices): + det = dets[det_idx] + if track.class_id != det.class_id: + # 类别不匹配时增加成本 + cost_matrix[i, j] += (1 - self.mc_lambda) * 10 + cost_matrix = linear_assignment.gate_cost_matrix( self.kf, cost_matrix, tracks, dets, track_indices, detection_indices) return cost_matrix - # Split track set into confirmed and unconfirmed tracks. + # 将轨迹集合分为已确认和未确认的轨迹。 confirmed_tracks = [ i for i, t in enumerate(self.tracks) if t.is_confirmed()] unconfirmed_tracks = [ i for i, t in enumerate(self.tracks) if not t.is_confirmed()] - # Associate confirmed tracks using appearance features. + # 使用外观特征关联已确认的轨迹。 matches_a, unmatched_tracks_a, unmatched_detections = \ linear_assignment.matching_cascade( gated_metric, self.metric.matching_threshold, self.max_age, self.tracks, detections, confirmed_tracks) - # Associate remaining tracks together with unconfirmed tracks using IOU. + # 使用IOU关联剩余的轨迹和未确认的轨迹。 iou_track_candidates = unconfirmed_tracks + [ k for k in unmatched_tracks_a if self.tracks[k].time_since_update == 1] @@ -131,8 +170,24 @@ def gated_metric(tracks, dets, track_indices, detection_indices): return matches, unmatched_tracks, unmatched_detections def _initiate_track(self, detection): + """从检测初始化新轨迹。""" mean, covariance = self.kf.initiate(detection.to_xyah()) self.tracks.append(Track( mean, covariance, self._next_id, self.n_init, self.max_age, - detection.feature)) + detection.feature, detection.class_id, detection.class_name, + self.ema_alpha)) self._next_id += 1 + + def get_deleted_tracks_for_interpolation(self): + """获取已删除的轨迹缓存,用于轨迹插值。 + + Returns + ------- + List[Track] + 已删除的轨迹列表。 + """ + return self._deleted_tracks_cache + + def clear_deleted_tracks_cache(self): + """清空已删除轨迹缓存。""" + self._deleted_tracks_cache = [] diff --git a/deep_sort/yolo_detector.py b/deep_sort/yolo_detector.py new file mode 100644 index 000000000..b0fde92e6 --- /dev/null +++ b/deep_sort/yolo_detector.py @@ -0,0 +1,306 @@ +import numpy as np +import cv2 +from ultralytics import YOLO + + +class YOLODetector: + """ + YOLOv8检测器封装类,集成到Deep SORT中。 + + Parameters + ---------- + model_path : str + YOLOv8模型路径或模型名称。 + conf_threshold : float + 置信度阈值。 + iou_threshold : float + NMS的IOU阈值。 + device : str + 计算设备。 + classes : list, optional + 要检测的类别列表,None表示检测所有类别。 + + Attributes + ---------- + model : YOLO + YOLOv8模型实例。 + conf_threshold : float + 置信度阈值。 + iou_threshold : float + NMS的IOU阈值。 + classes : list + 要检测的类别列表。 + class_names : dict + 类别名称映射。 + """ + + # COCO数据集的类别名称 + COCO_CLASSES = { + 0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', + 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', + 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', + 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', + 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', + 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', + 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', + 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', + 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', + 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', + 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', + 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', + 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', + 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', + 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', + 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', + 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', + 79: 'toothbrush' + } + + def __init__(self, model_path='yolov8n.pt', conf_threshold=0.3, + iou_threshold=0.45, device=None, classes=None): + self.model = YOLO(model_path) + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + self.classes = classes + self.device = device + + # 获取类别名称 + if hasattr(self.model, 'names'): + self.class_names = self.model.names + else: + self.class_names = self.COCO_CLASSES + + def detect(self, image): + """ + 对图像进行目标检测。 + + Parameters + ---------- + image : ndarray + 输入图像,BGR格式。 + + Returns + ------- + tuple + (boxes, confidences, class_ids, class_names) + - boxes: 边界框数组,格式为 (x, y, w, h) + - confidences: 置信度数组 + - class_ids: 类别ID数组 + - class_names: 类别名称列表 + """ + # 运行检测 + results = self.model( + image, + conf=self.conf_threshold, + iou=self.iou_threshold, + classes=self.classes, + device=self.device, + verbose=False + )[0] + + boxes = [] + confidences = [] + class_ids = [] + class_names = [] + + if results.boxes is not None: + # 获取检测结果 + det_boxes = results.boxes.xyxy.cpu().numpy() # xyxy格式 + det_confidences = results.boxes.conf.cpu().numpy() + det_class_ids = results.boxes.cls.cpu().numpy().astype(int) + + for box, conf, cls_id in zip(det_boxes, det_confidences, det_class_ids): + x1, y1, x2, y2 = box + # 转换为xywh格式 + x = x1 + y = y1 + w = x2 - x1 + h = y2 - y1 + + boxes.append([x, y, w, h]) + confidences.append(conf) + class_ids.append(cls_id) + class_names.append(self.class_names.get(cls_id, f"class_{cls_id}")) + + return np.array(boxes), np.array(confidences), np.array(class_ids), class_names + + def detect_and_track(self, image, tracker, feature_extractor=None): + """ + 检测目标并更新跟踪器。 + + Parameters + ---------- + image : ndarray + 输入图像。 + tracker : Tracker + Deep SORT跟踪器实例。 + feature_extractor : FeatureExtractor, optional + 特征提取器。 + + Returns + ------- + list + 跟踪结果列表,每个元素为 (track_id, bbox, class_id, class_name)。 + """ + from .detection import Detection + + # 检测目标 + boxes, confidences, class_ids, class_names = self.detect(image) + + # 提取特征 + if feature_extractor is not None and len(boxes) > 0: + features = feature_extractor(image, boxes) + else: + features = [None] * len(boxes) + + # 创建Detection对象 + detections = [] + for box, conf, cls_id, cls_name, feat in zip( + boxes, confidences, class_ids, class_names, features): + detections.append(Detection( + tlwh=box, + confidence=conf, + feature=feat, + class_id=cls_id, + class_name=cls_name + )) + + # 更新跟踪器 + tracker.predict() + tracker.update(detections) + + # 收集跟踪结果 + results = [] + for track in tracker.tracks: + if track.is_confirmed() and track.time_since_update == 0: + bbox = track.to_tlbr() # 使用tlbr格式用于可视化 + results.append({ + 'track_id': track.track_id, + 'bbox': bbox, + 'class_id': track.class_id, + 'class_name': track.class_name, + 'tlwh': track.to_tlwh() + }) + + return results + + +class VideoCapture: + """ + 视频捕获类,支持视频文件和摄像头输入。 + """ + + def __init__(self, source): + """ + Parameters + ---------- + source : str or int + 视频文件路径或摄像头索引。 + """ + self.source = source + self.cap = None + self.frame_count = 0 + self.fps = 0 + self.width = 0 + self.height = 0 + + def open(self): + """打开视频源。""" + self.cap = cv2.VideoCapture(self.source) + if not self.cap.isOpened(): + raise ValueError(f"无法打开视频源: {self.source}") + + self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.fps = self.cap.get(cv2.CAP_PROP_FPS) + self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + return self + + def read(self): + """读取一帧。""" + if self.cap is None: + return False, None + return self.cap.read() + + def release(self): + """释放视频源。""" + if self.cap is not None: + self.cap.release() + self.cap = None + + def __enter__(self): + return self.open() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + return False + + def __iter__(self): + """支持迭代。""" + return self + + def __next__(self): + """获取下一帧。""" + ret, frame = self.read() + if not ret: + raise StopIteration + return frame + + +def draw_tracks(image, tracks, show_class=True, show_id=True): + """ + 在图像上绘制跟踪结果。 + + Parameters + ---------- + image : ndarray + 输入图像。 + tracks : list + 跟踪结果列表。 + show_class : bool + 是否显示类别名称。 + show_id : bool + 是否显示跟踪ID。 + + Returns + ------- + ndarray + 绘制后的图像。 + """ + # 为不同类别生成颜色 + np.random.seed(42) + colors = np.random.randint(0, 255, size=(100, 3), dtype=np.uint8) + + for track in tracks: + track_id = track['track_id'] + bbox = track['bbox'] + class_id = track['class_id'] + class_name = track['class_name'] + + x1, y1, x2, y2 = bbox.astype(int) + + # 使用类别ID选择颜色 + color = tuple(int(c) for c in colors[class_id % len(colors)]) + + # 绘制边界框 + cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) + + # 绘制标签 + label_parts = [] + if show_id: + label_parts.append(f"ID:{track_id}") + if show_class: + label_parts.append(class_name) + + if label_parts: + label = " ".join(label_parts) + label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) + label_y = max(y1, label_size[1] + 10) + + cv2.rectangle(image, (x1, y1 - label_size[1] - 10), + (x1 + label_size[0], y1), color, -1) + cv2.putText(image, label, (x1, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + + return image diff --git a/deep_sort_app_v2.py b/deep_sort_app_v2.py new file mode 100644 index 000000000..228a2ac87 --- /dev/null +++ b/deep_sort_app_v2.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 +""" +Deep SORT v2 - 端到端多目标跟踪应用 + +支持: +1. PyTorch深度学习框架 +2. 多类别跟踪 +3. YOLOv8集成(视频流/摄像头输入) +4. EMA特征更新 +5. 轨迹插值与平滑 +""" + +import argparse +import os +import time +from collections import defaultdict + +import cv2 +import numpy as np + +from deep_sort import nn_matching +from deep_sort.tracker import Tracker +from deep_sort.yolo_detector import YOLODetector, VideoCapture, draw_tracks +from deep_sort.deep_feature_extractor import FeatureExtractor +from deep_sort.track_interpolation import TrackPostProcessor + + +def run_tracking( + source, + output_path=None, + model='yolov8n.pt', + feature_model=None, + conf_threshold=0.3, + iou_threshold=0.45, + max_cosine_distance=0.2, + nn_budget=100, + max_age=30, + n_init=3, + ema_alpha=0.9, + mc_lambda=0.995, + display=True, + save_video=False, + classes=None, + device=None, + enable_interpolation=True, + enable_smoothing=True, + max_interpolation_gap=30 +): + """ + 运行端到端跟踪。 + + Parameters + ---------- + source : str or int + 视频文件路径或摄像头索引。 + output_path : str, optional + 输出视频路径。 + model : str + YOLOv8模型路径。 + feature_model : str, optional + 特征提取器模型路径。 + conf_threshold : float + 检测置信度阈值。 + iou_threshold : float + NMS IOU阈值。 + max_cosine_distance : float + 最大余弦距离。 + nn_budget : int + 外观描述符库的最大大小。 + max_age : int + 轨迹最大存活帧数。 + n_init : int + 轨迹初始化所需检测次数。 + ema_alpha : float + EMA平滑因子。 + mc_lambda : float + 多类别匹配权重。 + display : bool + 是否显示结果。 + save_video : bool + 是否保存视频。 + classes : list, optional + 要跟踪的类别列表。 + device : str, optional + 计算设备。 + enable_interpolation : bool + 是否启用轨迹插值。 + enable_smoothing : bool + 是否启用轨迹平滑。 + max_interpolation_gap : int + 最大插值间隔。 + """ + # 创建检测器 + print(f"Loading YOLO model: {model}") + detector = YOLODetector( + model_path=model, + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + device=device, + classes=classes + ) + + # 创建特征提取器 + print("Initializing feature extractor...") + feature_extractor = FeatureExtractor( + model_path=feature_model, + device=device if device else ('cuda' if cv2.cuda.getCudaEnabledDeviceCount() > 0 else 'cpu') + ) + + # 创建跟踪器 + metric = nn_matching.NearestNeighborDistanceMetric( + "cosine", max_cosine_distance, nn_budget) + tracker = Tracker( + metric, + max_iou_distance=iou_threshold, + max_age=max_age, + n_init=n_init, + ema_alpha=ema_alpha, + mc_lambda=mc_lambda + ) + + # 创建轨迹后处理器 + post_processor = TrackPostProcessor( + enable_interpolation=enable_interpolation, + enable_smoothing=enable_smoothing, + max_gap=max_interpolation_gap + ) + + # 打开视频源 + print(f"Opening video source: {source}") + with VideoCapture(source) as cap: + print(f"Video info: {cap.width}x{cap.height} @ {cap.fps:.2f}fps") + + # 设置视频写入器 + video_writer = None + if save_video and output_path: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter( + output_path, fourcc, cap.fps, (cap.width, cap.height)) + + # 跟踪历史记录 + track_history = defaultdict(list) + frame_idx = 0 + + # 处理统计 + start_time = time.time() + processing_times = [] + + print("\nStarting tracking...") + print("Press 'q' to quit, 'p' to pause\n") + + for frame in cap: + frame_idx += 1 + loop_start = time.time() + + # 检测和跟踪 + results = detector.detect_and_track( + frame, tracker, feature_extractor) + + # 记录轨迹历史 + for result in results: + track_id = result['track_id'] + track_history[track_id].append({ + 'frame_idx': frame_idx, + 'bbox': result['tlwh'], + 'class_id': result['class_id'], + 'class_name': result['class_name'] + }) + + # 绘制结果 + display_frame = frame.copy() + display_frame = draw_tracks(display_frame, results) + + # 添加统计信息 + elapsed = time.time() - loop_start + processing_times.append(elapsed) + avg_time = np.mean(processing_times[-30:]) + fps = 1.0 / avg_time if avg_time > 0 else 0 + + info_text = [ + f"Frame: {frame_idx}", + f"Tracks: {len(results)}", + f"FPS: {fps:.1f}", + f"Time: {elapsed*1000:.1f}ms" + ] + for i, text in enumerate(info_text): + cv2.putText(display_frame, text, (10, 30 + i * 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + # 保存视频 + if video_writer is not None: + video_writer.write(display_frame) + + # 显示结果 + if display: + cv2.imshow('Deep SORT v2', display_frame) + key = cv2.waitKey(1) & 0xFF + if key == ord('q'): + break + elif key == ord('p'): + print("Paused. Press any key to continue...") + cv2.waitKey(0) + + # 打印进度 + if frame_idx % 30 == 0: + print(f"Frame {frame_idx}: {len(results)} tracks, {fps:.1f} FPS") + + # 释放资源 + if video_writer is not None: + video_writer.release() + cv2.destroyAllWindows() + + # 后处理轨迹 + print("\nPost-processing tracks...") + processed_tracks = post_processor.process_frame_results(track_history) + + # 输出统计 + total_time = time.time() - start_time + print(f"\nTracking completed!") + print(f"Total frames: {frame_idx}") + print(f"Total time: {total_time:.2f}s") + print(f"Average FPS: {frame_idx / total_time:.1f}") + print(f"Unique tracks: {len(track_history)}") + + return processed_tracks + + +def run_on_video( + video_path, + output_path=None, + **kwargs +): + """在视频文件上运行跟踪。""" + return run_tracking( + source=video_path, + output_path=output_path, + **kwargs + ) + + +def run_on_camera( + camera_id=0, + output_path=None, + **kwargs +): + """在摄像头上运行跟踪。""" + return run_tracking( + source=camera_id, + output_path=output_path, + **kwargs + ) + + +def parse_args(): + """解析命令行参数。""" + parser = argparse.ArgumentParser(description='Deep SORT v2 - End-to-End Tracking') + + # 输入输出 + parser.add_argument('--source', type=str, default='0', + help='视频文件路径或摄像头索引 (默认: 0)') + parser.add_argument('--output', type=str, default=None, + help='输出视频路径') + parser.add_argument('--save-video', action='store_true', + help='保存输出视频') + + # 模型设置 + parser.add_argument('--model', type=str, default='yolov8n.pt', + help='YOLOv8模型路径 (默认: yolov8n.pt)') + parser.add_argument('--feature-model', type=str, default=None, + help='特征提取器模型路径') + + # 检测参数 + parser.add_argument('--conf-threshold', type=float, default=0.3, + help='检测置信度阈值 (默认: 0.3)') + parser.add_argument('--iou-threshold', type=float, default=0.45, + help='NMS IOU阈值 (默认: 0.45)') + parser.add_argument('--classes', type=int, nargs='+', default=None, + help='要跟踪的类别ID列表 (默认: 所有类别)') + + # 跟踪参数 + parser.add_argument('--max-cosine-distance', type=float, default=0.2, + help='最大余弦距离 (默认: 0.2)') + parser.add_argument('--nn-budget', type=int, default=100, + help='外观描述符库大小 (默认: 100)') + parser.add_argument('--max-age', type=int, default=30, + help='轨迹最大存活帧数 (默认: 30)') + parser.add_argument('--n-init', type=int, default=3, + help='轨迹初始化所需检测次数 (默认: 3)') + parser.add_argument('--ema-alpha', type=float, default=0.9, + help='EMA平滑因子 (默认: 0.9)') + parser.add_argument('--mc-lambda', type=float, default=0.995, + help='多类别匹配权重 (默认: 0.995)') + + # 后处理参数 + parser.add_argument('--no-interpolation', action='store_true', + help='禁用轨迹插值') + parser.add_argument('--no-smoothing', action='store_true', + help='禁用轨迹平滑') + parser.add_argument('--max-interpolation-gap', type=int, default=30, + help='最大插值间隔 (默认: 30)') + + # 其他 + parser.add_argument('--no-display', action='store_true', + help='不显示结果') + parser.add_argument('--device', type=str, default=None, + help='计算设备 (cuda/cpu)') + + return parser.parse_args() + + +def main(): + args = parse_args() + + # 解析输入源 + try: + source = int(args.source) + print(f"Using camera: {source}") + except ValueError: + source = args.source + if not os.path.exists(source): + raise ValueError(f"Video file not found: {source}") + print(f"Using video file: {source}") + + # 运行跟踪 + run_tracking( + source=source, + output_path=args.output, + model=args.model, + feature_model=args.feature_model, + conf_threshold=args.conf_threshold, + iou_threshold=args.iou_threshold, + max_cosine_distance=args.max_cosine_distance, + nn_budget=args.nn_budget, + max_age=args.max_age, + n_init=args.n_init, + ema_alpha=args.ema_alpha, + mc_lambda=args.mc_lambda, + display=not args.no_display, + save_video=args.save_video, + classes=args.classes, + device=args.device, + enable_interpolation=not args.no_interpolation, + enable_smoothing=not args.no_smoothing, + max_interpolation_gap=args.max_interpolation_gap + ) + + +if __name__ == '__main__': + main() diff --git a/example_usage.py b/example_usage.py new file mode 100644 index 000000000..954913c3e --- /dev/null +++ b/example_usage.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +Deep SORT v2 使用示例 + +演示如何使用新的PyTorch版本Deep SORT进行多目标跟踪。 +""" + +import cv2 +from deep_sort import ( + Tracker, + NearestNeighborDistanceMetric, + YOLODetector, + FeatureExtractor, + TrackPostProcessor +) + + +def example_basic_tracking(): + """基本跟踪示例。""" + print("=" * 50) + print("示例1: 基本跟踪") + print("=" * 50) + + # 创建检测器 + detector = YOLODetector( + model_path='yolov8n.pt', + conf_threshold=0.3, + classes=[0] # 只检测人 + ) + + # 创建特征提取器 + feature_extractor = FeatureExtractor(device='cpu') + + # 创建跟踪器 + metric = NearestNeighborDistanceMetric("cosine", 0.2, 100) + tracker = Tracker( + metric, + max_iou_distance=0.7, + max_age=30, + n_init=3, + ema_alpha=0.9, # EMA平滑因子 + mc_lambda=0.995 # 多类别匹配权重 + ) + + # 打开视频 + video_path = 'path/to/your/video.mp4' # 替换为你的视频路径 + cap = cv2.VideoCapture(video_path) + + frame_idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + + frame_idx += 1 + + # 检测和跟踪 + results = detector.detect_and_track(frame, tracker, feature_extractor) + + # 打印结果 + print(f"Frame {frame_idx}: {len(results)} tracks") + for r in results: + print(f" ID {r['track_id']}: {r['class_name']} at {r['bbox']}") + + # 可视化(可选) + # ... + + if frame_idx >= 100: # 只处理前100帧 + break + + cap.release() + print("示例1完成\n") + + +def example_multi_class_tracking(): + """多类别跟踪示例。""" + print("=" * 50) + print("示例2: 多类别跟踪") + print("=" * 50) + + # 创建检测器 - 检测多个类别 + detector = YOLODetector( + model_path='yolov8n.pt', + conf_threshold=0.3, + classes=[0, 2, 3, 5, 7] # 人、汽车、摩托车、公交车、卡车 + ) + + feature_extractor = FeatureExtractor(device='cpu') + + metric = NearestNeighborDistanceMetric("cosine", 0.2, 100) + tracker = Tracker( + metric, + max_iou_distance=0.7, + max_age=30, + n_init=3, + ema_alpha=0.9, + mc_lambda=0.995 # 启用多类别匹配 + ) + + print("支持跟踪的类别:") + for class_id, class_name in detector.class_names.items(): + if class_id in [0, 2, 3, 5, 7]: + print(f" {class_id}: {class_name}") + + print("\n示例2完成\n") + + +def example_with_interpolation(): + """带轨迹插值的跟踪示例。""" + print("=" * 50) + print("示例3: 带轨迹插值的跟踪") + print("=" * 50) + + # 创建后处理器 + post_processor = TrackPostProcessor( + enable_interpolation=True, + enable_smoothing=True, + max_gap=30, # 最大插值间隔 + smooth_window=5 # 平滑窗口大小 + ) + + # 模拟一些轨迹数据 + from collections import defaultdict + track_history = defaultdict(list) + + # 模拟轨迹(实际使用时从跟踪器获取) + for track_id in range(3): + for frame_idx in range(100): + if frame_idx % 10 != 5: # 模拟丢失第5, 15, 25...帧 + track_history[track_id].append({ + 'frame_idx': frame_idx, + 'bbox': [100 + track_id * 50 + frame_idx, 200, 50, 100], + 'mean': [125 + track_id * 50 + frame_idx, 250, 0.5, 100, 0, 0, 0, 0], + 'covariance': [[1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1]], + 'class_id': 0, + 'class_name': 'person' + }) + + print(f"原始轨迹数: {len(track_history)}") + for track_id, history in track_history.items(): + print(f" Track {track_id}: {len(history)} frames") + + # 后处理 + processed_tracks = post_processor.process_frame_results(track_history) + + print(f"\n处理后轨迹数: {len(processed_tracks)}") + for track_id, history in processed_tracks.items(): + interpolated_count = sum(1 for h in history if h.get('interpolated', False)) + print(f" Track {track_id}: {len(history)} frames ({interpolated_count} interpolated)") + + print("\n示例3完成\n") + + +def example_camera_tracking(): + """摄像头实时跟踪示例。""" + print("=" * 50) + print("示例4: 摄像头实时跟踪") + print("=" * 50) + + print(""" + 要使用摄像头进行实时跟踪,请运行: + + python deep_sort_app_v2.py --source 0 --model yolov8n.pt + + 参数说明: + --source 0 : 使用默认摄像头 + --model yolov8n.pt : YOLOv8模型 + --classes 0 : 只跟踪人 + --conf-threshold 0.3: 置信度阈值 + --max-age 30 : 轨迹最大存活帧数 + --ema-alpha 0.9 : EMA平滑因子 + --save-video : 保存结果视频 + --output result.mp4 : 输出视频路径 + """) + + print("示例4完成\n") + + +def example_video_tracking(): + """视频文件跟踪示例。""" + print("=" * 50) + print("示例5: 视频文件跟踪") + print("=" * 50) + + print(""" + 要对视频文件进行跟踪,请运行: + + python deep_sort_app_v2.py \ + --source path/to/video.mp4 \ + --model yolov8n.pt \ + --classes 0 2 3 5 7 \ + --conf-threshold 0.3 \ + --iou-threshold 0.45 \ + --max-cosine-distance 0.2 \ + --nn-budget 100 \ + --max-age 30 \ + --n-init 3 \ + --ema-alpha 0.9 \ + --mc-lambda 0.995 \ + --save-video \ + --output result.mp4 + + 高级选项: + --no-interpolation : 禁用轨迹插值 + --no-smoothing : 禁用轨迹平滑 + --max-interpolation-gap 30 : 最大插值间隔 + """) + + print("示例5完成\n") + + +def main(): + """运行所有示例。""" + print("\n" + "=" * 50) + print("Deep SORT v2 使用示例") + print("=" * 50 + "\n") + + # 运行示例 + example_multi_class_tracking() + example_with_interpolation() + example_camera_tracking() + example_video_tracking() + + print("=" * 50) + print("所有示例完成!") + print("=" * 50) + + +if __name__ == '__main__': + main() diff --git a/requirements.txt b/requirements.txt index e7590640d..7e570faf7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ +torch>=2.0.0 +torchvision>=0.15.0 +ultralytics>=8.0.0 numpy opencv-python scipy