2023-06-14 15:17:59 +00:00
|
|
|
import torch
|
2023-06-15 00:13:08 +00:00
|
|
|
from contextlib import contextmanager
|
2023-06-14 15:17:59 +00:00
|
|
|
|
|
|
|
class Linear(torch.nn.Module):
|
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
|
|
|
device=None, dtype=None) -> None:
|
|
|
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
|
|
super().__init__()
|
|
|
|
self.in_features = in_features
|
|
|
|
self.out_features = out_features
|
|
|
|
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
|
|
|
if bias:
|
|
|
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
|
|
|
else:
|
|
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
return torch.nn.functional.linear(input, self.weight, self.bias)
|
2023-06-14 23:46:08 +00:00
|
|
|
|
|
|
|
class Conv2d(torch.nn.Conv2d):
|
|
|
|
def reset_parameters(self):
|
|
|
|
return None
|
2023-06-15 00:13:08 +00:00
|
|
|
|
2023-08-18 06:46:11 +00:00
|
|
|
def conv_nd(dims, *args, **kwargs):
|
|
|
|
if dims == 2:
|
|
|
|
return Conv2d(*args, **kwargs)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
2023-06-15 00:13:08 +00:00
|
|
|
|
|
|
|
@contextmanager
|
2023-08-24 01:01:15 +00:00
|
|
|
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
|
2023-06-15 00:13:08 +00:00
|
|
|
old_torch_nn_linear = torch.nn.Linear
|
2023-08-24 01:01:15 +00:00
|
|
|
force_device = device
|
|
|
|
force_dtype = dtype
|
|
|
|
def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
|
|
|
if force_device is not None:
|
|
|
|
device = force_device
|
|
|
|
if force_dtype is not None:
|
|
|
|
dtype = force_dtype
|
|
|
|
return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
|
|
|
|
|
|
|
torch.nn.Linear = linear_with_dtype
|
2023-06-15 00:13:08 +00:00
|
|
|
try:
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
torch.nn.Linear = old_torch_nn_linear
|