Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
c4bfd20f
Commit
c4bfd20f
authored
Jan 12, 2023
by
Shondoit
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Hijack to add weighted_forward to model: return loss * weight map
parent
3715ece0
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
0 deletions
+52
-0
sd_hijack.py
modules/sd_hijack.py
+52
-0
No files found.
modules/sd_hijack.py
View file @
c4bfd20f
import
torch
import
torch
from
torch.nn.functional
import
silu
from
torch.nn.functional
import
silu
from
types
import
MethodType
import
modules.textual_inversion.textual_inversion
import
modules.textual_inversion.textual_inversion
from
modules
import
devices
,
sd_hijack_optimizations
,
shared
,
sd_hijack_checkpoint
from
modules
import
devices
,
sd_hijack_optimizations
,
shared
,
sd_hijack_checkpoint
...
@@ -76,6 +77,54 @@ def fix_checkpoint():
...
@@ -76,6 +77,54 @@ def fix_checkpoint():
pass
pass
def
weighted_loss
(
sd_model
,
pred
,
target
,
mean
=
True
):
#Calculate the weight normally, but ignore the mean
loss
=
sd_model
.
_old_get_loss
(
pred
,
target
,
mean
=
False
)
#Check if we have weights available
weight
=
getattr
(
sd_model
,
'_custom_loss_weight'
,
None
)
if
weight
is
not
None
:
loss
*=
weight
#Return the loss, as mean if specified
return
loss
.
mean
()
if
mean
else
loss
def
weighted_forward
(
sd_model
,
x
,
c
,
w
,
*
args
,
**
kwargs
):
try
:
#Temporarily append weights to a place accessible during loss calc
sd_model
.
_custom_loss_weight
=
w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if
not
hasattr
(
sd_model
,
'_old_get_loss'
):
sd_model
.
_old_get_loss
=
sd_model
.
get_loss
sd_model
.
get_loss
=
MethodType
(
weighted_loss
,
sd_model
)
#Run the standard forward function, but with the patched 'get_loss'
return
sd_model
.
forward
(
x
,
c
,
*
args
,
**
kwargs
)
finally
:
try
:
#Delete temporary weights if appended
del
sd_model
.
_custom_loss_weight
except
AttributeError
as
e
:
pass
#If we have an old loss function, reset the loss function to the original one
if
hasattr
(
sd_model
,
'_old_get_loss'
):
sd_model
.
get_loss
=
sd_model
.
_old_get_loss
del
sd_model
.
_old_get_loss
def
apply_weighted_forward
(
sd_model
):
#Add new function 'weighted_forward' that can be called to calc weighted loss
sd_model
.
weighted_forward
=
MethodType
(
weighted_forward
,
sd_model
)
def
undo_weighted_forward
(
sd_model
):
try
:
del
sd_model
.
weighted_forward
except
AttributeError
as
e
:
pass
class
StableDiffusionModelHijack
:
class
StableDiffusionModelHijack
:
fixes
=
None
fixes
=
None
comments
=
[]
comments
=
[]
...
@@ -104,6 +153,8 @@ class StableDiffusionModelHijack:
...
@@ -104,6 +153,8 @@ class StableDiffusionModelHijack:
m
.
cond_stage_model
.
model
.
token_embedding
=
EmbeddingsWithFixes
(
m
.
cond_stage_model
.
model
.
token_embedding
,
self
)
m
.
cond_stage_model
.
model
.
token_embedding
=
EmbeddingsWithFixes
(
m
.
cond_stage_model
.
model
.
token_embedding
,
self
)
m
.
cond_stage_model
=
sd_hijack_open_clip
.
FrozenOpenCLIPEmbedderWithCustomWords
(
m
.
cond_stage_model
,
self
)
m
.
cond_stage_model
=
sd_hijack_open_clip
.
FrozenOpenCLIPEmbedderWithCustomWords
(
m
.
cond_stage_model
,
self
)
apply_weighted_forward
(
m
)
self
.
optimization_method
=
apply_optimizations
()
self
.
optimization_method
=
apply_optimizations
()
self
.
clip
=
m
.
cond_stage_model
self
.
clip
=
m
.
cond_stage_model
...
@@ -132,6 +183,7 @@ class StableDiffusionModelHijack:
...
@@ -132,6 +183,7 @@ class StableDiffusionModelHijack:
m
.
cond_stage_model
=
m
.
cond_stage_model
.
wrapped
m
.
cond_stage_model
=
m
.
cond_stage_model
.
wrapped
undo_optimizations
()
undo_optimizations
()
undo_weighted_forward
(
m
)
self
.
apply_circular
(
False
)
self
.
apply_circular
(
False
)
self
.
layers
=
None
self
.
layers
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment