Commit 4b0dc206 authored by AUTOMATIC's avatar AUTOMATIC

use modelloader for #4956

parent 2a649154
import contextlib
import os import os
import sys import sys
import traceback import traceback
...@@ -11,12 +10,9 @@ from torchvision import transforms ...@@ -11,12 +10,9 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import modules.shared as shared import modules.shared as shared
from modules import devices, paths, lowvram from modules import devices, paths, lowvram, modelloader
blip_image_eval_size = 384 blip_image_eval_size = 384
blip_local_dir = os.path.join('models', 'Interrogator')
blip_local_file = os.path.join(blip_local_dir, 'model_base_caption_capfilt_large.pth')
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14' clip_model_name = 'ViT-L/14'
Category = namedtuple("Category", ["name", "topn", "items"]) Category = namedtuple("Category", ["name", "topn", "items"])
...@@ -49,16 +45,14 @@ class InterrogateModels: ...@@ -49,16 +45,14 @@ class InterrogateModels:
def load_blip_model(self): def load_blip_model(self):
import models.blip import models.blip
if not os.path.isfile(blip_local_file): files = modelloader.load_models(
if not os.path.isdir(blip_local_dir): model_path=os.path.join(paths.models_path, "BLIP"),
os.mkdir(blip_local_dir) model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',
ext_filter=[".pth"],
download_name='model_base_caption_capfilt_large.pth',
)
print("Downloading BLIP...") blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
from requests import get as reqget
open(blip_local_file, 'wb').write(reqget(blip_model_url, allow_redirects=True).content)
print("BLIP downloaded to", blip_local_file + '.')
blip_model = models.blip.blip_decoder(pretrained=blip_local_file, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
blip_model.eval() blip_model.eval()
return blip_model return blip_model
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment