Skip to content

Bug in async_grpo_full_finetune recipe: TrainingWorker receives identical advantages #2943

@jiatong-yu

Description

@jiatong-yu

I was running the GSM8k example with asyn_grpo_full_finetune_distributed.py recipe with 4 GPUs on the same node. I added a logging logic that track what advantages the training worker pulled from the replay buffer, and here's what I got:

(PostProcessingWorker pid=4057218) [Step 0] Rewards: 4.062 | Success: 0.344
(PostProcessingWorker pid=4057218)   Advantages: [0:-1.01, 1:-1.01, 2:0.61, 3:0.61, 4:0.61, 5:0.61, 6:-1.01, 7:0.61, 8:0.61, 9:0.61, 10:0.61, 11:-1.01, 12:0.61, 13:-2.62, 14:0.61, 15:0.61]
(ReplayBuffer( pid=4057217) INFO 01-04 13:24:30 [__init__.py:239] Automatically detected platform cuda.
(TrainingWorker pid=4057213) --------TrainingWorker | advantages: tensor([0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057,
(TrainingWorker pid=4057213)         0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057],
(TrainingWorker pid=4057213)        device='cuda:0')---------

This is where I added the logging:

    def _prepare_trajectory(
        self, raw_trajectory: Trajectory
    ) -> tuple[GRPOTrajectory, int, dict[str, Any]]:
        """Process raw trajectory, compute rewards, and prepare for optimization.

        Args:
            raw_trajectory (Trajectory): The trajectory sampled from the replay buffer.

        Returns:
            tuple[trajectory, context_length, metadata]
        """
        # Extract components from raw trajectory
        query_responses = raw_trajectory.query_responses
        responses = raw_trajectory.responses
        logprobs = raw_trajectory.logprobs
        ref_logprobs = raw_trajectory.ref_logprobs
        query_response_padding_masks = raw_trajectory.query_response_padding_masks
        seq_lens = raw_trajectory.seq_lens
        advantages = raw_trajectory.advantages
        answers = raw_trajectory.answers

        utils.log_rank_zero(log, f"--------TrainingWorker | advantages: {advantages}---------")

Here's the config file I used:

# All orchestration args
orchestration:
  num_inference_workers: 1
  num_postprocessing_workers: 1
  num_training_workers: 1
  replay_buffer_size: ${inference.batch_size}  # TODO: Right now this can't be bigger, or else we'll get padding issues
  num_steps: 250

# All inference args
inference:
  engine: vllm
  model: ${base_model_path}
  top_k: null
  temperature: 1.0
  tensor_parallel_dim: 1
  max_generated_tokens: 512
  batch_size: 1
  group_size: 16
  total_batch_size: ${eval:'${inference.batch_size} * ${inference.group_size}'}
  steps_before_weight_sync: 1
  queue_maxsize: ${eval:'${orchestration.num_inference_workers} * ${training.steps_before_weight_sync}'}
...

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions