Skip to content

Commit 5f4ccac

Browse files
committed
base_resolution with desired aspect ratio
1 parent 26ed7b4 commit 5f4ccac

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

rope.hpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,12 +265,12 @@ namespace Rope {
265265
int bs,
266266
float theta,
267267
const std::vector<int>& axes_dim,
268-
bool yarn = false,
269-
std::vector<int> max_pe_len = {},
270-
int ori_max_pe_len = 64,
271-
bool dype = false,
272-
float current_timestep = 1.0f,
273-
std::vector<float> ntk_factors = {}) {
268+
bool yarn = false,
269+
std::vector<int> max_pe_len = {},
270+
std::vector<int> ori_max_pe_len = {64, 64, 64},
271+
bool dype = false,
272+
float current_timestep = 1.0f,
273+
std::vector<float> ntk_factors = {}) {
274274
std::vector<std::vector<float>> trans_ids = transpose(ids);
275275
size_t pos_len = ids.size() / bs;
276276
int num_axes = axes_dim.size();
@@ -292,7 +292,7 @@ namespace Rope {
292292

293293
for (int i = 0; i < num_axes; ++i) {
294294
std::vector<std::vector<float>> rope_emb = rope_ext(
295-
trans_ids[i], axes_dim[i], theta, false, 1.0f, ntk_factors[i], true, yarn, max_pe_len[i], ori_max_pe_len, dype, current_timestep);
295+
trans_ids[i], axes_dim[i], theta, false, 1.0f, ntk_factors[i], true, yarn, max_pe_len[i], ori_max_pe_len[i], dype, current_timestep);
296296

297297
for (int b = 0; b < bs; ++b) {
298298
for (size_t j = 0; j < pos_len; ++j) {
@@ -372,12 +372,31 @@ namespace Rope {
372372
bool use_ntk = false,
373373
float current_timestep = 1.0f) {
374374
int base_resolution = 1024;
375+
int base_patches_H = -1;
376+
int base_patches_W = -1;
377+
375378
// set it via environment variable for now (TODO: arg)
379+
// could be either a single integer, or WxH
376380
const char* env_base_resolution = getenv("FLUX_DYPE_BASE_RESOLUTION");
377381
if (env_base_resolution != nullptr) {
378-
base_resolution = atoi(env_base_resolution);
382+
if (strchr(env_base_resolution, 'x') != nullptr) {
383+
const char* x_pos = strchr(env_base_resolution, 'x');
384+
base_patches_H = atoi(x_pos + 1) / 16;
385+
base_patches_W = atoi(env_base_resolution) / 16;
386+
} else {
387+
base_resolution = atoi(env_base_resolution);
388+
}
379389
}
380-
int base_patches = base_resolution / 16;
390+
// preserve aspect ratio of the input image
391+
// base_patches_W = k*w, base_patches_H = k*h, base_patches_W*base_patches_H = base_resolution^2
392+
// => k = base_resolution / sqrt(w*h)
393+
if (base_patches_H == -1)
394+
base_patches_H = (base_resolution * h * sqrt(1.0f / (w * h))) / 16;
395+
if (base_patches_W == -1)
396+
base_patches_W = (base_resolution * w * sqrt(1.0f / (w * h))) / 16;
397+
398+
// First dim is ref image, should not need any weird rope modifications since the max pos should stay very low. 1024 is a lot
399+
std::vector<int> base_patches = {1024, base_patches_H, base_patches_W};
381400
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
382401
std::vector<int> max_pos_vec = {};
383402
std::vector<float> ntk_factor_vec = {};
@@ -393,7 +412,7 @@ namespace Rope {
393412
max_pos_vec.push_back(max_pos);
394413
float ntk_factor = 1.0f;
395414
if (use_ntk) {
396-
float base_ntk = pow((float)max_pos / base_patches, (float)axes_dim[i] / (axes_dim[i] - 2));
415+
float base_ntk = pow((float)max_pos / base_patches[i], (float)axes_dim[i] / (axes_dim[i] - 2));
397416
ntk_factor = use_dype ? pow(base_ntk, 2.0f * current_timestep * current_timestep) : base_ntk;
398417
ntk_factor = std::max(1.0f, ntk_factor);
399418
}

0 commit comments

Comments
 (0)