Commit 6d805b66 authored by AUTOMATIC's avatar AUTOMATIC

make CLIP interrogator download original text files if the directory does not exist

remove random artist built-in extension (to re-added as a normal extension on demand)
remove artists.csv (but what does it mean????????????????????)
make interrogate buttons show Loading... when you click them
parent 40ff6db5
...@@ -49,7 +49,6 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -49,7 +49,6 @@ A browser interface based on Gradio library for Stable Diffusion.
- Running arbitrary python code from UI (must run with --allow-code to enable) - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
- Tiling support, a checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview - Progress bar and live image generation preview
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
......
This source diff could not be displayed because it is too large. You can view the blob instead.
import random
from modules import script_callbacks, shared
import gradio as gr
art_symbol = '\U0001f3a8' # 🎨
global_prompt = None
related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" }
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
return prompt + ", " + artist.name if prompt != '' else artist.name
def add_roll_button(prompt):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
prompt,
],
outputs=[
prompt,
]
)
def after_component(component, **kwargs):
global global_prompt
elem_id = kwargs.get('elem_id', None)
if elem_id not in related_ids:
return
if elem_id == "txt2img_prompt":
global_prompt = component
elif elem_id == "txt2img_clear_prompt":
add_roll_button(global_prompt)
elif elem_id == "img2img_prompt":
global_prompt = component
elif elem_id == "img2img_clear_prompt":
add_roll_button(global_prompt)
script_callbacks.on_after_component(after_component)
...@@ -14,7 +14,6 @@ titles = { ...@@ -14,7 +14,6 @@ titles = {
"Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result",
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"\u{1f4be}": "Save style", "\u{1f4be}": "Save style",
......
...@@ -126,8 +126,6 @@ class Api: ...@@ -126,8 +126,6 @@ class Api:
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
...@@ -390,12 +388,6 @@ class Api: ...@@ -390,12 +388,6 @@ class Api:
return styleList return styleList
def get_artists_categories(self):
return shared.artist_db.cats
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def get_embeddings(self): def get_embeddings(self):
db = sd_hijack.model_hijack.embedding_db db = sd_hijack.model_hijack.embedding_db
......
import os.path
import csv
from collections import namedtuple
Artist = namedtuple("Artist", ['name', 'weight', 'category'])
class ArtistsDatabase:
def __init__(self, filename):
self.cats = set()
self.artists = []
if not os.path.exists(filename):
return
with open(filename, "r", newline='', encoding="utf8") as file:
reader = csv.DictReader(file)
for row in reader:
artist = Artist(row["artist"], float(row["score"]), row["category"])
self.artists.append(artist)
self.cats.add(artist.category)
def categories(self):
return sorted(self.cats)
...@@ -5,12 +5,13 @@ from collections import namedtuple ...@@ -5,12 +5,13 @@ from collections import namedtuple
import re import re
import torch import torch
import torch.hub
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths, lowvram, modelloader from modules import devices, paths, lowvram, modelloader, errors
blip_image_eval_size = 384 blip_image_eval_size = 384
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
...@@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) ...@@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.") re_topn = re.compile(r"\.top(\d+)\.")
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
try:
os.makedirs(tmpdir)
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
os.rename(tmpdir, content_dir)
except Exception as e:
errors.display(e, "downloading default CLIP interrogate categories")
finally:
if os.path.exists(tmpdir):
os.remove(tmpdir)
class InterrogateModels: class InterrogateModels:
blip_model = None blip_model = None
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
categories = None
dtype = None dtype = None
running_on_cpu = None running_on_cpu = None
def __init__(self, content_dir): def __init__(self, content_dir):
self.categories = [] self.loaded_categories = None
self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu") self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir): def categories(self):
for filename in os.listdir(content_dir): if self.loaded_categories is not None:
return self.loaded_categories
self.loaded_categories = []
if not os.path.exists(self.content_dir):
download_default_clip_interrogate_categories(self.content_dir)
if os.path.exists(self.content_dir):
for filename in os.listdir(self.content_dir):
m = re_topn.search(filename) m = re_topn.search(filename)
topn = 1 if m is None else int(m.group(1)) topn = 1 if m is None else int(m.group(1))
with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()] lines = [x.strip() for x in file.readlines()]
self.categories.append(Category(name=filename, topn=topn, items=lines)) self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
return self.loaded_categories
def load_blip_model(self): def load_blip_model(self):
import models.blip import models.blip
...@@ -139,7 +172,6 @@ class InterrogateModels: ...@@ -139,7 +172,6 @@ class InterrogateModels:
shared.state.begin() shared.state.begin()
shared.state.job = 'interrogate' shared.state.job = 'interrogate'
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
devices.torch_gc() devices.torch_gc()
...@@ -159,12 +191,7 @@ class InterrogateModels: ...@@ -159,12 +191,7 @@ class InterrogateModels:
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
if shared.opts.interrogate_use_builtin_artists: for name, topn, items in self.categories():
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
res += ", " + artist[0]
for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
if shared.opts.interrogate_return_ranks: if shared.opts.interrogate_return_ranks:
......
...@@ -9,7 +9,6 @@ from PIL import Image ...@@ -9,7 +9,6 @@ from PIL import Image
import gradio as gr import gradio as gr
import tqdm import tqdm
import modules.artists
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
...@@ -254,8 +253,6 @@ class State: ...@@ -254,8 +253,6 @@ class State:
state = State() state = State()
state.server_start = time.time() state.server_start = time.time()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
styles_filename = cmd_opts.styles_file styles_filename = cmd_opts.styles_file
prompt_styles = modules.styles.StyleDatabase(styles_filename) prompt_styles = modules.styles.StyleDatabase(styles_filename)
...@@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
})) }))
options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('compatibility', "Compatibility"), {
...@@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { ...@@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
......
...@@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di ...@@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di
left, _ = os.path.splitext(filename) left, _ = os.path.splitext(filename)
print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a'))
return [gr_show(True), None] return [gr.update(), None]
def interrogate(image): def interrogate(image):
prompt = shared.interrogator.interrogate(image.convert("RGB")) prompt = shared.interrogator.interrogate(image.convert("RGB"))
return gr_show(True) if prompt is None else prompt return gr.update() if prompt is None else prompt
def interrogate_deepbooru(image): def interrogate_deepbooru(image):
prompt = deepbooru.model.tag(image) prompt = deepbooru.model.tag(image)
return gr_show(True) if prompt is None else prompt return gr.update() if prompt is None else prompt
def create_seed_inputs(target_interface): def create_seed_inputs(target_interface):
...@@ -1039,19 +1039,18 @@ def create_ui(): ...@@ -1039,19 +1039,18 @@ def create_ui():
init_img_inpaint, init_img_inpaint,
], ],
outputs=[img2img_prompt, dummy_component], outputs=[img2img_prompt, dummy_component],
show_progress=False,
) )
img2img_prompt.submit(**img2img_args) img2img_prompt.submit(**img2img_args)
submit.click(**img2img_args) submit.click(**img2img_args)
img2img_interrogate.click( img2img_interrogate.click(
fn=lambda *args : process_interrogate(interrogate, *args), fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args, **interrogate_args,
) )
img2img_deepbooru.click( img2img_deepbooru.click(
fn=lambda *args : process_interrogate(interrogate_deepbooru, *args), fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
**interrogate_args, **interrogate_args,
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment