2024-11-22 13:44:42 +00:00
|
|
|
from typing import Tuple, Union
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from .dual_conv3d import DualConv3d
|
|
|
|
from .causal_conv3d import CausalConv3d
|
2024-11-22 22:17:11 +00:00
|
|
|
import comfy.ops
|
|
|
|
ops = comfy.ops.disable_weight_init
|
2024-11-22 13:44:42 +00:00
|
|
|
|
|
|
|
def make_conv_nd(
|
|
|
|
dims: Union[int, Tuple[int, int]],
|
|
|
|
in_channels: int,
|
|
|
|
out_channels: int,
|
|
|
|
kernel_size: int,
|
|
|
|
stride=1,
|
|
|
|
padding=0,
|
|
|
|
dilation=1,
|
|
|
|
groups=1,
|
|
|
|
bias=True,
|
|
|
|
causal=False,
|
|
|
|
):
|
|
|
|
if dims == 2:
|
2024-11-22 22:17:11 +00:00
|
|
|
return ops.Conv2d(
|
2024-11-22 13:44:42 +00:00
|
|
|
in_channels=in_channels,
|
|
|
|
out_channels=out_channels,
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
padding=padding,
|
|
|
|
dilation=dilation,
|
|
|
|
groups=groups,
|
|
|
|
bias=bias,
|
|
|
|
)
|
|
|
|
elif dims == 3:
|
|
|
|
if causal:
|
|
|
|
return CausalConv3d(
|
|
|
|
in_channels=in_channels,
|
|
|
|
out_channels=out_channels,
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
padding=padding,
|
|
|
|
dilation=dilation,
|
|
|
|
groups=groups,
|
|
|
|
bias=bias,
|
|
|
|
)
|
2024-11-22 22:17:11 +00:00
|
|
|
return ops.Conv3d(
|
2024-11-22 13:44:42 +00:00
|
|
|
in_channels=in_channels,
|
|
|
|
out_channels=out_channels,
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
padding=padding,
|
|
|
|
dilation=dilation,
|
|
|
|
groups=groups,
|
|
|
|
bias=bias,
|
|
|
|
)
|
|
|
|
elif dims == (2, 1):
|
|
|
|
return DualConv3d(
|
|
|
|
in_channels=in_channels,
|
|
|
|
out_channels=out_channels,
|
|
|
|
kernel_size=kernel_size,
|
|
|
|
stride=stride,
|
|
|
|
padding=padding,
|
|
|
|
bias=bias,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
|
|
|
|
|
|
def make_linear_nd(
|
|
|
|
dims: int,
|
|
|
|
in_channels: int,
|
|
|
|
out_channels: int,
|
|
|
|
bias=True,
|
|
|
|
):
|
|
|
|
if dims == 2:
|
2024-11-22 22:17:11 +00:00
|
|
|
return ops.Conv2d(
|
2024-11-22 13:44:42 +00:00
|
|
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
|
|
|
)
|
|
|
|
elif dims == 3 or dims == (2, 1):
|
2024-11-22 22:17:11 +00:00
|
|
|
return ops.Conv3d(
|
2024-11-22 13:44:42 +00:00
|
|
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|