Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
89e1df01
Commit
89e1df01
authored
Dec 03, 2022
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Plain Diff
Merge remote-tracking branch 'wywywywy/autoencoder-hijack'
parents
b6e5edd7
7193814c
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
287 additions
and
1 deletion
+287
-1
sd_hijack.py
modules/sd_hijack.py
+1
-1
sd_hijack_autoencoder.py
modules/sd_hijack_autoencoder.py
+286
-0
No files found.
modules/sd_hijack.py
View file @
89e1df01
...
...
@@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion
from
modules
import
prompt_parser
,
devices
,
sd_hijack_optimizations
,
shared
,
sd_hijack_checkpoint
from
modules.hypernetworks
import
hypernetwork
from
modules.shared
import
opts
,
device
,
cmd_opts
from
modules
import
sd_hijack_clip
,
sd_hijack_open_clip
from
modules
import
sd_hijack_clip
,
sd_hijack_open_clip
,
sd_hijack_autoencoder
from
modules.sd_hijack_optimizations
import
invokeAI_mps_available
...
...
modules/sd_hijack_autoencoder.py
0 → 100644
View file @
89e1df01
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
import
torch
import
pytorch_lightning
as
pl
import
torch.nn.functional
as
F
from
contextlib
import
contextmanager
from
taming.modules.vqvae.quantize
import
VectorQuantizer2
as
VectorQuantizer
from
ldm.modules.diffusionmodules.model
import
Encoder
,
Decoder
from
ldm.util
import
instantiate_from_config
import
ldm.models.autoencoder
class
VQModel
(
pl
.
LightningModule
):
def
__init__
(
self
,
ddconfig
,
lossconfig
,
n_embed
,
embed_dim
,
ckpt_path
=
None
,
ignore_keys
=
[],
image_key
=
"image"
,
colorize_nlabels
=
None
,
monitor
=
None
,
batch_resize_range
=
None
,
scheduler_config
=
None
,
lr_g_factor
=
1.0
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
use_ema
=
False
):
super
()
.
__init__
()
self
.
embed_dim
=
embed_dim
self
.
n_embed
=
n_embed
self
.
image_key
=
image_key
self
.
encoder
=
Encoder
(
**
ddconfig
)
self
.
decoder
=
Decoder
(
**
ddconfig
)
self
.
loss
=
instantiate_from_config
(
lossconfig
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
ddconfig
[
"z_channels"
],
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
if
colorize_nlabels
is
not
None
:
assert
type
(
colorize_nlabels
)
==
int
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
colorize_nlabels
,
1
,
1
))
if
monitor
is
not
None
:
self
.
monitor
=
monitor
self
.
batch_resize_range
=
batch_resize_range
if
self
.
batch_resize_range
is
not
None
:
print
(
f
"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
)
self
.
use_ema
=
use_ema
if
self
.
use_ema
:
self
.
model_ema
=
LitEma
(
self
)
print
(
f
"Keeping EMAs of {len(list(self.model_ema.buffers()))}."
)
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
)
self
.
scheduler_config
=
scheduler_config
self
.
lr_g_factor
=
lr_g_factor
@
contextmanager
def
ema_scope
(
self
,
context
=
None
):
if
self
.
use_ema
:
self
.
model_ema
.
store
(
self
.
parameters
())
self
.
model_ema
.
copy_to
(
self
)
if
context
is
not
None
:
print
(
f
"{context}: Switched to EMA weights"
)
try
:
yield
None
finally
:
if
self
.
use_ema
:
self
.
model_ema
.
restore
(
self
.
parameters
())
if
context
is
not
None
:
print
(
f
"{context}: Restored training weights"
)
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
()):
sd
=
torch
.
load
(
path
,
map_location
=
"cpu"
)[
"state_dict"
]
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
print
(
"Deleting key {} from state_dict."
.
format
(
k
))
del
sd
[
k
]
missing
,
unexpected
=
self
.
load_state_dict
(
sd
,
strict
=
False
)
print
(
f
"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if
len
(
missing
)
>
0
:
print
(
f
"Missing Keys: {missing}"
)
print
(
f
"Unexpected Keys: {unexpected}"
)
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
):
if
self
.
use_ema
:
self
.
model_ema
(
self
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
return
quant
,
emb_loss
,
info
def
encode_to_prequant
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
quant
):
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
def
decode_code
(
self
,
code_b
):
quant_b
=
self
.
quantize
.
embed_code
(
code_b
)
dec
=
self
.
decode
(
quant_b
)
return
dec
def
forward
(
self
,
input
,
return_pred_indices
=
False
):
quant
,
diff
,
(
_
,
_
,
ind
)
=
self
.
encode
(
input
)
dec
=
self
.
decode
(
quant
)
if
return_pred_indices
:
return
dec
,
diff
,
ind
return
dec
,
diff
def
get_input
(
self
,
batch
,
k
):
x
=
batch
[
k
]
if
len
(
x
.
shape
)
==
3
:
x
=
x
[
...
,
None
]
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
.
to
(
memory_format
=
torch
.
contiguous_format
)
.
float
()
if
self
.
batch_resize_range
is
not
None
:
lower_size
=
self
.
batch_resize_range
[
0
]
upper_size
=
self
.
batch_resize_range
[
1
]
if
self
.
global_step
<=
4
:
# do the first few batches with max size to avoid later oom
new_resize
=
upper_size
else
:
new_resize
=
np
.
random
.
choice
(
np
.
arange
(
lower_size
,
upper_size
+
16
,
16
))
if
new_resize
!=
x
.
shape
[
2
]:
x
=
F
.
interpolate
(
x
,
size
=
new_resize
,
mode
=
"bicubic"
)
x
=
x
.
detach
()
return
x
def
training_step
(
self
,
batch
,
batch_idx
,
optimizer_idx
):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
xrec
,
qloss
,
ind
=
self
(
x
,
return_pred_indices
=
True
)
if
optimizer_idx
==
0
:
# autoencode
aeloss
,
log_dict_ae
=
self
.
loss
(
qloss
,
x
,
xrec
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
,
predicted_indices
=
ind
)
self
.
log_dict
(
log_dict_ae
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
return
aeloss
if
optimizer_idx
==
1
:
# discriminator
discloss
,
log_dict_disc
=
self
.
loss
(
qloss
,
x
,
xrec
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
)
self
.
log_dict
(
log_dict_disc
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
return
discloss
def
validation_step
(
self
,
batch
,
batch_idx
):
log_dict
=
self
.
_validation_step
(
batch
,
batch_idx
)
with
self
.
ema_scope
():
log_dict_ema
=
self
.
_validation_step
(
batch
,
batch_idx
,
suffix
=
"_ema"
)
return
log_dict
def
_validation_step
(
self
,
batch
,
batch_idx
,
suffix
=
""
):
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
xrec
,
qloss
,
ind
=
self
(
x
,
return_pred_indices
=
True
)
aeloss
,
log_dict_ae
=
self
.
loss
(
qloss
,
x
,
xrec
,
0
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
+
suffix
,
predicted_indices
=
ind
)
discloss
,
log_dict_disc
=
self
.
loss
(
qloss
,
x
,
xrec
,
1
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
+
suffix
,
predicted_indices
=
ind
)
rec_loss
=
log_dict_ae
[
f
"val{suffix}/rec_loss"
]
self
.
log
(
f
"val{suffix}/rec_loss"
,
rec_loss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
,
sync_dist
=
True
)
self
.
log
(
f
"val{suffix}/aeloss"
,
aeloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
,
sync_dist
=
True
)
if
version
.
parse
(
pl
.
__version__
)
>=
version
.
parse
(
'1.4.0'
):
del
log_dict_ae
[
f
"val{suffix}/rec_loss"
]
self
.
log_dict
(
log_dict_ae
)
self
.
log_dict
(
log_dict_disc
)
return
self
.
log_dict
def
configure_optimizers
(
self
):
lr_d
=
self
.
learning_rate
lr_g
=
self
.
lr_g_factor
*
self
.
learning_rate
print
(
"lr_d"
,
lr_d
)
print
(
"lr_g"
,
lr_g
)
opt_ae
=
torch
.
optim
.
Adam
(
list
(
self
.
encoder
.
parameters
())
+
list
(
self
.
decoder
.
parameters
())
+
list
(
self
.
quantize
.
parameters
())
+
list
(
self
.
quant_conv
.
parameters
())
+
list
(
self
.
post_quant_conv
.
parameters
()),
lr
=
lr_g
,
betas
=
(
0.5
,
0.9
))
opt_disc
=
torch
.
optim
.
Adam
(
self
.
loss
.
discriminator
.
parameters
(),
lr
=
lr_d
,
betas
=
(
0.5
,
0.9
))
if
self
.
scheduler_config
is
not
None
:
scheduler
=
instantiate_from_config
(
self
.
scheduler_config
)
print
(
"Setting up LambdaLR scheduler..."
)
scheduler
=
[
{
'scheduler'
:
LambdaLR
(
opt_ae
,
lr_lambda
=
scheduler
.
schedule
),
'interval'
:
'step'
,
'frequency'
:
1
},
{
'scheduler'
:
LambdaLR
(
opt_disc
,
lr_lambda
=
scheduler
.
schedule
),
'interval'
:
'step'
,
'frequency'
:
1
},
]
return
[
opt_ae
,
opt_disc
],
scheduler
return
[
opt_ae
,
opt_disc
],
[]
def
get_last_layer
(
self
):
return
self
.
decoder
.
conv_out
.
weight
def
log_images
(
self
,
batch
,
only_inputs
=
False
,
plot_ema
=
False
,
**
kwargs
):
log
=
dict
()
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
x
=
x
.
to
(
self
.
device
)
if
only_inputs
:
log
[
"inputs"
]
=
x
return
log
xrec
,
_
=
self
(
x
)
if
x
.
shape
[
1
]
>
3
:
# colorize with random projection
assert
xrec
.
shape
[
1
]
>
3
x
=
self
.
to_rgb
(
x
)
xrec
=
self
.
to_rgb
(
xrec
)
log
[
"inputs"
]
=
x
log
[
"reconstructions"
]
=
xrec
if
plot_ema
:
with
self
.
ema_scope
():
xrec_ema
,
_
=
self
(
x
)
if
x
.
shape
[
1
]
>
3
:
xrec_ema
=
self
.
to_rgb
(
xrec_ema
)
log
[
"reconstructions_ema"
]
=
xrec_ema
return
log
def
to_rgb
(
self
,
x
):
assert
self
.
image_key
==
"segmentation"
if
not
hasattr
(
self
,
"colorize"
):
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
x
.
shape
[
1
],
1
,
1
)
.
to
(
x
))
x
=
F
.
conv2d
(
x
,
weight
=
self
.
colorize
)
x
=
2.
*
(
x
-
x
.
min
())
/
(
x
.
max
()
-
x
.
min
())
-
1.
return
x
class
VQModelInterface
(
VQModel
):
def
__init__
(
self
,
embed_dim
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
embed_dim
=
embed_dim
,
*
args
,
**
kwargs
)
self
.
embed_dim
=
embed_dim
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
setattr
(
ldm
.
models
.
autoencoder
,
"VQModel"
,
VQModel
)
setattr
(
ldm
.
models
.
autoencoder
,
"VQModelInterface"
,
VQModelInterface
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment