Skip to content

有关trans loss 的问题 #46

@artofstate

Description

@artofstate

target['track_query_hs_embeds'] = prev_out['hs_embeds'][b_i, prev_out_ind_final]

你好,我发现prepare_track_queries_and_targets()函数中, 只有track_query_boxes 这部分保存的内容是经过detach的,但是track_query_hs_embeds在经过prop后才做了detach,用于与初始query合并,而 trans loss 中用到的确实detach之前的特征,这样不会有梯度问题吗?
track_query_info[b_i]['track_query_hs_embeds'] = track_query_updated.clone().detach()
if get_trans_loss:
pred = self.head.reg_branches[-1](track_query_updated).sigmoid() # (num_prop, 2*num_pts)
pred_scores = self.head.cls_branches[-1](track_query_updated)
assert list(pred.shape) == [N, 2*num_points]

我尝试了一阶段的stream训练的方式,但是加上trans loss的话会报错,假如在prepare_track_queries_and_targets函数中track_query_boxes和track_query_hs_embeds都存储为detach形式的话,是否需要先单帧训练,保证拿到的都是比较好的信息?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions