Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
241cbc4d
Unverified
Commit
241cbc4d
authored
Nov 29, 2022
by
wywywywy
Committed by
GitHub
Nov 29, 2022
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Hijack VQModelInterface back to AutoEncoder
parent
4b3c5bc2
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
282 additions
and
0 deletions
+282
-0
sd_hijack_autoencoder.py
modules/sd_hijack_autoencoder.py
+282
-0
No files found.
modules/sd_hijack_autoencoder.py
0 → 100644
View file @
241cbc4d
import
torch
import
pytorch_lightning
as
pl
import
torch.nn.functional
as
F
from
contextlib
import
contextmanager
from
taming.modules.vqvae.quantize
import
VectorQuantizer2
as
VectorQuantizer
from
ldm.modules.diffusionmodules.model
import
Encoder
,
Decoder
from
ldm.util
import
instantiate_from_config
import
ldm.models.autoencoder
class
VQModel
(
pl
.
LightningModule
):
def
__init__
(
self
,
ddconfig
,
lossconfig
,
n_embed
,
embed_dim
,
ckpt_path
=
None
,
ignore_keys
=
[],
image_key
=
"image"
,
colorize_nlabels
=
None
,
monitor
=
None
,
batch_resize_range
=
None
,
scheduler_config
=
None
,
lr_g_factor
=
1.0
,
remap
=
None
,
sane_index_shape
=
False
,
# tell vector quantizer to return indices as bhw
use_ema
=
False
):
super
()
.
__init__
()
self
.
embed_dim
=
embed_dim
self
.
n_embed
=
n_embed
self
.
image_key
=
image_key
self
.
encoder
=
Encoder
(
**
ddconfig
)
self
.
decoder
=
Decoder
(
**
ddconfig
)
self
.
loss
=
instantiate_from_config
(
lossconfig
)
self
.
quantize
=
VectorQuantizer
(
n_embed
,
embed_dim
,
beta
=
0.25
,
remap
=
remap
,
sane_index_shape
=
sane_index_shape
)
self
.
quant_conv
=
torch
.
nn
.
Conv2d
(
ddconfig
[
"z_channels"
],
embed_dim
,
1
)
self
.
post_quant_conv
=
torch
.
nn
.
Conv2d
(
embed_dim
,
ddconfig
[
"z_channels"
],
1
)
if
colorize_nlabels
is
not
None
:
assert
type
(
colorize_nlabels
)
==
int
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
colorize_nlabels
,
1
,
1
))
if
monitor
is
not
None
:
self
.
monitor
=
monitor
self
.
batch_resize_range
=
batch_resize_range
if
self
.
batch_resize_range
is
not
None
:
print
(
f
"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}."
)
self
.
use_ema
=
use_ema
if
self
.
use_ema
:
self
.
model_ema
=
LitEma
(
self
)
print
(
f
"Keeping EMAs of {len(list(self.model_ema.buffers()))}."
)
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
)
self
.
scheduler_config
=
scheduler_config
self
.
lr_g_factor
=
lr_g_factor
@
contextmanager
def
ema_scope
(
self
,
context
=
None
):
if
self
.
use_ema
:
self
.
model_ema
.
store
(
self
.
parameters
())
self
.
model_ema
.
copy_to
(
self
)
if
context
is
not
None
:
print
(
f
"{context}: Switched to EMA weights"
)
try
:
yield
None
finally
:
if
self
.
use_ema
:
self
.
model_ema
.
restore
(
self
.
parameters
())
if
context
is
not
None
:
print
(
f
"{context}: Restored training weights"
)
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
()):
sd
=
torch
.
load
(
path
,
map_location
=
"cpu"
)[
"state_dict"
]
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
print
(
"Deleting key {} from state_dict."
.
format
(
k
))
del
sd
[
k
]
missing
,
unexpected
=
self
.
load_state_dict
(
sd
,
strict
=
False
)
print
(
f
"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if
len
(
missing
)
>
0
:
print
(
f
"Missing Keys: {missing}"
)
print
(
f
"Unexpected Keys: {unexpected}"
)
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
):
if
self
.
use_ema
:
self
.
model_ema
(
self
)
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
return
quant
,
emb_loss
,
info
def
encode_to_prequant
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
quant
):
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
def
decode_code
(
self
,
code_b
):
quant_b
=
self
.
quantize
.
embed_code
(
code_b
)
dec
=
self
.
decode
(
quant_b
)
return
dec
def
forward
(
self
,
input
,
return_pred_indices
=
False
):
quant
,
diff
,
(
_
,
_
,
ind
)
=
self
.
encode
(
input
)
dec
=
self
.
decode
(
quant
)
if
return_pred_indices
:
return
dec
,
diff
,
ind
return
dec
,
diff
def
get_input
(
self
,
batch
,
k
):
x
=
batch
[
k
]
if
len
(
x
.
shape
)
==
3
:
x
=
x
[
...
,
None
]
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
.
to
(
memory_format
=
torch
.
contiguous_format
)
.
float
()
if
self
.
batch_resize_range
is
not
None
:
lower_size
=
self
.
batch_resize_range
[
0
]
upper_size
=
self
.
batch_resize_range
[
1
]
if
self
.
global_step
<=
4
:
# do the first few batches with max size to avoid later oom
new_resize
=
upper_size
else
:
new_resize
=
np
.
random
.
choice
(
np
.
arange
(
lower_size
,
upper_size
+
16
,
16
))
if
new_resize
!=
x
.
shape
[
2
]:
x
=
F
.
interpolate
(
x
,
size
=
new_resize
,
mode
=
"bicubic"
)
x
=
x
.
detach
()
return
x
def
training_step
(
self
,
batch
,
batch_idx
,
optimizer_idx
):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
xrec
,
qloss
,
ind
=
self
(
x
,
return_pred_indices
=
True
)
if
optimizer_idx
==
0
:
# autoencode
aeloss
,
log_dict_ae
=
self
.
loss
(
qloss
,
x
,
xrec
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
,
predicted_indices
=
ind
)
self
.
log_dict
(
log_dict_ae
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
return
aeloss
if
optimizer_idx
==
1
:
# discriminator
discloss
,
log_dict_disc
=
self
.
loss
(
qloss
,
x
,
xrec
,
optimizer_idx
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"train"
)
self
.
log_dict
(
log_dict_disc
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
return
discloss
def
validation_step
(
self
,
batch
,
batch_idx
):
log_dict
=
self
.
_validation_step
(
batch
,
batch_idx
)
with
self
.
ema_scope
():
log_dict_ema
=
self
.
_validation_step
(
batch
,
batch_idx
,
suffix
=
"_ema"
)
return
log_dict
def
_validation_step
(
self
,
batch
,
batch_idx
,
suffix
=
""
):
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
xrec
,
qloss
,
ind
=
self
(
x
,
return_pred_indices
=
True
)
aeloss
,
log_dict_ae
=
self
.
loss
(
qloss
,
x
,
xrec
,
0
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
+
suffix
,
predicted_indices
=
ind
)
discloss
,
log_dict_disc
=
self
.
loss
(
qloss
,
x
,
xrec
,
1
,
self
.
global_step
,
last_layer
=
self
.
get_last_layer
(),
split
=
"val"
+
suffix
,
predicted_indices
=
ind
)
rec_loss
=
log_dict_ae
[
f
"val{suffix}/rec_loss"
]
self
.
log
(
f
"val{suffix}/rec_loss"
,
rec_loss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
,
sync_dist
=
True
)
self
.
log
(
f
"val{suffix}/aeloss"
,
aeloss
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
,
sync_dist
=
True
)
if
version
.
parse
(
pl
.
__version__
)
>=
version
.
parse
(
'1.4.0'
):
del
log_dict_ae
[
f
"val{suffix}/rec_loss"
]
self
.
log_dict
(
log_dict_ae
)
self
.
log_dict
(
log_dict_disc
)
return
self
.
log_dict
def
configure_optimizers
(
self
):
lr_d
=
self
.
learning_rate
lr_g
=
self
.
lr_g_factor
*
self
.
learning_rate
print
(
"lr_d"
,
lr_d
)
print
(
"lr_g"
,
lr_g
)
opt_ae
=
torch
.
optim
.
Adam
(
list
(
self
.
encoder
.
parameters
())
+
list
(
self
.
decoder
.
parameters
())
+
list
(
self
.
quantize
.
parameters
())
+
list
(
self
.
quant_conv
.
parameters
())
+
list
(
self
.
post_quant_conv
.
parameters
()),
lr
=
lr_g
,
betas
=
(
0.5
,
0.9
))
opt_disc
=
torch
.
optim
.
Adam
(
self
.
loss
.
discriminator
.
parameters
(),
lr
=
lr_d
,
betas
=
(
0.5
,
0.9
))
if
self
.
scheduler_config
is
not
None
:
scheduler
=
instantiate_from_config
(
self
.
scheduler_config
)
print
(
"Setting up LambdaLR scheduler..."
)
scheduler
=
[
{
'scheduler'
:
LambdaLR
(
opt_ae
,
lr_lambda
=
scheduler
.
schedule
),
'interval'
:
'step'
,
'frequency'
:
1
},
{
'scheduler'
:
LambdaLR
(
opt_disc
,
lr_lambda
=
scheduler
.
schedule
),
'interval'
:
'step'
,
'frequency'
:
1
},
]
return
[
opt_ae
,
opt_disc
],
scheduler
return
[
opt_ae
,
opt_disc
],
[]
def
get_last_layer
(
self
):
return
self
.
decoder
.
conv_out
.
weight
def
log_images
(
self
,
batch
,
only_inputs
=
False
,
plot_ema
=
False
,
**
kwargs
):
log
=
dict
()
x
=
self
.
get_input
(
batch
,
self
.
image_key
)
x
=
x
.
to
(
self
.
device
)
if
only_inputs
:
log
[
"inputs"
]
=
x
return
log
xrec
,
_
=
self
(
x
)
if
x
.
shape
[
1
]
>
3
:
# colorize with random projection
assert
xrec
.
shape
[
1
]
>
3
x
=
self
.
to_rgb
(
x
)
xrec
=
self
.
to_rgb
(
xrec
)
log
[
"inputs"
]
=
x
log
[
"reconstructions"
]
=
xrec
if
plot_ema
:
with
self
.
ema_scope
():
xrec_ema
,
_
=
self
(
x
)
if
x
.
shape
[
1
]
>
3
:
xrec_ema
=
self
.
to_rgb
(
xrec_ema
)
log
[
"reconstructions_ema"
]
=
xrec_ema
return
log
def
to_rgb
(
self
,
x
):
assert
self
.
image_key
==
"segmentation"
if
not
hasattr
(
self
,
"colorize"
):
self
.
register_buffer
(
"colorize"
,
torch
.
randn
(
3
,
x
.
shape
[
1
],
1
,
1
)
.
to
(
x
))
x
=
F
.
conv2d
(
x
,
weight
=
self
.
colorize
)
x
=
2.
*
(
x
-
x
.
min
())
/
(
x
.
max
()
-
x
.
min
())
-
1.
return
x
class
VQModelInterface
(
VQModel
):
def
__init__
(
self
,
embed_dim
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
embed_dim
=
embed_dim
,
*
args
,
**
kwargs
)
self
.
embed_dim
=
embed_dim
def
encode
(
self
,
x
):
h
=
self
.
encoder
(
x
)
h
=
self
.
quant_conv
(
h
)
return
h
def
decode
(
self
,
h
,
force_not_quantize
=
False
):
# also go through quantization layer
if
not
force_not_quantize
:
quant
,
emb_loss
,
info
=
self
.
quantize
(
h
)
else
:
quant
=
h
quant
=
self
.
post_quant_conv
(
quant
)
dec
=
self
.
decoder
(
quant
)
return
dec
setattr
(
ldm
.
models
.
autoencoder
,
"VQModel"
,
VQModel
)
setattr
(
ldm
.
models
.
autoencoder
,
"VQModelInterface"
,
VQModelInterface
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment