77from  genjax .incremental  import  Diff , NoChange , UnknownChange 
88
99import  bayes3d  as  b 
10- import  bayes3d .scene_graph 
1110
1211from  .genjax_distributions  import  (
1312    contact_params_uniform ,
@@ -128,14 +127,14 @@ def get_far_plane(trace):
128127
129128
130129def  add_object (trace , key , obj_id , parent , face_parent , face_child ):
131-     N  =  get_indices (trace ).shape [0 ] +  1 
130+     N  =  b . get_indices (trace ).shape [0 ] +  1 
132131    choices  =  trace .get_choices ()
133132    choices [f"parent_{ N - 1 }  ] =  parent 
134133    choices [f"id_{ N - 1 }  ] =  obj_id 
135134    choices [f"face_parent_{ N - 1 }  ] =  face_parent 
136135    choices [f"face_child_{ N - 1 }  ] =  face_child 
137136    choices [f"contact_params_{ N - 1 }  ] =  jnp .zeros (3 )
138-     return  model .importance (key , choices , (jnp .arange (N ), * trace .get_args ()[1 :]))[0 ]
137+     return  model .importance (key , choices , (jnp .arange (N ), * trace .get_args ()[1 :]))[1 ]
139138
140139
141140add_object_jit  =  jax .jit (add_object )
@@ -152,7 +151,7 @@ def print_trace(trace):
152151
153152
154153def  viz_trace_meshcat (trace , colors = None ):
155-     b .clear_visualizer ()
154+     b .clear ()
156155    b .show_cloud (
157156        "1" , b .apply_transform_jit (trace ["image" ].reshape (- 1 , 3 ), trace ["camera_pose" ])
158157    )
@@ -224,14 +223,14 @@ def enumerator(trace, key, *args):
224223            key ,
225224            chm_builder (addresses , args , chm_args ),
226225            argdiff_f (trace ),
227-         )[0 ]
226+         )[2 ]
228227
229228    def  enumerator_with_weight (trace , key , * args ):
230229        return  trace .update (
231230            key ,
232231            chm_builder (addresses , args , chm_args ),
233232            argdiff_f (trace ),
234-         )[0 : 2 ]
233+         )[1 : 3 ]
235234
236235    def  enumerator_score (trace , key , * args ):
237236        return  enumerator (trace , key , * args ).get_score ()
@@ -302,4 +301,4 @@ def update_address(trace, key, address, value):
302301        key ,
303302        genjax .choice_map ({address : value }),
304303        tuple (map (lambda  v : Diff (v , UnknownChange ), trace .args )),
305-     )[0 ]
304+     )[2 ]
0 commit comments