Skip to content

Conversation

@djsaunde
Copy link
Collaborator

@djsaunde djsaunde commented Nov 7, 2025

This PR patches the SFTTrainer class' constructor to auto-enable sample packing when applicable (i.e., when the user is not training a multi-modal model or passing in a custom data collator). Sample packing is a good default in SFT since we reduce the amount of zero-padding we see and increase the training token/s throughput as a result.

Followup to #3525; can merge after that PR is. I closed that PR, I think we should just merge this one.

Needs to be extensively tested so we don't break any existing notebooks / scripts.

Edit: This PR no longer auto-enables sample packing as it could slightly change the SFT training dynamics. Users must now pass packing=True into their SFTConfig objects.

@djsaunde djsaunde self-assigned this Nov 7, 2025
@djsaunde djsaunde removed the request for review from danielhanchen November 7, 2025 19:03
@djsaunde djsaunde changed the title auto-enable sample packing auto-enable SFT sample packing Nov 7, 2025
@djsaunde djsaunde force-pushed the auto-packing branch 2 times, most recently from 072de80 to efe4424 Compare November 10, 2025 20:23
@djsaunde djsaunde force-pushed the auto-packing branch 3 times, most recently from f352814 to 9cd0b26 Compare November 19, 2025 19:14
@djsaunde djsaunde mentioned this pull request Nov 20, 2025
1 task
@djsaunde djsaunde marked this pull request as ready for review November 20, 2025 13:47
@djsaunde
Copy link
Collaborator Author

I tested all the SFT notebooks on the main page README. I'd like to test a few more notebooks, including ones that we don't expect to use packing (multimodal training, RL training, ...).

@djsaunde
Copy link
Collaborator Author

Okay, I ran through all the notebooks displayed in the README. As expected, non of the non-text-only SFT notebooks utilized sample packing, and none of them hit any errors when switching to this branch.

@djsaunde djsaunde changed the title auto-enable SFT sample packing SFT sample packing Nov 21, 2025
window_size = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)

use_varlen = (
seq_info is not None and past_key_value is None and window_size == (-1, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: So when past_key_value is not None aka decoding phase of generation, do we always use flash_attn_func and not flash_attn_varlen? Is that what this is trying to do? Does that place an implicit assumption that post prefill, inputs are padded to neat shapes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct! the flash attention varlen API doesn't support key/value caching, and we don't pack inputs during inference anyways, so we use the dense API. no such assumption AFAIK, not sure I follow you on that.

K: Tensor,
V: Tensor,
) -> Tensor:
"""Run attention using config / context info."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add a couple of lines of info on which backend is used when and why?

)

if requires_grad:
K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the expansion happening? Otherwise reshaping to n_heads would not be possible right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just before this, K_mod / V_mod are viewed as (bsz, kv_seq_len, n_kv_heads, 1, head_dim) and then expanded:

            K_mod = K_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
            V_mod = V_t.view(bsz, kv_seq_len, config.n_kv_heads, 1, head_dim)
            K_mod = K_mod.expand(
                bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
            )
            V_mod = V_mod.expand(
                bsz, kv_seq_len, config.n_kv_heads, config.n_groups, head_dim
            )

            if requires_grad:
                K_mod = K_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
                V_mod = V_mod.reshape(bsz, kv_seq_len, n_heads, head_dim)
            else:
                Q_mod = Q_t.view(
                    bsz, q_len, config.n_kv_heads, config.n_groups, head_dim
                )

attn_bias, XFORMERS_BLOCK_DIAG_CLS
)

if config.n_groups != 1 and not requires_grad and has_block:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Feel like we should condense these into the same check at the cost of repeating the code of single out = xformers_attn()...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I combined them but kept the single xformers_attn call, hopefully I understood you correctly.

@danielhanchen danielhanchen merged commit 50325e0 into unslothai:main Dec 10, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants