Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
c7e50425
Commit
c7e50425
authored
Jan 19, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add progress bar to modelmerger
parent
7cfc6450
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
40 additions
and
9 deletions
+40
-9
ui.js
javascript/ui.js
+11
-0
extras.py
modules/extras.py
+15
-3
progress.py
modules/progress.py
+1
-1
ui.py
modules/ui.py
+8
-5
style.css
style.css
+5
-0
No files found.
javascript/ui.js
View file @
c7e50425
...
@@ -172,6 +172,17 @@ function submit_img2img(){
...
@@ -172,6 +172,17 @@ function submit_img2img(){
return
res
return
res
}
}
function
modelmerger
(){
var
id
=
randomId
()
requestProgress
(
id
,
gradioApp
().
getElementById
(
'modelmerger_results_panel'
),
null
,
function
(){})
gradioApp
().
getElementById
(
'modelmerger_result'
).
innerHTML
=
''
var
res
=
create_submit_args
(
arguments
)
res
[
0
]
=
id
return
res
}
function
ask_for_style_name
(
_
,
prompt_text
,
negative_prompt_text
)
{
function
ask_for_style_name
(
_
,
prompt_text
,
negative_prompt_text
)
{
name_
=
prompt
(
'Style name:'
)
name_
=
prompt
(
'Style name:'
)
...
...
modules/extras.py
View file @
c7e50425
...
@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
...
@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
shutil
.
copyfile
(
cfg
,
checkpoint_filename
)
shutil
.
copyfile
(
cfg
,
checkpoint_filename
)
def
run_modelmerger
(
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
):
def
run_modelmerger
(
id_task
,
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
):
shared
.
state
.
begin
()
shared
.
state
.
begin
()
shared
.
state
.
job
=
'model-merge'
shared
.
state
.
job
=
'model-merge'
shared
.
state
.
job_count
=
1
def
fail
(
message
):
def
fail
(
message
):
shared
.
state
.
textinfo
=
message
shared
.
state
.
textinfo
=
message
shared
.
state
.
end
()
shared
.
state
.
end
()
return
[
message
,
*
[
gr
.
update
()
for
_
in
range
(
4
)]
]
return
[
*
[
gr
.
update
()
for
_
in
range
(
4
)],
message
]
def
weighted_sum
(
theta0
,
theta1
,
alpha
):
def
weighted_sum
(
theta0
,
theta1
,
alpha
):
return
((
1
-
alpha
)
*
theta0
)
+
(
alpha
*
theta1
)
return
((
1
-
alpha
)
*
theta0
)
+
(
alpha
*
theta1
)
...
@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
...
@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1
=
sd_models
.
read_state_dict
(
secondary_model_info
.
filename
,
map_location
=
'cpu'
)
theta_1
=
sd_models
.
read_state_dict
(
secondary_model_info
.
filename
,
map_location
=
'cpu'
)
if
theta_func1
:
if
theta_func1
:
shared
.
state
.
job_count
+=
1
print
(
f
"Loading {tertiary_model_info.filename}..."
)
print
(
f
"Loading {tertiary_model_info.filename}..."
)
theta_2
=
sd_models
.
read_state_dict
(
tertiary_model_info
.
filename
,
map_location
=
'cpu'
)
theta_2
=
sd_models
.
read_state_dict
(
tertiary_model_info
.
filename
,
map_location
=
'cpu'
)
shared
.
state
.
sampling_steps
=
len
(
theta_1
.
keys
())
for
key
in
tqdm
.
tqdm
(
theta_1
.
keys
()):
for
key
in
tqdm
.
tqdm
(
theta_1
.
keys
()):
if
'model'
in
key
:
if
'model'
in
key
:
if
key
in
theta_2
:
if
key
in
theta_2
:
...
@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
...
@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
theta_1
[
key
]
=
theta_func1
(
theta_1
[
key
],
t2
)
theta_1
[
key
]
=
theta_func1
(
theta_1
[
key
],
t2
)
else
:
else
:
theta_1
[
key
]
=
torch
.
zeros_like
(
theta_1
[
key
])
theta_1
[
key
]
=
torch
.
zeros_like
(
theta_1
[
key
])
shared
.
state
.
sampling_step
+=
1
del
theta_2
del
theta_2
shared
.
state
.
nextjob
()
shared
.
state
.
textinfo
=
f
"Loading {primary_model_info.filename}..."
shared
.
state
.
textinfo
=
f
"Loading {primary_model_info.filename}..."
print
(
f
"Loading {primary_model_info.filename}..."
)
print
(
f
"Loading {primary_model_info.filename}..."
)
theta_0
=
sd_models
.
read_state_dict
(
primary_model_info
.
filename
,
map_location
=
'cpu'
)
theta_0
=
sd_models
.
read_state_dict
(
primary_model_info
.
filename
,
map_location
=
'cpu'
)
...
@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
...
@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
chckpoint_dict_skip_on_merge
=
[
"cond_stage_model.transformer.text_model.embeddings.position_ids"
]
chckpoint_dict_skip_on_merge
=
[
"cond_stage_model.transformer.text_model.embeddings.position_ids"
]
shared
.
state
.
sampling_steps
=
len
(
theta_0
.
keys
())
for
key
in
tqdm
.
tqdm
(
theta_0
.
keys
()):
for
key
in
tqdm
.
tqdm
(
theta_0
.
keys
()):
if
'model'
in
key
and
key
in
theta_1
:
if
'model'
in
key
and
key
in
theta_1
:
...
@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
...
@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
if
save_as_half
:
if
save_as_half
:
theta_0
[
key
]
=
theta_0
[
key
]
.
half
()
theta_0
[
key
]
=
theta_0
[
key
]
.
half
()
shared
.
state
.
sampling_step
+=
1
# I believe this part should be discarded, but I'll leave it for now until I am sure
# I believe this part should be discarded, but I'll leave it for now until I am sure
for
key
in
theta_1
.
keys
():
for
key
in
theta_1
.
keys
():
if
'model'
in
key
and
key
not
in
theta_0
:
if
'model'
in
key
and
key
not
in
theta_0
:
...
@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
...
@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
output_modelname
=
os
.
path
.
join
(
ckpt_dir
,
filename
)
output_modelname
=
os
.
path
.
join
(
ckpt_dir
,
filename
)
shared
.
state
.
nextjob
()
shared
.
state
.
textinfo
=
f
"Saving to {output_modelname}..."
shared
.
state
.
textinfo
=
f
"Saving to {output_modelname}..."
print
(
f
"Saving to {output_modelname}..."
)
print
(
f
"Saving to {output_modelname}..."
)
...
@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
...
@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
shared
.
state
.
textinfo
=
"Checkpoint saved to "
+
output_modelname
shared
.
state
.
textinfo
=
"Checkpoint saved to "
+
output_modelname
shared
.
state
.
end
()
shared
.
state
.
end
()
return
[
"Checkpoint saved to "
+
output_modelname
]
+
[
gr
.
Dropdown
.
update
(
choices
=
sd_models
.
checkpoint_tiles
())
for
_
in
range
(
4
)
]
return
[
*
[
gr
.
Dropdown
.
update
(
choices
=
sd_models
.
checkpoint_tiles
())
for
_
in
range
(
4
)],
"Checkpoint saved to "
+
output_modelname
]
modules/progress.py
View file @
c7e50425
...
@@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):
...
@@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):
if
job_count
>
0
:
if
job_count
>
0
:
progress
+=
job_no
/
job_count
progress
+=
job_no
/
job_count
if
sampling_steps
>
0
:
if
sampling_steps
>
0
and
job_count
>
0
:
progress
+=
1
/
job_count
*
sampling_step
/
sampling_steps
progress
+=
1
/
job_count
*
sampling_step
/
sampling_steps
progress
=
min
(
progress
,
1
)
progress
=
min
(
progress
,
1
)
...
...
modules/ui.py
View file @
c7e50425
...
@@ -1208,8 +1208,9 @@ def create_ui():
...
@@ -1208,8 +1208,9 @@ def create_ui():
with
gr
.
Row
():
with
gr
.
Row
():
modelmerger_merge
=
gr
.
Button
(
elem_id
=
"modelmerger_merge"
,
value
=
"Merge"
,
variant
=
'primary'
)
modelmerger_merge
=
gr
.
Button
(
elem_id
=
"modelmerger_merge"
,
value
=
"Merge"
,
variant
=
'primary'
)
with
gr
.
Column
(
variant
=
'panel'
):
with
gr
.
Column
(
variant
=
'compact'
,
elem_id
=
"modelmerger_results_container"
):
submit_result
=
gr
.
Textbox
(
elem_id
=
"modelmerger_result"
,
show_label
=
False
)
with
gr
.
Group
(
elem_id
=
"modelmerger_results_panel"
):
modelmerger_result
=
gr
.
HTML
(
elem_id
=
"modelmerger_result"
,
show_label
=
False
)
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
train_interface
:
with
gr
.
Blocks
(
analytics_enabled
=
False
)
as
train_interface
:
with
gr
.
Row
()
.
style
(
equal_height
=
False
):
with
gr
.
Row
()
.
style
(
equal_height
=
False
):
...
@@ -1753,12 +1754,14 @@ def create_ui():
...
@@ -1753,12 +1754,14 @@ def create_ui():
print
(
"Error loading/saving model file:"
,
file
=
sys
.
stderr
)
print
(
"Error loading/saving model file:"
,
file
=
sys
.
stderr
)
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
modules
.
sd_models
.
list_models
()
# to remove the potentially missing models from the list
modules
.
sd_models
.
list_models
()
# to remove the potentially missing models from the list
return
[
f
"Error merging checkpoints: {e}"
]
+
[
gr
.
Dropdown
.
update
(
choices
=
modules
.
sd_models
.
checkpoint_tiles
())
for
_
in
range
(
4
)
]
return
[
*
[
gr
.
Dropdown
.
update
(
choices
=
modules
.
sd_models
.
checkpoint_tiles
())
for
_
in
range
(
4
)],
f
"Error merging checkpoints: {e}"
]
return
results
return
results
modelmerger_merge
.
click
(
modelmerger_merge
.
click
(
fn
=
modelmerger
,
fn
=
wrap_gradio_gpu_call
(
modelmerger
,
extra_outputs
=
lambda
:
[
gr
.
update
()
for
_
in
range
(
4
)]),
_js
=
'modelmerger'
,
inputs
=
[
inputs
=
[
dummy_component
,
primary_model_name
,
primary_model_name
,
secondary_model_name
,
secondary_model_name
,
tertiary_model_name
,
tertiary_model_name
,
...
@@ -1770,11 +1773,11 @@ def create_ui():
...
@@ -1770,11 +1773,11 @@ def create_ui():
config_source
,
config_source
,
],
],
outputs
=
[
outputs
=
[
submit_result
,
primary_model_name
,
primary_model_name
,
secondary_model_name
,
secondary_model_name
,
tertiary_model_name
,
tertiary_model_name
,
component_dict
[
'sd_model_checkpoint'
],
component_dict
[
'sd_model_checkpoint'
],
modelmerger_result
,
]
]
)
)
...
...
style.css
View file @
c7e50425
...
@@ -737,6 +737,11 @@ footer {
...
@@ -737,6 +737,11 @@ footer {
line-height
:
2.4em
;
line-height
:
2.4em
;
}
}
#modelmerger_results_container
{
margin-top
:
1em
;
overflow
:
visible
;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic.
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
If you change anything above, you need to make sure it is RTL compliant by just running
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment