@@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
8181 dst[index] = result;
8282}
8383
84+ // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
85+ // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
86+ static __global__ void upscale_f32_bilinear_antialias (const float * src0, float * dst,
87+ const int nb00, const int nb01, const int nb02, const int nb03,
88+ const int ne00_src, const int ne01_src,
89+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
90+ const float sf0, const float sf1, const float sf2, const float sf3,
91+ const float pixel_offset) {
92+ const int64_t index = threadIdx .x + blockIdx .x * blockDim .x ;
93+ const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
94+
95+ if (index >= dst_total_elements) {
96+ return ;
97+ }
98+
99+ const int i10_dst = index % ne10_dst;
100+ const int i11_dst = (index / ne10_dst) % ne11_dst;
101+ const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
102+ const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
103+
104+ const int i02_src = (int )(i12_dst / sf2);
105+ const int i03_src = (int )(i13_dst / sf3);
106+
107+ const float y = ((float )i11_dst + pixel_offset) / sf1;
108+ const float x = ((float )i10_dst + pixel_offset) / sf0;
109+
110+ // support and invscale, minimum 1 pixel for bilinear
111+ const float support1 = max (1 .0f / sf1, 1 .0f );
112+ const float invscale1 = 1 .0f / support1;
113+ const float support0 = max (1 .0f / sf0, 1 .0f );
114+ const float invscale0 = 1 .0f / support0;
115+
116+ // the range of source pixels that contribute
117+ const int64_t x_min = max (int64_t (0 ), int64_t (x - support0 + pixel_offset));
118+ const int64_t x_max = min (int64_t (ne00_src), int64_t (x + support0 + pixel_offset));
119+ const int64_t y_min = max (int64_t (0 ), int64_t (y - support1 + pixel_offset));
120+ const int64_t y_max = min (int64_t (ne01_src), int64_t (y + support1 + pixel_offset));
121+
122+ // bilinear filter with antialiasing
123+ float val = 0 .0f ;
124+ float total_weight = 0 .0f ;
125+
126+ auto triangle_filter = [](float x) -> float {
127+ return max (1 .0f - fabsf (x), 0 .0f );
128+ };
129+
130+ for (int64_t sy = y_min; sy < y_max; sy++) {
131+ const float weight_y = triangle_filter ((sy - y + pixel_offset) * invscale1);
132+
133+ for (int64_t sx = x_min; sx < x_max; sx++) {
134+ const float weight_x = triangle_filter ((sx - x + pixel_offset) * invscale0);
135+ const float weight = weight_x * weight_y;
136+
137+ if (weight <= 0 .0f ) {
138+ continue ;
139+ }
140+
141+ const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
142+ val += pixel * weight;
143+ total_weight += weight;
144+ }
145+ }
146+
147+ if (total_weight > 0 .0f ) {
148+ val /= total_weight;
149+ }
150+
151+ dst[index] = val;
152+ }
153+
84154namespace bicubic_interpolation {
85155// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
86156__device__ const float a = -0 .75f ; // use alpha = -0.75 (same as PyTorch)
@@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
161231 const int ne00_src, const int ne01_src,
162232 const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
163233 const float sf0, const float sf1, const float sf2, const float sf3,
164- const float pixel_offset, cudaStream_t stream) {
234+ const float pixel_offset, bool antialias, cudaStream_t stream) {
165235 const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
166236 const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1 ) / CUDA_UPSCALE_BLOCK_SIZE;
167237
168- upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
238+ if (antialias) {
239+ upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
240+ } else {
241+ upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0 ,stream>>> (x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
242+ }
169243}
170244
171245static void upscale_f32_bicubic_cuda (const float * x, float * dst,
@@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
207281 if (mode == GGML_SCALE_MODE_NEAREST) {
208282 upscale_f32_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ], sf0, sf1, sf2, sf3, stream);
209283 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
284+ const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
210285 upscale_f32_bilinear_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
211286 src0->ne [0 ], src0->ne [1 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
212- sf0, sf1, sf2, sf3, pixel_offset, stream);
287+ sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
213288 } else if (mode == GGML_SCALE_MODE_BICUBIC) {
214289 upscale_f32_bicubic_cuda (src0_d, dst_d, src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
215290 src0->ne [0 ], src0->ne [1 ], dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
0 commit comments