Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def perform_tuning(
placeholder_token_ids,
placeholder_tokens,
save_path,
lr_scheduler_lora,
):

progress_bar = tqdm(range(num_steps))
Expand All @@ -430,6 +431,8 @@ def perform_tuning(

for epoch in range(math.ceil(num_steps / len(dataloader))):
for batch in dataloader:
lr_scheduler_lora.step()

optimizer.zero_grad()

loss = loss_step(
Expand All @@ -447,6 +450,11 @@ def perform_tuning(
)
optimizer.step()
progress_bar.update(1)
logs = {
"loss": loss.detach().item(),
"lr": lr_scheduler_lora.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)

global_step += 1

Expand Down Expand Up @@ -504,27 +512,29 @@ def train(
color_jitter: bool = True,
train_batch_size: int = 1,
sample_batch_size: int = 1,
max_train_steps_tuning: int = 10000,
max_train_steps_ti: int = 2000,
save_steps: int = 500,
gradient_accumulation_steps: int = 1,
max_train_steps_tuning: int = 1000,
max_train_steps_ti: int = 1000,
save_steps: int = 100,
gradient_accumulation_steps: int = 4,
gradient_checkpointing: bool = False,
mixed_precision="fp16",
lora_rank: int = 4,
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
lora_clip_target_modules={"CLIPAttention"},
clip_ti_decay: bool = True,
learning_rate_unet: float = 1e-5,
learning_rate_unet: float = 1e-4,
learning_rate_text: float = 1e-5,
learning_rate_ti: float = 5e-4,
continue_inversion: bool = True,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_scheduler: str = "linear",
lr_warmup_steps: int = 0,
weight_decay_ti: float = 0.01,
weight_decay_lora: float = 0.01,
lr_scheduler_lora: str = "linear",
lr_warmup_steps_lora: int = 0,
weight_decay_ti: float = 0.00,
weight_decay_lora: float = 0.001,
use_8bit_adam: bool = False,
device="cuda:0",
extra_args: Optional[dict] = None,
Expand Down Expand Up @@ -553,7 +563,7 @@ def train(
placeholder_tokens = placeholder_tokens.split("|")
if initializer_tokens is None:
print("PTI : Initializer Token not give, random inits")
initializer_tokens = ["<rand-0.036>"] * len(placeholder_tokens)
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
else:
initializer_tokens = initializer_tokens.split("|")

Expand Down Expand Up @@ -588,8 +598,7 @@ def train(
)

if gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
unet.gradient_checkpointing_enable()
unet.enable_gradient_checkpointing()

if scale_lr:
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
Expand Down Expand Up @@ -734,6 +743,13 @@ def train(

train_dataset.blur_amount = 70

lr_scheduler_lora = get_scheduler(
lr_scheduler_lora,
optimizer=lora_optimizers,
num_warmup_steps=lr_warmup_steps_lora,
num_training_steps=max_train_steps_tuning,
)

perform_tuning(
unet,
vae,
Expand All @@ -746,6 +762,7 @@ def train(
placeholder_tokens=placeholder_tokens,
placeholder_token_ids=placeholder_token_ids,
save_path=output_dir,
lr_scheduler_lora=lr_scheduler_lora,
)


Expand Down
26 changes: 13 additions & 13 deletions lora_diffusion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,15 @@
]


def image_grid(_imgs, rows = None, cols = None):
def image_grid(_imgs, rows=None, cols=None):

if rows is None and cols is None:
rows = cols = math.ceil(len(_imgs) ** 0.5)

if rows is None:
rows = math.ceil(len(_imgs) / cols)
if cols is None:
cols = math.ceil(len(_imgs) / rows)


w, h = _imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
Expand Down Expand Up @@ -176,25 +175,23 @@ def visualize_progress(
text_sclae=1.0,
num_inference_steps=50,
guidance_scale=5.0,
offset : int = 0,
limit : int = 10,
seed : int = 0
offset: int = 0,
limit: int = 10,
seed: int = 0,
):


imgs = []
if isinstance(path_alls, str):
alls = list(set(glob.glob(path_alls)))

alls.sort(key=os.path.getmtime)
else:
alls = path_alls

pipe = StableDiffusionPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
).to(device)


print(f"Found {len(alls)} checkpoints")
for path in alls[offset:limit]:
print(path)
Expand All @@ -207,8 +204,11 @@ def visualize_progress(
tune_lora_scale(pipe.text_encoder, text_sclae)

torch.manual_seed(seed)
image = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images[0]
image = pipe(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).images[0]
imgs.append(image)

return imgs

3 changes: 3 additions & 0 deletions training_scripts/multivector_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@ lora_pti \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--scale_lr \
--learning_rate_unet=1e-4 \
--learning_rate_text=1e-5 \
--learning_rate_ti=5e-4 \
--color_jitter \
--lr_scheduler="linear" \
--lr_warmup_steps=0 \
--lr_scheduler_lora="linear" \
--lr_warmup_steps_lora=100 \
--placeholder_tokens="<s1>|<s2>" \
--use_template="style"\
--save_steps=100 \
Expand Down
13 changes: 9 additions & 4 deletions training_scripts/train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,25 @@ def __init__(
self.class_prompt = class_prompt
else:
self.class_data_root = None

img_transforms = []

img_transforms = []

if resize:
img_transforms.append(transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR))
img_transforms.append(
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
)
)
if center_crop:
img_transforms.append(transforms.CenterCrop(size))
if color_jitter:
img_transforms.append(transforms.ColorJitter(0.2, 0.1))
if h_flip:
img_transforms.append(transforms.RandomHorizontalFlip())

self.image_transforms = transforms.Compose([*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
self.image_transforms = transforms.Compose(
[*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
)

def __len__(self):
return self._length
Expand Down