Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9b1d90b

Browse files
authored
fix: improve clip text_projection support (leejet#397)
1 parent 65fa646 commit 9b1d90b

File tree

2 files changed

+40
-47
lines changed

2 files changed

+40
-47
lines changed

clip.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,12 @@ class CLIPTextModel : public GGMLBlock {
711711
if (return_pooled) {
712712
auto text_projection = params["text_projection"];
713713
ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx);
714-
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled);
715-
return pooled;
714+
if (text_projection != NULL) {
715+
pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL);
716+
} else {
717+
LOG_DEBUG("Missing text_projection matrix, assuming identity...");
718+
}
719+
return pooled; // [hidden_size, 1, 1]
716720
}
717721

718722
return x; // [N, n_token, hidden_size]

conditioner.hpp

Lines changed: 34 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -798,21 +798,17 @@ struct SD3CLIPEmbedder : public Conditioner {
798798
}
799799

800800
if (chunk_idx == 0) {
801-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
802-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
803-
// clip_l->compute(n_threads,
804-
// input_ids,
805-
// 0,
806-
// NULL,
807-
// max_token_idx,
808-
// true,
809-
// &pooled_l,
810-
// work_ctx);
811-
812-
// clip_l.transformer.text_model.text_projection no in file, ignore
813-
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
814-
pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
815-
ggml_set_f32(pooled_l, 0.f);
801+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
802+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
803+
clip_l->compute(n_threads,
804+
input_ids,
805+
0,
806+
NULL,
807+
max_token_idx,
808+
true,
809+
&pooled_l,
810+
work_ctx);
811+
816812
}
817813
}
818814

@@ -852,21 +848,17 @@ struct SD3CLIPEmbedder : public Conditioner {
852848
}
853849

854850
if (chunk_idx == 0) {
855-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
856-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
857-
// clip_g->compute(n_threads,
858-
// input_ids,
859-
// 0,
860-
// NULL,
861-
// max_token_idx,
862-
// true,
863-
// &pooled_g,
864-
// work_ctx);
865-
// clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too
866-
867-
// TODO: fix pooled_g
868-
pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280);
869-
ggml_set_f32(pooled_g, 0.f);
851+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID);
852+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
853+
clip_g->compute(n_threads,
854+
input_ids,
855+
0,
856+
NULL,
857+
max_token_idx,
858+
true,
859+
&pooled_g,
860+
work_ctx);
861+
870862
}
871863
}
872864

@@ -1104,21 +1096,18 @@ struct FluxCLIPEmbedder : public Conditioner {
11041096
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
11051097
size_t max_token_idx = 0;
11061098

1107-
// auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1108-
// max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1109-
// clip_l->compute(n_threads,
1110-
// input_ids,
1111-
// 0,
1112-
// NULL,
1113-
// max_token_idx,
1114-
// true,
1115-
// &pooled,
1116-
// work_ctx);
1117-
1118-
// clip_l.transformer.text_model.text_projection no in file, ignore
1119-
// TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection
1120-
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1121-
ggml_set_f32(pooled, 0.f);
1099+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1100+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1101+
1102+
clip_l->compute(n_threads,
1103+
input_ids,
1104+
0,
1105+
NULL,
1106+
max_token_idx,
1107+
true,
1108+
&pooled,
1109+
work_ctx);
1110+
11221111
}
11231112

11241113
// t5

0 commit comments

Comments
 (0)