diff --git a/nodes.py b/nodes.py index 9c541eba..a230f725 100644 --- a/nodes.py +++ b/nodes.py @@ -232,8 +232,9 @@ class ConditioningZeroOut: c = [] for t in conditioning: d = t[1].copy() - if "pooled_output" in d: - d["pooled_output"] = torch.zeros_like(d["pooled_output"]) + pooled_output = d.get("pooled_output", None) + if pooled_output is not None: + d["pooled_output"] = torch.zeros_like(pooled_output) n = [torch.zeros_like(t[0]), d] c.append(n) return (c, )