Add emb_patch support to UNetModel forward (#4779)

This commit is contained in:
Jedrzej Kosinski 2024-09-04 13:35:15 -05:00 committed by GitHub
parent f067ad15d1
commit f04229b84d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 0 deletions

View File

@ -842,6 +842,11 @@ class UNetModel(nn.Module):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if "emb_patch" in transformer_patches:
patch = transformer_patches["emb_patch"]
for p in patch:
emb = p(emb, self.model_channels, transformer_options)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)