@@ -265,12 +265,12 @@ namespace Rope {
265265 int bs,
266266 float theta,
267267 const std::vector<int >& axes_dim,
268- bool yarn = false ,
269- std::vector<int > max_pe_len = {},
270- int ori_max_pe_len = 64 ,
271- bool dype = false ,
272- float current_timestep = 1 .0f ,
273- std::vector<float > ntk_factors = {}) {
268+ bool yarn = false ,
269+ std::vector<int > max_pe_len = {},
270+ std::vector< int > ori_max_pe_len = { 64 , 64 , 64 } ,
271+ bool dype = false ,
272+ float current_timestep = 1 .0f ,
273+ std::vector<float > ntk_factors = {}) {
274274 std::vector<std::vector<float >> trans_ids = transpose (ids);
275275 size_t pos_len = ids.size () / bs;
276276 int num_axes = axes_dim.size ();
@@ -292,7 +292,7 @@ namespace Rope {
292292
293293 for (int i = 0 ; i < num_axes; ++i) {
294294 std::vector<std::vector<float >> rope_emb = rope_ext (
295- trans_ids[i], axes_dim[i], theta, false , 1 .0f , ntk_factors[i], true , yarn, max_pe_len[i], ori_max_pe_len, dype, current_timestep);
295+ trans_ids[i], axes_dim[i], theta, false , 1 .0f , ntk_factors[i], true , yarn, max_pe_len[i], ori_max_pe_len[i] , dype, current_timestep);
296296
297297 for (int b = 0 ; b < bs; ++b) {
298298 for (size_t j = 0 ; j < pos_len; ++j) {
@@ -372,12 +372,31 @@ namespace Rope {
372372 bool use_ntk = false ,
373373 float current_timestep = 1 .0f ) {
374374 int base_resolution = 1024 ;
375+ int base_patches_H = -1 ;
376+ int base_patches_W = -1 ;
377+
375378 // set it via environment variable for now (TODO: arg)
379+ // could be either a single integer, or WxH
376380 const char * env_base_resolution = getenv (" FLUX_DYPE_BASE_RESOLUTION" );
377381 if (env_base_resolution != nullptr ) {
378- base_resolution = atoi (env_base_resolution);
382+ if (strchr (env_base_resolution, ' x' ) != nullptr ) {
383+ const char * x_pos = strchr (env_base_resolution, ' x' );
384+ base_patches_H = atoi (x_pos + 1 ) / 16 ;
385+ base_patches_W = atoi (env_base_resolution) / 16 ;
386+ } else {
387+ base_resolution = atoi (env_base_resolution);
388+ }
379389 }
380- int base_patches = base_resolution / 16 ;
390+ // preserve aspect ratio of the input image
391+ // base_patches_W = k*w, base_patches_H = k*h, base_patches_W*base_patches_H = base_resolution^2
392+ // => k = base_resolution / sqrt(w*h)
393+ if (base_patches_H == -1 )
394+ base_patches_H = (base_resolution * h * sqrt (1 .0f / (w * h))) / 16 ;
395+ if (base_patches_W == -1 )
396+ base_patches_W = (base_resolution * w * sqrt (1 .0f / (w * h))) / 16 ;
397+
398+ // First dim is ref image, should not need any weird rope modifications since the max pos should stay very low. 1024 is a lot
399+ std::vector<int > base_patches = {1024 , base_patches_H, base_patches_W};
381400 std::vector<std::vector<float >> ids = gen_flux_ids (h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
382401 std::vector<int > max_pos_vec = {};
383402 std::vector<float > ntk_factor_vec = {};
@@ -393,7 +412,7 @@ namespace Rope {
393412 max_pos_vec.push_back (max_pos);
394413 float ntk_factor = 1 .0f ;
395414 if (use_ntk) {
396- float base_ntk = pow ((float )max_pos / base_patches, (float )axes_dim[i] / (axes_dim[i] - 2 ));
415+ float base_ntk = pow ((float )max_pos / base_patches[i] , (float )axes_dim[i] / (axes_dim[i] - 2 ));
397416 ntk_factor = use_dype ? pow (base_ntk, 2 .0f * current_timestep * current_timestep) : base_ntk;
398417 ntk_factor = std::max (1 .0f , ntk_factor);
399418 }
0 commit comments