Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
aa60fc66
Unverified
Commit
aa60fc66
authored
Jan 19, 2023
by
AUTOMATIC1111
Committed by
GitHub
Jan 19, 2023
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #6922 from brkirch/cumsum-fix
Improve cumsum fix for MPS
parents
0f9cacaa
a255dac4
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
4 deletions
+7
-4
devices.py
modules/devices.py
+7
-4
No files found.
modules/devices.py
View file @
aa60fc66
...
...
@@ -169,8 +169,10 @@ orig_Tensor_cumsum = torch.Tensor.cumsum
def
cumsum_fix
(
input
,
cumsum_func
,
*
args
,
**
kwargs
):
if
input
.
device
.
type
==
'mps'
:
output_dtype
=
kwargs
.
get
(
'dtype'
,
input
.
dtype
)
if
any
(
output_dtype
==
broken_dtype
for
broken_dtype
in
[
torch
.
bool
,
torch
.
int8
,
torch
.
int16
,
torch
.
int64
])
:
if
output_dtype
==
torch
.
int64
:
return
cumsum_func
(
input
.
cpu
(),
*
args
,
**
kwargs
)
.
to
(
input
.
device
)
elif
cumsum_needs_bool_fix
and
output_dtype
==
torch
.
bool
or
cumsum_needs_int_fix
and
(
output_dtype
==
torch
.
int8
or
output_dtype
==
torch
.
int16
):
return
cumsum_func
(
input
.
to
(
torch
.
int32
),
*
args
,
**
kwargs
)
.
to
(
torch
.
int64
)
return
cumsum_func
(
input
,
*
args
,
**
kwargs
)
...
...
@@ -181,9 +183,10 @@ if has_mps():
torch
.
nn
.
functional
.
layer_norm
=
layer_norm_fix
torch
.
Tensor
.
numpy
=
numpy_fix
elif
version
.
parse
(
torch
.
__version__
)
>
version
.
parse
(
"1.13.1"
):
if
not
torch
.
Tensor
([
1
,
2
])
.
to
(
torch
.
device
(
"mps"
))
.
equal
(
torch
.
Tensor
([
1
,
1
])
.
to
(
torch
.
device
(
"mps"
))
.
cumsum
(
0
,
dtype
=
torch
.
int16
)):
torch
.
cumsum
=
lambda
input
,
*
args
,
**
kwargs
:
(
cumsum_fix
(
input
,
orig_cumsum
,
*
args
,
**
kwargs
)
)
torch
.
Tensor
.
cumsum
=
lambda
self
,
*
args
,
**
kwargs
:
(
cumsum_fix
(
self
,
orig_Tensor_cumsum
,
*
args
,
**
kwargs
)
)
cumsum_needs_int_fix
=
not
torch
.
Tensor
([
1
,
2
])
.
to
(
torch
.
device
(
"mps"
))
.
equal
(
torch
.
ShortTensor
([
1
,
1
])
.
to
(
torch
.
device
(
"mps"
))
.
cumsum
(
0
))
cumsum_needs_bool_fix
=
not
torch
.
BoolTensor
([
True
,
True
])
.
to
(
device
=
torch
.
device
(
"mps"
),
dtype
=
torch
.
int64
)
.
equal
(
torch
.
BoolTensor
([
True
,
False
])
.
to
(
torch
.
device
(
"mps"
))
.
cumsum
(
0
))
torch
.
cumsum
=
lambda
input
,
*
args
,
**
kwargs
:
(
cumsum_fix
(
input
,
orig_cumsum
,
*
args
,
**
kwargs
)
)
torch
.
Tensor
.
cumsum
=
lambda
self
,
*
args
,
**
kwargs
:
(
cumsum_fix
(
self
,
orig_Tensor_cumsum
,
*
args
,
**
kwargs
)
)
orig_narrow
=
torch
.
narrow
torch
.
narrow
=
lambda
*
args
,
**
kwargs
:
(
orig_narrow
(
*
args
,
**
kwargs
)
.
clone
()
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment