2023-02-08 21:51:19 +00:00
from . k_diffusion import sampling as k_diffusion_sampling
2023-02-11 08:18:27 +00:00
from . extra_samplers import uni_pc
2023-01-03 06:53:32 +00:00
import torch
2023-12-14 16:41:49 +00:00
import collections
2023-04-15 22:55:17 +00:00
from comfy import model_management
2023-05-10 17:58:19 +00:00
import math
2024-03-11 20:24:47 +00:00
import logging
2024-04-04 04:48:42 +00:00
import comfy . sampler_helpers
2023-10-24 07:38:41 +00:00
2023-12-13 21:10:03 +00:00
def get_area_and_mult ( conds , x_in , timestep_in ) :
2024-06-10 05:05:53 +00:00
dims = tuple ( x_in . shape [ 2 : ] )
area = None
2023-12-13 21:10:03 +00:00
strength = 1.0
if ' timestep_start ' in conds :
timestep_start = conds [ ' timestep_start ' ]
if timestep_in [ 0 ] > timestep_start :
return None
if ' timestep_end ' in conds :
timestep_end = conds [ ' timestep_end ' ]
if timestep_in [ 0 ] < timestep_end :
return None
if ' area ' in conds :
2024-06-10 05:05:53 +00:00
area = list ( conds [ ' area ' ] )
2023-12-13 21:10:03 +00:00
if ' strength ' in conds :
strength = conds [ ' strength ' ]
2024-06-10 05:05:53 +00:00
input_x = x_in
if area is not None :
for i in range ( len ( dims ) ) :
area [ i ] = min ( input_x . shape [ i + 2 ] - area [ len ( dims ) + i ] , area [ i ] )
input_x = input_x . narrow ( i + 2 , area [ len ( dims ) + i ] , area [ i ] )
2023-12-13 21:10:03 +00:00
if ' mask ' in conds :
# Scale the mask to the size of the input
# The mask should have been resized as we began the sampling process
mask_strength = 1.0
if " mask_strength " in conds :
mask_strength = conds [ " mask_strength " ]
mask = conds [ ' mask ' ]
2024-06-10 05:05:53 +00:00
assert ( mask . shape [ 1 : ] == x_in . shape [ 2 : ] )
mask = mask [ : input_x . shape [ 0 ] ]
if area is not None :
for i in range ( len ( dims ) ) :
mask = mask . narrow ( i + 1 , area [ len ( dims ) + i ] , area [ i ] )
mask = mask * mask_strength
2023-12-13 21:10:03 +00:00
mask = mask . unsqueeze ( 1 ) . repeat ( input_x . shape [ 0 ] / / mask . shape [ 0 ] , input_x . shape [ 1 ] , 1 , 1 )
else :
mask = torch . ones_like ( input_x )
mult = mask * strength
2024-06-10 05:05:53 +00:00
if ' mask ' not in conds and area is not None :
2023-12-13 21:10:03 +00:00
rr = 8
2024-06-10 05:05:53 +00:00
for i in range ( len ( dims ) ) :
if area [ len ( dims ) + i ] != 0 :
for t in range ( rr ) :
m = mult . narrow ( i + 2 , t , 1 )
m * = ( ( 1.0 / rr ) * ( t + 1 ) )
if ( area [ i ] + area [ len ( dims ) + i ] ) < x_in . shape [ i + 2 ] :
for t in range ( rr ) :
m = mult . narrow ( i + 2 , area [ i ] - 1 - t , 1 )
m * = ( ( 1.0 / rr ) * ( t + 1 ) )
2023-12-13 21:10:03 +00:00
conditioning = { }
model_conds = conds [ " model_conds " ]
for c in model_conds :
conditioning [ c ] = model_conds [ c ] . process_cond ( batch_size = x_in . shape [ 0 ] , device = x_in . device , area = area )
2023-12-14 16:41:49 +00:00
control = conds . get ( ' control ' , None )
2023-12-13 21:10:03 +00:00
patches = None
if ' gligen ' in conds :
gligen = conds [ ' gligen ' ]
patches = { }
gligen_type = gligen [ 0 ]
gligen_model = gligen [ 1 ]
if gligen_type == " position " :
gligen_patch = gligen_model . model . set_position ( input_x . shape , gligen [ 2 ] , input_x . device )
else :
gligen_patch = gligen_model . model . set_empty ( input_x . shape , input_x . device )
patches [ ' middle_patch ' ] = [ gligen_patch ]
2023-12-14 16:41:49 +00:00
cond_obj = collections . namedtuple ( ' cond_obj ' , [ ' input_x ' , ' mult ' , ' conditioning ' , ' area ' , ' control ' , ' patches ' ] )
return cond_obj ( input_x , mult , conditioning , area , control , patches )
2023-12-13 21:10:03 +00:00
def cond_equal_size ( c1 , c2 ) :
if c1 is c2 :
return True
if c1 . keys ( ) != c2 . keys ( ) :
return False
for k in c1 :
if not c1 [ k ] . can_concat ( c2 [ k ] ) :
return False
return True
def can_concat_cond ( c1 , c2 ) :
2023-12-14 16:41:49 +00:00
if c1 . input_x . shape != c2 . input_x . shape :
2023-12-13 21:10:03 +00:00
return False
2023-12-14 16:41:49 +00:00
def objects_concatable ( obj1 , obj2 ) :
if ( obj1 is None ) != ( obj2 is None ) :
2023-12-13 21:10:03 +00:00
return False
2023-12-14 16:41:49 +00:00
if obj1 is not None :
if obj1 is not obj2 :
return False
return True
2023-12-13 21:10:03 +00:00
2023-12-14 16:41:49 +00:00
if not objects_concatable ( c1 . control , c2 . control ) :
return False
if not objects_concatable ( c1 . patches , c2 . patches ) :
2023-12-13 21:10:03 +00:00
return False
2023-12-14 16:41:49 +00:00
return cond_equal_size ( c1 . conditioning , c2 . conditioning )
2023-12-13 21:10:03 +00:00
def cond_cat ( c_list ) :
c_crossattn = [ ]
c_concat = [ ]
c_adm = [ ]
crossattn_max_len = 0
temp = { }
for x in c_list :
for k in x :
cur = temp . get ( k , [ ] )
cur . append ( x [ k ] )
temp [ k ] = cur
out = { }
for k in temp :
conds = temp [ k ]
out [ k ] = conds [ 0 ] . concat ( conds [ 1 : ] )
return out
2024-04-01 21:23:07 +00:00
def calc_cond_batch ( model , conds , x_in , timestep , model_options ) :
out_conds = [ ]
out_counts = [ ]
2023-12-13 21:10:03 +00:00
to_run = [ ]
2024-04-01 21:23:07 +00:00
for i in range ( len ( conds ) ) :
out_conds . append ( torch . zeros_like ( x_in ) )
out_counts . append ( torch . ones_like ( x_in ) * 1e-37 )
2023-12-13 21:10:03 +00:00
2024-04-01 21:23:07 +00:00
cond = conds [ i ]
if cond is not None :
for x in cond :
p = get_area_and_mult ( x , x_in , timestep )
if p is None :
continue
to_run + = [ ( p , i ) ]
2023-12-13 21:10:03 +00:00
while len ( to_run ) > 0 :
first = to_run [ 0 ]
first_shape = first [ 0 ] [ 0 ] . shape
to_batch_temp = [ ]
for x in range ( len ( to_run ) ) :
if can_concat_cond ( to_run [ x ] [ 0 ] , first [ 0 ] ) :
to_batch_temp + = [ x ]
to_batch_temp . reverse ( )
to_batch = to_batch_temp [ : 1 ]
free_memory = model_management . get_free_memory ( x_in . device )
for i in range ( 1 , len ( to_batch_temp ) + 1 ) :
batch_amount = to_batch_temp [ : len ( to_batch_temp ) / / i ]
input_shape = [ len ( batch_amount ) * first_shape [ 0 ] ] + list ( first_shape ) [ 1 : ]
if model . memory_required ( input_shape ) < free_memory :
to_batch = batch_amount
break
input_x = [ ]
mult = [ ]
c = [ ]
cond_or_uncond = [ ]
area = [ ]
control = None
patches = None
for x in to_batch :
o = to_run . pop ( x )
p = o [ 0 ]
2023-12-14 16:41:49 +00:00
input_x . append ( p . input_x )
mult . append ( p . mult )
c . append ( p . conditioning )
area . append ( p . area )
cond_or_uncond . append ( o [ 1 ] )
control = p . control
patches = p . patches
2023-12-13 21:10:03 +00:00
batch_chunks = len ( cond_or_uncond )
input_x = torch . cat ( input_x )
c = cond_cat ( c )
timestep_ = torch . cat ( [ timestep ] * batch_chunks )
if control is not None :
c [ ' control ' ] = control . get_control ( input_x , timestep_ , c , len ( cond_or_uncond ) )
transformer_options = { }
if ' transformer_options ' in model_options :
transformer_options = model_options [ ' transformer_options ' ] . copy ( )
if patches is not None :
if " patches " in transformer_options :
cur_patches = transformer_options [ " patches " ] . copy ( )
for p in patches :
if p in cur_patches :
cur_patches [ p ] = cur_patches [ p ] + patches [ p ]
2023-04-23 16:35:25 +00:00
else :
2023-12-13 21:10:03 +00:00
cur_patches [ p ] = patches [ p ]
2024-02-23 08:19:43 +00:00
transformer_options [ " patches " ] = cur_patches
2023-12-13 21:10:03 +00:00
else :
transformer_options [ " patches " ] = patches
2023-11-16 09:07:35 +00:00
2023-12-13 21:10:03 +00:00
transformer_options [ " cond_or_uncond " ] = cond_or_uncond [ : ]
transformer_options [ " sigmas " ] = timestep
2023-03-31 17:04:39 +00:00
2023-12-13 21:10:03 +00:00
c [ ' transformer_options ' ] = transformer_options
2023-02-08 19:05:31 +00:00
2023-12-13 21:10:03 +00:00
if ' model_function_wrapper ' in model_options :
output = model_options [ ' model_function_wrapper ' ] ( model . apply_model , { " input " : input_x , " timestep " : timestep_ , " c " : c , " cond_or_uncond " : cond_or_uncond } ) . chunk ( batch_chunks )
else :
output = model . apply_model ( input_x , timestep_ , * * c ) . chunk ( batch_chunks )
2023-01-26 17:06:48 +00:00
2023-12-13 21:10:03 +00:00
for o in range ( batch_chunks ) :
2024-04-01 21:23:07 +00:00
cond_index = cond_or_uncond [ o ]
2024-06-10 05:05:53 +00:00
a = area [ o ]
if a is None :
out_conds [ cond_index ] + = output [ o ] * mult [ o ]
out_counts [ cond_index ] + = mult [ o ]
else :
out_c = out_conds [ cond_index ]
out_cts = out_counts [ cond_index ]
dims = len ( a ) / / 2
for i in range ( dims ) :
out_c = out_c . narrow ( i + 2 , a [ i + dims ] , a [ i ] )
out_cts = out_cts . narrow ( i + 2 , a [ i + dims ] , a [ i ] )
out_c + = output [ o ] * mult [ o ]
out_cts + = mult [ o ]
2023-01-26 17:06:48 +00:00
2024-04-01 21:23:07 +00:00
for i in range ( len ( out_conds ) ) :
out_conds [ i ] / = out_counts [ i ]
return out_conds
def calc_cond_uncond_batch ( model , cond , uncond , x_in , timestep , model_options ) : #TODO: remove
logging . warning ( " WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead. " )
return tuple ( calc_cond_batch ( model , [ cond , uncond ] , x_in , timestep , model_options ) )
2023-01-26 17:06:48 +00:00
2024-04-05 03:38:57 +00:00
def cfg_function ( model , cond_pred , uncond_pred , cond_scale , x , timestep , model_options = { } , cond = None , uncond = None ) :
2024-04-04 04:48:42 +00:00
if " sampler_cfg_function " in model_options :
args = { " cond " : x - cond_pred , " uncond " : x - uncond_pred , " cond_scale " : cond_scale , " timestep " : timestep , " input " : x , " sigma " : timestep ,
" cond_denoised " : cond_pred , " uncond_denoised " : uncond_pred , " model " : model , " model_options " : model_options }
cfg_result = x - model_options [ " sampler_cfg_function " ] ( args )
else :
cfg_result = uncond_pred + ( cond_pred - uncond_pred ) * cond_scale
2024-04-01 21:23:07 +00:00
2024-04-04 04:48:42 +00:00
for fn in model_options . get ( " sampler_post_cfg_function " , [ ] ) :
args = { " denoised " : cfg_result , " cond " : cond , " uncond " : uncond , " model " : model , " uncond_denoised " : uncond_pred , " cond_denoised " : cond_pred ,
" sigma " : timestep , " model_options " : model_options , " input " : x }
cfg_result = fn ( args )
2024-04-01 21:23:07 +00:00
2024-04-04 04:48:42 +00:00
return cfg_result
2024-04-01 21:23:07 +00:00
2024-04-04 04:48:42 +00:00
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function ( model , x , timestep , uncond , cond , cond_scale , model_options = { } , seed = None ) :
if math . isclose ( cond_scale , 1.0 ) and model_options . get ( " disable_cfg1_optimization " , False ) == False :
uncond_ = None
else :
uncond_ = uncond
2023-12-13 20:52:11 +00:00
2024-04-04 04:48:42 +00:00
conds = [ cond , uncond_ ]
out = calc_cond_batch ( model , conds , x , timestep , model_options )
2024-04-05 03:38:57 +00:00
return cfg_function ( model , out [ 0 ] , out [ 1 ] , cond_scale , x , timestep , model_options = model_options , cond = cond , uncond = uncond_ )
2023-01-03 06:53:32 +00:00
2023-02-16 15:38:08 +00:00
2024-04-04 04:48:42 +00:00
class KSamplerX0Inpaint :
2024-03-03 20:11:13 +00:00
def __init__ ( self , model , sigmas ) :
2023-02-11 08:18:27 +00:00
self . inner_model = model
2024-03-03 20:11:13 +00:00
self . sigmas = sigmas
2024-04-04 04:48:42 +00:00
def __call__ ( self , x , sigma , denoise_mask , model_options = { } , seed = None ) :
2023-02-15 06:49:17 +00:00
if denoise_mask is not None :
2024-03-03 20:34:13 +00:00
if " denoise_mask_function " in model_options :
2024-03-03 20:11:13 +00:00
denoise_mask = model_options [ " denoise_mask_function " ] ( sigma , denoise_mask , extra_options = { " model " : self . inner_model , " sigmas " : self . sigmas } )
2023-02-15 06:49:17 +00:00
latent_mask = 1. - denoise_mask
2024-03-01 19:24:41 +00:00
x = x * denoise_mask + self . inner_model . inner_model . model_sampling . noise_scaling ( sigma . reshape ( [ sigma . shape [ 0 ] ] + [ 1 ] * ( len ( self . noise . shape ) - 1 ) ) , self . noise , self . latent_image ) * latent_mask
2024-04-04 04:48:42 +00:00
out = self . inner_model ( x , sigma , model_options = model_options , seed = seed )
2023-02-15 06:49:17 +00:00
if denoise_mask is not None :
2023-12-08 21:02:08 +00:00
out = out * denoise_mask + self . latent_image * latent_mask
2023-02-15 06:49:17 +00:00
return out
2023-02-11 08:18:27 +00:00
2024-04-05 02:08:49 +00:00
def simple_scheduler ( model_sampling , steps ) :
s = model_sampling
2023-01-03 06:53:32 +00:00
sigs = [ ]
2023-10-31 21:33:43 +00:00
ss = len ( s . sigmas ) / steps
2023-01-03 06:53:32 +00:00
for x in range ( steps ) :
2023-10-31 21:33:43 +00:00
sigs + = [ float ( s . sigmas [ - ( 1 + int ( x * ss ) ) ] ) ]
2023-01-03 06:53:32 +00:00
sigs + = [ 0.0 ]
return torch . FloatTensor ( sigs )
2024-04-05 02:08:49 +00:00
def ddim_scheduler ( model_sampling , steps ) :
s = model_sampling
2023-02-23 02:06:43 +00:00
sigs = [ ]
2024-02-09 23:11:34 +00:00
ss = max ( len ( s . sigmas ) / / steps , 1 )
2023-10-31 21:33:43 +00:00
x = 1
while x < len ( s . sigmas ) :
sigs + = [ float ( s . sigmas [ x ] ) ]
x + = ss
sigs = sigs [ : : - 1 ]
2023-02-23 02:06:43 +00:00
sigs + = [ 0.0 ]
return torch . FloatTensor ( sigs )
2024-04-05 02:08:49 +00:00
def normal_scheduler ( model_sampling , steps , sgm = False , floor = False ) :
s = model_sampling
2023-10-31 21:33:43 +00:00
start = s . timestep ( s . sigma_max )
end = s . timestep ( s . sigma_min )
if sgm :
timesteps = torch . linspace ( start , end , steps + 1 ) [ : - 1 ]
else :
timesteps = torch . linspace ( start , end , steps )
2023-08-14 04:13:35 +00:00
sigs = [ ]
for x in range ( len ( timesteps ) ) :
ts = timesteps [ x ]
2023-10-31 21:33:43 +00:00
sigs . append ( s . sigma ( ts ) )
2023-08-14 04:13:35 +00:00
sigs + = [ 0.0 ]
return torch . FloatTensor ( sigs )
2023-04-29 07:16:58 +00:00
def get_mask_aabb ( masks ) :
if masks . numel ( ) == 0 :
return torch . zeros ( ( 0 , 4 ) , device = masks . device , dtype = torch . int )
b = masks . shape [ 0 ]
bounding_boxes = torch . zeros ( ( b , 4 ) , device = masks . device , dtype = torch . int )
is_empty = torch . zeros ( ( b ) , device = masks . device , dtype = torch . bool )
for i in range ( b ) :
mask = masks [ i ]
if mask . numel ( ) == 0 :
continue
if torch . max ( mask != 0 ) == False :
is_empty [ i ] = True
continue
y , x = torch . where ( mask )
bounding_boxes [ i , 0 ] = torch . min ( x )
bounding_boxes [ i , 1 ] = torch . min ( y )
bounding_boxes [ i , 2 ] = torch . max ( x )
bounding_boxes [ i , 3 ] = torch . max ( y )
return bounding_boxes , is_empty
2024-06-10 05:05:53 +00:00
def resolve_areas_and_cond_masks_multidim ( conditions , dims , device ) :
2023-04-25 07:15:25 +00:00
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
for i in range ( len ( conditions ) ) :
c = conditions [ i ]
2023-10-25 03:31:12 +00:00
if ' area ' in c :
area = c [ ' area ' ]
2023-09-06 07:26:55 +00:00
if area [ 0 ] == " percentage " :
2023-10-25 03:31:12 +00:00
modified = c . copy ( )
2024-06-10 05:05:53 +00:00
a = area [ 1 : ]
a_len = len ( a ) / / 2
area = ( )
for d in range ( len ( dims ) ) :
area + = ( max ( 1 , round ( a [ d ] * dims [ d ] ) ) , )
for d in range ( len ( dims ) ) :
area + = ( round ( a [ d + a_len ] * dims [ d ] ) , )
2023-09-06 07:26:55 +00:00
modified [ ' area ' ] = area
2023-10-25 03:31:12 +00:00
c = modified
2023-09-06 07:26:55 +00:00
conditions [ i ] = c
2023-10-25 03:31:12 +00:00
if ' mask ' in c :
mask = c [ ' mask ' ]
2023-04-25 07:15:25 +00:00
mask = mask . to ( device = device )
2023-10-25 03:31:12 +00:00
modified = c . copy ( )
2024-06-10 05:05:53 +00:00
if len ( mask . shape ) == len ( dims ) :
2023-04-25 07:15:25 +00:00
mask = mask . unsqueeze ( 0 )
2024-06-10 05:05:53 +00:00
if mask . shape [ 1 : ] != dims :
mask = torch . nn . functional . interpolate ( mask . unsqueeze ( 1 ) , size = dims , mode = ' bilinear ' , align_corners = False ) . squeeze ( 1 )
2023-04-25 07:15:25 +00:00
2024-06-10 05:05:53 +00:00
if modified . get ( " set_area_to_bounds " , False ) : #TODO: handle dim != 2
2023-04-25 07:15:25 +00:00
bounds = torch . max ( torch . abs ( mask ) , dim = 0 ) . values . unsqueeze ( 0 )
2023-04-29 07:16:58 +00:00
boxes , is_empty = get_mask_aabb ( bounds )
if is_empty [ 0 ] :
# Use the minimum possible size for efficiency reasons. (Since the mask is all-0, this becomes a noop anyway)
modified [ ' area ' ] = ( 8 , 8 , 0 , 0 )
2023-04-25 07:15:25 +00:00
else :
2023-04-29 07:16:58 +00:00
box = boxes [ 0 ]
2023-04-25 07:15:25 +00:00
H , W , Y , X = ( box [ 3 ] - box [ 1 ] + 1 , box [ 2 ] - box [ 0 ] + 1 , box [ 1 ] , box [ 0 ] )
2023-05-09 16:18:18 +00:00
H = max ( 8 , H )
W = max ( 8 , W )
2023-04-29 07:16:58 +00:00
area = ( int ( H ) , int ( W ) , int ( Y ) , int ( X ) )
modified [ ' area ' ] = area
2023-04-25 07:15:25 +00:00
modified [ ' mask ' ] = mask
2023-10-25 03:31:12 +00:00
conditions [ i ] = modified
2023-04-25 07:15:25 +00:00
2024-06-10 05:05:53 +00:00
def resolve_areas_and_cond_masks ( conditions , h , w , device ) :
logging . warning ( " WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead. " )
return resolve_areas_and_cond_masks_multidim ( conditions , [ h , w ] , device )
def create_cond_with_same_area_if_none ( conds , c ) : #TODO: handle dim != 2
2023-10-25 03:31:12 +00:00
if ' area ' not in c :
2023-01-26 17:06:48 +00:00
return
2023-10-25 03:31:12 +00:00
c_area = c [ ' area ' ]
2023-01-26 17:06:48 +00:00
smallest = None
for x in conds :
2023-10-25 03:31:12 +00:00
if ' area ' in x :
a = x [ ' area ' ]
2023-01-26 17:06:48 +00:00
if c_area [ 2 ] > = a [ 2 ] and c_area [ 3 ] > = a [ 3 ] :
if a [ 0 ] + a [ 2 ] > = c_area [ 0 ] + c_area [ 2 ] :
if a [ 1 ] + a [ 3 ] > = c_area [ 1 ] + c_area [ 3 ] :
if smallest is None :
smallest = x
2023-10-25 03:31:12 +00:00
elif ' area ' not in smallest :
2023-01-26 17:06:48 +00:00
smallest = x
else :
2023-10-25 03:31:12 +00:00
if smallest [ ' area ' ] [ 0 ] * smallest [ ' area ' ] [ 1 ] > a [ 0 ] * a [ 1 ] :
2023-01-26 17:06:48 +00:00
smallest = x
else :
if smallest is None :
smallest = x
if smallest is None :
return
2023-10-25 03:31:12 +00:00
if ' area ' in smallest :
if smallest [ ' area ' ] == c_area :
2023-01-26 17:06:48 +00:00
return
2023-10-25 03:31:12 +00:00
out = c . copy ( )
out [ ' model_conds ' ] = smallest [ ' model_conds ' ] . copy ( ) #TODO: which fields should be copied?
conds + = [ out ]
2023-01-03 06:53:32 +00:00
2023-07-24 13:25:02 +00:00
def calculate_start_end_timesteps ( model , conds ) :
2023-11-01 02:14:32 +00:00
s = model . model_sampling
2023-07-24 13:25:02 +00:00
for t in range ( len ( conds ) ) :
x = conds [ t ]
timestep_start = None
timestep_end = None
2023-10-25 03:31:12 +00:00
if ' start_percent ' in x :
2023-11-01 02:14:32 +00:00
timestep_start = s . percent_to_sigma ( x [ ' start_percent ' ] )
2023-10-25 03:31:12 +00:00
if ' end_percent ' in x :
2023-11-01 02:14:32 +00:00
timestep_end = s . percent_to_sigma ( x [ ' end_percent ' ] )
2023-07-24 13:25:02 +00:00
if ( timestep_start is not None ) or ( timestep_end is not None ) :
2023-10-25 03:31:12 +00:00
n = x . copy ( )
2023-07-24 13:25:02 +00:00
if ( timestep_start is not None ) :
n [ ' timestep_start ' ] = timestep_start
if ( timestep_end is not None ) :
n [ ' timestep_end ' ] = timestep_end
2023-10-25 03:31:12 +00:00
conds [ t ] = n
2023-07-24 13:25:02 +00:00
2023-07-24 21:50:49 +00:00
def pre_run_control ( model , conds ) :
2023-11-01 02:14:32 +00:00
s = model . model_sampling
2023-07-24 21:50:49 +00:00
for t in range ( len ( conds ) ) :
x = conds [ t ]
timestep_start = None
timestep_end = None
2023-11-01 02:14:32 +00:00
percent_to_timestep_function = lambda a : s . percent_to_sigma ( a )
2023-10-25 03:31:12 +00:00
if ' control ' in x :
2023-11-01 02:14:32 +00:00
x [ ' control ' ] . pre_run ( model , percent_to_timestep_function )
2023-07-24 21:50:49 +00:00
2023-04-19 13:36:19 +00:00
def apply_empty_x_to_equal_area ( conds , uncond , name , uncond_fill_func ) :
2023-02-16 15:38:08 +00:00
cond_cnets = [ ]
cond_other = [ ]
uncond_cnets = [ ]
uncond_other = [ ]
for t in range ( len ( conds ) ) :
x = conds [ t ]
2023-10-25 03:31:12 +00:00
if ' area ' not in x :
if name in x and x [ name ] is not None :
cond_cnets . append ( x [ name ] )
2023-02-16 15:38:08 +00:00
else :
cond_other . append ( ( x , t ) )
for t in range ( len ( uncond ) ) :
x = uncond [ t ]
2023-10-25 03:31:12 +00:00
if ' area ' not in x :
if name in x and x [ name ] is not None :
uncond_cnets . append ( x [ name ] )
2023-02-16 15:38:08 +00:00
else :
uncond_other . append ( ( x , t ) )
if len ( uncond_cnets ) > 0 :
return
for x in range ( len ( cond_cnets ) ) :
temp = uncond_other [ x % len ( uncond_other ) ]
o = temp [ 0 ]
2023-10-25 03:31:12 +00:00
if name in o and o [ name ] is not None :
n = o . copy ( )
2023-04-19 13:36:19 +00:00
n [ name ] = uncond_fill_func ( cond_cnets , x )
2023-10-25 03:31:12 +00:00
uncond + = [ n ]
2023-02-16 15:38:08 +00:00
else :
2023-10-25 03:31:12 +00:00
n = o . copy ( )
2023-04-19 13:36:19 +00:00
n [ name ] = uncond_fill_func ( cond_cnets , x )
2023-10-25 03:31:12 +00:00
uncond [ temp [ 1 ] ] = n
2023-04-02 03:19:15 +00:00
2023-10-25 03:31:12 +00:00
def encode_model_conds ( model_function , conds , noise , device , prompt_type , * * kwargs ) :
2023-10-19 00:36:37 +00:00
for t in range ( len ( conds ) ) :
x = conds [ t ]
2023-10-25 03:31:12 +00:00
params = x . copy ( )
2023-10-19 05:10:41 +00:00
params [ " device " ] = device
2023-10-25 03:31:12 +00:00
params [ " noise " ] = noise
2024-06-10 05:05:53 +00:00
default_width = None
if len ( noise . shape ) > = 4 : #TODO: 8 multiple should be set by the model
default_width = noise . shape [ 3 ] * 8
params [ " width " ] = params . get ( " width " , default_width )
2023-10-25 03:31:12 +00:00
params [ " height " ] = params . get ( " height " , noise . shape [ 2 ] * 8 )
params [ " prompt_type " ] = params . get ( " prompt_type " , prompt_type )
2023-10-19 00:36:37 +00:00
for k in kwargs :
if k not in params :
params [ k ] = kwargs [ k ]
out = model_function ( * * params )
2023-10-25 03:31:12 +00:00
x = x . copy ( )
model_conds = x [ ' model_conds ' ] . copy ( )
for k in out :
model_conds [ k ] = out [ k ]
x [ ' model_conds ' ] = model_conds
conds [ t ] = x
2023-10-19 00:36:37 +00:00
return conds
2023-04-19 13:36:19 +00:00
2023-09-26 17:45:15 +00:00
class Sampler :
def sample ( self ) :
pass
def max_denoise ( self , model_wrap , sigmas ) :
2023-10-31 21:33:43 +00:00
max_sigma = float ( model_wrap . inner_model . model_sampling . sigma_max )
sigma = float ( sigmas [ 0 ] )
return math . isclose ( max_sigma , sigma , rel_tol = 1e-05 ) or sigma > max_sigma
2023-09-26 17:45:15 +00:00
2024-06-25 11:41:52 +00:00
KSAMPLER_NAMES = [ " euler " , " euler_pp " , " euler_ancestral " , " euler_ancestral_pp " , " heun " , " heunpp2 " , " dpm_2 " , " dpm_2_ancestral " ,
2023-09-26 17:45:15 +00:00
" lms " , " dpm_fast " , " dpm_adaptive " , " dpmpp_2s_ancestral " , " dpmpp_sde " , " dpmpp_sde_gpu " ,
2024-06-19 20:32:30 +00:00
" dpmpp_2m " , " dpmpp_2m_sde " , " dpmpp_2m_sde_gpu " , " dpmpp_3m_sde " , " dpmpp_3m_sde_gpu " , " ddpm " , " lcm " ,
2024-06-20 12:51:49 +00:00
" ipndm " , " ipndm_v " ]
2023-09-26 17:45:15 +00:00
2023-11-14 05:39:34 +00:00
class KSAMPLER ( Sampler ) :
def __init__ ( self , sampler_function , extra_options = { } , inpaint_options = { } ) :
self . sampler_function = sampler_function
self . extra_options = extra_options
self . inpaint_options = inpaint_options
2023-09-26 17:45:15 +00:00
2023-11-14 05:39:34 +00:00
def sample ( self , model_wrap , sigmas , extra_args , callback , noise , latent_image = None , denoise_mask = None , disable_pbar = False ) :
extra_args [ " denoise_mask " ] = denoise_mask
2024-03-03 20:11:13 +00:00
model_k = KSamplerX0Inpaint ( model_wrap , sigmas )
2023-11-14 05:39:34 +00:00
model_k . latent_image = latent_image
if self . inpaint_options . get ( " random " , False ) : #TODO: Should this be the default?
generator = torch . manual_seed ( extra_args . get ( " seed " , 41 ) + 1 )
model_k . noise = torch . randn ( noise . shape , generator = generator , device = " cpu " ) . to ( noise . dtype ) . to ( noise . device )
else :
model_k . noise = noise
2023-09-26 17:45:15 +00:00
2024-03-01 17:54:38 +00:00
noise = model_wrap . inner_model . model_sampling . noise_scaling ( sigmas [ 0 ] , noise , latent_image , self . max_denoise ( model_wrap , sigmas ) )
2023-11-14 05:39:34 +00:00
k_callback = None
total_steps = len ( sigmas ) - 1
if callback is not None :
k_callback = lambda x : callback ( x [ " i " ] , x [ " denoised " ] , x [ " x " ] , total_steps )
samples = self . sampler_function ( model_k , noise , sigmas , extra_args = extra_args , callback = k_callback , disable = disable_pbar , * * self . extra_options )
2024-03-21 18:49:11 +00:00
samples = model_wrap . inner_model . model_sampling . inverse_noise_scaling ( sigmas [ - 1 ] , samples )
2023-11-14 05:39:34 +00:00
return samples
def ksampler ( sampler_name , extra_options = { } , inpaint_options = { } ) :
if sampler_name == " dpm_fast " :
def dpm_fast_function ( model , noise , sigmas , extra_args , callback , disable ) :
2024-05-01 21:05:30 +00:00
if len ( sigmas ) < = 1 :
return noise
2023-09-26 17:45:15 +00:00
sigma_min = sigmas [ - 1 ]
if sigma_min == 0 :
sigma_min = sigmas [ - 2 ]
2023-11-14 05:39:34 +00:00
total_steps = len ( sigmas ) - 1
return k_diffusion_sampling . sample_dpm_fast ( model , noise , sigma_min , sigmas [ 0 ] , total_steps , extra_args = extra_args , callback = callback , disable = disable )
sampler_function = dpm_fast_function
elif sampler_name == " dpm_adaptive " :
2024-03-15 23:34:22 +00:00
def dpm_adaptive_function ( model , noise , sigmas , extra_args , callback , disable , * * extra_options ) :
2024-05-01 21:05:30 +00:00
if len ( sigmas ) < = 1 :
return noise
2023-11-14 05:39:34 +00:00
sigma_min = sigmas [ - 1 ]
if sigma_min == 0 :
sigma_min = sigmas [ - 2 ]
2024-03-15 23:34:22 +00:00
return k_diffusion_sampling . sample_dpm_adaptive ( model , noise , sigma_min , sigmas [ 0 ] , extra_args = extra_args , callback = callback , disable = disable , * * extra_options )
2023-11-14 05:39:34 +00:00
sampler_function = dpm_adaptive_function
else :
sampler_function = getattr ( k_diffusion_sampling , " sample_ {} " . format ( sampler_name ) )
2023-09-26 17:45:15 +00:00
2023-11-14 05:39:34 +00:00
return KSAMPLER ( sampler_function , extra_options , inpaint_options )
2023-09-26 17:45:15 +00:00
2024-04-03 20:34:19 +00:00
def process_conds ( model , noise , conds , device , latent_image = None , denoise_mask = None , seed = None ) :
for k in conds :
conds [ k ] = conds [ k ] [ : ]
2024-06-10 05:05:53 +00:00
resolve_areas_and_cond_masks_multidim ( conds [ k ] , noise . shape [ 2 : ] , device )
2023-09-26 17:45:15 +00:00
2024-04-03 20:34:19 +00:00
for k in conds :
calculate_start_end_timesteps ( model , conds [ k ] )
2023-09-26 17:45:15 +00:00
2024-04-03 20:34:19 +00:00
if hasattr ( model , ' extra_conds ' ) :
for k in conds :
conds [ k ] = encode_model_conds ( model . extra_conds , conds [ k ] , noise , device , k , latent_image = latent_image , denoise_mask = denoise_mask , seed = seed )
2023-09-26 17:45:15 +00:00
2024-04-03 20:34:19 +00:00
#make sure each cond area has an opposite one with the same area
for k in conds :
for c in conds [ k ] :
for kk in conds :
if k != kk :
create_cond_with_same_area_if_none ( conds [ kk ] , c )
for k in conds :
pre_run_control ( model , conds [ k ] )
if " positive " in conds :
positive = conds [ " positive " ]
for k in conds :
if k != " positive " :
apply_empty_x_to_equal_area ( list ( filter ( lambda c : c . get ( ' control_apply_to_uncond ' , False ) == True , positive ) ) , conds [ k ] , ' control ' , lambda cond_cnets , x : cond_cnets [ x ] )
apply_empty_x_to_equal_area ( positive , conds [ k ] , ' gligen ' , lambda cond_cnets , x : cond_cnets [ x ] )
2023-09-26 17:45:15 +00:00
2024-04-03 20:34:19 +00:00
return conds
2024-04-04 04:48:42 +00:00
class CFGGuider :
def __init__ ( self , model_patcher ) :
self . model_patcher = model_patcher
self . model_options = model_patcher . model_options
self . original_conds = { }
self . cfg = 1.0
2024-04-04 15:16:49 +00:00
def set_conds ( self , positive , negative ) :
self . inner_set_conds ( { " positive " : positive , " negative " : negative } )
2024-04-04 04:48:42 +00:00
def set_cfg ( self , cfg ) :
self . cfg = cfg
2024-04-04 15:16:49 +00:00
def inner_set_conds ( self , conds ) :
for k in conds :
self . original_conds [ k ] = comfy . sampler_helpers . convert_cond ( conds [ k ] )
2024-04-04 04:48:42 +00:00
def __call__ ( self , * args , * * kwargs ) :
return self . predict_noise ( * args , * * kwargs )
def predict_noise ( self , x , timestep , model_options = { } , seed = None ) :
return sampling_function ( self . inner_model , x , timestep , self . conds . get ( " negative " , None ) , self . conds . get ( " positive " , None ) , self . cfg , model_options = model_options , seed = seed )
def inner_sample ( self , noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed ) :
if latent_image is not None and torch . count_nonzero ( latent_image ) > 0 : #Don't shift the empty latent image.
latent_image = self . inner_model . process_latent_in ( latent_image )
2024-04-03 20:34:19 +00:00
2024-04-04 04:48:42 +00:00
self . conds = process_conds ( self . inner_model , noise , self . conds , device , latent_image , denoise_mask , seed )
extra_args = { " model_options " : self . model_options , " seed " : seed }
samples = sampler . sample ( self , sigmas , extra_args , callback , noise , latent_image , denoise_mask , disable_pbar )
return self . inner_model . process_latent_out ( samples . to ( torch . float32 ) )
def sample ( self , noise , latent_image , sampler , sigmas , denoise_mask = None , callback = None , disable_pbar = False , seed = None ) :
2024-04-04 15:38:25 +00:00
if sigmas . shape [ - 1 ] == 0 :
return latent_image
2024-04-04 04:48:42 +00:00
self . conds = { }
for k in self . original_conds :
self . conds [ k ] = list ( map ( lambda a : a . copy ( ) , self . original_conds [ k ] ) )
self . inner_model , self . conds , self . loaded_models = comfy . sampler_helpers . prepare_sampling ( self . model_patcher , noise . shape , self . conds )
device = self . model_patcher . load_device
if denoise_mask is not None :
denoise_mask = comfy . sampler_helpers . prepare_mask ( denoise_mask , noise . shape , device )
2023-12-19 07:32:59 +00:00
2024-04-04 04:48:42 +00:00
noise = noise . to ( device )
latent_image = latent_image . to ( device )
sigmas = sigmas . to ( device )
2023-12-18 17:54:23 +00:00
2024-04-04 04:48:42 +00:00
output = self . inner_sample ( noise , latent_image , device , sampler , sigmas , denoise_mask , callback , disable_pbar , seed )
2023-09-26 17:45:15 +00:00
2024-04-04 04:48:42 +00:00
comfy . sampler_helpers . cleanup_models ( self . conds , self . loaded_models )
del self . inner_model
del self . conds
del self . loaded_models
return output
2023-09-26 17:45:15 +00:00
2024-04-03 20:34:19 +00:00
def sample ( model , noise , positive , negative , cfg , device , sampler , sigmas , model_options = { } , latent_image = None , denoise_mask = None , callback = None , disable_pbar = False , seed = None ) :
2024-04-04 04:48:42 +00:00
cfg_guider = CFGGuider ( model )
2024-04-04 15:16:49 +00:00
cfg_guider . set_conds ( positive , negative )
2024-04-04 04:48:42 +00:00
cfg_guider . set_cfg ( cfg )
return cfg_guider . sample ( noise , latent_image , sampler , sigmas , denoise_mask , callback , disable_pbar , seed )
2023-09-26 17:45:15 +00:00
2023-09-26 20:25:34 +00:00
SCHEDULER_NAMES = [ " normal " , " karras " , " exponential " , " sgm_uniform " , " simple " , " ddim_uniform " ]
SAMPLER_NAMES = KSAMPLER_NAMES + [ " ddim " , " uni_pc " , " uni_pc_bh2 " ]
2024-04-05 02:08:49 +00:00
def calculate_sigmas ( model_sampling , scheduler_name , steps ) :
2023-09-26 20:25:34 +00:00
if scheduler_name == " karras " :
2024-04-05 02:08:49 +00:00
sigmas = k_diffusion_sampling . get_sigmas_karras ( n = steps , sigma_min = float ( model_sampling . sigma_min ) , sigma_max = float ( model_sampling . sigma_max ) )
2023-09-26 20:25:34 +00:00
elif scheduler_name == " exponential " :
2024-04-05 02:08:49 +00:00
sigmas = k_diffusion_sampling . get_sigmas_exponential ( n = steps , sigma_min = float ( model_sampling . sigma_min ) , sigma_max = float ( model_sampling . sigma_max ) )
2023-09-26 20:25:34 +00:00
elif scheduler_name == " normal " :
2024-04-05 02:08:49 +00:00
sigmas = normal_scheduler ( model_sampling , steps )
2023-09-26 20:25:34 +00:00
elif scheduler_name == " simple " :
2024-04-05 02:08:49 +00:00
sigmas = simple_scheduler ( model_sampling , steps )
2023-09-26 20:25:34 +00:00
elif scheduler_name == " ddim_uniform " :
2024-04-05 02:08:49 +00:00
sigmas = ddim_scheduler ( model_sampling , steps )
2023-09-26 20:25:34 +00:00
elif scheduler_name == " sgm_uniform " :
2024-04-05 02:08:49 +00:00
sigmas = normal_scheduler ( model_sampling , steps , sgm = True )
2023-09-26 20:25:34 +00:00
else :
2024-03-11 20:24:47 +00:00
logging . error ( " error invalid scheduler {} " . format ( scheduler_name ) )
2023-09-26 20:25:34 +00:00
return sigmas
2023-11-14 05:39:34 +00:00
def sampler_object ( name ) :
2023-09-28 04:17:03 +00:00
if name == " uni_pc " :
2024-02-23 07:39:35 +00:00
sampler = KSAMPLER ( uni_pc . sample_unipc )
2023-09-28 04:17:03 +00:00
elif name == " uni_pc_bh2 " :
2024-02-23 07:39:35 +00:00
sampler = KSAMPLER ( uni_pc . sample_unipc_bh2 )
2023-09-28 04:17:03 +00:00
elif name == " ddim " :
2023-10-31 22:11:29 +00:00
sampler = ksampler ( " euler " , inpaint_options = { " random " : True } )
2023-09-28 04:17:03 +00:00
else :
sampler = ksampler ( name )
return sampler
2023-01-03 06:53:32 +00:00
class KSampler :
2023-09-26 20:25:34 +00:00
SCHEDULERS = SCHEDULER_NAMES
SAMPLERS = SAMPLER_NAMES
2024-02-08 09:24:23 +00:00
DISCARD_PENULTIMATE_SIGMA_SAMPLERS = set ( ( ' dpm_2 ' , ' dpm_2_ancestral ' , ' uni_pc ' , ' uni_pc_bh2 ' ) )
2023-01-03 06:53:32 +00:00
2023-03-31 21:19:58 +00:00
def __init__ ( self , model , steps , device , sampler = None , scheduler = None , denoise = None , model_options = { } ) :
2023-01-03 06:53:32 +00:00
self . model = model
self . device = device
if scheduler not in self . SCHEDULERS :
scheduler = self . SCHEDULERS [ 0 ]
if sampler not in self . SAMPLERS :
sampler = self . SAMPLERS [ 0 ]
self . scheduler = scheduler
self . sampler = sampler
self . set_steps ( steps , denoise )
2023-02-22 07:04:21 +00:00
self . denoise = denoise
2023-03-31 21:19:58 +00:00
self . model_options = model_options
2023-01-03 06:53:32 +00:00
2023-04-25 02:45:35 +00:00
def calculate_sigmas ( self , steps ) :
sigmas = None
discard_penultimate_sigma = False
2024-02-08 09:24:23 +00:00
if self . sampler in self . DISCARD_PENULTIMATE_SIGMA_SAMPLERS :
2023-04-25 02:45:35 +00:00
steps + = 1
discard_penultimate_sigma = True
2024-04-05 02:08:49 +00:00
sigmas = calculate_sigmas ( self . model . get_model_object ( " model_sampling " ) , self . scheduler , steps )
2023-04-25 02:45:35 +00:00
if discard_penultimate_sigma :
sigmas = torch . cat ( [ sigmas [ : - 2 ] , sigmas [ - 1 : ] ] )
return sigmas
2023-01-03 06:53:32 +00:00
def set_steps ( self , steps , denoise = None ) :
self . steps = steps
2023-03-28 20:29:35 +00:00
if denoise is None or denoise > 0.9999 :
2023-04-25 02:45:35 +00:00
self . sigmas = self . calculate_sigmas ( steps ) . to ( self . device )
2023-01-03 06:53:32 +00:00
else :
2024-04-04 15:38:25 +00:00
if denoise < = 0.0 :
self . sigmas = torch . FloatTensor ( [ ] )
else :
new_steps = int ( steps / denoise )
sigmas = self . calculate_sigmas ( new_steps ) . to ( self . device )
self . sigmas = sigmas [ - ( steps + 1 ) : ]
2023-01-03 06:53:32 +00:00
2023-06-25 06:41:31 +00:00
def sample ( self , noise , positive , negative , cfg , latent_image = None , start_step = None , last_step = None , force_full_denoise = False , denoise_mask = None , sigmas = None , callback = None , disable_pbar = False , seed = None ) :
2023-04-23 18:02:08 +00:00
if sigmas is None :
sigmas = self . sigmas
2023-01-03 06:53:32 +00:00
2023-01-31 08:09:38 +00:00
if last_step is not None and last_step < ( len ( sigmas ) - 1 ) :
2023-01-03 06:53:32 +00:00
sigmas = sigmas [ : last_step + 1 ]
2023-01-31 08:09:38 +00:00
if force_full_denoise :
sigmas [ - 1 ] = 0
2023-01-03 06:53:32 +00:00
if start_step is not None :
2023-01-31 08:09:38 +00:00
if start_step < ( len ( sigmas ) - 1 ) :
sigmas = sigmas [ start_step : ]
else :
if latent_image is not None :
return latent_image
else :
return torch . zeros_like ( noise )
2023-01-03 06:53:32 +00:00
2023-11-14 05:39:34 +00:00
sampler = sampler_object ( self . sampler )
2023-02-15 06:49:17 +00:00
2023-11-14 05:39:34 +00:00
return sample ( self . model , noise , positive , negative , cfg , self . device , sampler , sigmas , self . model_options , latent_image = latent_image , denoise_mask = denoise_mask , callback = callback , disable_pbar = disable_pbar , seed = seed )