Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 125 additions & 50 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,31 @@ def get_models(
)


def text2img_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
@torch.no_grad()
def text2img_dataloader(
train_dataset,
train_batch_size,
tokenizer,
vae,
text_encoder,
cached_latents: bool = False,
):

if cached_latents:
cached_latents_dataset = []
for idx in tqdm(range(len(train_dataset))):
batch = train_dataset[idx]
# rint(batch)
latents = vae.encode(
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
).latent_dist.sample()
latents = latents * 0.18215
batch["instance_images"] = latents.squeeze(0)
cached_latents_dataset.append(batch)

def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if examples[0].get("class_prompt_ids", None) is not None:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]

pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

Expand All @@ -159,33 +173,60 @@ def collate_fn(examples):

return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)
if cached_latents:

train_dataloader = torch.utils.data.DataLoader(
cached_latents_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)

print("PTI : Using cached latent.")

else:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
)

return train_dataloader

def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):

def inpainting_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
):
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
mask_values = [example["instance_masks"] for example in examples]
masked_image_values = [example["instance_masked_images"] for example in examples]
masked_image_values = [
example["instance_masked_images"] for example in examples
]

# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if examples[0].get("class_prompt_ids", None) is not None:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
mask_values += [example["class_masks"] for example in examples]
masked_image_values += [example["class_masked_images"] for example in examples]
masked_image_values += [
example["class_masked_images"] for example in examples
]

pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()
pixel_values = (
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
)
mask_values = (
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
)
masked_image_values = (
torch.stack(masked_image_values)
.to(memory_format=torch.contiguous_format)
.float()
)

input_ids = tokenizer.pad(
{"input_ids": input_ids},
Expand All @@ -198,7 +239,7 @@ def collate_fn(examples):
"input_ids": input_ids,
"pixel_values": pixel_values,
"mask_values": mask_values,
"masked_image_values": masked_image_values
"masked_image_values": masked_image_values,
}

if examples[0].get("mask", None) is not None:
Expand All @@ -215,6 +256,7 @@ def collate_fn(examples):

return train_dataloader


def loss_step(
batch,
unet,
Expand All @@ -225,23 +267,30 @@ def loss_step(
t_mutliplier=1.0,
mixed_precision=False,
mask_temperature=1.0,
cached_latents: bool = False,
):
weight_dtype = torch.float32

latents = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
latents = latents * 0.18215

if train_inpainting:
masked_image_latents = vae.encode(
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
if not cached_latents:
latents = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
masked_image_latents = masked_image_latents * 0.18215
mask = F.interpolate(
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
scale_factor=1/8
)
latents = latents * 0.18215

if train_inpainting:
masked_image_latents = vae.encode(
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
masked_image_latents = masked_image_latents * 0.18215
mask = F.interpolate(
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
scale_factor=1 / 8,
)
else:
latents = batch["pixel_values"]

if train_inpainting:
masked_image_latents = batch["masked_image_latents"]
mask = batch["mask_values"]

noise = torch.randn_like(latents)
bsz = latents.shape[0]
Expand All @@ -257,7 +306,9 @@ def loss_step(
noisy_latents = scheduler.add_noise(latents, noise, timesteps)

if train_inpainting:
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
latent_model_input = torch.cat(
[noisy_latents, mask, masked_image_latents], dim=1
)
else:
latent_model_input = noisy_latents

Expand All @@ -268,7 +319,9 @@ def loss_step(
batch["input_ids"].to(text_encoder.device)
)[0]

model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
model_pred = unet(
latent_model_input, timesteps, encoder_hidden_states
).sample
else:

encoder_hidden_states = text_encoder(
Expand Down Expand Up @@ -308,7 +361,12 @@ def loss_step(

target = target * mask

loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
loss = (
F.mse_loss(model_pred.float(), target.float(), reduction="none")
.mean([1, 2, 3])
.mean()
)

return loss


Expand All @@ -328,6 +386,7 @@ def train_inversion(
tokenizer,
lr_scheduler,
test_image_path: str,
cached_latents: bool,
accum_iter: int = 1,
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
Expand Down Expand Up @@ -367,6 +426,7 @@ def train_inversion(
scheduler,
train_inpainting=train_inpainting,
mixed_precision=mixed_precision,
cached_latents=cached_latents,
)
/ accum_iter
)
Expand All @@ -375,6 +435,13 @@ def train_inversion(
loss_sum += loss.detach().item()

if global_step % accum_iter == 0:
# print gradient of text encoder embedding
print(
text_encoder.get_input_embeddings()
.weight.grad[index_updates, :]
.norm(dim=-1)
.mean()
)
optimizer.step()
optimizer.zero_grad()

Expand Down Expand Up @@ -448,7 +515,11 @@ def train_inversion(
# open all images in test_image_path
images = []
for file in os.listdir(test_image_path):
if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"):
if (
file.lower().endswith(".png")
or file.lower().endswith(".jpg")
or file.lower().endswith(".jpeg")
):
images.append(
Image.open(os.path.join(test_image_path, file))
)
Expand Down Expand Up @@ -490,6 +561,7 @@ def perform_tuning(
out_name: str,
tokenizer,
test_image_path: str,
cached_latents: bool,
log_wandb: bool = False,
wandb_log_prompt_cnt: int = 10,
class_token: str = "person",
Expand Down Expand Up @@ -526,6 +598,7 @@ def perform_tuning(
t_mutliplier=0.8,
mixed_precision=True,
mask_temperature=mask_temperature,
cached_latents=cached_latents,
)
loss_sum += loss.detach().item()

Expand Down Expand Up @@ -627,18 +700,12 @@ def train(
train_text_encoder: bool = True,
pretrained_vae_name_or_path: str = None,
revision: Optional[str] = None,
class_data_dir: Optional[str] = None,
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Literal[None, "object", "style"] = None,
train_inpainting: bool = False,
placeholder_tokens: str = "",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: Optional[str] = None,
class_prompt: Optional[str] = None,
with_prior_preservation: bool = False,
prior_loss_weight: float = 1.0,
num_class_images: int = 100,
seed: int = 42,
resolution: int = 512,
color_jitter: bool = True,
Expand All @@ -649,7 +716,6 @@ def train(
save_steps: int = 100,
gradient_accumulation_steps: int = 4,
gradient_checkpointing: bool = False,
mixed_precision="fp16",
lora_rank: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
Expand All @@ -663,6 +729,7 @@ def train(
continue_inversion: bool = False,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
cached_latents: bool = True,
use_mask_captioned_data: bool = False,
mask_temperature: float = 1.0,
scale_lr: bool = False,
Expand Down Expand Up @@ -773,11 +840,8 @@ def train(

train_dataset = PivotalTuningDatasetCapation(
instance_data_root=instance_data_dir,
stochastic_attribute=stochastic_attribute,
token_map=token_map,
use_template=use_template,
class_data_root=class_data_dir if with_prior_preservation else None,
class_prompt=class_prompt,
tokenizer=tokenizer,
size=resolution,
color_jitter=color_jitter,
Expand All @@ -789,12 +853,19 @@ def train(
train_dataset.blur_amount = 200

if train_inpainting:
assert not cached_latents, "Cached latents not supported for inpainting"

train_dataloader = inpainting_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
)
else:
train_dataloader = text2img_dataloader(
train_dataset, train_batch_size, tokenizer, vae, text_encoder
train_dataset,
train_batch_size,
tokenizer,
vae,
text_encoder,
cached_latents=cached_latents,
)

index_no_updates = torch.arange(len(tokenizer)) != -1
Expand All @@ -813,6 +884,8 @@ def train(
for param in params_to_freeze:
param.requires_grad = False

if cached_latents:
vae = None
# STEP 1 : Perform Inversion
if perform_inversion:
ti_optimizer = optim.AdamW(
Expand All @@ -836,6 +909,7 @@ def train(
text_encoder,
train_dataloader,
max_train_steps_ti,
cached_latents=cached_latents,
accum_iter=gradient_accumulation_steps,
scheduler=noise_scheduler,
index_no_updates=index_no_updates,
Expand Down Expand Up @@ -941,6 +1015,7 @@ def train(
text_encoder,
train_dataloader,
max_train_steps_tuning,
cached_latents=cached_latents,
scheduler=noise_scheduler,
optimizer=lora_optimizers,
save_steps=save_steps,
Expand Down
Loading