Minimal-RL  by RLHFlow

LLM fine-tuning for mathematical reasoning via RL

Created 9 months ago
256 stars

Top 98.5% on SourcePulse

GitHubView on GitHub
Project Summary

<2-3 sentences summarising what the project addresses and solves, the target audience, and the benefit.> This project investigates reinforcement learning (RL) algorithms for fine-tuning large language models (LLMs) on mathematical reasoning tasks. It compares RAFT++ (rejection sampling), Vanilla Reinforce, and GRPO to understand factors behind LLM fine-tuning success. The research offers insights into algorithm performance, convergence, and exploration strategies, introducing Reinforce-rej, a new, more KL-efficient variant.

How It Works

The project revisits and enhances RL algorithms for LLM post-training. RAFT++ is a basic rejection sampling method with added importance sampling and clipping. Vanilla Reinforce is a simplified policy gradient algorithm without a critic. GRPO, a Reinforce variant, samples multiple responses per prompt and normalizes rewards. Key findings show RAFT++ offers competitive performance and faster early convergence. The research emphasizes that while positive-only training accelerates convergence, negative samples are vital for exploration and preventing distributional collapse, a benefit lacking in RAFT++. GRPO's advantage over standard Reinforce is attributed to its implicit filtering of prompts with universally incorrect responses.

Quick Start & Requirements

Environment setup requires a Python virtual environment (e.g., python -m venv ~/.python/raftpp) or Conda. Dependencies include PyTorch 2.4.0 with CUDA 12.4 (torch==2.4.0 --index-url https://download.pytorch.org/whl/cu124), flash-attn, and

Health Check
Last Commit

8 months ago

Responsiveness

Inactive

Pull Requests (30d)
0
Issues (30d)
0
Star History
5 stars in the last 30 days

Explore Similar Projects

Feedback? Help us improve.