Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
21642000
Commit
21642000
authored
Jan 12, 2023
by
Shondoit
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add PNG alpha channel as weight maps to data entries
parent
c4bfd20f
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
13 deletions
+38
-13
dataset.py
modules/textual_inversion/dataset.py
+38
-13
No files found.
modules/textual_inversion/dataset.py
View file @
21642000
...
@@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
...
@@ -19,9 +19,10 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class
DatasetEntry
:
class
DatasetEntry
:
def
__init__
(
self
,
filename
=
None
,
filename_text
=
None
,
latent_dist
=
None
,
latent_sample
=
None
,
cond
=
None
,
cond_text
=
None
,
pixel_values
=
None
):
def
__init__
(
self
,
filename
=
None
,
filename_text
=
None
,
latent_dist
=
None
,
latent_sample
=
None
,
cond
=
None
,
cond_text
=
None
,
pixel_values
=
None
,
weight
=
None
):
self
.
filename
=
filename
self
.
filename
=
filename
self
.
filename_text
=
filename_text
self
.
filename_text
=
filename_text
self
.
weight
=
weight
self
.
latent_dist
=
latent_dist
self
.
latent_dist
=
latent_dist
self
.
latent_sample
=
latent_sample
self
.
latent_sample
=
latent_sample
self
.
cond
=
cond
self
.
cond
=
cond
...
@@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
...
@@ -56,10 +57,16 @@ class PersonalizedBase(Dataset):
print
(
"Preparing dataset..."
)
print
(
"Preparing dataset..."
)
for
path
in
tqdm
.
tqdm
(
self
.
image_paths
):
for
path
in
tqdm
.
tqdm
(
self
.
image_paths
):
alpha_channel
=
None
if
shared
.
state
.
interrupted
:
if
shared
.
state
.
interrupted
:
raise
Exception
(
"interrupted"
)
raise
Exception
(
"interrupted"
)
try
:
try
:
image
=
Image
.
open
(
path
)
.
convert
(
'RGB'
)
image
=
Image
.
open
(
path
)
#Currently does not work for single color transparency
#We would need to read image.info['transparency'] for that
if
'A'
in
image
.
getbands
():
alpha_channel
=
image
.
getchannel
(
'A'
)
image
=
image
.
convert
(
'RGB'
)
if
not
varsize
:
if
not
varsize
:
image
=
image
.
resize
((
width
,
height
),
PIL
.
Image
.
BICUBIC
)
image
=
image
.
resize
((
width
,
height
),
PIL
.
Image
.
BICUBIC
)
except
Exception
:
except
Exception
:
...
@@ -87,17 +94,33 @@ class PersonalizedBase(Dataset):
...
@@ -87,17 +94,33 @@ class PersonalizedBase(Dataset):
with
devices
.
autocast
():
with
devices
.
autocast
():
latent_dist
=
model
.
encode_first_stage
(
torchdata
.
unsqueeze
(
dim
=
0
))
latent_dist
=
model
.
encode_first_stage
(
torchdata
.
unsqueeze
(
dim
=
0
))
if
latent_sampling_method
==
"once"
or
(
latent_sampling_method
==
"deterministic"
and
not
isinstance
(
latent_dist
,
DiagonalGaussianDistribution
)):
#Perform latent sampling, even for random sampling.
latent_sample
=
model
.
get_first_stage_encoding
(
latent_dist
)
.
squeeze
()
.
to
(
devices
.
cpu
)
#We need the sample dimensions for the weights
latent_sampling_method
=
"once"
if
latent_sampling_method
==
"deterministic"
:
entry
=
DatasetEntry
(
filename
=
path
,
filename_text
=
filename_text
,
latent_sample
=
latent_sample
)
if
isinstance
(
latent_dist
,
DiagonalGaussianDistribution
):
elif
latent_sampling_method
==
"deterministic"
:
# Works only for DiagonalGaussianDistribution
# Works only for DiagonalGaussianDistribution
latent_dist
.
std
=
0
latent_dist
.
std
=
0
else
:
latent_sample
=
model
.
get_first_stage_encoding
(
latent_dist
)
.
squeeze
()
.
to
(
devices
.
cpu
)
latent_sampling_method
=
"once"
entry
=
DatasetEntry
(
filename
=
path
,
filename_text
=
filename_text
,
latent_sample
=
latent_sample
)
latent_sample
=
model
.
get_first_stage_encoding
(
latent_dist
)
.
squeeze
()
.
to
(
devices
.
cpu
)
elif
latent_sampling_method
==
"random"
:
entry
=
DatasetEntry
(
filename
=
path
,
filename_text
=
filename_text
,
latent_dist
=
latent_dist
)
if
alpha_channel
is
not
None
:
channels
,
*
latent_size
=
latent_sample
.
shape
weight_img
=
alpha_channel
.
resize
(
latent_size
)
npweight
=
np
.
array
(
weight_img
)
.
astype
(
np
.
float32
)
#Repeat for every channel in the latent sample
weight
=
torch
.
tensor
([
npweight
]
*
channels
)
.
reshape
([
channels
]
+
latent_size
)
#Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.
weight
-=
weight
.
min
()
weight
/=
weight
.
mean
()
else
:
#If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later
weight
=
torch
.
ones
([
channels
]
+
latent_size
)
if
latent_sampling_method
==
"random"
:
entry
=
DatasetEntry
(
filename
=
path
,
filename_text
=
filename_text
,
latent_dist
=
latent_dist
,
weight
=
weight
)
else
:
entry
=
DatasetEntry
(
filename
=
path
,
filename_text
=
filename_text
,
latent_sample
=
latent_sample
,
weight
=
weight
)
if
not
(
self
.
tag_drop_out
!=
0
or
self
.
shuffle_tags
):
if
not
(
self
.
tag_drop_out
!=
0
or
self
.
shuffle_tags
):
entry
.
cond_text
=
self
.
create_text
(
filename_text
)
entry
.
cond_text
=
self
.
create_text
(
filename_text
)
...
@@ -110,6 +133,7 @@ class PersonalizedBase(Dataset):
...
@@ -110,6 +133,7 @@ class PersonalizedBase(Dataset):
del
torchdata
del
torchdata
del
latent_dist
del
latent_dist
del
latent_sample
del
latent_sample
del
weight
self
.
length
=
len
(
self
.
dataset
)
self
.
length
=
len
(
self
.
dataset
)
self
.
groups
=
list
(
groups
.
values
())
self
.
groups
=
list
(
groups
.
values
())
...
@@ -195,6 +219,7 @@ class BatchLoader:
...
@@ -195,6 +219,7 @@ class BatchLoader:
self
.
cond_text
=
[
entry
.
cond_text
for
entry
in
data
]
self
.
cond_text
=
[
entry
.
cond_text
for
entry
in
data
]
self
.
cond
=
[
entry
.
cond
for
entry
in
data
]
self
.
cond
=
[
entry
.
cond
for
entry
in
data
]
self
.
latent_sample
=
torch
.
stack
([
entry
.
latent_sample
for
entry
in
data
])
.
squeeze
(
1
)
self
.
latent_sample
=
torch
.
stack
([
entry
.
latent_sample
for
entry
in
data
])
.
squeeze
(
1
)
self
.
weight
=
torch
.
stack
([
entry
.
weight
for
entry
in
data
])
.
squeeze
(
1
)
#self.emb_index = [entry.emb_index for entry in data]
#self.emb_index = [entry.emb_index for entry in data]
#print(self.latent_sample.device)
#print(self.latent_sample.device)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment