@@ -78,7 +78,7 @@ namespace Rope {
7878 std::pair<int , int > find_correction_range (float low_ratio, float high_ratio, int dim, float base, float ori_max_pe_len) {
7979 float low = std::floor (find_correction_factor (low_ratio, dim, base, ori_max_pe_len));
8080 float high = std::ceil (find_correction_factor (high_ratio, dim, base, ori_max_pe_len));
81- return {std::max (0 , static_cast <int >(low)), std::min (dim - 1 , static_cast <int >(high))};
81+ return {std::max (0 , static_cast <int >(low)), std::min (dim / 2 , static_cast <int >(high))};
8282 }
8383
8484 std::vector<float > linear_ramp_mask (int min, int max, int dim) {
@@ -92,10 +92,6 @@ namespace Rope {
9292 return ramp;
9393 }
9494
95- float find_newbase_ntk (int dim, float base, float scale) {
96- return base * std::pow (scale, static_cast <float >(dim) / (dim - 2 ));
97- }
98-
9995 __STATIC_INLINE__ std::vector<std::vector<float >> rope_ext (
10096 const std::vector<float >& pos,
10197 int dim,
@@ -112,66 +108,47 @@ namespace Rope {
112108 assert (dim % 2 == 0 );
113109 int half_dim = dim / 2 ;
114110
115- // Compute scale for YARN
116- float scale = 1 .0f ;
117- if (yarn && max_pe_len > ori_max_pe_len) {
118- scale = std::max (1 .0f , static_cast <float >(max_pe_len) / ori_max_pe_len);
119- }
120-
121111 // Compute frequencies
122112 std::vector<float > freqs_base (half_dim);
123113 std::vector<float > freqs_linear (half_dim);
124114 std::vector<float > freqs_ntk (half_dim);
125115 std::vector<float > freqs (half_dim);
126116
127- for (int i = 0 ; i < half_dim; ++i) {
128- float exponent = static_cast <float >(i) / half_dim;
129- freqs_base[i] = 1 .0f / std::pow (theta, exponent);
130- if (yarn && max_pe_len > ori_max_pe_len) {
131- freqs_linear[i] = 1 .0f / std::pow (theta, exponent) / scale;
132- float new_base = 1 .0f / std::pow (theta, exponent / scale); // Simplified for YARN
133- freqs_ntk[i] = 1 .0f / std::pow (new_base, exponent);
134- }
135- }
136-
137- // YARN interpolation
138117 if (yarn && max_pe_len > ori_max_pe_len) {
139118 float beta_0 = 1 .25f ;
140119 float beta_1 = 0 .75f ;
141120 float gamma_0 = 16 .0f ;
142121 float gamma_1 = 2 .0f ;
143122
123+ float scale = std::max (1 .0f , static_cast <float >(max_pe_len) / ori_max_pe_len);
124+ // d,t,s
125+ float new_base = theta * std::pow (scale, half_dim / (half_dim - 1 ));
126+ for (int i = 0 ; i < half_dim; ++i) {
127+ float exponent = static_cast <float >(i) / half_dim;
128+ freqs_base[i] = 1 .0f / std::pow (theta, exponent);
129+ freqs_linear[i] = 1 .0f / (scale * std::pow (theta, exponent));
130+ freqs_ntk[i] = 1 .0f / std::pow (new_base, exponent);
131+ }
132+
144133 if (dype) {
145134 beta_0 = std::pow (beta_0, 2 .0f * current_timestep * current_timestep);
146135 beta_1 = std::pow (beta_1, 2 .0f * current_timestep * current_timestep);
147136 gamma_0 = std::pow (gamma_0, 2 .0f * current_timestep * current_timestep);
148137 gamma_1 = std::pow (gamma_1, 2 .0f * current_timestep * current_timestep);
149138 }
150139
151- // Compute freqs_linear and freqs_ntk
152- for (int i = 0 ; i < half_dim; ++i) {
153- float exponent = static_cast <float >(i) / half_dim;
154- freqs_linear[i] = 1 .0f / (std::pow (theta, exponent) * scale);
155- }
156-
157- float new_base = find_newbase_ntk (dim, theta, scale);
158- for (int i = 0 ; i < half_dim; ++i) {
159- float exponent = static_cast <float >(i) / half_dim;
160- freqs_ntk[i] = 1 .0f / std::pow (new_base, exponent);
161- }
162-
163140 // Apply correction range and linear ramp mask
164141 auto [low, high] = find_correction_range (beta_0, beta_1, dim, theta, ori_max_pe_len);
165142 auto mask = linear_ramp_mask (low, high, half_dim);
166143 for (int i = 0 ; i < half_dim; ++i) {
167- freqs[i] = freqs_linear[i] * ( 1 . 0f - mask[i]) + freqs_ntk[i] * mask[i];
144+ freqs[i] = freqs_linear[i] * mask[i] + freqs_ntk[i] * ( 1 . 0f - mask[i]) ;
168145 }
169146
170147 // Apply gamma correction
171148 auto [low_gamma, high_gamma] = find_correction_range (gamma_0, gamma_1, dim, theta, ori_max_pe_len);
172149 auto mask_gamma = linear_ramp_mask (low_gamma, high_gamma, half_dim);
173150 for (int i = 0 ; i < half_dim; ++i) {
174- freqs[i] = freqs[i] * ( 1 . 0f - mask_gamma[i]) + freqs_base[i] * mask_gamma[i];
151+ freqs[i] = freqs[i] * mask_gamma[i] + freqs_base[i] * ( 1 . 0f - mask_gamma[i]) ;
175152 }
176153 } else {
177154 float theta_ntk = theta * ntk_factor;
@@ -288,15 +265,23 @@ namespace Rope {
288265 int bs,
289266 float theta,
290267 const std::vector<int >& axes_dim,
291- bool yarn = false ,
292- int max_pe_len = -1 ,
293- int ori_max_pe_len = 64 ,
294- bool dype = false ,
295- float current_timestep = 1 .0f ) {
268+ bool yarn = false ,
269+ std::vector<int > max_pe_len = {},
270+ int ori_max_pe_len = 64 ,
271+ bool dype = false ,
272+ float current_timestep = 1 .0f ,
273+ std::vector<float > ntk_factors = {}) {
296274 std::vector<std::vector<float >> trans_ids = transpose (ids);
297275 size_t pos_len = ids.size () / bs;
298276 int num_axes = axes_dim.size ();
299277
278+ if (ntk_factors.size () == 0 ) {
279+ ntk_factors = std::vector<float >(num_axes, 1 .0f );
280+ }
281+ if (max_pe_len.size () == 0 ) {
282+ max_pe_len = std::vector<int >(num_axes, -1 );
283+ }
284+
300285 int emb_dim = 0 ;
301286 for (int d : axes_dim) {
302287 emb_dim += d;
@@ -307,7 +292,7 @@ namespace Rope {
307292
308293 for (int i = 0 ; i < num_axes; ++i) {
309294 std::vector<std::vector<float >> rope_emb = rope_ext (
310- trans_ids[i], axes_dim[i], theta, false , 1 .0f , 1 . 0f , true , yarn, max_pe_len, ori_max_pe_len, dype, current_timestep);
295+ trans_ids[i], axes_dim[i], theta, false , 1 .0f , ntk_factors[i] , true , yarn, max_pe_len[i] , ori_max_pe_len, dype, current_timestep);
311296
312297 for (int b = 0 ; b < bs; ++b) {
313298 for (size_t j = 0 ; j < pos_len; ++j) {
@@ -384,20 +369,38 @@ namespace Rope {
384369 const std::vector<int >& axes_dim,
385370 bool use_yarn = false ,
386371 bool use_dype = false ,
372+ bool use_ntk = false ,
387373 float current_timestep = 1 .0f ) {
388- const int base_patches = 1024 / 16 ;
374+ int base_resolution = 1024 ;
375+ // set it via environment variable for now (TODO: arg)
376+ const char * env_base_resolution = getenv (" FLUX_DYPE_BASE_RESOLUTION" );
377+ if (env_base_resolution != nullptr ) {
378+ base_resolution = atoi (env_base_resolution);
379+ }
380+ int base_patches = base_resolution / 16 ;
389381 std::vector<std::vector<float >> ids = gen_flux_ids (h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
390- float max_pos_f = 0 .0f ;
391- for (const auto & row : ids) {
392- for (float val : row) {
382+ std::vector<int > max_pos_vec = {};
383+ std::vector<float > ntk_factor_vec = {};
384+ for (int i = 0 ; i < axes_dim.size (); i++) {
385+ float max_pos_f = 0 .0f ;
386+ for (const auto & row : ids) {
387+ float val = row[i];
393388 if (val > max_pos_f) {
394389 max_pos_f = val;
395390 }
396391 }
392+ int max_pos = static_cast <int >(max_pos_f) + 1 ;
393+ max_pos_vec.push_back (max_pos);
394+ float ntk_factor = 1 .0f ;
395+ if (use_ntk) {
396+ float base_ntk = pow ((float )max_pos / base_patches, (float )axes_dim[i] / (axes_dim[i] - 2 ));
397+ ntk_factor = use_dype ? pow (base_ntk, 2 .0f * current_timestep * current_timestep) : base_ntk;
398+ ntk_factor = std::max (1 .0f , ntk_factor);
399+ }
400+ ntk_factor_vec.push_back (ntk_factor);
397401 }
398- int max_pos = static_cast <int >(max_pos_f) + 1 ;
399- if (use_yarn && max_pos > base_patches) {
400- return embed_nd_ext (ids, bs, theta, axes_dim, true , max_pos, base_patches, use_dype, current_timestep);
402+ if (use_yarn || use_ntk) {
403+ return embed_nd_ext (ids, bs, theta, axes_dim, use_yarn, max_pos_vec, base_patches, use_dype, current_timestep, ntk_factor_vec);
401404 } else {
402405 return embed_nd (ids, bs, theta, axes_dim);
403406 }
0 commit comments