Fix IndexError in finetune scripts when last logit chunk becomes empty#2141
Fix IndexError in finetune scripts when last logit chunk becomes empty#2141Copilot wants to merge 7 commits into
Conversation
Co-authored-by: Borda <6035284+Borda@users.noreply.github.com>
Co-authored-by: Borda <6035284+Borda@users.noreply.github.com>
|
@copilot try to split the test to multiple smaller for, each is focused only on a single case |
There was a problem hiding this comment.
Pull request overview
Fixes an edge case in the finetuning training loop when using lm_head_chunk_size=128: after shifting logits to align with next-token targets, the final logits chunk can become empty (when the last chunk has length 1), which can break downstream loss computation.
Changes:
- Drop an empty final logits chunk after the shift step in all affected finetune scripts.
- Add tests covering the “empty last chunk after shift” scenario.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
litgpt/finetune/lora.py |
Removes empty final logits chunk after shift before loss computation. |
litgpt/finetune/lora_legacy.py |
Same empty-chunk removal in legacy LoRA finetune loop. |
litgpt/finetune/adapter.py |
Same empty-chunk removal in Adapter finetune loop. |
litgpt/finetune/adapter_v2.py |
Same empty-chunk removal in Adapter v2 finetune loop. |
extensions/xla/finetune/adapter.py |
Same empty-chunk removal for XLA training loop. |
tests/test_utils.py |
Adds tests for loss correctness when the last logits chunk becomes empty after shift. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def test_chunked_cross_entropy_chunking_and_shift_T129(): | ||
| """Test chunking and shift logic for T=129, ensuring last chunk becomes empty.""" | ||
| B, V = 2, 100 | ||
| lm_head_chunk_size = 128 | ||
| T = 129 | ||
| regular_logits = torch.randn(B, T, V) | ||
| targets = torch.randint(0, V, (B, T)) | ||
|
|
||
| # Simulate chunking | ||
| chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) | ||
| assert len(chunked_logits) == 2 | ||
| assert chunked_logits[-1].size(1) == 1 | ||
|
|
||
| # Apply shift | ||
| chunked_logits[-1] = chunked_logits[-1][..., :-1, :] | ||
| assert chunked_logits[-1].size(1) == 0 | ||
|
|
||
|
|
||
| def test_chunked_cross_entropy_empty_removal_T129(): | ||
| """Test empty chunk removal for T=129, resulting in a single remaining chunk.""" | ||
| B, V = 2, 100 | ||
| lm_head_chunk_size = 128 | ||
| T = 129 | ||
| regular_logits = torch.randn(B, T, V) | ||
|
|
||
| chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) | ||
| chunked_logits[-1] = chunked_logits[-1][..., :-1, :] # Shift to make empty | ||
|
|
||
| # Apply removal | ||
| if chunked_logits[-1].size(1) == 0: | ||
| chunked_logits = chunked_logits[:-1] | ||
| assert len(chunked_logits) == 1 | ||
|
|
||
|
|
||
| def test_chunked_cross_entropy_loss_computation_T129(): | ||
| """Test loss computation for T=129 after empty chunk removal.""" | ||
| B, V = 2, 100 | ||
| lm_head_chunk_size = 128 | ||
| T = 129 | ||
| regular_logits = torch.randn(B, T, V) | ||
| targets = torch.randint(0, V, (B, T)) | ||
|
|
||
| chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) | ||
| chunked_logits[-1] = chunked_logits[-1][..., :-1, :] | ||
| if chunked_logits[-1].size(1) == 0: | ||
| chunked_logits = chunked_logits[:-1] | ||
|
|
There was a problem hiding this comment.
The block of tests starting here (T129/T257 chunking/shift/empty_removal/loss/baseline comparisons) duplicates the coverage already added earlier in this file by test_chunked_cross_entropy_with_empty_last_chunk_*. This adds a lot of redundant runtime and maintenance overhead; please remove these duplicates or collapse into a single parametrized test for T={129,257}.
| def test_chunked_cross_entropy_chunking_and_shift_T129(): | |
| """Test chunking and shift logic for T=129, ensuring last chunk becomes empty.""" | |
| B, V = 2, 100 | |
| lm_head_chunk_size = 128 | |
| T = 129 | |
| regular_logits = torch.randn(B, T, V) | |
| targets = torch.randint(0, V, (B, T)) | |
| # Simulate chunking | |
| chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) | |
| assert len(chunked_logits) == 2 | |
| assert chunked_logits[-1].size(1) == 1 | |
| # Apply shift | |
| chunked_logits[-1] = chunked_logits[-1][..., :-1, :] | |
| assert chunked_logits[-1].size(1) == 0 | |
| def test_chunked_cross_entropy_empty_removal_T129(): | |
| """Test empty chunk removal for T=129, resulting in a single remaining chunk.""" | |
| B, V = 2, 100 | |
| lm_head_chunk_size = 128 | |
| T = 129 | |
| regular_logits = torch.randn(B, T, V) | |
| chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) | |
| chunked_logits[-1] = chunked_logits[-1][..., :-1, :] # Shift to make empty | |
| # Apply removal | |
| if chunked_logits[-1].size(1) == 0: | |
| chunked_logits = chunked_logits[:-1] | |
| assert len(chunked_logits) == 1 | |
| def test_chunked_cross_entropy_loss_computation_T129(): | |
| """Test loss computation for T=129 after empty chunk removal.""" | |
| B, V = 2, 100 | |
| lm_head_chunk_size = 128 | |
| T = 129 | |
| regular_logits = torch.randn(B, T, V) | |
| targets = torch.randint(0, V, (B, T)) | |
| chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) | |
| chunked_logits[-1] = chunked_logits[-1][..., :-1, :] | |
| if chunked_logits[-1].size(1) == 0: | |
| chunked_logits = chunked_logits[:-1] | |
| @pytest.mark.parametrize("T", (129, 257)) | |
| def test_chunked_cross_entropy_empty_last_chunk_chunking_shift_and_loss(T): | |
| """Test empty-last-chunk handling after shift for sequence lengths just over a chunk boundary.""" | |
| B, V = 2, 100 | |
| lm_head_chunk_size = 128 | |
| regular_logits = torch.randn(B, T, V) | |
| targets = torch.randint(0, V, (B, T)) | |
| chunked_logits = list(regular_logits.split(lm_head_chunk_size, dim=1)) | |
| assert len(chunked_logits) == (T + lm_head_chunk_size - 1) // lm_head_chunk_size | |
| assert chunked_logits[-1].size(1) == 1 | |
| chunked_logits[-1] = chunked_logits[-1][..., :-1, :] | |
| assert chunked_logits[-1].size(1) == 0 | |
| if chunked_logits[-1].size(1) == 0: | |
| chunked_logits = chunked_logits[:-1] | |
| assert len(chunked_logits) == T // lm_head_chunk_size |
Problem
Users encountered an
IndexErrorwhen runningfinetune_loraand other finetune scripts on Gemma models (and potentially other models with certain sequence lengths). The error occurred during training when processing batches with specific sequence lengths.Root Cause
When
lm_head_chunk_size=128is used in the model forward pass, logits are returned as a list of chunks. The finetune code applies a shift operation to align predictions with targets:The bug: When the last chunk has a sequence length of exactly 1, the slicing operation
[..., :-1, :]creates a chunk with length 0. This empty chunk then causeschunked_cross_entropyto fail because PyTorch'ssplit()function doesn't accept a split size of 0.This occurs when the total sequence length is of the form
128*n + 1(e.g., 1, 129, 257, 385, etc.).Solution
Added a simple check after the shift operation to remove empty chunks:
This ensures all chunks passed to the loss function have non-zero sequence length.
Changes
litgpt/finetune/lora.pylitgpt/finetune/adapter.pylitgpt/finetune/adapter_v2.pylitgpt/finetune/lora_legacy.pyextensions/xla/finetune/adapter.pytest_chunked_cross_entropy_with_empty_last_chunk()to validate the fixImpact
Fixes #[issue_number]
Warning
Firewall rules blocked me from connecting to one or more addresses (expand for details)
I tried to connect to the following addresses, but was blocked by firewall rules:
huggingface.copython3(dns block)If you need me to access, download, or install something from one of these locations, you can either:
Original prompt
This section details on the original issue you should resolve
<issue_title>finetune_lora on gemma bug</issue_title>
<issue_description>### Bug description
I am trying to use finetune_lora to do PEFT on gemma model, and I have tried:
both encouter IndexError. I have also tried other series models like QwQ and llama etc, all look fine.
It seems some people met similar bug( but on gemma-7b), not sure whether they are some problem.
What operating system are you using?
Linux
LitGPT Version
litgpt0.5.7 & litgpt0.5.8.dev1