Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
7dbfd8a7
Commit
7dbfd8a7
authored
Dec 10, 2022
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
do not replace entire unet for the resolution hack
parent
2641d1b8
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
30 deletions
+33
-30
sd_hijack.py
modules/sd_hijack.py
+3
-2
sd_hijack_optimizations.py
modules/sd_hijack_optimizations.py
+0
-28
sd_hijack_unet.py
modules/sd_hijack_unet.py
+30
-0
No files found.
modules/sd_hijack.py
View file @
7dbfd8a7
...
...
@@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion
from
modules
import
prompt_parser
,
devices
,
sd_hijack_optimizations
,
shared
,
sd_hijack_checkpoint
from
modules.hypernetworks
import
hypernetwork
from
modules.shared
import
opts
,
device
,
cmd_opts
from
modules
import
sd_hijack_clip
,
sd_hijack_open_clip
from
modules
import
sd_hijack_clip
,
sd_hijack_open_clip
,
sd_hijack_unet
from
modules.sd_hijack_optimizations
import
invokeAI_mps_available
...
...
@@ -35,11 +35,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm
.
modules
.
attention
.
print
=
lambda
*
args
:
None
ldm
.
modules
.
diffusionmodules
.
model
.
print
=
lambda
*
args
:
None
def
apply_optimizations
():
undo_optimizations
()
ldm
.
modules
.
diffusionmodules
.
model
.
nonlinearity
=
silu
ldm
.
modules
.
diffusionmodules
.
openaimodel
.
UNetModel
.
forward
=
sd_hijack_optimizations
.
patched_unet_forward
ldm
.
modules
.
diffusionmodules
.
openaimodel
.
th
=
sd_hijack_unet
.
th
if
cmd_opts
.
force_enable_xformers
or
(
cmd_opts
.
xformers
and
shared
.
xformers_available
and
torch
.
version
.
cuda
and
(
6
,
0
)
<=
torch
.
cuda
.
get_device_capability
(
shared
.
device
)
<=
(
9
,
0
)):
print
(
"Applying xformers cross attention optimization."
)
...
...
modules/sd_hijack_optimizations.py
View file @
7dbfd8a7
...
...
@@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x):
return
x
+
out
except
NotImplementedError
:
return
cross_attention_attnblock_forward
(
self
,
x
)
def
patched_unet_forward
(
self
,
x
,
timesteps
=
None
,
context
=
None
,
y
=
None
,
**
kwargs
):
assert
(
y
is
not
None
)
==
(
self
.
num_classes
is
not
None
),
"must specify y if and only if the model is class-conditional"
hs
=
[]
t_emb
=
timestep_embedding
(
timesteps
,
self
.
model_channels
,
repeat_only
=
False
)
emb
=
self
.
time_embed
(
t_emb
)
if
self
.
num_classes
is
not
None
:
assert
y
.
shape
==
(
x
.
shape
[
0
],)
emb
=
emb
+
self
.
label_emb
(
y
)
h
=
x
.
type
(
self
.
dtype
)
for
module
in
self
.
input_blocks
:
h
=
module
(
h
,
emb
,
context
)
hs
.
append
(
h
)
h
=
self
.
middle_block
(
h
,
emb
,
context
)
for
module
in
self
.
output_blocks
:
if
h
.
shape
[
-
2
:]
!=
hs
[
-
1
]
.
shape
[
-
2
:]:
h
=
F
.
interpolate
(
h
,
hs
[
-
1
]
.
shape
[
-
2
:],
mode
=
"nearest"
)
h
=
torch
.
cat
([
h
,
hs
.
pop
()],
dim
=
1
)
h
=
module
(
h
,
emb
,
context
)
h
=
h
.
type
(
x
.
dtype
)
if
self
.
predict_codebook_ids
:
return
self
.
id_predictor
(
h
)
else
:
return
self
.
out
(
h
)
modules/sd_hijack_unet.py
0 → 100644
View file @
7dbfd8a7
import
torch
class
TorchHijackForUnet
:
"""
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
"""
def
__getattr__
(
self
,
item
):
if
item
==
'cat'
:
return
self
.
cat
if
hasattr
(
torch
,
item
):
return
getattr
(
torch
,
item
)
raise
AttributeError
(
"'{}' object has no attribute '{}'"
.
format
(
type
(
self
)
.
__name__
,
item
))
def
cat
(
self
,
tensors
,
*
args
,
**
kwargs
):
if
len
(
tensors
)
==
2
:
a
,
b
=
tensors
if
a
.
shape
[
-
2
:]
!=
b
.
shape
[
-
2
:]:
a
=
torch
.
nn
.
functional
.
interpolate
(
a
,
b
.
shape
[
-
2
:],
mode
=
"nearest"
)
tensors
=
(
a
,
b
)
return
torch
.
cat
(
tensors
,
*
args
,
**
kwargs
)
th
=
TorchHijackForUnet
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment