Thanks to visit codestin.com
Credit goes to github.com

Skip to content

RuntimeError: group_norm in VAE decode for SDXL Masked Img2Img (even with ControlNets disabled & FP32 VAE/Latents) #11424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Haniubub opened this issue Apr 26, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@Haniubub
Copy link

Describe the bug

When using the StableDiffusionXLControlNetPipeline for masked image-to-image generation, a persistent RuntimeError occurs during the final VAE decoding step, specifically within torch.nn.functional.group_norm.

The error occurs even under the following simplified conditions:

The SDXL Refiner stage is completely bypassed.
All ControlNets are disabled by setting controlnet_conditioning_scale=[0.0, 0.0, 0.0] when calling the pipeline.
A standalone SDXL VAE (stabilityai/sdxl-vae) is loaded separately in full torch.float32 precision.
The input latents (output from the base pipeline with output_type="latent") are confirmed to contain no NaN or Inf values via torch.isnan().any() / torch.isinf().any().
The input latents are explicitly cast to torch.float32 before being passed to the vae.decode() method.
Debug prints confirm both the VAE model and the input latents are torch.float32 at the time of the vae.decode() call, and the latent shape is correct (e.g., [4, 128, 128]).
The specific error is:
RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [512] and input of shape [512, 128, 128]

Expected behavior
The vae.decode() call should successfully decode the float32 latent tensor (which contains no NaNs or Infs) using the float32 VAE without raising a RuntimeError, especially given that the ControlNets were disabled for this test run.

Additional context
Tested with both stabilityai/sdxl-vae and madebyollin/sdxl-vae-fp16-fix, loaded in FP32. Error persists.
Error occurs even when the refiner pipeline is completely bypassed.
Error occurs even when controlnet_conditioning_scale is forced to [0.0, 0.0, 0.0].
The controlnet_aux warning UserWarning: The module 'mediapipe' is not installed... appears unrelated.

Reproduction

Model Loading (model_loader.py excerpt):

Python
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
import torch

def load_models():
vae_model_id = "stabilityai/sdxl-vae" # Also tested with "madebyollin/sdxl-vae-fp16-fix"
base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
controlnet_model_ids = [ # ControlNets are loaded but disabled via scale later
"diffusers/controlnet-depth-sdxl-1.0",
"diffusers/controlnet-canny-sdxl-1.0",
"thibaud/controlnet-openpose-sdxl-1.0",
]

# Load Standalone VAE in FP32
print(f"Loading Standalone VAE from {vae_model_id}...")
vae = AutoencoderKL.from_pretrained(vae_model_id) # Load in FP32
vae.to("cuda")
print("Standalone VAE (FP32) on CUDA.")

# Load ControlNets in FP16
print("Loading ControlNet models...")
controlnets = []
for model_id in controlnet_model_ids:
    controlnet = ControlNetModel.from_pretrained(model_id, torch_dtype=torch.float16)
    controlnets.append(controlnet)
print(f"{len(controlnets)} ControlNet models loaded.")

# Load Base Pipeline in FP16 (does not receive the external VAE)
print(f"Loading SDXL Base Pipeline from {base_model_id}...")
base_pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    base_model_id, controlnet=controlnets,
    torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
base_pipe.to("cuda")
print("Base Pipeline on CUDA.")

# Refiner is loaded but bypassed in inference code
# ... (refiner loading code omitted for brevity, also moved to CUDA) ...
refiner_pipe = None # Placeholder as it's bypassed

# Enable Xformers if available
try: base_pipe.enable_xformers_memory_efficient_attention()
except Exception: print("Xformers not available or failed for base_pipe")
# try: refiner_pipe.enable_xformers_memory_efficient_attention() except Exception: pass

print("All models loaded.")
# Return base pipe and the separately loaded VAE
return base_pipe, refiner_pipe, vae # refiner_pipe is ignored later

Inference Logic (inference.py excerpt - Refiner skipped, CN disabled):

Python

Inside generate_image function:

... (load original_image, mask_image_pil) ...

num_loaded_nets = len(base_pipe.controlnet.nets) # Should be 3

Force disable ControlNets for this test run

scales = [0.0] * num_loaded_nets
print("DEBUG: Forcing all ControlNet scales to 0.0!")

Create placeholder conditioning images

conditioning_image_list = [original_image.resize((1024,1024))] * num_loaded_nets
image_list = [original_image.resize((1024,1024))] * num_loaded_nets

print(f"Running BASE Pipeline (ControlNets Disabled)...")

Base Pipeline (Latent Output)

num_inference_steps_base = 30
base_latents = base_pipe(
prompt=prompt, negative_prompt=negative_prompt, image=image_list,
mask_image=mask_image_pil, strength=strength,
controlnet_conditioning_image=conditioning_image_list,
controlnet_conditioning_scale=scales, # Should be [0.0, 0.0, 0.0]
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps_base,
output_type="latent",
).images
print(f"Base Pipeline finished.")

=== REFINER STEP IS SKIPPED FOR THIS TEST ===

=== VAE Decoding using standalone VAE ===

print("Starting final VAE decoding (from BASE Latents)...")

Check for NaN/Inf

print(f"DEBUG Base Latents Shape (raw): {base_latents.shape}") # e.g., torch.Size([1, 4, 128, 128])
print(f"DEBUG Base Latents Dtype (raw): {base_latents.dtype}") # e.g., torch.float16
print(f"DEBUG Base Latents Has NaN?: {torch.isnan(base_latents).any()}") # Prints False
print(f"DEBUG Base Latents Has Inf?: {torch.isinf(base_latents).any()}") # Prints False

Scale and convert latents to FP32

latents_to_decode = (base_latents[0] / vae.config.scaling_factor).to(torch.float32)

print(f" - VAE Dtype: {vae.dtype}") # Prints torch.float32
print(f" - Latents Input Dtype: {latents_to_decode.dtype}") # Prints torch.float32
print(f" - Latents Input Shape: {latents_to_decode.shape}") # Prints torch.Size([4, 128, 128])

VAE is already FP32

with torch.no_grad():
# THIS LINE CAUSES THE ERROR:
output_image_raw = vae.decode(latents_to_decode, return_dict=False)[0]

... (Post-processing) ...

Logs

Traceback (most recent call last):
  File "/workspace/inference.py", line 118, in generate_image
    output_image_raw = vae.decode(latents_to_decode, return_dict=False)[0]
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 323, in decode
    decoded = self._decode(z).sample
              ^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 294, in _decode
    dec = self.decoder(z)
          ^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/autoencoders/vae.py", line 300, in forward
    sample = self.mid_block(sample, latent_embeds)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 737, in forward
    hidden_states = self.resnets[0](hidden_states, temb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/resnet.py", line 327, in forward
    hidden_states = self.norm1(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/normalization.py", line 313, in forward
    return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2965, in group_norm
    return torch.group_norm(
           ^^^^^^^^^^^^^^^^^
RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [512] and input of shape [512, 128, 128]
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!!! FEHLER BEI /generate !!!!!!!!
Traceback (most recent call last):
  File "/workspace/main.py", line 75, in generate
    img_base64 = generate_image(
                 ^^^^^^^^^^^^^^^
  File "/workspace/inference.py", line 142, in generate_image
    raise e
  File "/workspace/inference.py", line 118, in generate_image
    output_image_raw = vae.decode(latents_to_decode, return_dict=False)[0]
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 323, in decode
    decoded = self._decode(z).sample
              ^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 294, in _decode
    dec = self.decoder(z)
          ^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/autoencoders/vae.py", line 300, in forward
    sample = self.mid_block(sample, latent_embeds)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 737, in forward
    hidden_states = self.resnets[0](hidden_states, temb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/diffusers/models/resnet.py", line 327, in forward
    hidden_states = self.norm1(hidden_states)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/modules/normalization.py", line 313, in forward
    return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2965, in group_norm
    return torch.group_norm(
           ^^^^^^^^^^^^^^^^^
RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [512] and input of shape [512, 128, 128]
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
INFO:     100.64.0.27:58290 - "POST /generate HTTP/1.1" 500 Internal Server Error

System Info

Platform: RunPod Cloud (Community Cloud)
GPU: 1 x NVIDIA RTX 4090 (24GB VRAM) also tested with a H100 same result
vCPU: 32
RAM: 62 GB
Disks: Container Disk: 80 GB, Volume Disk: 200 GB (/workspace)
OS: Ubuntu 22.04 (via Docker Image: runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04)
Python: 3.10.12 (within .venv, Base image provides 3.10)
CUDA Toolkit (System/Base Image): 11.8.0 (according to base image name)
Key Library Versions (Installed via Pip):
torch: 2.6.0 +cu124 (Requires CUDA Driver >= 12.1)
diffusers: 0.33.1
transformers: 4.51.3
accelerate: 1.6.0
xformers: 0.0.29.post3
controlnet-aux: 0.0.9
fastapi: 0.115.12
uvicorn: 0.34.1
numpy: 2.2.4
Pillow: 11.2.1
opencv-python: 4.11.0.86

Who can help?

No response

@Haniubub Haniubub added the bug Something isn't working label Apr 26, 2025
@DN6
Copy link
Collaborator

DN6 commented May 6, 2025

@Haniubub Could you please provide a minimal code example in a single code block to reproduce this issue? It is a bit difficult to identify the issue the way it is currently presented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants