diff --git a/clip.hpp b/clip.hpp index 9bdbc8072..2ba12da60 100644 --- a/clip.hpp +++ b/clip.hpp @@ -891,6 +891,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { LOG_ERROR("embedding '%s' failed", embd_name.c_str()); return false; } + if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) { + LOG_DEBUG("embedding already read in: %s", embd_name.c_str()); + return true; + } struct ggml_init_params params; params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB params.mem_buffer = NULL; diff --git a/lora.hpp b/lora.hpp index 734635b66..7eb42e100 100644 --- a/lora.hpp +++ b/lora.hpp @@ -33,6 +33,7 @@ struct LoraModel : public GGMLModule { return model_loader.get_params_mem_size(NULL); } + bool load_from_file() { LOG_INFO("loading LoRA from '%s'", file_path.c_str()); @@ -55,6 +56,7 @@ struct LoraModel : public GGMLModule { auto real = lora_tensors[name]; *dst_tensor = real; } + return true; }; @@ -64,6 +66,7 @@ struct LoraModel : public GGMLModule { dry_run = false; model_loader.load_tensors(on_new_tensor_cb, backend); + LOG_DEBUG("finished loaded lora"); return true; } diff --git a/stable-diffusion.h b/stable-diffusion.h index 99eba4330..5d1476f85 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -92,8 +92,10 @@ enum sd_log_level_t { }; typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); +typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data); SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); +SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); SD_API int32_t get_num_physical_cores(); SD_API const char* sd_get_system_info(); diff --git a/util.cpp b/util.cpp index a4e74a3e3..94b731457 100644 --- a/util.cpp +++ b/util.cpp @@ -161,6 +161,9 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } +static sd_progress_cb_t sd_progress_cb = NULL; +void* sd_progress_cb_data = NULL; + std::u32string utf8_to_utf32(const std::string& utf8_str) { std::wstring_convert, char32_t> converter; return converter.from_bytes(utf8_str); @@ -205,6 +208,10 @@ std::string path_join(const std::string& p1, const std::string& p2) { } void pretty_progress(int step, int steps, float time) { + if (sd_progress_cb) { + sd_progress_cb(step,steps,time, sd_progress_cb_data); + return; + } if (step == 0) { return; } @@ -248,8 +255,9 @@ std::string trim(const std::string& s) { return rtrim(ltrim(s)); } -static sd_log_cb_t sd_log_cb = NULL; -void* sd_log_cb_data = NULL; +static sd_log_cb_t sd_log_cb = NULL; +void* sd_log_cb_data = NULL; + #define LOG_BUFFER_SIZE 1024 @@ -286,7 +294,10 @@ void sd_set_log_callback(sd_log_cb_t cb, void* data) { sd_log_cb = cb; sd_log_cb_data = data; } - +void sd_set_progress_callback(sd_progress_cb_t cb, void* data) { + sd_progress_cb = cb; + sd_progress_cb_data = data; +} const char* sd_get_system_info() { static char buffer[1024]; std::stringstream ss; diff --git a/vae.hpp b/vae.hpp index 32b610e5c..0c5752095 100644 --- a/vae.hpp +++ b/vae.hpp @@ -6,7 +6,7 @@ /*================================================== AutoEncoderKL ===================================================*/ -#define VAE_GRAPH_SIZE 10240 +#define VAE_GRAPH_SIZE 20480 class ResnetBlock : public UnaryBlock { protected: