Commit c7e50425 authored by AUTOMATIC's avatar AUTOMATIC

add progress bar to modelmerger

parent 7cfc6450
...@@ -172,6 +172,17 @@ function submit_img2img(){ ...@@ -172,6 +172,17 @@ function submit_img2img(){
return res return res
} }
function modelmerger(){
var id = randomId()
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
gradioApp().getElementById('modelmerger_result').innerHTML = ''
var res = create_submit_args(arguments)
res[0] = id
return res
}
function ask_for_style_name(_, prompt_text, negative_prompt_text) { function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:') name_ = prompt('Style name:')
......
...@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c): ...@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil.copyfile(cfg, checkpoint_filename) shutil.copyfile(cfg, checkpoint_filename)
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin() shared.state.begin()
shared.state.job = 'model-merge' shared.state.job = 'model-merge'
shared.state.job_count = 1
def fail(message): def fail(message):
shared.state.textinfo = message shared.state.textinfo = message
shared.state.end() shared.state.end()
return [message, *[gr.update() for _ in range(4)]] return [*[gr.update() for _ in range(4)], message]
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
...@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
if theta_func1: if theta_func1:
shared.state.job_count += 1
print(f"Loading {tertiary_model_info.filename}...") print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()): for key in tqdm.tqdm(theta_1.keys()):
if 'model' in key: if 'model' in key:
if key in theta_2: if key in theta_2:
...@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1[key] = theta_func1(theta_1[key], t2) theta_1[key] = theta_func1(theta_1[key], t2)
else: else:
theta_1[key] = torch.zeros_like(theta_1[key]) theta_1[key] = torch.zeros_like(theta_1[key])
shared.state.sampling_step += 1
del theta_2 del theta_2
shared.state.nextjob()
shared.state.textinfo = f"Loading {primary_model_info.filename}..." shared.state.textinfo = f"Loading {primary_model_info.filename}..."
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
...@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
...@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
shared.state.sampling_step += 1
# I believe this part should be discarded, but I'll leave it for now until I am sure # I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
...@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
shared.state.nextjob()
shared.state.textinfo = f"Saving to {output_modelname}..." shared.state.textinfo = f"Saving to {output_modelname}..."
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
...@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam ...@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end() shared.state.end()
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
...@@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest): ...@@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):
if job_count > 0: if job_count > 0:
progress += job_no / job_count progress += job_no / job_count
if sampling_steps > 0: if sampling_steps > 0 and job_count > 0:
progress += 1 / job_count * sampling_step / sampling_steps progress += 1 / job_count * sampling_step / sampling_steps
progress = min(progress, 1) progress = min(progress, 1)
......
...@@ -1208,8 +1208,9 @@ def create_ui(): ...@@ -1208,8 +1208,9 @@ def create_ui():
with gr.Row(): with gr.Row():
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) with gr.Group(elem_id="modelmerger_results_panel"):
modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Blocks(analytics_enabled=False) as train_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
...@@ -1753,12 +1754,14 @@ def create_ui(): ...@@ -1753,12 +1754,14 @@ def create_ui():
print("Error loading/saving model file:", file=sys.stderr) print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() # to remove the potentially missing models from the list modules.sd_models.list_models() # to remove the potentially missing models from the list
return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results return results
modelmerger_merge.click( modelmerger_merge.click(
fn=modelmerger, fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger',
inputs=[ inputs=[
dummy_component,
primary_model_name, primary_model_name,
secondary_model_name, secondary_model_name,
tertiary_model_name, tertiary_model_name,
...@@ -1770,11 +1773,11 @@ def create_ui(): ...@@ -1770,11 +1773,11 @@ def create_ui():
config_source, config_source,
], ],
outputs=[ outputs=[
submit_result,
primary_model_name, primary_model_name,
secondary_model_name, secondary_model_name,
tertiary_model_name, tertiary_model_name,
component_dict['sd_model_checkpoint'], component_dict['sd_model_checkpoint'],
modelmerger_result,
] ]
) )
......
...@@ -737,6 +737,11 @@ footer { ...@@ -737,6 +737,11 @@ footer {
line-height: 2.4em; line-height: 2.4em;
} }
#modelmerger_results_container{
margin-top: 1em;
overflow: visible;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment