Commit 15e89ef0 authored by AUTOMATIC's avatar AUTOMATIC

fix for unet hijack breaking the train tab

parent 789d47f8
...@@ -36,8 +36,11 @@ th = TorchHijackForUnet() ...@@ -36,8 +36,11 @@ th = TorchHijackForUnet()
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
for y in cond.keys():
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] if isinstance(cond, dict):
for y in cond.keys():
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
with devices.autocast(): with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment