Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Contribute to GitLab
Sign in / Register
Toggle navigation
S
stable-diffusion-webui
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
stable-diffusion-webui
Commits
b1198153
Commit
b1198153
authored
Jan 05, 2023
by
brkirch
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use narrow instead of dynamic_slice
parent
3bfe2bb5
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
15 deletions
+19
-15
sub_quadratic_attention.py
modules/sub_quadratic_attention.py
+19
-15
No files found.
modules/sub_quadratic_attention.py
View file @
b1198153
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
# credit:
# credit:
# Amin Rezaei (original author)
# Amin Rezaei (original author)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
# implementation of:
# Self-attention Does Not Need O(n2) Memory":
# Self-attention Does Not Need O(n2) Memory":
# https://arxiv.org/abs/2112.05682v2
# https://arxiv.org/abs/2112.05682v2
...
@@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint
...
@@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint
import
math
import
math
from
typing
import
Optional
,
NamedTuple
,
Protocol
,
List
from
typing
import
Optional
,
NamedTuple
,
Protocol
,
List
def
dynamic_slice
(
def
narrow_trunc
(
x
:
Tensor
,
input
:
Tensor
,
starts
:
List
[
int
],
dim
:
int
,
sizes
:
List
[
int
],
start
:
int
,
length
:
int
)
->
Tensor
:
)
->
Tensor
:
slicing
=
[
slice
(
start
,
start
+
size
)
for
start
,
size
in
zip
(
starts
,
sizes
)]
return
torch
.
narrow
(
input
,
dim
,
start
,
length
if
input
.
shape
[
dim
]
>=
start
+
length
else
input
.
shape
[
dim
]
-
start
)
return
x
[
slicing
]
class
AttnChunk
(
NamedTuple
):
class
AttnChunk
(
NamedTuple
):
exp_values
:
Tensor
exp_values
:
Tensor
...
@@ -76,15 +77,17 @@ def _query_chunk_attention(
...
@@ -76,15 +77,17 @@ def _query_chunk_attention(
_
,
_
,
v_channels_per_head
=
value
.
shape
_
,
_
,
v_channels_per_head
=
value
.
shape
def
chunk_scanner
(
chunk_idx
:
int
)
->
AttnChunk
:
def
chunk_scanner
(
chunk_idx
:
int
)
->
AttnChunk
:
key_chunk
=
dynamic_slice
(
key_chunk
=
narrow_trunc
(
key
,
key
,
(
0
,
chunk_idx
,
0
),
1
,
(
batch_x_heads
,
kv_chunk_size
,
k_channels_per_head
)
chunk_idx
,
kv_chunk_size
)
)
value_chunk
=
dynamic_slice
(
value_chunk
=
narrow_trunc
(
value
,
value
,
(
0
,
chunk_idx
,
0
),
1
,
(
batch_x_heads
,
kv_chunk_size
,
v_channels_per_head
)
chunk_idx
,
kv_chunk_size
)
)
return
summarize_chunk
(
query
,
key_chunk
,
value_chunk
)
return
summarize_chunk
(
query
,
key_chunk
,
value_chunk
)
...
@@ -161,10 +164,11 @@ def efficient_dot_product_attention(
...
@@ -161,10 +164,11 @@ def efficient_dot_product_attention(
kv_chunk_size
=
max
(
kv_chunk_size
,
kv_chunk_size_min
)
kv_chunk_size
=
max
(
kv_chunk_size
,
kv_chunk_size_min
)
def
get_query_chunk
(
chunk_idx
:
int
)
->
Tensor
:
def
get_query_chunk
(
chunk_idx
:
int
)
->
Tensor
:
return
dynamic_slice
(
return
narrow_trunc
(
query
,
query
,
(
0
,
chunk_idx
,
0
),
1
,
(
batch_x_heads
,
min
(
query_chunk_size
,
q_tokens
),
q_channels_per_head
)
chunk_idx
,
min
(
query_chunk_size
,
q_tokens
)
)
)
summarize_chunk
:
SummarizeChunk
=
partial
(
_summarize_chunk
,
scale
=
scale
)
summarize_chunk
:
SummarizeChunk
=
partial
(
_summarize_chunk
,
scale
=
scale
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment