@@ -71,6 +71,138 @@ namespace Rope {
7171 return result;
7272 }
7373
74+ float find_correction_factor (float num_rotations, int dim, float base, float max_position_embeddings) {
75+ return (dim * std::log (max_position_embeddings / (num_rotations * 2 * 3.14159265358979323846 ))) / (2 * std::log (base));
76+ }
77+
78+ std::pair<int , int > find_correction_range (float low_ratio, float high_ratio, int dim, float base, float ori_max_pe_len) {
79+ float low = std::floor (find_correction_factor (low_ratio, dim, base, ori_max_pe_len));
80+ float high = std::ceil (find_correction_factor (high_ratio, dim, base, ori_max_pe_len));
81+ return {std::max (0 , static_cast <int >(low)), std::min (dim - 1 , static_cast <int >(high))};
82+ }
83+
84+ std::vector<float > linear_ramp_mask (int min, int max, int dim) {
85+ if (min == max) {
86+ max += 0 .001f ; // Prevent singularity
87+ }
88+ std::vector<float > ramp (dim);
89+ for (int i = 0 ; i < dim; ++i) {
90+ ramp[i] = std::max (0 .0f , std::min (1 .0f , static_cast <float >(i - min) / (max - min)));
91+ }
92+ return ramp;
93+ }
94+
95+ float find_newbase_ntk (int dim, float base, float scale) {
96+ return base * std::pow (scale, static_cast <float >(dim) / (dim - 2 ));
97+ }
98+
99+ __STATIC_INLINE__ std::vector<std::vector<float >> rope_ext (
100+ const std::vector<float >& pos,
101+ int dim,
102+ float theta = 10000 .0f ,
103+ bool use_real = false ,
104+ float linear_factor = 1 .0f ,
105+ float ntk_factor = 1 .0f ,
106+ bool repeat_interleave_real = true ,
107+ bool yarn = false ,
108+ int max_pe_len = -1 ,
109+ int ori_max_pe_len = 64 ,
110+ bool dype = false ,
111+ float current_timestep = 1 .0f ) {
112+ assert (dim % 2 == 0 );
113+ int half_dim = dim / 2 ;
114+
115+ // Compute scale for YARN
116+ float scale = 1 .0f ;
117+ if (yarn && max_pe_len > ori_max_pe_len) {
118+ scale = std::max (1 .0f , static_cast <float >(max_pe_len) / ori_max_pe_len);
119+ }
120+
121+ // Compute frequencies
122+ std::vector<float > freqs_base (half_dim);
123+ std::vector<float > freqs_linear (half_dim);
124+ std::vector<float > freqs_ntk (half_dim);
125+ std::vector<float > freqs (half_dim);
126+
127+ for (int i = 0 ; i < half_dim; ++i) {
128+ float exponent = static_cast <float >(i) / half_dim;
129+ freqs_base[i] = 1 .0f / std::pow (theta, exponent);
130+ if (yarn && max_pe_len > ori_max_pe_len) {
131+ freqs_linear[i] = 1 .0f / std::pow (theta, exponent) / scale;
132+ float new_base = 1 .0f / std::pow (theta, exponent / scale); // Simplified for YARN
133+ freqs_ntk[i] = 1 .0f / std::pow (new_base, exponent);
134+ }
135+ }
136+
137+ // YARN interpolation
138+ if (yarn && max_pe_len > ori_max_pe_len) {
139+ float beta_0 = 1 .25f ;
140+ float beta_1 = 0 .75f ;
141+ float gamma_0 = 16 .0f ;
142+ float gamma_1 = 2 .0f ;
143+
144+ if (dype) {
145+ beta_0 = std::pow (beta_0, 2 .0f * current_timestep * current_timestep);
146+ beta_1 = std::pow (beta_1, 2 .0f * current_timestep * current_timestep);
147+ gamma_0 = std::pow (gamma_0, 2 .0f * current_timestep * current_timestep);
148+ gamma_1 = std::pow (gamma_1, 2 .0f * current_timestep * current_timestep);
149+ }
150+
151+ // Compute freqs_linear and freqs_ntk
152+ for (int i = 0 ; i < half_dim; ++i) {
153+ float exponent = static_cast <float >(i) / half_dim;
154+ freqs_linear[i] = 1 .0f / (std::pow (theta, exponent) * scale);
155+ }
156+
157+ float new_base = find_newbase_ntk (dim, theta, scale);
158+ for (int i = 0 ; i < half_dim; ++i) {
159+ float exponent = static_cast <float >(i) / half_dim;
160+ freqs_ntk[i] = 1 .0f / std::pow (new_base, exponent);
161+ }
162+
163+ // Apply correction range and linear ramp mask
164+ auto [low, high] = find_correction_range (beta_0, beta_1, dim, theta, ori_max_pe_len);
165+ auto mask = linear_ramp_mask (low, high, half_dim);
166+ for (int i = 0 ; i < half_dim; ++i) {
167+ freqs[i] = freqs_linear[i] * (1 .0f - mask[i]) + freqs_ntk[i] * mask[i];
168+ }
169+
170+ // Apply gamma correction
171+ auto [low_gamma, high_gamma] = find_correction_range (gamma_0, gamma_1, dim, theta, ori_max_pe_len);
172+ auto mask_gamma = linear_ramp_mask (low_gamma, high_gamma, half_dim);
173+ for (int i = 0 ; i < half_dim; ++i) {
174+ freqs[i] = freqs[i] * (1 .0f - mask_gamma[i]) + freqs_base[i] * mask_gamma[i];
175+ }
176+ } else {
177+ float theta_ntk = theta * ntk_factor;
178+ for (int i = 0 ; i < half_dim; ++i) {
179+ float exponent = static_cast <float >(i) / half_dim;
180+ freqs[i] = 1 .0f / std::pow (theta_ntk, exponent) / linear_factor;
181+ }
182+ }
183+
184+ // Outer product of pos and freqs
185+ std::vector<std::vector<float >> freqs_outer (pos.size (), std::vector<float >(half_dim));
186+ for (size_t i = 0 ; i < pos.size (); ++i) {
187+ for (int j = 0 ; j < half_dim; ++j) {
188+ freqs_outer[i][j] = pos[i] * freqs[j];
189+ }
190+ }
191+
192+ std::vector<std::vector<float >> result;
193+ result.resize (pos.size (), std::vector<float >(half_dim * 4 ));
194+ for (size_t i = 0 ; i < pos.size (); ++i) {
195+ for (int j = 0 ; j < half_dim; ++j) {
196+ result[i][4 * j] = std::cos (freqs_outer[i][j]); // cos
197+ result[i][4 * j + 1 ] = -std::sin (freqs_outer[i][j]); // -sin
198+ result[i][4 * j + 2 ] = std::sin (freqs_outer[i][j]); // sin
199+ result[i][4 * j + 3 ] = std::cos (freqs_outer[i][j]); // cos
200+ }
201+ }
202+
203+ return result;
204+ }
205+
74206 // Generate IDs for image patches and text
75207 __STATIC_INLINE__ std::vector<std::vector<float >> gen_txt_ids (int bs, int context_len) {
76208 return std::vector<std::vector<float >>(bs * context_len, std::vector<float >(3 , 0.0 ));
@@ -151,6 +283,45 @@ namespace Rope {
151283 return flatten (emb);
152284 }
153285
286+ std::vector<float > embed_nd_ext (
287+ const std::vector<std::vector<float >>& ids,
288+ int bs,
289+ float theta,
290+ const std::vector<int >& axes_dim,
291+ bool yarn = false ,
292+ int max_pe_len = -1 ,
293+ int ori_max_pe_len = 64 ,
294+ bool dype = false ,
295+ float current_timestep = 1 .0f ) {
296+ std::vector<std::vector<float >> trans_ids = transpose (ids);
297+ size_t pos_len = ids.size () / bs;
298+ int num_axes = axes_dim.size ();
299+
300+ int emb_dim = 0 ;
301+ for (int d : axes_dim) {
302+ emb_dim += d;
303+ }
304+
305+ std::vector<std::vector<float >> emb (bs * pos_len, std::vector<float >(emb_dim * 2 , 0 .0f ));
306+ int offset = 0 ;
307+
308+ for (int i = 0 ; i < num_axes; ++i) {
309+ std::vector<std::vector<float >> rope_emb = rope_ext (
310+ trans_ids[i], axes_dim[i], theta, false , 1 .0f , 1 .0f , true , yarn, max_pe_len, ori_max_pe_len, dype, current_timestep);
311+
312+ for (int b = 0 ; b < bs; ++b) {
313+ for (size_t j = 0 ; j < pos_len; ++j) {
314+ for (size_t k = 0 ; k < rope_emb[j].size (); ++k) {
315+ emb[b * pos_len + j][offset + k] = rope_emb[j][k];
316+ }
317+ }
318+ }
319+ offset += static_cast <int >(axes_dim[i] * 2 );
320+ }
321+
322+ return flatten (emb);
323+ }
324+
154325 __STATIC_INLINE__ std::vector<std::vector<float >> gen_refs_ids (int patch_size,
155326 int bs,
156327 const std::vector<ggml_tensor*>& ref_latents,
@@ -210,9 +381,26 @@ namespace Rope {
210381 const std::vector<ggml_tensor*>& ref_latents,
211382 bool increase_ref_index,
212383 int theta,
213- const std::vector<int >& axes_dim) {
384+ const std::vector<int >& axes_dim,
385+ bool use_yarn = false ,
386+ bool use_dype = false ,
387+ float current_timestep = 1 .0f ) {
388+ const int base_patches = 1024 / 16 ;
214389 std::vector<std::vector<float >> ids = gen_flux_ids (h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
215- return embed_nd (ids, bs, theta, axes_dim);
390+ float max_pos_f = 0 .0f ;
391+ for (const auto & row : ids) {
392+ for (float val : row) {
393+ if (val > max_pos_f) {
394+ max_pos_f = val;
395+ }
396+ }
397+ }
398+ int max_pos = static_cast <int >(max_pos_f) + 1 ;
399+ if (use_yarn && max_pos > base_patches) {
400+ return embed_nd_ext (ids, bs, theta, axes_dim, true , max_pos, base_patches, use_dype, current_timestep);
401+ } else {
402+ return embed_nd (ids, bs, theta, axes_dim);
403+ }
216404 }
217405
218406 __STATIC_INLINE__ std::vector<std::vector<float >> gen_qwen_image_ids (int h,
0 commit comments