Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
cc90dcc9
Unverified
Commit
cc90dcc9
authored
Nov 27, 2022
by
AUTOMATIC1111
Committed by
GitHub
Nov 27, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #4918 from brkirch/pytorch-fixes
Fixes for PyTorch 1.12.1 when using MPS
parents
10923f9b
e247b740
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
10 deletions
+27
-10
devices.py
modules/devices.py
+24
-7
esrgan_model.py
modules/esrgan_model.py
+1
-1
scunet_model.py
modules/scunet_model.py
+1
-1
swinir_model.py
modules/swinir_model.py
+1
-1
No files found.
modules/devices.py
View file @
cc90dcc9
...
...
@@ -2,9 +2,10 @@ import sys, os, shlex
import
contextlib
import
torch
from
modules
import
errors
from
packaging
import
version
# has_mps is only available in nightly pytorch (for now) and
Mas
OS 12.3+.
# has_mps is only available in nightly pytorch (for now) and
mac
OS 12.3+.
# check `getattr` and try it for compatibility
def
has_mps
()
->
bool
:
if
not
getattr
(
torch
,
'has_mps'
,
False
):
...
...
@@ -99,9 +100,25 @@ def autocast(disable=False):
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def
mps_contiguous
(
input_tensor
,
device
):
return
input_tensor
.
contiguous
()
if
device
.
type
==
'mps'
else
input_tensor
def
mps_contiguous_to
(
input_tensor
,
device
):
return
mps_contiguous
(
input_tensor
,
device
)
.
to
(
device
)
orig_tensor_to
=
torch
.
Tensor
.
to
def
tensor_to_fix
(
self
,
*
args
,
**
kwargs
):
if
self
.
device
.
type
!=
'mps'
and
\
((
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
torch
.
device
)
and
args
[
0
]
.
type
==
'mps'
)
or
\
(
isinstance
(
kwargs
.
get
(
'device'
),
torch
.
device
)
and
kwargs
[
'device'
]
.
type
==
'mps'
)):
self
=
self
.
contiguous
()
return
orig_tensor_to
(
self
,
*
args
,
**
kwargs
)
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
orig_layer_norm
=
torch
.
nn
.
functional
.
layer_norm
def
layer_norm_fix
(
*
args
,
**
kwargs
):
if
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
torch
.
Tensor
)
and
args
[
0
]
.
device
.
type
==
'mps'
:
args
=
list
(
args
)
args
[
0
]
=
args
[
0
]
.
contiguous
()
return
orig_layer_norm
(
*
args
,
**
kwargs
)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if
has_mps
()
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.13"
):
torch
.
Tensor
.
to
=
tensor_to_fix
torch
.
nn
.
functional
.
layer_norm
=
layer_norm_fix
modules/esrgan_model.py
View file @
cc90dcc9
...
...
@@ -199,7 +199,7 @@ def upscale_without_tiling(model, img):
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
ascontiguousarray
(
np
.
transpose
(
img
,
(
2
,
0
,
1
)))
/
255
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
devices
.
mps_contiguous_to
(
img
.
unsqueeze
(
0
),
devices
.
device_esrgan
)
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_esrgan
)
with
torch
.
no_grad
():
output
=
model
(
img
)
output
=
output
.
squeeze
()
.
float
()
.
cpu
()
.
clamp_
(
0
,
1
)
.
numpy
()
...
...
modules/scunet_model.py
View file @
cc90dcc9
...
...
@@ -54,7 +54,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
moveaxis
(
img
,
2
,
0
)
/
255
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
devices
.
mps_contiguous_to
(
img
.
unsqueeze
(
0
),
device
)
img
=
img
.
unsqueeze
(
0
)
.
to
(
device
)
with
torch
.
no_grad
():
output
=
model
(
img
)
...
...
modules/swinir_model.py
View file @
cc90dcc9
...
...
@@ -111,7 +111,7 @@ def upscale(
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
moveaxis
(
img
,
2
,
0
)
/
255
img
=
torch
.
from_numpy
(
img
)
.
float
()
img
=
devices
.
mps_contiguous_to
(
img
.
unsqueeze
(
0
),
devices
.
device_swinir
)
img
=
img
.
unsqueeze
(
0
)
.
to
(
devices
.
device_swinir
)
with
torch
.
no_grad
(),
precision_scope
(
"cuda"
):
_
,
_
,
h_old
,
w_old
=
img
.
size
()
h_pad
=
(
h_old
//
window_size
+
1
)
*
window_size
-
h_old
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment