@@ -178,7 +178,7 @@ def xfm_points(points, matrix):
178178 pos_grads .min ().item (),
179179 pos_grads .max ().item (),
180180 )
181- print (f"JAX rasterization (eval + grad): { (end_time - start_time )* 1000 } ms" )
181+ print (f"JAX rasterization (eval + grad): { (end_time - start_time ) * 1000 } ms" )
182182
183183 # save viz
184184 b .viz .get_depth_image (rast_out [0 ][:, :, 2 ]).save ("img_jax.png" )
@@ -229,7 +229,7 @@ def xfm_points(points, matrix):
229229 pos_grads .min ().item (),
230230 pos_grads .max ().item (),
231231 )
232- print (f"Torch rasterization (eval + grad): { (end_time - start_time )* 1000 } ms" )
232+ print (f"Torch rasterization (eval + grad): { (end_time - start_time ) * 1000 } ms" )
233233
234234 # save viz
235235 b .viz .get_depth_image (jnp .array (rast_out [0 ][:, :, 2 ].cpu ())).save ("img_torch.png" )
@@ -278,7 +278,7 @@ def xfm_points(points, matrix):
278278 print (
279279 f"JAX BWD (sum, min, max): g_attr={ g_attr .sum ().item (), g_attr .min ().item (), g_attr .max ().item ()} \n g_rast={ g_rast .sum ().item (), g_rast .min ().item (), g_rast .max ().item ()} "
280280 )
281- print (f"JAX interpolation: { (end_time - start_time )* 1000 } ms" )
281+ print (f"JAX interpolation: { (end_time - start_time ) * 1000 } ms" )
282282
283283 # save viz
284284 b .viz .get_depth_image (gb_pos [0 ][:, :, 2 ]).save ("interpolate_jax.png" )
@@ -316,7 +316,7 @@ def xfm_points(points, matrix):
316316 print (
317317 f"TORCH BWD (sum, min, max): g_attr={ g_attr .sum ().item (), g_attr .min ().item (), g_attr .max ().item ()} \n g_rast={ g_rast .sum ().item (), g_rast .min ().item (), g_rast .max ().item ()} "
318318 )
319- print (f"Torch interpolation: { (end_time - start_time )* 1000 } ms" )
319+ print (f"Torch interpolation: { (end_time - start_time ) * 1000 } ms" )
320320
321321 # save viz
322322 b .viz .get_depth_image (jnp .array (gb_pos [0 ][:, :, 2 ].cpu ())).save (
0 commit comments