diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 6cc4d26f29b4..b57fd1ef4ae1 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -28,6 +28,15 @@ def cosine_distance(image_embeds, text_embeds): normalized_text_embeds = nn.functional.normalize(text_embeds) return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) +## Seems to be working better for now, still not the best of safety models +def jaccard_distance(image_embeds, text_embeds, eps=-1): + scaler = torch.matmul(image_embeds, text_embeds.t()) + image_square = image_embeds.pow(2).sum(dim=-1, keepdim=True) + text_square = text_embeds.pow(2).sum(dim=-1, keepdim=True) + print((scaler / (image_square + text_square.transpose(0,1) - scaler + eps))*2) + print(f'{cosine_distance(image_embeds,text_embeds)=}') + return (scaler / (image_square + text_square.transpose(0,1) - scaler + eps))*2 + class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig @@ -52,8 +61,8 @@ def forward(self, clip_input, images): image_embeds = self.visual_projection(pooled_output) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() - cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() + special_cos_dist = jaccard_distance(image_embeds, self.special_care_embeds).cpu().float().numpy() + cos_dist = jaccard_distance(image_embeds, self.concept_embeds).cpu().float().numpy() result = [] batch_size = image_embeds.shape[0] @@ -103,8 +112,8 @@ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor) pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) - special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) - cos_dist = cosine_distance(image_embeds, self.concept_embeds) + special_cos_dist = jaccard_distance(image_embeds, self.special_care_embeds) + cos_dist = jaccard_distance(image_embeds, self.concept_embeds) # increase this value to create a stronger `nsfw` filter # at the cost of increasing the possibility of filtering benign images @@ -123,3 +132,4 @@ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor) images[has_nsfw_concepts] = 0.0 # black image return images, has_nsfw_concepts +