赞
踩
Github:https://github.com/ysymyth/ReAct
人类的认知通常具备一定的自我调节(self-regulation)和策略制定(strategization)的能力,对于解决一个复杂问题时,可以很自然地运用工作记忆(working memory)将任务相关的决策动作(actions)与思考推理(reasoning)相结合。
This tight synergy between “acting” and “reasoning” allows humans to learn new tasks quickly and perform robust decision making or reasoning, even under previously unseen circumstances or facing information uncertainties.
虽然现如今诸如Chain-of-Thought(CoT)通过上下文提示的形式可以提高大语言模型step-by-step的推理能力,但是其依然属于静态的黑盒子,依靠其推理的结果很难与真实知识保持一致,且限制了推理过程中及时反应和知识更新的能力。从而可能引发推理过程中的幻觉问题。
虽然目前有一些工作在研究嵌入式reasoning,但还没有研究如何将推理和行动以协同的方式结合起来解决通用任务,以及这种结合是否比单独的推理或行动带来系统性的好处。
本文提出ReAct,旨在通过提示的方式来让大语言模型能够协同Action和Reasoning。下图展示了ReAct的工作机制,挑选了两个例子(HotpotQA和AlfWold)并对比了其他三个Baseline(ICL、CoT、Act-only):
可发现,传统的ICL、CoT方法都不能够给出正确的答案,因为其没有运用Action来获得环境给予的反馈;Act-only则没有真正地将Action和Reasoning结合。
回顾一下Act-only的定义(类似强化学习中的过程),第
t
t
t时刻,Agent观察到当前的环境给予的反馈记作
o
t
∈
O
o_t\in\mathcal{O}
ot∈O(其中
O
\mathcal{O}
O表示整个环境)后,做出一个动作记作
a
t
∈
A
a_t\in\mathcal{A}
at∈A(
A
\mathcal{A}
A表示动作空间),因此,Agent需要学习一个策略
π
(
a
t
∣
c
t
)
\pi(a_t|c_t)
π(at∣ct),使得其能够在基于当前状态和历史行为序列
c
t
=
(
o
1
,
a
1
,
⋯
,
o
t
−
1
,
a
t
−
1
,
o
t
)
c_t=(o_1, a_1, \cdots, o_{t-1}, a_{t-1}, o_t)
ct=(o1,a1,⋯,ot−1,at−1,ot)的基础上给出合适的动作
a
t
a_t
at。
ReAct的创新点在于,在动作空间中新增了一个基于自然语言的推理动作空间
L
\mathcal{L}
L,即
A
^
=
A
∪
L
\hat{\mathcal{A}}=\mathcal{A}\cup\mathcal{L}
A^=A∪L。换句话说,Agent在某时刻除了可以直接执行一个实质性的Action
a
t
a_t
at以外,也可以触发Reasoning
a
^
t
\hat{a}_t
a^t。而触发Reasoning其实不会影响到环境。
a thought a ^ t \hat{a}_t a^t aims to compose useful information by reasoning over the current context c t c_t ct, and update the context c t + 1 = ( c t , a ^ t ) c_{t+1} = (c_t, \hat{a}_t) ct+1=(ct,a^t) to support future reasoning or acting.
换句话说,ReAct是将Act-only额外扩展了Reasoning的动作。大语言模型在执行动作与外部环境进行交互的同时,能够及时的进行推理和思考,并基于这些思考及时地调整后续的Action。相比Act-only,这些Reasoning可以约束并优化可能存在错误或无法执行的Action。
如下图所示:
在这个例子中,Thought1的思考相当于规划,Thought2相当于抽取Obs1中的有价值信息;Thought3相当于调整Action,Thought4相当于总结并得出结论。试想一下,如果没有这些Reasoning的介入,可能执行道Act3就无法继续执行下去,从而引发幻觉。
ReAct在具体实现时需要注意下面的一些情况。
(1)基于自然语言的推理动作空间
L
\mathcal{L}
L**通常是无限的,所以需要通过prompt进行一些约束。**即通过上下文提示来告诉大模型当前任务有哪些动作空间。
在ReAct中,针对不同的domain,设置了人工编写的In-Context Examplar。每个样本都是action-thought-observation序列。
以HotpotQA为例:
因此,ReAct的运作本质上还是和CoT一样,只是以插入的形式交叉地添加了Act和Reasoning。
(2)Reasoning与Action的稀疏性
以上都是通过ICL exemplar来实现的。
包含两个任务:
动作空间:
prompt如下所示:
prompt包含若干个thought-action-observation序列,且都是人工标注,每个序列都是按照预先设定的模式编写的:
参数内知识与外部知识的结合
ReAct与CoT在HotpotQA任务上执行情况如下表:
执行成功的所有任务中,ReAct达到94%的精度。
也有部分任务执行失败,其中大多数来自于推理错误、搜索结果错误或失败。也有一部分实质上预测的结果正确,但没有匹配到标签上。
挑选了ALFWorld和WebShop两个任务。prompt如下所示:
以HotpotQA任务为例,测试样本如下所示:
{
"question": "What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?",
"answer": "Chief of Protocol",
"type": "bridge"
}
因为HotpotQA需要涉及到搜索Wikipedia的环境,因此先定义WikiEnv。
import ast import json import time import gym # reinforcement learning toolkit,用于与环境做交互的python工具包 import requests from bs4 import BeautifulSoup # 访问http # import wikipedia def clean_str(p): return p.encode().decode("unicode-escape").encode("latin1").decode("utf-8") class textSpace(gym.spaces.Space): def contains(self, x) -> bool: """Return boolean specifying if x is a valid member of this space.""" return isinstance(x, str) class WikiEnv(gym.Env): def __init__(self): """ 定义环境类 Initialize the environment. """ super().__init__() self.page = None # current Wikipedia page self.obs = None # current observation self.lookup_keyword = None # current lookup keyword self.lookup_list = None # list of paragraphs containing current lookup keyword self.lookup_cnt = None # current lookup index self.steps = 0 # current number of steps self.answer = None # current answer from the agent self.observation_space = self.action_space = textSpace() self.search_time = 0 self.num_searches = 0 def _get_obs(self): return self.obs def _get_info(self): return {"steps": self.steps, "answer": self.answer} def reset(self, seed=None, return_info=False, options=None): # We need the following line to seed self.np_random # super().reset(seed=seed) self.obs = ("Interact with Wikipedia using search[], lookup[], and " "finish[].\n") self.page = None self.lookup_keyword = None self.lookup_list = None self.lookup_cnt = None self.steps = 0 self.answer = None observation = self._get_obs() info = self._get_info() return (observation, info) if return_info else observation def construct_lookup_list(self, keyword): # lookup action:用于寻找符合keyword的next sentence # find all paragraphs if self.page is None: return [] paragraphs = self.page.split("\n") paragraphs = [p.strip() for p in paragraphs if p.strip()] # find all sentence sentences = [] for p in paragraphs: sentences += p.split('. ') sentences = [s.strip() + '.' for s in sentences if s.strip()] parts = sentences parts = [p for p in parts if keyword.lower() in p.lower()] return parts @staticmethod def get_page_obs(page): # find all paragraphs paragraphs = page.split("\n") paragraphs = [p.strip() for p in paragraphs if p.strip()] # find all sentence sentences = [] for p in paragraphs: sentences += p.split('. ') sentences = [s.strip() + '.' for s in sentences if s.strip()] return ' '.join(sentences[:5]) # ps = page.split("\n") # ret = ps[0] # for i in range(1, len(ps)): # if len((ret + ps[i]).split(" ")) <= 50: # ret += ps[i] # else: # break # return ret def search_step(self, entity): # search action:执行搜索entity的动作 # 解析wikipedia页面 entity_ = entity.replace(" ", "+") search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}" old_time = time.time() response_text = requests.get(search_url).text self.search_time += time.time() - old_time self.num_searches += 1 soup = BeautifulSoup(response_text, features="html.parser") result_divs = soup.find_all("div", {"class": "mw-search-result-heading"}) if result_divs: # mismatch self.result_titles = [clean_str(div.get_text().strip()) for div in result_divs] self.obs = f"Could not find {entity}. Similar: {self.result_titles[:5]}." else: page = [p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")] if any("may refer to:" in p for p in page): self.search_step("[" + entity + "]") else: self.page = "" for p in page: if len(p.split(" ")) > 2: self.page += clean_str(p) if not p.endswith("\n"): self.page += "\n" self.obs = self.get_page_obs(self.page) self.lookup_keyword = self.lookup_list = self.lookup_cnt = None def step(self, action): # 调用此方法,根据action的类别,选择执行相应的动作 # search entity,则调用self.search_step # lookup,则调用self.construct_lookup_list、 reward = 0 done = False action = action.strip() if self.answer is not None: # already finished done = True return self.obs, reward, done, self._get_info() if action.startswith("search[") and action.endswith("]"): entity = action[len("search["):-1] # entity_ = entity.replace(" ", "_") # search_url = f"https://en.wikipedia.org/wiki/{entity_}" self.search_step(entity) elif action.startswith("lookup[") and action.endswith("]"): keyword = action[len("lookup["):-1] if self.lookup_keyword != keyword: # reset lookup self.lookup_keyword = keyword self.lookup_list = self.construct_lookup_list(keyword) self.lookup_cnt = 0 if self.lookup_cnt >= len(self.lookup_list): self.obs = "No more results.\n" else: self.obs = f"(Result {self.lookup_cnt + 1} / {len(self.lookup_list)}) " + self.lookup_list[self.lookup_cnt] self.lookup_cnt += 1 elif action.startswith("finish[") and action.endswith("]"): answer = action[len("finish["):-1] self.answer = answer done = True self.obs = f"Episode finished, reward = {reward}\n" elif action.startswith("think[") and action.endswith("]"): self.obs = "Nice thought." else: self.obs = "Invalid action: {}".format(action) self.steps += 1 return self.obs, reward, done, self._get_info() def get_time_info(self): speed = self.search_time / self.num_searches if self.num_searches else 0 return { "call_speed": speed, "call_time": self.search_time, "num_calls": self.num_searches, }
定义一个Wrapper,将与HotpotQa相关的信息进行封装,包括环境等。
class HotPotQAWrapper(gym.Wrapper): def __init__(self, env, split): super().__init__(env) data_file = f"{DATA_DIR}/{HOTPOTQA_SPLIT_FILE[split]}" self.data = json.load(open(data_file)) self.data = [(d['question'], d['answer']) for d in self.data] self.data_idx = 0 self.split = split def reset(self, seed=None, return_info=False, options=None, idx=None): self.env.reset(seed=seed, return_info=return_info, options=options) try: self.env.step('') except: pass self.env.reset(seed=seed, return_info=return_info, options=options) self.data_idx = int(np.random.randint(len(self.data))) if idx is None else idx observation = f"Question: {self.data[self.data_idx][0]}" info = self._get_info() return (observation, info) if return_info else observation def _get_info(self): return { "steps": self.steps, "answer": self.answer, "question": self.data[self.data_idx][0], "hotpot_split": self.split } def get_reward(self, info): if info['answer'] is not None: pred = normalize_answer(self.data[self.data_idx][1]) gt = normalize_answer(info['answer']) score = (pred == gt) return int(score) return 0 def get_metrics(self, info): if info['answer'] is not None: pred = normalize_answer(self.data[self.data_idx][1]) gt = normalize_answer(info['answer']) em = (pred == gt) f1 = f1_score(pred, gt)[0] return {'reward': em, 'em': em, 'f1': f1} return {'reward': 0, 'em': 0, 'f1': 0} def step(self, action): # TODO: first step obs does not have question. obs, _, done, info = self.env.step(action) reward = self.get_reward(info) # 这里的reward认为是适配gym框架所需,实际上可以直接用evaluation(例如acc等)做代替 if done: obs = f"Episode finished, reward = {reward}\n" info.update({"gt_answer": self.data[self.data_idx][1], "question_idx": self.data_idx}) info.update(self.get_metrics(info)) return obs, reward, done, info def __len__(self): return len(self.data)
定义好了Wrapper和Env,下面进行ReAct的调用
import os import openai import wikienv, wrappers env = wikienv.WikiEnv() env = wrappers.HotPotQAWrapper(env, split="dev") env = wrappers.LoggingWrapper(env) openai.api_key = os.environ["OPENAI_API_KEY"] def llm(prompt, stop=["\n"]): response = openai.Completion.create( model="text-davinci-002", prompt=prompt, temperature=0, max_tokens=100, top_p=1, frequency_penalty=0.0, presence_penalty=0.0, stop=stop ) return response["choices"][0]["text"] def step(env, action): attempts = 0 while attempts < 10: try: return env.step(action) except requests.exceptions.Timeout: attempts += 1
import json import sys import random import time idxs = list(range(7405)) # 所有任务 random.Random(233).shuffle(idxs) # 打乱顺序 # 先获得In-Context Exemplar,构建prompt folder = './prompts/' prompt_file = 'prompts_naive.json' with open(folder + prompt_file, 'r') as f: prompt_dict = json.load(f) webthink_examples = prompt_dict['webthink_simple6'] instruction = """Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types: (1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search. (2) Lookup[keyword], which returns the next sentence containing keyword in the current passage. (3) Finish[answer], which returns the answer and finishes the task. Here are some examples. """ # Instruction+In-Context Exemplar webthink_prompt = instruction + webthink_examples def webthink(idx=None, prompt=webthink_prompt, to_print=True): # 给定一个编号为idx的的question,重置当前的环境 question = env.reset(idx=idx) if to_print: print(idx, question) prompt += question + "\n" # 设置好prompt n_calls, n_badcalls = 0, 0 # 一个任务,最多调用LLM8次。 for i in range(1, 8): n_calls += 1 # 根据当前的prompt(这个prompt可以是最开始的ICL+指令,也可以是经过几轮之后的动作序列)进行思考(reasoning thought),大模型生成一些action thought_action = llm(prompt + f"Thought {i}:", stop=[f"\nObservation {i}:"]) try: thought, action = thought_action.strip().split(f"\nAction {i}: ") except: print('ohh...', thought_action) n_badcalls += 1 n_calls += 1 thought = thought_action.strip().split('\n')[0] action = llm(prompt + f"Thought {i}: {thought}\nAction {i}:", stop=[f"\n"]).strip() # 根据action,在环境中执行一个动作,并获得observation # 这里的step()会根据action的类型,执行包括search、lookup等调用外部工具的动作。 obs, r, done, info = step(env, action[0].lower() + action[1:]) obs = obs.replace('\\n', '') # 将Thought、Action和Obs重新规约到一个prompt里,并与之前的prompt进行拼接 step_str = f"Thought {i}: {thought}\nAction {i}: {action}\nObservation {i}: {obs}\n" prompt += step_str if to_print: print(step_str) if done: break if not done: obs, r, done, info = step(env, "finish[]") if to_print: print(info, '\n') info.update({'n_calls': n_calls, 'n_badcalls': n_badcalls, 'traj': prompt}) return r, info rs = [] infos = [] old_time = time.time() for i in idxs[:500]: # 对当前第idx个任务进行推理 r, info = webthink(i, to_print=True) rs.append(info['em']) infos.append(info) print(sum(rs), len(rs), sum(rs) / len(rs), (time.time() - old_time) / len(rs)) print('-----------') print()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。