Skip to content

FSDP Doesn't Work with model.generate()  #30228

@QiyaoWei

Description

@QiyaoWei

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • transformers version: 4.39.3
  • Platform: Linux-5.15.0-1059-azure-x86_64-with-glibc2.31
  • Python version: 3.10.11
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.2
  • Accelerate config: - compute_environment: LOCAL_MACHINE
    - distributed_type: FSDP
    - mixed_precision: bf16
    - use_cpu: False
    - debug: False
    - num_processes: 2
    - machine_rank: 0
    - num_machines: 1
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - enable_cpu_affinity: False
    - fsdp_config: {'fsdp_auto_wrap_policy': 'TRANSFORMER_BASED_WRAP', 'fsdp_backward_prefetch': 'BACKWARD_PRE', 'fsdp_cpu_ram_efficient_loading': True, 'fsdp_forward_prefetch': False, 'fsdp_offload_params': True, 'fsdp_sharding_strategy': 'FULL_SHARD', 'fsdp_state_dict_type': 'SHARDED_STATE_DICT', 'fsdp_sync_module_states': True, 'fsdp_use_orig_params': True}
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env: []
  • PyTorch version (GPU?): 2.2.2 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: YES
  • Using distributed or parallel set-up in script?: YES

Who can help?

@ArthurZucker @younesbelkada @gante for the relevance with text models and generate()

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am trying to use FSDP, but for some reason there is an error when I do model.generate(). MWE below

import torch
import os
from omegaconf import DictConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    StateDictType,
    BackwardPrefetch,
    ShardingStrategy,
    CPUOffload,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import functools
import subtrainers

class BasicTrainer(object):
    def __init__(self):

        model_name_or_path = "openai-community/gpt2-large"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.policy = AutoModelForCausalLM.from_pretrained(model_name_or_path)
        
        tokenized = self.tokenizer("hi there", return_tensors="pt").to(self.policy.device)
        print(self.policy.generate(**tokenized))
        return
    
    def train(self):
        pass
    
def get_block_class_from_model(model: torch.nn.Module, block_class_name: str) -> torch.nn.Module:
    """Get the class of a block from a model, using the block's class name."""
    for module in model.modules():
        if module.__class__.__name__ == block_class_name:
            return module.__class__
    raise ValueError(f"Could not find block class {block_class_name} in model {model}")

def init_distributed(rank: int, world_size: int, master_addr: str = 'localhost', port: int = 12355, backend: str = 'nccl'):
    print(rank, 'initializing distributed')
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = str(port)
    torch.distributed.init_process_group(backend, rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def worker_main(rank: int, world_size: int):

    init_distributed(rank, world_size)
    print(f'Creating trainer on process {rank} with world size {world_size}')
    trainer = FSDPTrainer()

    # trainer.train()
    # trainer.save()


def main():

    world_size = torch.cuda.device_count()
    print('starting', world_size, 'processes for FSDP training')
    torch.multiprocessing.spawn(worker_main, nprocs=world_size, args=(world_size,), join=True)
        
class FSDPTrainer(BasicTrainer):
    def __init__(self):

        super().__init__()
        
        model_name_or_path = "openai-community/gpt2-large"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)#.to('cuda')
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.policy = AutoModelForCausalLM.from_pretrained(model_name_or_path).to('cuda')

        wrap_class = get_block_class_from_model(self.policy, "GPT2Block")
        model_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={wrap_class},)

        shared_fsdp_kwargs = dict(
            auto_wrap_policy=model_auto_wrap_policy,
            sharding_strategy=ShardingStrategy.FULL_SHARD,
            cpu_offload=CPUOffload(offload_params=False),
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            ignored_modules=None,
            limit_all_gathers=False,
            use_orig_params=False,
            sync_module_states=False
        )
        mp_dtype = None
        policy_mp_policy = MixedPrecision(param_dtype=mp_dtype, reduce_dtype=mp_dtype, buffer_dtype=mp_dtype)
        self.policy = FSDP(self.policy, **shared_fsdp_kwargs, mixed_precision=policy_mp_policy)
        
        tokenized = self.tokenizer("hi there", return_tensors="pt").to(self.policy.device)
        print(self.policy.generate(**tokenized))
        return
    
if __name__ == '__main__':

    main() #BasicTrainer works, but FSDPTrainer errors

Error below

starting 2 processes for FSDP training
1 initializing distributed
0 initializing distributed
Creating trainer on process 0 with world size 2
Creating trainer on process 1 with world size 2
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py:1132: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py:1132: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
tensor([[5303,  612,  318,  257, 1256,  286,  670,  284,  307, 1760,   13,  198,
          198,    1, 1135,  423,  284,  787, 1654,  326]])
tensor([[5303,  612,  318,  257, 1256,  286,  670,  284,  307, 1760,   13,  198,
          198,    1, 1135,  423,  284,  787, 1654,  326]])
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Traceback (most recent call last):
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 98, in <module>
    main()
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 61, in main
    torch.multiprocessing.spawn(worker_main, nprocs=world_size, args=(world_size,), join=True)
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 158, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/anaconda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 51, in worker_main
    trainer = FSDPTrainer()
  File "/home/azureuser/f-divergence-dpo/mwe.py", line 92, in __init__
    print(self.policy.generate(**tokenized))
  File "/anaconda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1527, in generate
    result = self._greedy_search(
  File "/anaconda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2411, in _greedy_search
    outputs = self(
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1074, in forward
    transformer_outputs = self.transformer(
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 837, in forward
    inputs_embeds = self.wte(input_ids)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/anaconda/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
  File "/anaconda/lib/python3.10/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

Expected behavior

The code provided should not error

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions