@@ -7402,9 +7402,27 @@ static void ggml_compute_forward_upscale_f32(
74027402 sf1 = ne1 > 1 && ne01 > 1 ? (float )(ne1 - 1 ) / (ne01 - 1 ) : sf1;
74037403 }
74047404
7405- // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7406- // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7407- if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7405+ if (mode == GGML_SCALE_MODE_NEAREST) {
7406+ for (int64_t i3 = 0 ; i3 < ne3; i3++) {
7407+ const int64_t i03 = i3 / sf3;
7408+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7409+ const int64_t i02 = i2 / sf2;
7410+ for (int64_t i1 = 0 ; i1 < ne1; i1++) {
7411+ const int64_t i01 = i1 / sf1;
7412+ for (int64_t i0 = 0 ; i0 < ne0; i0++) {
7413+ const int64_t i00 = i0 / sf0;
7414+
7415+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7416+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7417+
7418+ *y = *x;
7419+ }
7420+ }
7421+ }
7422+ }
7423+ } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7424+ // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7425+ // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
74087426 auto triangle_filter = [](float x) -> float {
74097427 return std::max (1 .0f - fabsf (x), 0 .f );
74107428 };
@@ -7461,24 +7479,6 @@ static void ggml_compute_forward_upscale_f32(
74617479 }
74627480 }
74637481 }
7464- } else if (mode == GGML_SCALE_MODE_NEAREST) {
7465- for (int64_t i3 = 0 ; i3 < ne3; i3++) {
7466- const int64_t i03 = i3 / sf3;
7467- for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7468- const int64_t i02 = i2 / sf2;
7469- for (int64_t i1 = 0 ; i1 < ne1; i1++) {
7470- const int64_t i01 = i1 / sf1;
7471- for (int64_t i0 = 0 ; i0 < ne0; i0++) {
7472- const int64_t i00 = i0 / sf0;
7473-
7474- const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7475- float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7476-
7477- *y = *x;
7478- }
7479- }
7480- }
7481- }
74827482 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
74837483 for (int64_t i3 = 0 ; i3 < ne3; i3++) {
74847484 const int64_t i03 = i3 / sf3;
0 commit comments