Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
3896242e
Unverified
Commit
3896242e
authored
Dec 10, 2022
by
AUTOMATIC1111
Committed by
GitHub
Dec 10, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #5415 from wywywywy/reinstate-ddpm-v1
Reinstate DDPM V1 to LDSR
parents
505ec7e4
a8ae263c
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
1451 additions
and
1 deletion
+1451
-1
ldsr_model_arch.py
extensions-builtin/LDSR/ldsr_model_arch.py
+1
-0
ldsr_model.py
extensions-builtin/LDSR/scripts/ldsr_model.py
+1
-1
sd_hijack_ddpm_v1.py
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
+1449
-0
No files found.
extensions-builtin/LDSR/ldsr_model_arch.py
View file @
3896242e
...
@@ -22,6 +22,7 @@ class LDSR:
...
@@ -22,6 +22,7 @@ class LDSR:
pl_sd
=
torch
.
load
(
self
.
modelPath
,
map_location
=
"cpu"
)
pl_sd
=
torch
.
load
(
self
.
modelPath
,
map_location
=
"cpu"
)
sd
=
pl_sd
[
"state_dict"
]
sd
=
pl_sd
[
"state_dict"
]
config
=
OmegaConf
.
load
(
self
.
yamlPath
)
config
=
OmegaConf
.
load
(
self
.
yamlPath
)
config
.
model
.
target
=
"ldm.models.diffusion.ddpm.LatentDiffusionV1"
model
=
instantiate_from_config
(
config
.
model
)
model
=
instantiate_from_config
(
config
.
model
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
model
.
load_state_dict
(
sd
,
strict
=
False
)
model
.
cuda
()
model
.
cuda
()
...
...
extensions-builtin/LDSR/scripts/ldsr_model.py
View file @
3896242e
...
@@ -7,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
...
@@ -7,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
from
modules.upscaler
import
Upscaler
,
UpscalerData
from
modules.upscaler
import
Upscaler
,
UpscalerData
from
ldsr_model_arch
import
LDSR
from
ldsr_model_arch
import
LDSR
from
modules
import
shared
,
script_callbacks
from
modules
import
shared
,
script_callbacks
import
sd_hijack_autoencoder
import
sd_hijack_autoencoder
,
sd_hijack_ddpm_v1
class
UpscalerLDSR
(
Upscaler
):
class
UpscalerLDSR
(
Upscaler
):
...
...
extensions-builtin/LDSR/sd_hijack_ddpm_v1.py
0 → 100644
View file @
3896242e
# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)
# Original filename: ldm/models/diffusion/ddpm.py
# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't
# Some models such as LDSR require VQ to work correctly
# The classes are suffixed with "V1" and added back to the "ldm.models.diffusion.ddpm" module
import
torch
import
torch.nn
as
nn
import
numpy
as
np
import
pytorch_lightning
as
pl
from
torch.optim.lr_scheduler
import
LambdaLR
from
einops
import
rearrange
,
repeat
from
contextlib
import
contextmanager
from
functools
import
partial
from
tqdm
import
tqdm
from
torchvision.utils
import
make_grid
from
pytorch_lightning.utilities.distributed
import
rank_zero_only
from
ldm.util
import
log_txt_as_img
,
exists
,
default
,
ismap
,
isimage
,
mean_flat
,
count_params
,
instantiate_from_config
from
ldm.modules.ema
import
LitEma
from
ldm.modules.distributions.distributions
import
normal_kl
,
DiagonalGaussianDistribution
from
ldm.models.autoencoder
import
VQModelInterface
,
IdentityFirstStage
,
AutoencoderKL
from
ldm.modules.diffusionmodules.util
import
make_beta_schedule
,
extract_into_tensor
,
noise_like
from
ldm.models.diffusion.ddim
import
DDIMSampler
import
ldm.models.diffusion.ddpm
__conditioning_keys__
=
{
'concat'
:
'c_concat'
,
'crossattn'
:
'c_crossattn'
,
'adm'
:
'y'
}
def
disabled_train
(
self
,
mode
=
True
):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return
self
def
uniform_on_device
(
r1
,
r2
,
shape
,
device
):
return
(
r1
-
r2
)
*
torch
.
rand
(
*
shape
,
device
=
device
)
+
r2
class
DDPMV1
(
pl
.
LightningModule
):
# classic DDPM with Gaussian diffusion, in image space
def
__init__
(
self
,
unet_config
,
timesteps
=
1000
,
beta_schedule
=
"linear"
,
loss_type
=
"l2"
,
ckpt_path
=
None
,
ignore_keys
=
[],
load_only_unet
=
False
,
monitor
=
"val/loss"
,
use_ema
=
True
,
first_stage_key
=
"image"
,
image_size
=
256
,
channels
=
3
,
log_every_t
=
100
,
clip_denoised
=
True
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
,
given_betas
=
None
,
original_elbo_weight
=
0.
,
v_posterior
=
0.
,
# weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight
=
1.
,
conditioning_key
=
None
,
parameterization
=
"eps"
,
# all assuming fixed variance schedules
scheduler_config
=
None
,
use_positional_encodings
=
False
,
learn_logvar
=
False
,
logvar_init
=
0.
,
):
super
()
.
__init__
()
assert
parameterization
in
[
"eps"
,
"x0"
],
'currently only supporting "eps" and "x0"'
self
.
parameterization
=
parameterization
print
(
f
"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
)
self
.
cond_stage_model
=
None
self
.
clip_denoised
=
clip_denoised
self
.
log_every_t
=
log_every_t
self
.
first_stage_key
=
first_stage_key
self
.
image_size
=
image_size
# try conv?
self
.
channels
=
channels
self
.
use_positional_encodings
=
use_positional_encodings
self
.
model
=
DiffusionWrapperV1
(
unet_config
,
conditioning_key
)
count_params
(
self
.
model
,
verbose
=
True
)
self
.
use_ema
=
use_ema
if
self
.
use_ema
:
self
.
model_ema
=
LitEma
(
self
.
model
)
print
(
f
"Keeping EMAs of {len(list(self.model_ema.buffers()))}."
)
self
.
use_scheduler
=
scheduler_config
is
not
None
if
self
.
use_scheduler
:
self
.
scheduler_config
=
scheduler_config
self
.
v_posterior
=
v_posterior
self
.
original_elbo_weight
=
original_elbo_weight
self
.
l_simple_weight
=
l_simple_weight
if
monitor
is
not
None
:
self
.
monitor
=
monitor
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
=
ignore_keys
,
only_model
=
load_only_unet
)
self
.
register_schedule
(
given_betas
=
given_betas
,
beta_schedule
=
beta_schedule
,
timesteps
=
timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
,
cosine_s
=
cosine_s
)
self
.
loss_type
=
loss_type
self
.
learn_logvar
=
learn_logvar
self
.
logvar
=
torch
.
full
(
fill_value
=
logvar_init
,
size
=
(
self
.
num_timesteps
,))
if
self
.
learn_logvar
:
self
.
logvar
=
nn
.
Parameter
(
self
.
logvar
,
requires_grad
=
True
)
def
register_schedule
(
self
,
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
if
exists
(
given_betas
):
betas
=
given_betas
else
:
betas
=
make_beta_schedule
(
beta_schedule
,
timesteps
,
linear_start
=
linear_start
,
linear_end
=
linear_end
,
cosine_s
=
cosine_s
)
alphas
=
1.
-
betas
alphas_cumprod
=
np
.
cumprod
(
alphas
,
axis
=
0
)
alphas_cumprod_prev
=
np
.
append
(
1.
,
alphas_cumprod
[:
-
1
])
timesteps
,
=
betas
.
shape
self
.
num_timesteps
=
int
(
timesteps
)
self
.
linear_start
=
linear_start
self
.
linear_end
=
linear_end
assert
alphas_cumprod
.
shape
[
0
]
==
self
.
num_timesteps
,
'alphas have to be defined for each timestep'
to_torch
=
partial
(
torch
.
tensor
,
dtype
=
torch
.
float32
)
self
.
register_buffer
(
'betas'
,
to_torch
(
betas
))
self
.
register_buffer
(
'alphas_cumprod'
,
to_torch
(
alphas_cumprod
))
self
.
register_buffer
(
'alphas_cumprod_prev'
,
to_torch
(
alphas_cumprod_prev
))
# calculations for diffusion q(x_t | x_{t-1}) and others
self
.
register_buffer
(
'sqrt_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_one_minus_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
-
alphas_cumprod
)))
self
.
register_buffer
(
'log_one_minus_alphas_cumprod'
,
to_torch
(
np
.
log
(
1.
-
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_recip_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
)))
self
.
register_buffer
(
'sqrt_recipm1_alphas_cumprod'
,
to_torch
(
np
.
sqrt
(
1.
/
alphas_cumprod
-
1
)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance
=
(
1
-
self
.
v_posterior
)
*
betas
*
(
1.
-
alphas_cumprod_prev
)
/
(
1.
-
alphas_cumprod
)
+
self
.
v_posterior
*
betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self
.
register_buffer
(
'posterior_variance'
,
to_torch
(
posterior_variance
))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self
.
register_buffer
(
'posterior_log_variance_clipped'
,
to_torch
(
np
.
log
(
np
.
maximum
(
posterior_variance
,
1e-20
))))
self
.
register_buffer
(
'posterior_mean_coef1'
,
to_torch
(
betas
*
np
.
sqrt
(
alphas_cumprod_prev
)
/
(
1.
-
alphas_cumprod
)))
self
.
register_buffer
(
'posterior_mean_coef2'
,
to_torch
(
(
1.
-
alphas_cumprod_prev
)
*
np
.
sqrt
(
alphas
)
/
(
1.
-
alphas_cumprod
)))
if
self
.
parameterization
==
"eps"
:
lvlb_weights
=
self
.
betas
**
2
/
(
2
*
self
.
posterior_variance
*
to_torch
(
alphas
)
*
(
1
-
self
.
alphas_cumprod
))
elif
self
.
parameterization
==
"x0"
:
lvlb_weights
=
0.5
*
np
.
sqrt
(
torch
.
Tensor
(
alphas_cumprod
))
/
(
2.
*
1
-
torch
.
Tensor
(
alphas_cumprod
))
else
:
raise
NotImplementedError
(
"mu not supported"
)
# TODO how to choose this term
lvlb_weights
[
0
]
=
lvlb_weights
[
1
]
self
.
register_buffer
(
'lvlb_weights'
,
lvlb_weights
,
persistent
=
False
)
assert
not
torch
.
isnan
(
self
.
lvlb_weights
)
.
all
()
@
contextmanager
def
ema_scope
(
self
,
context
=
None
):
if
self
.
use_ema
:
self
.
model_ema
.
store
(
self
.
model
.
parameters
())
self
.
model_ema
.
copy_to
(
self
.
model
)
if
context
is
not
None
:
print
(
f
"{context}: Switched to EMA weights"
)
try
:
yield
None
finally
:
if
self
.
use_ema
:
self
.
model_ema
.
restore
(
self
.
model
.
parameters
())
if
context
is
not
None
:
print
(
f
"{context}: Restored training weights"
)
def
init_from_ckpt
(
self
,
path
,
ignore_keys
=
list
(),
only_model
=
False
):
sd
=
torch
.
load
(
path
,
map_location
=
"cpu"
)
if
"state_dict"
in
list
(
sd
.
keys
()):
sd
=
sd
[
"state_dict"
]
keys
=
list
(
sd
.
keys
())
for
k
in
keys
:
for
ik
in
ignore_keys
:
if
k
.
startswith
(
ik
):
print
(
"Deleting key {} from state_dict."
.
format
(
k
))
del
sd
[
k
]
missing
,
unexpected
=
self
.
load_state_dict
(
sd
,
strict
=
False
)
if
not
only_model
else
self
.
model
.
load_state_dict
(
sd
,
strict
=
False
)
print
(
f
"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if
len
(
missing
)
>
0
:
print
(
f
"Missing Keys: {missing}"
)
if
len
(
unexpected
)
>
0
:
print
(
f
"Unexpected Keys: {unexpected}"
)
def
q_mean_variance
(
self
,
x_start
,
t
):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean
=
(
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
)
variance
=
extract_into_tensor
(
1.0
-
self
.
alphas_cumprod
,
t
,
x_start
.
shape
)
log_variance
=
extract_into_tensor
(
self
.
log_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
return
mean
,
variance
,
log_variance
def
predict_start_from_noise
(
self
,
x_t
,
t
,
noise
):
return
(
extract_into_tensor
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
extract_into_tensor
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
*
noise
)
def
q_posterior
(
self
,
x_start
,
x_t
,
t
):
posterior_mean
=
(
extract_into_tensor
(
self
.
posterior_mean_coef1
,
t
,
x_t
.
shape
)
*
x_start
+
extract_into_tensor
(
self
.
posterior_mean_coef2
,
t
,
x_t
.
shape
)
*
x_t
)
posterior_variance
=
extract_into_tensor
(
self
.
posterior_variance
,
t
,
x_t
.
shape
)
posterior_log_variance_clipped
=
extract_into_tensor
(
self
.
posterior_log_variance_clipped
,
t
,
x_t
.
shape
)
return
posterior_mean
,
posterior_variance
,
posterior_log_variance_clipped
def
p_mean_variance
(
self
,
x
,
t
,
clip_denoised
:
bool
):
model_out
=
self
.
model
(
x
,
t
)
if
self
.
parameterization
==
"eps"
:
x_recon
=
self
.
predict_start_from_noise
(
x
,
t
=
t
,
noise
=
model_out
)
elif
self
.
parameterization
==
"x0"
:
x_recon
=
model_out
if
clip_denoised
:
x_recon
.
clamp_
(
-
1.
,
1.
)
model_mean
,
posterior_variance
,
posterior_log_variance
=
self
.
q_posterior
(
x_start
=
x_recon
,
x_t
=
x
,
t
=
t
)
return
model_mean
,
posterior_variance
,
posterior_log_variance
@
torch
.
no_grad
()
def
p_sample
(
self
,
x
,
t
,
clip_denoised
=
True
,
repeat_noise
=
False
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
model_mean
,
_
,
model_log_variance
=
self
.
p_mean_variance
(
x
=
x
,
t
=
t
,
clip_denoised
=
clip_denoised
)
noise
=
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
# no noise when t == 0
nonzero_mask
=
(
1
-
(
t
==
0
)
.
float
())
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x
.
shape
)
-
1
)))
return
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
)
.
exp
()
*
noise
@
torch
.
no_grad
()
def
p_sample_loop
(
self
,
shape
,
return_intermediates
=
False
):
device
=
self
.
betas
.
device
b
=
shape
[
0
]
img
=
torch
.
randn
(
shape
,
device
=
device
)
intermediates
=
[
img
]
for
i
in
tqdm
(
reversed
(
range
(
0
,
self
.
num_timesteps
)),
desc
=
'Sampling t'
,
total
=
self
.
num_timesteps
):
img
=
self
.
p_sample
(
img
,
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
),
clip_denoised
=
self
.
clip_denoised
)
if
i
%
self
.
log_every_t
==
0
or
i
==
self
.
num_timesteps
-
1
:
intermediates
.
append
(
img
)
if
return_intermediates
:
return
img
,
intermediates
return
img
@
torch
.
no_grad
()
def
sample
(
self
,
batch_size
=
16
,
return_intermediates
=
False
):
image_size
=
self
.
image_size
channels
=
self
.
channels
return
self
.
p_sample_loop
((
batch_size
,
channels
,
image_size
,
image_size
),
return_intermediates
=
return_intermediates
)
def
q_sample
(
self
,
x_start
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
return
(
extract_into_tensor
(
self
.
sqrt_alphas_cumprod
,
t
,
x_start
.
shape
)
*
x_start
+
extract_into_tensor
(
self
.
sqrt_one_minus_alphas_cumprod
,
t
,
x_start
.
shape
)
*
noise
)
def
get_loss
(
self
,
pred
,
target
,
mean
=
True
):
if
self
.
loss_type
==
'l1'
:
loss
=
(
target
-
pred
)
.
abs
()
if
mean
:
loss
=
loss
.
mean
()
elif
self
.
loss_type
==
'l2'
:
if
mean
:
loss
=
torch
.
nn
.
functional
.
mse_loss
(
target
,
pred
)
else
:
loss
=
torch
.
nn
.
functional
.
mse_loss
(
target
,
pred
,
reduction
=
'none'
)
else
:
raise
NotImplementedError
(
"unknown loss type '{loss_type}'"
)
return
loss
def
p_losses
(
self
,
x_start
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
x_noisy
=
self
.
q_sample
(
x_start
=
x_start
,
t
=
t
,
noise
=
noise
)
model_out
=
self
.
model
(
x_noisy
,
t
)
loss_dict
=
{}
if
self
.
parameterization
==
"eps"
:
target
=
noise
elif
self
.
parameterization
==
"x0"
:
target
=
x_start
else
:
raise
NotImplementedError
(
f
"Paramterization {self.parameterization} not yet supported"
)
loss
=
self
.
get_loss
(
model_out
,
target
,
mean
=
False
)
.
mean
(
dim
=
[
1
,
2
,
3
])
log_prefix
=
'train'
if
self
.
training
else
'val'
loss_dict
.
update
({
f
'{log_prefix}/loss_simple'
:
loss
.
mean
()})
loss_simple
=
loss
.
mean
()
*
self
.
l_simple_weight
loss_vlb
=
(
self
.
lvlb_weights
[
t
]
*
loss
)
.
mean
()
loss_dict
.
update
({
f
'{log_prefix}/loss_vlb'
:
loss_vlb
})
loss
=
loss_simple
+
self
.
original_elbo_weight
*
loss_vlb
loss_dict
.
update
({
f
'{log_prefix}/loss'
:
loss
})
return
loss
,
loss_dict
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t
=
torch
.
randint
(
0
,
self
.
num_timesteps
,
(
x
.
shape
[
0
],),
device
=
self
.
device
)
.
long
()
return
self
.
p_losses
(
x
,
t
,
*
args
,
**
kwargs
)
def
get_input
(
self
,
batch
,
k
):
x
=
batch
[
k
]
if
len
(
x
.
shape
)
==
3
:
x
=
x
[
...
,
None
]
x
=
rearrange
(
x
,
'b h w c -> b c h w'
)
x
=
x
.
to
(
memory_format
=
torch
.
contiguous_format
)
.
float
()
return
x
def
shared_step
(
self
,
batch
):
x
=
self
.
get_input
(
batch
,
self
.
first_stage_key
)
loss
,
loss_dict
=
self
(
x
)
return
loss
,
loss_dict
def
training_step
(
self
,
batch
,
batch_idx
):
loss
,
loss_dict
=
self
.
shared_step
(
batch
)
self
.
log_dict
(
loss_dict
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
True
)
self
.
log
(
"global_step"
,
self
.
global_step
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
if
self
.
use_scheduler
:
lr
=
self
.
optimizers
()
.
param_groups
[
0
][
'lr'
]
self
.
log
(
'lr_abs'
,
lr
,
prog_bar
=
True
,
logger
=
True
,
on_step
=
True
,
on_epoch
=
False
)
return
loss
@
torch
.
no_grad
()
def
validation_step
(
self
,
batch
,
batch_idx
):
_
,
loss_dict_no_ema
=
self
.
shared_step
(
batch
)
with
self
.
ema_scope
():
_
,
loss_dict_ema
=
self
.
shared_step
(
batch
)
loss_dict_ema
=
{
key
+
'_ema'
:
loss_dict_ema
[
key
]
for
key
in
loss_dict_ema
}
self
.
log_dict
(
loss_dict_no_ema
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
)
self
.
log_dict
(
loss_dict_ema
,
prog_bar
=
False
,
logger
=
True
,
on_step
=
False
,
on_epoch
=
True
)
def
on_train_batch_end
(
self
,
*
args
,
**
kwargs
):
if
self
.
use_ema
:
self
.
model_ema
(
self
.
model
)
def
_get_rows_from_list
(
self
,
samples
):
n_imgs_per_row
=
len
(
samples
)
denoise_grid
=
rearrange
(
samples
,
'n b c h w -> b n c h w'
)
denoise_grid
=
rearrange
(
denoise_grid
,
'b n c h w -> (b n) c h w'
)
denoise_grid
=
make_grid
(
denoise_grid
,
nrow
=
n_imgs_per_row
)
return
denoise_grid
@
torch
.
no_grad
()
def
log_images
(
self
,
batch
,
N
=
8
,
n_row
=
2
,
sample
=
True
,
return_keys
=
None
,
**
kwargs
):
log
=
dict
()
x
=
self
.
get_input
(
batch
,
self
.
first_stage_key
)
N
=
min
(
x
.
shape
[
0
],
N
)
n_row
=
min
(
x
.
shape
[
0
],
n_row
)
x
=
x
.
to
(
self
.
device
)[:
N
]
log
[
"inputs"
]
=
x
# get diffusion row
diffusion_row
=
list
()
x_start
=
x
[:
n_row
]
for
t
in
range
(
self
.
num_timesteps
):
if
t
%
self
.
log_every_t
==
0
or
t
==
self
.
num_timesteps
-
1
:
t
=
repeat
(
torch
.
tensor
([
t
]),
'1 -> b'
,
b
=
n_row
)
t
=
t
.
to
(
self
.
device
)
.
long
()
noise
=
torch
.
randn_like
(
x_start
)
x_noisy
=
self
.
q_sample
(
x_start
=
x_start
,
t
=
t
,
noise
=
noise
)
diffusion_row
.
append
(
x_noisy
)
log
[
"diffusion_row"
]
=
self
.
_get_rows_from_list
(
diffusion_row
)
if
sample
:
# get denoise row
with
self
.
ema_scope
(
"Plotting"
):
samples
,
denoise_row
=
self
.
sample
(
batch_size
=
N
,
return_intermediates
=
True
)
log
[
"samples"
]
=
samples
log
[
"denoise_row"
]
=
self
.
_get_rows_from_list
(
denoise_row
)
if
return_keys
:
if
np
.
intersect1d
(
list
(
log
.
keys
()),
return_keys
)
.
shape
[
0
]
==
0
:
return
log
else
:
return
{
key
:
log
[
key
]
for
key
in
return_keys
}
return
log
def
configure_optimizers
(
self
):
lr
=
self
.
learning_rate
params
=
list
(
self
.
model
.
parameters
())
if
self
.
learn_logvar
:
params
=
params
+
[
self
.
logvar
]
opt
=
torch
.
optim
.
AdamW
(
params
,
lr
=
lr
)
return
opt
class
LatentDiffusionV1
(
DDPMV1
):
"""main class"""
def
__init__
(
self
,
first_stage_config
,
cond_stage_config
,
num_timesteps_cond
=
None
,
cond_stage_key
=
"image"
,
cond_stage_trainable
=
False
,
concat_mode
=
True
,
cond_stage_forward
=
None
,
conditioning_key
=
None
,
scale_factor
=
1.0
,
scale_by_std
=
False
,
*
args
,
**
kwargs
):
self
.
num_timesteps_cond
=
default
(
num_timesteps_cond
,
1
)
self
.
scale_by_std
=
scale_by_std
assert
self
.
num_timesteps_cond
<=
kwargs
[
'timesteps'
]
# for backwards compatibility after implementation of DiffusionWrapper
if
conditioning_key
is
None
:
conditioning_key
=
'concat'
if
concat_mode
else
'crossattn'
if
cond_stage_config
==
'__is_unconditional__'
:
conditioning_key
=
None
ckpt_path
=
kwargs
.
pop
(
"ckpt_path"
,
None
)
ignore_keys
=
kwargs
.
pop
(
"ignore_keys"
,
[])
super
()
.
__init__
(
conditioning_key
=
conditioning_key
,
*
args
,
**
kwargs
)
self
.
concat_mode
=
concat_mode
self
.
cond_stage_trainable
=
cond_stage_trainable
self
.
cond_stage_key
=
cond_stage_key
try
:
self
.
num_downs
=
len
(
first_stage_config
.
params
.
ddconfig
.
ch_mult
)
-
1
except
:
self
.
num_downs
=
0
if
not
scale_by_std
:
self
.
scale_factor
=
scale_factor
else
:
self
.
register_buffer
(
'scale_factor'
,
torch
.
tensor
(
scale_factor
))
self
.
instantiate_first_stage
(
first_stage_config
)
self
.
instantiate_cond_stage
(
cond_stage_config
)
self
.
cond_stage_forward
=
cond_stage_forward
self
.
clip_denoised
=
False
self
.
bbox_tokenizer
=
None
self
.
restarted_from_ckpt
=
False
if
ckpt_path
is
not
None
:
self
.
init_from_ckpt
(
ckpt_path
,
ignore_keys
)
self
.
restarted_from_ckpt
=
True
def
make_cond_schedule
(
self
,
):
self
.
cond_ids
=
torch
.
full
(
size
=
(
self
.
num_timesteps
,),
fill_value
=
self
.
num_timesteps
-
1
,
dtype
=
torch
.
long
)
ids
=
torch
.
round
(
torch
.
linspace
(
0
,
self
.
num_timesteps
-
1
,
self
.
num_timesteps_cond
))
.
long
()
self
.
cond_ids
[:
self
.
num_timesteps_cond
]
=
ids
@
rank_zero_only
@
torch
.
no_grad
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
dataloader_idx
):
# only for very first batch
if
self
.
scale_by_std
and
self
.
current_epoch
==
0
and
self
.
global_step
==
0
and
batch_idx
==
0
and
not
self
.
restarted_from_ckpt
:
assert
self
.
scale_factor
==
1.
,
'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print
(
"### USING STD-RESCALING ###"
)
x
=
super
()
.
get_input
(
batch
,
self
.
first_stage_key
)
x
=
x
.
to
(
self
.
device
)
encoder_posterior
=
self
.
encode_first_stage
(
x
)
z
=
self
.
get_first_stage_encoding
(
encoder_posterior
)
.
detach
()
del
self
.
scale_factor
self
.
register_buffer
(
'scale_factor'
,
1.
/
z
.
flatten
()
.
std
())
print
(
f
"setting self.scale_factor to {self.scale_factor}"
)
print
(
"### USING STD-RESCALING ###"
)
def
register_schedule
(
self
,
given_betas
=
None
,
beta_schedule
=
"linear"
,
timesteps
=
1000
,
linear_start
=
1e-4
,
linear_end
=
2e-2
,
cosine_s
=
8e-3
):
super
()
.
register_schedule
(
given_betas
,
beta_schedule
,
timesteps
,
linear_start
,
linear_end
,
cosine_s
)
self
.
shorten_cond_schedule
=
self
.
num_timesteps_cond
>
1
if
self
.
shorten_cond_schedule
:
self
.
make_cond_schedule
()
def
instantiate_first_stage
(
self
,
config
):
model
=
instantiate_from_config
(
config
)
self
.
first_stage_model
=
model
.
eval
()
self
.
first_stage_model
.
train
=
disabled_train
for
param
in
self
.
first_stage_model
.
parameters
():
param
.
requires_grad
=
False
def
instantiate_cond_stage
(
self
,
config
):
if
not
self
.
cond_stage_trainable
:
if
config
==
"__is_first_stage__"
:
print
(
"Using first stage also as cond stage."
)
self
.
cond_stage_model
=
self
.
first_stage_model
elif
config
==
"__is_unconditional__"
:
print
(
f
"Training {self.__class__.__name__} as an unconditional model."
)
self
.
cond_stage_model
=
None
# self.be_unconditional = True
else
:
model
=
instantiate_from_config
(
config
)
self
.
cond_stage_model
=
model
.
eval
()
self
.
cond_stage_model
.
train
=
disabled_train
for
param
in
self
.
cond_stage_model
.
parameters
():
param
.
requires_grad
=
False
else
:
assert
config
!=
'__is_first_stage__'
assert
config
!=
'__is_unconditional__'
model
=
instantiate_from_config
(
config
)
self
.
cond_stage_model
=
model
def
_get_denoise_row_from_list
(
self
,
samples
,
desc
=
''
,
force_no_decoder_quantization
=
False
):
denoise_row
=
[]
for
zd
in
tqdm
(
samples
,
desc
=
desc
):
denoise_row
.
append
(
self
.
decode_first_stage
(
zd
.
to
(
self
.
device
),
force_not_quantize
=
force_no_decoder_quantization
))
n_imgs_per_row
=
len
(
denoise_row
)
denoise_row
=
torch
.
stack
(
denoise_row
)
# n_log_step, n_row, C, H, W
denoise_grid
=
rearrange
(
denoise_row
,
'n b c h w -> b n c h w'
)
denoise_grid
=
rearrange
(
denoise_grid
,
'b n c h w -> (b n) c h w'
)
denoise_grid
=
make_grid
(
denoise_grid
,
nrow
=
n_imgs_per_row
)
return
denoise_grid
def
get_first_stage_encoding
(
self
,
encoder_posterior
):
if
isinstance
(
encoder_posterior
,
DiagonalGaussianDistribution
):
z
=
encoder_posterior
.
sample
()
elif
isinstance
(
encoder_posterior
,
torch
.
Tensor
):
z
=
encoder_posterior
else
:
raise
NotImplementedError
(
f
"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
)
return
self
.
scale_factor
*
z
def
get_learned_conditioning
(
self
,
c
):
if
self
.
cond_stage_forward
is
None
:
if
hasattr
(
self
.
cond_stage_model
,
'encode'
)
and
callable
(
self
.
cond_stage_model
.
encode
):
c
=
self
.
cond_stage_model
.
encode
(
c
)
if
isinstance
(
c
,
DiagonalGaussianDistribution
):
c
=
c
.
mode
()
else
:
c
=
self
.
cond_stage_model
(
c
)
else
:
assert
hasattr
(
self
.
cond_stage_model
,
self
.
cond_stage_forward
)
c
=
getattr
(
self
.
cond_stage_model
,
self
.
cond_stage_forward
)(
c
)
return
c
def
meshgrid
(
self
,
h
,
w
):
y
=
torch
.
arange
(
0
,
h
)
.
view
(
h
,
1
,
1
)
.
repeat
(
1
,
w
,
1
)
x
=
torch
.
arange
(
0
,
w
)
.
view
(
1
,
w
,
1
)
.
repeat
(
h
,
1
,
1
)
arr
=
torch
.
cat
([
y
,
x
],
dim
=-
1
)
return
arr
def
delta_border
(
self
,
h
,
w
):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner
=
torch
.
tensor
([
h
-
1
,
w
-
1
])
.
view
(
1
,
1
,
2
)
arr
=
self
.
meshgrid
(
h
,
w
)
/
lower_right_corner
dist_left_up
=
torch
.
min
(
arr
,
dim
=-
1
,
keepdims
=
True
)[
0
]
dist_right_down
=
torch
.
min
(
1
-
arr
,
dim
=-
1
,
keepdims
=
True
)[
0
]
edge_dist
=
torch
.
min
(
torch
.
cat
([
dist_left_up
,
dist_right_down
],
dim
=-
1
),
dim
=-
1
)[
0
]
return
edge_dist
def
get_weighting
(
self
,
h
,
w
,
Ly
,
Lx
,
device
):
weighting
=
self
.
delta_border
(
h
,
w
)
weighting
=
torch
.
clip
(
weighting
,
self
.
split_input_params
[
"clip_min_weight"
],
self
.
split_input_params
[
"clip_max_weight"
],
)
weighting
=
weighting
.
view
(
1
,
h
*
w
,
1
)
.
repeat
(
1
,
1
,
Ly
*
Lx
)
.
to
(
device
)
if
self
.
split_input_params
[
"tie_braker"
]:
L_weighting
=
self
.
delta_border
(
Ly
,
Lx
)
L_weighting
=
torch
.
clip
(
L_weighting
,
self
.
split_input_params
[
"clip_min_tie_weight"
],
self
.
split_input_params
[
"clip_max_tie_weight"
])
L_weighting
=
L_weighting
.
view
(
1
,
1
,
Ly
*
Lx
)
.
to
(
device
)
weighting
=
weighting
*
L_weighting
return
weighting
def
get_fold_unfold
(
self
,
x
,
kernel_size
,
stride
,
uf
=
1
,
df
=
1
):
# todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs
,
nc
,
h
,
w
=
x
.
shape
# number of crops in image
Ly
=
(
h
-
kernel_size
[
0
])
//
stride
[
0
]
+
1
Lx
=
(
w
-
kernel_size
[
1
])
//
stride
[
1
]
+
1
if
uf
==
1
and
df
==
1
:
fold_params
=
dict
(
kernel_size
=
kernel_size
,
dilation
=
1
,
padding
=
0
,
stride
=
stride
)
unfold
=
torch
.
nn
.
Unfold
(
**
fold_params
)
fold
=
torch
.
nn
.
Fold
(
output_size
=
x
.
shape
[
2
:],
**
fold_params
)
weighting
=
self
.
get_weighting
(
kernel_size
[
0
],
kernel_size
[
1
],
Ly
,
Lx
,
x
.
device
)
.
to
(
x
.
dtype
)
normalization
=
fold
(
weighting
)
.
view
(
1
,
1
,
h
,
w
)
# normalizes the overlap
weighting
=
weighting
.
view
((
1
,
1
,
kernel_size
[
0
],
kernel_size
[
1
],
Ly
*
Lx
))
elif
uf
>
1
and
df
==
1
:
fold_params
=
dict
(
kernel_size
=
kernel_size
,
dilation
=
1
,
padding
=
0
,
stride
=
stride
)
unfold
=
torch
.
nn
.
Unfold
(
**
fold_params
)
fold_params2
=
dict
(
kernel_size
=
(
kernel_size
[
0
]
*
uf
,
kernel_size
[
0
]
*
uf
),
dilation
=
1
,
padding
=
0
,
stride
=
(
stride
[
0
]
*
uf
,
stride
[
1
]
*
uf
))
fold
=
torch
.
nn
.
Fold
(
output_size
=
(
x
.
shape
[
2
]
*
uf
,
x
.
shape
[
3
]
*
uf
),
**
fold_params2
)
weighting
=
self
.
get_weighting
(
kernel_size
[
0
]
*
uf
,
kernel_size
[
1
]
*
uf
,
Ly
,
Lx
,
x
.
device
)
.
to
(
x
.
dtype
)
normalization
=
fold
(
weighting
)
.
view
(
1
,
1
,
h
*
uf
,
w
*
uf
)
# normalizes the overlap
weighting
=
weighting
.
view
((
1
,
1
,
kernel_size
[
0
]
*
uf
,
kernel_size
[
1
]
*
uf
,
Ly
*
Lx
))
elif
df
>
1
and
uf
==
1
:
fold_params
=
dict
(
kernel_size
=
kernel_size
,
dilation
=
1
,
padding
=
0
,
stride
=
stride
)
unfold
=
torch
.
nn
.
Unfold
(
**
fold_params
)
fold_params2
=
dict
(
kernel_size
=
(
kernel_size
[
0
]
//
df
,
kernel_size
[
0
]
//
df
),
dilation
=
1
,
padding
=
0
,
stride
=
(
stride
[
0
]
//
df
,
stride
[
1
]
//
df
))
fold
=
torch
.
nn
.
Fold
(
output_size
=
(
x
.
shape
[
2
]
//
df
,
x
.
shape
[
3
]
//
df
),
**
fold_params2
)
weighting
=
self
.
get_weighting
(
kernel_size
[
0
]
//
df
,
kernel_size
[
1
]
//
df
,
Ly
,
Lx
,
x
.
device
)
.
to
(
x
.
dtype
)
normalization
=
fold
(
weighting
)
.
view
(
1
,
1
,
h
//
df
,
w
//
df
)
# normalizes the overlap
weighting
=
weighting
.
view
((
1
,
1
,
kernel_size
[
0
]
//
df
,
kernel_size
[
1
]
//
df
,
Ly
*
Lx
))
else
:
raise
NotImplementedError
return
fold
,
unfold
,
normalization
,
weighting
@
torch
.
no_grad
()
def
get_input
(
self
,
batch
,
k
,
return_first_stage_outputs
=
False
,
force_c_encode
=
False
,
cond_key
=
None
,
return_original_cond
=
False
,
bs
=
None
):
x
=
super
()
.
get_input
(
batch
,
k
)
if
bs
is
not
None
:
x
=
x
[:
bs
]
x
=
x
.
to
(
self
.
device
)
encoder_posterior
=
self
.
encode_first_stage
(
x
)
z
=
self
.
get_first_stage_encoding
(
encoder_posterior
)
.
detach
()
if
self
.
model
.
conditioning_key
is
not
None
:
if
cond_key
is
None
:
cond_key
=
self
.
cond_stage_key
if
cond_key
!=
self
.
first_stage_key
:
if
cond_key
in
[
'caption'
,
'coordinates_bbox'
]:
xc
=
batch
[
cond_key
]
elif
cond_key
==
'class_label'
:
xc
=
batch
else
:
xc
=
super
()
.
get_input
(
batch
,
cond_key
)
.
to
(
self
.
device
)
else
:
xc
=
x
if
not
self
.
cond_stage_trainable
or
force_c_encode
:
if
isinstance
(
xc
,
dict
)
or
isinstance
(
xc
,
list
):
# import pudb; pudb.set_trace()
c
=
self
.
get_learned_conditioning
(
xc
)
else
:
c
=
self
.
get_learned_conditioning
(
xc
.
to
(
self
.
device
))
else
:
c
=
xc
if
bs
is
not
None
:
c
=
c
[:
bs
]
if
self
.
use_positional_encodings
:
pos_x
,
pos_y
=
self
.
compute_latent_shifts
(
batch
)
ckey
=
__conditioning_keys__
[
self
.
model
.
conditioning_key
]
c
=
{
ckey
:
c
,
'pos_x'
:
pos_x
,
'pos_y'
:
pos_y
}
else
:
c
=
None
xc
=
None
if
self
.
use_positional_encodings
:
pos_x
,
pos_y
=
self
.
compute_latent_shifts
(
batch
)
c
=
{
'pos_x'
:
pos_x
,
'pos_y'
:
pos_y
}
out
=
[
z
,
c
]
if
return_first_stage_outputs
:
xrec
=
self
.
decode_first_stage
(
z
)
out
.
extend
([
x
,
xrec
])
if
return_original_cond
:
out
.
append
(
xc
)
return
out
@
torch
.
no_grad
()
def
decode_first_stage
(
self
,
z
,
predict_cids
=
False
,
force_not_quantize
=
False
):
if
predict_cids
:
if
z
.
dim
()
==
4
:
z
=
torch
.
argmax
(
z
.
exp
(),
dim
=
1
)
.
long
()
z
=
self
.
first_stage_model
.
quantize
.
get_codebook_entry
(
z
,
shape
=
None
)
z
=
rearrange
(
z
,
'b h w c -> b c h w'
)
.
contiguous
()
z
=
1.
/
self
.
scale_factor
*
z
if
hasattr
(
self
,
"split_input_params"
):
if
self
.
split_input_params
[
"patch_distributed_vq"
]:
ks
=
self
.
split_input_params
[
"ks"
]
# eg. (128, 128)
stride
=
self
.
split_input_params
[
"stride"
]
# eg. (64, 64)
uf
=
self
.
split_input_params
[
"vqf"
]
bs
,
nc
,
h
,
w
=
z
.
shape
if
ks
[
0
]
>
h
or
ks
[
1
]
>
w
:
ks
=
(
min
(
ks
[
0
],
h
),
min
(
ks
[
1
],
w
))
print
(
"reducing Kernel"
)
if
stride
[
0
]
>
h
or
stride
[
1
]
>
w
:
stride
=
(
min
(
stride
[
0
],
h
),
min
(
stride
[
1
],
w
))
print
(
"reducing stride"
)
fold
,
unfold
,
normalization
,
weighting
=
self
.
get_fold_unfold
(
z
,
ks
,
stride
,
uf
=
uf
)
z
=
unfold
(
z
)
# (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z
=
z
.
view
((
z
.
shape
[
0
],
-
1
,
ks
[
0
],
ks
[
1
],
z
.
shape
[
-
1
]))
# (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if
isinstance
(
self
.
first_stage_model
,
VQModelInterface
):
output_list
=
[
self
.
first_stage_model
.
decode
(
z
[:,
:,
:,
:,
i
],
force_not_quantize
=
predict_cids
or
force_not_quantize
)
for
i
in
range
(
z
.
shape
[
-
1
])]
else
:
output_list
=
[
self
.
first_stage_model
.
decode
(
z
[:,
:,
:,
:,
i
])
for
i
in
range
(
z
.
shape
[
-
1
])]
o
=
torch
.
stack
(
output_list
,
axis
=-
1
)
# # (bn, nc, ks[0], ks[1], L)
o
=
o
*
weighting
# Reverse 1. reshape to img shape
o
=
o
.
view
((
o
.
shape
[
0
],
-
1
,
o
.
shape
[
-
1
]))
# (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded
=
fold
(
o
)
decoded
=
decoded
/
normalization
# norm is shape (1, 1, h, w)
return
decoded
else
:
if
isinstance
(
self
.
first_stage_model
,
VQModelInterface
):
return
self
.
first_stage_model
.
decode
(
z
,
force_not_quantize
=
predict_cids
or
force_not_quantize
)
else
:
return
self
.
first_stage_model
.
decode
(
z
)
else
:
if
isinstance
(
self
.
first_stage_model
,
VQModelInterface
):
return
self
.
first_stage_model
.
decode
(
z
,
force_not_quantize
=
predict_cids
or
force_not_quantize
)
else
:
return
self
.
first_stage_model
.
decode
(
z
)
# same as above but without decorator
def
differentiable_decode_first_stage
(
self
,
z
,
predict_cids
=
False
,
force_not_quantize
=
False
):
if
predict_cids
:
if
z
.
dim
()
==
4
:
z
=
torch
.
argmax
(
z
.
exp
(),
dim
=
1
)
.
long
()
z
=
self
.
first_stage_model
.
quantize
.
get_codebook_entry
(
z
,
shape
=
None
)
z
=
rearrange
(
z
,
'b h w c -> b c h w'
)
.
contiguous
()
z
=
1.
/
self
.
scale_factor
*
z
if
hasattr
(
self
,
"split_input_params"
):
if
self
.
split_input_params
[
"patch_distributed_vq"
]:
ks
=
self
.
split_input_params
[
"ks"
]
# eg. (128, 128)
stride
=
self
.
split_input_params
[
"stride"
]
# eg. (64, 64)
uf
=
self
.
split_input_params
[
"vqf"
]
bs
,
nc
,
h
,
w
=
z
.
shape
if
ks
[
0
]
>
h
or
ks
[
1
]
>
w
:
ks
=
(
min
(
ks
[
0
],
h
),
min
(
ks
[
1
],
w
))
print
(
"reducing Kernel"
)
if
stride
[
0
]
>
h
or
stride
[
1
]
>
w
:
stride
=
(
min
(
stride
[
0
],
h
),
min
(
stride
[
1
],
w
))
print
(
"reducing stride"
)
fold
,
unfold
,
normalization
,
weighting
=
self
.
get_fold_unfold
(
z
,
ks
,
stride
,
uf
=
uf
)
z
=
unfold
(
z
)
# (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z
=
z
.
view
((
z
.
shape
[
0
],
-
1
,
ks
[
0
],
ks
[
1
],
z
.
shape
[
-
1
]))
# (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if
isinstance
(
self
.
first_stage_model
,
VQModelInterface
):
output_list
=
[
self
.
first_stage_model
.
decode
(
z
[:,
:,
:,
:,
i
],
force_not_quantize
=
predict_cids
or
force_not_quantize
)
for
i
in
range
(
z
.
shape
[
-
1
])]
else
:
output_list
=
[
self
.
first_stage_model
.
decode
(
z
[:,
:,
:,
:,
i
])
for
i
in
range
(
z
.
shape
[
-
1
])]
o
=
torch
.
stack
(
output_list
,
axis
=-
1
)
# # (bn, nc, ks[0], ks[1], L)
o
=
o
*
weighting
# Reverse 1. reshape to img shape
o
=
o
.
view
((
o
.
shape
[
0
],
-
1
,
o
.
shape
[
-
1
]))
# (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded
=
fold
(
o
)
decoded
=
decoded
/
normalization
# norm is shape (1, 1, h, w)
return
decoded
else
:
if
isinstance
(
self
.
first_stage_model
,
VQModelInterface
):
return
self
.
first_stage_model
.
decode
(
z
,
force_not_quantize
=
predict_cids
or
force_not_quantize
)
else
:
return
self
.
first_stage_model
.
decode
(
z
)
else
:
if
isinstance
(
self
.
first_stage_model
,
VQModelInterface
):
return
self
.
first_stage_model
.
decode
(
z
,
force_not_quantize
=
predict_cids
or
force_not_quantize
)
else
:
return
self
.
first_stage_model
.
decode
(
z
)
@
torch
.
no_grad
()
def
encode_first_stage
(
self
,
x
):
if
hasattr
(
self
,
"split_input_params"
):
if
self
.
split_input_params
[
"patch_distributed_vq"
]:
ks
=
self
.
split_input_params
[
"ks"
]
# eg. (128, 128)
stride
=
self
.
split_input_params
[
"stride"
]
# eg. (64, 64)
df
=
self
.
split_input_params
[
"vqf"
]
self
.
split_input_params
[
'original_image_size'
]
=
x
.
shape
[
-
2
:]
bs
,
nc
,
h
,
w
=
x
.
shape
if
ks
[
0
]
>
h
or
ks
[
1
]
>
w
:
ks
=
(
min
(
ks
[
0
],
h
),
min
(
ks
[
1
],
w
))
print
(
"reducing Kernel"
)
if
stride
[
0
]
>
h
or
stride
[
1
]
>
w
:
stride
=
(
min
(
stride
[
0
],
h
),
min
(
stride
[
1
],
w
))
print
(
"reducing stride"
)
fold
,
unfold
,
normalization
,
weighting
=
self
.
get_fold_unfold
(
x
,
ks
,
stride
,
df
=
df
)
z
=
unfold
(
x
)
# (bn, nc * prod(**ks), L)
# Reshape to img shape
z
=
z
.
view
((
z
.
shape
[
0
],
-
1
,
ks
[
0
],
ks
[
1
],
z
.
shape
[
-
1
]))
# (bn, nc, ks[0], ks[1], L )
output_list
=
[
self
.
first_stage_model
.
encode
(
z
[:,
:,
:,
:,
i
])
for
i
in
range
(
z
.
shape
[
-
1
])]
o
=
torch
.
stack
(
output_list
,
axis
=-
1
)
o
=
o
*
weighting
# Reverse reshape to img shape
o
=
o
.
view
((
o
.
shape
[
0
],
-
1
,
o
.
shape
[
-
1
]))
# (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded
=
fold
(
o
)
decoded
=
decoded
/
normalization
return
decoded
else
:
return
self
.
first_stage_model
.
encode
(
x
)
else
:
return
self
.
first_stage_model
.
encode
(
x
)
def
shared_step
(
self
,
batch
,
**
kwargs
):
x
,
c
=
self
.
get_input
(
batch
,
self
.
first_stage_key
)
loss
=
self
(
x
,
c
)
return
loss
def
forward
(
self
,
x
,
c
,
*
args
,
**
kwargs
):
t
=
torch
.
randint
(
0
,
self
.
num_timesteps
,
(
x
.
shape
[
0
],),
device
=
self
.
device
)
.
long
()
if
self
.
model
.
conditioning_key
is
not
None
:
assert
c
is
not
None
if
self
.
cond_stage_trainable
:
c
=
self
.
get_learned_conditioning
(
c
)
if
self
.
shorten_cond_schedule
:
# TODO: drop this option
tc
=
self
.
cond_ids
[
t
]
.
to
(
self
.
device
)
c
=
self
.
q_sample
(
x_start
=
c
,
t
=
tc
,
noise
=
torch
.
randn_like
(
c
.
float
()))
return
self
.
p_losses
(
x
,
c
,
t
,
*
args
,
**
kwargs
)
def
_rescale_annotations
(
self
,
bboxes
,
crop_coordinates
):
# TODO: move to dataset
def
rescale_bbox
(
bbox
):
x0
=
clamp
((
bbox
[
0
]
-
crop_coordinates
[
0
])
/
crop_coordinates
[
2
])
y0
=
clamp
((
bbox
[
1
]
-
crop_coordinates
[
1
])
/
crop_coordinates
[
3
])
w
=
min
(
bbox
[
2
]
/
crop_coordinates
[
2
],
1
-
x0
)
h
=
min
(
bbox
[
3
]
/
crop_coordinates
[
3
],
1
-
y0
)
return
x0
,
y0
,
w
,
h
return
[
rescale_bbox
(
b
)
for
b
in
bboxes
]
def
apply_model
(
self
,
x_noisy
,
t
,
cond
,
return_ids
=
False
):
if
isinstance
(
cond
,
dict
):
# hybrid case, cond is exptected to be a dict
pass
else
:
if
not
isinstance
(
cond
,
list
):
cond
=
[
cond
]
key
=
'c_concat'
if
self
.
model
.
conditioning_key
==
'concat'
else
'c_crossattn'
cond
=
{
key
:
cond
}
if
hasattr
(
self
,
"split_input_params"
):
assert
len
(
cond
)
==
1
# todo can only deal with one conditioning atm
assert
not
return_ids
ks
=
self
.
split_input_params
[
"ks"
]
# eg. (128, 128)
stride
=
self
.
split_input_params
[
"stride"
]
# eg. (64, 64)
h
,
w
=
x_noisy
.
shape
[
-
2
:]
fold
,
unfold
,
normalization
,
weighting
=
self
.
get_fold_unfold
(
x_noisy
,
ks
,
stride
)
z
=
unfold
(
x_noisy
)
# (bn, nc * prod(**ks), L)
# Reshape to img shape
z
=
z
.
view
((
z
.
shape
[
0
],
-
1
,
ks
[
0
],
ks
[
1
],
z
.
shape
[
-
1
]))
# (bn, nc, ks[0], ks[1], L )
z_list
=
[
z
[:,
:,
:,
:,
i
]
for
i
in
range
(
z
.
shape
[
-
1
])]
if
self
.
cond_stage_key
in
[
"image"
,
"LR_image"
,
"segmentation"
,
'bbox_img'
]
and
self
.
model
.
conditioning_key
:
# todo check for completeness
c_key
=
next
(
iter
(
cond
.
keys
()))
# get key
c
=
next
(
iter
(
cond
.
values
()))
# get value
assert
(
len
(
c
)
==
1
)
# todo extend to list with more than one elem
c
=
c
[
0
]
# get element
c
=
unfold
(
c
)
c
=
c
.
view
((
c
.
shape
[
0
],
-
1
,
ks
[
0
],
ks
[
1
],
c
.
shape
[
-
1
]))
# (bn, nc, ks[0], ks[1], L )
cond_list
=
[{
c_key
:
[
c
[:,
:,
:,
:,
i
]]}
for
i
in
range
(
c
.
shape
[
-
1
])]
elif
self
.
cond_stage_key
==
'coordinates_bbox'
:
assert
'original_image_size'
in
self
.
split_input_params
,
'BoudingBoxRescaling is missing original_image_size'
# assuming padding of unfold is always 0 and its dilation is always 1
n_patches_per_row
=
int
((
w
-
ks
[
0
])
/
stride
[
0
]
+
1
)
full_img_h
,
full_img_w
=
self
.
split_input_params
[
'original_image_size'
]
# as we are operating on latents, we need the factor from the original image size to the
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
num_downs
=
self
.
first_stage_model
.
encoder
.
num_resolutions
-
1
rescale_latent
=
2
**
(
num_downs
)
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
# need to rescale the tl patch coordinates to be in between (0,1)
tl_patch_coordinates
=
[(
rescale_latent
*
stride
[
0
]
*
(
patch_nr
%
n_patches_per_row
)
/
full_img_w
,
rescale_latent
*
stride
[
1
]
*
(
patch_nr
//
n_patches_per_row
)
/
full_img_h
)
for
patch_nr
in
range
(
z
.
shape
[
-
1
])]
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
patch_limits
=
[(
x_tl
,
y_tl
,
rescale_latent
*
ks
[
0
]
/
full_img_w
,
rescale_latent
*
ks
[
1
]
/
full_img_h
)
for
x_tl
,
y_tl
in
tl_patch_coordinates
]
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
# tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd
=
[
torch
.
LongTensor
(
self
.
bbox_tokenizer
.
_crop_encoder
(
bbox
))[
None
]
.
to
(
self
.
device
)
for
bbox
in
patch_limits
]
# list of length l with tensors of shape (1, 2)
print
(
patch_limits_tknzd
[
0
]
.
shape
)
# cut tknzd crop position from conditioning
assert
isinstance
(
cond
,
dict
),
'cond must be dict to be fed into model'
cut_cond
=
cond
[
'c_crossattn'
][
0
][
...
,
:
-
2
]
.
to
(
self
.
device
)
print
(
cut_cond
.
shape
)
adapted_cond
=
torch
.
stack
([
torch
.
cat
([
cut_cond
,
p
],
dim
=
1
)
for
p
in
patch_limits_tknzd
])
adapted_cond
=
rearrange
(
adapted_cond
,
'l b n -> (l b) n'
)
print
(
adapted_cond
.
shape
)
adapted_cond
=
self
.
get_learned_conditioning
(
adapted_cond
)
print
(
adapted_cond
.
shape
)
adapted_cond
=
rearrange
(
adapted_cond
,
'(l b) n d -> l b n d'
,
l
=
z
.
shape
[
-
1
])
print
(
adapted_cond
.
shape
)
cond_list
=
[{
'c_crossattn'
:
[
e
]}
for
e
in
adapted_cond
]
else
:
cond_list
=
[
cond
for
i
in
range
(
z
.
shape
[
-
1
])]
# Todo make this more efficient
# apply model by loop over crops
output_list
=
[
self
.
model
(
z_list
[
i
],
t
,
**
cond_list
[
i
])
for
i
in
range
(
z
.
shape
[
-
1
])]
assert
not
isinstance
(
output_list
[
0
],
tuple
)
# todo cant deal with multiple model outputs check this never happens
o
=
torch
.
stack
(
output_list
,
axis
=-
1
)
o
=
o
*
weighting
# Reverse reshape to img shape
o
=
o
.
view
((
o
.
shape
[
0
],
-
1
,
o
.
shape
[
-
1
]))
# (bn, nc * ks[0] * ks[1], L)
# stitch crops together
x_recon
=
fold
(
o
)
/
normalization
else
:
x_recon
=
self
.
model
(
x_noisy
,
t
,
**
cond
)
if
isinstance
(
x_recon
,
tuple
)
and
not
return_ids
:
return
x_recon
[
0
]
else
:
return
x_recon
def
_predict_eps_from_xstart
(
self
,
x_t
,
t
,
pred_xstart
):
return
(
extract_into_tensor
(
self
.
sqrt_recip_alphas_cumprod
,
t
,
x_t
.
shape
)
*
x_t
-
pred_xstart
)
/
\
extract_into_tensor
(
self
.
sqrt_recipm1_alphas_cumprod
,
t
,
x_t
.
shape
)
def
_prior_bpd
(
self
,
x_start
):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size
=
x_start
.
shape
[
0
]
t
=
torch
.
tensor
([
self
.
num_timesteps
-
1
]
*
batch_size
,
device
=
x_start
.
device
)
qt_mean
,
_
,
qt_log_variance
=
self
.
q_mean_variance
(
x_start
,
t
)
kl_prior
=
normal_kl
(
mean1
=
qt_mean
,
logvar1
=
qt_log_variance
,
mean2
=
0.0
,
logvar2
=
0.0
)
return
mean_flat
(
kl_prior
)
/
np
.
log
(
2.0
)
def
p_losses
(
self
,
x_start
,
cond
,
t
,
noise
=
None
):
noise
=
default
(
noise
,
lambda
:
torch
.
randn_like
(
x_start
))
x_noisy
=
self
.
q_sample
(
x_start
=
x_start
,
t
=
t
,
noise
=
noise
)
model_output
=
self
.
apply_model
(
x_noisy
,
t
,
cond
)
loss_dict
=
{}
prefix
=
'train'
if
self
.
training
else
'val'
if
self
.
parameterization
==
"x0"
:
target
=
x_start
elif
self
.
parameterization
==
"eps"
:
target
=
noise
else
:
raise
NotImplementedError
()
loss_simple
=
self
.
get_loss
(
model_output
,
target
,
mean
=
False
)
.
mean
([
1
,
2
,
3
])
loss_dict
.
update
({
f
'{prefix}/loss_simple'
:
loss_simple
.
mean
()})
logvar_t
=
self
.
logvar
[
t
]
.
to
(
self
.
device
)
loss
=
loss_simple
/
torch
.
exp
(
logvar_t
)
+
logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if
self
.
learn_logvar
:
loss_dict
.
update
({
f
'{prefix}/loss_gamma'
:
loss
.
mean
()})
loss_dict
.
update
({
'logvar'
:
self
.
logvar
.
data
.
mean
()})
loss
=
self
.
l_simple_weight
*
loss
.
mean
()
loss_vlb
=
self
.
get_loss
(
model_output
,
target
,
mean
=
False
)
.
mean
(
dim
=
(
1
,
2
,
3
))
loss_vlb
=
(
self
.
lvlb_weights
[
t
]
*
loss_vlb
)
.
mean
()
loss_dict
.
update
({
f
'{prefix}/loss_vlb'
:
loss_vlb
})
loss
+=
(
self
.
original_elbo_weight
*
loss_vlb
)
loss_dict
.
update
({
f
'{prefix}/loss'
:
loss
})
return
loss
,
loss_dict
def
p_mean_variance
(
self
,
x
,
c
,
t
,
clip_denoised
:
bool
,
return_codebook_ids
=
False
,
quantize_denoised
=
False
,
return_x0
=
False
,
score_corrector
=
None
,
corrector_kwargs
=
None
):
t_in
=
t
model_out
=
self
.
apply_model
(
x
,
t_in
,
c
,
return_ids
=
return_codebook_ids
)
if
score_corrector
is
not
None
:
assert
self
.
parameterization
==
"eps"
model_out
=
score_corrector
.
modify_score
(
self
,
model_out
,
x
,
t
,
c
,
**
corrector_kwargs
)
if
return_codebook_ids
:
model_out
,
logits
=
model_out
if
self
.
parameterization
==
"eps"
:
x_recon
=
self
.
predict_start_from_noise
(
x
,
t
=
t
,
noise
=
model_out
)
elif
self
.
parameterization
==
"x0"
:
x_recon
=
model_out
else
:
raise
NotImplementedError
()
if
clip_denoised
:
x_recon
.
clamp_
(
-
1.
,
1.
)
if
quantize_denoised
:
x_recon
,
_
,
[
_
,
_
,
indices
]
=
self
.
first_stage_model
.
quantize
(
x_recon
)
model_mean
,
posterior_variance
,
posterior_log_variance
=
self
.
q_posterior
(
x_start
=
x_recon
,
x_t
=
x
,
t
=
t
)
if
return_codebook_ids
:
return
model_mean
,
posterior_variance
,
posterior_log_variance
,
logits
elif
return_x0
:
return
model_mean
,
posterior_variance
,
posterior_log_variance
,
x_recon
else
:
return
model_mean
,
posterior_variance
,
posterior_log_variance
@
torch
.
no_grad
()
def
p_sample
(
self
,
x
,
c
,
t
,
clip_denoised
=
False
,
repeat_noise
=
False
,
return_codebook_ids
=
False
,
quantize_denoised
=
False
,
return_x0
=
False
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
):
b
,
*
_
,
device
=
*
x
.
shape
,
x
.
device
outputs
=
self
.
p_mean_variance
(
x
=
x
,
c
=
c
,
t
=
t
,
clip_denoised
=
clip_denoised
,
return_codebook_ids
=
return_codebook_ids
,
quantize_denoised
=
quantize_denoised
,
return_x0
=
return_x0
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
)
if
return_codebook_ids
:
raise
DeprecationWarning
(
"Support dropped."
)
model_mean
,
_
,
model_log_variance
,
logits
=
outputs
elif
return_x0
:
model_mean
,
_
,
model_log_variance
,
x0
=
outputs
else
:
model_mean
,
_
,
model_log_variance
=
outputs
noise
=
noise_like
(
x
.
shape
,
device
,
repeat_noise
)
*
temperature
if
noise_dropout
>
0.
:
noise
=
torch
.
nn
.
functional
.
dropout
(
noise
,
p
=
noise_dropout
)
# no noise when t == 0
nonzero_mask
=
(
1
-
(
t
==
0
)
.
float
())
.
reshape
(
b
,
*
((
1
,)
*
(
len
(
x
.
shape
)
-
1
)))
if
return_codebook_ids
:
return
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
)
.
exp
()
*
noise
,
logits
.
argmax
(
dim
=
1
)
if
return_x0
:
return
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
)
.
exp
()
*
noise
,
x0
else
:
return
model_mean
+
nonzero_mask
*
(
0.5
*
model_log_variance
)
.
exp
()
*
noise
@
torch
.
no_grad
()
def
progressive_denoising
(
self
,
cond
,
shape
,
verbose
=
True
,
callback
=
None
,
quantize_denoised
=
False
,
img_callback
=
None
,
mask
=
None
,
x0
=
None
,
temperature
=
1.
,
noise_dropout
=
0.
,
score_corrector
=
None
,
corrector_kwargs
=
None
,
batch_size
=
None
,
x_T
=
None
,
start_T
=
None
,
log_every_t
=
None
):
if
not
log_every_t
:
log_every_t
=
self
.
log_every_t
timesteps
=
self
.
num_timesteps
if
batch_size
is
not
None
:
b
=
batch_size
if
batch_size
is
not
None
else
shape
[
0
]
shape
=
[
batch_size
]
+
list
(
shape
)
else
:
b
=
batch_size
=
shape
[
0
]
if
x_T
is
None
:
img
=
torch
.
randn
(
shape
,
device
=
self
.
device
)
else
:
img
=
x_T
intermediates
=
[]
if
cond
is
not
None
:
if
isinstance
(
cond
,
dict
):
cond
=
{
key
:
cond
[
key
][:
batch_size
]
if
not
isinstance
(
cond
[
key
],
list
)
else
list
(
map
(
lambda
x
:
x
[:
batch_size
],
cond
[
key
]))
for
key
in
cond
}
else
:
cond
=
[
c
[:
batch_size
]
for
c
in
cond
]
if
isinstance
(
cond
,
list
)
else
cond
[:
batch_size
]
if
start_T
is
not
None
:
timesteps
=
min
(
timesteps
,
start_T
)
iterator
=
tqdm
(
reversed
(
range
(
0
,
timesteps
)),
desc
=
'Progressive Generation'
,
total
=
timesteps
)
if
verbose
else
reversed
(
range
(
0
,
timesteps
))
if
type
(
temperature
)
==
float
:
temperature
=
[
temperature
]
*
timesteps
for
i
in
iterator
:
ts
=
torch
.
full
((
b
,),
i
,
device
=
self
.
device
,
dtype
=
torch
.
long
)
if
self
.
shorten_cond_schedule
:
assert
self
.
model
.
conditioning_key
!=
'hybrid'
tc
=
self
.
cond_ids
[
ts
]
.
to
(
cond
.
device
)
cond
=
self
.
q_sample
(
x_start
=
cond
,
t
=
tc
,
noise
=
torch
.
randn_like
(
cond
))
img
,
x0_partial
=
self
.
p_sample
(
img
,
cond
,
ts
,
clip_denoised
=
self
.
clip_denoised
,
quantize_denoised
=
quantize_denoised
,
return_x0
=
True
,
temperature
=
temperature
[
i
],
noise_dropout
=
noise_dropout
,
score_corrector
=
score_corrector
,
corrector_kwargs
=
corrector_kwargs
)
if
mask
is
not
None
:
assert
x0
is
not
None
img_orig
=
self
.
q_sample
(
x0
,
ts
)
img
=
img_orig
*
mask
+
(
1.
-
mask
)
*
img
if
i
%
log_every_t
==
0
or
i
==
timesteps
-
1
:
intermediates
.
append
(
x0_partial
)
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
img
,
i
)
return
img
,
intermediates
@
torch
.
no_grad
()
def
p_sample_loop
(
self
,
cond
,
shape
,
return_intermediates
=
False
,
x_T
=
None
,
verbose
=
True
,
callback
=
None
,
timesteps
=
None
,
quantize_denoised
=
False
,
mask
=
None
,
x0
=
None
,
img_callback
=
None
,
start_T
=
None
,
log_every_t
=
None
):
if
not
log_every_t
:
log_every_t
=
self
.
log_every_t
device
=
self
.
betas
.
device
b
=
shape
[
0
]
if
x_T
is
None
:
img
=
torch
.
randn
(
shape
,
device
=
device
)
else
:
img
=
x_T
intermediates
=
[
img
]
if
timesteps
is
None
:
timesteps
=
self
.
num_timesteps
if
start_T
is
not
None
:
timesteps
=
min
(
timesteps
,
start_T
)
iterator
=
tqdm
(
reversed
(
range
(
0
,
timesteps
)),
desc
=
'Sampling t'
,
total
=
timesteps
)
if
verbose
else
reversed
(
range
(
0
,
timesteps
))
if
mask
is
not
None
:
assert
x0
is
not
None
assert
x0
.
shape
[
2
:
3
]
==
mask
.
shape
[
2
:
3
]
# spatial size has to match
for
i
in
iterator
:
ts
=
torch
.
full
((
b
,),
i
,
device
=
device
,
dtype
=
torch
.
long
)
if
self
.
shorten_cond_schedule
:
assert
self
.
model
.
conditioning_key
!=
'hybrid'
tc
=
self
.
cond_ids
[
ts
]
.
to
(
cond
.
device
)
cond
=
self
.
q_sample
(
x_start
=
cond
,
t
=
tc
,
noise
=
torch
.
randn_like
(
cond
))
img
=
self
.
p_sample
(
img
,
cond
,
ts
,
clip_denoised
=
self
.
clip_denoised
,
quantize_denoised
=
quantize_denoised
)
if
mask
is
not
None
:
img_orig
=
self
.
q_sample
(
x0
,
ts
)
img
=
img_orig
*
mask
+
(
1.
-
mask
)
*
img
if
i
%
log_every_t
==
0
or
i
==
timesteps
-
1
:
intermediates
.
append
(
img
)
if
callback
:
callback
(
i
)
if
img_callback
:
img_callback
(
img
,
i
)
if
return_intermediates
:
return
img
,
intermediates
return
img
@
torch
.
no_grad
()
def
sample
(
self
,
cond
,
batch_size
=
16
,
return_intermediates
=
False
,
x_T
=
None
,
verbose
=
True
,
timesteps
=
None
,
quantize_denoised
=
False
,
mask
=
None
,
x0
=
None
,
shape
=
None
,
**
kwargs
):
if
shape
is
None
:
shape
=
(
batch_size
,
self
.
channels
,
self
.
image_size
,
self
.
image_size
)
if
cond
is
not
None
:
if
isinstance
(
cond
,
dict
):
cond
=
{
key
:
cond
[
key
][:
batch_size
]
if
not
isinstance
(
cond
[
key
],
list
)
else
list
(
map
(
lambda
x
:
x
[:
batch_size
],
cond
[
key
]))
for
key
in
cond
}
else
:
cond
=
[
c
[:
batch_size
]
for
c
in
cond
]
if
isinstance
(
cond
,
list
)
else
cond
[:
batch_size
]
return
self
.
p_sample_loop
(
cond
,
shape
,
return_intermediates
=
return_intermediates
,
x_T
=
x_T
,
verbose
=
verbose
,
timesteps
=
timesteps
,
quantize_denoised
=
quantize_denoised
,
mask
=
mask
,
x0
=
x0
)
@
torch
.
no_grad
()
def
sample_log
(
self
,
cond
,
batch_size
,
ddim
,
ddim_steps
,
**
kwargs
):
if
ddim
:
ddim_sampler
=
DDIMSampler
(
self
)
shape
=
(
self
.
channels
,
self
.
image_size
,
self
.
image_size
)
samples
,
intermediates
=
ddim_sampler
.
sample
(
ddim_steps
,
batch_size
,
shape
,
cond
,
verbose
=
False
,
**
kwargs
)
else
:
samples
,
intermediates
=
self
.
sample
(
cond
=
cond
,
batch_size
=
batch_size
,
return_intermediates
=
True
,
**
kwargs
)
return
samples
,
intermediates
@
torch
.
no_grad
()
def
log_images
(
self
,
batch
,
N
=
8
,
n_row
=
4
,
sample
=
True
,
ddim_steps
=
200
,
ddim_eta
=
1.
,
return_keys
=
None
,
quantize_denoised
=
True
,
inpaint
=
True
,
plot_denoise_rows
=
False
,
plot_progressive_rows
=
True
,
plot_diffusion_rows
=
True
,
**
kwargs
):
use_ddim
=
ddim_steps
is
not
None
log
=
dict
()
z
,
c
,
x
,
xrec
,
xc
=
self
.
get_input
(
batch
,
self
.
first_stage_key
,
return_first_stage_outputs
=
True
,
force_c_encode
=
True
,
return_original_cond
=
True
,
bs
=
N
)
N
=
min
(
x
.
shape
[
0
],
N
)
n_row
=
min
(
x
.
shape
[
0
],
n_row
)
log
[
"inputs"
]
=
x
log
[
"reconstruction"
]
=
xrec
if
self
.
model
.
conditioning_key
is
not
None
:
if
hasattr
(
self
.
cond_stage_model
,
"decode"
):
xc
=
self
.
cond_stage_model
.
decode
(
c
)
log
[
"conditioning"
]
=
xc
elif
self
.
cond_stage_key
in
[
"caption"
]:
xc
=
log_txt_as_img
((
x
.
shape
[
2
],
x
.
shape
[
3
]),
batch
[
"caption"
])
log
[
"conditioning"
]
=
xc
elif
self
.
cond_stage_key
==
'class_label'
:
xc
=
log_txt_as_img
((
x
.
shape
[
2
],
x
.
shape
[
3
]),
batch
[
"human_label"
])
log
[
'conditioning'
]
=
xc
elif
isimage
(
xc
):
log
[
"conditioning"
]
=
xc
if
ismap
(
xc
):
log
[
"original_conditioning"
]
=
self
.
to_rgb
(
xc
)
if
plot_diffusion_rows
:
# get diffusion row
diffusion_row
=
list
()
z_start
=
z
[:
n_row
]
for
t
in
range
(
self
.
num_timesteps
):
if
t
%
self
.
log_every_t
==
0
or
t
==
self
.
num_timesteps
-
1
:
t
=
repeat
(
torch
.
tensor
([
t
]),
'1 -> b'
,
b
=
n_row
)
t
=
t
.
to
(
self
.
device
)
.
long
()
noise
=
torch
.
randn_like
(
z_start
)
z_noisy
=
self
.
q_sample
(
x_start
=
z_start
,
t
=
t
,
noise
=
noise
)
diffusion_row
.
append
(
self
.
decode_first_stage
(
z_noisy
))
diffusion_row
=
torch
.
stack
(
diffusion_row
)
# n_log_step, n_row, C, H, W
diffusion_grid
=
rearrange
(
diffusion_row
,
'n b c h w -> b n c h w'
)
diffusion_grid
=
rearrange
(
diffusion_grid
,
'b n c h w -> (b n) c h w'
)
diffusion_grid
=
make_grid
(
diffusion_grid
,
nrow
=
diffusion_row
.
shape
[
0
])
log
[
"diffusion_row"
]
=
diffusion_grid
if
sample
:
# get denoise row
with
self
.
ema_scope
(
"Plotting"
):
samples
,
z_denoise_row
=
self
.
sample_log
(
cond
=
c
,
batch_size
=
N
,
ddim
=
use_ddim
,
ddim_steps
=
ddim_steps
,
eta
=
ddim_eta
)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples
=
self
.
decode_first_stage
(
samples
)
log
[
"samples"
]
=
x_samples
if
plot_denoise_rows
:
denoise_grid
=
self
.
_get_denoise_row_from_list
(
z_denoise_row
)
log
[
"denoise_row"
]
=
denoise_grid
if
quantize_denoised
and
not
isinstance
(
self
.
first_stage_model
,
AutoencoderKL
)
and
not
isinstance
(
self
.
first_stage_model
,
IdentityFirstStage
):
# also display when quantizing x0 while sampling
with
self
.
ema_scope
(
"Plotting Quantized Denoised"
):
samples
,
z_denoise_row
=
self
.
sample_log
(
cond
=
c
,
batch_size
=
N
,
ddim
=
use_ddim
,
ddim_steps
=
ddim_steps
,
eta
=
ddim_eta
,
quantize_denoised
=
True
)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples
=
self
.
decode_first_stage
(
samples
.
to
(
self
.
device
))
log
[
"samples_x0_quantized"
]
=
x_samples
if
inpaint
:
# make a simple center square
b
,
h
,
w
=
z
.
shape
[
0
],
z
.
shape
[
2
],
z
.
shape
[
3
]
mask
=
torch
.
ones
(
N
,
h
,
w
)
.
to
(
self
.
device
)
# zeros will be filled in
mask
[:,
h
//
4
:
3
*
h
//
4
,
w
//
4
:
3
*
w
//
4
]
=
0.
mask
=
mask
[:,
None
,
...
]
with
self
.
ema_scope
(
"Plotting Inpaint"
):
samples
,
_
=
self
.
sample_log
(
cond
=
c
,
batch_size
=
N
,
ddim
=
use_ddim
,
eta
=
ddim_eta
,
ddim_steps
=
ddim_steps
,
x0
=
z
[:
N
],
mask
=
mask
)
x_samples
=
self
.
decode_first_stage
(
samples
.
to
(
self
.
device
))
log
[
"samples_inpainting"
]
=
x_samples
log
[
"mask"
]
=
mask
# outpaint
with
self
.
ema_scope
(
"Plotting Outpaint"
):
samples
,
_
=
self
.
sample_log
(
cond
=
c
,
batch_size
=
N
,
ddim
=
use_ddim
,
eta
=
ddim_eta
,
ddim_steps
=
ddim_steps
,
x0
=
z
[:
N
],
mask
=
mask
)
x_samples
=
self
.
decode_first_stage
(
samples
.
to
(
self
.
device
))
log
[
"samples_outpainting"
]
=
x_samples
if
plot_progressive_rows
:
with
self
.
ema_scope
(
"Plotting Progressives"
):
img
,
progressives
=
self
.
progressive_denoising
(
c
,
shape
=
(
self
.
channels
,
self
.
image_size
,
self
.
image_size
),
batch_size
=
N
)
prog_row
=
self
.
_get_denoise_row_from_list
(
progressives
,
desc
=
"Progressive Generation"
)
log
[
"progressive_row"
]
=
prog_row
if
return_keys
:
if
np
.
intersect1d
(
list
(
log
.
keys
()),
return_keys
)
.
shape
[
0
]
==
0
:
return
log
else
:
return
{
key
:
log
[
key
]
for
key
in
return_keys
}
return
log
def
configure_optimizers
(
self
):
lr
=
self
.
learning_rate
params
=
list
(
self
.
model
.
parameters
())
if
self
.
cond_stage_trainable
:
print
(
f
"{self.__class__.__name__}: Also optimizing conditioner params!"
)
params
=
params
+
list
(
self
.
cond_stage_model
.
parameters
())
if
self
.
learn_logvar
:
print
(
'Diffusion model optimizing logvar'
)
params
.
append
(
self
.
logvar
)
opt
=
torch
.
optim
.
AdamW
(
params
,
lr
=
lr
)
if
self
.
use_scheduler
:
assert
'target'
in
self
.
scheduler_config
scheduler
=
instantiate_from_config
(
self
.
scheduler_config
)
print
(
"Setting up LambdaLR scheduler..."
)
scheduler
=
[
{
'scheduler'
:
LambdaLR
(
opt
,
lr_lambda
=
scheduler
.
schedule
),
'interval'
:
'step'
,
'frequency'
:
1
}]
return
[
opt
],
scheduler
return
opt
@
torch
.
no_grad
()
def
to_rgb
(
self
,
x
):
x
=
x
.
float
()
if
not
hasattr
(
self
,
"colorize"
):
self
.
colorize
=
torch
.
randn
(
3
,
x
.
shape
[
1
],
1
,
1
)
.
to
(
x
)
x
=
nn
.
functional
.
conv2d
(
x
,
weight
=
self
.
colorize
)
x
=
2.
*
(
x
-
x
.
min
())
/
(
x
.
max
()
-
x
.
min
())
-
1.
return
x
class
DiffusionWrapperV1
(
pl
.
LightningModule
):
def
__init__
(
self
,
diff_model_config
,
conditioning_key
):
super
()
.
__init__
()
self
.
diffusion_model
=
instantiate_from_config
(
diff_model_config
)
self
.
conditioning_key
=
conditioning_key
assert
self
.
conditioning_key
in
[
None
,
'concat'
,
'crossattn'
,
'hybrid'
,
'adm'
]
def
forward
(
self
,
x
,
t
,
c_concat
:
list
=
None
,
c_crossattn
:
list
=
None
):
if
self
.
conditioning_key
is
None
:
out
=
self
.
diffusion_model
(
x
,
t
)
elif
self
.
conditioning_key
==
'concat'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
out
=
self
.
diffusion_model
(
xc
,
t
)
elif
self
.
conditioning_key
==
'crossattn'
:
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
x
,
t
,
context
=
cc
)
elif
self
.
conditioning_key
==
'hybrid'
:
xc
=
torch
.
cat
([
x
]
+
c_concat
,
dim
=
1
)
cc
=
torch
.
cat
(
c_crossattn
,
1
)
out
=
self
.
diffusion_model
(
xc
,
t
,
context
=
cc
)
elif
self
.
conditioning_key
==
'adm'
:
cc
=
c_crossattn
[
0
]
out
=
self
.
diffusion_model
(
x
,
t
,
y
=
cc
)
else
:
raise
NotImplementedError
()
return
out
class
Layout2ImgDiffusionV1
(
LatentDiffusionV1
):
# TODO: move all layout-specific hacks to this class
def
__init__
(
self
,
cond_stage_key
,
*
args
,
**
kwargs
):
assert
cond_stage_key
==
'coordinates_bbox'
,
'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super
()
.
__init__
(
cond_stage_key
=
cond_stage_key
,
*
args
,
**
kwargs
)
def
log_images
(
self
,
batch
,
N
=
8
,
*
args
,
**
kwargs
):
logs
=
super
()
.
log_images
(
batch
=
batch
,
N
=
N
,
*
args
,
**
kwargs
)
key
=
'train'
if
self
.
training
else
'validation'
dset
=
self
.
trainer
.
datamodule
.
datasets
[
key
]
mapper
=
dset
.
conditional_builders
[
self
.
cond_stage_key
]
bbox_imgs
=
[]
map_fn
=
lambda
catno
:
dset
.
get_textual_label
(
dset
.
get_category_id
(
catno
))
for
tknzd_bbox
in
batch
[
self
.
cond_stage_key
][:
N
]:
bboximg
=
mapper
.
plot
(
tknzd_bbox
.
detach
()
.
cpu
(),
map_fn
,
(
256
,
256
))
bbox_imgs
.
append
(
bboximg
)
cond_img
=
torch
.
stack
(
bbox_imgs
,
dim
=
0
)
logs
[
'bbox_image'
]
=
cond_img
return
logs
setattr
(
ldm
.
models
.
diffusion
.
ddpm
,
"DDPMV1"
,
DDPMV1
)
setattr
(
ldm
.
models
.
diffusion
.
ddpm
,
"LatentDiffusionV1"
,
LatentDiffusionV1
)
setattr
(
ldm
.
models
.
diffusion
.
ddpm
,
"DiffusionWrapperV1"
,
DiffusionWrapperV1
)
setattr
(
ldm
.
models
.
diffusion
.
ddpm
,
"Layout2ImgDiffusionV1"
,
Layout2ImgDiffusionV1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment