Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
ada17dbd
Commit
ada17dbd
authored
Jan 27, 2023
by
brkirch
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor conditional casting, fix upscalers
parent
c4b9b07d
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
10 deletions
+25
-10
devices.py
modules/devices.py
+8
-0
processing.py
modules/processing.py
+8
-7
realesrgan_model.py
modules/realesrgan_model.py
+1
-1
sd_hijack.py
modules/sd_hijack.py
+1
-1
sd_hijack_unet.py
modules/sd_hijack_unet.py
+7
-1
No files found.
modules/devices.py
View file @
ada17dbd
...
@@ -83,6 +83,14 @@ dtype_unet = torch.float16
...
@@ -83,6 +83,14 @@ dtype_unet = torch.float16
unet_needs_upcast
=
False
unet_needs_upcast
=
False
def
cond_cast_unet
(
input
):
return
input
.
to
(
dtype_unet
)
if
unet_needs_upcast
else
input
def
cond_cast_float
(
input
):
return
input
.
float
()
if
unet_needs_upcast
else
input
def
randn
(
seed
,
shape
):
def
randn
(
seed
,
shape
):
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
device
.
type
==
'mps'
:
if
device
.
type
==
'mps'
:
...
...
modules/processing.py
View file @
ada17dbd
...
@@ -172,8 +172,7 @@ class StableDiffusionProcessing:
...
@@ -172,8 +172,7 @@ class StableDiffusionProcessing:
midas_in
=
torch
.
from_numpy
(
transformed
[
"midas_in"
][
None
,
...
])
.
to
(
device
=
shared
.
device
)
midas_in
=
torch
.
from_numpy
(
transformed
[
"midas_in"
][
None
,
...
])
.
to
(
device
=
shared
.
device
)
midas_in
=
repeat
(
midas_in
,
"1 ... -> n ..."
,
n
=
self
.
batch_size
)
midas_in
=
repeat
(
midas_in
,
"1 ... -> n ..."
,
n
=
self
.
batch_size
)
conditioning_image
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
source_image
.
to
(
devices
.
dtype_vae
)
if
devices
.
unet_needs_upcast
else
source_image
))
conditioning_image
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
source_image
))
conditioning_image
=
conditioning_image
.
float
()
if
devices
.
unet_needs_upcast
else
conditioning_image
conditioning
=
torch
.
nn
.
functional
.
interpolate
(
conditioning
=
torch
.
nn
.
functional
.
interpolate
(
self
.
sd_model
.
depth_model
(
midas_in
),
self
.
sd_model
.
depth_model
(
midas_in
),
size
=
conditioning_image
.
shape
[
2
:],
size
=
conditioning_image
.
shape
[
2
:],
...
@@ -217,7 +216,7 @@ class StableDiffusionProcessing:
...
@@ -217,7 +216,7 @@ class StableDiffusionProcessing:
)
)
# Encode the new masked image using first stage of network.
# Encode the new masked image using first stage of network.
conditioning_image
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
conditioning_image
.
to
(
devices
.
dtype_vae
)
if
devices
.
unet_needs_upcast
else
conditioning_image
))
conditioning_image
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
conditioning_image
))
# Create the concatenated conditioning tensor to be fed to `c_concat`
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask
=
torch
.
nn
.
functional
.
interpolate
(
conditioning_mask
,
size
=
latent_image
.
shape
[
-
2
:])
conditioning_mask
=
torch
.
nn
.
functional
.
interpolate
(
conditioning_mask
,
size
=
latent_image
.
shape
[
-
2
:])
...
@@ -228,16 +227,18 @@ class StableDiffusionProcessing:
...
@@ -228,16 +227,18 @@ class StableDiffusionProcessing:
return
image_conditioning
return
image_conditioning
def
img2img_image_conditioning
(
self
,
source_image
,
latent_image
,
image_mask
=
None
):
def
img2img_image_conditioning
(
self
,
source_image
,
latent_image
,
image_mask
=
None
):
source_image
=
devices
.
cond_cast_float
(
source_image
)
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
# identify itself with a field common to all models. The conditioning_key is also hybrid.
# identify itself with a field common to all models. The conditioning_key is also hybrid.
if
isinstance
(
self
.
sd_model
,
LatentDepth2ImageDiffusion
):
if
isinstance
(
self
.
sd_model
,
LatentDepth2ImageDiffusion
):
return
self
.
depth2img_image_conditioning
(
source_image
.
float
()
if
devices
.
unet_needs_upcast
else
source_image
)
return
self
.
depth2img_image_conditioning
(
source_image
)
if
self
.
sd_model
.
cond_stage_key
==
"edit"
:
if
self
.
sd_model
.
cond_stage_key
==
"edit"
:
return
self
.
edit_image_conditioning
(
source_image
)
return
self
.
edit_image_conditioning
(
source_image
)
if
self
.
sampler
.
conditioning_key
in
{
'hybrid'
,
'concat'
}:
if
self
.
sampler
.
conditioning_key
in
{
'hybrid'
,
'concat'
}:
return
self
.
inpainting_image_conditioning
(
source_image
.
float
()
if
devices
.
unet_needs_upcast
else
source_image
,
latent_image
,
image_mask
=
image_mask
)
return
self
.
inpainting_image_conditioning
(
source_image
,
latent_image
,
image_mask
=
image_mask
)
# Dummy zero conditioning if we're not using inpainting or depth model.
# Dummy zero conditioning if we're not using inpainting or depth model.
return
latent_image
.
new_zeros
(
latent_image
.
shape
[
0
],
5
,
1
,
1
)
return
latent_image
.
new_zeros
(
latent_image
.
shape
[
0
],
5
,
1
,
1
)
...
@@ -417,7 +418,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
...
@@ -417,7 +418,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def
decode_first_stage
(
model
,
x
):
def
decode_first_stage
(
model
,
x
):
with
devices
.
autocast
(
disable
=
x
.
dtype
==
devices
.
dtype_vae
):
with
devices
.
autocast
(
disable
=
x
.
dtype
==
devices
.
dtype_vae
):
x
=
model
.
decode_first_stage
(
x
.
to
(
devices
.
dtype_vae
)
if
devices
.
unet_needs_upcast
else
x
)
x
=
model
.
decode_first_stage
(
x
)
return
x
return
x
...
@@ -1001,7 +1002,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
...
@@ -1001,7 +1002,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image
=
torch
.
from_numpy
(
batch_images
)
image
=
torch
.
from_numpy
(
batch_images
)
image
=
2.
*
image
-
1.
image
=
2.
*
image
-
1.
image
=
image
.
to
(
device
=
shared
.
device
,
dtype
=
devices
.
dtype_vae
if
devices
.
unet_needs_upcast
else
Non
e
)
image
=
image
.
to
(
shared
.
devic
e
)
self
.
init_latent
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
image
))
self
.
init_latent
=
self
.
sd_model
.
get_first_stage_encoding
(
self
.
sd_model
.
encode_first_stage
(
image
))
...
...
modules/realesrgan_model.py
View file @
ada17dbd
...
@@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
...
@@ -46,7 +46,7 @@ class UpscalerRealESRGAN(Upscaler):
scale
=
info
.
scale
,
scale
=
info
.
scale
,
model_path
=
info
.
local_data_path
,
model_path
=
info
.
local_data_path
,
model
=
info
.
model
(),
model
=
info
.
model
(),
half
=
not
cmd_opts
.
no_half
,
half
=
not
cmd_opts
.
no_half
and
not
cmd_opts
.
upcast_sampling
,
tile
=
opts
.
ESRGAN_tile
,
tile
=
opts
.
ESRGAN_tile
,
tile_pad
=
opts
.
ESRGAN_tile_overlap
,
tile_pad
=
opts
.
ESRGAN_tile_overlap
,
)
)
...
...
modules/sd_hijack.py
View file @
ada17dbd
...
@@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
...
@@ -171,7 +171,7 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs
=
[]
vecs
=
[]
for
fixes
,
tensor
in
zip
(
batch_fixes
,
inputs_embeds
):
for
fixes
,
tensor
in
zip
(
batch_fixes
,
inputs_embeds
):
for
offset
,
embedding
in
fixes
:
for
offset
,
embedding
in
fixes
:
emb
=
embedding
.
vec
.
to
(
devices
.
dtype_unet
)
if
devices
.
unet_needs_upcast
else
embedding
.
vec
emb
=
devices
.
cond_cast_unet
(
embedding
.
vec
)
emb_len
=
min
(
tensor
.
shape
[
0
]
-
offset
-
1
,
emb
.
shape
[
0
])
emb_len
=
min
(
tensor
.
shape
[
0
]
-
offset
-
1
,
emb
.
shape
[
0
])
tensor
=
torch
.
cat
([
tensor
[
0
:
offset
+
1
],
emb
[
0
:
emb_len
],
tensor
[
offset
+
1
+
emb_len
:]])
tensor
=
torch
.
cat
([
tensor
[
0
:
offset
+
1
],
emb
[
0
:
emb_len
],
tensor
[
offset
+
1
+
emb_len
:]])
...
...
modules/sd_hijack_unet.py
View file @
ada17dbd
...
@@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
...
@@ -55,8 +55,14 @@ class GELUHijack(torch.nn.GELU, torch.nn.Module):
unet_needs_upcast
=
lambda
*
args
,
**
kwargs
:
devices
.
unet_needs_upcast
unet_needs_upcast
=
lambda
*
args
,
**
kwargs
:
devices
.
unet_needs_upcast
CondFunc
(
'ldm.models.diffusion.ddpm.LatentDiffusion.apply_model'
,
apply_model
,
unet_needs_upcast
)
CondFunc
(
'ldm.models.diffusion.ddpm.LatentDiffusion.apply_model'
,
apply_model
,
unet_needs_upcast
)
CondFunc
(
'ldm.modules.diffusionmodules.openaimodel.timestep_embedding'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
orig_func
(
*
args
,
**
kwargs
)
.
to
(
devices
.
dtype_unet
),
unet_needs_upcast
)
CondFunc
(
'ldm.modules.diffusionmodules.openaimodel.timestep_embedding'
,
lambda
orig_func
,
timesteps
,
*
args
,
**
kwargs
:
orig_func
(
timesteps
,
*
args
,
**
kwargs
)
.
to
(
torch
.
float32
if
timesteps
.
dtype
==
torch
.
int64
else
devices
.
dtype_unet
),
unet_needs_upcast
)
if
version
.
parse
(
torch
.
__version__
)
<=
version
.
parse
(
"1.13.1"
):
if
version
.
parse
(
torch
.
__version__
)
<=
version
.
parse
(
"1.13.1"
):
CondFunc
(
'ldm.modules.diffusionmodules.util.GroupNorm32.forward'
,
lambda
orig_func
,
self
,
*
args
,
**
kwargs
:
orig_func
(
self
.
float
(),
*
args
,
**
kwargs
),
unet_needs_upcast
)
CondFunc
(
'ldm.modules.diffusionmodules.util.GroupNorm32.forward'
,
lambda
orig_func
,
self
,
*
args
,
**
kwargs
:
orig_func
(
self
.
float
(),
*
args
,
**
kwargs
),
unet_needs_upcast
)
CondFunc
(
'ldm.modules.attention.GEGLU.forward'
,
lambda
orig_func
,
self
,
x
:
orig_func
(
self
.
float
(),
x
.
float
())
.
to
(
devices
.
dtype_unet
),
unet_needs_upcast
)
CondFunc
(
'ldm.modules.attention.GEGLU.forward'
,
lambda
orig_func
,
self
,
x
:
orig_func
(
self
.
float
(),
x
.
float
())
.
to
(
devices
.
dtype_unet
),
unet_needs_upcast
)
CondFunc
(
'open_clip.transformer.ResidualAttentionBlock.__init__'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
kwargs
.
update
({
'act_layer'
:
GELUHijack
})
and
False
or
orig_func
(
*
args
,
**
kwargs
),
lambda
_
,
*
args
,
**
kwargs
:
kwargs
.
get
(
'act_layer'
)
is
None
or
kwargs
[
'act_layer'
]
==
torch
.
nn
.
GELU
)
CondFunc
(
'open_clip.transformer.ResidualAttentionBlock.__init__'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
kwargs
.
update
({
'act_layer'
:
GELUHijack
})
and
False
or
orig_func
(
*
args
,
**
kwargs
),
lambda
_
,
*
args
,
**
kwargs
:
kwargs
.
get
(
'act_layer'
)
is
None
or
kwargs
[
'act_layer'
]
==
torch
.
nn
.
GELU
)
first_stage_cond
=
lambda
_
,
self
,
*
args
,
**
kwargs
:
devices
.
unet_needs_upcast
and
self
.
model
.
diffusion_model
.
dtype
==
torch
.
float16
first_stage_sub
=
lambda
orig_func
,
self
,
x
,
**
kwargs
:
orig_func
(
self
,
x
.
to
(
devices
.
dtype_vae
),
**
kwargs
)
CondFunc
(
'ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage'
,
first_stage_sub
,
first_stage_cond
)
CondFunc
(
'ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage'
,
first_stage_sub
,
first_stage_cond
)
CondFunc
(
'ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
orig_func
(
*
args
,
**
kwargs
)
.
float
(),
first_stage_cond
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment