Skip to content

hkust-nlp/B-STaR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 

Repository files navigation

B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners

📄 Paper   

B-STAR (Balanced Self-Taught Reasoner) is a framework designed to improve the self-improvement process of reasoning models by dynamically balancing exploration and exploitation throughout training. This approach is particularly effective in enhancing performance in tasks requiring complex reasoning, such as mathematical problem-solving, coding, and commonsense reasoning.

截屏2024-12-22 17 35 44

Overview

Self-improvement in reasoning models involves iterative training where models generate their own training data from outputs. However, existing methods often stagnate after a few iterations due to imbalances between two critical factors:

  1. Exploration: The model's ability to generate diverse and high-quality responses.
  2. Exploitation: The effectiveness of external rewards in distinguishing and leveraging high-quality responses.

截屏2024-12-22 17 40 13

B-STAR introduces an adaptive mechanism to monitor and balance these factors dynamically, ensuring consistent performance improvements over multiple training iterations

Key Features

  • Dynamic Configuration Adjustments: Automatically tunes exploration and exploitation configurations (e.g., sampling temperature, reward thresholds) to optimize the self-improvement process.
  • Balance Score Metric: Quantifies the interplay between exploration and exploitation, guiding dynamic adjustments.
  • Generalization Across Tasks: Demonstrates effectiveness in mathematical reasoning, coding challenges, and commonsense reasoning tasks

Results

B-STAR achieves state-of-the-art performance across various benchmarks:

  • Significant improvements compared to previsous self-improvement methods. 截屏2024-12-22 17 39 06

  • Sustained performance growth across multiple iterations, outperforming existing methods that stagnate after a few iterations. 截屏2024-12-22 17 39 31

Reproduction

Our code builds upon easy-to-hard and gpt-accelerate. Please refer to gpt-accelerate for environment setup and model weight conversion instructions.

1. Prepare Model

We first need to prepare the model checkpoint in the gpt-fast format.

export DATA_DIR=/path/to/your/data/directory export MODEL_REPO=mistralai/Mistral-7B-v0.1 python scripts/download.py \ --repo_id $MODEL_REPO \ --local_dir $DATA_DIR/checkpoints python scripts/convert_hf_checkpoint.py \ --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \ --target_precision bf16

2. Train SFT Model

export DATA_DIR=/path/to/your/data/directory export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1 export OMP_NUM_THREADS=8 SFT_TRAIN_DATA=https://huggingface.co/datasets/AndrewZeng/math-trn-format/blob/main/math_format.json # Please download this dataset to local folder SFT_MODEL_SAVE_NAME=math_format_11k_mistral torchrun --standalone --nproc_per_node=8 \ train_sft.py \ --do_train \ --checkpoint_path $MODEL_REPO/model.pth \ --source_max_len 768 \ --target_max_len 768 \ --total_max_len 1024 \ --per_device_train_batch_size 16 \ --micro_train_batch_size 4 \ --learning_rate 5e-6 \ --lr_eta_min 2e-7 \ --num_train_epochs 3 \ --dataset "$SFT_TRAIN_DATA" \ --dataset_format "metamath" \ --add_eos_to_marked_target \ --save_strategy "steps" \ --save_steps 25 \ --optim_dtype bf16 \ --save_total_limit 40 \ --tensor_parallel_size 1 \ --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \ --resume_from_checkpoint

3. Train PRM Model

We constructed the PRM training data using the math-shepherd approach and trained the reward model using a pointwise objective.

export DATA_DIR=/path/to/your/data/directory export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1 export OMP_NUM_THREADS=4 RM_DATA=train_prm_math_shepherd_mistral.json RM_MODEL_SAVE_NAME=prm_model_mistral_sample_complete torchrun --standalone --nproc_per_node=8 \ train_rm_pointwise.py \ --do_train \ --checkpoint_path $MODEL_REPO/model.pth \ --source_max_len 768 \ --target_max_len 768 \ --total_max_len 1024 \ --per_device_train_batch_size 32 \ --micro_train_batch_size 32 \ --learning_rate 2e-6 \ --lr_eta_min 2e-7 \ --num_train_epochs 2 \ --dataset "$RM_DATA" \ --dataset_format "prm-v4" \ --save_strategy epoch \ --save_total_limit 5 \ --train_on_every_token \ --tensor_parallel_size 1 \ --save_only_model True \ --optim_dtype bf16 \ --save_dir $DATA_DIR/checkpoints/$RM_MODEL_SAVE_NAME \ --resume_from_checkpoint

4. Train B-STaR

## This is our initial release code.  ## We are working hard to clean it to make our code more clear and more readable cd train_code bash train_bstar.sh

5. Evaluation

Coming Soon !

Citation

If you find B-STaR useful, please cite our paper:

@article{zeng2024bstar, title={B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners}, author={Weihao Zeng, Yuzhen Huang, Lulu Zhao, Yijun Wang, Zifei Shan, Junxian He}, journal={arXiv preprint arXiv:2412.17256}, year={2024}, url={https://arxiv.org/abs/2412.17256} } 

About

B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •