Support SDXS 0.9
This commit is contained in:
parent
8ae1e4d125
commit
327ca1313d
|
@ -345,7 +345,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||||
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
|
SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||||
|
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
|
||||||
|
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
|
||||||
|
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
|
||||||
|
|
||||||
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
|
|
@ -70,8 +70,8 @@ class SD20(supported_models_base.BASE):
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||||
out = state_dict[k]
|
out = state_dict.get(k, None)
|
||||||
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||||
return model_base.ModelType.V_PREDICTION
|
return model_base.ModelType.V_PREDICTION
|
||||||
return model_base.ModelType.EPS
|
return model_base.ModelType.EPS
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue