Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
8662b5e5
Commit
8662b5e5
authored
Nov 19, 2022
by
Muhammad Rizqi Nur
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'a1111' into vae-fix-none
parents
45dca056
ff35ae9a
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
66 additions
and
72 deletions
+66
-72
api.py
modules/api/api.py
+9
-17
extensions.py
modules/extensions.py
+5
-2
hypernetwork.py
modules/hypernetworks/hypernetwork.py
+2
-2
images.py
modules/images.py
+1
-1
img2img.py
modules/img2img.py
+2
-2
processing.py
modules/processing.py
+13
-17
sd_models.py
modules/sd_models.py
+2
-8
sd_samplers.py
modules/sd_samplers.py
+10
-3
sd_vae.py
modules/sd_vae.py
+7
-6
textual_inversion.py
modules/textual_inversion/textual_inversion.py
+2
-2
txt2img.py
modules/txt2img.py
+2
-1
ui.py
modules/ui.py
+1
-1
ui_extensions.py
modules/ui_extensions.py
+2
-2
img2imgalt.py
scripts/img2imgalt.py
+2
-2
xy_grid.py
scripts/xy_grid.py
+6
-6
No files found.
modules/api/api.py
View file @
8662b5e5
...
@@ -6,9 +6,9 @@ from threading import Lock
...
@@ -6,9 +6,9 @@ from threading import Lock
from
gradio.processing_utils
import
encode_pil_to_base64
,
decode_base64_to_file
,
decode_base64_to_image
from
gradio.processing_utils
import
encode_pil_to_base64
,
decode_base64_to_file
,
decode_base64_to_image
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
HTTPException
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
HTTPException
import
modules.shared
as
shared
import
modules.shared
as
shared
from
modules
import
sd_samplers
from
modules.api.models
import
*
from
modules.api.models
import
*
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.processing
import
StableDiffusionProcessingTxt2Img
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.sd_samplers
import
all_samplers
from
modules.extras
import
run_extras
,
run_pnginfo
from
modules.extras
import
run_extras
,
run_pnginfo
from
PIL
import
PngImagePlugin
from
PIL
import
PngImagePlugin
from
modules.sd_models
import
checkpoints_list
from
modules.sd_models
import
checkpoints_list
...
@@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
...
@@ -25,8 +25,12 @@ def upscaler_to_index(name: str):
raise
HTTPException
(
status_code
=
400
,
detail
=
f
"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}"
)
raise
HTTPException
(
status_code
=
400
,
detail
=
f
"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}"
)
sampler_to_index
=
lambda
name
:
next
(
filter
(
lambda
row
:
name
.
lower
()
==
row
[
1
]
.
name
.
lower
(),
enumerate
(
all_samplers
)),
None
)
def
validate_sampler_name
(
name
):
config
=
sd_samplers
.
all_samplers_map
.
get
(
name
,
None
)
if
config
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
return
name
def
setUpscalers
(
req
:
dict
):
def
setUpscalers
(
req
:
dict
):
reqDict
=
vars
(
req
)
reqDict
=
vars
(
req
)
...
@@ -82,14 +86,9 @@ class Api:
...
@@ -82,14 +86,9 @@ class Api:
self
.
app
.
add_api_route
(
"/sdapi/v1/artists"
,
self
.
get_artists
,
methods
=
[
"GET"
],
response_model
=
List
[
ArtistItem
])
self
.
app
.
add_api_route
(
"/sdapi/v1/artists"
,
self
.
get_artists
,
methods
=
[
"GET"
],
response_model
=
List
[
ArtistItem
])
def
text2imgapi
(
self
,
txt2imgreq
:
StableDiffusionTxt2ImgProcessingAPI
):
def
text2imgapi
(
self
,
txt2imgreq
:
StableDiffusionTxt2ImgProcessingAPI
):
sampler_index
=
sampler_to_index
(
txt2imgreq
.
sampler_index
)
if
sampler_index
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
populate
=
txt2imgreq
.
copy
(
update
=
{
# Override __init__ params
populate
=
txt2imgreq
.
copy
(
update
=
{
# Override __init__ params
"sd_model"
:
shared
.
sd_model
,
"sd_model"
:
shared
.
sd_model
,
"sampler_
index"
:
sampler_index
[
0
]
,
"sampler_
name"
:
validate_sampler_name
(
txt2imgreq
.
sampler_index
)
,
"do_not_save_samples"
:
True
,
"do_not_save_samples"
:
True
,
"do_not_save_grid"
:
True
"do_not_save_grid"
:
True
}
}
...
@@ -109,12 +108,6 @@ class Api:
...
@@ -109,12 +108,6 @@ class Api:
return
TextToImageResponse
(
images
=
b64images
,
parameters
=
vars
(
txt2imgreq
),
info
=
processed
.
js
())
return
TextToImageResponse
(
images
=
b64images
,
parameters
=
vars
(
txt2imgreq
),
info
=
processed
.
js
())
def
img2imgapi
(
self
,
img2imgreq
:
StableDiffusionImg2ImgProcessingAPI
):
def
img2imgapi
(
self
,
img2imgreq
:
StableDiffusionImg2ImgProcessingAPI
):
sampler_index
=
sampler_to_index
(
img2imgreq
.
sampler_index
)
if
sampler_index
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Sampler not found"
)
init_images
=
img2imgreq
.
init_images
init_images
=
img2imgreq
.
init_images
if
init_images
is
None
:
if
init_images
is
None
:
raise
HTTPException
(
status_code
=
404
,
detail
=
"Init image not found"
)
raise
HTTPException
(
status_code
=
404
,
detail
=
"Init image not found"
)
...
@@ -123,10 +116,9 @@ class Api:
...
@@ -123,10 +116,9 @@ class Api:
if
mask
:
if
mask
:
mask
=
decode_base64_to_image
(
mask
)
mask
=
decode_base64_to_image
(
mask
)
populate
=
img2imgreq
.
copy
(
update
=
{
# Override __init__ params
populate
=
img2imgreq
.
copy
(
update
=
{
# Override __init__ params
"sd_model"
:
shared
.
sd_model
,
"sd_model"
:
shared
.
sd_model
,
"sampler_
index"
:
sampler_index
[
0
]
,
"sampler_
name"
:
validate_sampler_name
(
img2imgreq
.
sampler_index
)
,
"do_not_save_samples"
:
True
,
"do_not_save_samples"
:
True
,
"do_not_save_grid"
:
True
,
"do_not_save_grid"
:
True
,
"mask"
:
mask
"mask"
:
mask
...
@@ -272,7 +264,7 @@ class Api:
...
@@ -272,7 +264,7 @@ class Api:
return
vars
(
shared
.
cmd_opts
)
return
vars
(
shared
.
cmd_opts
)
def
get_samplers
(
self
):
def
get_samplers
(
self
):
return
[{
"name"
:
sampler
[
0
],
"aliases"
:
sampler
[
2
],
"options"
:
sampler
[
3
]}
for
sampler
in
all_samplers
]
return
[{
"name"
:
sampler
[
0
],
"aliases"
:
sampler
[
2
],
"options"
:
sampler
[
3
]}
for
sampler
in
sd_samplers
.
all_samplers
]
def
get_upscalers
(
self
):
def
get_upscalers
(
self
):
upscalers
=
[]
upscalers
=
[]
...
...
modules/extensions.py
View file @
8662b5e5
...
@@ -65,9 +65,12 @@ class Extension:
...
@@ -65,9 +65,12 @@ class Extension:
self
.
can_update
=
False
self
.
can_update
=
False
self
.
status
=
"latest"
self
.
status
=
"latest"
def
pull
(
self
):
def
fetch_and_reset_hard
(
self
):
repo
=
git
.
Repo
(
self
.
path
)
repo
=
git
.
Repo
(
self
.
path
)
repo
.
remotes
.
origin
.
pull
()
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
repo
.
git
.
fetch
(
'--all'
)
repo
.
git
.
reset
(
'--hard'
,
'origin'
)
def
list_extensions
():
def
list_extensions
():
...
...
modules/hypernetworks/hypernetwork.py
View file @
8662b5e5
...
@@ -12,7 +12,7 @@ import torch
...
@@ -12,7 +12,7 @@ import torch
import
tqdm
import
tqdm
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
ldm.util
import
default
from
ldm.util
import
default
from
modules
import
devices
,
processing
,
sd_models
,
shared
from
modules
import
devices
,
processing
,
sd_models
,
shared
,
sd_samplers
from
modules.textual_inversion
import
textual_inversion
from
modules.textual_inversion
import
textual_inversion
from
modules.textual_inversion.learn_schedule
import
LearnRateScheduler
from
modules.textual_inversion.learn_schedule
import
LearnRateScheduler
from
torch
import
einsum
from
torch
import
einsum
...
@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
...
@@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
p
.
prompt
=
preview_prompt
p
.
prompt
=
preview_prompt
p
.
negative_prompt
=
preview_negative_prompt
p
.
negative_prompt
=
preview_negative_prompt
p
.
steps
=
preview_steps
p
.
steps
=
preview_steps
p
.
sampler_
index
=
preview_sampler_index
p
.
sampler_
name
=
sd_samplers
.
samplers
[
preview_sampler_index
]
.
name
p
.
cfg_scale
=
preview_cfg_scale
p
.
cfg_scale
=
preview_cfg_scale
p
.
seed
=
preview_seed
p
.
seed
=
preview_seed
p
.
width
=
preview_width
p
.
width
=
preview_width
...
...
modules/images.py
View file @
8662b5e5
...
@@ -303,7 +303,7 @@ class FilenameGenerator:
...
@@ -303,7 +303,7 @@ class FilenameGenerator:
'width'
:
lambda
self
:
self
.
image
.
width
,
'width'
:
lambda
self
:
self
.
image
.
width
,
'height'
:
lambda
self
:
self
.
image
.
height
,
'height'
:
lambda
self
:
self
.
image
.
height
,
'styles'
:
lambda
self
:
self
.
p
and
sanitize_filename_part
(
", "
.
join
([
style
for
style
in
self
.
p
.
styles
if
not
style
==
"None"
])
or
"None"
,
replace_spaces
=
False
),
'styles'
:
lambda
self
:
self
.
p
and
sanitize_filename_part
(
", "
.
join
([
style
for
style
in
self
.
p
.
styles
if
not
style
==
"None"
])
or
"None"
,
replace_spaces
=
False
),
'sampler'
:
lambda
self
:
self
.
p
and
sanitize_filename_part
(
s
d_samplers
.
samplers
[
self
.
p
.
sampler_index
]
.
name
,
replace_spaces
=
False
),
'sampler'
:
lambda
self
:
self
.
p
and
sanitize_filename_part
(
s
elf
.
p
.
sampler_
name
,
replace_spaces
=
False
),
'model_hash'
:
lambda
self
:
getattr
(
self
.
p
,
"sd_model_hash"
,
shared
.
sd_model
.
sd_model_hash
),
'model_hash'
:
lambda
self
:
getattr
(
self
.
p
,
"sd_model_hash"
,
shared
.
sd_model
.
sd_model_hash
),
'date'
:
lambda
self
:
datetime
.
datetime
.
now
()
.
strftime
(
'
%
Y-
%
m-
%
d'
),
'date'
:
lambda
self
:
datetime
.
datetime
.
now
()
.
strftime
(
'
%
Y-
%
m-
%
d'
),
'datetime'
:
lambda
self
,
*
args
:
self
.
datetime
(
*
args
),
# accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'datetime'
:
lambda
self
,
*
args
:
self
.
datetime
(
*
args
),
# accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
...
...
modules/img2img.py
View file @
8662b5e5
...
@@ -6,7 +6,7 @@ import traceback
...
@@ -6,7 +6,7 @@ import traceback
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
,
ImageOps
,
ImageChops
from
PIL
import
Image
,
ImageOps
,
ImageChops
from
modules
import
devices
from
modules
import
devices
,
sd_samplers
from
modules.processing
import
Processed
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.processing
import
Processed
,
StableDiffusionProcessingImg2Img
,
process_images
from
modules.shared
import
opts
,
state
from
modules.shared
import
opts
,
state
import
modules.shared
as
shared
import
modules.shared
as
shared
...
@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
...
@@ -99,7 +99,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
seed_resize_from_h
=
seed_resize_from_h
,
seed_resize_from_h
=
seed_resize_from_h
,
seed_resize_from_w
=
seed_resize_from_w
,
seed_resize_from_w
=
seed_resize_from_w
,
seed_enable_extras
=
seed_enable_extras
,
seed_enable_extras
=
seed_enable_extras
,
sampler_index
=
s
ampler_index
,
sampler_index
=
s
d_samplers
.
samplers_for_img2img
[
sampler_index
]
.
name
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
n_iter
=
n_iter
,
n_iter
=
n_iter
,
steps
=
steps
,
steps
=
steps
,
...
...
modules/processing.py
View file @
8662b5e5
...
@@ -2,6 +2,7 @@ import json
...
@@ -2,6 +2,7 @@ import json
import
math
import
math
import
os
import
os
import
sys
import
sys
import
warnings
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
...
@@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):
...
@@ -66,19 +67,15 @@ def apply_overlay(image, paste_loc, index, overlays):
return
image
return
image
def
get_correct_sampler
(
p
):
if
isinstance
(
p
,
modules
.
processing
.
StableDiffusionProcessingTxt2Img
):
return
sd_samplers
.
samplers
elif
isinstance
(
p
,
modules
.
processing
.
StableDiffusionProcessingImg2Img
):
return
sd_samplers
.
samplers_for_img2img
elif
isinstance
(
p
,
modules
.
api
.
processing
.
StableDiffusionProcessingAPI
):
return
sd_samplers
.
samplers
class
StableDiffusionProcessing
():
class
StableDiffusionProcessing
():
"""
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
"""
def
__init__
(
self
,
sd_model
=
None
,
outpath_samples
=
None
,
outpath_grids
=
None
,
prompt
:
str
=
""
,
styles
:
List
[
str
]
=
None
,
seed
:
int
=
-
1
,
subseed
:
int
=
-
1
,
subseed_strength
:
float
=
0
,
seed_resize_from_h
:
int
=
-
1
,
seed_resize_from_w
:
int
=
-
1
,
seed_enable_extras
:
bool
=
True
,
sampler_index
:
int
=
0
,
batch_size
:
int
=
1
,
n_iter
:
int
=
1
,
steps
:
int
=
50
,
cfg_scale
:
float
=
7.0
,
width
:
int
=
512
,
height
:
int
=
512
,
restore_faces
:
bool
=
False
,
tiling
:
bool
=
False
,
do_not_save_samples
:
bool
=
False
,
do_not_save_grid
:
bool
=
False
,
extra_generation_params
:
Dict
[
Any
,
Any
]
=
None
,
overlay_images
:
Any
=
None
,
negative_prompt
:
str
=
None
,
eta
:
float
=
None
,
do_not_reload_embeddings
:
bool
=
False
,
denoising_strength
:
float
=
0
,
ddim_discretize
:
str
=
None
,
s_churn
:
float
=
0.0
,
s_tmax
:
float
=
None
,
s_tmin
:
float
=
0.0
,
s_noise
:
float
=
1.0
,
override_settings
:
Dict
[
str
,
Any
]
=
None
):
def
__init__
(
self
,
sd_model
=
None
,
outpath_samples
=
None
,
outpath_grids
=
None
,
prompt
:
str
=
""
,
styles
:
List
[
str
]
=
None
,
seed
:
int
=
-
1
,
subseed
:
int
=
-
1
,
subseed_strength
:
float
=
0
,
seed_resize_from_h
:
int
=
-
1
,
seed_resize_from_w
:
int
=
-
1
,
seed_enable_extras
:
bool
=
True
,
sampler_name
:
str
=
None
,
batch_size
:
int
=
1
,
n_iter
:
int
=
1
,
steps
:
int
=
50
,
cfg_scale
:
float
=
7.0
,
width
:
int
=
512
,
height
:
int
=
512
,
restore_faces
:
bool
=
False
,
tiling
:
bool
=
False
,
do_not_save_samples
:
bool
=
False
,
do_not_save_grid
:
bool
=
False
,
extra_generation_params
:
Dict
[
Any
,
Any
]
=
None
,
overlay_images
:
Any
=
None
,
negative_prompt
:
str
=
None
,
eta
:
float
=
None
,
do_not_reload_embeddings
:
bool
=
False
,
denoising_strength
:
float
=
0
,
ddim_discretize
:
str
=
None
,
s_churn
:
float
=
0.0
,
s_tmax
:
float
=
None
,
s_tmin
:
float
=
0.0
,
s_noise
:
float
=
1.0
,
override_settings
:
Dict
[
str
,
Any
]
=
None
,
sampler_index
:
int
=
None
):
if
sampler_index
is
not
None
:
warnings
.
warn
(
"sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name"
)
self
.
sd_model
=
sd_model
self
.
sd_model
=
sd_model
self
.
outpath_samples
:
str
=
outpath_samples
self
.
outpath_samples
:
str
=
outpath_samples
self
.
outpath_grids
:
str
=
outpath_grids
self
.
outpath_grids
:
str
=
outpath_grids
...
@@ -91,7 +88,7 @@ class StableDiffusionProcessing():
...
@@ -91,7 +88,7 @@ class StableDiffusionProcessing():
self
.
subseed_strength
:
float
=
subseed_strength
self
.
subseed_strength
:
float
=
subseed_strength
self
.
seed_resize_from_h
:
int
=
seed_resize_from_h
self
.
seed_resize_from_h
:
int
=
seed_resize_from_h
self
.
seed_resize_from_w
:
int
=
seed_resize_from_w
self
.
seed_resize_from_w
:
int
=
seed_resize_from_w
self
.
sampler_
index
:
int
=
sampler_index
self
.
sampler_
name
:
str
=
sampler_name
self
.
batch_size
:
int
=
batch_size
self
.
batch_size
:
int
=
batch_size
self
.
n_iter
:
int
=
n_iter
self
.
n_iter
:
int
=
n_iter
self
.
steps
:
int
=
steps
self
.
steps
:
int
=
steps
...
@@ -210,8 +207,7 @@ class Processed:
...
@@ -210,8 +207,7 @@ class Processed:
self
.
info
=
info
self
.
info
=
info
self
.
width
=
p
.
width
self
.
width
=
p
.
width
self
.
height
=
p
.
height
self
.
height
=
p
.
height
self
.
sampler_index
=
p
.
sampler_index
self
.
sampler_name
=
p
.
sampler_name
self
.
sampler
=
sd_samplers
.
samplers
[
p
.
sampler_index
]
.
name
self
.
cfg_scale
=
p
.
cfg_scale
self
.
cfg_scale
=
p
.
cfg_scale
self
.
steps
=
p
.
steps
self
.
steps
=
p
.
steps
self
.
batch_size
=
p
.
batch_size
self
.
batch_size
=
p
.
batch_size
...
@@ -256,8 +252,7 @@ class Processed:
...
@@ -256,8 +252,7 @@ class Processed:
"subseed_strength"
:
self
.
subseed_strength
,
"subseed_strength"
:
self
.
subseed_strength
,
"width"
:
self
.
width
,
"width"
:
self
.
width
,
"height"
:
self
.
height
,
"height"
:
self
.
height
,
"sampler_index"
:
self
.
sampler_index
,
"sampler_name"
:
self
.
sampler_name
,
"sampler"
:
self
.
sampler
,
"cfg_scale"
:
self
.
cfg_scale
,
"cfg_scale"
:
self
.
cfg_scale
,
"steps"
:
self
.
steps
,
"steps"
:
self
.
steps
,
"batch_size"
:
self
.
batch_size
,
"batch_size"
:
self
.
batch_size
,
...
@@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
...
@@ -384,7 +379,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params
=
{
generation_params
=
{
"Steps"
:
p
.
steps
,
"Steps"
:
p
.
steps
,
"Sampler"
:
get_correct_sampler
(
p
)[
p
.
sampler_index
]
.
name
,
"Sampler"
:
p
.
sampler_
name
,
"CFG scale"
:
p
.
cfg_scale
,
"CFG scale"
:
p
.
cfg_scale
,
"Seed"
:
all_seeds
[
index
],
"Seed"
:
all_seeds
[
index
],
"Face restoration"
:
(
opts
.
face_restoration_model
if
p
.
restore_faces
else
None
),
"Face restoration"
:
(
opts
.
face_restoration_model
if
p
.
restore_faces
else
None
),
...
@@ -399,6 +394,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
...
@@ -399,6 +394,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength"
:
(
None
if
p
.
subseed_strength
==
0
else
p
.
subseed_strength
),
"Variation seed strength"
:
(
None
if
p
.
subseed_strength
==
0
else
p
.
subseed_strength
),
"Seed resize from"
:
(
None
if
p
.
seed_resize_from_w
==
0
or
p
.
seed_resize_from_h
==
0
else
f
"{p.seed_resize_from_w}x{p.seed_resize_from_h}"
),
"Seed resize from"
:
(
None
if
p
.
seed_resize_from_w
==
0
or
p
.
seed_resize_from_h
==
0
else
f
"{p.seed_resize_from_w}x{p.seed_resize_from_h}"
),
"Denoising strength"
:
getattr
(
p
,
'denoising_strength'
,
None
),
"Denoising strength"
:
getattr
(
p
,
'denoising_strength'
,
None
),
"Inpainting strength"
:
(
None
if
getattr
(
p
,
'denoising_strength'
,
None
)
is
None
else
getattr
(
p
,
"inpainting_mask_weight"
,
shared
.
opts
.
inpainting_mask_weight
)),
"Eta"
:
(
None
if
p
.
sampler
is
None
or
p
.
sampler
.
eta
==
p
.
sampler
.
default_eta
else
p
.
sampler
.
eta
),
"Eta"
:
(
None
if
p
.
sampler
is
None
or
p
.
sampler
.
eta
==
p
.
sampler
.
default_eta
else
p
.
sampler
.
eta
),
"Clip skip"
:
None
if
clip_skip
<=
1
else
clip_skip
,
"Clip skip"
:
None
if
clip_skip
<=
1
else
clip_skip
,
"ENSD"
:
None
if
opts
.
eta_noise_seed_delta
==
0
else
opts
.
eta_noise_seed_delta
,
"ENSD"
:
None
if
opts
.
eta_noise_seed_delta
==
0
else
opts
.
eta_noise_seed_delta
,
...
@@ -645,7 +641,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
...
@@ -645,7 +641,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self
.
truncate_y
=
int
(
self
.
firstphase_height
-
firstphase_height_truncated
)
//
opt_f
self
.
truncate_y
=
int
(
self
.
firstphase_height
-
firstphase_height_truncated
)
//
opt_f
def
sample
(
self
,
conditioning
,
unconditional_conditioning
,
seeds
,
subseeds
,
subseed_strength
,
prompts
):
def
sample
(
self
,
conditioning
,
unconditional_conditioning
,
seeds
,
subseeds
,
subseed_strength
,
prompts
):
self
.
sampler
=
sd_samplers
.
create_sampler
_with_index
(
sd_samplers
.
samplers
,
self
.
sampler_index
,
self
.
sd_model
)
self
.
sampler
=
sd_samplers
.
create_sampler
(
self
.
sampler_name
,
self
.
sd_model
)
if
not
self
.
enable_hr
:
if
not
self
.
enable_hr
:
x
=
create_random_tensors
([
opt_C
,
self
.
height
//
opt_f
,
self
.
width
//
opt_f
],
seeds
=
seeds
,
subseeds
=
subseeds
,
subseed_strength
=
self
.
subseed_strength
,
seed_resize_from_h
=
self
.
seed_resize_from_h
,
seed_resize_from_w
=
self
.
seed_resize_from_w
,
p
=
self
)
x
=
create_random_tensors
([
opt_C
,
self
.
height
//
opt_f
,
self
.
width
//
opt_f
],
seeds
=
seeds
,
subseeds
=
subseeds
,
subseed_strength
=
self
.
subseed_strength
,
seed_resize_from_h
=
self
.
seed_resize_from_h
,
seed_resize_from_w
=
self
.
seed_resize_from_w
,
p
=
self
)
...
@@ -706,7 +702,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
...
@@ -706,7 +702,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared
.
state
.
nextjob
()
shared
.
state
.
nextjob
()
self
.
sampler
=
sd_samplers
.
create_sampler
_with_index
(
sd_samplers
.
samplers
,
self
.
sampler_index
,
self
.
sd_model
)
self
.
sampler
=
sd_samplers
.
create_sampler
(
self
.
sampler_name
,
self
.
sd_model
)
noise
=
create_random_tensors
(
samples
.
shape
[
1
:],
seeds
=
seeds
,
subseeds
=
subseeds
,
subseed_strength
=
subseed_strength
,
seed_resize_from_h
=
self
.
seed_resize_from_h
,
seed_resize_from_w
=
self
.
seed_resize_from_w
,
p
=
self
)
noise
=
create_random_tensors
(
samples
.
shape
[
1
:],
seeds
=
seeds
,
subseeds
=
subseeds
,
subseed_strength
=
subseed_strength
,
seed_resize_from_h
=
self
.
seed_resize_from_h
,
seed_resize_from_w
=
self
.
seed_resize_from_w
,
p
=
self
)
...
@@ -743,7 +739,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
...
@@ -743,7 +739,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self
.
image_conditioning
=
None
self
.
image_conditioning
=
None
def
init
(
self
,
all_prompts
,
all_seeds
,
all_subseeds
):
def
init
(
self
,
all_prompts
,
all_seeds
,
all_subseeds
):
self
.
sampler
=
sd_samplers
.
create_sampler
_with_index
(
sd_samplers
.
samplers_for_img2img
,
self
.
sampler_index
,
self
.
sd_model
)
self
.
sampler
=
sd_samplers
.
create_sampler
(
self
.
sampler_name
,
self
.
sd_model
)
crop_region
=
None
crop_region
=
None
if
self
.
image_mask
is
not
None
:
if
self
.
image_mask
is
not
None
:
...
...
modules/sd_models.py
View file @
8662b5e5
...
@@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
...
@@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
cache_enabled
=
shared
.
opts
.
sd_checkpoint_cache
>
0
cache_enabled
=
shared
.
opts
.
sd_checkpoint_cache
>
0
if
cache_enabled
:
sd_vae
.
restore_base_vae
(
model
)
vae_file
=
sd_vae
.
resolve_vae
(
checkpoint_file
,
vae_file
=
vae_file
)
if
cache_enabled
and
checkpoint_info
in
checkpoints_loaded
:
if
cache_enabled
and
checkpoint_info
in
checkpoints_loaded
:
# use checkpoint cache
# use checkpoint cache
vae_name
=
sd_vae
.
get_filename
(
vae_file
)
if
vae_file
else
None
print
(
f
"Loading weights [{sd_model_hash}] from cache"
)
vae_message
=
f
" with {vae_name} VAE"
if
vae_name
else
""
print
(
f
"Loading weights [{sd_model_hash}]{vae_message} from cache"
)
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_info
])
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_info
])
else
:
else
:
# load from file
# load from file
...
@@ -222,6 +215,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
...
@@ -222,6 +215,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
sd_vae
.
delete_base_vae
()
sd_vae
.
delete_base_vae
()
sd_vae
.
clear_loaded_vae
()
sd_vae
.
clear_loaded_vae
()
vae_file
=
sd_vae
.
resolve_vae
(
checkpoint_file
,
vae_file
=
vae_file
)
sd_vae
.
load_vae
(
model
,
vae_file
)
sd_vae
.
load_vae
(
model
,
vae_file
)
...
...
modules/sd_samplers.py
View file @
8662b5e5
...
@@ -46,16 +46,23 @@ all_samplers = [
...
@@ -46,16 +46,23 @@ all_samplers = [
SamplerData
(
'DDIM'
,
lambda
model
:
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
ddim
.
DDIMSampler
,
model
),
[],
{}),
SamplerData
(
'DDIM'
,
lambda
model
:
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
ddim
.
DDIMSampler
,
model
),
[],
{}),
SamplerData
(
'PLMS'
,
lambda
model
:
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
plms
.
PLMSSampler
,
model
),
[],
{}),
SamplerData
(
'PLMS'
,
lambda
model
:
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
plms
.
PLMSSampler
,
model
),
[],
{}),
]
]
all_samplers_map
=
{
x
.
name
:
x
for
x
in
all_samplers
}
samplers
=
[]
samplers
=
[]
samplers_for_img2img
=
[]
samplers_for_img2img
=
[]
def
create_sampler_with_index
(
list_of_configs
,
index
,
model
):
def
create_sampler
(
name
,
model
):
config
=
list_of_configs
[
index
]
if
name
is
not
None
:
config
=
all_samplers_map
.
get
(
name
,
None
)
else
:
config
=
all_samplers
[
0
]
assert
config
is
not
None
,
f
'bad sampler name: {name}'
sampler
=
config
.
constructor
(
model
)
sampler
=
config
.
constructor
(
model
)
sampler
.
config
=
config
sampler
.
config
=
config
return
sampler
return
sampler
...
...
modules/sd_vae.py
View file @
8662b5e5
...
@@ -95,7 +95,7 @@ def get_vae_from_settings(vae_file="auto"):
...
@@ -95,7 +95,7 @@ def get_vae_from_settings(vae_file="auto"):
# if VAE selected but not found, fallback to auto
# if VAE selected but not found, fallback to auto
if
vae_file
not
in
default_vae_values
and
not
os
.
path
.
isfile
(
vae_file
):
if
vae_file
not
in
default_vae_values
and
not
os
.
path
.
isfile
(
vae_file
):
vae_file
=
"auto"
vae_file
=
"auto"
print
(
"Selected VAE doesn't exist
"
)
print
(
f
"Selected VAE doesn't exist: {vae_file}
"
)
return
vae_file
return
vae_file
...
@@ -105,15 +105,15 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
...
@@ -105,15 +105,15 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
# if vae_file argument is provided, it takes priority, but not saved
# if vae_file argument is provided, it takes priority, but not saved
if
vae_file
and
vae_file
not
in
default_vae_list
:
if
vae_file
and
vae_file
not
in
default_vae_list
:
if
not
os
.
path
.
isfile
(
vae_file
):
if
not
os
.
path
.
isfile
(
vae_file
):
print
(
f
"VAE provided as function argument doesn't exist: {vae_file}"
)
vae_file
=
"auto"
vae_file
=
"auto"
print
(
"VAE provided as function argument doesn't exist"
)
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if
first_load
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
first_load
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
vae_file
=
shared
.
cmd_opts
.
vae_path
vae_file
=
shared
.
cmd_opts
.
vae_path
shared
.
opts
.
data
[
'sd_vae'
]
=
get_filename
(
vae_file
)
shared
.
opts
.
data
[
'sd_vae'
]
=
get_filename
(
vae_file
)
else
:
else
:
print
(
"VAE provided as command line argument doesn't exist
"
)
print
(
f
"VAE provided as command line argument doesn't exist: {vae_file}
"
)
# fallback to selector in settings, if vae selector not set to act as default fallback
# fallback to selector in settings, if vae selector not set to act as default fallback
if
not
shared
.
opts
.
sd_vae_as_default
:
if
not
shared
.
opts
.
sd_vae_as_default
:
vae_file
=
get_vae_from_settings
(
vae_file
)
vae_file
=
get_vae_from_settings
(
vae_file
)
...
@@ -121,20 +121,20 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
...
@@ -121,20 +121,20 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if
vae_file
==
"auto"
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
vae_file
==
"auto"
and
shared
.
cmd_opts
.
vae_path
is
not
None
:
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
if
os
.
path
.
isfile
(
shared
.
cmd_opts
.
vae_path
):
vae_file
=
shared
.
cmd_opts
.
vae_path
vae_file
=
shared
.
cmd_opts
.
vae_path
print
(
"Using VAE provided as command line argument
"
)
print
(
f
"Using VAE provided as command line argument: {vae_file}
"
)
# if still not found, try look for ".vae.pt" beside model
# if still not found, try look for ".vae.pt" beside model
model_path
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
model_path
=
os
.
path
.
splitext
(
checkpoint_file
)[
0
]
if
vae_file
==
"auto"
:
if
vae_file
==
"auto"
:
vae_file_try
=
model_path
+
".vae.pt"
vae_file_try
=
model_path
+
".vae.pt"
if
os
.
path
.
isfile
(
vae_file_try
):
if
os
.
path
.
isfile
(
vae_file_try
):
vae_file
=
vae_file_try
vae_file
=
vae_file_try
print
(
"Using VAE found beside selected model
"
)
print
(
f
"Using VAE found similar to selected model: {vae_file}
"
)
# if still not found, try look for ".vae.ckpt" beside model
# if still not found, try look for ".vae.ckpt" beside model
if
vae_file
==
"auto"
:
if
vae_file
==
"auto"
:
vae_file_try
=
model_path
+
".vae.ckpt"
vae_file_try
=
model_path
+
".vae.ckpt"
if
os
.
path
.
isfile
(
vae_file_try
):
if
os
.
path
.
isfile
(
vae_file_try
):
vae_file
=
vae_file_try
vae_file
=
vae_file_try
print
(
"Using VAE found beside selected model
"
)
print
(
f
"Using VAE found similar to selected model: {vae_file}
"
)
# No more fallbacks for auto
# No more fallbacks for auto
if
vae_file
==
"auto"
:
if
vae_file
==
"auto"
:
vae_file
=
None
vae_file
=
None
...
@@ -150,6 +150,7 @@ def load_vae(model, vae_file=None):
...
@@ -150,6 +150,7 @@ def load_vae(model, vae_file=None):
# save_settings = False
# save_settings = False
if
vae_file
:
if
vae_file
:
assert
os
.
path
.
isfile
(
vae_file
),
f
"VAE file doesn't exist: {vae_file}"
print
(
f
"Loading VAE weights from: {vae_file}"
)
print
(
f
"Loading VAE weights from: {vae_file}"
)
store_base_vae
(
model
)
store_base_vae
(
model
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
...
...
modules/textual_inversion/textual_inversion.py
View file @
8662b5e5
...
@@ -10,7 +10,7 @@ import csv
...
@@ -10,7 +10,7 @@ import csv
from
PIL
import
Image
,
PngImagePlugin
from
PIL
import
Image
,
PngImagePlugin
from
modules
import
shared
,
devices
,
sd_hijack
,
processing
,
sd_models
,
images
from
modules
import
shared
,
devices
,
sd_hijack
,
processing
,
sd_models
,
images
,
sd_samplers
import
modules.textual_inversion.dataset
import
modules.textual_inversion.dataset
from
modules.textual_inversion.learn_schedule
import
LearnRateScheduler
from
modules.textual_inversion.learn_schedule
import
LearnRateScheduler
...
@@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
...
@@ -345,7 +345,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
p
.
prompt
=
preview_prompt
p
.
prompt
=
preview_prompt
p
.
negative_prompt
=
preview_negative_prompt
p
.
negative_prompt
=
preview_negative_prompt
p
.
steps
=
preview_steps
p
.
steps
=
preview_steps
p
.
sampler_
index
=
preview_sampler_index
p
.
sampler_
name
=
sd_samplers
.
samplers
[
preview_sampler_index
]
.
name
p
.
cfg_scale
=
preview_cfg_scale
p
.
cfg_scale
=
preview_cfg_scale
p
.
seed
=
preview_seed
p
.
seed
=
preview_seed
p
.
width
=
preview_width
p
.
width
=
preview_width
...
...
modules/txt2img.py
View file @
8662b5e5
import
modules.scripts
import
modules.scripts
from
modules
import
sd_samplers
from
modules.processing
import
StableDiffusionProcessing
,
Processed
,
StableDiffusionProcessingTxt2Img
,
\
from
modules.processing
import
StableDiffusionProcessing
,
Processed
,
StableDiffusionProcessingTxt2Img
,
\
StableDiffusionProcessingImg2Img
,
process_images
StableDiffusionProcessingImg2Img
,
process_images
from
modules.shared
import
opts
,
cmd_opts
from
modules.shared
import
opts
,
cmd_opts
...
@@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
...
@@ -21,7 +22,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
seed_resize_from_h
=
seed_resize_from_h
,
seed_resize_from_h
=
seed_resize_from_h
,
seed_resize_from_w
=
seed_resize_from_w
,
seed_resize_from_w
=
seed_resize_from_w
,
seed_enable_extras
=
seed_enable_extras
,
seed_enable_extras
=
seed_enable_extras
,
sampler_
index
=
sampler_index
,
sampler_
name
=
sd_samplers
.
samplers
[
sampler_index
]
.
name
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
n_iter
=
n_iter
,
n_iter
=
n_iter
,
steps
=
steps
,
steps
=
steps
,
...
...
modules/ui.py
View file @
8662b5e5
...
@@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index):
...
@@ -142,7 +142,7 @@ def save_files(js_data, images, do_make_zip, index):
filenames
.
append
(
os
.
path
.
basename
(
txt_fullfn
))
filenames
.
append
(
os
.
path
.
basename
(
txt_fullfn
))
fullfns
.
append
(
txt_fullfn
)
fullfns
.
append
(
txt_fullfn
)
writer
.
writerow
([
data
[
"prompt"
],
data
[
"seed"
],
data
[
"width"
],
data
[
"height"
],
data
[
"sampler"
],
data
[
"cfg_scale"
],
data
[
"steps"
],
filenames
[
0
],
data
[
"negative_prompt"
]])
writer
.
writerow
([
data
[
"prompt"
],
data
[
"seed"
],
data
[
"width"
],
data
[
"height"
],
data
[
"sampler
_name
"
],
data
[
"cfg_scale"
],
data
[
"steps"
],
filenames
[
0
],
data
[
"negative_prompt"
]])
# Make Zip
# Make Zip
if
do_make_zip
:
if
do_make_zip
:
...
...
modules/ui_extensions.py
View file @
8662b5e5
...
@@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list):
...
@@ -36,9 +36,9 @@ def apply_and_restart(disable_list, update_list):
continue
continue
try
:
try
:
ext
.
pull
()
ext
.
fetch_and_reset_hard
()
except
Exception
:
except
Exception
:
print
(
f
"Error
pull
ing updates for {ext.name}:"
,
file
=
sys
.
stderr
)
print
(
f
"Error
gett
ing updates for {ext.name}:"
,
file
=
sys
.
stderr
)
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
print
(
traceback
.
format_exc
(),
file
=
sys
.
stderr
)
shared
.
opts
.
disabled_extensions
=
disabled
shared
.
opts
.
disabled_extensions
=
disabled
...
...
scripts/img2imgalt.py
View file @
8662b5e5
...
@@ -157,7 +157,7 @@ class Script(scripts.Script):
...
@@ -157,7 +157,7 @@ class Script(scripts.Script):
def
run
(
self
,
p
,
_
,
override_sampler
,
override_prompt
,
original_prompt
,
original_negative_prompt
,
override_steps
,
st
,
override_strength
,
cfg
,
randomness
,
sigma_adjustment
):
def
run
(
self
,
p
,
_
,
override_sampler
,
override_prompt
,
original_prompt
,
original_negative_prompt
,
override_steps
,
st
,
override_strength
,
cfg
,
randomness
,
sigma_adjustment
):
# Override
# Override
if
override_sampler
:
if
override_sampler
:
p
.
sampler_
index
=
[
sampler
.
name
for
sampler
in
sd_samplers
.
samplers
]
.
index
(
"Euler"
)
p
.
sampler_
name
=
"Euler"
if
override_prompt
:
if
override_prompt
:
p
.
prompt
=
original_prompt
p
.
prompt
=
original_prompt
p
.
negative_prompt
=
original_negative_prompt
p
.
negative_prompt
=
original_negative_prompt
...
@@ -191,7 +191,7 @@ class Script(scripts.Script):
...
@@ -191,7 +191,7 @@ class Script(scripts.Script):
combined_noise
=
((
1
-
randomness
)
*
rec_noise
+
randomness
*
rand_noise
)
/
((
randomness
**
2
+
(
1
-
randomness
)
**
2
)
**
0.5
)
combined_noise
=
((
1
-
randomness
)
*
rec_noise
+
randomness
*
rand_noise
)
/
((
randomness
**
2
+
(
1
-
randomness
)
**
2
)
**
0.5
)
sampler
=
sd_samplers
.
create_sampler
_with_index
(
sd_samplers
.
samplers
,
p
.
sampler_index
,
p
.
sd_model
)
sampler
=
sd_samplers
.
create_sampler
(
p
.
sampler_name
,
p
.
sd_model
)
sigmas
=
sampler
.
model_wrap
.
get_sigmas
(
p
.
steps
)
sigmas
=
sampler
.
model_wrap
.
get_sigmas
(
p
.
steps
)
...
...
scripts/xy_grid.py
View file @
8662b5e5
...
@@ -10,9 +10,9 @@ import numpy as np
...
@@ -10,9 +10,9 @@ import numpy as np
import
modules.scripts
as
scripts
import
modules.scripts
as
scripts
import
gradio
as
gr
import
gradio
as
gr
from
modules
import
images
from
modules
import
images
,
sd_samplers
from
modules.hypernetworks
import
hypernetwork
from
modules.hypernetworks
import
hypernetwork
from
modules.processing
import
process_images
,
Processed
,
get_correct_sampler
,
StableDiffusionProcessingTxt2Img
from
modules.processing
import
process_images
,
Processed
,
StableDiffusionProcessingTxt2Img
from
modules.shared
import
opts
,
cmd_opts
,
state
from
modules.shared
import
opts
,
cmd_opts
,
state
import
modules.shared
as
shared
import
modules.shared
as
shared
import
modules.sd_samplers
import
modules.sd_samplers
...
@@ -60,9 +60,9 @@ def apply_order(p, x, xs):
...
@@ -60,9 +60,9 @@ def apply_order(p, x, xs):
p
.
prompt
=
prompt_tmp
+
p
.
prompt
p
.
prompt
=
prompt_tmp
+
p
.
prompt
def
build_samplers_dict
(
p
):
def
build_samplers_dict
():
samplers_dict
=
{}
samplers_dict
=
{}
for
i
,
sampler
in
enumerate
(
get_correct_sampler
(
p
)
):
for
i
,
sampler
in
enumerate
(
sd_samplers
.
all_samplers
):
samplers_dict
[
sampler
.
name
.
lower
()]
=
i
samplers_dict
[
sampler
.
name
.
lower
()]
=
i
for
alias
in
sampler
.
aliases
:
for
alias
in
sampler
.
aliases
:
samplers_dict
[
alias
.
lower
()]
=
i
samplers_dict
[
alias
.
lower
()]
=
i
...
@@ -70,7 +70,7 @@ def build_samplers_dict(p):
...
@@ -70,7 +70,7 @@ def build_samplers_dict(p):
def
apply_sampler
(
p
,
x
,
xs
):
def
apply_sampler
(
p
,
x
,
xs
):
sampler_index
=
build_samplers_dict
(
p
)
.
get
(
x
.
lower
(),
None
)
sampler_index
=
build_samplers_dict
()
.
get
(
x
.
lower
(),
None
)
if
sampler_index
is
None
:
if
sampler_index
is
None
:
raise
RuntimeError
(
f
"Unknown sampler: {x}"
)
raise
RuntimeError
(
f
"Unknown sampler: {x}"
)
...
@@ -78,7 +78,7 @@ def apply_sampler(p, x, xs):
...
@@ -78,7 +78,7 @@ def apply_sampler(p, x, xs):
def
confirm_samplers
(
p
,
xs
):
def
confirm_samplers
(
p
,
xs
):
samplers_dict
=
build_samplers_dict
(
p
)
samplers_dict
=
build_samplers_dict
()
for
x
in
xs
:
for
x
in
xs
:
if
x
.
lower
()
not
in
samplers_dict
.
keys
():
if
x
.
lower
()
not
in
samplers_dict
.
keys
():
raise
RuntimeError
(
f
"Unknown sampler: {x}"
)
raise
RuntimeError
(
f
"Unknown sampler: {x}"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment