Commit 5d8c59ee authored by yfszzx's avatar yfszzx
parents 763b893f d41ac174
# See https://github.com/actions/starter-workflows/blob/1067f16ad8a1eac328834e4b0ae24f7d206f810d/ci/pylint.yml for original reference file
name: Run Linting/Formatting on Pull Requests
on:
- push
- pull_request
# See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#onpull_requestpull_request_targetbranchesbranches-ignore for syntax docs
# if you want to filter out branches, delete the `- pull_request` and uncomment these lines :
# pull_request:
# branches:
# - master
# branches-ignore:
# - development
jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: 3.10.6
- name: Install PyLint
run: |
python -m pip install --upgrade pip
pip install pylint
# This lets PyLint check to see if it can resolve imports
- name: Install dependencies
run : |
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
python launch.py
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html
[MESSAGES CONTROL]
disable=C,R,W,E,I
...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder != "Prompt") { if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
...@@ -53,7 +53,7 @@ window.document.addEventListener('dragover', e => { ...@@ -53,7 +53,7 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => { window.document.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder === "Prompt") { if (target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
......
...@@ -2,6 +2,8 @@ addEventListener('keydown', (event) => { ...@@ -2,6 +2,8 @@ addEventListener('keydown', (event) => {
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (!target.hasAttribute("placeholder")) return; if (!target.hasAttribute("placeholder")) return;
if (!target.placeholder.toLowerCase().includes("prompt")) return; if (!target.placeholder.toLowerCase().includes("prompt")) return;
if (! (event.metaKey || event.ctrlKey)) return;
let plus = "ArrowUp" let plus = "ArrowUp"
let minus = "ArrowDown" let minus = "ArrowDown"
......
...@@ -16,6 +16,8 @@ titles = { ...@@ -16,6 +16,8 @@ titles = {
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style",
"\u{1f4cb}": "Apply selected styles to current prompt",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
......
...@@ -2,7 +2,7 @@ window.onload = (function(){ ...@@ -2,7 +2,7 @@ window.onload = (function(){
window.addEventListener('drop', e => { window.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const idx = selected_gallery_index(); const idx = selected_gallery_index();
if (target.placeholder != "Prompt") return; if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
......
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
global_progressbars = {} global_progressbars = {}
galleries = {}
galleryObservers = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) var progressbar = gradioApp().getElementById(id_progressbar)
...@@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
preview.style.width = gallery.clientWidth + "px" preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px" preview.style.height = gallery.clientHeight + "px"
//only watch gallery if there is a generation process going on
check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){ if(!progressDiv){
if (skip) { if (skip) {
skip.style.display = "none" skip.style.display = "none"
} }
interrupt.style.display = "none" interrupt.style.display = "none"
//disconnect observer once generation finished, so user can close selected image if they want
if (galleryObservers[id_gallery]) {
galleryObservers[id_gallery].disconnect();
galleries[id_gallery] = null;
}
} }
} }
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
...@@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
} }
} }
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
if (gallery && galleries[id_gallery] !== gallery){
galleries[id_gallery] = gallery;
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
let prevSelectedIndex = selected_gallery_index();
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
//automatically re-open previously selected index (if exists)
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
}
})
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
}
}
onUiUpdate(function(){ onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
......
...@@ -141,7 +141,7 @@ function submit_img2img(){ ...@@ -141,7 +141,7 @@ function submit_img2img(){
function ask_for_style_name(_, prompt_text, negative_prompt_text) { function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:') name_ = prompt('Style name:')
return name_ === null ? [null, null, null]: [name_, prompt_text, negative_prompt_text] return [name_, prompt_text, negative_prompt_text]
} }
...@@ -187,12 +187,10 @@ onUiUpdate(function(){ ...@@ -187,12 +187,10 @@ onUiUpdate(function(){
if (!txt2img_textarea) { if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
} }
if (!img2img_textarea) { if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
} }
}) })
...@@ -220,14 +218,6 @@ function update_token_counter(button_id) { ...@@ -220,14 +218,6 @@ function update_token_counter(button_id) {
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function restart_reload(){ function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000) setTimeout(function(){location.reload()},2000)
......
...@@ -9,6 +9,7 @@ import platform ...@@ -9,6 +9,7 @@ import platform
dir_repos = "repositories" dir_repos = "repositories"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
def extract_arg(args, name): def extract_arg(args, name):
...@@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None): ...@@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None):
def run_pip(args, desc=None): def run_pip(args, desc=None):
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run_python(code): def check_run_python(code):
...@@ -102,6 +104,7 @@ def prepare_enviroment(): ...@@ -102,6 +104,7 @@ def prepare_enviroment():
args = shlex.split(commandline_args) args = shlex.split(commandline_args)
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
args, reinstall_xformers = extract_arg(args, '--reinstall-xformers')
xformers = '--xformers' in args xformers = '--xformers' in args
deepdanbooru = '--deepdanbooru' in args deepdanbooru = '--deepdanbooru' in args
ngrok = '--ngrok' in args ngrok = '--ngrok' in args
...@@ -126,9 +129,9 @@ def prepare_enviroment(): ...@@ -126,9 +129,9 @@ def prepare_enviroment():
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): if (not is_installed("xformers") or reinstall_xformers) and xformers and platform.python_version().startswith("3.10"):
if platform.system() == "Windows": if platform.system() == "Windows":
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/c/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") run_pip("install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip("install xformers", "xformers") run_pip("install xformers", "xformers")
......
...@@ -102,7 +102,7 @@ def get_deepbooru_tags_model(): ...@@ -102,7 +102,7 @@ def get_deepbooru_tags_model():
tags = dd.project.load_tags_from_project(model_path) tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project( model = dd.project.load_model_from_project(
model_path, compile_model=True model_path, compile_model=False
) )
return model, tags return model, tags
......
...@@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): ...@@ -182,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out) return self.to_out(out)
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def stack_conds(conds):
if len(conds) == 1:
return torch.stack(conds)
# same as in reconstruct_multicond_batch
token_count = max([x.shape[0] for x in conds])
for i in range(len(conds)):
if conds[i].shape[0] != token_count:
last_vector = conds[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
conds[i] = torch.vstack([conds[i], last_vector_repeated])
return torch.stack(conds)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
...@@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -211,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
...@@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -235,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entry in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)
...@@ -246,26 +260,29 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -246,26 +260,29 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
cond = entry.cond.to(devices.device) c = stack_conds([entry.cond for entry in entries]).to(devices.device)
x = entry.latent.to(devices.device) # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0] x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x del x
del cond del c
losses[hypernetwork.step % losses.shape[0]] = loss.item() losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
mean_loss = losses.mean()
pbar.set_description(f"loss: {losses.mean():.7f}") if torch.isnan(mean_loss):
raise RuntimeError("Loss diverged.")
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file) hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{losses.mean():.7f}", "loss": f"{mean_loss:.7f}",
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
...@@ -292,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -292,7 +309,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entry.cond_text p.prompt = entries[0].cond_text
p.steps = 20 p.steps = 20
preview_text = p.prompt preview_text = p.prompt
...@@ -313,9 +330,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -313,9 +330,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f""" shared.state.textinfo = f"""
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/> Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
......
...@@ -197,14 +197,16 @@ def delete_image(delete_num, tabname, name, page_index, filenames, image_index): ...@@ -197,14 +197,16 @@ def delete_image(delete_num, tabname, name, page_index, filenames, image_index):
return new_file_list, 1 return new_file_list, 1
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if tabname == "txt2img": if opts.outdir_samples != "":
dir_name = opts.outdir_samples
elif tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img": elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples dir_name = opts.outdir_img2img_samples
elif tabname == "extras": elif tabname == "extras":
dir_name = opts.outdir_extras_samples dir_name = opts.outdir_extras_samples
d = dir_name.split("/") d = dir_name.split("/")
dir_name = d[0] dir_name = "/" if dir_name.startswith("/") else d[0]
for p in d[1:]: for p in d[1:]:
dir_name = os.path.join(dir_name, p) dir_name = os.path.join(dir_name, p)
......
...@@ -140,7 +140,7 @@ class Processed: ...@@ -140,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
...@@ -528,7 +528,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -528,7 +528,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_height_truncated = int(scale * self.height) firstphase_height_truncated = int(scale * self.height)
else: else:
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
width_ratio = self.width / self.firstphase_width width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height height_ratio = self.height / self.firstphase_height
...@@ -540,6 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -540,6 +539,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
firstphase_width_truncated = self.firstphase_height * self.width / self.height firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height firstphase_height_truncated = self.firstphase_height
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
...@@ -557,11 +557,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -557,11 +557,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
decoded_samples = decode_first_stage(self.sd_model, samples) if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")
else: else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = [] batch_images = []
...@@ -578,7 +578,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -578,7 +578,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
decoded_samples = decoded_samples.to(shared.device) decoded_samples = decoded_samples.to(shared.device)
decoded_samples = 2. * decoded_samples - 1. decoded_samples = 2. * decoded_samples - 1.
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
shared.state.nextjob() shared.state.nextjob()
......
...@@ -24,7 +24,7 @@ def apply_optimizations(): ...@@ -24,7 +24,7 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
......
import glob import collections
import os.path import os.path
import sys import sys
from collections import namedtuple from collections import namedtuple
...@@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir)) ...@@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir))
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {} checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
...@@ -132,41 +133,45 @@ def load_model_weights(model, checkpoint_info): ...@@ -132,41 +133,45 @@ def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "global_step" in pl_sd: sd = get_state_dict_from_checkpoint(pl_sd)
print(f"Global Step: {pl_sd['global_step']}") model.load_state_dict(sd, strict=False)
sd = get_state_dict_from_checkpoint(pl_sd) if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
model.load_state_dict(sd, strict=False) if not shared.cmd_opts.no_half:
model.half()
if shared.cmd_opts.opt_channelslast: devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
model.to(memory_format=torch.channels_last) devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
if not shared.cmd_opts.no_half: vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = shared.cmd_opts.vae_path
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict)
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: model.first_stage_model.to(devices.dtype_vae)
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file): checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
print(f"Loading VAE weights from: {vae_file}") while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) else:
print(f"Loading weights [{sd_model_hash}] from cache")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} checkpoints_loaded.move_to_end(checkpoint_info)
model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
...@@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None): ...@@ -205,6 +210,7 @@ def reload_model_weights(sd_model, info=None):
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config: if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
checkpoints_loaded.clear()
shared.sd_model = load_model() shared.sd_model = load_model()
return shared.sd_model return shared.sd_model
......
...@@ -218,6 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { ...@@ -218,6 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
})) }))
options_templates.update(options_section(('face-restoration', "Face restoration"), { options_templates.update(options_section(('face-restoration', "Face restoration"), {
...@@ -242,6 +243,7 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -242,6 +243,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
...@@ -255,7 +257,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -255,7 +257,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
...@@ -283,6 +284,7 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -283,6 +284,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
})) }))
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
......
...@@ -24,11 +24,12 @@ class DatasetEntry: ...@@ -24,11 +24,12 @@ class DatasetEntry:
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
...@@ -78,13 +79,14 @@ class PersonalizedBase(Dataset): ...@@ -78,13 +79,14 @@ class PersonalizedBase(Dataset):
if include_cond: if include_cond:
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu) entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry) self.dataset.append(entry)
self.length = len(self.dataset) * repeats assert len(self.dataset) > 1, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(self.length) % len(self.dataset) self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None self.indexes = None
self.shuffle() self.shuffle()
...@@ -101,13 +103,19 @@ class PersonalizedBase(Dataset): ...@@ -101,13 +103,19 @@ class PersonalizedBase(Dataset):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
if i % len(self.dataset) == 0: res = []
self.shuffle()
index = self.indexes[i % len(self.indexes)] for j in range(self.batch_size):
entry = self.dataset[index] position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle()
if entry.cond is None: index = self.indexes[position % len(self.indexes)]
entry.cond_text = self.create_text(entry.filename_text) entry = self.dataset[index]
return entry if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
res.append(entry)
return res
...@@ -88,9 +88,9 @@ class EmbeddingDatabase: ...@@ -88,9 +88,9 @@ class EmbeddingDatabase:
data = [] data = []
if filename.upper().endswith('.PNG'): if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding']) data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name) name = data.get('name', name)
else: else:
...@@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -199,7 +199,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
}) })
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
...@@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -231,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
...@@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0 ititial_step = embedding.step or 0
if ititial_step > steps: if ititial_step > steps:
...@@ -251,7 +252,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -251,7 +252,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entry in pbar: for i, entries in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
scheduler.apply(optimizer, embedding.step) scheduler.apply(optimizer, embedding.step)
...@@ -262,10 +263,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -262,10 +263,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = cond_model([entry.cond_text]) c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
x = entry.latent.to(devices.device) loss = shared.sd_model(x, c)[0]
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x del x
losses[embedding.step % losses.shape[0]] = loss.item() losses[embedding.step % losses.shape[0]] = loss.item()
...@@ -282,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -282,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file) embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}", "loss": f"{losses.mean():.7f}",
...@@ -307,7 +308,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -307,7 +308,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
p.width = preview_width p.width = preview_width
p.height = preview_height p.height = preview_height
else: else:
p.prompt = entry.cond_text p.prompt = entries[0].cond_text
p.steps = 20 p.steps = 20
p.width = training_width p.width = training_width
p.height = training_height p.height = training_height
...@@ -319,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -319,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file): if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
...@@ -328,15 +329,22 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -328,15 +329,22 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
info.add_text("sd-ti-embedding", embedding_to_b64(data)) info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???')) title = "<{}>".format(data.get('name', '???'))
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
except Exception as e:
vectorSize = '?'
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash) footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step) footer_right = '{}v {}s'.format(vectorSize, embedding.step)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data) captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
image.save(last_saved_image) image.save(last_saved_image)
...@@ -348,7 +356,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -348,7 +356,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/> Step: {embedding.step}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/> Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
......
This diff is collapsed.
...@@ -2,7 +2,7 @@ transformers==4.19.2 ...@@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.4.1 gradio==3.5
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.2.0
realesrgan==0.3.0 realesrgan==0.3.0
......
...@@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
document.addEventListener('keydown', function(e) { document.addEventListener('keydown', function(e) {
var handled = false; var handled = false;
if (e.key !== undefined) { if (e.key !== undefined) {
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true; if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} else if (e.keyCode !== undefined) { } else if (e.keyCode !== undefined) {
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true; if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} }
if (handled) { if (handled) {
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
......
import copy
import math import math
import os import os
import sys import sys
import traceback import traceback
import shlex
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
...@@ -10,6 +12,75 @@ from modules.processing import Processed, process_images ...@@ -10,6 +12,75 @@ from modules.processing import Processed, process_images
from PIL import Image from PIL import Image
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
def process_string_tag(tag):
return tag
def process_int_tag(tag):
return int(tag)
def process_float_tag(tag):
return float(tag)
def process_boolean_tag(tag):
return True if (tag == "true") else False
prompt_tags = {
"sd_model": None,
"outpath_samples": process_string_tag,
"outpath_grids": process_string_tag,
"prompt_for_display": process_string_tag,
"prompt": process_string_tag,
"negative_prompt": process_string_tag,
"styles": process_string_tag,
"seed": process_int_tag,
"subseed_strength": process_float_tag,
"subseed": process_int_tag,
"seed_resize_from_h": process_int_tag,
"seed_resize_from_w": process_int_tag,
"sampler_index": process_int_tag,
"batch_size": process_int_tag,
"n_iter": process_int_tag,
"steps": process_int_tag,
"cfg_scale": process_float_tag,
"width": process_int_tag,
"height": process_int_tag,
"restore_faces": process_boolean_tag,
"tiling": process_boolean_tag,
"do_not_save_samples": process_boolean_tag,
"do_not_save_grid": process_boolean_tag
}
def cmdargs(line):
args = shlex.split(line)
pos = 0
res = {}
while pos < len(args):
arg = args[pos]
assert arg.startswith("--"), f'must start with "--": {arg}'
tag = arg[2:]
func = prompt_tags.get(tag, None)
assert func, f'unknown commandline option: {arg}'
assert pos+1 < len(args), f'missing argument for command line option {arg}'
val = args[pos+1]
res[tag] = func(val)
pos += 2
return res
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "Prompts from file or textbox" return "Prompts from file or textbox"
...@@ -32,26 +103,48 @@ class Script(scripts.Script): ...@@ -32,26 +103,48 @@ class Script(scripts.Script):
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ] return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str): def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
if (checkbox_txt): if checkbox_txt:
lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x.strip() for x in prompt_txt.splitlines()]
else: else:
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")] lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if len(x) > 0]
img_count = len(lines) * p.n_iter
batch_count = math.ceil(img_count / p.batch_size)
loop_count = math.ceil(batch_count / p.n_iter)
print(f"Will process {img_count} images in {batch_count} batches.")
p.do_not_save_grid = True p.do_not_save_grid = True
state.job_count = batch_count job_count = 0
jobs = []
for line in lines:
if "--" in line:
try:
args = cmdargs(line)
except Exception:
print(f"Error parsing line [line] as commandline:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
args = {"prompt": line}
else:
args = {"prompt": line}
n_iter = args.get("n_iter", 1)
if n_iter != 1:
job_count += n_iter
else:
job_count += 1
jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.")
state.job_count = job_count
images = [] images = []
for loop_no in range(loop_count): for n, args in enumerate(jobs):
state.job = f"{loop_no + 1} out of {loop_count}" state.job = f"{state.job_no + 1} out of {state.job_count}"
p.prompt = lines[loop_no*p.batch_size:(loop_no+1)*p.batch_size] * p.n_iter
proc = process_images(p) copy_p = copy.copy(p)
for k, v in args.items():
setattr(copy_p, k, v)
proc = process_images(copy_p)
images += proc.images images += proc.images
return Processed(p, images, p.seed, "") return Processed(p, images, p.seed, "")
...@@ -12,7 +12,7 @@ import gradio as gr ...@@ -12,7 +12,7 @@ import gradio as gr
from modules import images from modules import images
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
...@@ -354,6 +354,9 @@ class Script(scripts.Script): ...@@ -354,6 +354,9 @@ class Script(scripts.Script):
else: else:
total_steps = p.steps * len(xs) * len(ys) total_steps = p.steps * len(xs) * len(ys)
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
total_steps *= 2
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter) shared.total_tqdm.updateTotal(total_steps * p.n_iter)
......
...@@ -115,7 +115,7 @@ ...@@ -115,7 +115,7 @@
padding: 0.4em 0; padding: 0.4em 0;
} }
#roll, #paste{ #roll, #paste, #style_create, #style_apply{
min-width: 2em; min-width: 2em;
min-height: 2em; min-height: 2em;
max-width: 2em; max-width: 2em;
...@@ -126,14 +126,14 @@ ...@@ -126,14 +126,14 @@
margin: 0.1em 0; margin: 0.1em 0;
} }
#style_apply, #style_create, #interrogate{ #interrogate_col{
margin: 0.75em 0.25em 0.25em 0.25em; min-width: 0 !important;
min-width: 5em; max-width: 8em !important;
} }
#interrogate, #deepbooru{
#style_apply, #style_create, #deepbooru{ margin: 0em 0.25em 0.9em 0.25em;
margin: 0.75em 0.25em 0.25em 0.25em; min-width: 8em;
min-width: 5em; max-width: 8em;
} }
#style_pos_col, #style_neg_col{ #style_pos_col, #style_neg_col{
...@@ -167,10 +167,6 @@ button{ ...@@ -167,10 +167,6 @@ button{
align-self: stretch !important; align-self: stretch !important;
} }
#img2maskimg .h-60{
height: 30rem;
}
.overflow-hidden, .gr-panel{ .overflow-hidden, .gr-panel{
overflow: visible !important; overflow: visible !important;
} }
...@@ -241,13 +237,6 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s ...@@ -241,13 +237,6 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
margin: 0; margin: 0;
} }
.gr-panel div.flex-col div.justify-between div{
position: absolute;
top: -0.1em;
right: 1em;
padding: 0 0.5em;
}
#settings .gr-panel div.flex-col div.justify-between div{ #settings .gr-panel div.flex-col div.justify-between div{
position: relative; position: relative;
z-index: 200; z-index: 200;
...@@ -320,6 +309,8 @@ input[type="range"]{ ...@@ -320,6 +309,8 @@ input[type="range"]{
height: 100%; height: 100%;
overflow: auto; overflow: auto;
background-color: rgba(20, 20, 20, 0.95); background-color: rgba(20, 20, 20, 0.95);
user-select: none;
-webkit-user-select: none;
} }
.modalControls { .modalControls {
...@@ -443,10 +434,6 @@ input[type="range"]{ ...@@ -443,10 +434,6 @@ input[type="range"]{
--tw-bg-opacity: 0 !important; --tw-bg-opacity: 0 !important;
} }
#img2img_image div.h-60{
height: 480px;
}
#context-menu{ #context-menu{
z-index:9999; z-index:9999;
position:absolute; position:absolute;
...@@ -521,3 +508,11 @@ canvas[key="mask"] { ...@@ -521,3 +508,11 @@ canvas[key="mask"] {
.row.gr-compact{ .row.gr-compact{
overflow: visible; overflow: visible;
} }
#img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img,
img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img
{
height: 480px !important;
max-height: 480px !important;
min-height: 480px !important;
}
...@@ -82,8 +82,8 @@ then ...@@ -82,8 +82,8 @@ then
clone_dir="${PWD##*/}" clone_dir="${PWD##*/}"
fi fi
# Check prequisites # Check prerequisites
for preq in git python3 for preq in "${GIT}" "${python_cmd}"
do do
if ! hash "${preq}" &>/dev/null if ! hash "${preq}" &>/dev/null
then then
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment