Commit 37d7ffb4 authored by MalumaDev's avatar MalumaDev

fix to tokens lenght, addend embs generator, add new features to edit the...

fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text
parent bb57f30c
import itertools
import os
from pathlib import Path
import html
import gc
import gradio as gr
import torch
from PIL import Image
from modules import shared
from modules.shared import device, aesthetic_embeddings
from transformers import CLIPModel, CLIPProcessor
from tqdm.auto import tqdm
def get_all_images_in_folder(folder):
return [os.path.join(folder, f) for f in os.listdir(folder) if
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
def check_is_valid_image_file(filename):
return filename.lower().endswith(('.png', '.jpg', '.jpeg'))
def batched(dataset, total, n=1):
for ndx in range(0, total, n):
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
def iter_to_batched(iterable, n=1):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained(
# shared.sd_model.cond_stage_model.clipModel.name_or_path
# )
model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device)
processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path)
with torch.no_grad():
embs = []
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
desc=f"Generating embeddings for {name}"):
if shared.state.interrupted:
break
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
outputs = model.get_image_features(**inputs).cpu()
embs.append(torch.clone(outputs))
inputs.to("cpu")
del inputs, outputs
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
# The generated embedding will be located here
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path)
model = model.cpu()
del model
del processor
del embs
gc.collect()
torch.cuda.empty_cache()
res = f"""
Done generating embedding for {name}!
Hypernetwork saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding",
value=sorted(aesthetic_embeddings.keys())[0] if len(
aesthetic_embeddings) > 0 else None), res, ""
This diff is collapsed.
This diff is collapsed.
...@@ -95,6 +95,10 @@ loaded_hypernetwork = None ...@@ -95,6 +95,10 @@ loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
def reload_hypernetworks(): def reload_hypernetworks():
global hypernetworks global hypernetworks
......
...@@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
aesthetic_lr=0, aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0, aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None, aesthetic_imgs=None,
aesthetic_slerp=False, *args): aesthetic_slerp=False,
aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False,
*args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed = modules.scripts.scripts_txt2img.run(p, *args) processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None: if processed is None:
processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
shared.total_tqdm.clear() shared.total_tqdm.clear()
......
...@@ -41,6 +41,7 @@ from modules import prompt_parser ...@@ -41,6 +41,7 @@ from modules import prompt_parser
from modules.images import save_image from modules.images import save_image
import modules.textual_inversion.ui import modules.textual_inversion.ui
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.aesthetic_clip
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init() mimetypes.init()
...@@ -449,7 +450,7 @@ def create_toprow(is_img2img): ...@@ -449,7 +450,7 @@ def create_toprow(is_img2img):
with gr.Row(): with gr.Row():
negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"): with gr.Column(scale=1, elem_id="roll_col"):
sh = gr.Button(elem_id="sh", visible=True) sh = gr.Button(elem_id="sh", visible=True)
with gr.Column(scale=1, elem_id="style_neg_col"): with gr.Column(scale=1, elem_id="style_neg_col"):
prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
...@@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call):
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group(): with gr.Group():
aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001")
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5)
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None)
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
...@@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call):
aesthetic_weight, aesthetic_weight,
aesthetic_steps, aesthetic_steps,
aesthetic_imgs, aesthetic_imgs,
aesthetic_slerp aesthetic_slerp,
aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
txt2img_gallery, txt2img_gallery,
...@@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32)
with gr.TabItem('Batch img2img', id='batch'): with gr.TabItem('Batch img2img', id='batch'):
hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''
...@@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Tab(label="Create images embedding"):
new_embedding_name_ae = gr.Textbox(label="Name")
process_src_ae = gr.Textbox(label='Source directory')
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
with gr.Tab(label="Create hypernetwork"): with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
...@@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding, fn=modules.textual_inversion.ui.create_embedding,
inputs=[ inputs=[
new_embedding_name, new_embedding_name,
initialization_text, process_src,
nvpt, nvpt,
], ],
outputs=[ outputs=[
...@@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call):
] ]
) )
create_embedding_ae.click(
fn=modules.aesthetic_clip.generate_imgs_embd,
inputs=[
new_embedding_name_ae,
process_src_ae,
batch_ae
],
outputs=[
aesthetic_imgs,
ti_output,
ti_outcome,
]
)
create_hypernetwork.click( create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork, fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[ inputs=[
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment