Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion params/demo/fastsam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ plane_filter_params: [3.0, 1.0, 0.2]
semantics: 'dino'
yolo_imgsz: [256, 256]
depth_scale: 1000.0
max_depth: 7.5
max_depth: 7.5
frame_descriptor: 'dino-gem'
6 changes: 5 additions & 1 deletion params/demo/submap_align.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ force_rm_upside_down: True
use_object_bottom_middle: True
cosine_min: 0.5
cosine_max: 0.7
semantics_dim: 768
semantics_dim: 768

submap_descriptor: 'stacked_frame_descriptors'
frame_descriptor_dist: 10.0
submap_descriptor_thresh: 0.8
3 changes: 2 additions & 1 deletion params/demo_no_gpu/fastsam.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ plane_filter_params: [3.0, 1.0, 0.2]
semantics: 'none'
yolo_imgsz: [256, 256]
depth_scale: 1000.0
max_depth: 7.5
max_depth: 7.5
frame_descriptor: 'none'
25 changes: 20 additions & 5 deletions roman/align/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SubmapAlignResults:
clipper_angle_mat: np.array
clipper_dist_mat: np.array
clipper_num_associations: np.array
similarity_mat: np.array
submap_yaw_diff_mat: np.array
associated_objs_mat: np.array
T_ij_mat: np.array
Expand Down Expand Up @@ -52,17 +53,25 @@ def time_to_secs_nsecs(t, as_dict=False):
return {'seconds': seconds, 'nanoseconds': nanoseconds}

def plot_align_results(results: SubmapAlignResults, dpi=500):


show_sim = results.similarity_mat is not None

# if no ground truth, can only show number of associations
if None in results.submap_io.input_gt_pose_yaml:
fig, ax = plt.subplots(1, 1, figsize=(4,4), dpi=dpi)
mp = ax.imshow(results.clipper_num_associations, cmap='viridis', vmin=0)
fig, ax = plt.subplots(2 if show_sim else 1, 1, figsize=(8 if show_sim else 4, 4), dpi=dpi)
ax = np.array(ax).reshape(-1, 1)
mp = ax[0, 0].imshow(results.clipper_num_associations, cmap='viridis', vmin=0)
fig.colorbar(mp, fraction=0.04, pad=0.04)
ax.set_title("Number of Associations")

if show_sim:
mp = ax[1, 0].imshow(results.similarity_mat, cmap='viridis', vmin=0.0, vmax=1.0)
fig.colorbar(mp, fraction=0.04, pad=0.04)
ax[1, 0].set_title("Similarity Score")

fig.suptitle(f"{results.submap_io.run_name}: {results.submap_io.robot_names[0]}, {results.submap_io.robot_names[1]}")
return


fig, ax = plt.subplots(3, 2, figsize=(8, 12), dpi=dpi)
fig.subplots_adjust(wspace=.3)
fig.suptitle(f"{results.submap_io.run_name}: {results.submap_io.robot_names[0]}, {results.submap_io.robot_names[1]}")
Expand Down Expand Up @@ -95,14 +104,20 @@ def plot_align_results(results: SubmapAlignResults, dpi=500):
mp = ax[2, 0].imshow(results.clipper_num_associations, cmap='viridis', vmin=0)
fig.colorbar(mp, fraction=0.04, pad=0.04)
ax[2, 0].set_title("Number of Associations")

if show_sim:
mp = ax[2, 1].imshow(results.similarity_mat, cmap='viridis', vmin=0.0, vmax=1.0)
fig.colorbar(mp, fraction=0.04, pad=0.04)
ax[2, 1].set_title("Similarity Score")

for i in range(len(ax)):
for j in range(len(ax[i])):
ax[i,j].set_xlabel("submap index (robot 2)")
ax[i,j].set_ylabel("submap index (robot 1)")
ax[i,j].grid(False)

fig.delaxes(ax[2, 1])
if not show_sim:
fig.delaxes(ax[2, 1])

def save_submap_align_results(results: SubmapAlignResults, submaps, roman_maps: List[ROMANMap]):
plot_align_results(results)
Expand Down
6 changes: 4 additions & 2 deletions roman/align/submap_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def submap_align(sm_params: SubmapAlignParams, sm_io: SubmapAlignInputOutput):
clipper_angle_mat = np.zeros((len(submaps[0]), len(submaps[1])))*np.nan
clipper_dist_mat = np.zeros((len(submaps[0]), len(submaps[1])))*np.nan
clipper_num_associations = np.zeros((len(submaps[0]), len(submaps[1])))*np.nan
similarity_mat = np.zeros((len(submaps[0]), len(submaps[1])))*np.nan
robots_nearby_mat = np.zeros((len(submaps[0]), len(submaps[1])))*np.nan
clipper_percent_associations = np.zeros((len(submaps[0]), len(submaps[1])))*np.nan
submap_yaw_diff_mat = np.zeros((len(submaps[0]), len(submaps[1])))*np.nan
Expand Down Expand Up @@ -128,8 +129,7 @@ def submap_align(sm_params: SubmapAlignParams, sm_io: SubmapAlignInputOutput):
submap_yaw_diff_mat[i, j] = np.abs(np.rad2deg(relative_yaw_angle))

if sm_params.submap_descriptor is not None:
submap_sim = np.sum(submap_i.descriptor * submap_j.descriptor) / \
(np.linalg.norm(submap_i.descriptor) * np.linalg.norm(submap_j.descriptor))
submap_sim = Submap.similarity(submap_i, submap_j)
else:
submap_sim = np.inf # always try to register object maps if no descriptor is used

Expand Down Expand Up @@ -184,6 +184,7 @@ def submap_align(sm_params: SubmapAlignParams, sm_io: SubmapAlignInputOutput):
clipper_dist_mat[i, j] = np.nan

clipper_num_associations[i, j] = len(associations)
similarity_mat[i, j] = submap_sim
clipper_percent_associations[i, j] = len(associations) / np.mean([len(submap_i), len(submap_j)]) if np.mean([len(submap_i), len(submap_j)]) > 0 else 0.0

T_ij_mat[i, j] = T_ij
Expand All @@ -198,6 +199,7 @@ def submap_align(sm_params: SubmapAlignParams, sm_io: SubmapAlignInputOutput):
clipper_angle_mat=clipper_angle_mat,
clipper_dist_mat=clipper_dist_mat,
clipper_num_associations=clipper_num_associations,
similarity_mat=similarity_mat if sm_params.submap_descriptor is not None else None,
submap_yaw_diff_mat=submap_yaw_diff_mat,
T_ij_mat=T_ij_mat,
T_ij_hat_mat=T_ij_hat_mat,
Expand Down
59 changes: 50 additions & 9 deletions roman/map/fastsam_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def from_params(cls, params: FastSAMParams, depth_cam_params: CameraParams):
allow_tblr_edges=[True, True, True, True],
area_bounds=[img_area / (params.min_mask_len_div**2), img_area / (params.max_mask_len_div**2)],
semantics=params.semantics,
frame_descriptor=params.frame_descriptor,
triangle_ignore_masks=params.triangle_ignore_masks
)

Expand All @@ -132,6 +133,7 @@ def setup_filtering(self,
allow_tblr_edges = [True, True, True, True],
keep_mask_minimal_intersection=0.3,
semantics: str = None,
frame_descriptor: str = None,
triangle_ignore_masks=None
):
"""
Expand Down Expand Up @@ -176,6 +178,9 @@ def setup_filtering(self,
else:
raise ValueError(f"Invalid semantics option: {semantics}. Choose from 'clip', 'dino', or 'none'.")
self.semantic_patches_shape = None
self.frame_descriptor_type = frame_descriptor
if frame_descriptor is not None:
assert self.semantics.lower() == 'dino', "Frame descriptor only supported with DINO semantics."

if triangle_ignore_masks is not None:
self.constant_ignore_mask = np.zeros((self.depth_cam_params.height, self.depth_cam_params.width), dtype=np.uint8)
Expand Down Expand Up @@ -235,16 +240,16 @@ def setup_rgbd_params(
self.erosion_element = None
self.plane_filter_params = plane_filter_params

def run(self, t, pose, img, depth_data=None, plot=False):
def run(self, t, pose, img, depth_data=None):
"""
Takes and image and returns filtered FastSAM masks as Observations.

Args:
img (cv image): camera image
plot (bool, optional): Returns plots if true. Defaults to False.

Returns:
self.observations (list): list of Observations
frame_descriptor (np.ndarray): semantic descriptor of the frame if frame_descriptor is not None, else None
"""
self.observations = []

Expand Down Expand Up @@ -274,14 +279,20 @@ def run(self, t, pose, img, depth_data=None, plot=False):
img_rgb = cv.cvtColor(img, cv.COLOR_BGR2RGB)
preprocessed = self.semantics_preprocess(images=img_rgb, return_tensors="pt").to(self.device)
dino_output = self.semantics_model(**preprocessed)
dino_features = self.get_per_pixel_features(
dino_output_patches = self.get_output_patches(
model_output=dino_output.last_hidden_state,
img_shape=img.shape,
feature_dim=dino_shape
)
dino_features = self.get_per_pixel_features(
model_output_patches=dino_output_patches,
img_shape=img.shape
)
dino_features = self.unapply_rotation(dino_features)



frame_descriptor = None
if self.frame_descriptor_type is not None:
frame_descriptor = self.get_frame_descriptor(dino_output_patches)

for mask in masks:

Expand Down Expand Up @@ -393,7 +404,7 @@ def run(self, t, pose, img, depth_data=None, plot=False):
else:
self.observations.append(Observation(t, pose, mask, mask_downsampled, ptcld))

return self.observations
return self.observations, frame_descriptor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break things downstream in roman_ros, so before we merge this PR, we should make updates there too.


def apply_rotation(self, img, unrotate=False):
if self.rotate_img is None:
Expand Down Expand Up @@ -577,10 +588,9 @@ def mask_bounding_box(self, mask):

return (min_col, min_row, max_col, max_row,)

def get_per_pixel_features(self, model_output: ArrayLike, img_shape: ArrayLike,
feature_dim: int) -> ArrayLike:
def get_output_patches(self, model_output: ArrayLike, img_shape: ArrayLike, feature_dim: int) -> ArrayLike:
"""
Extract (Dino) per-pixel features
Extract (Dino) output patches

Args:
model_output (ArrayLike): Last hidden state of (Dino) model
Expand All @@ -601,6 +611,19 @@ def get_per_pixel_features(self, model_output: ArrayLike, img_shape: ArrayLike,

model_output_patches = model_output_flat_patches.reshape(self.semantic_patches_shape)

return model_output_patches # 1 x h x w x feature_dim

def get_per_pixel_features(self, model_output_patches: ArrayLike, img_shape: ArrayLike) -> ArrayLike:
"""
Extract (Dino) per-pixel features

Args:
model_output_patches (ArrayLike): Reshaped (Dino) output patches
img_shape (ArrayLike): Original image shape

Returns:
ArrayLike: Reshaped (Dino) output
"""
# interpolate the feature map to match the size of the original image
per_pixel_features = torch.nn.functional.interpolate(
model_output_patches.permute(0, 3, 1, 2), # permute to be batch, channels, height, width
Expand All @@ -613,3 +636,21 @@ def get_per_pixel_features(self, model_output: ArrayLike, img_shape: ArrayLike,

return per_pixel_features # h x w x feature_dim

def get_frame_descriptor(self, dino_features: torch.Tensor) -> np.ndarray:
with torch.no_grad(): # prevent memory leak
dino_features_flat = dino_features.view(-1, dino_features.shape[-1])
if self.frame_descriptor_type == 'dino-gap':
frame_descriptor = torch.sum(dino_features_flat, dim=0)
elif self.frame_descriptor_type == 'dino-gmp':
frame_descriptor = torch.max(dino_features_flat, dim=0).values
elif self.frame_descriptor_type == 'dino-gem':
cubed_descriptor = torch.mean(dino_features_flat ** 3, dim=0)
frame_descriptor = torch.sign(cubed_descriptor) * \
(torch.abs(cubed_descriptor).clamp(min=1e-12) ** (1.0 / 3)) # avoid NaN from negative or zero root
else:
raise ValueError(f"frame descriptor must be one of 'dino-gap', 'dino-gmp', or 'dino-gem'.")

frame_descriptor /= torch.norm(frame_descriptor)

return frame_descriptor.cpu().detach().numpy()

Loading