【LangChain】SQL

LangChain学习文档


概要

本文讲述如何使用 SQLDatabaseChain 通过 SQL 数据库回答问题。

在底层,LangChain 使用 SQLAlchemy 连接到 SQL 数据库。因此,SQLDatabaseChain 可以与 SQLAlchemy 支持的任何 SQL 方言一起使用,例如 MS SQL、MySQL、MariaDB、PostgreSQL、Oracle SQL、Databricks 和 SQLite。有关连接到数据库的要求的更多信息,请参阅 SQLAlchemy 文档。

例如,连接到 MySQL 需要适当的连接器,如 PyMySQL
MySQL 连接的 URI 可能类似于:mysql+pymysql://user:pass@some_mysql_db_address/db_name

内容

from langchain import OpenAI, SQLDatabase, SQLDatabaseChain
db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db")
llm = OpenAI(temperature=0, verbose=True)

注意:对于数据敏感的项目,您可以在 SQLDatabaseChain 初始化中指定 return_direct=True 来直接返回 SQL 查询的输出,而不需要任何额外的格式。这可以防止LLM看到数据库中的任何内容。但请注意,默认情况下,LLM 仍然可以访问数据库方案(即方言、表和键名称)。

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("How many employees are there?")

结果:

    
    
    > Entering new SQLDatabaseChain chain...
    How many employees are there?
    SQLQuery:

    /workspace/langchain/langchain/sql_database.py:191: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.
      sample_rows = connection.execute(command)


    SELECT COUNT(*) FROM "Employee";
    SQLResult: [(8,)]
    Answer:There are 8 employees.
    > Finished chain.

    'There are 8 employees.'

使用查询检查器(Use Query Checker)

有时,语言模型会生成带有小错误的无效 SQL,这些错误可以使用 SQL 数据库代理使用的相同技术来进行自我更正,以尝试使用 LLM 修复 SQL。您可以在创建链时简单地指定此选项:

# use_query_checker=True
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True)
db_chain.run("How many albums by Aerosmith?")

结果:

    
    
    > Entering new SQLDatabaseChain chain...
    How many albums by Aerosmith?
    SQLQuery:SELECT COUNT(*) FROM Album WHERE ArtistId = 3;
    SQLResult: [(1,)]
    Answer:There is 1 album by Aerosmith.
    > Finished chain.

    'There is 1 album by Aerosmith.'

自定义prompt(Customize Prompt)

您还可以自定义所使用的prompt。这是一个prompt示例:它理解 foobarEmployee 表是相同的;

from langchain.prompts.prompt import PromptTemplate
# 自定义prompt
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}

If someone asks for the table foobar, they really mean the employee table.

Question: {input}"""
PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "dialect"], template=_DEFAULT_TEMPLATE
)
db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True)
db_chain.run("How many employees are there in the foobar table?")

结果:

    > Entering new SQLDatabaseChain chain...
    How many employees are there in the foobar table?
    # 实际查询的是employee表
    SQLQuery:SELECT COUNT(*) FROM Employee;
    SQLResult: [(8,)]
    Answer:There are 8 employees in the foobar table.
    > Finished chain.

    'There are 8 employees in the foobar table.'

返回中间步骤(Return Intermediate Steps)

我们还可以返回 SQLDatabaseChain 的中间步骤。这允许您访问生成的 SQL 语句以及针对 SQL 数据库运行该语句的结果。

db_chain = SQLDatabaseChain.from_llm(llm, db, prompt=PROMPT, verbose=True, use_query_checker=True, return_intermediate_steps=True)

result = db_chain("How many employees are there in the foobar table?")
result["intermediate_steps"]

结果:

    > Entering new SQLDatabaseChain chain...
    How many employees are there in the foobar table?
    SQLQuery:SELECT COUNT(*) FROM Employee;
    SQLResult: [(8,)]
    Answer:There are 8 employees in the foobar table.
    > Finished chain.

    [{
    
    'input': 'How many employees are there in the foobar table?\nSQLQuery:SELECT COUNT(*) FROM Employee;\nSQLResult: [(8,)]\nAnswer:',
      'top_k': '5',
      'dialect': 'sqlite',
      'table_info': '\nCREATE TABLE "Artist" (\n\t"ArtistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("ArtistId")\n)\n\n/*\n3 rows from Artist table:\nArtistId\tName\n1\tAC/DC\n2\tAccept\n3\tAerosmith\n*/\n\n\nCREATE TABLE "Employee" (\n\t"EmployeeId" INTEGER NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"FirstName" NVARCHAR(20) NOT NULL, \n\t"Title" NVARCHAR(30), \n\t"ReportsTo" INTEGER, \n\t"BirthDate" DATETIME, \n\t"HireDate" DATETIME, \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60), \n\tPRIMARY KEY ("EmployeeId"), \n\tFOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")\n)\n\n/*\n3 rows from Employee table:\nEmployeeId\tLastName\tFirstName\tTitle\tReportsTo\tBirthDate\tHireDate\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\n1\tAdams\tAndrew\tGeneral Manager\tNone\t1962-02-18 00:00:00\t2002-08-14 00:00:00\t11120 Jasper Ave NW\tEdmonton\tAB\tCanada\tT5K 2N1\t+1 (780) 428-9482\t+1 (780) 428-3457\[email protected]\n2\tEdwards\tNancy\tSales Manager\t1\t1958-12-08 00:00:00\t2002-05-01 00:00:00\t825 8 Ave SW\tCalgary\tAB\tCanada\tT2P 2T3\t+1 (403) 262-3443\t+1 (403) 262-3322\[email protected]\n3\tPeacock\tJane\tSales Support Agent\t2\t1973-08-29 00:00:00\t2002-04-01 00:00:00\t1111 6 Ave SW\tCalgary\tAB\tCanada\tT2P 5M5\t+1 (403) 262-3443\t+1 (403) 262-6712\[email protected]\n*/\n\n\nCREATE TABLE "Genre" (\n\t"GenreId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("GenreId")\n)\n\n/*\n3 rows from Genre table:\nGenreId\tName\n1\tRock\n2\tJazz\n3\tMetal\n*/\n\n\nCREATE TABLE "MediaType" (\n\t"MediaTypeId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("MediaTypeId")\n)\n\n/*\n3 rows from MediaType table:\nMediaTypeId\tName\n1\tMPEG audio file\n2\tProtected AAC audio file\n3\tProtected MPEG-4 video file\n*/\n\n\nCREATE TABLE "Playlist" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("PlaylistId")\n)\n\n/*\n3 rows from Playlist table:\nPlaylistId\tName\n1\tMusic\n2\tMovies\n3\tTV Shows\n*/\n\n\nCREATE TABLE "Album" (\n\t"AlbumId" INTEGER NOT NULL, \n\t"Title" NVARCHAR(160) NOT NULL, \n\t"ArtistId" INTEGER NOT NULL, \n\tPRIMARY KEY ("AlbumId"), \n\tFOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")\n)\n\n/*\n3 rows from Album table:\nAlbumId\tTitle\tArtistId\n1\tFor Those About To Rock We Salute You\t1\n2\tBalls to the Wall\t2\n3\tRestless and Wild\t2\n*/\n\n\nCREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"SupportRepId" INTEGER, \n\tPRIMARY KEY ("CustomerId"), \n\tFOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")\n)\n\n/*\n3 rows from Customer table:\nCustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\[email protected]\t3\n2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\[email protected]\t5\n3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\[email protected]\t3\n*/\n\n\nCREATE TABLE "Invoice" (\n\t"InvoiceId" INTEGER NOT NULL, \n\t"CustomerId" INTEGER NOT NULL, \n\t"InvoiceDate" DATETIME NOT NULL, \n\t"BillingAddress" NVARCHAR(70), \n\t"BillingCity" NVARCHAR(40), \n\t"BillingState" NVARCHAR(40), \n\t"BillingCountry" NVARCHAR(40), \n\t"BillingPostalCode" NVARCHAR(10), \n\t"Total" NUMERIC(10, 2) NOT NULL, \n\tPRIMARY KEY ("InvoiceId"), \n\tFOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")\n)\n\n/*\n3 rows from Invoice table:\nInvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n*/\n\n\nCREATE TABLE "Track" (\n\t"TrackId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(200) NOT NULL, \n\t"AlbumId" INTEGER, \n\t"MediaTypeId" INTEGER NOT NULL, \n\t"GenreId" INTEGER, \n\t"Composer" NVARCHAR(220), \n\t"Milliseconds" INTEGER NOT NULL, \n\t"Bytes" INTEGER, \n\t"UnitPrice" NUMERIC(10, 2) NOT NULL, \n\tPRIMARY KEY ("TrackId"), \n\tFOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), \n\tFOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), \n\tFOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")\n)\n\n/*\n3 rows from Track table:\nTrackId\tName\tAlbumId\tMediaTypeId\tGenreId\tComposer\tMilliseconds\tBytes\tUnitPrice\n1\tFor Those About To Rock (We Salute You)\t1\t1\t1\tAngus Young, Malcolm Young, Brian Johnson\t343719\t11170334\t0.99\n2\tBalls to the Wall\t2\t2\t1\tNone\t342562\t5510424\t0.99\n3\tFast As a Shark\t3\t2\t1\tF. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\t230619\t3990994\t0.99\n*/\n\n\nCREATE TABLE "InvoiceLine" (\n\t"InvoiceLineId" INTEGER NOT NULL, \n\t"InvoiceId" INTEGER NOT NULL, \n\t"TrackId" INTEGER NOT NULL, \n\t"UnitPrice" NUMERIC(10, 2) NOT NULL, \n\t"Quantity" INTEGER NOT NULL, \n\tPRIMARY KEY ("InvoiceLineId"), \n\tFOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), \n\tFOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")\n)\n\n/*\n3 rows from InvoiceLine table:\nInvoiceLineId\tInvoiceId\tTrackId\tUnitPrice\tQuantity\n1\t1\t2\t0.99\t1\n2\t1\t4\t0.99\t1\n3\t2\t6\t0.99\t1\n*/\n\n\nCREATE TABLE "PlaylistTrack" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"TrackId" INTEGER NOT NULL, \n\tPRIMARY KEY ("PlaylistId", "TrackId"), \n\tFOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), \n\tFOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")\n)\n\n/*\n3 rows from PlaylistTrack table:\nPlaylistId\tTrackId\n1\t3402\n1\t3389\n1\t3390\n*/',
      'stop': ['\nSQLResult:']},
     'SELECT COUNT(*) FROM Employee;',
     {
    
    'query': 'SELECT COUNT(*) FROM Employee;', 'dialect': 'sqlite'},
     'SELECT COUNT(*) FROM Employee;',
     '[(8,)]']

选择如何限制返回的行数(Choosing how to limit the number of rows returned)

如果您要查询表的多行,您可以使用“top_k”参数选择要获取的最大结果数(默认值为 10)。这对于避免查询结果超出提示最大长度或不必要地消耗token非常有用。

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, use_query_checker=True, top_k=3)
db_chain.run("What are some example tracks by composer Johann Sebastian Bach?")

结果:

    
    
    > Entering new SQLDatabaseChain chain...
    What are some example tracks by composer Johann Sebastian Bach?
    SQLQuery:SELECT Name FROM Track WHERE Composer = 'Johann Sebastian Bach' LIMIT 3
    # 注意看这里,结果最多三条
    SQLResult: [('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',)]
    Answer:Examples of tracks by Johann Sebastian Bach are Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace, Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria, and Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude.
    > Finished chain.
    
    'Examples of tracks by Johann Sebastian Bach are Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace, Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria, and Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude.'

从每个表添加示例行(Adding example rows from each table)

有时,数据的格式并不明显,最好在prompt包含表中的行示例,以便 LLM 在提供最终查询之前理解数据。在这里,我们将使用此功能让LLM知道通过提供 Track 表中的两行来保存艺术家的全名。

db = SQLDatabase.from_uri(
    "sqlite:///../../../../notebooks/Chinook.db",
    include_tables=['Track'], # we include only one table to save tokens in the prompt :)
    sample_rows_in_table_info=2)

示例行将添加到prompt中每个对应表的列信息之后:

print(db.table_info)

    
CREATE TABLE "Track" (
    "TrackId" INTEGER NOT NULL, 
    "Name" NVARCHAR(200) NOT NULL, 
    "AlbumId" INTEGER, 
    "MediaTypeId" INTEGER NOT NULL, 
    "GenreId" INTEGER, 
    "Composer" NVARCHAR(220), 
    "Milliseconds" INTEGER NOT NULL, 
    "Bytes" INTEGER, 
    "UnitPrice" NUMERIC(10, 2) NOT NULL, 
    PRIMARY KEY ("TrackId"), 
    FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
    FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
    FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)
# 这个就是示例行
/*
2 rows from Track table:
TrackId Name    AlbumId MediaTypeId GenreId Composer    Milliseconds    Bytes   UnitPrice
1   For Those About To Rock (We Salute You) 1   1   1   Angus Young, Malcolm Young, Brian Johnson   343719  11170334    0.99
2   Balls to the Wall   2   2   1   None    342562  5510424 0.99
*/

自定义表信息(Custom Table Info)

在某些情况下,提供自定义表信息而不是使用自动生成的表定义和上面的sample_rows_in_table_info 示例行 可能会很有用。

例如,如果知道表的前几行没有提供任何信息,则手动提供更加多样化的示例行或为模型提供更多信息可能会有所帮助。如果存在不必要的列,还可以限制模型可见的列。

这些信息可以作为字典提供,其中表名作为键表信息作为值。例如,让我们为只有几列的 Track 表提供自定义定义和示例行:

custom_table_info = {
    
    
    "Track": """CREATE TABLE Track (
    "TrackId" INTEGER NOT NULL, 
    "Name" NVARCHAR(200) NOT NULL,
    "Composer" NVARCHAR(220),
    PRIMARY KEY ("TrackId")
)
/*
3 rows from Track table:
TrackId Name    Composer
1   For Those About To Rock (We Salute You) Angus Young, Malcolm Young, Brian Johnson
2   Balls to the Wall   None
3   My favorite song ever   The coolest composer of all time
*/"""
}

使用自定义表信息的代码如下:

db = SQLDatabase.from_uri(
    "sqlite:///../../../../notebooks/Chinook.db",
    include_tables=['Track', 'Playlist'],
    sample_rows_in_table_info=2,
    custom_table_info=custom_table_info)

print(db.table_info)

结果:

    
    CREATE TABLE "Playlist" (
        "PlaylistId" INTEGER NOT NULL, 
        "Name" NVARCHAR(120), 
        PRIMARY KEY ("PlaylistId")
    )
    
    /*
    2 rows from Playlist table:
    PlaylistId  Name
    1   Music
    2   Movies
    */
    
    CREATE TABLE Track (
        "TrackId" INTEGER NOT NULL, 
        "Name" NVARCHAR(200) NOT NULL,
        "Composer" NVARCHAR(220),
        PRIMARY KEY ("TrackId")
    )
    /*
    3 rows from Track table:
    TrackId Name    Composer
    1   For Those About To Rock (We Salute You) Angus Young, Malcolm Young, Brian Johnson
    2   Balls to the Wall   None
    3   My favorite song ever   The coolest composer of all time
    */

请注意:我们的自定义表结构和 Track 的示例行是如何覆盖sample_rows_in_table_info参数。

在此示例中的playlist表,未被 custom_table_info 覆盖,那么playlist表将像往常一样自动收集其表信息。

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("What are some example tracks by Bach?")

结果:

    
    
    > Entering new SQLDatabaseChain chain...
    What are some example tracks by Bach?
    SQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE '%Bach%' LIMIT 5;
    SQLResult: [('American Woman',), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata',)]
    Answer:text='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\n\nUse the following format:\n\nQuestion: "Question here"\nSQLQuery: "SQL Query to run"\nSQLResult: "Result of the SQLQuery"\nAnswer: "Final answer here"\n\nOnly use the following tables:\n\nCREATE TABLE "Playlist" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("PlaylistId")\n)\n\n/*\n2 rows from Playlist table:\nPlaylistId\tName\n1\tMusic\n2\tMovies\n*/\n\nCREATE TABLE Track (\n\t"TrackId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(200) NOT NULL,\n\t"Composer" NVARCHAR(220),\n\tPRIMARY KEY ("TrackId")\n)\n/*\n3 rows from Track table:\nTrackId\tName\tComposer\n1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n2\tBalls to the Wall\tNone\n3\tMy favorite song ever\tThe coolest composer of all time\n*/\n\nQuestion: What are some example tracks by Bach?\nSQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE \'%Bach%\' LIMIT 5;\nSQLResult: [(\'American Woman\',), (\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\',), (\'Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria\',), (\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\',), (\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\',)]\nAnswer:'
    You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
    Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
    Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
    Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
    
    Use the following format:
    
    Question: "Question here"
    SQLQuery: "SQL Query to run"
    SQLResult: "Result of the SQLQuery"
    Answer: "Final answer here"
    
    Only use the following tables:
    
    CREATE TABLE "Playlist" (
        "PlaylistId" INTEGER NOT NULL, 
        "Name" NVARCHAR(120), 
        PRIMARY KEY ("PlaylistId")
    )
    
    /*
    2 rows from Playlist table:
    PlaylistId  Name
    1   Music
    2   Movies
    */
    
    CREATE TABLE Track (
        "TrackId" INTEGER NOT NULL, 
        "Name" NVARCHAR(200) NOT NULL,
        "Composer" NVARCHAR(220),
        PRIMARY KEY ("TrackId")
    )
    /*
    3 rows from Track table:
    TrackId Name    Composer
    1   For Those About To Rock (We Salute You) Angus Young, Malcolm Young, Brian Johnson
    2   Balls to the Wall   None
    3   My favorite song ever   The coolest composer of all time
    */
    
    Question: What are some example tracks by Bach?
    SQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE '%Bach%' LIMIT 5;
    SQLResult: [('American Woman',), ('Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace',), ('Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria',), ('Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude',), ('Toccata and Fugue in D Minor, BWV 565: I. Toccata',)]
    Answer:
    {
    
    'input': 'What are some example tracks by Bach?\nSQLQuery:SELECT "Name" FROM Track WHERE "Composer" LIKE \'%Bach%\' LIMIT 5;\nSQLResult: [(\'American Woman\',), (\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\',), (\'Aria Mit 30 Veränderungen, BWV 988 "Goldberg Variations": Aria\',), (\'Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude\',), (\'Toccata and Fugue in D Minor, BWV 565: I. Toccata\',)]\nAnswer:', 'top_k': '5', 'dialect': 'sqlite', 'table_info': '\nCREATE TABLE "Playlist" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("PlaylistId")\n)\n\n/*\n2 rows from Playlist table:\nPlaylistId\tName\n1\tMusic\n2\tMovies\n*/\n\nCREATE TABLE Track (\n\t"TrackId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(200) NOT NULL,\n\t"Composer" NVARCHAR(220),\n\tPRIMARY KEY ("TrackId")\n)\n/*\n3 rows from Track table:\nTrackId\tName\tComposer\n1\tFor Those About To Rock (We Salute You)\tAngus Young, Malcolm Young, Brian Johnson\n2\tBalls to the Wall\tNone\n3\tMy favorite song ever\tThe coolest composer of all time\n*/', 'stop': ['\nSQLResult:']}
    Examples of tracks by Bach include "American Woman", "Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace", "Aria Mit 30 Veränderungen, BWV 988 'Goldberg Variations': Aria", "Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude", and "Toccata and Fugue in D Minor, BWV 565: I. Toccata".
    > Finished chain.

    'Examples of tracks by Bach include "American Woman", "Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace", "Aria Mit 30 Veränderungen, BWV 988 \'Goldberg Variations\': Aria", "Suite for Solo Cello No. 1 in G Major, BWV 1007: I. Prélude", and "Toccata and Fugue in D Minor, BWV 565: I. Toccata".'

SQLDatabaseSequentialChain

用于查询 SQL 数据库的链,它是一个顺序链。
工作步骤如下:

  1. 根据查询,确定要使用哪些表。
  2. 根据这些表,调用普通的SQL数据库链。

这在数据库中的表数量很大的情况下非常有用。

from langchain.chains import SQLDatabaseSequentialChain
db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db")

chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True)

chain.run("How many employees are also customers?")

结果:

    
    
    > Entering new SQLDatabaseSequentialChain chain...
    # 要使用的表名称
    Table names to use:
    ['Employee', 'Customer']
    
    > Entering new SQLDatabaseChain chain...
    # 问题:有多少名员工又是客户
    How many employees are also customers?
    # 执行的SQL
    SQLQuery:SELECT COUNT(*) FROM Employee e INNER JOIN Customer c ON e.EmployeeId = c.SupportRepId;
    # 执行的结果:
    SQLResult: [(59,)]
    # 回答的话术:
    Answer:59 employees are also customers.
    > Finished chain.
    
    > Finished chain.
    
    '59 employees are also customers.'

使用本地语言模型(Using Local Language Models)

有时我们可能没有机会使用 OpenAI 或其他服务托管的大型语言模型。
当然,您可以尝试将 SQLDatabaseChain 与本地模型一起使用,但很快会发现,即使使用大型 GPU,我们可以在本地运行的大多数模型都难以生成正确的输出

import logging
import torch
from transformers import AutoTokenizer, GPT2TokenizerFast, pipeline, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from langchain import HuggingFacePipeline

# Note: This model requires a large GPU, e.g. an 80GB A100. See documentation for other ways to run private non-OpenAI models.
# 注意:该模型需要较大的 GPU,例如80GB A100。有关运行私有非 OpenAI 模型的其他方法,请参阅文档。
model_id = "google/flan-ul2"
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, temperature=0)
# 默认为无 GPU,但如果可用,请使用 GPU 和半精度模式
device_id = -1  # default to no-GPU, but use GPU and half precision mode if available
if torch.cuda.is_available():
    device_id = 0
    try:
        model = model.half()
    except RuntimeError as exc:
        logging.warn(f"Could not run model in half precision mode: {
      
      str(exc)}")

tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline(task="text2text-generation", model=model, tokenizer=tokenizer, max_length=1024, device=device_id)

local_llm = HuggingFacePipeline(pipeline=pipe)

注意:

    /workspace/langchain/.venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
      from .autonotebook import tqdm as notebook_tqdm
    Loading checkpoint shards: 100%|██████████| 8/8 [00:32<00:00,  4.11s/it]
from langchain import SQLDatabase, SQLDatabaseChain

db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db", include_tables=['Customer'])
local_chain = SQLDatabaseChain.from_llm(local_llm, db, verbose=True, return_intermediate_steps=True, use_query_checker=True)

只要您使用上面指定的查询检查器,该模型就应该适用于非常简单的 SQL 查询,例如:

# 有多少客户?
local_chain("How many customers are there?")
    
    
    > Entering new SQLDatabaseChain chain...
    How many customers are there?
    SQLQuery:

    /workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
      warnings.warn(
    /workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
      warnings.warn(


    SELECT count(*) FROM Customer
    SQLResult: [(59,)]
    Answer:

    /workspace/langchain/.venv/lib/python3.9/site-packages/transformers/pipelines/base.py:1070: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
      warnings.warn(


    [59]
    > Finished chain.

    {
    
    'query': 'How many customers are there?',
     'result': '[59]',
     'intermediate_steps': [{
    
    'input': 'How many customers are there?\nSQLQuery:SELECT count(*) FROM Customer\nSQLResult: [(59,)]\nAnswer:',
       'top_k': '5',
       'dialect': 'sqlite',
       'table_info': '\nCREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"SupportRepId" INTEGER, \n\tPRIMARY KEY ("CustomerId"), \n\tFOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")\n)\n\n/*\n3 rows from Customer table:\nCustomerId\tFirstName\tLastName\tCompany\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\tSupportRepId\n1\tLuís\tGonçalves\tEmbraer - Empresa Brasileira de Aeronáutica S.A.\tAv. Brigadeiro Faria Lima, 2170\tSão José dos Campos\tSP\tBrazil\t12227-000\t+55 (12) 3923-5555\t+55 (12) 3923-5566\[email protected]\t3\n2\tLeonie\tKöhler\tNone\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t+49 0711 2842222\tNone\[email protected]\t5\n3\tFrançois\tTremblay\tNone\t1498 rue Bélanger\tMontréal\tQC\tCanada\tH2G 1A7\t+1 (514) 721-4711\tNone\[email protected]\t3\n*/',
       'stop': ['\nSQLResult:']},
      'SELECT count(*) FROM Customer',
      {
    
    'query': 'SELECT count(*) FROM Customer', 'dialect': 'sqlite'},
      'SELECT count(*) FROM Customer',
      '[(59,)]']}

即使这个相对较大的模型也很可能无法自行生成更复杂的 SQL。但是,我们可以记录其输入和输出,以便我们手动更正它们,并在稍后的几个提示示例中使用更正后的示例。在实践中,您可以记录引发异常链的任何执行(如下例所示),或者在结果不正确(但未引发异常)的情况下获取直接用户反馈。

poetry run pip install pyyaml chromadb
import yaml
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


    11842.36s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


    Requirement already satisfied: pyyaml in /workspace/langchain/.venv/lib/python3.9/site-packages (6.0)
    Requirement already satisfied: chromadb in /workspace/langchain/.venv/lib/python3.9/site-packages (0.3.21)
    Requirement already satisfied: pandas>=1.3 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (2.0.1)
    Requirement already satisfied: requests>=2.28 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (2.28.2)
    Requirement already satisfied: pydantic>=1.9 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (1.10.7)
    Requirement already satisfied: hnswlib>=0.7 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.7.0)
    Requirement already satisfied: clickhouse-connect>=0.5.7 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.5.20)
    Requirement already satisfied: sentence-transformers>=2.2.2 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (2.2.2)
    Requirement already satisfied: duckdb>=0.7.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.7.1)
    Requirement already satisfied: fastapi>=0.85.1 in /workspace/langchain/.venv/lib/python3.9/site-packages (from chromadb) (0.95.1)

... ...
from typing import Dict
# 列出所有以“a”开头的客户名字
QUERY = "List all the customer first names that start with 'a'"

def _parse_example(result: Dict) -> Dict:
    sql_cmd_key = "sql_cmd"
    sql_result_key = "sql_result"
    table_info_key = "table_info"
    input_key = "input"
    final_answer_key = "answer"

    _example = {
    
    
        "input": result.get("query"),
    }

    steps = result.get("intermediate_steps")
    answer_key = sql_cmd_key # the first one
    for step in steps:
	    #这些步骤是成对的,一个字典(输入)后跟一个字符串(输出)。
		# 不幸的是没有模式,但你可以查看输入键
		# dict 看看输出应该是什么
        # The steps are in pairs, a dict (input) followed by a string (output).
        # Unfortunately there is no schema but you can look at the input key of the
        # dict to see what the output is supposed to be
        if isinstance(step, dict):
            # Grab the table info from input dicts in the intermediate steps once
            if table_info_key not in _example:
                _example[table_info_key] = step.get(table_info_key)

            if input_key in step:
                if step[input_key].endswith("SQLQuery:"):
                    answer_key = sql_cmd_key # this is the SQL generation input
                if step[input_key].endswith("Answer:"):
                    answer_key = final_answer_key # this is the final answer input
            elif sql_cmd_key in step:
                _example[sql_cmd_key] = step[sql_cmd_key]
                answer_key = sql_result_key # this is SQL execution input
        elif isinstance(step, str):
            # The preceding element should have set the answer_key
            _example[answer_key] = step
    return _example

example: any
try:
    result = local_chain(QUERY)
    print("*** Query succeeded")
    example = _parse_example(result)
except Exception as exc:
    print("*** Query failed")
    result = {
    
    
        "query": QUERY,
        "intermediate_steps": exc.intermediate_steps
    }
    example = _parse_example(result)


# print for now, in reality you may want to write this out to a YAML file or database for manual fix-ups offline
yaml_example = yaml.dump(example, allow_unicode=True)
print("\n" + yaml_example)

运行上面的代码片段几次,或者在部署的环境中记录异常,以收集语言模型生成的大量inputtable_infosql_cmd 示例。 sql_cmd 值如果不正确,您可以手动修复它们以构建示例集合,例如在这里,我们使用 YAML 来保存我们的输入的整洁记录以及我们可以随着时间的推移建立的更正的 SQL 输出。

YAML_EXAMPLES = """
- input: How many customers are not from Brazil?
  table_info: |
    CREATE TABLE "Customer" (
      "CustomerId" INTEGER NOT NULL, 
      "FirstName" NVARCHAR(40) NOT NULL, 
      "LastName" NVARCHAR(20) NOT NULL, 
      "Company" NVARCHAR(80), 
      "Address" NVARCHAR(70), 
      "City" NVARCHAR(40), 
      "State" NVARCHAR(40), 
      "Country" NVARCHAR(40), 
      "PostalCode" NVARCHAR(10), 
      "Phone" NVARCHAR(24), 
      "Fax" NVARCHAR(24), 
      "Email" NVARCHAR(60) NOT NULL, 
      "SupportRepId" INTEGER, 
      PRIMARY KEY ("CustomerId"), 
      FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
    )
  sql_cmd: SELECT COUNT(*) FROM "Customer" WHERE NOT "Country" = "Brazil";
  sql_result: "[(54,)]"
  answer: 54 customers are not from Brazil.
- input: list all the genres that start with 'r'
  table_info: |
    CREATE TABLE "Genre" (
      "GenreId" INTEGER NOT NULL, 
      "Name" NVARCHAR(120), 
      PRIMARY KEY ("GenreId")
    )

    /*
    3 rows from Genre table:
    GenreId Name
    1   Rock
    2   Jazz
    3   Metal
    */
  sql_cmd: SELECT "Name" FROM "Genre" WHERE "Name" LIKE 'r%';
  sql_result: "[('Rock',), ('Rock and Roll',), ('Reggae',), ('R&B/Soul',)]"
  answer: The genres that start with 'r' are Rock, Rock and Roll, Reggae and R&B/Soul. 
"""

现在您已经有了一些示例(带有手动更正的输出 SQL),您可以按照通常的方式进行FewShotPromptTemplate

from langchain import FewShotPromptTemplate, PromptTemplate
from langchain.chains.sql_database.prompt import _sqlite_prompt, PROMPT_SUFFIX
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.prompts.example_selector.semantic_similarity import SemanticSimilarityExampleSelector
from langchain.vectorstores import Chroma

example_prompt = PromptTemplate(
    input_variables=["table_info", "input", "sql_cmd", "sql_result", "answer"],
    template="{table_info}\n\nQuestion: {input}\nSQLQuery: {sql_cmd}\nSQLResult: {sql_result}\nAnswer: {answer}",
)

examples_dict = yaml.safe_load(YAML_EXAMPLES)

local_embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

example_selector = SemanticSimilarityExampleSelector.from_examples(
                        # This is the list of examples available to select from.
                        examples_dict,
                        # This is the embedding class used to produce embeddings which are used to measure semantic similarity.
                        local_embeddings,
                        # This is the VectorStore class that is used to store the embeddings and do a similarity search over.
                        Chroma,  # type: ignore
                        # This is the number of examples to produce and include per prompt
                        k=min(3, len(examples_dict)),
                    )

few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=_sqlite_prompt + "Here are some examples:",
    suffix=PROMPT_SUFFIX,
    input_variables=["table_info", "input", "top_k"],
)
    Using embedded DuckDB without persistence: data will be transient

现在,通过这几个few shot prompt,模型应该做得更好.

local_chain = SQLDatabaseChain.from_llm(local_llm, db, prompt=few_shot_prompt, use_query_checker=True, verbose=True, return_intermediate_steps=True)

result = local_chain("How many customers are from Brazil?")

结果:

    
    
    > Entering new SQLDatabaseChain chain...
    How many customers are from Brazil?
    SQLQuery:SELECT count(*) FROM Customer WHERE Country = "Brazil";
    SQLResult: [(5,)]
    Answer:[5]
    > Finished chain.

参考地址:

https://python.langchain.com/docs/modules/chains/popular/sqlite

猜你喜欢

转载自blog.csdn.net/u013066244/article/details/131651918