Commit d717eb07 authored by Greg Fuller's avatar Greg Fuller

Interrogate: add option to include ranks in output

Since the UI also allows users to specify ranks, it can be useful to show people what ranks are being returned by interrogate

This can also give much better results when feeding the interrogate results back into either img2img or txt2img, especially when trying to generate a specific character or scene for which you have a similar concept image

Testing Steps:

Launch Webui with command line arg: --deepdanbooru
Navigate to img2img tab, use interrogate DeepBooru, verify tags appears as before. Use "Interrogate CLIP", verify prompt appears as before
Navigate to Settings tab, enable new option, click "apply settings"
Navigate to img2img, Interrogate DeepBooru again, verify that weights appear and are properly formatted. Note that "Interrogate CLIP" prompt is still unchanged
In my testing, this change has no effect to "Interrogate CLIP", as it seems to generate a sentence-structured caption, and not a set of tags.

(reproduce changes from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2149/commits/6ed4faac46c45ca7353f228aca9b436bbaba7bc7)
parent 6be32b31
...@@ -3,7 +3,7 @@ from concurrent.futures import ProcessPoolExecutor ...@@ -3,7 +3,7 @@ from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context from multiprocessing import get_context
def _load_tf_and_return_tags(pil_image, threshold): def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
import deepdanbooru as dd import deepdanbooru as dd
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
...@@ -52,12 +52,16 @@ def _load_tf_and_return_tags(pil_image, threshold): ...@@ -52,12 +52,16 @@ def _load_tf_and_return_tags(pil_image, threshold):
if result_dict[tag] >= threshold: if result_dict[tag] >= threshold:
if tag.startswith("rating:"): if tag.startswith("rating:"):
continue continue
result_tags_out.append(tag) tag_formatted = tag.replace('_', ' ').replace(':', ' ')
if include_ranks:
result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})')
else:
result_tags_out.append(tag_formatted)
result_tags_print.append(f'{result_dict[tag]} {tag}') result_tags_print.append(f'{result_dict[tag]} {tag}')
print('\n'.join(sorted(result_tags_print, reverse=True))) print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') return ', '.join(result_tags_out)
def subprocess_init_no_cuda(): def subprocess_init_no_cuda():
...@@ -65,9 +69,9 @@ def subprocess_init_no_cuda(): ...@@ -65,9 +69,9 @@ def subprocess_init_no_cuda():
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def get_deepbooru_tags(pil_image, threshold=0.5): def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False):
context = get_context('spawn') context = get_context('spawn')
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks)
ret = f.result() # will rethrow any exceptions ret = f.result() # will rethrow any exceptions
return ret return ret
\ No newline at end of file
...@@ -123,7 +123,7 @@ class InterrogateModels: ...@@ -123,7 +123,7 @@ class InterrogateModels:
return caption[0] return caption[0]
def interrogate(self, pil_image): def interrogate(self, pil_image, include_ranks=False):
res = None res = None
try: try:
...@@ -156,7 +156,10 @@ class InterrogateModels: ...@@ -156,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
res += ", " + match if include_ranks:
res += ", " + match
else:
res += f", ({match}:{score})"
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print(f"Error interrogating", file=sys.stderr)
......
...@@ -251,6 +251,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -251,6 +251,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
......
...@@ -311,13 +311,12 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name): ...@@ -311,13 +311,12 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name):
def interrogate(image): def interrogate(image):
prompt = shared.interrogator.interrogate(image) prompt = shared.interrogator.interrogate(image, include_ranks=opts.interrogate_return_ranks)
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
def interrogate_deepbooru(image): def interrogate_deepbooru(image):
prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold, opts.interrogate_return_ranks)
return gr_show(True) if prompt is None else prompt return gr_show(True) if prompt is None else prompt
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment