From 85fa6e79092a007da8a5f4dbe033836e2a9db4cc Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Thu, 9 Feb 2023 20:34:22 +0000 Subject: [PATCH 1/9] hotfix : token duplicate --- lora_diffusion/lora_manager.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lora_diffusion/lora_manager.py b/lora_diffusion/lora_manager.py index 2ef9608..02d679b 100644 --- a/lora_diffusion/lora_manager.py +++ b/lora_diffusion/lora_manager.py @@ -12,6 +12,7 @@ def lora_join(lora_safetenors: list): metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors] + _total_metadata = {} total_metadata = {} total_tensor = {} total_rank = 0 @@ -24,9 +25,14 @@ def lora_join(lora_safetenors: list): assert len(set(rankset)) == 1, "Rank should be the same per model" total_rank += rankset[0] - total_metadata.update(_metadata) + _total_metadata.update(_metadata) ranklist.append(rankset[0]) + # remove metadata about tokens + for k, v in _total_metadata.items(): + if v != "": + total_metadata[k] = v + tensorkeys = set() for safelora in lora_safetenors: tensorkeys.update(safelora.keys()) @@ -57,9 +63,6 @@ def lora_join(lora_safetenors: list): print(f"Embedding {token} replaced to ") - if total_metadata.get(token, None) is not None: - del total_metadata[token] - token_size_list.append(len(tokens)) return total_tensor, total_metadata, ranklist, token_size_list From 799c17aef2a475641fb70d68a6992de4fc325ce4 Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Sat, 11 Feb 2023 19:06:11 +0000 Subject: [PATCH 2/9] hotfix : background bias --- lora_diffusion/dataset.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index c9233e8..2a46313 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -86,7 +86,16 @@ def _shuffle(lis): return random.sample(lis, len(lis)) -def _get_cutout_holes(height, width, min_holes=8, max_holes=32, min_height=16, max_height=128, min_width=16, max_width=128): +def _get_cutout_holes( + height, + width, + min_holes=8, + max_holes=32, + min_height=16, + max_height=128, + min_width=16, + max_width=128, +): holes = [] for _n in range(random.randint(min_holes, max_holes)): hole_height = random.randint(min_height, max_height) @@ -103,12 +112,13 @@ def _generate_random_mask(image): mask = zeros_like(image[:1]) holes = _get_cutout_holes(mask.shape[1], mask.shape[2]) for (x1, y1, x2, y2) in holes: - mask[:, y1:y2, x1:x2] = 1. + mask[:, y1:y2, x1:x2] = 1.0 if random.uniform(0, 1) < 0.25: - mask.fill_(1.) + mask.fill_(1.0) masked_image = image * (mask < 0.5) return mask, masked_image + class PivotalTuningDatasetCapation(Dataset): """ A dataset to prepare the instance and class images with the prompts for fine-tuning the model. @@ -274,7 +284,10 @@ def __getitem__(self, index): example["instance_images"] = self.image_transforms(instance_image) if self.train_inpainting: - example["instance_masks"], example["instance_masked_images"] = _generate_random_mask(example["instance_images"]) + ( + example["instance_masks"], + example["instance_masked_images"], + ) = _generate_random_mask(example["instance_images"]) if self.use_template: assert self.token_map is not None @@ -296,7 +309,7 @@ def __getitem__(self, index): Image.open(self.mask_path[index % self.num_instance_images]) ) * 0.5 - + 0.5 + + 1.0 ) if self.h_flip and random.random() > 0.5: @@ -321,7 +334,10 @@ def __getitem__(self, index): class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) if self.train_inpainting: - example["class_masks"], example["class_masked_images"] = _generate_random_mask(example["class_images"]) + ( + example["class_masks"], + example["class_masked_images"], + ) = _generate_random_mask(example["class_images"]) example["class_prompt_ids"] = self.tokenizer( self.class_prompt, padding="do_not_pad", From bade0dab9af2fc8deb42714cfa2cea78828c67f7 Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Sun, 12 Feb 2023 07:41:41 +0000 Subject: [PATCH 3/9] cleanup --- lora_diffusion/cli_lora_pti.py | 56 +++++++++++++++++++++------------- lora_diffusion/dataset.py | 34 --------------------- 2 files changed, 35 insertions(+), 55 deletions(-) diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 83703d0..3297e38 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -168,12 +168,17 @@ def collate_fn(examples): return train_dataloader -def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder): + +def inpainting_dataloader( + train_dataset, train_batch_size, tokenizer, vae, text_encoder +): def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] mask_values = [example["instance_masks"] for example in examples] - masked_image_values = [example["instance_masked_images"] for example in examples] + masked_image_values = [ + example["instance_masked_images"] for example in examples + ] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. @@ -181,11 +186,21 @@ def collate_fn(examples): input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] mask_values += [example["class_masks"] for example in examples] - masked_image_values += [example["class_masked_images"] for example in examples] + masked_image_values += [ + example["class_masked_images"] for example in examples + ] - pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() - mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float() - masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float() + pixel_values = ( + torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() + ) + mask_values = ( + torch.stack(mask_values).to(memory_format=torch.contiguous_format).float() + ) + masked_image_values = ( + torch.stack(masked_image_values) + .to(memory_format=torch.contiguous_format) + .float() + ) input_ids = tokenizer.pad( {"input_ids": input_ids}, @@ -198,7 +213,7 @@ def collate_fn(examples): "input_ids": input_ids, "pixel_values": pixel_values, "mask_values": mask_values, - "masked_image_values": masked_image_values + "masked_image_values": masked_image_values, } if examples[0].get("mask", None) is not None: @@ -215,6 +230,7 @@ def collate_fn(examples): return train_dataloader + def loss_step( batch, unet, @@ -240,7 +256,7 @@ def loss_step( masked_image_latents = masked_image_latents * 0.18215 mask = F.interpolate( batch["mask_values"].to(dtype=weight_dtype).to(unet.device), - scale_factor=1/8 + scale_factor=1 / 8, ) noise = torch.randn_like(latents) @@ -257,7 +273,9 @@ def loss_step( noisy_latents = scheduler.add_noise(latents, noise, timesteps) if train_inpainting: - latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1) + latent_model_input = torch.cat( + [noisy_latents, mask, masked_image_latents], dim=1 + ) else: latent_model_input = noisy_latents @@ -268,7 +286,9 @@ def loss_step( batch["input_ids"].to(text_encoder.device) )[0] - model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + model_pred = unet( + latent_model_input, timesteps, encoder_hidden_states + ).sample else: encoder_hidden_states = text_encoder( @@ -448,7 +468,11 @@ def train_inversion( # open all images in test_image_path images = [] for file in os.listdir(test_image_path): - if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"): + if ( + file.lower().endswith(".png") + or file.lower().endswith(".jpg") + or file.lower().endswith(".jpeg") + ): images.append( Image.open(os.path.join(test_image_path, file)) ) @@ -627,18 +651,12 @@ def train( train_text_encoder: bool = True, pretrained_vae_name_or_path: str = None, revision: Optional[str] = None, - class_data_dir: Optional[str] = None, - stochastic_attribute: Optional[str] = None, perform_inversion: bool = True, use_template: Literal[None, "object", "style"] = None, train_inpainting: bool = False, placeholder_tokens: str = "", placeholder_token_at_data: Optional[str] = None, initializer_tokens: Optional[str] = None, - class_prompt: Optional[str] = None, - with_prior_preservation: bool = False, - prior_loss_weight: float = 1.0, - num_class_images: int = 100, seed: int = 42, resolution: int = 512, color_jitter: bool = True, @@ -649,7 +667,6 @@ def train( save_steps: int = 100, gradient_accumulation_steps: int = 4, gradient_checkpointing: bool = False, - mixed_precision="fp16", lora_rank: int = 4, lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, lora_clip_target_modules={"CLIPAttention"}, @@ -773,11 +790,8 @@ def train( train_dataset = PivotalTuningDatasetCapation( instance_data_root=instance_data_dir, - stochastic_attribute=stochastic_attribute, token_map=token_map, use_template=use_template, - class_data_root=class_data_dir if with_prior_preservation else None, - class_prompt=class_prompt, tokenizer=tokenizer, size=resolution, color_jitter=color_jitter, diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 2a46313..77508f0 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -128,12 +128,9 @@ class PivotalTuningDatasetCapation(Dataset): def __init__( self, instance_data_root, - stochastic_attribute, tokenizer, token_map: Optional[dict] = None, use_template: Optional[str] = None, - class_data_root=None, - class_prompt=None, size=512, h_flip=True, color_jitter=False, @@ -240,18 +237,6 @@ def __init__( self._length = self.num_instance_images - if class_data_root is not None: - assert NotImplementedError, "Prior preservation is not implemented yet." - - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) - self._length = max(self.num_class_images, self.num_instance_images) - self.class_prompt = class_prompt - else: - self.class_data_root = None - self.h_flip = h_flip self.image_transforms = transforms.Compose( [ @@ -326,23 +311,4 @@ def __getitem__(self, index): max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) - if self.train_inpainting: - ( - example["class_masks"], - example["class_masked_images"], - ) = _generate_random_mask(example["class_images"]) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - padding="do_not_pad", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids - return example From 93800601a4d05b93240996cdced1f502cb049a3d Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Sun, 12 Feb 2023 09:10:27 +0000 Subject: [PATCH 4/9] cached latent --- lora_diffusion/cli_lora_pti.py | 107 ++++++++++++++++++++++++--------- lora_diffusion/dataset.py | 5 +- 2 files changed, 79 insertions(+), 33 deletions(-) diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 3297e38..86ea9c6 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -128,17 +128,31 @@ def get_models( ) -def text2img_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder): +@torch.no_grad() +def text2img_dataloader( + train_dataset, + train_batch_size, + tokenizer, + vae, + text_encoder, + cached_latents: bool = False, +): + + if cached_latents: + cached_latents_dataset = [] + for idx in tqdm(range(len(train_dataset))): + batch = train_dataset[idx] + # rint(batch) + latents = vae.encode( + batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device) + ).latent_dist.sample() + latents = latents * 0.18215 + batch["instance_images"] = latents.squeeze(0) + cached_latents_dataset.append(batch) + def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. - if examples[0].get("class_prompt_ids", None) is not None: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -159,12 +173,24 @@ def collate_fn(examples): return batch - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=train_batch_size, - shuffle=True, - collate_fn=collate_fn, - ) + if cached_latents: + + train_dataloader = torch.utils.data.DataLoader( + cached_latents_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + print("PTI : Using cached latent.") + + else: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=collate_fn, + ) return train_dataloader @@ -241,23 +267,30 @@ def loss_step( t_mutliplier=1.0, mixed_precision=False, mask_temperature=1.0, + cached_latents: bool = False, ): weight_dtype = torch.float32 - - latents = vae.encode( - batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) - ).latent_dist.sample() - latents = latents * 0.18215 - - if train_inpainting: - masked_image_latents = vae.encode( - batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device) + if not cached_latents: + latents = vae.encode( + batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) ).latent_dist.sample() - masked_image_latents = masked_image_latents * 0.18215 - mask = F.interpolate( - batch["mask_values"].to(dtype=weight_dtype).to(unet.device), - scale_factor=1 / 8, - ) + latents = latents * 0.18215 + + if train_inpainting: + masked_image_latents = vae.encode( + batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device) + ).latent_dist.sample() + masked_image_latents = masked_image_latents * 0.18215 + mask = F.interpolate( + batch["mask_values"].to(dtype=weight_dtype).to(unet.device), + scale_factor=1 / 8, + ) + else: + latents = batch["pixel_values"] + + if train_inpainting: + masked_image_latents = batch["masked_image_latents"] + mask = batch["mask_values"] noise = torch.randn_like(latents) bsz = latents.shape[0] @@ -348,6 +381,7 @@ def train_inversion( tokenizer, lr_scheduler, test_image_path: str, + cached_latents: bool, accum_iter: int = 1, log_wandb: bool = False, wandb_log_prompt_cnt: int = 10, @@ -387,6 +421,7 @@ def train_inversion( scheduler, train_inpainting=train_inpainting, mixed_precision=mixed_precision, + cached_latents=cached_latents, ) / accum_iter ) @@ -514,6 +549,7 @@ def perform_tuning( out_name: str, tokenizer, test_image_path: str, + cached_latents: bool, log_wandb: bool = False, wandb_log_prompt_cnt: int = 10, class_token: str = "person", @@ -550,6 +586,7 @@ def perform_tuning( t_mutliplier=0.8, mixed_precision=True, mask_temperature=mask_temperature, + cached_latents=cached_latents, ) loss_sum += loss.detach().item() @@ -680,6 +717,7 @@ def train( continue_inversion: bool = False, continue_inversion_lr: Optional[float] = None, use_face_segmentation_condition: bool = False, + cached_latents: bool = True, use_mask_captioned_data: bool = False, mask_temperature: float = 1.0, scale_lr: bool = False, @@ -803,12 +841,19 @@ def train( train_dataset.blur_amount = 200 if train_inpainting: + assert not cached_latents, "Cached latents not supported for inpainting" + train_dataloader = inpainting_dataloader( train_dataset, train_batch_size, tokenizer, vae, text_encoder ) else: train_dataloader = text2img_dataloader( - train_dataset, train_batch_size, tokenizer, vae, text_encoder + train_dataset, + train_batch_size, + tokenizer, + vae, + text_encoder, + cached_latents=cached_latents, ) index_no_updates = torch.arange(len(tokenizer)) != -1 @@ -827,6 +872,8 @@ def train( for param in params_to_freeze: param.requires_grad = False + if cached_latents: + vae = None # STEP 1 : Perform Inversion if perform_inversion: ti_optimizer = optim.AdamW( @@ -850,6 +897,7 @@ def train( text_encoder, train_dataloader, max_train_steps_ti, + cached_latents=cached_latents, accum_iter=gradient_accumulation_steps, scheduler=noise_scheduler, index_no_updates=index_no_updates, @@ -955,6 +1003,7 @@ def train( text_encoder, train_dataloader, max_train_steps_tuning, + cached_latents=cached_latents, scheduler=noise_scheduler, optimizer=lora_optimizers, save_steps=save_steps, diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 77508f0..f1c28fd 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -2,14 +2,11 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -import cv2 -import numpy as np -from PIL import Image, ImageFilter +from PIL import Image from torch import zeros_like from torch.utils.data import Dataset from torchvision import transforms import glob - from .preprocess_files import face_mask_google_mediapipe OBJECT_TEMPLATE = [ From afbcea104141b1ed8d6e95007ec6324c182a44b9 Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Sun, 12 Feb 2023 10:05:58 +0000 Subject: [PATCH 5/9] fix precision problem? --- lora_diffusion/cli_lora_pti.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 86ea9c6..b5f6d0a 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -361,7 +361,9 @@ def loss_step( target = target * mask - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.sum(0).mean() + return loss From def3e3eda05b44d2a3639524ad06ea547c63804b Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Sun, 12 Feb 2023 11:45:18 +0000 Subject: [PATCH 6/9] bugfix : ti --- lora_diffusion/lora_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora_diffusion/lora_manager.py b/lora_diffusion/lora_manager.py index 02d679b..f7db965 100644 --- a/lora_diffusion/lora_manager.py +++ b/lora_diffusion/lora_manager.py @@ -23,7 +23,7 @@ def lora_join(lora_safetenors: list): if k.endswith("rank"): rankset.append(int(v)) - assert len(set(rankset)) == 1, "Rank should be the same per model" + assert len(set(rankset)) <= 1, "Rank should be the same per model" total_rank += rankset[0] _total_metadata.update(_metadata) ranklist.append(rankset[0]) From 9b6df8f01e6c580799710ced7421bd48625ce851 Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Mon, 13 Feb 2023 09:03:14 +0000 Subject: [PATCH 7/9] loss --- lora_diffusion/cli_lora_pti.py | 16 +++++++++++++--- lora_diffusion/lora_manager.py | 7 +++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index b5f6d0a..21b4245 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -361,8 +361,11 @@ def loss_step( target = target * mask - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.sum(0).mean() + loss = ( + F.mse_loss(model_pred.float(), target.float(), reduction="none") + .mean([1, 2, 3]) + .mean() + ) return loss @@ -432,6 +435,13 @@ def train_inversion( loss_sum += loss.detach().item() if global_step % accum_iter == 0: + # print gradient of text encoder embedding + print( + text_encoder.get_input_embeddings() + .weight.grad[index_updates, :] + .norm(dim=-1) + .mean() + ) optimizer.step() optimizer.zero_grad() @@ -914,7 +924,7 @@ def train( wandb_log_prompt_cnt=wandb_log_prompt_cnt, class_token=class_token, train_inpainting=train_inpainting, - mixed_precision=False, + mixed_precision=True, tokenizer=tokenizer, clip_ti_decay=clip_ti_decay, ) diff --git a/lora_diffusion/lora_manager.py b/lora_diffusion/lora_manager.py index f7db965..9d8306e 100644 --- a/lora_diffusion/lora_manager.py +++ b/lora_diffusion/lora_manager.py @@ -24,6 +24,9 @@ def lora_join(lora_safetenors: list): rankset.append(int(v)) assert len(set(rankset)) <= 1, "Rank should be the same per model" + if len(rankset) == 0: + rankset = [0] + total_rank += rankset[0] _total_metadata.update(_metadata) ranklist.append(rankset[0]) @@ -119,6 +122,10 @@ def _setup(self): def tune(self, scales): + assert len(scales) == len( + self.ranklist + ), "Scale list should be the same length as ranklist" + diags = [] for scale, rank in zip(scales, self.ranklist): diags = diags + [scale] * rank From abbde8456ca360354e1366657d0ca5adab35923b Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Mon, 13 Feb 2023 09:04:20 +0000 Subject: [PATCH 8/9] false mixed precision --- lora_diffusion/cli_lora_pti.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 21b4245..7de4bae 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -924,7 +924,7 @@ def train( wandb_log_prompt_cnt=wandb_log_prompt_cnt, class_token=class_token, train_inpainting=train_inpainting, - mixed_precision=True, + mixed_precision=False, tokenizer=tokenizer, clip_ti_decay=clip_ti_decay, ) From e48cbbb938afd9414d9773771c7761208489bb71 Mon Sep 17 00:00:00 2001 From: SimoRyu Date: Mon, 13 Feb 2023 09:06:18 +0000 Subject: [PATCH 9/9] version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3f21767..6d286b3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="lora_diffusion", py_modules=["lora_diffusion"], - version="0.1.6", + version="0.1.7", description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.", author="Simo Ryu", packages=find_packages(),