From 9187a09483514c9d363bf064481c4614563951b0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 May 2025 03:26:20 -0700 Subject: [PATCH] Change cosmos and hydit models to use the native RMSNorm. (#7934) --- comfy/ldm/cosmos/blocks.py | 11 +++++------ comfy/ldm/cosmos/model.py | 4 +--- comfy/ldm/hydit/models.py | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py index 84fd6d839bc..a12f892d21d 100644 --- a/comfy/ldm/cosmos/blocks.py +++ b/comfy/ldm/cosmos/blocks.py @@ -23,7 +23,6 @@ from einops.layers.torch import Rearrange from torch import nn -from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm from comfy.ldm.modules.attention import optimized_attention @@ -37,11 +36,11 @@ def apply_rotary_pos_emb( return t_out -def get_normalization(name: str, channels: int, weight_args={}): +def get_normalization(name: str, channels: int, weight_args={}, operations=None): if name == "I": return nn.Identity() elif name == "R": - return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) + return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) else: raise ValueError(f"Normalization {name} not found") @@ -120,15 +119,15 @@ def __init__( self.to_q = nn.Sequential( operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args), - get_normalization(qkv_norm[0], norm_dim), + get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations), ) self.to_k = nn.Sequential( operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), - get_normalization(qkv_norm[1], norm_dim), + get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations), ) self.to_v = nn.Sequential( operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), - get_normalization(qkv_norm[2], norm_dim), + get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations), ) self.to_out = nn.Sequential( diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py index 06d0baef3f1..4836e0b69e8 100644 --- a/comfy/ldm/cosmos/model.py +++ b/comfy/ldm/cosmos/model.py @@ -27,8 +27,6 @@ from enum import Enum import logging -from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm - from .blocks import ( FinalLayer, GeneralDITTransformerBlock, @@ -195,7 +193,7 @@ def __init__( if self.affline_emb_norm: logging.debug("Building affine embedding normalization layer") - self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6) + self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype) else: self.affline_norm = nn.Identity() diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 359f6a9651c..5ba2b76e0ca 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -3,7 +3,7 @@ import torch.nn as nn import comfy.ops -from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm +from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from torch.utils import checkpoint @@ -51,7 +51,7 @@ def __init__(self, if norm_type == "layer": norm_layer = operations.LayerNorm elif norm_type == "rms": - norm_layer = RMSNorm + norm_layer = operations.RMSNorm else: raise ValueError(f"Unknown norm_type: {norm_type}")