Commit aea5b251 authored by AUTOMATIC's avatar AUTOMATIC

save parameters for images when using the Save button.

parent 5eb9d1ae
...@@ -100,7 +100,7 @@ class StableDiffusionProcessing: ...@@ -100,7 +100,7 @@ class StableDiffusionProcessing:
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
...@@ -139,6 +139,7 @@ class Processed: ...@@ -139,6 +139,7 @@ class Processed:
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_seeds = all_seeds or [self.seed]
self.all_subseeds = all_subseeds or [self.subseed] self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self): def js(self):
obj = { obj = {
...@@ -165,6 +166,7 @@ class Processed: ...@@ -165,6 +166,7 @@ class Processed:
"denoising_strength": self.denoising_strength, "denoising_strength": self.denoising_strength,
"extra_generation_params": self.extra_generation_params, "extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image, "index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts,
} }
return json.dumps(obj) return json.dumps(obj)
...@@ -322,6 +324,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -322,6 +324,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
...@@ -404,6 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -404,6 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples: if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
infotexts.append(infotext(n, i))
output_images.append(image) output_images.append(image)
state.nextjob() state.nextjob()
...@@ -416,6 +420,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -416,6 +420,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
infotexts.insert(0, infotext())
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
...@@ -423,7 +428,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -423,7 +428,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image) return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
......
...@@ -143,6 +143,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -143,6 +143,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving"), {
...@@ -180,7 +181,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration" ...@@ -180,7 +181,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
"save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
......
...@@ -12,7 +12,7 @@ import traceback ...@@ -12,7 +12,7 @@ import traceback
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image, PngImagePlugin
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
...@@ -97,10 +97,11 @@ def save_files(js_data, images, index): ...@@ -97,10 +97,11 @@ def save_files(js_data, images, index):
filenames = [] filenames = []
data = json.loads(js_data) data = json.loads(js_data)
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
images = [images[index]] images = [images[index]]
data["seed"] += (index - 1 if opts.return_grid else index) infotexts = [data["infotexts"][index]]
else:
infotexts = data["infotexts"]
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0 at_start = file.tell() == 0
...@@ -116,8 +117,11 @@ def save_files(js_data, images, index): ...@@ -116,8 +117,11 @@ def save_files(js_data, images, index):
if filedata.startswith("data:image/png;base64,"): if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):] filedata = filedata[len("data:image/png;base64,"):]
with open(filepath, "wb") as imgfile: pnginfo = PngImagePlugin.PngInfo()
imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) pnginfo.add_text('parameters', infotexts[i])
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
filenames.append(filename) filenames.append(filename)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment