-
Notifications
You must be signed in to change notification settings - Fork 706
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I was running the GSM8k example with asyn_grpo_full_finetune_distributed.py recipe with 4 GPUs on the same node. I added a logging logic that track what advantages the training worker pulled from the replay buffer, and here's what I got:
(PostProcessingWorker pid=4057218) [Step 0] Rewards: 4.062 | Success: 0.344
(PostProcessingWorker pid=4057218) Advantages: [0:-1.01, 1:-1.01, 2:0.61, 3:0.61, 4:0.61, 5:0.61, 6:-1.01, 7:0.61, 8:0.61, 9:0.61, 10:0.61, 11:-1.01, 12:0.61, 13:-2.62, 14:0.61, 15:0.61]
(ReplayBuffer( pid=4057217) INFO 01-04 13:24:30 [__init__.py:239] Automatically detected platform cuda.
(TrainingWorker pid=4057213) --------TrainingWorker | advantages: tensor([0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057,
(TrainingWorker pid=4057213) 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057, 0.6057],
(TrainingWorker pid=4057213) device='cuda:0')---------This is where I added the logging:
def _prepare_trajectory(
self, raw_trajectory: Trajectory
) -> tuple[GRPOTrajectory, int, dict[str, Any]]:
"""Process raw trajectory, compute rewards, and prepare for optimization.
Args:
raw_trajectory (Trajectory): The trajectory sampled from the replay buffer.
Returns:
tuple[trajectory, context_length, metadata]
"""
# Extract components from raw trajectory
query_responses = raw_trajectory.query_responses
responses = raw_trajectory.responses
logprobs = raw_trajectory.logprobs
ref_logprobs = raw_trajectory.ref_logprobs
query_response_padding_masks = raw_trajectory.query_response_padding_masks
seq_lens = raw_trajectory.seq_lens
advantages = raw_trajectory.advantages
answers = raw_trajectory.answers
utils.log_rank_zero(log, f"--------TrainingWorker | advantages: {advantages}---------")Here's the config file I used:
# All orchestration args
orchestration:
num_inference_workers: 1
num_postprocessing_workers: 1
num_training_workers: 1
replay_buffer_size: ${inference.batch_size} # TODO: Right now this can't be bigger, or else we'll get padding issues
num_steps: 250
# All inference args
inference:
engine: vllm
model: ${base_model_path}
top_k: null
temperature: 1.0
tensor_parallel_dim: 1
max_generated_tokens: 512
batch_size: 1
group_size: 16
total_batch_size: ${eval:'${inference.batch_size} * ${inference.group_size}'}
steps_before_weight_sync: 1
queue_maxsize: ${eval:'${orchestration.num_inference_workers} * ${training.steps_before_weight_sync}'}
...Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working