Commit 3b0127e6 authored by Fampai's avatar Fampai

Merge branch 'master' of...

Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diffusion-webui into TI_optimizations
parents 006756f9 9b384dfb
...@@ -29,3 +29,5 @@ notification.mp3 ...@@ -29,3 +29,5 @@ notification.mp3
/textual_inversion /textual_inversion
.vscode .vscode
/extensions /extensions
/test/stdout.txt
/test/stderr.txt
...@@ -128,10 +128,12 @@ def prepare_enviroment(): ...@@ -128,10 +128,12 @@ def prepare_enviroment():
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
sys.argv += shlex.split(commandline_args) sys.argv += shlex.split(commandline_args)
test_argv = [x for x in sys.argv if x != '--tests']
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests = extract_arg(sys.argv, '--tests')
xformers = '--xformers' in sys.argv xformers = '--xformers' in sys.argv
deepdanbooru = '--deepdanbooru' in sys.argv deepdanbooru = '--deepdanbooru' in sys.argv
ngrok = '--ngrok' in sys.argv ngrok = '--ngrok' in sys.argv
...@@ -194,6 +196,26 @@ def prepare_enviroment(): ...@@ -194,6 +196,26 @@ def prepare_enviroment():
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)
if run_tests:
tests(test_argv)
exit(0)
def tests(argv):
if "--api" not in argv:
argv.append("--api")
print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}")
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr)
import test.server_poll
test.server_poll.run_tests()
print(f"Stopping Web UI process with id {proc.pid}")
proc.kill()
def start_webui(): def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
......
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
import time
import uvicorn import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import devices
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
def upscaler_to_index(name: str): def upscaler_to_index(name: str):
try: try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except: except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def setUpscalers(req: dict): def setUpscalers(req: dict):
reqDict = vars(req) reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
...@@ -23,6 +28,7 @@ def setUpscalers(req: dict): ...@@ -23,6 +28,7 @@ def setUpscalers(req: dict):
reqDict.pop('upscaler_2') reqDict.pop('upscaler_2')
return reqDict return reqDict
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
self.router = APIRouter() self.router = APIRouter()
...@@ -33,15 +39,16 @@ class Api: ...@@ -33,15 +39,16 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None: if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found") raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
...@@ -49,34 +56,39 @@ class Api: ...@@ -49,34 +56,39 @@ class Api:
) )
p = StableDiffusionProcessingTxt2Img(**vars(populate)) p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param # Override object param
shared.state.begin()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index) sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None: if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found") raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images init_images = img2imgreq.init_images
if init_images is None: if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found") raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask mask = img2imgreq.mask
if mask: if mask:
mask = decode_base64_to_image(mask) mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": sampler_index[0], "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True, "do_not_save_grid": True,
"mask": mask "mask": mask
} }
) )
...@@ -88,16 +100,20 @@ class Api: ...@@ -88,16 +100,20 @@ class Api:
imgs = [img] * p.batch_size imgs = [img] * p.batch_size
p.init_images = imgs p.init_images = imgs
# Override object param
shared.state.begin()
with self.queue_lock: with self.queue_lock:
processed = process_images(p) processed = process_images(p)
shared.state.end()
b64images = list(map(encode_pil_to_base64, processed.images)) b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images): if (not img2imgreq.include_init_images):
img2imgreq.init_images = None img2imgreq.init_images = None
img2imgreq.mask = None img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js()) return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest): def extras_single_image_api(self, req: ExtrasSingleImageRequest):
...@@ -125,7 +141,7 @@ class Api: ...@@ -125,7 +141,7 @@ class Api:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict) result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self, req: PNGInfoRequest): def pnginfoapi(self, req: PNGInfoRequest):
if(not req.image.strip()): if(not req.image.strip()):
return PNGInfoResponse(info="") return PNGInfoResponse(info="")
...@@ -134,6 +150,32 @@ class Api: ...@@ -134,6 +150,32 @@ class Api:
return PNGInfoResponse(info=result[1]) return PNGInfoResponse(info=result[1])
def progressapi(self, req: ProgressRequest = Depends()):
# copy from check_progress_call of ui.py
if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
# avoid dividing zero
progress = 0.01
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
progress = min(progress, 1)
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port)
...@@ -52,17 +52,17 @@ class PydanticModelGenerator: ...@@ -52,17 +52,17 @@ class PydanticModelGenerator:
# field_type = str if not overrides.get(k) else overrides[k]["type"] # field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default) # print(k, v.annotation, v.default)
field_type = v.annotation field_type = v.annotation
return Optional[field_type] return Optional[field_type]
def merge_class_params(class_): def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {} parameters = {}
for classes in all_classes: for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters} parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters return parameters
self._model_name = model_name self._model_name = model_name
self._class_data = merge_class_params(class_instance) self._class_data = merge_class_params(class_instance)
self._model_def = [ self._model_def = [
...@@ -74,11 +74,11 @@ class PydanticModelGenerator: ...@@ -74,11 +74,11 @@ class PydanticModelGenerator:
) )
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
] ]
for fields in additional_fields: for fields in additional_fields:
self._model_def.append(ModelDef( self._model_def.append(ModelDef(
field=underscore(fields["key"]), field=underscore(fields["key"]),
field_alias=fields["key"], field_alias=fields["key"],
field_type=fields["type"], field_type=fields["type"],
field_value=fields["default"], field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False)) field_exclude=fields["exclude"] if "exclude" in fields else False))
...@@ -95,15 +95,15 @@ class PydanticModelGenerator: ...@@ -95,15 +95,15 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True DynamicModel.__config__.allow_mutation = True
return DynamicModel return DynamicModel
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img", "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img, StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}] [{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model() ).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img", "StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model() ).generate_model()
...@@ -155,4 +155,13 @@ class PNGInfoRequest(BaseModel): ...@@ -155,4 +155,13 @@ class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image") image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel): class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had") info: str = Field(title="Image info", description="A string with all the info the image had")
\ No newline at end of file
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
...@@ -66,6 +66,7 @@ def integrate_settings_paste_fields(component_dict): ...@@ -66,6 +66,7 @@ def integrate_settings_paste_fields(component_dict):
settings_map = { settings_map = {
'sd_hypernetwork': 'Hypernet', 'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip', 'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash', 'sd_model_checkpoint': 'Model hash',
} }
......
...@@ -209,13 +209,16 @@ def list_hypernetworks(path): ...@@ -209,13 +209,16 @@ def list_hypernetworks(path):
res = {} res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename # Prevent a hypothetical "None.pt" from being listed.
if name != "None":
res[name] = filename
return res return res
def load_hypernetwork(filename): def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None) path = shared.hypernetworks.get(filename, None)
if path is not None: # Prevent any file named "None.pt" from being loaded.
if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}") print(f"Loading hypernetwork {filename}")
try: try:
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
...@@ -332,7 +335,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -332,7 +335,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# images allows training previews to have infotext. Importing it at the top causes a circular import problem. # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images from modules import images
assert hypernetwork_name, 'hypernetwork not selected' save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork = Hypernetwork()
...@@ -358,39 +363,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -358,39 +363,44 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
else: else:
images_dir = None images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
size = len(ds.indexes) size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024)) loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,)) losses = torch.zeros((size,))
previous_mean_losses = [0] previous_mean_losses = [0]
previous_mean_loss = 0 previous_mean_loss = 0
print("Mean loss of {} elements".format(size)) print("Mean loss of {} elements".format(size))
last_saved_file = "<none>" weights = hypernetwork.weights()
last_saved_image = "<none>" for weight in weights:
forced_filename = "<none>" weight.requires_grad = True
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0 steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
...@@ -443,9 +453,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -443,9 +453,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
hypernetwork.name = f'{hypernetwork_name}-{steps_done}' hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
hypernetwork.save(last_saved_file) save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}", "loss": f"{previous_mean_loss:.7f}",
...@@ -506,13 +516,23 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -506,13 +516,23 @@ Last saved image: {html.escape(last_saved_image)}<br/>
""" """
report_statistics(loss_dict) report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint()
hypernetwork.sd_checkpoint = checkpoint.hash filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
hypernetwork.sd_checkpoint_name = checkpoint.model_name save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
hypernetwork.name = hypernetwork_name
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_hypernetwork_name = hypernetwork.name
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
try:
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
hypernetwork.name = hypernetwork_name
hypernetwork.save(filename)
except:
hypernetwork.sd_checkpoint = old_sd_checkpoint
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
hypernetwork.name = old_hypernetwork_name
raise
...@@ -396,6 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -396,6 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
...@@ -686,15 +687,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -686,15 +687,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
image_conditioning = self.txt2img_image_conditioning(x)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None
devices.torch_gc() devices.torch_gc()
image_conditioning = self.img2img_image_conditioning(
decoded_samples,
samples,
decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples return samples
......
...@@ -144,9 +144,38 @@ class State: ...@@ -144,9 +144,38 @@ class State:
self.sampling_step = 0 self.sampling_step = 0
self.current_image_sampling_step = 0 self.current_image_sampling_step = 0
def get_job_timestamp(self): def dict(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
"job": self.job,
"job_count": self.job_count,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return obj
def begin(self):
self.sampling_step = 0
self.job_count = -1
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
devices.torch_gc()
def end(self):
self.job = ""
self.job_count = 0
devices.torch_gc()
state = State() state = State()
......
...@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset): ...@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
self.lines = lines self.lines = lines
assert data_root, 'dataset directory not specified' assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
......
...@@ -4,30 +4,37 @@ import tqdm ...@@ -4,30 +4,37 @@ import tqdm
class LearnScheduleIterator: class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0): def __init__(self, learn_rate, max_steps, cur_step=0):
""" """
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
""" """
pairs = learn_rate.split(',') pairs = learn_rate.split(',')
self.rates = [] self.rates = []
self.it = 0 self.it = 0
self.maxit = 0 self.maxit = 0
for i, pair in enumerate(pairs): try:
tmp = pair.split(':') for i, pair in enumerate(pairs):
if len(tmp) == 2: if not pair.strip():
step = int(tmp[1]) continue
if step > cur_step: tmp = pair.split(':')
self.rates.append((float(tmp[0]), min(step, max_steps))) if len(tmp) == 2:
self.maxit += 1 step = int(tmp[1])
if step > max_steps: if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
return
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return return
elif step == -1: else:
self.rates.append((float(tmp[0]), max_steps)) self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1 self.maxit += 1
return return
else: assert self.rates
self.rates.append((float(tmp[0]), max_steps)) except (ValueError, AssertionError):
self.maxit += 1 raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
return
def __iter__(self): def __iter__(self):
return self return self
......
...@@ -119,7 +119,7 @@ class EmbeddingDatabase: ...@@ -119,7 +119,7 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32) vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = data.get('step', None) embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('hash', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model) self.register_embedding(embedding, shared.sd_model)
...@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values, **values,
}) })
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive"
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
assert template_file, "Prompt template file is empty"
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0 , "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
assert save_model_every >= 0 , "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
assert create_image_every >= 0 , "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected' save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps shared.state.job_count = steps
...@@ -233,19 +254,30 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -233,19 +254,30 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True) os.makedirs(images_embeds_dir, exist_ok=True)
else: else:
images_embeds_dir = None images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload: if unload:
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
losses = torch.zeros((32,)) losses = torch.zeros((32,))
...@@ -254,13 +286,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -254,13 +286,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
forced_filename = "<none>" forced_filename = "<none>"
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar: for i, entries in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
...@@ -293,9 +318,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -293,9 +318,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if embedding_dir is not None and steps_done % save_embedding_every == 0: if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
embedding.name = f'{embedding_name}-{steps_done}' embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
embedding.save(last_saved_file) save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
...@@ -382,14 +407,26 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -382,14 +407,26 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
checkpoint = sd_models.select_checkpoint() filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
embedding.name = embedding_name
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
embedding.save(filename)
return embedding, filename return embedding, filename
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
try:
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint
embedding.sd_checkpoint_name = old_sd_checkpoint_name
embedding.name = old_embedding_name
embedding.cached_checksum = old_cached_checksum
raise
import unittest
class TestExtrasWorking(unittest.TestCase):
def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image"
self.simple_extras = {
"resize_mode": 0,
"show_extras_results": True,
"gfpgan_visibility": 0,
"codeformer_visibility": 0,
"codeformer_weight": 0,
"upscaling_resize": 2,
"upscaling_resize_w": 512,
"upscaling_resize_h": 512,
"upscaling_crop": True,
"upscaler_1": "None",
"upscaler_2": "None",
"extras_upscaler_2_visibility": 0,
"image": ""
}
class TestExtrasCorrectness(unittest.TestCase):
pass
if __name__ == "__main__":
unittest.main()
import unittest
import requests
from gradio.processing_utils import encode_pil_to_base64
from PIL import Image
class TestImg2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
self.simple_img2img = {
"init_images": [encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png"))],
"resize_mode": 0,
"denoising_strength": 0.75,
"mask": None,
"mask_blur": 4,
"inpainting_fill": 0,
"inpaint_full_res": False,
"inpaint_full_res_padding": 0,
"inpainting_mask_invert": 0,
"prompt": "example prompt",
"styles": [],
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
"seed_resize_from_w": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 3,
"cfg_scale": 7,
"width": 64,
"height": 64,
"restore_faces": False,
"tiling": False,
"negative_prompt": "",
"eta": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 1,
"override_settings": {},
"sampler_index": "Euler a",
"include_init_images": False
}
def test_img2img_simple_performed(self):
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
def test_inpainting_masked_performed(self):
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
class TestImg2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__":
unittest.main()
import unittest
import requests
import time
def run_tests():
timeout_threshold = 240
start_time = time.time()
while time.time()-start_time < timeout_threshold:
try:
requests.head("http://localhost:7860/")
break
except requests.exceptions.ConnectionError:
pass
if time.time()-start_time < timeout_threshold:
suite = unittest.TestLoader().discover('', pattern='*_test.py')
result = unittest.TextTestRunner(verbosity=2).run(suite)
else:
print("Launch unsuccessful")
import unittest
import requests
class TestTxt2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
self.simple_txt2img = {
"enable_hr": False,
"denoising_strength": 0,
"firstphase_width": 0,
"firstphase_height": 0,
"prompt": "example prompt",
"styles": [],
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
"seed_resize_from_w": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 3,
"cfg_scale": 7,
"width": 64,
"height": 64,
"restore_faces": False,
"tiling": False,
"negative_prompt": "",
"eta": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 1,
"sampler_index": "Euler a"
}
def test_txt2img_simple_performed(self):
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_negative_prompt_performed(self):
self.simple_txt2img["negative_prompt"] = "example negative prompt"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_not_square_image_performed(self):
self.simple_txt2img["height"] = 128
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_hrfix_performed(self):
self.simple_txt2img["enable_hr"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_restore_faces_performed(self):
self.simple_txt2img["restore_faces"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_tiling_faces_performed(self):
self.simple_txt2img["tiling"] = True
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_vanilla_sampler_performed(self):
self.simple_txt2img["sampler_index"] = "PLMS"
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
def test_txt2img_multiple_batches_performed(self):
self.simple_txt2img["n_iter"] = 2
self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200)
class TestTxt2ImgCorrectness(unittest.TestCase):
pass
if __name__ == "__main__":
unittest.main()
...@@ -46,26 +46,13 @@ def wrap_queued_call(func): ...@@ -46,26 +46,13 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None): def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc()
shared.state.begin()
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
shared.state.job_timestamp = shared.state.get_job_timestamp()
shared.state.current_latent = None
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.skipped = False
shared.state.interrupted = False
shared.state.textinfo = None
with queue_lock: with queue_lock:
res = func(*args, **kwargs) res = func(*args, **kwargs)
shared.state.job = "" shared.state.end()
shared.state.job_count = 0
devices.torch_gc()
return res return res
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment