Commit e0e80050 authored by AUTOMATIC's avatar AUTOMATIC

make StableDiffusionProcessing class not hold a reference to shared.sd_model object

parent 9991967f
...@@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): ...@@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
return image_conditioning return image_conditioning
class StableDiffusionProcessing(): class StableDiffusionProcessing:
""" """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
""" """
...@@ -102,7 +102,6 @@ class StableDiffusionProcessing(): ...@@ -102,7 +102,6 @@ class StableDiffusionProcessing():
if sampler_index is not None: if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
self.prompt: str = prompt self.prompt: str = prompt
...@@ -156,6 +155,10 @@ class StableDiffusionProcessing(): ...@@ -156,6 +155,10 @@ class StableDiffusionProcessing():
self.all_subseeds = None self.all_subseeds = None
self.iteration = 0 self.iteration = 0
@property
def sd_model(self):
return shared.sd_model
def txt2img_image_conditioning(self, x, width=None, height=None): def txt2img_image_conditioning(self, x, width=None, height=None):
self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
...@@ -236,7 +239,6 @@ class StableDiffusionProcessing(): ...@@ -236,7 +239,6 @@ class StableDiffusionProcessing():
raise NotImplementedError() raise NotImplementedError()
def close(self): def close(self):
self.sd_model = None
self.sampler = None self.sampler = None
...@@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_model_checkpoint': if k == 'sd_model_checkpoint':
sd_models.reload_model_weights() # make onchange call for changing SD model sd_models.reload_model_weights() # make onchange call for changing SD model
p.sd_model = shared.sd_model
if k == 'sd_vae': if k == 'sd_vae':
sd_vae.reload_vae_weights() # make onchange call for changing VAE sd_vae.reload_vae_weights() # make onchange call for changing VAE
......
...@@ -86,7 +86,6 @@ def apply_checkpoint(p, x, xs): ...@@ -86,7 +86,6 @@ def apply_checkpoint(p, x, xs):
if info is None: if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}") raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs): def confirm_checkpoints(p, xs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment