Skip to content

Commit 2ba7195

Browse files
authored
model: LFM2-VL fixes (#17577)
* Adjust to pytorch * Add antialiasing upscale * Increase number of patches to 1024 * Handle default marker insertion for LFM2 * Switch to flag * Reformat * Cuda implementation of antialias kernel * Change placement in ops.cpp * consistent float literals * Pad only for LFM2 * Address PR feedback * Rollback default marker placement changes * Fallback to CPU implementation for antialias implementation of upscale
1 parent 7f8ef50 commit 2ba7195

File tree

12 files changed

+162
-13
lines changed

12 files changed

+162
-13
lines changed

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2148,7 +2148,8 @@ extern "C" {
21482148
};
21492149

21502150
enum ggml_scale_flag {
2151-
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
2151+
GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8),
2152+
GGML_SCALE_FLAG_ANTIALIAS = (1 << 9),
21522153
};
21532154

21542155
// interpolate

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2500,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
25002500
if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
25012501
return false;
25022502
}
2503+
if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
2504+
return false;
2505+
}
25032506
return true;
25042507
}
25052508
case GGML_OP_POOL_2D:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32(
74207420
}
74217421
}
74227422
}
7423+
} else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7424+
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7425+
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7426+
auto triangle_filter = [](float x) -> float {
7427+
return std::max(1.0f - fabsf(x), 0.0f);
7428+
};
7429+
7430+
// support and invscale, minimum 1 pixel for bilinear
7431+
const float support1 = std::max(1.0f, 1.0f / sf1);
7432+
const float invscale1 = 1.0f / support1;
7433+
const float support0 = std::max(1.0f, 1.0f / sf0);
7434+
const float invscale0 = 1.0f / support0;
7435+
7436+
for (int64_t i3 = 0; i3 < ne3; i3++) {
7437+
const int64_t i03 = i3 / sf3;
7438+
for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7439+
const int64_t i02 = i2 / sf2;
7440+
for (int64_t i1 = 0; i1 < ne1; i1++) {
7441+
const float y = ((float) i1 + pixel_offset) / sf1;
7442+
for (int64_t i0 = 0; i0 < ne0; i0++) {
7443+
const float x = ((float) i0 + pixel_offset) / sf0;
7444+
7445+
// the range of source pixels that contribute
7446+
const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7447+
const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7448+
const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7449+
const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7450+
7451+
// bilinear filter with antialiasing
7452+
float val = 0.0f;
7453+
float total_weight = 0.0f;
7454+
7455+
for (int64_t sy = y_min; sy < y_max; sy++) {
7456+
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7457+
7458+
for (int64_t sx = x_min; sx < x_max; sx++) {
7459+
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7460+
const float weight = weight_x * weight_y;
7461+
7462+
if (weight <= 0.0f) {
7463+
continue;
7464+
}
7465+
7466+
const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7467+
val += pixel * weight;
7468+
total_weight += weight;
7469+
}
7470+
}
7471+
7472+
if (total_weight > 0.0f) {
7473+
val /= total_weight;
7474+
}
7475+
7476+
float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7477+
*dst_ptr = val;
7478+
}
7479+
}
7480+
}
7481+
}
74237482
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
74247483
for (int64_t i3 = 0; i3 < ne3; i3++) {
74257484
const int64_t i03 = i3 / sf3;

ggml/src/ggml-cuda/upscale.cu

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
8181
dst[index] = result;
8282
}
8383

84+
// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
85+
// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
86+
static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
87+
const int nb00, const int nb01, const int nb02, const int nb03,
88+
const int ne00_src, const int ne01_src,
89+
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
90+
const float sf0, const float sf1, const float sf2, const float sf3,
91+
const float pixel_offset) {
92+
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
93+
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
94+
95+
if (index >= dst_total_elements) {
96+
return;
97+
}
98+
99+
const int i10_dst = index % ne10_dst;
100+
const int i11_dst = (index / ne10_dst) % ne11_dst;
101+
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
102+
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
103+
104+
const int i02_src = (int)(i12_dst / sf2);
105+
const int i03_src = (int)(i13_dst / sf3);
106+
107+
const float y = ((float)i11_dst + pixel_offset) / sf1;
108+
const float x = ((float)i10_dst + pixel_offset) / sf0;
109+
110+
// support and invscale, minimum 1 pixel for bilinear
111+
const float support1 = max(1.0f / sf1, 1.0f);
112+
const float invscale1 = 1.0f / support1;
113+
const float support0 = max(1.0f / sf0, 1.0f);
114+
const float invscale0 = 1.0f / support0;
115+
116+
// the range of source pixels that contribute
117+
const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
118+
const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
119+
const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
120+
const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
121+
122+
// bilinear filter with antialiasing
123+
float val = 0.0f;
124+
float total_weight = 0.0f;
125+
126+
auto triangle_filter = [](float x) -> float {
127+
return max(1.0f - fabsf(x), 0.0f);
128+
};
129+
130+
for (int64_t sy = y_min; sy < y_max; sy++) {
131+
const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
132+
133+
for (int64_t sx = x_min; sx < x_max; sx++) {
134+
const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
135+
const float weight = weight_x * weight_y;
136+
137+
if (weight <= 0.0f) {
138+
continue;
139+
}
140+
141+
const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
142+
val += pixel * weight;
143+
total_weight += weight;
144+
}
145+
}
146+
147+
if (total_weight > 0.0f) {
148+
val /= total_weight;
149+
}
150+
151+
dst[index] = val;
152+
}
153+
84154
namespace bicubic_interpolation {
85155
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
86156
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
@@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
161231
const int ne00_src, const int ne01_src,
162232
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
163233
const float sf0, const float sf1, const float sf2, const float sf3,
164-
const float pixel_offset, cudaStream_t stream) {
234+
const float pixel_offset, bool antialias, cudaStream_t stream) {
165235
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
166236
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
167237

168-
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
238+
if (antialias) {
239+
upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
240+
} else {
241+
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
242+
}
169243
}
170244

171245
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
@@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
207281
if (mode == GGML_SCALE_MODE_NEAREST) {
208282
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
209283
} else if (mode == GGML_SCALE_MODE_BILINEAR) {
284+
const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
210285
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
211286
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
212-
sf0, sf1, sf2, sf3, pixel_offset, stream);
287+
sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
213288
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
214289
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
215290
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
894894
case GGML_OP_POOL_1D:
895895
return false;
896896
case GGML_OP_UPSCALE:
897-
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
897+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
898898
case GGML_OP_POOL_2D:
899899
return op->src[0]->type == GGML_TYPE_F32;
900900
case GGML_OP_PAD:

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3086,8 +3086,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
30863086
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
30873087
case GGML_OP_UPSCALE: {
30883088
ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF);
3089+
const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS);
30893090
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
3090-
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR);
3091+
(mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias;
30913092
}
30923093
case GGML_OP_CONV_2D:
30933094
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4597,7 +4597,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
45974597
case GGML_OP_IM2COL:
45984598
return true;
45994599
case GGML_OP_UPSCALE:
4600-
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
4600+
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
46014601
case GGML_OP_SUM:
46024602
case GGML_OP_SUM_ROWS:
46034603
case GGML_OP_MEAN:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14113,6 +14113,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1411314113
}
1411414114
return true;
1411514115
case GGML_OP_UPSCALE:
14116+
return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
1411614117
case GGML_OP_ACC:
1411714118
return op->src[0]->type == GGML_TYPE_F32;
1411814119
case GGML_OP_CONCAT:

ggml/src/ggml.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4891,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl(
48914891
int64_t ne3,
48924892
uint32_t mode) {
48934893
GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT);
4894+
// TODO: implement antialias for modes other than bilinear
4895+
GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR);
48944896

48954897
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
48964898

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7660,7 +7660,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
76607660
// test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1));
76617661
//}
76627662

7663-
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) {
7663+
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) {
76647664
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
76657665
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
76667666
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));

0 commit comments

Comments
 (0)