Skip to content

Commit 5e6c77e

Browse files
committed
Working dype + NTK
1 parent 1d27a27 commit 5e6c77e

File tree

2 files changed

+73
-61
lines changed

2 files changed

+73
-61
lines changed

flux.hpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,19 +1276,27 @@ namespace Flux {
12761276
ref_latents[i] = to_backend(ref_latents[i]);
12771277
}
12781278

1279-
// get use_yarn and use_dype from env for now (TODO: add args)
1280-
bool use_yarn = false;
1281-
bool use_dype = false;
1282-
char* use_yarn_env = getenv("USE_YARN");
1283-
if (use_yarn_env != nullptr) {
1284-
if (strcmp(use_yarn_env, "OFF") != 0) {
1279+
// get use_yarn, use_ntk and use_dype from env for now (TODO: add args)
1280+
// Env value could be one of yarn, dy_yarn, ntk or dy_ntk, (anything else means disabled)
1281+
const char* env_value = getenv("FLUX_ROPE");
1282+
bool use_yarn = false;
1283+
bool use_dype = false;
1284+
bool use_ntk = false;
1285+
if (env_value != nullptr) {
1286+
if (strcmp(env_value, "YARN") == 0) {
1287+
LOG_DEBUG("Using YARN RoPE");
12851288
use_yarn = true;
1286-
char* use_dype_env = getenv("USE_DYPE");
1287-
if (use_dype_env != nullptr) {
1288-
if (strcmp(use_dype_env, "OFF") != 0) {
1289-
use_dype = true;
1290-
}
1291-
}
1289+
} else if (strcmp(env_value, "DY_YARN") == 0) {
1290+
LOG_DEBUG("Using DY YARN RoPE");
1291+
use_yarn = true;
1292+
use_dype = true;
1293+
} else if (strcmp(env_value, "NTK") == 0) {
1294+
LOG_DEBUG("Using NTK RoPE");
1295+
use_ntk = true;
1296+
} else if (strcmp(env_value, "DY_NTK") == 0) {
1297+
LOG_DEBUG("Using DY NTK RoPE");
1298+
use_ntk = true;
1299+
use_dype = true;
12921300
}
12931301
}
12941302

@@ -1303,6 +1311,7 @@ namespace Flux {
13031311
flux_params.axes_dim,
13041312
use_yarn,
13051313
use_dype,
1314+
use_ntk,
13061315
current_timestep);
13071316
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
13081317
// LOG_DEBUG("pos_len %d", pos_len);

rope.hpp

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ namespace Rope {
7878
std::pair<int, int> find_correction_range(float low_ratio, float high_ratio, int dim, float base, float ori_max_pe_len) {
7979
float low = std::floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len));
8080
float high = std::ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len));
81-
return {std::max(0, static_cast<int>(low)), std::min(dim - 1, static_cast<int>(high))};
81+
return {std::max(0, static_cast<int>(low)), std::min(dim / 2, static_cast<int>(high))};
8282
}
8383

8484
std::vector<float> linear_ramp_mask(int min, int max, int dim) {
@@ -92,10 +92,6 @@ namespace Rope {
9292
return ramp;
9393
}
9494

95-
float find_newbase_ntk(int dim, float base, float scale) {
96-
return base * std::pow(scale, static_cast<float>(dim) / (dim - 2));
97-
}
98-
9995
__STATIC_INLINE__ std::vector<std::vector<float>> rope_ext(
10096
const std::vector<float>& pos,
10197
int dim,
@@ -112,66 +108,47 @@ namespace Rope {
112108
assert(dim % 2 == 0);
113109
int half_dim = dim / 2;
114110

115-
// Compute scale for YARN
116-
float scale = 1.0f;
117-
if (yarn && max_pe_len > ori_max_pe_len) {
118-
scale = std::max(1.0f, static_cast<float>(max_pe_len) / ori_max_pe_len);
119-
}
120-
121111
// Compute frequencies
122112
std::vector<float> freqs_base(half_dim);
123113
std::vector<float> freqs_linear(half_dim);
124114
std::vector<float> freqs_ntk(half_dim);
125115
std::vector<float> freqs(half_dim);
126116

127-
for (int i = 0; i < half_dim; ++i) {
128-
float exponent = static_cast<float>(i) / half_dim;
129-
freqs_base[i] = 1.0f / std::pow(theta, exponent);
130-
if (yarn && max_pe_len > ori_max_pe_len) {
131-
freqs_linear[i] = 1.0f / std::pow(theta, exponent) / scale;
132-
float new_base = 1.0f / std::pow(theta, exponent / scale); // Simplified for YARN
133-
freqs_ntk[i] = 1.0f / std::pow(new_base, exponent);
134-
}
135-
}
136-
137-
// YARN interpolation
138117
if (yarn && max_pe_len > ori_max_pe_len) {
139118
float beta_0 = 1.25f;
140119
float beta_1 = 0.75f;
141120
float gamma_0 = 16.0f;
142121
float gamma_1 = 2.0f;
143122

123+
float scale = std::max(1.0f, static_cast<float>(max_pe_len) / ori_max_pe_len);
124+
// d,t,s
125+
float new_base = theta * std::pow(scale, half_dim / (half_dim - 1));
126+
for (int i = 0; i < half_dim; ++i) {
127+
float exponent = static_cast<float>(i) / half_dim;
128+
freqs_base[i] = 1.0f / std::pow(theta, exponent);
129+
freqs_linear[i] = 1.0f / (scale * std::pow(theta, exponent));
130+
freqs_ntk[i] = 1.0f / std::pow(new_base, exponent);
131+
}
132+
144133
if (dype) {
145134
beta_0 = std::pow(beta_0, 2.0f * current_timestep * current_timestep);
146135
beta_1 = std::pow(beta_1, 2.0f * current_timestep * current_timestep);
147136
gamma_0 = std::pow(gamma_0, 2.0f * current_timestep * current_timestep);
148137
gamma_1 = std::pow(gamma_1, 2.0f * current_timestep * current_timestep);
149138
}
150139

151-
// Compute freqs_linear and freqs_ntk
152-
for (int i = 0; i < half_dim; ++i) {
153-
float exponent = static_cast<float>(i) / half_dim;
154-
freqs_linear[i] = 1.0f / (std::pow(theta, exponent) * scale);
155-
}
156-
157-
float new_base = find_newbase_ntk(dim, theta, scale);
158-
for (int i = 0; i < half_dim; ++i) {
159-
float exponent = static_cast<float>(i) / half_dim;
160-
freqs_ntk[i] = 1.0f / std::pow(new_base, exponent);
161-
}
162-
163140
// Apply correction range and linear ramp mask
164141
auto [low, high] = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len);
165142
auto mask = linear_ramp_mask(low, high, half_dim);
166143
for (int i = 0; i < half_dim; ++i) {
167-
freqs[i] = freqs_linear[i] * (1.0f - mask[i]) + freqs_ntk[i] * mask[i];
144+
freqs[i] = freqs_linear[i] * mask[i] + freqs_ntk[i] * (1.0f - mask[i]);
168145
}
169146

170147
// Apply gamma correction
171148
auto [low_gamma, high_gamma] = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len);
172149
auto mask_gamma = linear_ramp_mask(low_gamma, high_gamma, half_dim);
173150
for (int i = 0; i < half_dim; ++i) {
174-
freqs[i] = freqs[i] * (1.0f - mask_gamma[i]) + freqs_base[i] * mask_gamma[i];
151+
freqs[i] = freqs[i] * mask_gamma[i] + freqs_base[i] * (1.0f - mask_gamma[i]);
175152
}
176153
} else {
177154
float theta_ntk = theta * ntk_factor;
@@ -288,15 +265,23 @@ namespace Rope {
288265
int bs,
289266
float theta,
290267
const std::vector<int>& axes_dim,
291-
bool yarn = false,
292-
int max_pe_len = -1,
293-
int ori_max_pe_len = 64,
294-
bool dype = false,
295-
float current_timestep = 1.0f) {
268+
bool yarn = false,
269+
std::vector<int> max_pe_len = {},
270+
int ori_max_pe_len = 64,
271+
bool dype = false,
272+
float current_timestep = 1.0f,
273+
std::vector<float> ntk_factors = {}) {
296274
std::vector<std::vector<float>> trans_ids = transpose(ids);
297275
size_t pos_len = ids.size() / bs;
298276
int num_axes = axes_dim.size();
299277

278+
if (ntk_factors.size() == 0) {
279+
ntk_factors = std::vector<float>(num_axes, 1.0f);
280+
}
281+
if (max_pe_len.size() == 0) {
282+
max_pe_len = std::vector<int>(num_axes, -1);
283+
}
284+
300285
int emb_dim = 0;
301286
for (int d : axes_dim) {
302287
emb_dim += d;
@@ -307,7 +292,7 @@ namespace Rope {
307292

308293
for (int i = 0; i < num_axes; ++i) {
309294
std::vector<std::vector<float>> rope_emb = rope_ext(
310-
trans_ids[i], axes_dim[i], theta, false, 1.0f, 1.0f, true, yarn, max_pe_len, ori_max_pe_len, dype, current_timestep);
295+
trans_ids[i], axes_dim[i], theta, false, 1.0f, ntk_factors[i], true, yarn, max_pe_len[i], ori_max_pe_len, dype, current_timestep);
311296

312297
for (int b = 0; b < bs; ++b) {
313298
for (size_t j = 0; j < pos_len; ++j) {
@@ -384,20 +369,38 @@ namespace Rope {
384369
const std::vector<int>& axes_dim,
385370
bool use_yarn = false,
386371
bool use_dype = false,
372+
bool use_ntk = false,
387373
float current_timestep = 1.0f) {
388-
const int base_patches = 1024 / 16;
374+
int base_resolution = 1024;
375+
// set it via environment variable for now (TODO: arg)
376+
const char* env_base_resolution = getenv("FLUX_DYPE_BASE_RESOLUTION");
377+
if (env_base_resolution != nullptr) {
378+
base_resolution = atoi(env_base_resolution);
379+
}
380+
int base_patches = base_resolution / 16;
389381
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
390-
float max_pos_f = 0.0f;
391-
for (const auto& row : ids) {
392-
for (float val : row) {
382+
std::vector<int> max_pos_vec = {};
383+
std::vector<float> ntk_factor_vec = {};
384+
for (int i = 0; i < axes_dim.size(); i++) {
385+
float max_pos_f = 0.0f;
386+
for (const auto& row : ids) {
387+
float val = row[i];
393388
if (val > max_pos_f) {
394389
max_pos_f = val;
395390
}
396391
}
392+
int max_pos = static_cast<int>(max_pos_f) + 1;
393+
max_pos_vec.push_back(max_pos);
394+
float ntk_factor = 1.0f;
395+
if (use_ntk) {
396+
float base_ntk = pow((float)max_pos / base_patches, (float)axes_dim[i] / (axes_dim[i] - 2));
397+
ntk_factor = use_dype ? pow(base_ntk, 2.0f * current_timestep * current_timestep) : base_ntk;
398+
ntk_factor = std::max(1.0f, ntk_factor);
399+
}
400+
ntk_factor_vec.push_back(ntk_factor);
397401
}
398-
int max_pos = static_cast<int>(max_pos_f) + 1;
399-
if (use_yarn && max_pos > base_patches) {
400-
return embed_nd_ext(ids, bs, theta, axes_dim, true, max_pos, base_patches, use_dype, current_timestep);
402+
if (use_yarn || use_ntk) {
403+
return embed_nd_ext(ids, bs, theta, axes_dim, use_yarn, max_pos_vec, base_patches, use_dype, current_timestep, ntk_factor_vec);
401404
} else {
402405
return embed_nd(ids, bs, theta, axes_dim);
403406
}

0 commit comments

Comments
 (0)