Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2eea43c

Browse files
feat: temp
1 parent 9566897 commit 2eea43c

File tree

1 file changed

+68
-88
lines changed

1 file changed

+68
-88
lines changed

diffusion_model.hpp

Lines changed: 68 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "mmdit.hpp"
66
#include "unet.hpp"
77
#include "chroma.hpp"
8+
#include "ggml_extend.hpp" // Required for set_timestep_embedding
89

910
struct DiffusionModel {
1011
virtual void compute(int n_threads,
@@ -216,100 +217,79 @@ struct ChromaModel : public DiffusionModel {
216217
}
217218

218219
void compute(int n_threads,
219-
struct ggml_tensor* x,
220-
struct ggml_tensor* timesteps,
221-
struct ggml_tensor* context, // T5 embeddings
222-
struct ggml_tensor* c_concat, // Not used by Chroma
223-
struct ggml_tensor* y, // Not used by Chroma
224-
struct ggml_tensor* guidance, // Not used by Chroma
225-
int num_video_frames = -1, // Not used by Chroma
226-
std::vector<struct ggml_tensor*> controls = {}, // Not used by Chroma
227-
float control_strength = 0.f, // Not used by Chroma
220+
struct ggml_tensor* x, // img_latent_tokens
221+
struct ggml_tensor* timesteps, // raw_timesteps
222+
struct ggml_tensor* context, // txt_embeddings (T5 embeddings)
223+
struct ggml_tensor* c_concat, // t5_padding_mask
224+
struct ggml_tensor* y, // pe (positional embeddings)
225+
struct ggml_tensor* guidance, // raw_guidance
226+
int num_video_frames = -1,
227+
std::vector<struct ggml_tensor*> controls = {},
228+
float control_strength = 0.f,
228229
struct ggml_tensor** output = NULL,
229230
struct ggml_context* output_ctx = NULL,
230231
std::vector<int> skip_layers = std::vector<int>()) {
231-
// For Chroma, context is T5 embeddings, c_concat and y are not used.
232-
// We need to pass positional embeddings (pe) and t5_padding_mask.
233-
// These are not directly available in the DiffusionModel compute signature.
234-
// This implies a need to adjust the DiffusionModel interface or how ChromaModel is called.
235-
236-
// For now, let's assume pe and t5_padding_mask are handled internally by ChromaRunner
237-
// or passed through a different mechanism.
238-
// Based on the ChromaRunner::build_graph, it expects t5_padding_mask and pe.
239-
// The current DiffusionModel interface does not provide these.
240-
241-
// This is a design conflict. The current DiffusionModel interface is generic.
242-
// Chroma has specific inputs (T5 embeddings, T5 padding mask, positional embeddings).
243-
// FluxModel also has specific inputs (pe, guidance).
244-
245-
// Let's re-evaluate the `compute` method signature for `DiffusionModel`.
246-
// The `compute` method in `DiffusionModel` is quite generic.
247-
// `FluxModel` uses `pe` and `guidance`.
248-
// `ChromaModel` needs `pe` and `t5_padding_mask`.
249-
250-
// The `stable-diffusion.cpp` calls `diffusion_model->compute`.
251-
// It passes `context`, `c_concat`, `y`, `guidance`.
252-
253-
// For Chroma, `context` is `txt_embeddings`.
254-
// `c_concat` and `y` are not directly used by Chroma's UNet.
255-
// `guidance` is also not directly used by Chroma's UNet.
256-
257-
// The `chroma_integration_plan.md` states:
258-
// `Pass the image latent, the T5 embeddings sequence, the timestep, positional embeddings, and the T5 padding mask to the ChromaUNet_ggml forward function.`
259-
260-
// This means the `ChromaModel::compute` needs to receive `pe` and `t5_padding_mask`.
261-
// The current `DiffusionModel::compute` signature does not have these.
262-
263-
// I need to modify the `DiffusionModel` interface to include `pe` and `t5_padding_mask`.
264-
// This will affect all other DiffusionModel implementations (UNetModel, MMDiTModel, FluxModel).
265-
// This is a larger change than just implementing ChromaModel.
266-
267-
// Let's check the objective again: "Implement the `ChromaModel` class and integrate it into `StableDiffusionGGML`".
268-
// It doesn't explicitly say to modify the `DiffusionModel` interface.
269-
270-
// Alternative: `ChromaRunner` could generate `pe` and `t5_padding_mask` internally.
271-
// `pe` generation depends on image dimensions and context length.
272-
// `t5_padding_mask` generation depends on T5 token IDs.
273-
274-
// The `FluxRunner::build_graph` generates `pe_vec` and then creates `pe` tensor.
275-
// So, `pe` can be generated inside `ChromaRunner`.
276-
277-
// For `t5_padding_mask`, it needs `token_ids` which come from `T5Embedder`.
278-
// `T5Embedder` is part of `cond_stage_model`.
279-
// The `cond_stage_model->get_learned_condition` returns `SDCondition` which contains `c_crossattn` (T5 embeddings).
280-
// It does not return `token_ids`.
281-
282-
// This means `t5_padding_mask` cannot be generated inside `ChromaRunner` without access to `token_ids`.
283-
// The `generate_t5_padding_mask_ggml` function needs `token_ids`.
284-
285-
// This implies that `token_ids` (or the `t5_padding_mask` itself) needs to be passed to `ChromaModel::compute`.
286-
// This means the `DiffusionModel::compute` interface *must* change.
287-
288-
// Let's assume for now that `t5_padding_mask` and `pe` are passed as part of `context` or `c_concat` or `y`
289-
// or that the `ChromaRunner` can somehow access them.
290-
// This is a hacky solution.
291-
292-
// Let's look at `stable-diffusion.cpp` where `diffusion_model->compute` is called.
293-
// It passes `x`, `timesteps`, `cond.c_crossattn`, `cond.c_concat`, `cond.c_vector`, `guidance_tensor`.
294-
// For Chroma, `cond.c_crossattn` is `txt_embeddings`.
295-
// `cond.c_concat` and `cond.c_vector` are currently unused for Chroma.
296-
// I can repurpose `cond.c_concat` for `t5_padding_mask` and `cond.c_vector` for `pe`.
297-
// This is a bit of a hack, but avoids changing the `DiffusionModel` interface for now.
298-
299-
// Let's assume:
300-
// `context` (original `c_crossattn`) is `txt_embeddings`
301-
// `c_concat` is `t5_padding_mask`
302-
// `y` is `pe`
303-
304-
// This means I need to modify `stable-diffusion.cpp` to pass these correctly.
305-
// And `ChromaModel::compute` will interpret them as such.
232+
// ... existing comments about repurposing inputs ...
233+
234+
// Construct timestep_for_approximator_input_vec as per Python logic
235+
int64_t batch_size = x->ne[3]; // Assuming batch size is the last dim of x (img_latent_tokens)
236+
237+
// 1. distill_timestep = timestep_embedding(timesteps, 16)
238+
// raw_timesteps is expected to be a 1D tensor [N] or [1]. Extract the single float value.
239+
std::vector<float> current_timestep_val = {ggml_get_f32_1d(timesteps, 0)};
240+
struct ggml_tensor* distill_timestep_tensor = ggml_new_tensor_2d(output_ctx, GGML_TYPE_F32, 16, batch_size);
241+
set_timestep_embedding(current_timestep_val, distill_timestep_tensor, 16);
242+
// Permute from [16, batch_size] to [batch_size, 16] to match Python's (batch_size, 16)
243+
distill_timestep_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, distill_timestep_tensor, 1, 0, 2, 3));
244+
245+
// 2. distil_guidance = timestep_embedding(guidance, 16)
246+
std::vector<float> current_guidance_val = {ggml_get_f32_1d(guidance, 0)};
247+
struct ggml_tensor* distil_guidance_tensor = ggml_new_tensor_2d(output_ctx, GGML_TYPE_F32, 16, batch_size);
248+
set_timestep_embedding(current_guidance_val, distil_guidance_tensor, 16);
249+
// Permute from [16, batch_size] to [batch_size, 16]
250+
distil_guidance_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, distil_guidance_tensor, 1, 0, 2, 3));
251+
252+
// 3. modulation_index = timestep_embedding(torch.arange(mod_index_length), 32)
253+
// mod_index_length is chroma.chroma_hyperparams.mod_vector_total_indices (344)
254+
std::vector<float> arange_mod_index(chroma.chroma_hyperparams.mod_vector_total_indices);
255+
for (int i = 0; i < chroma.chroma_hyperparams.mod_vector_total_indices; ++i) {
256+
arange_mod_index[i] = (float)i;
257+
}
258+
struct ggml_tensor* modulation_index_tensor = ggml_new_tensor_2d(output_ctx, GGML_TYPE_F32, 32, chroma.chroma_hyperparams.mod_vector_total_indices);
259+
set_timestep_embedding(arange_mod_index, modulation_index_tensor, 32);
260+
// Permute from [32, mod_index_length] to [mod_index_length, 32] to match Python's (mod_index_length, 32)
261+
modulation_index_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, modulation_index_tensor, 1, 0, 2, 3));
262+
263+
// 4. Broadcast modulation_index: unsqueeze(0).repeat(batch_size, 1, 1)
264+
// From [mod_index_length, 32] to [batch_size, mod_index_length, 32]
265+
// Reshape to [32, mod_index_length, 1] then repeat along batch dimension
266+
modulation_index_tensor = ggml_reshape_3d(output_ctx, modulation_index_tensor, 32, chroma.chroma_hyperparams.mod_vector_total_indices, 1);
267+
modulation_index_tensor = ggml_repeat(output_ctx, modulation_index_tensor, ggml_new_tensor_3d(output_ctx, GGML_TYPE_F32, 32, chroma.chroma_hyperparams.mod_vector_total_indices, batch_size));
268+
// Permute back to [batch_size, mod_index_length, 32]
269+
modulation_index_tensor = ggml_cont(output_ctx, ggml_permute(output_ctx, modulation_index_tensor, 2, 1, 0, 3));
270+
271+
// 5. Concatenate distill_timestep and distil_guidance: torch.cat([distill_timestep, distil_guidance], dim=1)
272+
// From [batch_size, 16] and [batch_size, 16] to [batch_size, 32]
273+
struct ggml_tensor* combined_timestep_guidance = ggml_concat(output_ctx, distill_timestep_tensor, distil_guidance_tensor, 1);
274+
275+
// 6. Broadcast combined_timestep_guidance: unsqueeze(1).repeat(1, mod_index_length, 1)
276+
// From [batch_size, 32] to [batch_size, mod_index_length, 32]
277+
// Reshape to [32, 1, batch_size] then repeat along mod_index_length dimension
278+
combined_timestep_guidance = ggml_reshape_3d(output_ctx, combined_timestep_guidance, 32, 1, batch_size);
279+
combined_timestep_guidance = ggml_repeat(output_ctx, combined_timestep_guidance, ggml_new_tensor_3d(output_ctx, GGML_TYPE_F32, 32, chroma.chroma_hyperparams.mod_vector_total_indices, batch_size));
280+
// Permute back to [batch_size, mod_index_length, 32]
281+
combined_timestep_guidance = ggml_cont(output_ctx, ggml_permute(output_ctx, combined_timestep_guidance, 2, 1, 0, 3));
282+
283+
// 7. Final concatenation for input_vec: torch.cat([timestep_guidance, modulation_index], dim=-1)
284+
// From [batch_size, mod_index_length, 32] and [batch_size, mod_index_length, 32] to [batch_size, mod_index_length, 64]
285+
struct ggml_tensor* constructed_timestep_for_approximator_input_vec = ggml_concat(output_ctx, combined_timestep_guidance, modulation_index_tensor, 2);
306286

307287
chroma.compute(n_threads,
308-
x,
309-
timesteps,
310-
context, // T5 embeddings
288+
x, // img_latent_tokens
289+
constructed_timestep_for_approximator_input_vec, // constructed input_vec
290+
context, // txt_tokens (T5 embeddings)
291+
y, // pe (positional embeddings)
311292
c_concat, // t5_padding_mask
312-
y, // pe
313293
output,
314294
output_ctx,
315295
skip_layers);

0 commit comments

Comments
 (0)