Unverified Commit 01f2ed68 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #5065 from JaySmithWpg/vram-leak

#3449 - VRAM leak when switching to/from inpainting checkpoint
parents 151e2cc6 c833d5bf
from collections import namedtuple from collections import namedtuple, deque
import numpy as np import numpy as np
from math import floor from math import floor
import torch import torch
...@@ -344,18 +344,28 @@ class CFGDenoiser(torch.nn.Module): ...@@ -344,18 +344,28 @@ class CFGDenoiser(torch.nn.Module):
class TorchHijack: class TorchHijack:
def __init__(self, kdiff_sampler): def __init__(self, sampler_noises):
self.kdiff_sampler = kdiff_sampler # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation.
self.sampler_noises = deque(sampler_noises)
def __getattr__(self, item): def __getattr__(self, item):
if item == 'randn_like': if item == 'randn_like':
return self.kdiff_sampler.randn_like return self.randn_like
if hasattr(torch, item): if hasattr(torch, item):
return getattr(torch, item) return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item)) raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
def randn_like(self, x):
if self.sampler_noises:
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return torch.randn_like(x)
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model):
...@@ -367,7 +377,6 @@ class KDiffusionSampler: ...@@ -367,7 +377,6 @@ class KDiffusionSampler:
self.extra_params = sampler_extra_params.get(funcname, []) self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None self.sampler_noises = None
self.sampler_noise_index = 0
self.stop_at = None self.stop_at = None
self.eta = None self.eta = None
self.default_eta = 1.0 self.default_eta = 1.0
...@@ -400,26 +409,14 @@ class KDiffusionSampler: ...@@ -400,26 +409,14 @@ class KDiffusionSampler:
def number_of_needed_noises(self, p): def number_of_needed_noises(self, p):
return p.steps return p.steps
def randn_like(self, x):
noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None
if noise is not None and x.shape == noise.shape:
res = noise
else:
res = torch.randn_like(x)
self.sampler_noise_index += 1
return res
def initialize(self, p): def initialize(self, p):
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap.step = 0 self.model_wrap.step = 0
self.sampler_noise_index = 0
self.eta = p.eta or opts.eta_ancestral self.eta = p.eta or opts.eta_ancestral
if self.sampler_noises is not None: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self) k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
extra_params_kwargs = {} extra_params_kwargs = {}
for param_name in self.extra_params: for param_name in self.extra_params:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment