Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[pull] master from comfyanonymous:master #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/comfy_types/node_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class IO(StrEnum):
CONTROL_NET = "CONTROL_NET"
VAE = "VAE"
MODEL = "MODEL"
LORA_MODEL = "LORA_MODEL"
LOSS_MAP = "LOSS_MAP"
CLIP_VISION = "CLIP_VISION"
CLIP_VISION_OUTPUT = "CLIP_VISION_OUTPUT"
STYLE_MODEL = "STYLE_MODEL"
Expand Down
6 changes: 3 additions & 3 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def forward(self, x, context=None, transformer_options={}):
for p in patch:
n = p(n, extra_options)

x += n
x = n + x
if "middle_patch" in transformer_patches:
patch = transformer_patches["middle_patch"]
for p in patch:
Expand Down Expand Up @@ -793,12 +793,12 @@ def forward(self, x, context=None, transformer_options={}):
for p in patch:
n = p(n, extra_options)

x += n
x = n + x
if self.is_res:
x_skip = x
x = self.ff(self.norm3(x))
if self.is_res:
x += x_skip
x = x_skip + x

return x

Expand Down
19 changes: 11 additions & 8 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,26 @@
"""

from __future__ import annotations
from typing import Optional, Callable
import torch

import collections
import copy
import inspect
import logging
import uuid
import collections
import math
import uuid
from typing import Callable, Optional

import torch

import comfy.utils
import comfy.float
import comfy.model_management
import comfy.lora
import comfy.hooks
import comfy.lora
import comfy.model_management
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP


def string_to_seed(data):
crc = 0xFFFFFFFF
Expand Down
23 changes: 22 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,7 +1081,28 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)


def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options={}):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.

Args:
sd (dict): State dictionary containing model weights and configuration
model_options (dict, optional): Additional options for model loading. Supports:
- dtype: Override model data type
- custom_operations: Custom model operations
- fp8_optimizations: Enable FP8 optimizations

Returns:
ModelPatcher: A wrapped model instance that handles device management and weight loading.
Returns None if the model configuration cannot be detected.

The function:
1. Detects and handles different model formats (regular, diffusers, mmdit)
2. Configures model dtype based on parameters and device capabilities
3. Handles weight conversion and device placement
4. Manages model optimization settings
5. Loads weights and returns a device-managed model instance
"""
dtype = model_options.get("dtype", None)

#Allow loading unets from checkpoint files
Expand Down
8 changes: 7 additions & 1 deletion comfy/weight_adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import WeightAdapterBase
from .base import WeightAdapterBase, WeightAdapterTrainBase
from .lora import LoRAAdapter
from .loha import LoHaAdapter
from .lokr import LoKrAdapter
Expand All @@ -15,3 +15,9 @@
OFTAdapter,
BOFTAdapter,
]

__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
] + [a.__name__ for a in adapters]
35 changes: 33 additions & 2 deletions comfy/weight_adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,20 @@ class WeightAdapterBase:
weights: list[torch.Tensor]

@classmethod
def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]:
raise NotImplementedError

def to_train(self) -> "WeightAdapterTrainBase":
raise NotImplementedError

@classmethod
def create_train(cls, weight, *args) -> "WeightAdapterTrainBase":
"""
weight: The original weight tensor to be modified.
*args: Additional arguments for configuration, such as rank, alpha etc.
"""
raise NotImplementedError

def calculate_weight(
self,
weight,
Expand All @@ -33,10 +41,22 @@ def calculate_weight(


class WeightAdapterTrainBase(nn.Module):
# We follow the scheme of PR #7032
def __init__(self):
super().__init__()

# [TODO] Collaborate with LoRA training PR #7032
def __call__(self, w):
"""
w: The original weight tensor to be modified.
"""
raise NotImplementedError

def passive_memory_usage(self):
raise NotImplementedError("passive_memory_usage is not implemented")

def move_to(self, device):
self.to(device)
return self.passive_memory_usage()


def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
Expand Down Expand Up @@ -102,3 +122,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
padded_tensor[new_slices] = tensor[orig_slices]

return padded_tensor


def tucker_weight_from_conv(up, down, mid):
up = up.reshape(up.size(0), up.size(1))
down = down.reshape(down.size(0), down.size(1))
return torch.einsum("m n ..., i m, n j -> i j ...", mid, up, down)


def tucker_weight(wa, wb, t):
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
66 changes: 65 additions & 1 deletion comfy/weight_adapter/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,56 @@

import torch
import comfy.model_management
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
from .base import (
WeightAdapterBase,
WeightAdapterTrainBase,
weight_decompose,
pad_tensor_to_shape,
tucker_weight_from_conv,
)


class LoraDiff(WeightAdapterTrainBase):
def __init__(self, weights):
super().__init__()
mat1, mat2, alpha, mid, dora_scale, reshape = weights
out_dim, rank = mat1.shape[0], mat1.shape[1]
rank, in_dim = mat2.shape[0], mat2.shape[1]
if mid is not None:
convdim = mid.ndim - 2
layer = (
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d
)[convdim]
else:
layer = torch.nn.Linear
self.lora_up = layer(rank, out_dim, bias=False)
self.lora_down = layer(in_dim, rank, bias=False)
self.lora_up.weight.data.copy_(mat1)
self.lora_down.weight.data.copy_(mat2)
if mid is not None:
self.lora_mid = layer(mid, rank, bias=False)
self.lora_mid.weight.data.copy_(mid)
else:
self.lora_mid = None
self.rank = rank
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)

def __call__(self, w):
org_dtype = w.dtype
if self.lora_mid is None:
diff = self.lora_up.weight @ self.lora_down.weight
else:
diff = tucker_weight_from_conv(
self.lora_up.weight, self.lora_down.weight, self.lora_mid.weight
)
scale = self.alpha / self.rank
weight = w + scale * diff.reshape(w.shape)
return weight.to(org_dtype)

def passive_memory_usage(self):
return sum(param.numel() * param.element_size() for param in self.parameters())


class LoRAAdapter(WeightAdapterBase):
Expand All @@ -13,6 +62,21 @@ def __init__(self, loaded_keys, weights):
self.loaded_keys = loaded_keys
self.weights = weights

@classmethod
def create_train(cls, weight, rank=1, alpha=1.0):
out_dim = weight.shape[0]
in_dim = weight.shape[1:].numel()
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
torch.nn.init.kaiming_uniform_(mat1, a=5**0.5)
torch.nn.init.constant_(mat2, 0.0)
return LoraDiff(
(mat1, mat2, alpha, None, None, None)
)

def to_train(self):
return LoraDiff(self.weights)

@classmethod
def load(
cls,
Expand Down
2 changes: 1 addition & 1 deletion comfy_config/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ class PyProjectSettings(BaseSettings):

tool: dict = Field(default_factory=dict)

model_config = SettingsConfigDict()
model_config = SettingsConfigDict(extra='allow')
Loading
Loading