Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8079703

Browse files
committed
conditionner: make text encoders optional for Flux
1 parent c13cf04 commit 8079703

File tree

2 files changed

+115
-50
lines changed

2 files changed

+115
-50
lines changed

conditioner.hpp

+114-50
Original file line numberDiff line numberDiff line change
@@ -1096,38 +1096,82 @@ struct FluxCLIPEmbedder : public Conditioner {
10961096
std::shared_ptr<CLIPTextModelRunner> clip_l;
10971097
std::shared_ptr<T5Runner> t5;
10981098

1099+
bool use_clip_l = false;
1100+
bool use_t5 = false;
1101+
10991102
FluxCLIPEmbedder(ggml_backend_t backend,
11001103
std::map<std::string, enum ggml_type>& tensor_types,
11011104
int clip_skip = -1) {
11021105
if (clip_skip <= 0) {
11031106
clip_skip = 2;
11041107
}
1105-
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true);
1106-
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1108+
1109+
for (auto pair : tensor_types) {
1110+
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
1111+
use_clip_l = true;
1112+
} else if (pair.first.find("text_encoders.t5xxl") != std::string::npos) {
1113+
use_t5 = true;
1114+
}
1115+
}
1116+
1117+
if (!use_clip_l && !use_t5) {
1118+
LOG_WARN("IMPORTANT NOTICE: No text encoders provided, cannot process prompts!");
1119+
return;
1120+
}
1121+
1122+
if (use_clip_l) {
1123+
clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, clip_skip, true);
1124+
} else {
1125+
LOG_WARN("clip_l text encoder not found! Prompt adherence might be degraded.");
1126+
}
1127+
if (use_t5) {
1128+
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
1129+
} else {
1130+
LOG_WARN("t5xxl text encoder not found! Prompt adherence might be degraded.");
1131+
}
11071132
}
11081133

11091134
void set_clip_skip(int clip_skip) {
1110-
clip_l->set_clip_skip(clip_skip);
1135+
if (use_clip_l) {
1136+
clip_l->set_clip_skip(clip_skip);
1137+
}
11111138
}
11121139

11131140
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
1114-
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1115-
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1141+
if (use_clip_l) {
1142+
clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model");
1143+
}
1144+
if (use_t5) {
1145+
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
1146+
}
11161147
}
11171148

11181149
void alloc_params_buffer() {
1119-
clip_l->alloc_params_buffer();
1120-
t5->alloc_params_buffer();
1150+
if (use_clip_l) {
1151+
clip_l->alloc_params_buffer();
1152+
}
1153+
if (use_t5) {
1154+
t5->alloc_params_buffer();
1155+
}
11211156
}
11221157

11231158
void free_params_buffer() {
1124-
clip_l->free_params_buffer();
1125-
t5->free_params_buffer();
1159+
if (use_clip_l) {
1160+
clip_l->free_params_buffer();
1161+
}
1162+
if (use_t5) {
1163+
t5->free_params_buffer();
1164+
}
11261165
}
11271166

11281167
size_t get_params_buffer_size() {
1129-
size_t buffer_size = clip_l->get_params_buffer_size();
1130-
buffer_size += t5->get_params_buffer_size();
1168+
size_t buffer_size = 0;
1169+
if (use_clip_l) {
1170+
buffer_size += clip_l->get_params_buffer_size();
1171+
}
1172+
if (use_t5) {
1173+
buffer_size += t5->get_params_buffer_size();
1174+
}
11311175
return buffer_size;
11321176
}
11331177

@@ -1157,18 +1201,23 @@ struct FluxCLIPEmbedder : public Conditioner {
11571201
for (const auto& item : parsed_attention) {
11581202
const std::string& curr_text = item.first;
11591203
float curr_weight = item.second;
1160-
1161-
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1162-
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1163-
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1164-
1165-
curr_tokens = t5_tokenizer.Encode(curr_text, true);
1166-
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1167-
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1204+
if (use_clip_l) {
1205+
std::vector<int> curr_tokens = clip_l_tokenizer.encode(curr_text, on_new_token_cb);
1206+
clip_l_tokens.insert(clip_l_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1207+
clip_l_weights.insert(clip_l_weights.end(), curr_tokens.size(), curr_weight);
1208+
}
1209+
if (use_t5) {
1210+
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
1211+
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
1212+
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
1213+
}
1214+
}
1215+
if (use_clip_l) {
1216+
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1217+
}
1218+
if (use_t5) {
1219+
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
11681220
}
1169-
1170-
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
1171-
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
11721221

11731222
// for (int i = 0; i < clip_l_tokens.size(); i++) {
11741223
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1201,34 +1250,36 @@ struct FluxCLIPEmbedder : public Conditioner {
12011250
std::vector<float> hidden_states_vec;
12021251

12031252
size_t chunk_len = 256;
1204-
size_t chunk_count = t5_tokens.size() / chunk_len;
1253+
size_t chunk_count = std::max(clip_l_tokens.size() > 0 ? chunk_len : 0, t5_tokens.size()) / chunk_len;
12051254
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
12061255
// clip_l
12071256
if (chunk_idx == 0) {
1208-
size_t chunk_len_l = 77;
1209-
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1210-
clip_l_tokens.begin() + chunk_len_l);
1211-
std::vector<float> chunk_weights(clip_l_weights.begin(),
1212-
clip_l_weights.begin() + chunk_len_l);
1257+
if (use_clip_l) {
1258+
size_t chunk_len_l = 77;
1259+
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
1260+
clip_l_tokens.begin() + chunk_len_l);
1261+
std::vector<float> chunk_weights(clip_l_weights.begin(),
1262+
clip_l_weights.begin() + chunk_len_l);
12131263

1214-
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1215-
size_t max_token_idx = 0;
1264+
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
1265+
size_t max_token_idx = 0;
12161266

1217-
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1218-
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
1267+
auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID);
1268+
max_token_idx = std::min<size_t>(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1);
12191269

1220-
clip_l->compute(n_threads,
1221-
input_ids,
1222-
0,
1223-
NULL,
1224-
max_token_idx,
1225-
true,
1226-
&pooled,
1227-
work_ctx);
1270+
clip_l->compute(n_threads,
1271+
input_ids,
1272+
0,
1273+
NULL,
1274+
max_token_idx,
1275+
true,
1276+
&pooled,
1277+
work_ctx);
1278+
}
12281279
}
12291280

12301281
// t5
1231-
{
1282+
if (use_t5) {
12321283
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
12331284
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
12341285
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
@@ -1255,8 +1306,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12551306
float new_mean = ggml_tensor_mean(tensor);
12561307
ggml_tensor_scale(tensor, (original_mean / new_mean));
12571308
}
1309+
} else {
1310+
chunk_hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, chunk_len);
1311+
ggml_set_f32(chunk_hidden_states, 0.f);
12581312
}
12591313

1314+
12601315
int64_t t1 = ggml_time_ms();
12611316
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
12621317
if (force_zero_embeddings) {
@@ -1265,17 +1320,26 @@ struct FluxCLIPEmbedder : public Conditioner {
12651320
vec[i] = 0;
12661321
}
12671322
}
1268-
1323+
12691324
hidden_states_vec.insert(hidden_states_vec.end(),
1270-
(float*)chunk_hidden_states->data,
1271-
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1325+
(float*)chunk_hidden_states->data,
1326+
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
1327+
}
1328+
1329+
if (hidden_states_vec.size() > 0) {
1330+
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1331+
hidden_states = ggml_reshape_2d(work_ctx,
1332+
hidden_states,
1333+
chunk_hidden_states->ne[0],
1334+
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
1335+
} else {
1336+
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
1337+
ggml_set_f32(hidden_states, 0.f);
1338+
}
1339+
if (pooled == NULL) {
1340+
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768);
1341+
ggml_set_f32(pooled, 0.f);
12721342
}
1273-
1274-
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
1275-
hidden_states = ggml_reshape_2d(work_ctx,
1276-
hidden_states,
1277-
chunk_hidden_states->ne[0],
1278-
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
12791343
return SDCondition(hidden_states, pooled, NULL);
12801344
}
12811345

stable-diffusion.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ class StableDiffusionGGML {
314314
clip_backend = backend;
315315
bool use_t5xxl = false;
316316
if (sd_version_is_dit(version)) {
317+
// TODO: check if t5 is actually loaded?
317318
use_t5xxl = true;
318319
}
319320
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {

0 commit comments

Comments
 (0)