Remove unused code and torchdiffeq dependency.
This commit is contained in:
parent
1141029a4a
commit
c910b4a01c
|
@ -3,7 +3,6 @@ import math
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchdiffeq import odeint
|
|
||||||
import torchsde
|
import torchsde
|
||||||
from tqdm.auto import trange, tqdm
|
from tqdm.auto import trange, tqdm
|
||||||
|
|
||||||
|
@ -287,30 +286,6 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
|
||||||
s_in = x.new_ones([x.shape[0]])
|
|
||||||
v = torch.randint_like(x, 2) * 2 - 1
|
|
||||||
fevals = 0
|
|
||||||
def ode_fn(sigma, x):
|
|
||||||
nonlocal fevals
|
|
||||||
with torch.enable_grad():
|
|
||||||
x = x[0].detach().requires_grad_()
|
|
||||||
denoised = model(x, sigma * s_in, **extra_args)
|
|
||||||
d = to_d(x, sigma, denoised)
|
|
||||||
fevals += 1
|
|
||||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
|
||||||
d_ll = (v * grad).flatten(1).sum(1)
|
|
||||||
return d.detach(), d_ll
|
|
||||||
x_min = x, x.new_zeros([x.shape[0]])
|
|
||||||
t = x.new_tensor([sigma_min, sigma_max])
|
|
||||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
|
||||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
|
||||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
|
||||||
return ll_prior + delta_ll, {'fevals': fevals}
|
|
||||||
|
|
||||||
|
|
||||||
class PIDStepSizeController:
|
class PIDStepSizeController:
|
||||||
"""A PID controller for ODE adaptive step size control."""
|
"""A PID controller for ODE adaptive step size control."""
|
||||||
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
torch
|
torch
|
||||||
torchdiffeq
|
|
||||||
torchsde
|
torchsde
|
||||||
einops
|
einops
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
|
|
Loading…
Reference in New Issue