Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion second/core/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _points_count_convex_polygon_3d_jit(points,
return ret


@numba.jit
@numba.jit(forceobj=True)
def points_in_convex_polygon_jit(points, polygon, clockwise=True):
"""check points is in 2d convex polygons. True when point in polygon
Args:
Expand Down
24 changes: 11 additions & 13 deletions second/data/nuscenes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self,
"pedestrian": "pedestrian",
} # we only eval these classes in kitti
self.version = self._metadata["version"]
self.eval_version = "cvpr_2019"
self.eval_version = "detection_cvpr_2019"
self._with_velocity = False

def __len__(self):
Expand All @@ -82,9 +82,8 @@ def __len__(self):
def ground_truth_annotations(self):
if "gt_boxes" not in self._nusc_infos[0]:
return None
from nuscenes.eval.detection.config import eval_detection_configs
cls_range_map = eval_detection_configs[self.
eval_version]["class_range"]
from nuscenes.eval.detection.config import config_factory
cls_range_map = config_factory(self.eval_version).class_range
gt_annos = []
for info in self._nusc_infos:
gt_names = info["gt_names"]
Expand Down Expand Up @@ -329,7 +328,7 @@ def evaluation_nusc(self, detections, output_dir):
box.velocity = np.array([*velocity, 0.0])
boxes = _lidar_nusc_box_to_global(
token2info[det["metadata"]["token"]], boxes,
mapped_class_names, "cvpr_2019")
mapped_class_names, "detection_cvpr_2019")
for i, box in enumerate(boxes):
name = mapped_class_names[box.label]
velocity = box.velocity[:2].tolist()
Expand Down Expand Up @@ -545,16 +544,16 @@ def _second_det_to_nusc_box(detection):
return box_list


def _lidar_nusc_box_to_global(info, boxes, classes, eval_version="cvpr_2019"):
def _lidar_nusc_box_to_global(info, boxes, classes, eval_version="detection_cvpr_2019"):
import pyquaternion
box_list = []
for box in boxes:
# Move box to ego vehicle coord system
box.rotate(pyquaternion.Quaternion(info['lidar2ego_rotation']))
box.translate(np.array(info['lidar2ego_translation']))
from nuscenes.eval.detection.config import eval_detection_configs
from nuscenes.eval.detection.config import config_factory
# filter det in ego.
cls_range_map = eval_detection_configs[eval_version]["class_range"]
cls_range_map = config_factory(eval_version).class_range
radius = np.linalg.norm(box.center[:2], 2)
det_range = cls_range_map[classes[box.label]]
if radius > det_range:
Expand Down Expand Up @@ -784,9 +783,8 @@ def get_box_mean(info_path, class_name="vehicle.car",
eval_version="cvpr_2019"):
with open(info_path, 'rb') as f:
nusc_infos = pickle.load(f)["infos"]
from nuscenes.eval.detection.config import eval_detection_configs
cls_range_map = eval_detection_configs[eval_version]["class_range"]

from nuscenes.eval.detection.config import config_factory
cls_range_map = config_factory(self.eval_version).class_range
gt_boxes_list = []
gt_vels_list = []
for info in nusc_infos:
Expand Down Expand Up @@ -867,8 +865,8 @@ def render_nusc_result(nusc, results, sample_token):
def cluster_trailer_box(info_path, class_name="bus"):
with open(info_path, 'rb') as f:
nusc_infos = pickle.load(f)["infos"]
from nuscenes.eval.detection.config import eval_detection_configs
cls_range_map = eval_detection_configs["cvpr_2019"]["class_range"]
from nuscenes.eval.detection.config import config_factory
cls_range_map = config_factory(self.eval_version).class_range
gt_boxes_list = []
for info in nusc_infos:
gt_boxes = info["gt_boxes"]
Expand Down
11 changes: 10 additions & 1 deletion second/pytorch/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def example_convert_to_torch(example, dtype=torch.float32,
v, dtype=torch.int32, device=device)
elif k in ["anchors_mask"]:
example_torch[k] = torch.tensor(
v, dtype=torch.uint8, device=device)
v, dtype=torch.bool, device=device)
elif k == "calib":
calib = {}
for k1, v1 in v.items():
Expand Down Expand Up @@ -409,6 +409,10 @@ def train(config_path,
model_logging.log_text("Evaluation {}".format(k), global_step)
model_logging.log_text(v, global_step)
model_logging.log_metrics(result_dict["detail"], global_step)
with open(result_path_step / "result_dict.json" , "w") as f:
json.dump(result_dict, f)
with open(result_path_step / "result_dict.pkl", 'wb') as f:
pickle.dump(result_dict, f)
with open(result_path_step / "result.pkl", 'wb') as f:
pickle.dump(detections, f)
net.train()
Expand All @@ -418,6 +422,7 @@ def train(config_path,
if step >= total_step:
break
except Exception as e:
print(f"Exception in step {step}, {str(e)}")
print(json.dumps(example["metadata"], indent=2))
model_logging.log_text(str(e), step)
model_logging.log_text(json.dumps(example["metadata"], indent=2), step)
Expand Down Expand Up @@ -540,6 +545,10 @@ def evaluate(config_path,
result_dict = eval_dataset.dataset.evaluation(detections,
str(result_path_step))
if result_dict is not None:
with open(result_path_step / "result_dict.json" , "w") as f:
json.dump(result_dict, f)
with open(result_path_step / "result_dict.pkl", 'wb') as f:
pickle.dump(result_dict, f)
for k, v in result_dict["results"].items():
print("Evaluation {}".format(k))
print(v)
Expand Down
6 changes: 4 additions & 2 deletions second/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def bev_box_overlap(boxes, qboxes, criterion=-1, stable=True):
return riou


@numba.jit(nopython=True, parallel=True)
@numba.jit(nopython=True) # removed parallel=True as it produces warning
def box3d_overlap_kernel(boxes,
qboxes,
rinc,
Expand Down Expand Up @@ -701,7 +701,9 @@ def do_coco_style_eval(gt_annos,
min_overlaps = np.zeros([10, *overlap_ranges.shape[1:]])
for i in range(overlap_ranges.shape[1]):
for j in range(overlap_ranges.shape[2]):
min_overlaps[:, i, j] = np.linspace(*overlap_ranges[:, i, j])
a, b, c = overlap_ranges[:, i, j] #extracting the three numbers
min_overlaps[:, i, j] = np.linspace(a, b, int(c)) #casting to integer
# min_overlaps[:, i, j] = np.linspace(*overlap_ranges[:, i, j])
mAP_bbox, mAP_bev, mAP_3d, mAP_aos = do_eval_v2(
gt_annos,
dt_annos,
Expand Down
2 changes: 1 addition & 1 deletion torchplus/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def torch_to_np_dtype(ttype):
torch.float16: np.dtype(np.float64),
torch.int32: np.dtype(np.int32),
torch.int64: np.dtype(np.int64),
torch.uint8: np.dtype(np.uint8),
torch.bool: np.dtype(np.uint8),
}
return type_map[ttype]