Commit 9bb20be0 authored by AUTOMATIC's avatar AUTOMATIC

memory optimization for CLIP interrogator

changed default cfg_scale to a higher value
parent ab0a79cd
...@@ -11,7 +11,7 @@ from torchvision import transforms ...@@ -11,7 +11,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths from modules import devices, paths, lowvram
blip_image_eval_size = 384 blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth' blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
...@@ -75,19 +75,28 @@ class InterrogateModels: ...@@ -75,19 +75,28 @@ class InterrogateModels:
self.dtype = next(self.clip_model.parameters()).dtype self.dtype = next(self.clip_model.parameters()).dtype
def unload(self): def send_clip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory: if not shared.opts.interrogate_keep_models_in_memory:
if self.clip_model is not None: if self.clip_model is not None:
self.clip_model = self.clip_model.to(devices.cpu) self.clip_model = self.clip_model.to(devices.cpu)
def send_blip_to_ram(self):
if not shared.opts.interrogate_keep_models_in_memory:
if self.blip_model is not None: if self.blip_model is not None:
self.blip_model = self.blip_model.to(devices.cpu) self.blip_model = self.blip_model.to(devices.cpu)
devices.torch_gc() def unload(self):
self.send_clip_to_ram()
self.send_blip_to_ram()
devices.torch_gc()
def rank(self, image_features, text_array, top_count=1): def rank(self, image_features, text_array, top_count=1):
import clip import clip
if shared.opts.interrogate_clip_dict_limit != 0:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array]).to(shared.device) text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
...@@ -117,16 +126,24 @@ class InterrogateModels: ...@@ -117,16 +126,24 @@ class InterrogateModels:
res = None res = None
try: try:
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
devices.torch_gc()
self.load() self.load()
caption = self.generate_caption(pil_image) caption = self.generate_caption(pil_image)
self.send_blip_to_ram()
devices.torch_gc()
res = caption res = caption
images = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
image_features = self.clip_model.encode_image(images).type(self.dtype) image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
image_features /= image_features.norm(dim=-1, keepdim=True) image_features /= image_features.norm(dim=-1, keepdim=True)
...@@ -146,4 +163,5 @@ class InterrogateModels: ...@@ -146,4 +163,5 @@ class InterrogateModels:
self.unload() self.unload()
res += "<error>"
return res return res
...@@ -5,6 +5,16 @@ module_in_gpu = None ...@@ -5,6 +5,16 @@ module_in_gpu = None
cpu = torch.device("cpu") cpu = torch.device("cpu")
device = gpu = get_optimal_device() device = gpu = get_optimal_device()
def send_everything_to_cpu():
global module_in_gpu
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module_in_gpu = None
def setup_for_low_vram(sd_model, use_medvram): def setup_for_low_vram(sd_model, use_medvram):
parents = {} parents = {}
......
...@@ -132,6 +132,7 @@ class Options: ...@@ -132,6 +132,7 @@ class Options:
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum descripton length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum descripton length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
} }
def __init__(self): def __init__(self):
......
...@@ -270,7 +270,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -270,7 +270,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1) batch_count = gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count', value=1)
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
with gr.Group(): with gr.Group():
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
...@@ -413,7 +413,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -413,7 +413,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1)
with gr.Group(): with gr.Group():
cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='CFG Scale', value=7.0) cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75)
denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, visible=False) denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, visible=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment