Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
3bfc6c07
Unverified
Commit
3bfc6c07
authored
Dec 24, 2022
by
AUTOMATIC1111
Committed by
GitHub
Dec 24, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #5810 from brkirch/fix-training-mps
Training fixes for MPS
parents
f0dfed2a
cca16373
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
6 deletions
+15
-6
devices.py
modules/devices.py
+9
-0
safe.py
modules/safe.py
+6
-6
No files found.
modules/devices.py
View file @
3bfc6c07
...
@@ -125,7 +125,16 @@ def layer_norm_fix(*args, **kwargs):
...
@@ -125,7 +125,16 @@ def layer_norm_fix(*args, **kwargs):
return
orig_layer_norm
(
*
args
,
**
kwargs
)
return
orig_layer_norm
(
*
args
,
**
kwargs
)
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
orig_tensor_numpy
=
torch
.
Tensor
.
numpy
def
numpy_fix
(
self
,
*
args
,
**
kwargs
):
if
self
.
requires_grad
:
self
=
self
.
detach
()
return
orig_tensor_numpy
(
self
,
*
args
,
**
kwargs
)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if
has_mps
()
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.13"
):
if
has_mps
()
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.13"
):
torch
.
Tensor
.
to
=
tensor_to_fix
torch
.
Tensor
.
to
=
tensor_to_fix
torch
.
nn
.
functional
.
layer_norm
=
layer_norm_fix
torch
.
nn
.
functional
.
layer_norm
=
layer_norm_fix
torch
.
Tensor
.
numpy
=
numpy_fix
modules/safe.py
View file @
3bfc6c07
...
@@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler):
...
@@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler):
if
module
==
'collections'
and
name
==
'OrderedDict'
:
if
module
==
'collections'
and
name
==
'OrderedDict'
:
return
getattr
(
collections
,
name
)
return
getattr
(
collections
,
name
)
if
module
==
'torch._utils'
and
name
in
[
'_rebuild_tensor_v2'
,
'_rebuild_parameter'
]:
if
module
==
'torch._utils'
and
name
in
[
'_rebuild_tensor_v2'
,
'_rebuild_parameter'
,
'_rebuild_device_tensor_from_numpy'
]:
return
getattr
(
torch
.
_utils
,
name
)
return
getattr
(
torch
.
_utils
,
name
)
if
module
==
'torch'
and
name
in
[
'FloatStorage'
,
'HalfStorage'
,
'IntStorage'
,
'LongStorage'
,
'DoubleStorage'
,
'ByteStorage'
]:
if
module
==
'torch'
and
name
in
[
'FloatStorage'
,
'HalfStorage'
,
'IntStorage'
,
'LongStorage'
,
'DoubleStorage'
,
'ByteStorage'
,
'float32'
]:
return
getattr
(
torch
,
name
)
return
getattr
(
torch
,
name
)
if
module
==
'torch.nn.modules.container'
and
name
in
[
'ParameterDict'
]:
if
module
==
'torch.nn.modules.container'
and
name
in
[
'ParameterDict'
]:
return
getattr
(
torch
.
nn
.
modules
.
container
,
name
)
return
getattr
(
torch
.
nn
.
modules
.
container
,
name
)
if
module
==
'numpy.core.multiarray'
and
name
==
'scalar'
:
if
module
==
'numpy.core.multiarray'
and
name
in
[
'scalar'
,
'_reconstruct'
]
:
return
numpy
.
core
.
multiarray
.
scalar
return
getattr
(
numpy
.
core
.
multiarray
,
name
)
if
module
==
'numpy'
and
name
==
'dtype'
:
if
module
==
'numpy'
and
name
in
[
'dtype'
,
'ndarray'
]
:
return
numpy
.
dtype
return
getattr
(
numpy
,
name
)
if
module
==
'_codecs'
and
name
==
'encode'
:
if
module
==
'_codecs'
and
name
==
'encode'
:
return
encode
return
encode
if
module
==
"pytorch_lightning.callbacks"
and
name
==
'model_checkpoint'
:
if
module
==
"pytorch_lightning.callbacks"
and
name
==
'model_checkpoint'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment