LangChain结合SQLAgent实现TextToSql

LangChain 目前提供了SQL Chain(SqlDatabaseChain)和SQL Agent(SqlAgent)的方式来实现与存储在数据库中的数据进行交互。

可以参考上篇文章:使用LangChain实现TextToSql-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/wangning0714/article/details/133915510?spm=1001.2014.3001.5502

LangChain使用LLM构建复杂应用变得更容易。它可以让用户把LLM连接到其他数据源。通过把LLM连接到其他数据源,应用可以处理更广泛的信息。这使得应用更强大和多样化。

Agent是一个可以访问一套工具的组件,其中包括大型语言模型 (LLM)。其显着特点在于能够根据用户输入做出明智的决策,利用适当的工具,直到获得满意的答案。例如,在文本转SQL的情况下,如果执行生成的SQL出现错误,LangChain SQLAgent不会放弃。相反,它将尝试通过解释后续 LLM 调用中的错误来恢复并纠正问题。因此,理论上,SQLAgent 在生产力和准确性方面应该优于 SQLChain。

本文展示让SQLAgent 如何针对我们的mysql数据库执行操作。

1.依赖环境

mysql数据库的安装在此不做展示。确保已经有一个可用的API-KEY。

from langchain.llms import OpenAI
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql.base import SQLDatabaseChain
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents import create_sql_agent
from langchain.agents.agent_types import AgentType

Mysql表结构展示:

 建表sql如下:

--记录用户信息(id、名字、邮箱)
CREATE TABLE users (
    user_id INT PRIMARY KEY,
    username VARCHAR(255),
    email VARCHAR(255)
);

--记录用户订单信息(订单编号、用户id、产品id、下单时间)
CREATE TABLE orders (
    order_id INT PRIMARY KEY,
    user_id INT,
		product_id INT,
    order_date DATE,
    FOREIGN KEY (user_id) REFERENCES users(user_id),
		FOREIGN KEY (product_id) REFERENCES products(product_id)
);

--记录产品信息(产品id、产品名称)
CREATE TABLE products (
    product_id INT PRIMARY KEY,
    product_name VARCHAR(255)
);


--插入测试数据
INSERT INTO users (user_id, username, email)
VALUES 
(1, 'user1', '[email protected]'),
(2, 'user2', '[email protected]'),
(3, 'user3', '[email protected]');
INSERT INTO orders (order_id, user_id, product_id, order_date)
VALUES 
(1, 1, 1, '2023-01-15'),
(2, 1, 1, '2023-02-10'),
(3, 2, 2, '2023-03-05'),
(4, 3, 3, '2023-04-20');
INSERT INTO products (product_id, product_name)
VALUES 
(1, 'Product A'),
(2, 'Product B'),
(3, 'Product C');

2.设置环境

同上一篇文章内容相似

# mysql数据库连接
db_user = "root"
db_password = "root"
db_host = "localhost"
db_name = "test_001"
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

# 这里以gpt-3.5 为模型
OPENAI_API_KEY = "您的 OpenAI 密钥"
llm = OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY, model_name='gpt-3.5-turbo')

 3.配置SQLAgent

配置SQLAgent

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)

 4.运行查询

question = "帮我查询用户对应订单和订单中产品的信息"
agent_executor.run(question)

查看运行结果:

D:\Anaconda\envs\myenv\python.exe E:/Text-to-sql/git/pergit_wangning/code/1007code/TextToSql/LLMToSql-带有agent实现.py

> Entering new AgentExecutor chain...

Action: sql_db_list_tables
Action Input: ""
Observation: actor, magic_api_backup, magic_api_file, magic_api_info, orders, products, t_iit_m_actreg, t_iit_region, users
Thought:There are several tables in the database that could be relevant to the question. I should query the schema of the "users", "orders", and "products" tables to see what columns they have.
Action: sql_db_schema
Action Input: "users, orders, products"
D:\Anaconda\envs\myenv\lib\site-packages\langchain\__init__.py:34: UserWarning: Importing llm_cache from langchain root module is no longer supported. Please use langchain.globals.set_llm_cache() / langchain.globals.get_llm_cache() instead.
  warnings.warn(
Observation: 
CREATE TABLE orders (
	order_id INTEGER(11) NOT NULL, 
	user_id INTEGER(11), 
	product_id INTEGER(11), 
	order_date DATE, 
	PRIMARY KEY (order_id), 
	CONSTRAINT orders_ibfk_1 FOREIGN KEY(user_id) REFERENCES users (user_id), 
	CONSTRAINT orders_ibfk_2 FOREIGN KEY(product_id) REFERENCES products (product_id)
)ENGINE=InnoDB DEFAULT CHARSET=utf8

/*
3 rows from orders table:
order_id	user_id	product_id	order_date
1	1	1	2023-01-15
2	1	1	2023-02-10
3	2	2	2023-03-05
*/


CREATE TABLE products (
	product_id INTEGER(11) NOT NULL, 
	product_name VARCHAR(255), 
	PRIMARY KEY (product_id)
)ENGINE=InnoDB DEFAULT CHARSET=utf8

/*
3 rows from products table:
product_id	product_name
1	Product A
2	Product B
3	Product C
*/


CREATE TABLE users (
	user_id INTEGER(11) NOT NULL, 
	username VARCHAR(255), 
	email VARCHAR(255), 
	PRIMARY KEY (user_id)
)ENGINE=InnoDB DEFAULT CHARSET=utf8

/*
3 rows from users table:
user_id	username	email
1	user1	[email protected]
2	user2	[email protected]
3	user3	[email protected]
*/
Thought:The relevant tables for this question are "users", "orders", and "products". The "orders" table has columns for order_id, user_id, product_id, and order_date. The "products" table has columns for product_id and product_name. The "users" table has columns for user_id, username, and email. I can join these tables on the appropriate columns to get the desired information.

Action: sql_db_query
Action Input: "SELECT users.username, orders.order_id, products.product_name FROM users JOIN orders ON users.user_id = orders.user_id JOIN products ON orders.product_id = products.product_id LIMIT 10"
Observation: [('user1', 1, 'Product A'), ('user1', 2, 'Product A'), ('user2', 3, 'Product B'), ('user3', 4, 'Product C')]
Thought:The query returned the usernames, order IDs, and product names for the users, orders, and products tables. I can use this information to answer the original question.
Final Answer: The users correspond to the orders and the products in the following way:
- User 1 has ordered Product A with order IDs 1 and 2.
- User 2 has ordered Product B with order ID 3.
- User 3 has ordered Product C with order ID 4.

> Finished chain.

进程已结束,退出代码0

5.总结

通过测试,我们使用 GPT-3.5 对 SQLChain 和 本文SQLAgent 运行了多个问题,并比较了它们各自的结果。结果表明,SQLAgent 在回答更多问题方面优于 SQLChain。

LangChain SQLAgent是一个强大的工具,可以创建复杂的LLM链调用来回答用户问题。尽管它返回了大多数相关问题的答案,但它无法防止 LLM 幻觉。

猜你喜欢

转载自blog.csdn.net/wangning0714/article/details/133900753