Skip to content

Commit 288d574

Browse files
committed
Flux dype
1 parent 347710f commit 288d574

File tree

2 files changed

+212
-4
lines changed

2 files changed

+212
-4
lines changed

flux.hpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,8 @@ namespace Flux {
12661266
set_backend_tensor_data(mod_index_arange, mod_index_arange_vec.data());
12671267
}
12681268
y = to_backend(y);
1269-
1269+
float current_timestep = ggml_get_f32_1d(timesteps, 0);
1270+
LOG_DEBUG("current_timestep %f", current_timestep);
12701271
timesteps = to_backend(timesteps);
12711272
if (flux_params.guidance_embed || flux_params.is_chroma) {
12721273
guidance = to_backend(guidance);
@@ -1275,6 +1276,22 @@ namespace Flux {
12751276
ref_latents[i] = to_backend(ref_latents[i]);
12761277
}
12771278

1279+
// get use_yarn and use_dype from env for now (TODO: add args)
1280+
bool use_yarn = false;
1281+
bool use_dype = false;
1282+
char* use_yarn_env = getenv("USE_YARN");
1283+
if (use_yarn_env != nullptr) {
1284+
if (strcmp(use_yarn_env, "OFF") != 0) {
1285+
use_yarn = true;
1286+
char* use_dype_env = getenv("USE_DYPE");
1287+
if (use_dype_env != nullptr) {
1288+
if (strcmp(use_dype_env, "OFF") != 0) {
1289+
use_dype = true;
1290+
}
1291+
}
1292+
}
1293+
}
1294+
12781295
pe_vec = Rope::gen_flux_pe(x->ne[1],
12791296
x->ne[0],
12801297
flux_params.patch_size,
@@ -1283,7 +1300,10 @@ namespace Flux {
12831300
ref_latents,
12841301
increase_ref_index,
12851302
flux_params.theta,
1286-
flux_params.axes_dim);
1303+
flux_params.axes_dim,
1304+
use_yarn,
1305+
use_dype,
1306+
current_timestep);
12871307
int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2;
12881308
// LOG_DEBUG("pos_len %d", pos_len);
12891309
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len);

rope.hpp

Lines changed: 190 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,138 @@ namespace Rope {
7171
return result;
7272
}
7373

74+
float find_correction_factor(float num_rotations, int dim, float base, float max_position_embeddings) {
75+
return (dim * std::log(max_position_embeddings / (num_rotations * 2 * 3.14159265358979323846))) / (2 * std::log(base));
76+
}
77+
78+
std::pair<int, int> find_correction_range(float low_ratio, float high_ratio, int dim, float base, float ori_max_pe_len) {
79+
float low = std::floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len));
80+
float high = std::ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len));
81+
return {std::max(0, static_cast<int>(low)), std::min(dim - 1, static_cast<int>(high))};
82+
}
83+
84+
std::vector<float> linear_ramp_mask(int min, int max, int dim) {
85+
if (min == max) {
86+
max += 0.001f; // Prevent singularity
87+
}
88+
std::vector<float> ramp(dim);
89+
for (int i = 0; i < dim; ++i) {
90+
ramp[i] = std::max(0.0f, std::min(1.0f, static_cast<float>(i - min) / (max - min)));
91+
}
92+
return ramp;
93+
}
94+
95+
float find_newbase_ntk(int dim, float base, float scale) {
96+
return base * std::pow(scale, static_cast<float>(dim) / (dim - 2));
97+
}
98+
99+
__STATIC_INLINE__ std::vector<std::vector<float>> rope_ext(
100+
const std::vector<float>& pos,
101+
int dim,
102+
float theta = 10000.0f,
103+
bool use_real = false,
104+
float linear_factor = 1.0f,
105+
float ntk_factor = 1.0f,
106+
bool repeat_interleave_real = true,
107+
bool yarn = false,
108+
int max_pe_len = -1,
109+
int ori_max_pe_len = 64,
110+
bool dype = false,
111+
float current_timestep = 1.0f) {
112+
assert(dim % 2 == 0);
113+
int half_dim = dim / 2;
114+
115+
// Compute scale for YARN
116+
float scale = 1.0f;
117+
if (yarn && max_pe_len > ori_max_pe_len) {
118+
scale = std::max(1.0f, static_cast<float>(max_pe_len) / ori_max_pe_len);
119+
}
120+
121+
// Compute frequencies
122+
std::vector<float> freqs_base(half_dim);
123+
std::vector<float> freqs_linear(half_dim);
124+
std::vector<float> freqs_ntk(half_dim);
125+
std::vector<float> freqs(half_dim);
126+
127+
for (int i = 0; i < half_dim; ++i) {
128+
float exponent = static_cast<float>(i) / half_dim;
129+
freqs_base[i] = 1.0f / std::pow(theta, exponent);
130+
if (yarn && max_pe_len > ori_max_pe_len) {
131+
freqs_linear[i] = 1.0f / std::pow(theta, exponent) / scale;
132+
float new_base = 1.0f / std::pow(theta, exponent / scale); // Simplified for YARN
133+
freqs_ntk[i] = 1.0f / std::pow(new_base, exponent);
134+
}
135+
}
136+
137+
// YARN interpolation
138+
if (yarn && max_pe_len > ori_max_pe_len) {
139+
float beta_0 = 1.25f;
140+
float beta_1 = 0.75f;
141+
float gamma_0 = 16.0f;
142+
float gamma_1 = 2.0f;
143+
144+
if (dype) {
145+
beta_0 = std::pow(beta_0, 2.0f * current_timestep * current_timestep);
146+
beta_1 = std::pow(beta_1, 2.0f * current_timestep * current_timestep);
147+
gamma_0 = std::pow(gamma_0, 2.0f * current_timestep * current_timestep);
148+
gamma_1 = std::pow(gamma_1, 2.0f * current_timestep * current_timestep);
149+
}
150+
151+
// Compute freqs_linear and freqs_ntk
152+
for (int i = 0; i < half_dim; ++i) {
153+
float exponent = static_cast<float>(i) / half_dim;
154+
freqs_linear[i] = 1.0f / (std::pow(theta, exponent) * scale);
155+
}
156+
157+
float new_base = find_newbase_ntk(dim, theta, scale);
158+
for (int i = 0; i < half_dim; ++i) {
159+
float exponent = static_cast<float>(i) / half_dim;
160+
freqs_ntk[i] = 1.0f / std::pow(new_base, exponent);
161+
}
162+
163+
// Apply correction range and linear ramp mask
164+
auto [low, high] = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len);
165+
auto mask = linear_ramp_mask(low, high, half_dim);
166+
for (int i = 0; i < half_dim; ++i) {
167+
freqs[i] = freqs_linear[i] * (1.0f - mask[i]) + freqs_ntk[i] * mask[i];
168+
}
169+
170+
// Apply gamma correction
171+
auto [low_gamma, high_gamma] = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len);
172+
auto mask_gamma = linear_ramp_mask(low_gamma, high_gamma, half_dim);
173+
for (int i = 0; i < half_dim; ++i) {
174+
freqs[i] = freqs[i] * (1.0f - mask_gamma[i]) + freqs_base[i] * mask_gamma[i];
175+
}
176+
} else {
177+
float theta_ntk = theta * ntk_factor;
178+
for (int i = 0; i < half_dim; ++i) {
179+
float exponent = static_cast<float>(i) / half_dim;
180+
freqs[i] = 1.0f / std::pow(theta_ntk, exponent) / linear_factor;
181+
}
182+
}
183+
184+
// Outer product of pos and freqs
185+
std::vector<std::vector<float>> freqs_outer(pos.size(), std::vector<float>(half_dim));
186+
for (size_t i = 0; i < pos.size(); ++i) {
187+
for (int j = 0; j < half_dim; ++j) {
188+
freqs_outer[i][j] = pos[i] * freqs[j];
189+
}
190+
}
191+
192+
std::vector<std::vector<float>> result;
193+
result.resize(pos.size(), std::vector<float>(half_dim * 4));
194+
for (size_t i = 0; i < pos.size(); ++i) {
195+
for (int j = 0; j < half_dim; ++j) {
196+
result[i][4 * j] = std::cos(freqs_outer[i][j]); // cos
197+
result[i][4 * j + 1] = -std::sin(freqs_outer[i][j]); // -sin
198+
result[i][4 * j + 2] = std::sin(freqs_outer[i][j]); // sin
199+
result[i][4 * j + 3] = std::cos(freqs_outer[i][j]); // cos
200+
}
201+
}
202+
203+
return result;
204+
}
205+
74206
// Generate IDs for image patches and text
75207
__STATIC_INLINE__ std::vector<std::vector<float>> gen_txt_ids(int bs, int context_len) {
76208
return std::vector<std::vector<float>>(bs * context_len, std::vector<float>(3, 0.0));
@@ -151,6 +283,45 @@ namespace Rope {
151283
return flatten(emb);
152284
}
153285

286+
std::vector<float> embed_nd_ext(
287+
const std::vector<std::vector<float>>& ids,
288+
int bs,
289+
float theta,
290+
const std::vector<int>& axes_dim,
291+
bool yarn = false,
292+
int max_pe_len = -1,
293+
int ori_max_pe_len = 64,
294+
bool dype = false,
295+
float current_timestep = 1.0f) {
296+
std::vector<std::vector<float>> trans_ids = transpose(ids);
297+
size_t pos_len = ids.size() / bs;
298+
int num_axes = axes_dim.size();
299+
300+
int emb_dim = 0;
301+
for (int d : axes_dim) {
302+
emb_dim += d;
303+
}
304+
305+
std::vector<std::vector<float>> emb(bs * pos_len, std::vector<float>(emb_dim * 2, 0.0f));
306+
int offset = 0;
307+
308+
for (int i = 0; i < num_axes; ++i) {
309+
std::vector<std::vector<float>> rope_emb = rope_ext(
310+
trans_ids[i], axes_dim[i], theta, false, 1.0f, 1.0f, true, yarn, max_pe_len, ori_max_pe_len, dype, current_timestep);
311+
312+
for (int b = 0; b < bs; ++b) {
313+
for (size_t j = 0; j < pos_len; ++j) {
314+
for (size_t k = 0; k < rope_emb[j].size(); ++k) {
315+
emb[b * pos_len + j][offset + k] = rope_emb[j][k];
316+
}
317+
}
318+
}
319+
offset += static_cast<int>(axes_dim[i] * 2);
320+
}
321+
322+
return flatten(emb);
323+
}
324+
154325
__STATIC_INLINE__ std::vector<std::vector<float>> gen_refs_ids(int patch_size,
155326
int bs,
156327
const std::vector<ggml_tensor*>& ref_latents,
@@ -210,9 +381,26 @@ namespace Rope {
210381
const std::vector<ggml_tensor*>& ref_latents,
211382
bool increase_ref_index,
212383
int theta,
213-
const std::vector<int>& axes_dim) {
384+
const std::vector<int>& axes_dim,
385+
bool use_yarn = false,
386+
bool use_dype = false,
387+
float current_timestep = 1.0f) {
388+
const int base_patches = 1024 / 16;
214389
std::vector<std::vector<float>> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
215-
return embed_nd(ids, bs, theta, axes_dim);
390+
float max_pos_f = 0.0f;
391+
for (const auto& row : ids) {
392+
for (float val : row) {
393+
if (val > max_pos_f) {
394+
max_pos_f = val;
395+
}
396+
}
397+
}
398+
int max_pos = static_cast<int>(max_pos_f) + 1;
399+
if (use_yarn && max_pos > base_patches) {
400+
return embed_nd_ext(ids, bs, theta, axes_dim, true, max_pos, base_patches, use_dype, current_timestep);
401+
} else {
402+
return embed_nd(ids, bs, theta, axes_dim);
403+
}
216404
}
217405

218406
__STATIC_INLINE__ std::vector<std::vector<float>> gen_qwen_image_ids(int h,

0 commit comments

Comments
 (0)