Use common function for casting weights to input.
This commit is contained in:
parent
79040635da
commit
25853d0be8
|
@ -9,6 +9,7 @@ from einops import rearrange
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
import comfy.ops
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||
|
@ -18,7 +19,7 @@ class FourierFeatures(nn.Module):
|
|||
[out_features // 2, in_features], dtype=dtype, device=device))
|
||||
|
||||
def forward(self, input):
|
||||
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
|
||||
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
|
||||
return torch.cat([f.cos(), f.sin()], dim=-1)
|
||||
|
||||
# norms
|
||||
|
@ -38,9 +39,9 @@ class LayerNorm(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
beta = self.beta
|
||||
if self.beta is not None:
|
||||
beta = beta.to(dtype=x.dtype, device=x.device)
|
||||
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
|
||||
if beta is not None:
|
||||
beta = comfy.ops.cast_to_input(beta, x)
|
||||
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
|
||||
|
||||
class GLU(nn.Module):
|
||||
def __init__(
|
||||
|
@ -123,7 +124,9 @@ class RotaryEmbedding(nn.Module):
|
|||
scale_base = 512,
|
||||
interpolation_factor = 1.,
|
||||
base = 10000,
|
||||
base_rescale_factor = 1.
|
||||
base_rescale_factor = 1.,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
super().__init__()
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
|
@ -131,8 +134,8 @@ class RotaryEmbedding(nn.Module):
|
|||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
base *= base_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
|
||||
|
||||
assert interpolation_factor >= 1.
|
||||
self.interpolation_factor = interpolation_factor
|
||||
|
@ -161,14 +164,14 @@ class RotaryEmbedding(nn.Module):
|
|||
|
||||
t = t / self.interpolation_factor
|
||||
|
||||
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
|
||||
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
|
||||
freqs = torch.cat((freqs, freqs), dim = -1)
|
||||
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
|
||||
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
return freqs, scale
|
||||
|
@ -568,7 +571,7 @@ class ContinuousTransformer(nn.Module):
|
|||
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
|
||||
|
||||
if rotary_pos_emb:
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
|
||||
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
|
||||
else:
|
||||
self.rotary_pos_emb = None
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
@ -427,7 +428,7 @@ class MMDiT(nn.Module):
|
|||
max_dim = max(h, w)
|
||||
|
||||
cur_dim = self.h_max
|
||||
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
|
||||
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
|
||||
|
||||
if max_dim > cur_dim:
|
||||
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
|
||||
|
@ -455,7 +456,7 @@ class MMDiT(nn.Module):
|
|||
t = timestep
|
||||
|
||||
c = self.cond_seq_linear(c_seq) # B, T_c, D
|
||||
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
|
||||
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
|
||||
|
||||
global_cond = self.t_embedder(t, x.dtype) # B, D
|
||||
|
||||
|
|
|
@ -19,14 +19,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
import comfy.ops
|
||||
|
||||
class OptimizedAttention(nn.Module):
|
||||
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
|
||||
|
@ -78,13 +71,13 @@ class GlobalResponseNorm(nn.Module):
|
|||
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
|
||||
def __init__(self, dim, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
|
||||
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
|
||||
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
from torch.utils import checkpoint
|
||||
|
@ -234,7 +235,7 @@ class HunYuanDiT(nn.Module):
|
|||
|
||||
if self.use_style_cond:
|
||||
# Here we use a default learned embedder layer for future extension.
|
||||
self.style_embedder = nn.Embedding(1, hidden_size, dtype=dtype, device=device)
|
||||
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
|
||||
self.extra_in_dim += hidden_size
|
||||
|
||||
# Text embedding for `add`
|
||||
|
@ -321,7 +322,7 @@ class HunYuanDiT(nn.Module):
|
|||
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||
|
||||
padding = self.text_embedding_padding.to(text_states)
|
||||
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||
|
||||
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
|
||||
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.attention import optimized_attention #TODO
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ops
|
||||
|
||||
class AttentionPool(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
|
||||
|
@ -19,7 +19,7 @@ class AttentionPool(nn.Module):
|
|||
x = x[:,:self.positional_embedding.shape[0] - 1]
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(dtype=x.dtype, device=x.device) # (L+1)NC
|
||||
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
|
||||
|
||||
q = self.q_proj(x[:1])
|
||||
k = self.k_proj(x)
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||
from .. import attention
|
||||
from einops import rearrange, repeat
|
||||
from .util import timestep_embedding
|
||||
import comfy.ops
|
||||
|
||||
def default(x, y):
|
||||
if x is not None:
|
||||
|
@ -926,7 +927,7 @@ class MMDiT(nn.Module):
|
|||
context = self.context_processor(context)
|
||||
|
||||
hw = x.shape[-2:]
|
||||
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
|
||||
x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
|
||||
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
if y is not None and self.y_embedder is not None:
|
||||
y = self.y_embedder(y) # (N, D)
|
||||
|
|
25
comfy/ops.py
25
comfy/ops.py
|
@ -19,14 +19,17 @@
|
|||
import torch
|
||||
import comfy.model_management
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
bias = cast_to_input(s.bias, input, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
bias = s.bias_function(bias)
|
||||
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
|
||||
weight = cast_to_input(s.weight, input, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
@ -168,6 +171,21 @@ class disable_weight_init:
|
|||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.comfy_cast_weights:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
|
@ -202,3 +220,6 @@ class manual_cast(disable_weight_init):
|
|||
|
||||
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
|
Loading…
Reference in New Issue