Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b5f4932

Browse files
committed
refactor: add some sd vesion helper functions
1 parent 1c168d9 commit b5f4932

File tree

5 files changed

+59
-38
lines changed

5 files changed

+59
-38
lines changed

conditioner.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct Conditioner {
4444
// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
4545
struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
4646
SDVersion version = VERSION_SD1;
47-
PMVersion pm_version = VERSION_1;
47+
PMVersion pm_version = PM_VERSION_1;
4848
CLIPTokenizer tokenizer;
4949
ggml_type wtype;
5050
std::shared_ptr<CLIPTextModelRunner> text_model;
@@ -60,7 +60,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
6060
ggml_type wtype,
6161
const std::string& embd_dir,
6262
SDVersion version = VERSION_SD1,
63-
PMVersion pv = VERSION_1,
63+
PMVersion pv = PM_VERSION_1,
6464
int clip_skip = -1)
6565
: version(version), pm_version(pv), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) {
6666
if (clip_skip <= 0) {
@@ -270,7 +270,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
270270
std::vector<int> clean_input_ids_tmp;
271271
for (uint32_t i = 0; i < class_token_index[0]; i++)
272272
clean_input_ids_tmp.push_back(clean_input_ids[i]);
273-
for (uint32_t i = 0; i < (pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
273+
for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++)
274274
clean_input_ids_tmp.push_back(class_token);
275275
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
276276
clean_input_ids_tmp.push_back(clean_input_ids[i]);
@@ -286,7 +286,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
286286
// weights.insert(weights.begin(), 1.0);
287287

288288
tokenizer.pad_tokens(tokens, weights, max_length, padding);
289-
int offset = pm_version == VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
289+
int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs;
290290
for (uint32_t i = 0; i < tokens.size(); i++) {
291291
// if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs
292292
if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs

model.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,30 @@ enum SDVersion {
3131
VERSION_COUNT,
3232
};
3333

34+
static inline bool sd_version_is_flux(SDVersion version) {
35+
if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
36+
return true;
37+
}
38+
return false;
39+
}
40+
41+
static inline bool sd_version_is_sd3(SDVersion version) {
42+
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
43+
return true;
44+
}
45+
return false;
46+
}
47+
48+
static inline bool sd_version_is_dit(SDVersion version) {
49+
if (sd_version_is_flux(version) || sd_version_is_sd3(version)) {
50+
return true;
51+
}
52+
return false;
53+
}
54+
3455
enum PMVersion {
35-
VERSION_1,
36-
VERSION_2,
56+
PM_VERSION_1,
57+
PM_VERSION_2,
3758
};
3859

3960
struct TensorStorage {

pmid.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionMo
608608
struct PhotoMakerIDEncoder : public GGMLRunner {
609609
public:
610610
SDVersion version = VERSION_SDXL;
611-
PMVersion pm_version = VERSION_1;
611+
PMVersion pm_version = PM_VERSION_1;
612612
PhotoMakerIDEncoderBlock id_encoder;
613613
PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2;
614614
float style_strength;
@@ -623,14 +623,14 @@ struct PhotoMakerIDEncoder : public GGMLRunner {
623623
std::vector<float> zeros_right;
624624

625625
public:
626-
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = VERSION_1, float sty = 20.f)
626+
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f)
627627
: GGMLRunner(backend, wtype),
628628
version(version),
629629
pm_version(pm_v),
630630
style_strength(sty) {
631-
if (pm_version == VERSION_1) {
631+
if (pm_version == PM_VERSION_1) {
632632
id_encoder.init(params_ctx, wtype);
633-
} else if (pm_version == VERSION_2) {
633+
} else if (pm_version == PM_VERSION_2) {
634634
id_encoder2.init(params_ctx, wtype);
635635
}
636636
}
@@ -644,9 +644,9 @@ struct PhotoMakerIDEncoder : public GGMLRunner {
644644
}
645645

646646
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
647-
if (pm_version == VERSION_1)
647+
if (pm_version == PM_VERSION_1)
648648
id_encoder.get_param_tensors(tensors, prefix);
649-
else if (pm_version == VERSION_2)
649+
else if (pm_version == PM_VERSION_2)
650650
id_encoder2.get_param_tensors(tensors, prefix);
651651
}
652652

@@ -734,14 +734,14 @@ struct PhotoMakerIDEncoder : public GGMLRunner {
734734
}
735735
}
736736
struct ggml_tensor* updated_prompt_embeds = NULL;
737-
if (pm_version == VERSION_1)
737+
if (pm_version == PM_VERSION_1)
738738
updated_prompt_embeds = id_encoder.forward(ctx0,
739739
id_pixel_values_d,
740740
prompt_embeds_d,
741741
class_tokens_mask_d,
742742
class_tokens_mask_pos,
743743
left, right);
744-
else if (pm_version == VERSION_2)
744+
else if (pm_version == PM_VERSION_2)
745745
updated_prompt_embeds = id_encoder2.forward(ctx0,
746746
id_pixel_values_d,
747747
prompt_embeds_d,

stable-diffusion.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ class StableDiffusionGGML {
286286
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
287287
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
288288
}
289-
} else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
289+
} else if (sd_version_is_sd3(version)) {
290290
scale_factor = 1.5305f;
291-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
291+
} else if (sd_version_is_flux(version)) {
292292
scale_factor = 0.3611;
293293
// TODO: shift_factor
294294
}
@@ -309,7 +309,7 @@ class StableDiffusionGGML {
309309
} else {
310310
clip_backend = backend;
311311
bool use_t5xxl = false;
312-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
312+
if (sd_version_is_dit(version)) {
313313
use_t5xxl = true;
314314
}
315315
if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) {
@@ -323,18 +323,18 @@ class StableDiffusionGGML {
323323
if (diffusion_flash_attn) {
324324
LOG_INFO("Using flash attention in the diffusion model");
325325
}
326-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
326+
if (sd_version_is_sd3(version)) {
327327
if (diffusion_flash_attn) {
328328
LOG_WARN("flash attention in this diffusion model is currently unsupported!");
329329
}
330330
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, conditioner_wtype);
331331
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model_wtype, version);
332-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
332+
} else if (sd_version_is_flux(version)) {
333333
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, conditioner_wtype);
334334
diffusion_model = std::make_shared<FluxModel>(backend, diffusion_model_wtype, version, diffusion_flash_attn);
335335
} else {
336336
if (id_embeddings_path.find("v2") != std::string::npos) {
337-
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, VERSION_2);
337+
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version, PM_VERSION_2);
338338
} else {
339339
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, conditioner_wtype, embeddings_path, version);
340340
}
@@ -373,7 +373,7 @@ class StableDiffusionGGML {
373373
}
374374

375375
if (id_embeddings_path.find("v2") != std::string::npos) {
376-
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, VERSION_2);
376+
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version, PM_VERSION_2);
377377
LOG_INFO("using PhotoMaker Version 2");
378378
} else {
379379
pmid_model = std::make_shared<PhotoMakerIDEncoder>(backend, model_wtype, version);
@@ -527,10 +527,10 @@ class StableDiffusionGGML {
527527
is_using_v_parameterization = true;
528528
}
529529

530-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
530+
if (sd_version_is_sd3(version)) {
531531
LOG_INFO("running in FLOW mode");
532532
denoiser = std::make_shared<DiscreteFlowDenoiser>();
533-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
533+
} else if (sd_version_is_flux(version)) {
534534
LOG_INFO("running in Flux FLOW mode");
535535
float shift = 1.15f;
536536
if (version == VERSION_FLUX_SCHNELL) {
@@ -804,7 +804,7 @@ class StableDiffusionGGML {
804804
out_uncond = ggml_dup_tensor(work_ctx, x);
805805
}
806806
if (has_skiplayer) {
807-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
807+
if (sd_version_is_dit(version)) {
808808
out_skip = ggml_dup_tensor(work_ctx, x);
809809
} else {
810810
has_skiplayer = false;
@@ -995,9 +995,9 @@ class StableDiffusionGGML {
995995
if (use_tiny_autoencoder) {
996996
C = 4;
997997
} else {
998-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B) {
998+
if (sd_version_is_sd3(version)) {
999999
C = 32;
1000-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
1000+
} else if (sd_version_is_flux(version)) {
10011001
C = 32;
10021002
}
10031003
}
@@ -1214,7 +1214,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
12141214
}
12151215
// preprocess input id images
12161216
std::vector<sd_image_t*> input_id_images;
1217-
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == VERSION_2;
1217+
bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2;
12181218
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
12191219
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
12201220
for (std::string img_file : img_files) {
@@ -1343,9 +1343,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13431343
// Sample
13441344
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
13451345
int C = 4;
1346-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
1346+
if (sd_version_is_sd3(sd_ctx->sd->version)) {
13471347
C = 16;
1348-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
1348+
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
13491349
C = 16;
13501350
}
13511351
int W = width / 8;
@@ -1464,10 +1464,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14641464

14651465
struct ggml_init_params params;
14661466
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
1467-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
1467+
if (sd_version_is_sd3(sd_ctx->sd->version)) {
14681468
params.mem_size *= 3;
14691469
}
1470-
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
1470+
if (sd_version_is_flux(sd_ctx->sd->version)) {
14711471
params.mem_size *= 4;
14721472
}
14731473
if (sd_ctx->sd->stacked_id) {
@@ -1490,17 +1490,17 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
14901490
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps);
14911491

14921492
int C = 4;
1493-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
1493+
if (sd_version_is_sd3(sd_ctx->sd->version)) {
14941494
C = 16;
1495-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
1495+
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
14961496
C = 16;
14971497
}
14981498
int W = width / 8;
14991499
int H = height / 8;
15001500
ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
1501-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
1501+
if (sd_version_is_sd3(sd_ctx->sd->version)) {
15021502
ggml_set_f32(init_latent, 0.0609f);
1503-
} else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
1503+
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
15041504
ggml_set_f32(init_latent, 0.1159f);
15051505
} else {
15061506
ggml_set_f32(init_latent, 0.f);
@@ -1567,10 +1567,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
15671567

15681568
struct ggml_init_params params;
15691569
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
1570-
if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B || sd_ctx->sd->version == VERSION_SD3_5_2B) {
1570+
if (sd_version_is_sd3(sd_ctx->sd->version)) {
15711571
params.mem_size *= 2;
15721572
}
1573-
if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL || sd_ctx->sd->version == VERSION_FLUX_LITE) {
1573+
if (sd_version_is_flux(sd_ctx->sd->version)) {
15741574
params.mem_size *= 3;
15751575
}
15761576
if (sd_ctx->sd->stacked_id) {

vae.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ class AutoencodingEngine : public GGMLBlock {
459459
bool use_video_decoder = false,
460460
SDVersion version = VERSION_SD1)
461461
: decode_only(decode_only), use_video_decoder(use_video_decoder) {
462-
if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_5_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
462+
if (sd_version_is_dit(version)) {
463463
dd_config.z_channels = 16;
464464
use_quant = false;
465465
}

0 commit comments

Comments
 (0)