2929
3030from ._utils import (
3131 _FillTypeJIT ,
32+ _get_cvcuda_interp ,
3233 _get_kernel ,
3334 _import_cvcuda ,
3435 _is_cvcuda_available ,
@@ -2530,36 +2531,14 @@ def elastic_video(
25302531 return elastic_image (video , displacement , interpolation = interpolation , fill = fill )
25312532
25322533
2533- if CVCUDA_AVAILABLE :
2534- _cvcuda_interp = {
2535- InterpolationMode .BILINEAR : cvcuda .Interp .LINEAR ,
2536- "bilinear" : cvcuda .Interp .LINEAR ,
2537- "linear" : cvcuda .Interp .LINEAR ,
2538- 2 : cvcuda .Interp .LINEAR ,
2539- InterpolationMode .BICUBIC : cvcuda .Interp .CUBIC ,
2540- "bicubic" : cvcuda .Interp .CUBIC ,
2541- 3 : cvcuda .Interp .CUBIC ,
2542- InterpolationMode .NEAREST : cvcuda .Interp .NEAREST ,
2543- "nearest" : cvcuda .Interp .NEAREST ,
2544- 0 : cvcuda .Interp .NEAREST ,
2545- InterpolationMode .BOX : cvcuda .Interp .BOX ,
2546- "box" : cvcuda .Interp .BOX ,
2547- 4 : cvcuda .Interp .BOX ,
2548- InterpolationMode .HAMMING : cvcuda .Interp .HAMMING ,
2549- "hamming" : cvcuda .Interp .HAMMING ,
2550- 5 : cvcuda .Interp .HAMMING ,
2551- InterpolationMode .LANCZOS : cvcuda .Interp .LANCZOS ,
2552- "lanczos" : cvcuda .Interp .LANCZOS ,
2553- 1 : cvcuda .Interp .LANCZOS ,
2554- }
2555-
2556-
2557- def _elastic_cvcuda (
2534+ def _elastic_image_cvcuda (
25582535 image : "cvcuda.Tensor" ,
25592536 displacement : torch .Tensor ,
25602537 interpolation : Union [InterpolationMode , int ] = InterpolationMode .BILINEAR ,
25612538 fill : _FillTypeJIT = None ,
25622539) -> "cvcuda.Tensor" :
2540+ cvcuda = _import_cvcuda ()
2541+
25632542 if not isinstance (displacement , torch .Tensor ):
25642543 raise TypeError ("Argument displacement should be a Tensor" )
25652544
@@ -2578,9 +2557,7 @@ def _elastic_cvcuda(
25782557 elif num_channels == 1 and input_dtype != cvcuda .Type .F32 :
25792558 raise ValueError (f"cvcuda.remap requires float32 dtype for 1-channel images, but got { input_dtype } " )
25802559
2581- interp = _cvcuda_interp .get (interpolation , cvcuda .Interp .LINEAR )
2582- if interp is None :
2583- raise ValueError (f"Invalid interpolation mode: { interpolation } " )
2560+ interp = _get_cvcuda_interp (interpolation )
25842561
25852562 # Build normalized grid: identity + displacement
25862563 # _create_identity_grid returns (1, H, W, 2) with values in [-1, 1]
@@ -2627,7 +2604,7 @@ def _elastic_cvcuda(
26272604
26282605
26292606if CVCUDA_AVAILABLE :
2630- _elastic_cvcuda = _register_kernel_internal (elastic , cvcuda .Tensor )(_elastic_cvcuda )
2607+ _register_kernel_internal (elastic , _import_cvcuda () .Tensor )(_elastic_image_cvcuda )
26312608
26322609
26332610def center_crop (inpt : torch .Tensor , output_size : list [int ]) -> torch .Tensor :
0 commit comments