Commit 8fb9c57e authored by AUTOMATIC's avatar AUTOMATIC

add half() supporrt for CLIP interrogation

parent d97c6f22
...@@ -14,3 +14,9 @@ def get_optimal_device(): ...@@ -14,3 +14,9 @@ def get_optimal_device():
return torch.device("mps") return torch.device("mps")
return cpu return cpu
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from modules import processing, shared, images from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
...@@ -11,7 +11,7 @@ cached_images = {} ...@@ -11,7 +11,7 @@ cached_images = {}
def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): def run_extras(image, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
processing.torch_gc() devices.torch_gc()
image = image.convert("RGB") image = image.convert("RGB")
info = "" info = ""
......
...@@ -3,6 +3,7 @@ import cv2 ...@@ -3,6 +3,7 @@ import cv2
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules import devices
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
...@@ -131,7 +132,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init ...@@ -131,7 +132,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
upscaler = shared.sd_upscalers[upscaler_index] upscaler = shared.sd_upscalers[upscaler_index]
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2) img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
processing.torch_gc() devices.torch_gc()
grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap) grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
......
import contextlib
import os import os
import sys import sys
import traceback import traceback
...@@ -6,7 +7,6 @@ import re ...@@ -6,7 +7,6 @@ import re
import torch import torch
from PIL import Image
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -26,6 +26,7 @@ class InterrogateModels: ...@@ -26,6 +26,7 @@ class InterrogateModels:
clip_model = None clip_model = None
clip_preprocess = None clip_preprocess = None
categories = None categories = None
dtype = None
def __init__(self, content_dir): def __init__(self, content_dir):
self.categories = [] self.categories = []
...@@ -60,14 +61,20 @@ class InterrogateModels: ...@@ -60,14 +61,20 @@ class InterrogateModels:
def load(self): def load(self):
if self.blip_model is None: if self.blip_model is None:
self.blip_model = self.load_blip_model() self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(shared.device) self.blip_model = self.blip_model.to(shared.device)
if self.clip_model is None: if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(shared.device) self.clip_model = self.clip_model.to(shared.device)
self.dtype = next(self.clip_model.parameters()).dtype
def unload(self): def unload(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None: if self.clip_model is not None:
...@@ -76,14 +83,14 @@ class InterrogateModels: ...@@ -76,14 +83,14 @@ class InterrogateModels:
if self.blip_model is not None: if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu) self.blip_model = self.blip_model.to(devices.cpu)
devices.torch_gc()
def rank(self, image_features, text_array, top_count=1): def rank(self, image_features, text_array, top_count=1):
import clip import clip
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).cuda() text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
with torch.no_grad(): text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features = self.clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(shared.device) similarity = torch.zeros((1, len(text_array))).to(shared.device)
...@@ -94,13 +101,12 @@ class InterrogateModels: ...@@ -94,13 +101,12 @@ class InterrogateModels:
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1) top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
def generate_caption(self, pil_image): def generate_caption(self, pil_image):
gpu_image = transforms.Compose([ gpu_image = transforms.Compose([
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).to(shared.device) ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
with torch.no_grad(): with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
...@@ -116,22 +122,23 @@ class InterrogateModels: ...@@ -116,22 +122,23 @@ class InterrogateModels:
caption = self.generate_caption(pil_image) caption = self.generate_caption(pil_image)
res = caption res = caption
images = self.clip_preprocess(pil_image).unsqueeze(0).to(shared.device) images = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
with torch.no_grad(): precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
image_features = self.clip_model.encode_image(images).float() with torch.no_grad(), precision_scope("cuda"):
image_features = self.clip_model.encode_image(images).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
if shared.opts.interrogate_use_builtin_artists: if shared.opts.interrogate_use_builtin_artists:
artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
res += ", " + artist[0] res += ", " + artist[0]
for name, topn, items in self.categories: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
res += ", " + match res += ", " + match
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print(f"Error interrogating", file=sys.stderr)
......
...@@ -10,6 +10,7 @@ from PIL import Image, ImageFilter, ImageOps ...@@ -10,6 +10,7 @@ from PIL import Image, ImageFilter, ImageOps
import random import random
import modules.sd_hijack import modules.sd_hijack
from modules import devices
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
...@@ -23,11 +24,6 @@ opt_C = 4 ...@@ -23,11 +24,6 @@ opt_C = 4
opt_f = 8 opt_f = 8
def torch_gc():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
...@@ -157,7 +153,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -157,7 +153,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
assert p.prompt is not None assert p.prompt is not None
torch_gc() devices.torch_gc()
fix_seed(p) fix_seed(p)
...@@ -258,7 +254,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -258,7 +254,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
if p.restore_faces: if p.restore_faces:
torch_gc() devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample) x_sample = modules.face_restoration.restore_faces(x_sample)
...@@ -297,7 +293,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -297,7 +293,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext()) return Processed(p, output_images, all_seeds[0], infotext())
......
...@@ -4,7 +4,7 @@ import modules.scripts as scripts ...@@ -4,7 +4,7 @@ import modules.scripts as scripts
import gradio as gr import gradio as gr
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
from modules import images, processing from modules import images, processing, devices
from modules.processing import Processed, process_images from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
...@@ -77,7 +77,7 @@ class Script(scripts.Script): ...@@ -77,7 +77,7 @@ class Script(scripts.Script):
mask.height - down - (mask_blur//2 if down > 0 else 0) mask.height - down - (mask_blur//2 if down > 0 else 0)
), fill="black") ), fill="black")
processing.torch_gc() devices.torch_gc()
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels) grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels) grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment