LlamaIndex 使用 RouterOutputAgentWorkflow

2024-09-06 12:12

本文主要是介绍LlamaIndex 使用 RouterOutputAgentWorkflow,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

LlamaIndex 中提供了一个 RouterOutputAgentWorkflow 功能,可以集成多个 QueryTool,根据用户的输入判断使用那个 QueryEngine,在做查询的时候,可以从不同的数据源进行查询,例如确定的数据从数据库查询,如果是语义查询可以从向量数据库进行查询。本文将实现两个搜索引擎,根据不同 Query 使用不同 QueryEngine。

安装 MySQL 依赖

pip install mysql-connector-python  

搜索引擎

定义搜索引擎,初始两个数据源

  • 使用 MySQL 作为数据库的数据源
  • 使用 VectorIndex 作为语义搜索数据源
from pathlib import Path
from llama_index.core.tools import QueryEngineTool
from llama_index.core import VectorStoreIndex
import llm
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.query_engine import NLSQLTableQueryEngine
from llama_index.core import Settings
from llama_index.core import SQLDatabasefrom sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select
Settings.llm = llm.get_ollama("mistral-nemo")
Settings.embed_model = llm.get_ollama_embbeding()engine = create_engine('mysql+mysqlconnector://root:123456@localhost:13306/db_llama', echo=True  
)def init_db():# 初始化数据库metadata_obj = MetaData()table_name = "city_stats"city_stats_table = Table(table_name,metadata_obj,Column("city_name", String(16), primary_key=True),Column("population", Integer, ),Column("state", String(16), nullable=False),)metadata_obj.create_all(engine)sql_database = SQLDatabase(engine, include_tables=["city_stats"])from sqlalchemy import insertrows = [{"city_name": "New York City", "population": 8336000, "state": "New York"},{"city_name": "Los Angeles", "population": 3822000, "state": "California"},{"city_name": "Chicago", "population": 2665000, "state": "Illinois"},{"city_name": "Houston", "population": 2303000, "state": "Texas"},{"city_name": "Miami", "population": 449514, "state": "Florida"},{"city_name": "Seattle", "population": 749256, "state": "Washington"},]for row in rows:stmt = insert(city_stats_table).values(**row)with engine.begin() as connection:cursor = connection.execute(stmt)from llama_index.core.query_engine import NLSQLTableQueryEnginesql_database = SQLDatabase(engine, include_tables=["city_stats"])
sql_query_engine = NLSQLTableQueryEngine(sql_database=sql_database,tables=["city_stats"]
)def get_doc_index()-> VectorStoreIndex:'''解析 words'''# 创建 OllamaEmbedding 实例,用于指定嵌入模型和服务的基本 URLollama_embedding = llm.get_ollama_embbeding()# 读取 "./data" 目录中的数据并加载为文档对象documents = SimpleDirectoryReader(input_files=[Path(__file__).parent / "data" / "LA.pdf"]).load_data()# 从文档中创建 VectorStoreIndex,并使用 OllamaEmbedding 作为嵌入模型vector_index = VectorStoreIndex.from_documents(documents, embed_model=ollama_embedding, transformations=[SentenceSplitter(chunk_size=1000, chunk_overlap=20)],)vector_index.set_index_id("vector_index")  # 设置索引 IDvector_index.storage_context.persist("./storage")  # 将索引持久化到 "./storage"return vector_indexllama_index_query_engine = get_doc_index().as_query_engine()sql_tool = QueryEngineTool.from_defaults(query_engine=sql_query_engine,description=("Useful for translating a natural language query into a SQL query over"" a table containing: city_stats, containing the population/state of"" each city located in the USA."),name="sql_tool"
)llama_cloud_tool = QueryEngineTool.from_defaults(query_engine=llama_index_query_engine,description=(f"Useful for answering semantic questions about certain cities in the US."),name="llama_cloud_tool"
)

创建工作流

下图中显示了工作流的节点,绿色背景节点是工作流的动作,例如大模型返回 ToolEvent,ToolEvent 节点执行并返回结果。
在这里插入图片描述
工作流定义代码:

from typing import Dict, List, Any, Optionalfrom llama_index.core.tools import BaseTool
from llama_index.core.llms import ChatMessage
from llama_index.core.llms.llm import ToolSelection, LLM
from llama_index.core.workflow import (Workflow,Event,StartEvent,StopEvent,step,Context
)
from llama_index.core.base.response.schema import Response
from llama_index.core.tools import FunctionTool
from llama_index.utils.workflow import draw_all_possible_flows
from llm import get_ollamafrom docs import enable_traceenable_trace()class InputEvent(Event):"""Input event."""class GatherToolsEvent(Event):"""Gather Tools Event"""tool_calls: Anyclass ToolCallEvent(Event):"""Tool Call event"""tool_call: ToolSelectionclass ToolCallEventResult(Event):"""Tool call event result."""msg: ChatMessageclass RouterOutputAgentWorkflow(Workflow):"""Custom router output agent workflow."""def __init__(self,tools: List[BaseTool],timeout: Optional[float] = 10.0,disable_validation: bool = False,verbose: bool = False,llm: Optional[LLM] = None,chat_history: Optional[List[ChatMessage]] = None,):"""Constructor."""super().__init__(timeout=timeout, disable_validation=disable_validation, verbose=verbose)self.tools: List[BaseTool] = toolsself.tools_dict: Optional[Dict[str, BaseTool]] = {tool.metadata.name: tool for tool in self.tools}self.llm: LLM = llmself.chat_history: List[ChatMessage] = chat_history or []def reset(self) -> None:"""Resets Chat History"""self.chat_history = []@step()async def prepare_chat(self, ev: StartEvent) -> InputEvent:message = ev.get("message")if message is None:raise ValueError("'message' field is required.")# add msg to chat historychat_history = self.chat_historychat_history.append(ChatMessage(role="user", content=message))return InputEvent()@step()async def chat(self, ev: InputEvent) -> GatherToolsEvent | StopEvent:"""Appends msg to chat history, then gets tool calls."""# Put msg into LLM with tools includedchat_res = await self.llm.achat_with_tools(self.tools,chat_history=self.chat_history,verbose=self._verbose,allow_parallel_tool_calls=True)tool_calls = self.llm.get_tool_calls_from_response(chat_res, error_on_no_tool_call=False)ai_message = chat_res.messageself.chat_history.append(ai_message)if self._verbose:print(f"Chat message: {ai_message.content}")# no tool calls, return chat message.if not tool_calls:return StopEvent(result=ai_message.content)return GatherToolsEvent(tool_calls=tool_calls)@step(pass_context=True)async def dispatch_calls(self, ctx: Context, ev: GatherToolsEvent) -> ToolCallEvent:"""Dispatches calls."""tool_calls = ev.tool_callsawait ctx.set("num_tool_calls", len(tool_calls))# trigger tool call eventsfor tool_call in tool_calls:ctx.send_event(ToolCallEvent(tool_call=tool_call))return None@step()async def call_tool(self, ev: ToolCallEvent) -> ToolCallEventResult:"""Calls tool."""tool_call = ev.tool_call# get tool ID and function callid_ = tool_call.tool_idif self._verbose:print(f"Calling function {tool_call.tool_name} with msg {tool_call.tool_kwargs}")# call function and put result into a chat messagetool = self.tools_dict[tool_call.tool_name]output = await tool.acall(**tool_call.tool_kwargs)msg = ChatMessage(name=tool_call.tool_name,content=str(output),role="tool",additional_kwargs={"tool_call_id": id_,"name": tool_call.tool_name})return ToolCallEventResult(msg=msg)@step(pass_context=True)async def gather(self, ctx: Context, ev: ToolCallEventResult) -> StopEvent | None:"""Gathers tool calls."""# wait for all tool call events to finish.tool_events = ctx.collect_events(ev, [ToolCallEventResult] * await ctx.get("num_tool_calls"))if not tool_events:return Nonefor tool_event in tool_events:# append tool call chat messages to historyself.chat_history.append(tool_event.msg)# # after all tool calls finish, pass input event back, restart agent loopreturn InputEvent()from muti_agent import sql_tool, llama_cloud_tool
wf = RouterOutputAgentWorkflow(tools=[sql_tool, llama_cloud_tool], verbose=True, timeout=120, llm=get_ollama("mistral-nemo"))async def main():result = await wf.run(message="Which city has the highest population?")print("RSULT ===============", result)# if __name__ == "__main__":
#     import asyncio#     asyncio.run(main())import gradio as grasync def random_response(message, history):wf.reset()result = await wf.run(message=message)print("RSULT ===============", result)return resultdemo = gr.ChatInterface(random_response, clear_btn=None, title="Qwen2")demo.launch()

输入问题是 “What are five popular travel spots in Los Angeles?”,自动路由到 VectorIndex 进行查询。
在这里插入图片描述
输入问题为 “which city has the most population” 时,调用数据库进行搜索。
在这里插入图片描述

总结

LlamaIndex 中搜索引擎自动路由,根据用户的输入型自动选择所需的搜索引擎,这里有一个需要注意的点,模型需要支持 Function Call。如果 Ollama 本地模型进行推理,不是所有的本地模型都支持Function Call,Llama3.1 和 mistral-nemo 是支持 Function Call 的,可以使用。

这篇关于LlamaIndex 使用 RouterOutputAgentWorkflow的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1141971

相关文章

springboot security快速使用示例详解

《springbootsecurity快速使用示例详解》:本文主要介绍springbootsecurity快速使用示例,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝... 目录创www.chinasem.cn建spring boot项目生成脚手架配置依赖接口示例代码项目结构启用s

Python如何使用__slots__实现节省内存和性能优化

《Python如何使用__slots__实现节省内存和性能优化》你有想过,一个小小的__slots__能让你的Python类内存消耗直接减半吗,没错,今天咱们要聊的就是这个让人眼前一亮的技巧,感兴趣的... 目录背景:内存吃得满满的类__slots__:你的内存管理小助手举个大概的例子:看看效果如何?1.

java中使用POI生成Excel并导出过程

《java中使用POI生成Excel并导出过程》:本文主要介绍java中使用POI生成Excel并导出过程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录需求说明及实现方式需求完成通用代码版本1版本2结果展示type参数为atype参数为b总结注:本文章中代码均为

Spring Boot3虚拟线程的使用步骤详解

《SpringBoot3虚拟线程的使用步骤详解》虚拟线程是Java19中引入的一个新特性,旨在通过简化线程管理来提升应用程序的并发性能,:本文主要介绍SpringBoot3虚拟线程的使用步骤,... 目录问题根源分析解决方案验证验证实验实验1:未启用keep-alive实验2:启用keep-alive扩展建

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

GORM中Model和Table的区别及使用

《GORM中Model和Table的区别及使用》Model和Table是两种与数据库表交互的核心方法,但它们的用途和行为存在著差异,本文主要介绍了GORM中Model和Table的区别及使用,具有一... 目录1. Model 的作用与特点1.1 核心用途1.2 行为特点1.3 示例China编程代码2. Tab

SpringBoot使用OkHttp完成高效网络请求详解

《SpringBoot使用OkHttp完成高效网络请求详解》OkHttp是一个高效的HTTP客户端,支持同步和异步请求,且具备自动处理cookie、缓存和连接池等高级功能,下面我们来看看SpringB... 目录一、OkHttp 简介二、在 Spring Boot 中集成 OkHttp三、封装 OkHttp

使用Python实现获取网页指定内容

《使用Python实现获取网页指定内容》在当今互联网时代,网页数据抓取是一项非常重要的技能,本文将带你从零开始学习如何使用Python获取网页中的指定内容,希望对大家有所帮助... 目录引言1. 网页抓取的基本概念2. python中的网页抓取库3. 安装必要的库4. 发送HTTP请求并获取网页内容5. 解

使用Python实现网络设备配置备份与恢复

《使用Python实现网络设备配置备份与恢复》网络设备配置备份与恢复在网络安全管理中起着至关重要的作用,本文为大家介绍了如何通过Python实现网络设备配置备份与恢复,需要的可以参考下... 目录一、网络设备配置备份与恢复的概念与重要性二、网络设备配置备份与恢复的分类三、python网络设备配置备份与恢复实

C#中的 StreamReader/StreamWriter 使用示例详解

《C#中的StreamReader/StreamWriter使用示例详解》在C#开发中,StreamReader和StreamWriter是处理文本文件的核心类,属于System.IO命名空间,本... 目录前言一、什么是 StreamReader 和 StreamWriter?1. 定义2. 特点3. 用