@@ -1096,38 +1096,82 @@ struct FluxCLIPEmbedder : public Conditioner {
1096
1096
std::shared_ptr<CLIPTextModelRunner> clip_l;
1097
1097
std::shared_ptr<T5Runner> t5;
1098
1098
1099
+ bool use_clip_l = false ;
1100
+ bool use_t5 = false ;
1101
+
1099
1102
FluxCLIPEmbedder (ggml_backend_t backend,
1100
1103
std::map<std::string, enum ggml_type>& tensor_types,
1101
1104
int clip_skip = -1 ) {
1102
1105
if (clip_skip <= 0 ) {
1103
1106
clip_skip = 2 ;
1104
1107
}
1105
- clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, true );
1106
- t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1108
+
1109
+ for (auto pair : tensor_types) {
1110
+ if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
1111
+ use_clip_l = true ;
1112
+ } else if (pair.first .find (" text_encoders.t5xxl" ) != std::string::npos) {
1113
+ use_t5 = true ;
1114
+ }
1115
+ }
1116
+
1117
+ if (!use_clip_l && !use_t5) {
1118
+ LOG_WARN (" IMPORTANT NOTICE: No text encoders provided, cannot process prompts!" );
1119
+ return ;
1120
+ }
1121
+
1122
+ if (use_clip_l) {
1123
+ clip_l = std::make_shared<CLIPTextModelRunner>(backend, tensor_types, " text_encoders.clip_l.transformer.text_model" , OPENAI_CLIP_VIT_L_14, clip_skip, true );
1124
+ } else {
1125
+ LOG_WARN (" clip_l text encoder not found! Prompt adherence might be degraded." );
1126
+ }
1127
+ if (use_t5) {
1128
+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1129
+ } else {
1130
+ LOG_WARN (" t5xxl text encoder not found! Prompt adherence might be degraded." );
1131
+ }
1107
1132
}
1108
1133
1109
1134
void set_clip_skip (int clip_skip) {
1110
- clip_l->set_clip_skip (clip_skip);
1135
+ if (use_clip_l) {
1136
+ clip_l->set_clip_skip (clip_skip);
1137
+ }
1111
1138
}
1112
1139
1113
1140
void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
1114
- clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1115
- t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1141
+ if (use_clip_l) {
1142
+ clip_l->get_param_tensors (tensors, " text_encoders.clip_l.transformer.text_model" );
1143
+ }
1144
+ if (use_t5) {
1145
+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1146
+ }
1116
1147
}
1117
1148
1118
1149
void alloc_params_buffer () {
1119
- clip_l->alloc_params_buffer ();
1120
- t5->alloc_params_buffer ();
1150
+ if (use_clip_l) {
1151
+ clip_l->alloc_params_buffer ();
1152
+ }
1153
+ if (use_t5) {
1154
+ t5->alloc_params_buffer ();
1155
+ }
1121
1156
}
1122
1157
1123
1158
void free_params_buffer () {
1124
- clip_l->free_params_buffer ();
1125
- t5->free_params_buffer ();
1159
+ if (use_clip_l) {
1160
+ clip_l->free_params_buffer ();
1161
+ }
1162
+ if (use_t5) {
1163
+ t5->free_params_buffer ();
1164
+ }
1126
1165
}
1127
1166
1128
1167
size_t get_params_buffer_size () {
1129
- size_t buffer_size = clip_l->get_params_buffer_size ();
1130
- buffer_size += t5->get_params_buffer_size ();
1168
+ size_t buffer_size = 0 ;
1169
+ if (use_clip_l) {
1170
+ buffer_size += clip_l->get_params_buffer_size ();
1171
+ }
1172
+ if (use_t5) {
1173
+ buffer_size += t5->get_params_buffer_size ();
1174
+ }
1131
1175
return buffer_size;
1132
1176
}
1133
1177
@@ -1157,18 +1201,23 @@ struct FluxCLIPEmbedder : public Conditioner {
1157
1201
for (const auto & item : parsed_attention) {
1158
1202
const std::string& curr_text = item.first ;
1159
1203
float curr_weight = item.second ;
1160
-
1161
- std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1162
- clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1163
- clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1164
-
1165
- curr_tokens = t5_tokenizer.Encode (curr_text, true );
1166
- t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1167
- t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1204
+ if (use_clip_l) {
1205
+ std::vector<int > curr_tokens = clip_l_tokenizer.encode (curr_text, on_new_token_cb);
1206
+ clip_l_tokens.insert (clip_l_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1207
+ clip_l_weights.insert (clip_l_weights.end (), curr_tokens.size (), curr_weight);
1208
+ }
1209
+ if (use_t5) {
1210
+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
1211
+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1212
+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1213
+ }
1214
+ }
1215
+ if (use_clip_l) {
1216
+ clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1217
+ }
1218
+ if (use_t5) {
1219
+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
1168
1220
}
1169
-
1170
- clip_l_tokenizer.pad_tokens (clip_l_tokens, clip_l_weights, 77 , padding);
1171
- t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
1172
1221
1173
1222
// for (int i = 0; i < clip_l_tokens.size(); i++) {
1174
1223
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@@ -1201,34 +1250,36 @@ struct FluxCLIPEmbedder : public Conditioner {
1201
1250
std::vector<float > hidden_states_vec;
1202
1251
1203
1252
size_t chunk_len = 256 ;
1204
- size_t chunk_count = t5_tokens.size () / chunk_len;
1253
+ size_t chunk_count = std::max (clip_l_tokens. size () > 0 ? chunk_len : 0 , t5_tokens.size () ) / chunk_len;
1205
1254
for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1206
1255
// clip_l
1207
1256
if (chunk_idx == 0 ) {
1208
- size_t chunk_len_l = 77 ;
1209
- std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1210
- clip_l_tokens.begin () + chunk_len_l);
1211
- std::vector<float > chunk_weights (clip_l_weights.begin (),
1212
- clip_l_weights.begin () + chunk_len_l);
1257
+ if (use_clip_l) {
1258
+ size_t chunk_len_l = 77 ;
1259
+ std::vector<int > chunk_tokens (clip_l_tokens.begin (),
1260
+ clip_l_tokens.begin () + chunk_len_l);
1261
+ std::vector<float > chunk_weights (clip_l_weights.begin (),
1262
+ clip_l_weights.begin () + chunk_len_l);
1213
1263
1214
- auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1215
- size_t max_token_idx = 0 ;
1264
+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1265
+ size_t max_token_idx = 0 ;
1216
1266
1217
- auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1218
- max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
1267
+ auto it = std::find (chunk_tokens.begin (), chunk_tokens.end (), clip_l_tokenizer.EOS_TOKEN_ID );
1268
+ max_token_idx = std::min<size_t >(std::distance (chunk_tokens.begin (), it), chunk_tokens.size () - 1 );
1219
1269
1220
- clip_l->compute (n_threads,
1221
- input_ids,
1222
- 0 ,
1223
- NULL ,
1224
- max_token_idx,
1225
- true ,
1226
- &pooled,
1227
- work_ctx);
1270
+ clip_l->compute (n_threads,
1271
+ input_ids,
1272
+ 0 ,
1273
+ NULL ,
1274
+ max_token_idx,
1275
+ true ,
1276
+ &pooled,
1277
+ work_ctx);
1278
+ }
1228
1279
}
1229
1280
1230
1281
// t5
1231
- {
1282
+ if (use_t5) {
1232
1283
std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
1233
1284
t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
1234
1285
std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
@@ -1255,8 +1306,12 @@ struct FluxCLIPEmbedder : public Conditioner {
1255
1306
float new_mean = ggml_tensor_mean (tensor);
1256
1307
ggml_tensor_scale (tensor, (original_mean / new_mean));
1257
1308
}
1309
+ } else {
1310
+ chunk_hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , chunk_len);
1311
+ ggml_set_f32 (chunk_hidden_states, 0 .f );
1258
1312
}
1259
1313
1314
+
1260
1315
int64_t t1 = ggml_time_ms ();
1261
1316
LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
1262
1317
if (force_zero_embeddings) {
@@ -1265,17 +1320,26 @@ struct FluxCLIPEmbedder : public Conditioner {
1265
1320
vec[i] = 0 ;
1266
1321
}
1267
1322
}
1268
-
1323
+
1269
1324
hidden_states_vec.insert (hidden_states_vec.end (),
1270
- (float *)chunk_hidden_states->data ,
1271
- ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1325
+ (float *)chunk_hidden_states->data ,
1326
+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1327
+ }
1328
+
1329
+ if (hidden_states_vec.size () > 0 ) {
1330
+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1331
+ hidden_states = ggml_reshape_2d (work_ctx,
1332
+ hidden_states,
1333
+ chunk_hidden_states->ne [0 ],
1334
+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1335
+ } else {
1336
+ hidden_states = ggml_new_tensor_2d (work_ctx, GGML_TYPE_F32, 4096 , 256 );
1337
+ ggml_set_f32 (hidden_states, 0 .f );
1338
+ }
1339
+ if (pooled == NULL ) {
1340
+ pooled = ggml_new_tensor_1d (work_ctx, GGML_TYPE_F32, 768 );
1341
+ ggml_set_f32 (pooled, 0 .f );
1272
1342
}
1273
-
1274
- hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1275
- hidden_states = ggml_reshape_2d (work_ctx,
1276
- hidden_states,
1277
- chunk_hidden_states->ne [0 ],
1278
- ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1279
1343
return SDCondition (hidden_states, pooled, NULL );
1280
1344
}
1281
1345
0 commit comments