Unverified Commit 2f90496b authored by Keavon Chambers's avatar Keavon Chambers Committed by GitHub

Merge branch 'master' into cors-regex

parents a258fd60 47a44c7e
// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes
let txt2img_gallery, img2img_gallery, modal = undefined;
onUiUpdate(function(){
if (!txt2img_gallery) {
txt2img_gallery = attachGalleryListeners("txt2img")
}
if (!img2img_gallery) {
img2img_gallery = attachGalleryListeners("img2img")
}
if (!modal) {
modal = gradioApp().getElementById('lightboxModal')
modalObserver.observe(modal, { attributes : true, attributeFilter : ['style'] });
}
});
let modalObserver = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
});
});
function attachGalleryListeners(tab_name) {
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
gradioApp().getElementById(tab_name+"_generation_info_button").click()
});
return gallery;
}
...@@ -105,22 +105,26 @@ def version_check(commit): ...@@ -105,22 +105,26 @@ def version_check(commit):
print("version check failed", e) print("version check failed", e)
def run_extension_installer(extension_dir):
path_installer = os.path.join(extension_dir, "install.py")
if not os.path.isfile(path_installer):
return
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def run_extensions_installers(): def run_extensions_installers():
if not os.path.isdir(dir_extensions): if not os.path.isdir(dir_extensions):
return return
for dirname_extension in os.listdir(dir_extensions): for dirname_extension in os.listdir(dir_extensions):
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py") run_extension_installer(os.path.join(dir_extensions, dirname_extension))
if not os.path.isfile(path_installer):
continue
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def prepare_enviroment(): def prepare_enviroment():
...@@ -135,9 +139,9 @@ def prepare_enviroment(): ...@@ -135,9 +139,9 @@ def prepare_enviroment():
xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl')
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/CompVis/stable-diffusion.git")
taming_transformers_repo = os.environ.get('TAMING_REANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMET_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -170,14 +170,15 @@ class ProgressResponse(BaseModel): ...@@ -170,14 +170,15 @@ class ProgressResponse(BaseModel):
class InterrogateRequest(BaseModel): class InterrogateRequest(BaseModel):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
model: str = Field(default="clip", title="Model", description="The interrogate model used.")
class InterrogateResponse(BaseModel): class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
fields = {} fields = {}
for key, value in opts.data.items(): for key, metadata in opts.data_labels.items():
metadata = opts.data_labels.get(key) value = opts.data.get(key)
optType = opts.typemap.get(type(value), type(value)) optType = opts.typemap.get(type(metadata.default), type(value))
if (metadata is not None): if (metadata is not None):
fields.update({key: (Optional[optType], Field( fields.update({key: (Optional[optType], Field(
......
...@@ -3,16 +3,27 @@ import contextlib ...@@ -3,16 +3,27 @@ import contextlib
import torch import torch
from modules import errors from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu") # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
# check `getattr` and try it for compatibility
def has_mps() -> bool:
if not getattr(torch, 'has_mps', False):
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
def extract_device_id(args, name): def extract_device_id(args, name):
for x in range(len(args)): for x in range(len(args)):
if name in args[x]: return args[x+1] if name in args[x]:
return args[x + 1]
return None return None
def get_optimal_device(): def get_optimal_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
from modules import shared from modules import shared
...@@ -25,7 +36,7 @@ def get_optimal_device(): ...@@ -25,7 +36,7 @@ def get_optimal_device():
else: else:
return torch.device("cuda") return torch.device("cuda")
if has_mps: if has_mps():
return torch.device("mps") return torch.device("mps")
return cpu return cpu
...@@ -45,10 +56,12 @@ def enable_tf32(): ...@@ -45,10 +56,12 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps': if device.type == 'mps':
...@@ -82,6 +95,11 @@ def autocast(disable=False): ...@@ -82,6 +95,11 @@ def autocast(disable=False):
return torch.autocast("cuda") return torch.autocast("cuda")
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor def mps_contiguous(input_tensor, device):
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) return input_tensor.contiguous() if device.type == 'mps' else input_tensor
def mps_contiguous_to(input_tensor, device):
return mps_contiguous(input_tensor, device).to(device)
...@@ -6,7 +6,6 @@ import git ...@@ -6,7 +6,6 @@ import git
from modules import paths, shared from modules import paths, shared
extensions = [] extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions") extensions_dir = os.path.join(paths.script_path, "extensions")
...@@ -66,9 +65,12 @@ class Extension: ...@@ -66,9 +65,12 @@ class Extension:
self.can_update = False self.can_update = False
self.status = "latest" self.status = "latest"
def pull(self): def fetch_and_reset_hard(self):
repo = git.Repo(self.path) repo = git.Repo(self.path)
repo.remotes.origin.pull() # Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo.git.fetch('--all')
repo.git.reset('--hard', 'origin')
def list_extensions(): def list_extensions():
...@@ -84,3 +86,4 @@ def list_extensions(): ...@@ -84,3 +86,4 @@ def list_extensions():
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension) extensions.append(extension)
...@@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict): ...@@ -73,6 +73,7 @@ def integrate_settings_paste_fields(component_dict):
'sd_hypernetwork': 'Hypernet', 'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength', 'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'inpainting_mask_weight': 'Conditional mask weight',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
} }
settings_paste_fields = [ settings_paste_fields = [
......
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
import tqdm import tqdm
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import default from ldm.util import default
from modules import devices, processing, sd_models, shared from modules import devices, processing, sd_models, shared, sd_samplers
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum from torch import einsum
...@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p.prompt = preview_prompt p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt p.negative_prompt = preview_negative_prompt
p.steps = preview_steps p.steps = preview_steps
p.sampler_index = preview_sampler_index p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
p.cfg_scale = preview_cfg_scale p.cfg_scale = preview_cfg_scale
p.seed = preview_seed p.seed = preview_seed
p.width = preview_width p.width = preview_width
......
...@@ -303,7 +303,7 @@ class FilenameGenerator: ...@@ -303,7 +303,7 @@ class FilenameGenerator:
'width': lambda self: self.image.width, 'width': lambda self: self.image.width,
'height': lambda self: self.image.height, 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
......
...@@ -6,7 +6,7 @@ import traceback ...@@ -6,7 +6,7 @@ import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules import devices from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
...@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h=seed_resize_from_h, seed_resize_from_h=seed_resize_from_h,
seed_resize_from_w=seed_resize_from_w, seed_resize_from_w=seed_resize_from_w,
seed_enable_extras=seed_enable_extras, seed_enable_extras=seed_enable_extras,
sampler_index=sampler_index, sampler_index=sd_samplers.samplers_for_img2img[sampler_index].name,
batch_size=batch_size, batch_size=batch_size,
n_iter=n_iter, n_iter=n_iter,
steps=steps, steps=steps,
......
from pyngrok import ngrok, conf, exception from pyngrok import ngrok, conf, exception
def connect(token, port, region): def connect(token, port, region):
account = None
if token == None: if token == None:
token = 'None' token = 'None'
else:
if ':' in token:
# token = authtoken:username:password
account = token.split(':')[1] + ':' + token.split(':')[-1]
token = token.split(':')[0]
config = conf.PyngrokConfig( config = conf.PyngrokConfig(
auth_token=token, region=region auth_token=token, region=region
) )
try: try:
public_url = ngrok.connect(port, pyngrok_config=config).public_url if account == None:
public_url = ngrok.connect(port, pyngrok_config=config).public_url
else:
public_url = ngrok.connect(port, pyngrok_config=config, auth=account).public_url
except exception.PyngrokNgrokError: except exception.PyngrokNgrokError:
print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' print(f'Invalid ngrok authtoken, ngrok connection aborted.\n'
f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment