Skip to content

Commit 4936d0f

Browse files
Add xformers to training scripts (#103)
1 parent 7dd0467 commit 4936d0f

3 files changed

Lines changed: 88 additions & 4 deletions

File tree

lora_diffusion/xformers_utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import functools
2+
3+
import torch
4+
from diffusers.models.attention import BasicTransformerBlock
5+
from diffusers.utils.import_utils import is_xformers_available
6+
7+
from .lora import LoraInjectedLinear
8+
9+
if is_xformers_available():
10+
import xformers
11+
import xformers.ops
12+
else:
13+
xformers = None
14+
15+
16+
@functools.cache
17+
def test_xformers_backwards(size):
18+
@torch.enable_grad()
19+
def _grad(size):
20+
q = torch.randn((1, 4, size), device="cuda")
21+
k = torch.randn((1, 4, size), device="cuda")
22+
v = torch.randn((1, 4, size), device="cuda")
23+
24+
q = q.detach().requires_grad_()
25+
k = k.detach().requires_grad_()
26+
v = v.detach().requires_grad_()
27+
28+
out = xformers.ops.memory_efficient_attention(q, k, v)
29+
loss = out.sum(2).mean(0).sum()
30+
31+
return torch.autograd.grad(loss, v)
32+
33+
try:
34+
_grad(size)
35+
print(size, "pass")
36+
return True
37+
except Exception as e:
38+
print(size, "fail")
39+
return False
40+
41+
42+
def set_use_memory_efficient_attention_xformers(
43+
module: torch.nn.Module, valid: bool
44+
) -> None:
45+
def fn_test_dim_head(module: torch.nn.Module):
46+
if isinstance(module, BasicTransformerBlock):
47+
# dim_head isn't stored anywhere, so back-calculate
48+
source = module.attn1.to_v
49+
if isinstance(source, LoraInjectedLinear):
50+
source = source.linear
51+
52+
dim_head = source.out_features // module.attn1.heads
53+
54+
result = test_xformers_backwards(dim_head)
55+
56+
# If dim_head > dim_head_max, turn xformers off
57+
if not result:
58+
module.set_use_memory_efficient_attention_xformers(False)
59+
60+
for child in module.children():
61+
fn_test_dim_head(child)
62+
63+
if not is_xformers_available() and valid:
64+
print("XFormers is not available. Skipping.")
65+
return
66+
67+
module.set_use_memory_efficient_attention_xformers(valid)
68+
69+
if valid:
70+
fn_test_dim_head(module)

train_lora_dreambooth.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
save_lora_weight,
3737
save_safeloras,
3838
)
39-
40-
from torch.utils.data import Dataset
39+
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
4140
from PIL import Image
41+
from torch.utils.data import Dataset
4242
from torchvision import transforms
4343

4444
from pathlib import Path
@@ -450,6 +450,9 @@ def parse_args(input_args=None):
450450
required=False,
451451
help="Should images be resized to --resolution before training?",
452452
)
453+
parser.add_argument(
454+
"--use_xformers", action="store_true", help="Whether or not to use xformers"
455+
)
453456

454457
if input_args is not None:
455458
args = parser.parse_args(input_args)
@@ -615,6 +618,10 @@ def main(args):
615618
)
616619
break
617620

621+
if args.use_xformers:
622+
set_use_memory_efficient_attention_xformers(unet, True)
623+
set_use_memory_efficient_attention_xformers(vae, True)
624+
618625
if args.gradient_checkpointing:
619626
unet.enable_gradient_checkpointing()
620627
if args.train_text_encoder:

train_lora_w_ti.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
save_lora_weight,
3737
extract_lora_ups_down,
3838
)
39-
40-
from torch.utils.data import Dataset
39+
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
4140
from PIL import Image
41+
from torch.utils.data import Dataset
4242
from torchvision import transforms
4343

4444
from pathlib import Path
@@ -575,6 +575,9 @@ def parse_args(input_args=None):
575575
required=False,
576576
help="Should images be resized to --resolution before training?",
577577
)
578+
parser.add_argument(
579+
"--use_xformers", action="store_true", help="Whether or not to use xformers"
580+
)
578581

579582
if input_args is not None:
580583
args = parser.parse_args(input_args)
@@ -774,6 +777,10 @@ def main(args):
774777
print("Before training: text encoder First Layer lora down", _down.weight.data)
775778
break
776779

780+
if args.use_xformers:
781+
set_use_memory_efficient_attention_xformers(unet, True)
782+
set_use_memory_efficient_attention_xformers(vae, True)
783+
777784
if args.gradient_checkpointing:
778785
unet.enable_gradient_checkpointing()
779786
if args.train_text_encoder:

0 commit comments

Comments
 (0)