Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
30228c67
Unverified
Commit
30228c67
authored
Feb 04, 2023
by
AUTOMATIC1111
Committed by
GitHub
Feb 04, 2023
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #7461 from brkirch/mac-fixes
Move Mac related code to separate file
parents
c4b9ed1a
4306659c
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
61 deletions
+60
-61
devices.py
modules/devices.py
+7
-45
mac_specific.py
modules/mac_specific.py
+53
-0
sd_samplers_common.py
modules/sd_samplers_common.py
+0
-16
No files found.
modules/devices.py
View file @
30228c67
import
sys
,
os
,
shlex
import
sys
import
contextlib
import
torch
from
modules
import
errors
from
modules.sd_hijack_utils
import
CondFunc
from
packaging
import
version
if
sys
.
platform
==
"darwin"
:
from
modules
import
mac_specific
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility
def
has_mps
()
->
bool
:
if
not
getattr
(
torch
,
'has_mps'
,
False
):
return
False
try
:
torch
.
zeros
(
1
)
.
to
(
torch
.
device
(
"mps"
))
return
True
except
Exception
:
if
sys
.
platform
!=
"darwin"
:
return
False
else
:
return
mac_specific
.
has_mps
def
extract_device_id
(
args
,
name
):
for
x
in
range
(
len
(
args
)):
...
...
@@ -155,36 +150,3 @@ def test_for_nans(x, where):
message
+=
" Use --disable-nan-check commandline argument to disable this check."
raise
NansException
(
message
)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def
cumsum_fix
(
input
,
cumsum_func
,
*
args
,
**
kwargs
):
if
input
.
device
.
type
==
'mps'
:
output_dtype
=
kwargs
.
get
(
'dtype'
,
input
.
dtype
)
if
output_dtype
==
torch
.
int64
:
return
cumsum_func
(
input
.
cpu
(),
*
args
,
**
kwargs
)
.
to
(
input
.
device
)
elif
cumsum_needs_bool_fix
and
output_dtype
==
torch
.
bool
or
cumsum_needs_int_fix
and
(
output_dtype
==
torch
.
int8
or
output_dtype
==
torch
.
int16
):
return
cumsum_func
(
input
.
to
(
torch
.
int32
),
*
args
,
**
kwargs
)
.
to
(
torch
.
int64
)
return
cumsum_func
(
input
,
*
args
,
**
kwargs
)
if
has_mps
():
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.13"
):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc
(
'torch.Tensor.to'
,
lambda
orig_func
,
self
,
*
args
,
**
kwargs
:
orig_func
(
self
.
contiguous
(),
*
args
,
**
kwargs
),
lambda
_
,
self
,
*
args
,
**
kwargs
:
self
.
device
.
type
!=
'mps'
and
(
args
and
isinstance
(
args
[
0
],
torch
.
device
)
and
args
[
0
]
.
type
==
'mps'
or
isinstance
(
kwargs
.
get
(
'device'
),
torch
.
device
)
and
kwargs
[
'device'
]
.
type
==
'mps'
))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc
(
'torch.nn.functional.layer_norm'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
orig_func
(
*
([
args
[
0
]
.
contiguous
()]
+
list
(
args
[
1
:])),
**
kwargs
),
lambda
_
,
*
args
,
**
kwargs
:
args
and
isinstance
(
args
[
0
],
torch
.
Tensor
)
and
args
[
0
]
.
device
.
type
==
'mps'
)
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc
(
'torch.Tensor.numpy'
,
lambda
orig_func
,
self
,
*
args
,
**
kwargs
:
orig_func
(
self
.
detach
(),
*
args
,
**
kwargs
),
lambda
_
,
self
,
*
args
,
**
kwargs
:
self
.
requires_grad
)
elif
version
.
parse
(
torch
.
__version__
)
>
version
.
parse
(
"1.13.1"
):
cumsum_needs_int_fix
=
not
torch
.
Tensor
([
1
,
2
])
.
to
(
torch
.
device
(
"mps"
))
.
equal
(
torch
.
ShortTensor
([
1
,
1
])
.
to
(
torch
.
device
(
"mps"
))
.
cumsum
(
0
))
cumsum_needs_bool_fix
=
not
torch
.
BoolTensor
([
True
,
True
])
.
to
(
device
=
torch
.
device
(
"mps"
),
dtype
=
torch
.
int64
)
.
equal
(
torch
.
BoolTensor
([
True
,
False
])
.
to
(
torch
.
device
(
"mps"
))
.
cumsum
(
0
))
cumsum_fix_func
=
lambda
orig_func
,
input
,
*
args
,
**
kwargs
:
cumsum_fix
(
input
,
orig_func
,
*
args
,
**
kwargs
)
CondFunc
(
'torch.cumsum'
,
cumsum_fix_func
,
None
)
CondFunc
(
'torch.Tensor.cumsum'
,
cumsum_fix_func
,
None
)
CondFunc
(
'torch.narrow'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
orig_func
(
*
args
,
**
kwargs
)
.
clone
(),
None
)
modules/mac_specific.py
0 → 100644
View file @
30228c67
import
torch
from
modules
import
paths
from
modules.sd_hijack_utils
import
CondFunc
from
packaging
import
version
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility
def
check_for_mps
()
->
bool
:
if
not
getattr
(
torch
,
'has_mps'
,
False
):
return
False
try
:
torch
.
zeros
(
1
)
.
to
(
torch
.
device
(
"mps"
))
return
True
except
Exception
:
return
False
has_mps
=
check_for_mps
()
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def
cumsum_fix
(
input
,
cumsum_func
,
*
args
,
**
kwargs
):
if
input
.
device
.
type
==
'mps'
:
output_dtype
=
kwargs
.
get
(
'dtype'
,
input
.
dtype
)
if
output_dtype
==
torch
.
int64
:
return
cumsum_func
(
input
.
cpu
(),
*
args
,
**
kwargs
)
.
to
(
input
.
device
)
elif
cumsum_needs_bool_fix
and
output_dtype
==
torch
.
bool
or
cumsum_needs_int_fix
and
(
output_dtype
==
torch
.
int8
or
output_dtype
==
torch
.
int16
):
return
cumsum_func
(
input
.
to
(
torch
.
int32
),
*
args
,
**
kwargs
)
.
to
(
torch
.
int64
)
return
cumsum_func
(
input
,
*
args
,
**
kwargs
)
if
has_mps
:
# MPS fix for randn in torchsde
CondFunc
(
'torchsde._brownian.brownian_interval._randn'
,
lambda
_
,
size
,
dtype
,
device
,
seed
:
torch
.
randn
(
size
,
dtype
=
dtype
,
device
=
torch
.
device
(
"cpu"
),
generator
=
torch
.
Generator
(
torch
.
device
(
"cpu"
))
.
manual_seed
(
int
(
seed
)))
.
to
(
device
),
lambda
_
,
size
,
dtype
,
device
,
seed
:
device
.
type
==
'mps'
)
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"1.13"
):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc
(
'torch.Tensor.to'
,
lambda
orig_func
,
self
,
*
args
,
**
kwargs
:
orig_func
(
self
.
contiguous
(),
*
args
,
**
kwargs
),
lambda
_
,
self
,
*
args
,
**
kwargs
:
self
.
device
.
type
!=
'mps'
and
(
args
and
isinstance
(
args
[
0
],
torch
.
device
)
and
args
[
0
]
.
type
==
'mps'
or
isinstance
(
kwargs
.
get
(
'device'
),
torch
.
device
)
and
kwargs
[
'device'
]
.
type
==
'mps'
))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc
(
'torch.nn.functional.layer_norm'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
orig_func
(
*
([
args
[
0
]
.
contiguous
()]
+
list
(
args
[
1
:])),
**
kwargs
),
lambda
_
,
*
args
,
**
kwargs
:
args
and
isinstance
(
args
[
0
],
torch
.
Tensor
)
and
args
[
0
]
.
device
.
type
==
'mps'
)
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc
(
'torch.Tensor.numpy'
,
lambda
orig_func
,
self
,
*
args
,
**
kwargs
:
orig_func
(
self
.
detach
(),
*
args
,
**
kwargs
),
lambda
_
,
self
,
*
args
,
**
kwargs
:
self
.
requires_grad
)
elif
version
.
parse
(
torch
.
__version__
)
>
version
.
parse
(
"1.13.1"
):
cumsum_needs_int_fix
=
not
torch
.
Tensor
([
1
,
2
])
.
to
(
torch
.
device
(
"mps"
))
.
equal
(
torch
.
ShortTensor
([
1
,
1
])
.
to
(
torch
.
device
(
"mps"
))
.
cumsum
(
0
))
cumsum_needs_bool_fix
=
not
torch
.
BoolTensor
([
True
,
True
])
.
to
(
device
=
torch
.
device
(
"mps"
),
dtype
=
torch
.
int64
)
.
equal
(
torch
.
BoolTensor
([
True
,
False
])
.
to
(
torch
.
device
(
"mps"
))
.
cumsum
(
0
))
cumsum_fix_func
=
lambda
orig_func
,
input
,
*
args
,
**
kwargs
:
cumsum_fix
(
input
,
orig_func
,
*
args
,
**
kwargs
)
CondFunc
(
'torch.cumsum'
,
cumsum_fix_func
,
None
)
CondFunc
(
'torch.Tensor.cumsum'
,
cumsum_fix_func
,
None
)
CondFunc
(
'torch.narrow'
,
lambda
orig_func
,
*
args
,
**
kwargs
:
orig_func
(
*
args
,
**
kwargs
)
.
clone
(),
None
)
modules/sd_samplers_common.py
View file @
30228c67
...
...
@@ -2,7 +2,6 @@ from collections import namedtuple
import
numpy
as
np
import
torch
from
PIL
import
Image
import
torchsde._brownian.brownian_interval
from
modules
import
devices
,
processing
,
images
,
sd_vae_approx
from
modules.shared
import
opts
,
state
...
...
@@ -61,18 +60,3 @@ def store_latent(decoded):
class
InterruptedException
(
BaseException
):
pass
# MPS fix for randn in torchsde
# XXX move this to separate file for MPS
def
torchsde_randn
(
size
,
dtype
,
device
,
seed
):
if
device
.
type
==
'mps'
:
generator
=
torch
.
Generator
(
devices
.
cpu
)
.
manual_seed
(
int
(
seed
))
return
torch
.
randn
(
size
,
dtype
=
dtype
,
device
=
devices
.
cpu
,
generator
=
generator
)
.
to
(
device
)
else
:
generator
=
torch
.
Generator
(
device
)
.
manual_seed
(
int
(
seed
))
return
torch
.
randn
(
size
,
dtype
=
dtype
,
device
=
device
,
generator
=
generator
)
torchsde
.
_brownian
.
brownian_interval
.
_randn
=
torchsde_randn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment