Reinforcement Learning with GRPO Fine-Tuning a Small Language Model for Chain-of-Thought Math Reasoning. Similar to Deepseek R1 training
I found this lovely gist the is similar to DeepSeek R1 training.
This demonstrates how to train a small language model (Qwen or Llama) on the GSM8K math dataset using the new GRPO algorithm from the Hugging Face TRL library, where multiple reward functions (correctness, integer output, and XML formatting) guide the model to produce answers in a specific chain-of-thought style <reasoning>...</reasoning><answer>...</answer>)
Instead of simple supervised fine-tuning, the model is optimized via reinforcement learning generating multiple candidate outputs, scoring them with the reward functions, and updating its parameters accordingly. The script also provides optional LoRA (PEFT) configuration for more efficient fine-tuning, making it a compact example of how to implement a multi-reward RL framework to shape both the correctness and format of the model’s responses.
I love the simplicity of the reward function :)
From the author about hardware
works with peft on a single 80GB gpu for me, you can probably get it to work on a 40GB with a <1B model + shorter context length/smaller num_generations (though you probably won’t want to go lower than 4 or 6 samples, and will want grad accum >= 16 for stable training)
Napkin math (Qwen-1.5B, non-PEFT) is that GRPO needs the model (3GB) + reference copy (3GB) + Adam optimizer states (12GB) + activations. The latter adds up quickly with GRPO because you need a decent number of group samples (we’re already pushing it lower than is ideal, original paper uses 64 iirc)
60% is probably about as good as you should expect for Llama-1B, the original instruct only gets 44% so that’s already a nice bump. Qwen-1.5B can get closer to 80% with this setup from my tests. It might be worth playing around with adding more reward functions to incentivize better reasoning. You can have some fun here + treat it like prompt engineering (example).
I don’t want people to get their hopes up that they’ll magically see full-R1-like behavior, these are pretty tiny models, but I think it can be a cool testbed for different reward approaches
# train_grpo.py
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
#{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
#{'role': 'assistant', 'content': XML_COT_FORMAT.format(
# reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
# answer="7"
#)},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
#model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
if "Llama" in model_name:
output_dir = "outputs/Llama-1B-GRPO"
run_name = "Llama-1B-GRPO-gsm8k"
else:
output_dir="outputs/Qwen-1.5B-GRPO"
run_name="Qwen-1.5B-GRPO-gsm8k"
training_args = GRPOConfig(
output_dir=output_dir,
run_name=run_name,
learning_rate=5e-6,
adam_beta1 = 0.9,
adam_beta2 = 0.99,
weight_decay = 0.1,
warmup_ratio = 0.1,
lr_scheduler_type='cosine',
logging_steps=1,
bf16=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=16,
max_prompt_length=256,
max_completion_length=786,
num_train_epochs=1,
save_steps=100,
max_grad_norm=0.1,
report_to="wandb",
log_on_each_node=False,
)
peft_config = LoraConfig(
r=16,
lora_alpha=64,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map=None
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# use peft at your own risk; not working for me with multi-GPU training
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func],
args=training_args,
train_dataset=dataset,
#peft_config=peft_config
)
trainer.train()
GRPOConfig Training Arguments
- output_dir
- The folder where your model and related files will be saved.
- run_name
- A name you choose to identify this training run.
- learning_rate = 5e-6
- How fast the model learns. Smaller = slower but steadier learning.
- adam_beta1 = 0.9 and adam_beta2 = 0.99
- Settings for the Adam optimizer, controlling how weights get updated.
- weight_decay = 0.1
- Helps prevent overfitting by making large weights smaller.
- warmup_ratio = 0.1
- For the first 10% of training steps, the learning rate gradually goes from low to the set learning rate.
- lr_scheduler_type = ‘cosine’
- Learning rate follows a cosine pattern (starts high, decreases smoothly).
- logging_steps = 1
- Log training info after every single step.
- bf16 = True
- Uses bfloat16 precision to speed up training and reduce memory usage (on supported hardware).
- per_device_train_batch_size = 1
- Each GPU/device trains on 1 example at a time (small batch size).
- gradient_accumulation_steps = 4
- Accumulates gradients over 4 steps before updating (acts like a bigger batch size).
- num_generations = 16
- The number of model outputs (generations) the code might produce at certain steps.
- max_prompt_length = 256
- The maximum number of tokens in the input prompt.
- max_completion_length = 786
- The maximum number of tokens in the generated response.
- num_train_epochs = 1
- The model will go through the entire dataset exactly once.
- save_steps = 100
- Saves the model every 100 steps.
- max_grad_norm = 0.1
- Limits how large the gradients can get (gradient clipping).
- report_to = “wandb”
- Sends logs and metrics to the Weights & Biases service.
- log_on_each_node = False
- If using multiple machines, don’t log from each machine separately.
LoraConfig (PEFT) Arguments
- r = 16
- The “rank” in LoRA (Low-Rank Adaptation). A bigger rank means more parameters to learn, but also needs more memory.
- lora_alpha = 64
- A scaling factor for LoRA updates. Controls how strongly these extra parameters affect the model.
- target_modules = [“q_proj”, “k_proj”, “v_proj”, “o_proj”, “up_proj”, “down_proj”, “gate_proj”]
- The parts (layers) of the model where LoRA is applied (commonly Transformer projection layers).
- task_type = “CAUSAL_LM”
- Specifies the type of task (next-token prediction, a causal language modeling task).
- lora_dropout = 0.05
- A small percentage of LoRA’s parameters are randomly “dropped” during training to reduce overfitting.
What is LoRA?
LoRA (Low-Rank Adaptation) adds a small number of new parameters to an existing large model, and only these new parameters are trained. This keeps the main model mostly untouched, which makes the fine-tuning process more memory efficient.