Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9c03a7d

Browse files
Fix DDIMInverseScheduler (#5145)
* fix ddim inverse scheduler * update test of ddim inverse scheduler * update test of pix2pix_zero * update test of diffedit * fix typo --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 1d3120f commit 9c03a7d

File tree

4 files changed

+14
-16
lines changed

4 files changed

+14
-16
lines changed

‎src/diffusers/schedulers/scheduling_ddim_inverse.py‎

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
288288
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
289289
)
290290

291-
# Roll timesteps array by one to reflect reversed origin and destination semantics for each step
292-
timesteps = np.roll(timesteps, 1)
293-
timesteps[0] = int(timesteps[1] - step_ratio)
294291
self.timesteps = torch.from_numpy(timesteps).to(device)
295292

296293
def step(
@@ -335,7 +332,8 @@ def step(
335332
336333
"""
337334
# 1. get previous step value (=t+1)
338-
prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
335+
prev_timestep = timestep
336+
timestep = min(timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps-1)
339337

340338
# 2. compute alphas, betas
341339
# change original implementation to exactly match noise levels for analogous forward process

‎tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def test_stable_diffusion_pix2pix_zero_inversion(self):
229229
image = sd_pipe.invert(**inputs).images
230230
image_slice = image[0, -3:, -3:, -1]
231231
assert image.shape == (1, 32, 32, 3)
232-
expected_slice = np.array([0.4823, 0.4783, 0.5638, 0.5201, 0.5247, 0.5644, 0.5029, 0.5404, 0.5062])
232+
expected_slice = np.array([0.4732, 0.4630, 0.5722, 0.5103, 0.5140, 0.5622, 0.5104, 0.5390, 0.5020])
233233

234234
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
235235

@@ -244,7 +244,7 @@ def test_stable_diffusion_pix2pix_zero_inversion_batch(self):
244244
image = sd_pipe.invert(**inputs).images
245245
image_slice = image[1, -3:, -3:, -1]
246246
assert image.shape == (2, 32, 32, 3)
247-
expected_slice = np.array([0.6446, 0.5232, 0.4914, 0.4441, 0.4654, 0.5546, 0.4650, 0.4938, 0.5044])
247+
expected_slice = np.array([0.6046, 0.5400, 0.4902, 0.4448, 0.4694, 0.5498, 0.4857, 0.5073, 0.5089])
248248

249249
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
250250

‎tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def test_inversion(self):
257257

258258
self.assertEqual(image.shape, (2, 32, 32, 3))
259259
expected_slice = np.array(
260-
[0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.5105, 0.5015, 0.4407, 0.4799],
260+
[0.5160, 0.5115, 0.5060, 0.5456, 0.4704, 0.5060, 0.5019, 0.4405, 0.4726],
261261
)
262262
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
263263
self.assertLessEqual(max_diff, 1e-3)

‎tests/schedulers/test_scheduler_ddim_inverse.py‎

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_steps_offset(self):
5151
scheduler_config = self.get_scheduler_config(steps_offset=1)
5252
scheduler = scheduler_class(**scheduler_config)
5353
scheduler.set_timesteps(5)
54-
assert torch.equal(scheduler.timesteps, torch.LongTensor([-199, 1, 201, 401, 601]))
54+
assert torch.equal(scheduler.timesteps, torch.LongTensor([ 1, 201, 401, 601, 801]))
5555

5656
def test_betas(self):
5757
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
@@ -104,32 +104,32 @@ def test_full_loop_no_noise(self):
104104
result_sum = torch.sum(torch.abs(sample))
105105
result_mean = torch.mean(torch.abs(sample))
106106

107-
assert abs(result_sum.item() - 509.1079) < 1e-2
108-
assert abs(result_mean.item() - 0.6629) < 1e-3
107+
assert abs(result_sum.item() - 671.6816) < 1e-2
108+
assert abs(result_mean.item() - 0.8746) < 1e-3
109109

110110
def test_full_loop_with_v_prediction(self):
111111
sample = self.full_loop(prediction_type="v_prediction")
112112

113113
result_sum = torch.sum(torch.abs(sample))
114114
result_mean = torch.mean(torch.abs(sample))
115115

116-
assert abs(result_sum.item() - 1029.129) < 1e-2
117-
assert abs(result_mean.item() - 1.3400) < 1e-3
116+
assert abs(result_sum.item() - 1394.2185) < 1e-2
117+
assert abs(result_mean.item() - 1.8154) < 1e-3
118118

119119
def test_full_loop_with_set_alpha_to_one(self):
120120
# We specify different beta, so that the first alpha is 0.99
121121
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
122122
result_sum = torch.sum(torch.abs(sample))
123123
result_mean = torch.mean(torch.abs(sample))
124124

125-
assert abs(result_sum.item() - 259.8116) < 1e-2
126-
assert abs(result_mean.item() - 0.3383) < 1e-3
125+
assert abs(result_sum.item() - 539.9622) < 1e-2
126+
assert abs(result_mean.item() - 0.7031) < 1e-3
127127

128128
def test_full_loop_with_no_set_alpha_to_one(self):
129129
# We specify different beta, so that the first alpha is 0.99
130130
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
131131
result_sum = torch.sum(torch.abs(sample))
132132
result_mean = torch.mean(torch.abs(sample))
133133

134-
assert abs(result_sum.item() - 239.055) < 1e-2
135-
assert abs(result_mean.item() - 0.3113) < 1e-3
134+
assert abs(result_sum.item() - 542.6722) < 1e-2
135+
assert abs(result_mean.item() - 0.7066) < 1e-3

0 commit comments

Comments
 (0)