Unverified Commit 7ba7f4ed authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #7113 from vladmandic/interrogate

Add selector to interrogate categories
parents 7b1c7ba8 04a561c1
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import sys import sys
import traceback import traceback
from collections import namedtuple from collections import namedtuple
from pathlib import Path
import re import re
import torch import torch
...@@ -20,19 +21,20 @@ Category = namedtuple("Category", ["name", "topn", "items"]) ...@@ -20,19 +21,20 @@ Category = namedtuple("Category", ["name", "topn", "items"])
re_topn = re.compile(r"\.top(\d+)\.") re_topn = re.compile(r"\.top(\d+)\.")
def category_types():
return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]
def download_default_clip_interrogate_categories(content_dir): def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...") print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp" tmpdir = content_dir + "_tmp"
category_types = ["artists", "flavors", "mediums", "movements"]
try: try:
os.makedirs(tmpdir) os.makedirs(tmpdir)
for category_type in category_types:
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))
os.rename(tmpdir, content_dir) os.rename(tmpdir, content_dir)
except Exception as e: except Exception as e:
...@@ -51,27 +53,32 @@ class InterrogateModels: ...@@ -51,27 +53,32 @@ class InterrogateModels:
def __init__(self, content_dir): def __init__(self, content_dir):
self.loaded_categories = None self.loaded_categories = None
self.skip_categories = []
self.content_dir = content_dir self.content_dir = content_dir
self.running_on_cpu = devices.device_interrogate == torch.device("cpu") self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
def categories(self): def categories(self):
if self.loaded_categories is not None:
return self.loaded_categories
self.loaded_categories = []
if not os.path.exists(self.content_dir): if not os.path.exists(self.content_dir):
download_default_clip_interrogate_categories(self.content_dir) download_default_clip_interrogate_categories(self.content_dir)
if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
return self.loaded_categories
self.loaded_categories = []
if os.path.exists(self.content_dir): if os.path.exists(self.content_dir):
for filename in os.listdir(self.content_dir): self.skip_categories = shared.opts.interrogate_clip_skip_categories
m = re_topn.search(filename) category_types = []
for filename in Path(self.content_dir).glob('*.txt'):
category_types.append(filename.stem)
if filename.stem in self.skip_categories:
continue
m = re_topn.search(filename.stem)
topn = 1 if m is None else int(m.group(1)) topn = 1 if m is None else int(m.group(1))
with open(filename, "r", encoding="utf8") as file:
with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
lines = [x.strip() for x in file.readlines()] lines = [x.strip() for x in file.readlines()]
self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))
return self.loaded_categories return self.loaded_categories
...@@ -139,6 +146,8 @@ class InterrogateModels: ...@@ -139,6 +146,8 @@ class InterrogateModels:
def rank(self, image_features, text_array, top_count=1): def rank(self, image_features, text_array, top_count=1):
import clip import clip
devices.torch_gc()
if shared.opts.interrogate_clip_dict_limit != 0: if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
......
...@@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), ...@@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment