@@ -181,13 +181,9 @@ struct Approximator_ggml : public UnaryBlock {
181
181
blocks[" layers." + std::to_string (i)] = std::shared_ptr<GGMLBlock>(new MLPEmbedder (hidden_dim, hidden_dim));
182
182
blocks[" norms." + std::to_string (i)] = std::shared_ptr<GGMLBlock>(new RMSNorm (hidden_dim));
183
183
}
184
- blocks[" out_proj" ] = std::shared_ptr<GGMLBlock>(new Linear (hidden_dim, out_dim));
184
+ blocks[" out_proj" ] = std::shared_ptr<GGMLBlock>(new Linear (hidden_dim , out_dim));
185
185
}
186
186
187
- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) override {
188
- // Rely on the base class to initialize nested blocks
189
- UnaryBlock::init_params (ctx, tensor_types, prefix);
190
- }
191
187
192
188
struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * timestep) {
193
189
// Implement forward pass based on the pseudo-code in the plan (Phase 2.2)
@@ -243,10 +239,6 @@ struct QKNorm : public GGMLBlock {
243
239
return norm->forward (ctx, x);
244
240
}
245
241
246
- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) override {
247
- // Rely on the base class to initialize nested blocks
248
- GGMLBlock::init_params (ctx, tensor_types, prefix);
249
- }
250
242
};
251
243
252
244
// Based on the plan (Phase 1.2 and 2.5)
@@ -262,12 +254,6 @@ struct SelfAttention : public GGMLBlock {
262
254
blocks[" norm" ] = std::shared_ptr<GGMLBlock>(new QKNorm (head_dim));
263
255
blocks[" proj" ] = std::shared_ptr<GGMLBlock>(new Linear (dim, dim));
264
256
}
265
-
266
- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) override {
267
- // Rely on the base class to initialize nested blocks
268
- GGMLBlock::init_params (ctx, tensor_types, prefix);
269
- }
270
-
271
257
std::vector<struct ggml_tensor *> pre_attention (struct ggml_context * ctx, struct ggml_tensor * x) {
272
258
auto qkv_proj = std::dynamic_pointer_cast<Linear>(blocks[" qkv" ]);
273
259
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks[" norm" ]);
@@ -288,7 +274,7 @@ struct SelfAttention : public GGMLBlock {
288
274
auto v = ggml_view_3d (ctx, qkv_split, dim, L, N, qkv_split->nb [1 ], qkv_split->nb [2 ], offset * 2 ); // [dim, L, N]
289
275
290
276
// Reshape q, k, v for QKNorm and ggml_nn_attention_ext
291
- // QKNorm expects [..., dim], ggml_nn_attention_ext expects [d_head, L, N*n_head] for q, k and [d_head, L, n_head, N] for v
277
+ // QKNorm expects [..., dim], ggml_nn_attention_ext expects [d_head, L, N*n_head] or similar depending on implementation
292
278
// Let's reshape q, k, v to [d_head, L, N*n_head] and [d_head, L, n_head, N] respectively
293
279
294
280
auto q_reshaped = ggml_reshape_3d (ctx, q, head_dim, L, N * num_heads); // [dim, L, N] -> [d_head, L, N*n_head]
@@ -470,10 +456,6 @@ struct SingleStreamBlock_ggml : public GGMLBlock {
470
456
// Modulation block is created and called in the forward pass based on the plan
471
457
}
472
458
473
- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) override {
474
- GGMLBlock::init_params (ctx, tensor_types, prefix);
475
- }
476
-
477
459
struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * pe, struct ModulationOut & mod, struct ggml_tensor * attn_mask = NULL ) {
478
460
// x: [N, L, hidden_size]
479
461
// pe: Positional embeddings (shape TBD)
@@ -573,10 +555,6 @@ struct DoubleStreamBlock_ggml : public GGMLBlock { // DoubleStreamBlock forward
573
555
blocks[" txt_mlp.2" ] = std::shared_ptr<GGMLBlock>(new Linear (mlp_hidden_dim, hidden_size));
574
556
}
575
557
576
- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) override {
577
- GGMLBlock::init_params (ctx, tensor_types, prefix);
578
- }
579
-
580
558
std::pair<struct ggml_tensor *, struct ggml_tensor *> forward (struct ggml_context * ctx, struct ggml_tensor * img, struct ggml_tensor * txt, struct ggml_tensor * pe, const std::vector<ModulationOut>& img_mods, const std::vector<ModulationOut>& txt_mods, struct ggml_tensor * attn_mask = NULL ) {
581
559
auto img_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks[" img_norm1" ]);
582
560
auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks[" img_attn" ]);
@@ -665,9 +643,6 @@ struct LastLayer_ggml : public GGMLBlock {
665
643
}
666
644
667
645
668
- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) override {
669
- GGMLBlock::init_params (ctx, tensor_types, prefix);
670
- }
671
646
672
647
struct ggml_tensor * forward (struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * shift, struct ggml_tensor * scale) {
673
648
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks[" norm_final" ]);
@@ -718,7 +693,7 @@ struct ChromaUNet_ggml : public GGMLBlock {
718
693
: hidden_size(hidden_size), num_heads(num_heads), mlp_ratio(mlp_ratio),
719
694
depth (depth), single_depth(single_depth), in_channels(in_channels),
720
695
out_channels(out_channels), flash_attn(flash_attn) {
721
- blocks[" approximator " ] = std::shared_ptr<GGMLBlock>(new Approximator_ggml (1 , hidden_size * 6 , hidden_size)); // out_dim is hidden_size * 6 for 2 sets of scale/shift/gate
696
+ blocks[" distilled_guidance_layer " ] = std::shared_ptr<GGMLBlock>(new Approximator_ggml (64 , hidden_size * 6 , hidden_size)); // out_dim is hidden_size * 6 for 2 sets of scale/shift/gate
722
697
723
698
blocks[" img_in" ] = std::shared_ptr<GGMLBlock>(new Linear (in_channels, hidden_size, true ));
724
699
blocks[" txt_in" ] = std::shared_ptr<GGMLBlock>(new Linear (hidden_size, hidden_size, true )); // T5 embeddings are already hidden_size
@@ -734,9 +709,6 @@ struct ChromaUNet_ggml : public GGMLBlock {
734
709
blocks[" final_layer" ] = std::shared_ptr<GGMLBlock>(new LastLayer_ggml (hidden_size, 1 , out_channels)); // patch_size is 1 for Chroma
735
710
}
736
711
737
- void init_params (struct ggml_context * ctx, std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix = " " ) override {
738
- GGMLBlock::init_params (ctx, tensor_types, prefix);
739
- }
740
712
741
713
struct ggml_tensor * forward (struct ggml_context * ctx,
742
714
struct ggml_tensor * img_latent,
@@ -745,7 +717,7 @@ struct ChromaUNet_ggml : public GGMLBlock {
745
717
struct ggml_tensor * pe,
746
718
struct ggml_tensor * t5_padding_mask,
747
719
std::vector<int > skip_layers = std::vector<int >()) {
748
- auto approximator = std::dynamic_pointer_cast<Approximator_ggml>(blocks[" approximator " ]);
720
+ auto approximator = std::dynamic_pointer_cast<Approximator_ggml>(blocks[" distilled_guidance_layer " ]);
749
721
auto img_in = std::dynamic_pointer_cast<Linear>(blocks[" img_in" ]);
750
722
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks[" txt_in" ]);
751
723
auto final_layer = std::dynamic_pointer_cast<LastLayer_ggml>(blocks[" final_layer" ]);
@@ -804,10 +776,6 @@ struct ChromaUNet_ggml : public GGMLBlock {
804
776
// This means approx_output needs to be split into one set of (scale, shift, gate).
805
777
// This is inconsistent with the DoubleStreamBlock.
806
778
807
- // Let's assume the `vec` for SingleStreamBlock is also from the Approximator.
808
- // And we need a separate Modulation instance for it.
809
- // This means the Approximator output needs to be even larger.
810
-
811
779
// Re-reading the plan:
812
780
// Approximator output: `conditioning signal (Tensor of shape [batch_size, 1, out_dim])`, likely split into scale, shift, gate.
813
781
// SingleStreamBlock: `mod = vec`
@@ -817,19 +785,6 @@ struct ChromaUNet_ggml : public GGMLBlock {
817
785
// Let's assume `approx_output` is the `vec` and it contains all necessary modulation parameters.
818
786
// And the `Modulation` class will be used to extract them.
819
787
820
- // Let's create a single Modulation instance for SingleStreamBlock.
821
- // And pass the relevant part of approx_output to it.
822
- // This means approx_output needs to be split into one set of (scale, shift, gate).
823
-
824
- // This is getting complicated. Let's simplify.
825
- // The `FluxModel` passes a single `vec` to all blocks.
826
- // Let's assume `approx_output` is the single `vec` for all blocks.
827
- // And each block will extract what it needs.
828
-
829
- // For SingleStreamBlock, it needs one ModulationOut.
830
- // So, we need to extract one ModulationOut from `approx_output`.
831
- // This means `approx_output` should contain at least `hidden_size * 3`.
832
-
833
788
// Let's assume `approx_output` is `hidden_size * 6` (for 2 sets of modulations).
834
789
// And `SingleStreamBlock` will use the first set.
835
790
@@ -909,7 +864,8 @@ struct ChromaRunner : public GGMLRunner {
909
864
ChromaRunner (
910
865
ggml_backend_t backend,
911
866
std::map<std::string, enum ggml_type>& tensor_types,
912
- bool flash_attn
867
+ const std::string prefix = " " ,
868
+ bool flash_attn = false
913
869
) :
914
870
GGMLRunner (backend),
915
871
chroma_params ({
@@ -921,7 +877,7 @@ struct ChromaRunner : public GGMLRunner {
921
877
chroma_params.in_channels, chroma_params.out_channels, chroma_params.flash_attn
922
878
)
923
879
{
924
- chroma_unet.init_params (params_ctx, tensor_types);
880
+ chroma_unet.init (params_ctx, tensor_types,prefix );
925
881
}
926
882
927
883
std::string get_desc () override {
0 commit comments