-
-
Notifications
You must be signed in to change notification settings - Fork 4.1k
SFT sample packing #3566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SFT sample packing #3566
Conversation
072de80 to
efe4424
Compare
f352814 to
9cd0b26
Compare
8f06935 to
50e5104
Compare
|
I tested all the SFT notebooks on the main page README. I'd like to test a few more notebooks, including ones that we don't expect to use packing (multimodal training, RL training, ...). |
|
Okay, I ran through all the notebooks displayed in the README. As expected, non of the non-text-only SFT notebooks utilized sample packing, and none of them hit any errors when switching to this branch. |
for more information, see https://pre-commit.ci
585e26f to
e498912
Compare
e498912 to
6c90169
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
| window_size = (-1, -1) if (kv_seq_len <= sw) else (sw, sw) | ||
|
|
||
| use_varlen = ( | ||
| seq_info is not None and past_key_value is None and window_size == (-1, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q: So when past_key_value is not None aka decoding phase of generation, do we always use flash_attn_func and not flash_attn_varlen? Is that what this is trying to do? Does that place an implicit assumption that post prefill, inputs are padded to neat shapes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct! the flash attention varlen API doesn't support key/value caching, and we don't pack inputs during inference anyways, so we use the dense API. no such assumption AFAIK, not sure I follow you on that.
unsloth/utils/attention_dispatch.py
Outdated
| K: Tensor, | ||
| V: Tensor, | ||
| ) -> Tensor: | ||
| """Run attention using config / context info.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please add a couple of lines of info on which backend is used when and why?
| ) | ||
|
|
||
| if requires_grad: | ||
| K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the expansion happening? Otherwise reshaping to n_heads would not be possible right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just before this, K_mod / V_mod are viewed as (bsz, kv_seq_len, n_kv_heads, 1, head_dim) and then expanded:
K_mod = K_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
V_mod = V_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
K_mod = K_mod.expand(
bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
)
V_mod = V_mod.expand(
bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
)
if requires_grad:
K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
V_mod = V_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
else:
Q_mod = Q_t.view(
bsz, q_len, config.n_kv_heads, config.n_groups, head_dim
)
unsloth/utils/attention_dispatch.py
Outdated
| attn_bias, XFORMERS_BLOCK_DIAG_CLS | ||
| ) | ||
|
|
||
| if config.n_groups != 1 and not requires_grad and has_block: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Feel like we should condense these into the same check at the cost of repeating the code of single out = xformers_attn()...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I combined them but kept the single xformers_attn call, hopefully I understood you correctly.
for more information, see https://pre-commit.ci
This PR patches the SFTTrainer class' constructor to auto-enable sample packing when applicable (i.e., when the user is not training a multi-modal model or passing in a custom data collator). Sample packing is a good default in SFT since we reduce the amount of zero-padding we see and increase the training token/s throughput as a result.
Followup to #3525; can merge after that PR is.I closed that PR, I think we should just merge this one.Needs to be extensively tested so we don't break any existing notebooks / scripts.
Edit: This PR no longer auto-enables sample packing as it could slightly change the SFT training dynamics. Users must now pass
packing=Trueinto theirSFTConfigobjects.