Learning to Use Tools
Learning Tools (Experimental ๐งช)
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as ToolFormer and ToolBench. In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Hereโs an overview of the scripts in the trl repository:
Script to train LLM to use a calculator with reinforcement learning.
Script to train LLM to use a wiki tool to answer questions.
Script to train LLM to use python interpreter to solve math puzzles.
Note that the scripts above rely heavily on the TextEnvironment API which is still under active development. The API may change in the future. Please see TextEnvironment for the related docs.
Learning to Use a Calculator
The rough idea is as follows:
Load a tool such as ybelkada/simple-calculator that parse a text calculation like
"14 + 34"and return the calulated number:Copied
from transformers import AutoTokenizer, load_tool tool = load_tool("ybelkada/simple-calculator") tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal placesDefine a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like
reward_fn = lambda x: 1, but we override the rewards directly later.Create a prompt on how to use the tools
Copied
# system prompt prompt = """\ What is 13.1-3? <request><SimpleCalculatorTool>13.1-3<call>10.1<response> Result=10.1<submit> What is 4*3? <request><SimpleCalculatorTool>4*3<call>12<response> Result=12<submit> What is 12.1+1? <request><SimpleCalculatorTool>12.1+1<call>13.1<response> Result=13.1<submit> What is 12.1-20? <request><SimpleCalculatorTool>12.1-20<call>-7.9<response> Result=-7.9<submit>"""Create a
trl.TextEnvironmentwith the modelCopied
env = TextEnvironment( model, tokenizer, {"SimpleCalculatorTool": tool_fn}, reward_fn, prompt, generation_kwargs=generation_kwargs, )Then generate some data such as
tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]and run the environment withqueries, responses, masks, rewards, histories = env.run(tasks). The environment will look for the<call>token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use thehistoriesto visualize the interaction between the model and the tool;histories[0].show_text()will show the text with color-coded tool output andhistories[0].show_tokens(tokenizer)will show visualize the tokens.
Finally, we can train the model with
train_stats = ppo_trainer.step(queries, responses, rewards, masks). The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument tostep.
Experiment results
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the --slurm-* arguments if you donโt have access to a slurm cluster.
Copied
We can then use openrlbenchmark which generates the following plot.
Copied

As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
(Early Experiments ๐งช): learning to use a wiki tool for question answering
In the ToolFormer paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the TriviaQA dataset.
Note that many settings are different so the results are not directly comparable.
Building a search index
Since ToolFormer did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from KILT
Fortunately, pyserini already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
Copied
Copied
We then basically deployed this snippet as a BOINC AI space here, so that we can use the space as a transformers.Tool later.

Experiment settings
We use the following settings:
use the
bigcode/starcoderbasemodel as the base modeluse the
pyserini-wikipedia-kilt-docspace as the wiki tool and only uses the first paragrahs of the search result, allowing theTextEnvironmentto obtain at mostmax_tool_reponse=400response tokens from the tool.test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
notice this is a simplified evaluation criteria. In ToolFormer, the authors checks if the first 20 words of the response contain the correct answer.
used the following prompt that demonstrates the usage of the wiki tool.
Copied
Result and Discussion
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.

Wandb report is here for further inspection.
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
incorrect searches: When given the question
"What is Bruce Willis' real first name?"if the model searches forBruce Willis, our wiki tool returns โPatrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.But a correct search should beWalter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985โ1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988โ2013) and other roles.[1][2]โ

unnecessarily long response: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for โBrown Actโ
Our wiki tool returns โThe Ralph M. Brown Act, located at California Government Code 54950 โet seq.โ, is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the publicโs right to attend and participate in meetings of local legislative bodies.โ
ToolFormerโs wiki tool returns โThe Ralph M. Brown Act is an act of the California State Legislature that guarantees the publicโs right to attend and participate in meetings of local legislative bodies.โ which is more succinct.

(Early Experiments ๐งช): solving math puzzles with python interpreter
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
Copied
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y

Last updated