Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
4df63d2d
Commit
4df63d2d
authored
Jan 30, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split samplers into one more files for k-diffusion
parent
27447410
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
348 deletions
+22
-348
sd_samplers.py
modules/sd_samplers.py
+5
-297
sd_samplers_common.py
modules/sd_samplers_common.py
+2
-1
sd_samplers_compvis.py
modules/sd_samplers_compvis.py
+8
-0
sd_samplers_kdiffusion.py
modules/sd_samplers_kdiffusion.py
+7
-50
No files found.
modules/sd_samplers.py
View file @
4df63d2d
This diff is collapsed.
Click to expand it.
modules/sd_samplers_common.py
View file @
4df63d2d
from
collections
import
namedtuple
,
deque
from
collections
import
namedtuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
...
@@ -64,6 +64,7 @@ class InterruptedException(BaseException):
...
@@ -64,6 +64,7 @@ class InterruptedException(BaseException):
# MPS fix for randn in torchsde
# MPS fix for randn in torchsde
# XXX move this to separate file for MPS
def
torchsde_randn
(
size
,
dtype
,
device
,
seed
):
def
torchsde_randn
(
size
,
dtype
,
device
,
seed
):
if
device
.
type
==
'mps'
:
if
device
.
type
==
'mps'
:
generator
=
torch
.
Generator
(
devices
.
cpu
)
.
manual_seed
(
int
(
seed
))
generator
=
torch
.
Generator
(
devices
.
cpu
)
.
manual_seed
(
int
(
seed
))
...
...
modules/sd_samplers_compvis.py
View file @
4df63d2d
import
math
import
math
import
ldm.models.diffusion.ddim
import
ldm.models.diffusion.plms
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -7,6 +9,12 @@ from modules.shared import state
...
@@ -7,6 +9,12 @@ from modules.shared import state
from
modules
import
sd_samplers_common
,
prompt_parser
,
shared
from
modules
import
sd_samplers_common
,
prompt_parser
,
shared
samplers_data_compvis
=
[
sd_samplers_common
.
SamplerData
(
'DDIM'
,
lambda
model
:
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
ddim
.
DDIMSampler
,
model
),
[],
{}),
sd_samplers_common
.
SamplerData
(
'PLMS'
,
lambda
model
:
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
plms
.
PLMSSampler
,
model
),
[],
{}),
]
class
VanillaStableDiffusionSampler
:
class
VanillaStableDiffusionSampler
:
def
__init__
(
self
,
constructor
,
sd_model
):
def
__init__
(
self
,
constructor
,
sd_model
):
self
.
sampler
=
constructor
(
sd_model
)
self
.
sampler
=
constructor
(
sd_model
)
...
...
modules/sd_samplers_kdiffusion.py
View file @
4df63d2d
...
@@ -2,18 +2,12 @@ from collections import deque
...
@@ -2,18 +2,12 @@ from collections import deque
import
torch
import
torch
import
inspect
import
inspect
import
k_diffusion.sampling
import
k_diffusion.sampling
import
ldm.models.diffusion.ddim
import
ldm.models.diffusion.plms
from
modules
import
prompt_parser
,
devices
,
sd_samplers_common
,
sd_samplers_compvis
from
modules
import
prompt_parser
,
devices
,
sd_samplers_common
,
sd_samplers_compvis
from
modules.shared
import
opts
,
state
from
modules.shared
import
opts
,
state
import
modules.shared
as
shared
import
modules.shared
as
shared
from
modules.script_callbacks
import
CFGDenoiserParams
,
cfg_denoiser_callback
from
modules.script_callbacks
import
CFGDenoiserParams
,
cfg_denoiser_callback
# imports for functions that previously were here and are used by other modules
from
modules.sd_samplers_common
import
samples_to_image_grid
,
sample_to_image
samplers_k_diffusion
=
[
samplers_k_diffusion
=
[
(
'Euler a'
,
'sample_euler_ancestral'
,
[
'k_euler_a'
,
'k_euler_ancestral'
],
{}),
(
'Euler a'
,
'sample_euler_ancestral'
,
[
'k_euler_a'
,
'k_euler_ancestral'
],
{}),
(
'Euler'
,
'sample_euler'
,
[
'k_euler'
],
{}),
(
'Euler'
,
'sample_euler'
,
[
'k_euler'
],
{}),
...
@@ -40,50 +34,6 @@ samplers_data_k_diffusion = [
...
@@ -40,50 +34,6 @@ samplers_data_k_diffusion = [
if
hasattr
(
k_diffusion
.
sampling
,
funcname
)
if
hasattr
(
k_diffusion
.
sampling
,
funcname
)
]
]
all_samplers
=
[
*
samplers_data_k_diffusion
,
sd_samplers_common
.
SamplerData
(
'DDIM'
,
lambda
model
:
sd_samplers_compvis
.
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
ddim
.
DDIMSampler
,
model
),
[],
{}),
sd_samplers_common
.
SamplerData
(
'PLMS'
,
lambda
model
:
sd_samplers_compvis
.
VanillaStableDiffusionSampler
(
ldm
.
models
.
diffusion
.
plms
.
PLMSSampler
,
model
),
[],
{}),
]
all_samplers_map
=
{
x
.
name
:
x
for
x
in
all_samplers
}
samplers
=
[]
samplers_for_img2img
=
[]
samplers_map
=
{}
def
create_sampler
(
name
,
model
):
if
name
is
not
None
:
config
=
all_samplers_map
.
get
(
name
,
None
)
else
:
config
=
all_samplers
[
0
]
assert
config
is
not
None
,
f
'bad sampler name: {name}'
sampler
=
config
.
constructor
(
model
)
sampler
.
config
=
config
return
sampler
def
set_samplers
():
global
samplers
,
samplers_for_img2img
hidden
=
set
(
opts
.
hide_samplers
)
hidden_img2img
=
set
(
opts
.
hide_samplers
+
[
'PLMS'
])
samplers
=
[
x
for
x
in
all_samplers
if
x
.
name
not
in
hidden
]
samplers_for_img2img
=
[
x
for
x
in
all_samplers
if
x
.
name
not
in
hidden_img2img
]
samplers_map
.
clear
()
for
sampler
in
all_samplers
:
samplers_map
[
sampler
.
name
.
lower
()]
=
sampler
.
name
for
alias
in
sampler
.
aliases
:
samplers_map
[
alias
.
lower
()]
=
sampler
.
name
set_samplers
()
sampler_extra_params
=
{
sampler_extra_params
=
{
'sample_euler'
:
[
's_churn'
,
's_tmin'
,
's_tmax'
,
's_noise'
],
'sample_euler'
:
[
's_churn'
,
's_tmin'
,
's_tmax'
,
's_noise'
],
'sample_heun'
:
[
's_churn'
,
's_tmin'
,
's_tmax'
,
's_noise'
],
'sample_heun'
:
[
's_churn'
,
's_tmin'
,
's_tmax'
,
's_noise'
],
...
@@ -92,6 +42,13 @@ sampler_extra_params = {
...
@@ -92,6 +42,13 @@ sampler_extra_params = {
class
CFGDenoiser
(
torch
.
nn
.
Module
):
class
CFGDenoiser
(
torch
.
nn
.
Module
):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
inner_model
=
model
self
.
inner_model
=
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment