|
5 | 5 | #include "mmdit.hpp"
|
6 | 6 | #include "unet.hpp"
|
7 | 7 | #include "chroma.hpp"
|
| 8 | +#include "ggml_extend.hpp" // Required for set_timestep_embedding |
8 | 9 |
|
9 | 10 | struct DiffusionModel {
|
10 | 11 | virtual void compute(int n_threads,
|
@@ -216,100 +217,79 @@ struct ChromaModel : public DiffusionModel {
|
216 | 217 | }
|
217 | 218 |
|
218 | 219 | void compute(int n_threads,
|
219 |
| - struct ggml_tensor* x, |
220 |
| - struct ggml_tensor* timesteps, |
221 |
| - struct ggml_tensor* context, // T5 embeddings |
222 |
| - struct ggml_tensor* c_concat, // Not used by Chroma |
223 |
| - struct ggml_tensor* y, // Not used by Chroma |
224 |
| - struct ggml_tensor* guidance, // Not used by Chroma |
225 |
| - int num_video_frames = -1, // Not used by Chroma |
226 |
| - std::vector<struct ggml_tensor*> controls = {}, // Not used by Chroma |
227 |
| - float control_strength = 0.f, // Not used by Chroma |
| 220 | + struct ggml_tensor* x, // img_latent_tokens |
| 221 | + struct ggml_tensor* timesteps, // raw_timesteps |
| 222 | + struct ggml_tensor* context, // txt_embeddings (T5 embeddings) |
| 223 | + struct ggml_tensor* c_concat, // t5_padding_mask |
| 224 | + struct ggml_tensor* y, // pe (positional embeddings) |
| 225 | + struct ggml_tensor* guidance, // raw_guidance |
| 226 | + int num_video_frames = -1, |
| 227 | + std::vector<struct ggml_tensor*> controls = {}, |
| 228 | + float control_strength = 0.f, |
228 | 229 | struct ggml_tensor** output = NULL,
|
229 | 230 | struct ggml_context* output_ctx = NULL,
|
230 | 231 | std::vector<int> skip_layers = std::vector<int>()) {
|
231 |
| - // For Chroma, context is T5 embeddings, c_concat and y are not used. |
232 |
| - // We need to pass positional embeddings (pe) and t5_padding_mask. |
233 |
| - // These are not directly available in the DiffusionModel compute signature. |
234 |
| - // This implies a need to adjust the DiffusionModel interface or how ChromaModel is called. |
235 |
| - |
236 |
| - // For now, let's assume pe and t5_padding_mask are handled internally by ChromaRunner |
237 |
| - // or passed through a different mechanism. |
238 |
| - // Based on the ChromaRunner::build_graph, it expects t5_padding_mask and pe. |
239 |
| - // The current DiffusionModel interface does not provide these. |
240 |
| - |
241 |
| - // This is a design conflict. The current DiffusionModel interface is generic. |
242 |
| - // Chroma has specific inputs (T5 embeddings, T5 padding mask, positional embeddings). |
243 |
| - // FluxModel also has specific inputs (pe, guidance). |
244 |
| - |
245 |
| - // Let's re-evaluate the `compute` method signature for `DiffusionModel`. |
246 |
| - // The `compute` method in `DiffusionModel` is quite generic. |
247 |
| - // `FluxModel` uses `pe` and `guidance`. |
248 |
| - // `ChromaModel` needs `pe` and `t5_padding_mask`. |
249 |
| - |
250 |
| - // The `stable-diffusion.cpp` calls `diffusion_model->compute`. |
251 |
| - // It passes `context`, `c_concat`, `y`, `guidance`. |
252 |
| - |
253 |
| - // For Chroma, `context` is `txt_embeddings`. |
254 |
| - // `c_concat` and `y` are not directly used by Chroma's UNet. |
255 |
| - // `guidance` is also not directly used by Chroma's UNet. |
256 |
| - |
257 |
| - // The `chroma_integration_plan.md` states: |
258 |
| - // `Pass the image latent, the T5 embeddings sequence, the timestep, positional embeddings, and the T5 padding mask to the ChromaUNet_ggml forward function.` |
259 |
| - |
260 |
| - // This means the `ChromaModel::compute` needs to receive `pe` and `t5_padding_mask`. |
261 |
| - // The current `DiffusionModel::compute` signature does not have these. |
262 |
| - |
263 |
| - // I need to modify the `DiffusionModel` interface to include `pe` and `t5_padding_mask`. |
264 |
| - // This will affect all other DiffusionModel implementations (UNetModel, MMDiTModel, FluxModel). |
265 |
| - // This is a larger change than just implementing ChromaModel. |
266 |
| - |
267 |
| - // Let's check the objective again: "Implement the `ChromaModel` class and integrate it into `StableDiffusionGGML`". |
268 |
| - // It doesn't explicitly say to modify the `DiffusionModel` interface. |
269 |
| - |
270 |
| - // Alternative: `ChromaRunner` could generate `pe` and `t5_padding_mask` internally. |
271 |
| - // `pe` generation depends on image dimensions and context length. |
272 |
| - // `t5_padding_mask` generation depends on T5 token IDs. |
273 |
| - |
274 |
| - // The `FluxRunner::build_graph` generates `pe_vec` and then creates `pe` tensor. |
275 |
| - // So, `pe` can be generated inside `ChromaRunner`. |
276 |
| - |
277 |
| - // For `t5_padding_mask`, it needs `token_ids` which come from `T5Embedder`. |
278 |
| - // `T5Embedder` is part of `cond_stage_model`. |
279 |
| - // The `cond_stage_model->get_learned_condition` returns `SDCondition` which contains `c_crossattn` (T5 embeddings). |
280 |
| - // It does not return `token_ids`. |
281 |
| - |
282 |
| - // This means `t5_padding_mask` cannot be generated inside `ChromaRunner` without access to `token_ids`. |
283 |
| - // The `generate_t5_padding_mask_ggml` function needs `token_ids`. |
284 |
| - |
285 |
| - // This implies that `token_ids` (or the `t5_padding_mask` itself) needs to be passed to `ChromaModel::compute`. |
286 |
| - // This means the `DiffusionModel::compute` interface *must* change. |
287 |
| - |
288 |
| - // Let's assume for now that `t5_padding_mask` and `pe` are passed as part of `context` or `c_concat` or `y` |
289 |
| - // or that the `ChromaRunner` can somehow access them. |
290 |
| - // This is a hacky solution. |
291 |
| - |
292 |
| - // Let's look at `stable-diffusion.cpp` where `diffusion_model->compute` is called. |
293 |
| - // It passes `x`, `timesteps`, `cond.c_crossattn`, `cond.c_concat`, `cond.c_vector`, `guidance_tensor`. |
294 |
| - // For Chroma, `cond.c_crossattn` is `txt_embeddings`. |
295 |
| - // `cond.c_concat` and `cond.c_vector` are currently unused for Chroma. |
296 |
| - // I can repurpose `cond.c_concat` for `t5_padding_mask` and `cond.c_vector` for `pe`. |
297 |
| - // This is a bit of a hack, but avoids changing the `DiffusionModel` interface for now. |
298 |
| - |
299 |
| - // Let's assume: |
300 |
| - // `context` (original `c_crossattn`) is `txt_embeddings` |
301 |
| - // `c_concat` is `t5_padding_mask` |
302 |
| - // `y` is `pe` |
303 |
| - |
304 |
| - // This means I need to modify `stable-diffusion.cpp` to pass these correctly. |
305 |
| - // And `ChromaModel::compute` will interpret them as such. |
| 232 | + // ... existing comments about repurposing inputs ... |
| 233 | + |
| 234 | + // Construct timestep_for_approximator_input_vec as per Python logic |
| 235 | + int64_t batch_size = x->ne[3]; // Assuming batch size is the last dim of x (img_latent_tokens) |
| 236 | + |
| 237 | + // 1. distill_timestep = timestep_embedding(timesteps, 16) |
| 238 | + // raw_timesteps is expected to be a 1D tensor [N] or [1]. Extract the single float value. |
| 239 | + std::vector<float> current_timestep_val = {ggml_get_f32_1d(timesteps, 0)}; |
| 240 | + struct ggml_tensor* distill_timestep_tensor = ggml_new_tensor_2d(output_ctx, GGML_TYPE_F32, 16, batch_size); |
| 241 | + set_timestep_embedding(current_timestep_val, distill_timestep_tensor, 16); |
| 242 | + // Permute from [16, batch_size] to [batch_size, 16] to match Python's (batch_size, 16) |
| 243 | + distill_timestep_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, distill_timestep_tensor, 1, 0, 2, 3)); |
| 244 | + |
| 245 | + // 2. distil_guidance = timestep_embedding(guidance, 16) |
| 246 | + std::vector<float> current_guidance_val = {ggml_get_f32_1d(guidance, 0)}; |
| 247 | + struct ggml_tensor* distil_guidance_tensor = ggml_new_tensor_2d(output_ctx, GGML_TYPE_F32, 16, batch_size); |
| 248 | + set_timestep_embedding(current_guidance_val, distil_guidance_tensor, 16); |
| 249 | + // Permute from [16, batch_size] to [batch_size, 16] |
| 250 | + distil_guidance_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, distil_guidance_tensor, 1, 0, 2, 3)); |
| 251 | + |
| 252 | + // 3. modulation_index = timestep_embedding(torch.arange(mod_index_length), 32) |
| 253 | + // mod_index_length is chroma.chroma_hyperparams.mod_vector_total_indices (344) |
| 254 | + std::vector<float> arange_mod_index(chroma.chroma_hyperparams.mod_vector_total_indices); |
| 255 | + for (int i = 0; i < chroma.chroma_hyperparams.mod_vector_total_indices; ++i) { |
| 256 | + arange_mod_index[i] = (float)i; |
| 257 | + } |
| 258 | + struct ggml_tensor* modulation_index_tensor = ggml_new_tensor_2d(output_ctx, GGML_TYPE_F32, 32, chroma.chroma_hyperparams.mod_vector_total_indices); |
| 259 | + set_timestep_embedding(arange_mod_index, modulation_index_tensor, 32); |
| 260 | + // Permute from [32, mod_index_length] to [mod_index_length, 32] to match Python's (mod_index_length, 32) |
| 261 | + modulation_index_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, modulation_index_tensor, 1, 0, 2, 3)); |
| 262 | + |
| 263 | + // 4. Broadcast modulation_index: unsqueeze(0).repeat(batch_size, 1, 1) |
| 264 | + // From [mod_index_length, 32] to [batch_size, mod_index_length, 32] |
| 265 | + // Reshape to [32, mod_index_length, 1] then repeat along batch dimension |
| 266 | + modulation_index_tensor = ggml_reshape_3d(output_ctx, modulation_index_tensor, 32, chroma.chroma_hyperparams.mod_vector_total_indices, 1); |
| 267 | + modulation_index_tensor = ggml_repeat(output_ctx, modulation_index_tensor, ggml_new_tensor_3d(output_ctx, GGML_TYPE_F32, 32, chroma.chroma_hyperparams.mod_vector_total_indices, batch_size)); |
| 268 | + // Permute back to [batch_size, mod_index_length, 32] |
| 269 | + modulation_index_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, modulation_index_tensor, 2, 1, 0, 3)); |
| 270 | + |
| 271 | + // 5. Concatenate distill_timestep and distil_guidance: torch.cat([distill_timestep, distil_guidance], dim=1) |
| 272 | + // From [batch_size, 16] and [batch_size, 16] to [batch_size, 32] |
| 273 | + struct ggml_tensor* combined_timestep_guidance = ggml_concat(output_ctx, distill_timestep_tensor, distil_guidance_tensor, 1); |
| 274 | + |
| 275 | + // 6. Broadcast combined_timestep_guidance: unsqueeze(1).repeat(1, mod_index_length, 1) |
| 276 | + // From [batch_size, 32] to [batch_size, mod_index_length, 32] |
| 277 | + // Reshape to [32, 1, batch_size] then repeat along mod_index_length dimension |
| 278 | + combined_timestep_guidance = ggml_reshape_3d(output_ctx, combined_timestep_guidance, 32, 1, batch_size); |
| 279 | + combined_timestep_guidance = ggml_repeat(output_ctx, combined_timestep_guidance, ggml_new_tensor_3d(output_ctx, GGML_TYPE_F32, 32, chroma.chroma_hyperparams.mod_vector_total_indices, batch_size)); |
| 280 | + // Permute back to [batch_size, mod_index_length, 32] |
| 281 | + combined_timestep_guidance = ggml_cont(output_ctx, ggml_permute(output_ctx, combined_timestep_guidance, 2, 1, 0, 3)); |
| 282 | + |
| 283 | + // 7. Final concatenation for input_vec: torch.cat([timestep_guidance, modulation_index], dim=-1) |
| 284 | + // From [batch_size, mod_index_length, 32] and [batch_size, mod_index_length, 32] to [batch_size, mod_index_length, 64] |
| 285 | + struct ggml_tensor* constructed_timestep_for_approximator_input_vec = ggml_concat(output_ctx, combined_timestep_guidance, modulation_index_tensor, 2); |
306 | 286 |
|
307 | 287 | chroma.compute(n_threads,
|
308 |
| - x, |
309 |
| - timesteps, |
310 |
| - context, // T5 embeddings |
| 288 | + x, // img_latent_tokens |
| 289 | + constructed_timestep_for_approximator_input_vec, // constructed input_vec |
| 290 | + context, // txt_tokens (T5 embeddings) |
| 291 | + y, // pe (positional embeddings) |
311 | 292 | c_concat, // t5_padding_mask
|
312 |
| - y, // pe |
313 | 293 | output,
|
314 | 294 | output_ctx,
|
315 | 295 | skip_layers);
|
|
0 commit comments