Skip to content

Commit 3d976d2

Browse files
authored
Merge pull request #44 from Agent-RL/re-call
feat: add implementation of ReCall
2 parents 3b2fd14 + 7fcf4d3 commit 3d976d2

File tree

139 files changed

+12642
-3396
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+12642
-3396
lines changed

README.md

Lines changed: 67 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,79 @@
11
<div align="center">
22

3-
# ***ReSearch***: Learning to ***Re***ason with ***Search*** for LLMs via Reinforcement Learning
3+
# ***ReCall***: Learning to ***Re***ason with Tool ***Call*** for LLMs via Reinforcement Learning
44

5-
[![Arxiv](https://img.shields.io/badge/paper-A82F27?style=for-the-badge&logo=arxiv)](https://arxiv.org/abs/2503.19470) [![Model](https://img.shields.io/badge/model-4169E1?style=for-the-badge&logo=huggingface)](https://huggingface.co/collections/agentrl/research-67e506a0311bea06dc54878b)
5+
[![Notion](https://img.shields.io/badge/blog-black?style=for-the-badge&logo=notion)](https://attractive-almandine-935.notion.site/ReCall-Learning-to-Reason-with-Tool-Call-for-LLMs-via-Reinforcement-Learning-1d7aec91e9bb8006ad40f9edbfe2191a) [![Arxiv](https://img.shields.io/badge/paper-A82F27?style=for-the-badge&logo=arxiv)](https://arxiv.org/abs/2503.19470) [![Model](https://img.shields.io/badge/model-4169E1?style=for-the-badge&logo=huggingface)](https://huggingface.co/collections/agentrl/research-67e506a0311bea06dc54878b)
66

77
</div>
88

9+
We introduce ***ReCall***, a novel framework that trains LLMs to ***Re***ason with Tool ***Call*** via reinforcement learning—without requiring any supervised data on tool use trajectories or reasoning steps. *ReCall* empowers LLMs to agentically use and combine arbitrary tools like [OpenAI o3](https://openai.com/index/introducing-o3-and-o4-mini/), offering an accessible approach toward general-purpose agents. Additionally, we provide a novel perspective to generate synthetic data with diverse environments and complex multi-step tasks, enabling LLMs to develop sophisticated tool-based reasoning capabilities. This is a work in progress and we are actively working on it.
10+
11+
> [!IMPORTANT]
12+
> *ReCall* is the successor to [*ReSearch*](https://arxiv.org/abs/2503.19470) and represents a more comprehensive framework that extends beyond the search tool to support reasoning with any user-defined tools. It can be a drop-in replacement of *ReSearch*. We've archived the original implementation of *ReSearch* in the branch `re-search`.
13+
914
<p align="center">
10-
<img src="./assets/intro_bar.png" width="90%" alt="Intro" />
11-
<img src="./assets/method.png" width="90%" alt="Method" />
15+
<img src="./assets/overview.png" width="90%" alt="Overview" />
16+
<img src="./assets/eval_bar.png" width="90%" alt="Eval" />
1217
</p>
1318

14-
We propose ***ReSearch***, a novel framework that trains LLMs to ***Re***ason with ***Search*** via reinforcement learning without using any supervised data on reasoning steps. Our approach treats search operations as integral components of the reasoning chain, where when and how to perform searches is guided by text-based thinking, and search results subsequently influence further reasoning.
15-
1619
## 📰 News
17-
- **[2025-03-27]** 🤗 We release our trained models on [Hugging Face](https://huggingface.co/collections/agentrl/research-67e506a0311bea06dc54878b), please check it out!
18-
- **[2025-03-26]** 🎉 We release the paper, update the code and open-source the models.
20+
- **[2025-04-24]** 🎉 We release the first version of *ReCall*, and archive the original implementation of *ReSearch*.
21+
- ➡️ The name of the repository is changed from *ReSearch* to *ReCall*.
22+
- 📝 We release a [blog](https://attractive-almandine-935.notion.site/ReCall-Learning-to-Reason-with-Tool-Call-for-LLMs-via-Reinforcement-Learning-1d7aec91e9bb8006ad40f9edbfe2191a) to introduce the idea of *ReCall*.
23+
- 📦 Current implementation of *ReCall* is based on verl 0.3.0 + vllm 0.8.4.
24+
- **[2025-03-27]** 🤗 We release our trained *ReSearch* models on [Hugging Face](https://huggingface.co/collections/agentrl/research-67e506a0311bea06dc54878b), please check it out!
25+
- **[2025-03-26]** 🎉 We release the paper and update the code of *ReSearch*.
1926
- 📝 The **paper is released** on arXiv, more details and evaluation results can be found in our [paper](https://arxiv.org/abs/2503.19470).
2027
- 🛠️ The **repository is updated** with the new implementation, especially the rollout with search during RL training. This version of implementation is based on the latest release of verl.
21-
- **[2025-03-03]** ✅ We have released the preview version of ReSearch implementation.
28+
- **[2025-03-03]** ✅ We have released the preview version of *ReSearch* implementation.
2229

2330
## 📦 Installation
2431

2532
We recommend using conda to manage the environment. First create a conda environment and activate it.
2633
```bash
27-
conda create -n re-search python==3.10
28-
conda activate re-search
34+
conda create -n re-call python==3.10
35+
conda activate re-call
2936
```
30-
Then install dependencies, and our modified verl and flashrag packages under ```src/``` will be installed in the editable mode. Check out ```setup.py``` for details.
37+
Then install dependencies, and the packages under ```src/``` will be installed in the editable mode. Check out ```setup.py``` for details.
3138
```bash
32-
pip3 install torch==2.4.0 --index-url https://download.pytorch.org/whl/cu124
33-
pip3 install flash-attn --no-build-isolation
34-
git clone https://github.com/Agent-RL/ReSearch.git
35-
cd ReSearch
39+
git clone https://github.com/Agent-RL/ReCall.git
40+
cd ReCall
3641
pip3 install -e .
42+
pip3 install flash-attn --no-build-isolation
3743
```
38-
As described in the [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG?tab=readme-ov-file#wrench-installation), due to the incompatibility when installing faiss using pip, we need to use the following conda command to install faiss-gpu.
44+
If you want to host a Wikipedia RAG system based on FlashRAG, you need to install faiss-gpu as follow. As described in the [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG?tab=readme-ov-file#wrench-installation), due to the incompatibility when installing faiss using pip, we need to use the following conda command to install faiss-gpu.
3945
```bash
4046
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
4147
```
4248

4349
## 🚀 Quick Start
4450

45-
### Retriever Serving
51+
> If you want to learn the details of current version of *ReCall*, please refer to the [blog](https://attractive-almandine-935.notion.site/ReCall-Learning-to-Reason-with-Tool-Call-for-LLMs-via-Reinforcement-Learning-1d7aec91e9bb8006ad40f9edbfe2191a) first.
52+
53+
### Data Preparation
4654

47-
As described in our paper, during model training and evaluation, search operation will be conducted in the rollout and inference process. In practice, we host a retriever service via FlashRAG and FastAPI. Hence, the search operation is standardized to be an API call. This serving can be used to decouple the search operation from the reinforcement learning process, making the training and evaluation more clear and flexible.
55+
*ReCall* is trained on a mixture of our synthetic dataset `SynTool` and the training set of `MuSiQue`. You can download the preprocessed training data from [here](https://huggingface.co/datasets/agentrl/ReCall-data), and use such data directly for training.
4856

49-
Before starting the retriever serving, you need download the [pre-indexed wikipedia](https://github.com/RUC-NLPIR/FlashRAG?tab=readme-ov-file#index), [wikipedia corpus and corresponding retriever models](https://github.com/RUC-NLPIR/FlashRAG/blob/main/docs/original_docs/reproduce_experiment.md#preliminary). More details can be found in the documentation of FlashRAG.
57+
### Sandbox Serving
5058

51-
For starting the retriever serving, you need to first fill the `scripts/serving/retriever_config.yaml` with the correct path to the retrieval model, index, and corpus, and available GPU ids. Then, you can run the following command to start the retriever serving:
59+
Since tools are implemented in executable Python code, the tool executor is responsible for running the Python code. To ensure safety and security, we implement a sandbox for running Python code on a remote server. To launch the sandbox service, run the following command:
5260
```bash
5361
cd scripts/serving
54-
python retriever_serving.py \
55-
--config retriever_config.yaml \
56-
--num_retriever {num_retriever} \
57-
--port {port}
62+
python sandbox.py --port {port}
5863
```
64+
Note: The current implementation is a basic sandbox environment. We plan to use a more robust and secure sandbox in future updates. We recommend hosting the sandbox on a remote server, as local hosting may expose your machine to potential security risks.
5965

60-
The started retriever serving will be used in the training and evaluation process in the following part.
61-
62-
### Data Preparation
66+
### Retriever Serving
6367

64-
*ReSearch* is trained on the training set of MuSiQue, and evaluated on the dev set of HotpotQA, 2WikiMultiHopQA, MuSiQue and Bamboogle. For downloading the datasets, please refer to the `data/download_dataset.sh` script.
65-
```bash
66-
cd data
67-
bash download_dataset.sh
68-
```
68+
For training on MuSiQue data with a Wikipedia search tool, we provide a Wikipedia retriever service implemented using FlashRAG and FastAPI. Before starting the retriever serving, you need download the [pre-indexed wikipedia](https://github.com/RUC-NLPIR/FlashRAG?tab=readme-ov-file#index), [wikipedia corpus and corresponding retriever models](https://github.com/RUC-NLPIR/FlashRAG/blob/main/docs/original_docs/reproduce_experiment.md#preliminary). More details can be found in the documentation of FlashRAG.
6969

70-
For preparing the training and validation data for following reinforcement learning, please run this script to parse the MuSiQue dataset to the parquet format.
70+
For starting the retriever serving, you need to first fill the `scripts/serving/retriever_config.yaml` with the correct path to the retrieval model, index, and corpus, and available GPU ids. Then, you can run the following command to start the retriever serving:
7171
```bash
72-
cd data
73-
python prepare_musique.py
72+
cd scripts/serving
73+
python retriever_serving.py \
74+
--config retriever_config.yaml \
75+
--num_retriever {num_retriever} \
76+
--port {port}
7477
```
7578

7679
### Training
@@ -83,11 +86,12 @@ Here is an example of training Qwen2.5-7B-Instruct with 4 GPUs locally. Note tha
8386
cd scripts/train
8487
bash train.sh \
8588
--train_batch_size 8 \
86-
--ppo_mini_batch_size 8 \
87-
--apply_chat True \
88-
--prompt_template_name re_search_template_sys \
89+
--ppo_mini_batch_size 4 \
90+
--use_re_call True \
91+
--prompt_template_name re_call_template_sys \
8992
--actor_model_path {model/path/to/qwen2.5-7b-instruct} \
9093
--search_url {your-hosted-retriever-url} \
94+
--sandbox_url {your-hosted-sandbox-url} \
9195
--project_name {wandb-project-name} \
9296
--experiment_name {wandb-experiment-name} \
9397
--nnodes 1 \
@@ -97,19 +101,18 @@ bash train.sh \
97101
--total_epochs 2 \
98102
--wandb_api_key {your-wandb-api-key} \
99103
--save_path {path/to/save} \
100-
--train_files {path/to/train/parquet/data} \
101-
--test_files {path/to/test/parquet/data}
104+
--train_files "['train1.parquet', 'train2.parquet']" \
105+
--test_files "['test1.parquet', 'test2.parquet']"
102106
```
103-
- For training base (pre-trained) models, please use `--apply_chat False` and `--prompt_template_name re_search_template`
104-
- For training instruction-tuned models, please use `--apply_chat True` and `--prompt_template_name re_search_template_sys`
105107

106108
#### Multi-node training
107109

108-
If you want to **fully reproduce** the results in our paper, please refer to the multi-node training script in `scripts/train/train_multi_node.sh`, as well as the implementation details in our paper.
110+
If you want to **fully reproduce** *ReCall*, please refer to the multi-node training script in `scripts/train/train_multi_node.sh`.
109111

110-
### Evaluation
111-
112-
We recommend using [SGLang](https://docs.sglang.ai/) to serve the trained model. You can download our open-sourced models or trained your own models to conduct the evaluation. Here is an example of launching the model serving:
112+
### Inference
113+
This section demonstrates how to perform inference using the trained *ReCall* model. We provide a standard wrapper class in `src/re_call/inference/re_call.py` that simplifies the inference process. To get started, you only need to provide the model URL and sandbox URL, then use the `run` function to execute inference. The `ReCall` class handles all the orchestration between model generation and tool execution internally. For a practical example of using the `ReCall` class, please refer to our sample implementation at `scripts/inference/re_call_use_case.py`.
114+
115+
For model serving, we recommend using [SGLang](https://docs.sglang.ai/). You can either download our open-source models or train your own models to conduct the inference. Here is an example of how to launch the model service:
113116
```bash
114117
python3 -m sglang.launch_server \
115118
--served-model-name {trained/model/name} \
@@ -125,28 +128,39 @@ python3 -m sglang.launch_server \
125128
--disable-radix-cache
126129
```
127130

128-
We use [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG) as the standard evaluation environment. Here is an example of evaluating the performance of ReSearch-Qwen-7B-Instruct on Bamboogle test set.
131+
### Evaluation
132+
133+
#### Multi-hop QA
134+
135+
For the evaluation on multi-hop QA, we use [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG) as the standard evaluation environment. For downloading the evaluation data, please run the following command:
136+
```bash
137+
cd data
138+
bash download_dataset.sh
139+
```
140+
Here is an example of evaluating the performance of ReCall-Qwen-7B-Instruct on Bamboogle test set.
129141
```bash
130142
cd scripts/evaluation
131143
python run_eval.py \
132144
--config_path eval_config.yaml \
133-
--method_name research \
145+
--method_name re-call \
134146
--data_dir {root/path/to/evaluation/data} \
135147
--dataset_name bamboogle \
136148
--split test \
137149
--save_dir {your-save-dir} \
138-
--save_note research_qwen7b_ins
150+
--save_note re-call_qwen7b_ins
139151
--sgl_remote_url {your-launched-sgl-url} \
140152
--remote_retriever_url {your-hosted-retriever-url} \
141153
--generator_model {your-local-model-path} \
142-
--apply_chat True
154+
--sandbox_url {your-hosted-sandbox-url}
143155
```
156+
For more details about the configuration, please refer to the `scripts/evaluation/eval_config.yaml` file.
144157

145-
For base model, please use `--apply_chat False` and for instruction-tuned model, please use `--apply_chat True`, for loading correct prompt template when conducting evaluation for *ReSearch* model. For more details about the configuration, please refer to the `scripts/evaluation/eval_config.yaml` file.
158+
#### BFCL
159+
We will release the evaluation code on BFCL soon.
146160

147161
## 🤝 Acknowledge
148162

149-
This training implementation is based on [verl](https://github.com/volcengine/verl) and the evaluation is based on [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG). The serving of retriever is based on [FastAPI](https://github.com/fastapi/fastapi). The model serving is based on [SGLang](https://docs.sglang.ai/). *ReSearch* models are trained based on [Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/). We sincerely appreciate their contributions to the open-source community.
163+
This training implementation is based on [verl](https://github.com/volcengine/verl) and the evaluation is based on [FlashRAG](https://github.com/RUC-NLPIR/FlashRAG) and BFCL. The serving of sandbox and retriever is based on [FastAPI](https://github.com/fastapi/fastapi). The model serving is based on [SGLang](https://docs.sglang.ai/). *ReCall* models are trained based on [Qwen2.5](https://qwenlm.github.io/blog/qwen2.5/). We sincerely appreciate their contributions to the open-source community.
150164

151165
## 📚 Citation
152166

assets/eval_bar.png

682 KB
Loading

assets/intro_bar.png

-1.26 MB
Binary file not shown.

assets/method.png

-200 KB
Binary file not shown.

assets/overview.png

535 KB
Loading

data/prepare_musique.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

scripts/evaluation/eval_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ retrieval_pooling_method: ~ # set automatically if not provided
2828
# -------------------------------------------------Generator Settings------------------------------------------------#
2929
framework: sgl_remote # inference frame work of LLM, supporting: 'hf','vllm','fschat'
3030
sgl_remote_url: "your-sgl-remote-url"
31+
sandbox_url: "your-sandbox-url"
3132
generator_model: "the-model-local-path" # name or path of the generator model, for laoding tokenizer
3233
generator_max_input_len: 8192 # max length of the input
3334
generation_params:

scripts/evaluation/run_eval.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,35 +71,35 @@ def ircot(args, config_dict):
7171

7272
result = pipeline.run(test_data)
7373

74-
def research(args, config_dict):
74+
def re_call(args, config_dict):
7575
config = Config(args.config_path, config_dict)
7676
all_split = get_dataset(config)
7777
test_data = all_split[args.split]
7878

79-
from flashrag.pipeline import ReSearchPipeline
80-
pipeline = ReSearchPipeline(config, apply_chat=args.apply_chat)
79+
from flashrag.pipeline import ReCallPipeline
80+
pipeline = ReCallPipeline(config)
8181
result = pipeline.run(test_data)
8282

8383
if __name__ == "__main__":
8484
parser = argparse.ArgumentParser(description="Running exp")
8585
parser.add_argument("--config_path", type=str, default="./eval_config.yaml")
86-
parser.add_argument("--method_name", type=str, default="research")
86+
parser.add_argument("--method_name", type=str, default="re-call")
8787
parser.add_argument("--data_dir", type=str, default="your-data-dir")
8888
parser.add_argument("--dataset_name", type=str, default="bamboogle")
8989
parser.add_argument("--split", type=str, default="test")
9090
parser.add_argument("--save_dir", type=str, default="your-save-dir")
9191
parser.add_argument("--save_note", type=str, default='your-save-note-for-identification')
9292
parser.add_argument("--sgl_remote_url", type=str, default="your-sgl-remote-url")
93+
parser.add_argument("--sandbox_url", type=str, default="your-sandbox-url")
9394
parser.add_argument("--remote_retriever_url", type=str, default="your-remote-retriever-url")
9495
parser.add_argument("--generator_model", type=str, default="your-local-model-path")
95-
parser.add_argument("--apply_chat", type=bool, default=True)
9696

9797
func_dict = {
9898
"naive": naive,
9999
"zero-shot": zero_shot,
100100
"iterretgen": iterretgen,
101101
"ircot": ircot,
102-
"research": research,
102+
"re-call": re_call,
103103
}
104104

105105
args = parser.parse_args()
@@ -113,6 +113,7 @@ def research(args, config_dict):
113113
"sgl_remote_url": args.sgl_remote_url,
114114
"remote_retriever_url": args.remote_retriever_url,
115115
"generator_model": args.generator_model,
116+
"sandbox_url": args.sandbox_url,
116117
}
117118

118119
func = func_dict[args.method_name]

0 commit comments

Comments
 (0)