Using LLMs and Langchain to Ask Questions about SQL Databases

Some of the most valuable information to make LLMs useful is in structured data such as atabases. In this article we show how we can use langchain to help LLMs answer questions based on information stored in an SQL database.
natural-language-processing
agents
langchain
openai
llm-evaluation
Author

Pranath Fernando

Published

August 20, 2023

1 Introduction

SQL databases are frequently used to hold enterprise data. Natural language interaction with SQL databases is made feasible by LLMs such as OpenAI’s ChatGPT and GPT Models. LangChain provides SQL Chains and Agents for building and running SQL queries based on natural language prompts. These SQL Chains and Agents are compatible with any SQL dialect supported by SQLAlchemy (e.g., MySQL, PostgreSQL, Oracle SQL, Databricks, SQLite).

They enable use cases like:

  • Creating queries that will be executed in response to natural language questions
  • Developing chatbots that can answer queries based on database data
  • Developing custom dashboards based on information that a user want to analyse

In this article we will see different ways we can use langchain and LLM’s to ask questions about data in an SQL database.

2 Overview

LangChain provides tools to interact with SQL Databases:

  1. Build SQL queries based on natural language user questions
  2. Query a SQL database using chains for query creation and execution
  3. Interact with a SQL database using agents for robust and flexible querying

3 Import Libs & Setup

First, get required packages and set environment variables:

import os
from dotenv import load_dotenv
load_dotenv()

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

For our project we are also going to use Langsmith for logging runs and visulising runs, I wrote an article introducing Langsmith previously. Let’s set up and configure Langsmith.

from uuid import uuid4

unique_id = uuid4().hex[0:8]
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = f"Langchain SQL Demo - {unique_id}"
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")  # Update to your API key

# Used by the agent in this post
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

We will also use an SQLite connection with the Chinook database. The Chinook database is an sample data database that represents a digital media store, including tables for artists, albums, media tracks, invoices and customers.

There are 11 tables in the chinook sample database.

  • employees table stores employees data such as employee id, last name, first name, etc. It also has a field named ReportsTo to specify who reports to whom.
  • customers table stores customers data.
  • invoices & invoice_items tables these two tables store invoice data. The invoices table stores invoice header data and the invoice_items table stores the invoice line items data.
  • artists table stores artists data. It is a simple table that contains only the artist id and name.
  • albums table stores data about a list of tracks. Each album belongs to one artist. However, one artist may have multiple albums.
  • media_types table stores media types such as MPEG audio and AAC audio files.
  • genres table stores music types such as rock, jazz, metal, etc.
  • tracks table stores the data of songs. Each track belongs to one album.
  • playlists & playlist_track tables playlists table store data about playlists. Each playlist contains a list of tracks. Each track may belong to multiple playlists. The relationship between the playlists table and tracks table is many-to-many. The playlist_track table is used to reflect this relationship.

Follow installation steps to create Chinook.db in the same directory as this notebook:

  • Save this file to the directory as Chinook_Sqlite.sql
  • Run sqlite3 Chinook.db
  • Run .read Chinook_Sqlite.sql
  • Test SELECT * FROM Artist LIMIT 10;

Now, Chinhook.db is in our directory.

Let’s create a SQLDatabaseChain to create and execute SQL queries.

from langchain.utilities import SQLDatabase
from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain

db = SQLDatabase.from_uri("sqlite:///docs/Chinook.db")
llm = OpenAI(temperature=0, verbose=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
/Users/pranathfernando/opt/anaconda3/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.24) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.
  warnings.warn(
db_chain.run("How many customers are there?")


> Entering new SQLDatabaseChain chain...
How many customers are there?
SQLQuery:SELECT COUNT(*) FROM Customer;
SQLResult: [(59,)]
Answer:There are 59 customers.
> Finished chain.
/Users/pranathfernando/opt/anaconda3/lib/python3.9/site-packages/langchain/utilities/sql_database.py:357: SAWarning: Dialect sqlite+pysqlite does *not* support Decimal objects natively, and SQLAlchemy must convert from floating point - rounding errors and other issues may occur. Please consider storing Decimal numbers as strings or integers on this platform for lossless storage.
  sample_rows_result = connection.execute(command)  # type: ignore
'There are 59 customers.'

Note that this both creates and executes the query. In the following sections, we will cover the 3 different use cases mentioned in the overview.

4 Case 1: Text-to-SQL query

from langchain.chat_models import ChatOpenAI
from langchain.chains import create_sql_query_chain

Let’s create the chain that will build the SQL Query:

chain = create_sql_query_chain(ChatOpenAI(temperature=0), db)
response = chain.invoke({"question":"How many customers are there"})
print(response)
SELECT COUNT(*) FROM Customer

After building the SQL query based on a user question, we can execute the query:

db.run(response)
'[(59,)]'

As we can see, the SQL Query Builder chain only created the query, and we handled the query execution separately.

4.1 Go deeper

Looking under the hood

We can look at the LangSmith trace to unpack this:

This is the full text of the prompt created from that query:

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:

CREATE TABLE "Album" (
    "AlbumId" INTEGER NOT NULL, 
    "Title" NVARCHAR(160) NOT NULL, 
    "ArtistId" INTEGER NOT NULL, 
    PRIMARY KEY ("AlbumId"), 
    FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId Title   ArtistId
1   For Those About To Rock We Salute You   1
2   Balls to the Wall   2
3   Restless and Wild   2
*/


CREATE TABLE "Artist" (
    "ArtistId" INTEGER NOT NULL, 
    "Name" NVARCHAR(120), 
    PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId    Name
1   AC/DC
2   Accept
3   Aerosmith
*/


CREATE TABLE "Customer" (
    "CustomerId" INTEGER NOT NULL, 
    "FirstName" NVARCHAR(40) NOT NULL, 
    "LastName" NVARCHAR(20) NOT NULL, 
    "Company" NVARCHAR(80), 
    "Address" NVARCHAR(70), 
    "City" NVARCHAR(40), 
    "State" NVARCHAR(40), 
    "Country" NVARCHAR(40), 
    "PostalCode" NVARCHAR(10), 
    "Phone" NVARCHAR(24), 
    "Fax" NVARCHAR(24), 
    "Email" NVARCHAR(60) NOT NULL, 
    "SupportRepId" INTEGER, 
    PRIMARY KEY ("CustomerId"), 
    FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId  FirstName   LastName    Company Address City    State   Country PostalCode  Phone   Fax Email   SupportRepId
1   Luís    Gonçalves   Embraer - Empresa Brasileira de Aeronáutica S.A.    Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP  Brazil  12227-000   +55 (12) 3923-5555  +55 (12) 3923-5566  luisg@embraer.com.br    3
2   Leonie  Köhler  None    Theodor-Heuss-Straße 34 Stuttgart   None    Germany 70174   +49 0711 2842222    None    leonekohler@surfeu.de   5
3   François    Tremblay    None    1498 rue Bélanger   Montréal    QC  Canada  H2G 1A7 +1 (514) 721-4711   None    ftremblay@gmail.com 3
*/


CREATE TABLE "Employee" (
    "EmployeeId" INTEGER NOT NULL, 
    "LastName" NVARCHAR(20) NOT NULL, 
    "FirstName" NVARCHAR(20) NOT NULL, 
    "Title" NVARCHAR(30), 
    "ReportsTo" INTEGER, 
    "BirthDate" DATETIME, 
    "HireDate" DATETIME, 
    "Address" NVARCHAR(70), 
    "City" NVARCHAR(40), 
    "State" NVARCHAR(40), 
    "Country" NVARCHAR(40), 
    "PostalCode" NVARCHAR(10), 
    "Phone" NVARCHAR(24), 
    "Fax" NVARCHAR(24), 
    "Email" NVARCHAR(60), 
    PRIMARY KEY ("EmployeeId"), 
    FOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Employee table:
EmployeeId  LastName    FirstName   Title   ReportsTo   BirthDate   HireDate    Address City    State   Country PostalCode  Phone   Fax Email
1   Adams   Andrew  General Manager None    1962-02-18 00:00:00 2002-08-14 00:00:00 11120 Jasper Ave NW Edmonton    AB  Canada  T5K 2N1 +1 (780) 428-9482   +1 (780) 428-3457   andrew@chinookcorp.com
2   Edwards Nancy   Sales Manager   1   1958-12-08 00:00:00 2002-05-01 00:00:00 825 8 Ave SW    Calgary AB  Canada  T2P 2T3 +1 (403) 262-3443   +1 (403) 262-3322   nancy@chinookcorp.com
3   Peacock Jane    Sales Support Agent 2   1973-08-29 00:00:00 2002-04-01 00:00:00 1111 6 Ave SW   Calgary AB  Canada  T2P 5M5 +1 (403) 262-3443   +1 (403) 262-6712   jane@chinookcorp.com
*/


CREATE TABLE "Genre" (
    "GenreId" INTEGER NOT NULL, 
    "Name" NVARCHAR(120), 
    PRIMARY KEY ("GenreId")
)

/*
3 rows from Genre table:
GenreId Name
1   Rock
2   Jazz
3   Metal
*/


CREATE TABLE "Invoice" (
    "InvoiceId" INTEGER NOT NULL, 
    "CustomerId" INTEGER NOT NULL, 
    "InvoiceDate" DATETIME NOT NULL, 
    "BillingAddress" NVARCHAR(70), 
    "BillingCity" NVARCHAR(40), 
    "BillingState" NVARCHAR(40), 
    "BillingCountry" NVARCHAR(40), 
    "BillingPostalCode" NVARCHAR(10), 
    "Total" NUMERIC(10, 2) NOT NULL, 
    PRIMARY KEY ("InvoiceId"), 
    FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
3 rows from Invoice table:
InvoiceId   CustomerId  InvoiceDate BillingAddress  BillingCity BillingState    BillingCountry  BillingPostalCode   Total
1   2   2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart   None    Germany 70174   1.98
2   4   2009-01-02 00:00:00 Ullevålsveien 14    Oslo    None    Norway  0171    3.96
3   8   2009-01-03 00:00:00 Grétrystraat 63 Brussels    None    Belgium 1000    5.94
*/


CREATE TABLE "InvoiceLine" (
    "InvoiceLineId" INTEGER NOT NULL, 
    "InvoiceId" INTEGER NOT NULL, 
    "TrackId" INTEGER NOT NULL, 
    "UnitPrice" NUMERIC(10, 2) NOT NULL, 
    "Quantity" INTEGER NOT NULL, 
    PRIMARY KEY ("InvoiceLineId"), 
    FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
    FOREIGN KEY("InvoiceId") REFERENCES "Invoice" ("InvoiceId")
)

/*
3 rows from InvoiceLine table:
InvoiceLineId   InvoiceId   TrackId UnitPrice   Quantity
1   1   2   0.99    1
2   1   4   0.99    1
3   2   6   0.99    1
*/


CREATE TABLE "MediaType" (
    "MediaTypeId" INTEGER NOT NULL, 
    "Name" NVARCHAR(120), 
    PRIMARY KEY ("MediaTypeId")
)

/*
3 rows from MediaType table:
MediaTypeId Name
1   MPEG audio file
2   Protected AAC audio file
3   Protected MPEG-4 video file
*/


CREATE TABLE "Playlist" (
    "PlaylistId" INTEGER NOT NULL, 
    "Name" NVARCHAR(120), 
    PRIMARY KEY ("PlaylistId")
)

/*
3 rows from Playlist table:
PlaylistId  Name
1   Music
2   Movies
3   TV Shows
*/


CREATE TABLE "PlaylistTrack" (
    "PlaylistId" INTEGER NOT NULL, 
    "TrackId" INTEGER NOT NULL, 
    PRIMARY KEY ("PlaylistId", "TrackId"), 
    FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
    FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId  TrackId
1   3402
1   3389
1   3390
*/


CREATE TABLE "Track" (
    "TrackId" INTEGER NOT NULL, 
    "Name" NVARCHAR(200) NOT NULL, 
    "AlbumId" INTEGER, 
    "MediaTypeId" INTEGER NOT NULL, 
    "GenreId" INTEGER, 
    "Composer" NVARCHAR(220), 
    "Milliseconds" INTEGER NOT NULL, 
    "Bytes" INTEGER, 
    "UnitPrice" NUMERIC(10, 2) NOT NULL, 
    PRIMARY KEY ("TrackId"), 
    FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
    FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
    FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)

/*
3 rows from Track table:
TrackId Name    AlbumId MediaTypeId GenreId Composer    Milliseconds    Bytes   UnitPrice
1   For Those About To Rock (We Salute You) 1   1   1   Angus Young, Malcolm Young, Brian Johnson   343719  11170334    0.99
2   Balls to the Wall   2   2   1   None    342562  5510424 0.99
3   Fast As a Shark 3   2   1   F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman 230619  3990994 0.99
*/

Question: How many customers are there
SQLQuery: 

Some papers have reported good performance when prompting with:

  • A CREATE TABLE description for each table, which include column names, their types, etc
  • Followed by three example rows in a SELECT statement

create_sql_query_chain adopts this the best practice (see more in this blog).

Improvements

The query builder can be improved in a variety of ways, including (but not limited to):

  • Tailoring the database description to your particular use case
  • Using a vector database to provide dynamic examples that are relevant to the individual user question - Hardcoding a few instances of questions and their related SQL query in the prompt

All of these examples involve changing the prompt for the chain. For example, we could include the following instances in our prompt:

from langchain.prompts import PromptTemplate

TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQL queries that corrsespond to questions are:

{few_shot_examples}

Question: {input}"""

CUSTOM_PROMPT = PromptTemplate(
    input_variables=["input", "few_shot_examples", "table_info", "dialect"], template=TEMPLATE
)

5 Case 2: Text-to-SQL query and execution

We can use SQLDatabaseChain from langchain_experimental to create and run SQL queries.

from langchain.llms import OpenAI
from langchain_experimental.sql import SQLDatabaseChain

llm = OpenAI(temperature=0, verbose=True)
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
db_chain.run("How many customers are there?")


> Entering new SQLDatabaseChain chain...
How many customers are there?
SQLQuery:SELECT COUNT(*) FROM Customer;
SQLResult: [(59,)]
Answer:There are 59 customers.
> Finished chain.
'There are 59 customers.'

As we can see, we get the same result as the previous case.

Here, the chain also handles the query execution and provides a final answer based on the user question and the query result.

Be careful while using this approach as it is susceptible to SQL Injection:

  • The chain is executing queries that are created by an LLM, and weren’t validated
  • e.g. records may be created, modified or deleted unintentionally_

This is why we see the SQLDatabaseChain is inside langchain_experimental.

5.1 Go deeper

Looking under the hood

We can use the LangSmith trace to see what is happening under the hood:

  • As discussed above, first we create the query:
text: ' SELECT COUNT(*) FROM "Customer";'
  • Then, it executes the query and passes the results to an LLM for synthesis.

Improvements

The performance of the SQLDatabaseChain can be enhanced in several ways:

You might find SQLDatabaseSequentialChain useful for cases in which the number of tables in the database is large.

This Sequential Chain handles the process of:

  1. Determining which tables to use based on the user question
  2. Calling the normal SQL database chain using only relevant tables

Adding Sample Rows

Providing sample data can help the LLM construct correct queries when the data format is not obvious.

For example, we can tell LLM that artists are saved with their full names by providing two rows from the Track table.

db = SQLDatabase.from_uri(
    "sqlite:///docs/Chinook.db",
    include_tables=['Track'], # we include only one table to save tokens in the prompt :)
    sample_rows_in_table_info=2)

The sample rows are added to the prompt after each corresponding table’s column information.

We can use db.table_info and check which sample rows are included:

print(db.table_info)

CREATE TABLE "Track" (
    "TrackId" INTEGER NOT NULL, 
    "Name" NVARCHAR(200) NOT NULL, 
    "AlbumId" INTEGER, 
    "MediaTypeId" INTEGER NOT NULL, 
    "GenreId" INTEGER, 
    "Composer" NVARCHAR(220), 
    "Milliseconds" INTEGER NOT NULL, 
    "Bytes" INTEGER, 
    "UnitPrice" NUMERIC(10, 2) NOT NULL, 
    PRIMARY KEY ("TrackId"), 
    FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("MediaTypeId"), 
    FOREIGN KEY("GenreId") REFERENCES "Genre" ("GenreId"), 
    FOREIGN KEY("AlbumId") REFERENCES "Album" ("AlbumId")
)

/*
2 rows from Track table:
TrackId Name    AlbumId MediaTypeId GenreId Composer    Milliseconds    Bytes   UnitPrice
1   For Those About To Rock (We Salute You) 1   1   1   Angus Young, Malcolm Young, Brian Johnson   343719  11170334    0.99
2   Balls to the Wall   2   2   1   None    342562  5510424 0.99
*/

6 Case 3: SQL agents

LangChain has a SQL Agent that is more flexible than the ‘SQLDatabaseChain’ in communicating with SQL Databases.

The following are the primary benefits of utilising the SQL Agent:

  • It can answer questions based on the schema as well as the content of the databases (for example, describing a specific table).
  • It can recover from problems by running a created query, capturing the traceback, and correctly rebuilding it.

In this article the author desribed reasons why you might want to consider using an agent for SQL queries rather than just a chain:

‘…Let us first understand what is an agent and why it might be preferred over a simple SQLChain. An agent is a component that has access to a suite of tools, including a Large Language Model (LLM). Its distinguishing characteristic lies in its ability to make informed decisions based on user input, utilizing the appropriate tools until it achieves a satisfactory answer. For example in the context of text-to-SQL, the LangChain SQLAgent will not give up if there is an error in executing the generated SQL. Instead, it will attempt to recover by interpreting the error in a subsequent LLM call and rectify the issue. Therefore, in theory, SQLAgent should outperform SQLChain in productivity and accuracy’

And this is what that author found from their experiments:

‘…During our tests, we ran multiple questions on both SQLChain and SQLAgent using GPT-3.5 and compared their respective results. Our findings revealed that SQLAgent outperformed SQLChain by answering a greater number of questions…For accuracy, however, our findings also indicate a higher incidence of incorrect responses from SQLAgent. Besides the general shortcomings of using LLM to query database, we hypothesize that SQLAgent will occasionally make its best attempt to answer a question even when concrete results cannot be obtained from the SQL query.’

The ‘create_sql_agent’ method is used to initialise the agent.

This agent includes the ‘SQLDatabaseToolkit,’ which includes tools for:

  • Create and run queries
  • Verify query syntax
  • Get table descriptions
  • … and much more
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
# from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType

db = SQLDatabase.from_uri("sqlite:///docs/Chinook.db")
llm = OpenAI(temperature=0, verbose=True)

agent_executor = create_sql_agent(
    llm=OpenAI(temperature=0),
    toolkit=SQLDatabaseToolkit(db=db, llm=OpenAI(temperature=0)),
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
)

6.1 Agent task example #1 - Running queries

agent_executor.run(
    "List the total sales per country. Which country's customers spent the most?"
)


> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: 
Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Thought: I should query the schema of the Invoice and Customer tables.
Action: sql_db_schema
Action Input: Invoice, Customer
Observation: 
CREATE TABLE "Customer" (
    "CustomerId" INTEGER NOT NULL, 
    "FirstName" NVARCHAR(40) NOT NULL, 
    "LastName" NVARCHAR(20) NOT NULL, 
    "Company" NVARCHAR(80), 
    "Address" NVARCHAR(70), 
    "City" NVARCHAR(40), 
    "State" NVARCHAR(40), 
    "Country" NVARCHAR(40), 
    "PostalCode" NVARCHAR(10), 
    "Phone" NVARCHAR(24), 
    "Fax" NVARCHAR(24), 
    "Email" NVARCHAR(60) NOT NULL, 
    "SupportRepId" INTEGER, 
    PRIMARY KEY ("CustomerId"), 
    FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId  FirstName   LastName    Company Address City    State   Country PostalCode  Phone   Fax Email   SupportRepId
1   Luís    Gonçalves   Embraer - Empresa Brasileira de Aeronáutica S.A.    Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP  Brazil  12227-000   +55 (12) 3923-5555  +55 (12) 3923-5566  luisg@embraer.com.br    3
2   Leonie  Köhler  None    Theodor-Heuss-Straße 34 Stuttgart   None    Germany 70174   +49 0711 2842222    None    leonekohler@surfeu.de   5
3   François    Tremblay    None    1498 rue Bélanger   Montréal    QC  Canada  H2G 1A7 +1 (514) 721-4711   None    ftremblay@gmail.com 3
*/


CREATE TABLE "Invoice" (
    "InvoiceId" INTEGER NOT NULL, 
    "CustomerId" INTEGER NOT NULL, 
    "InvoiceDate" DATETIME NOT NULL, 
    "BillingAddress" NVARCHAR(70), 
    "BillingCity" NVARCHAR(40), 
    "BillingState" NVARCHAR(40), 
    "BillingCountry" NVARCHAR(40), 
    "BillingPostalCode" NVARCHAR(10), 
    "Total" NUMERIC(10, 2) NOT NULL, 
    PRIMARY KEY ("InvoiceId"), 
    FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
3 rows from Invoice table:
InvoiceId   CustomerId  InvoiceDate BillingAddress  BillingCity BillingState    BillingCountry  BillingPostalCode   Total
1   2   2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart   None    Germany 70174   1.98
2   4   2009-01-02 00:00:00 Ullevålsveien 14    Oslo    None    Norway  0171    3.96
3   8   2009-01-03 00:00:00 Grétrystraat 63 Brussels    None    Belgium 1000    5.94
*/
Thought: I should query the total sales per country.
Action: sql_db_query
Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10
Observation: [('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]
Thought: I now know the final answer
Final Answer: The country with the highest total sales is the USA, with a total of $523.06.

> Finished chain.
'The country with the highest total sales is the USA, with a total of $523.06.'

Looking at the LangSmith trace, we can see:

  • The agent is using a ReAct style prompt
  • First, it will look at the tables: Action: sql_db_list_tables using tool sql_db_list_tables
  • Given the tables as an observation, it thinks and then determinates the next action:
Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Thought: I should query the schema of the Invoice and Customer tables.
Action: sql_db_schema
Action Input: Invoice, Customer
  • It then formulates the query using the schema from tool sql_db_schema
Thought: I should query the total sales per country.
Action: sql_db_query
Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10

  • It finally executes the generated query using tool sql_db_query

6.2 Agent task example #2 - Describing a Table

agent_executor.run("Describe the playlisttrack table")


> Entering new AgentExecutor chain...
Action: sql_db_list_tables
Action Input: 
Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Thought: I should query the schema of the PlaylistTrack table
Action: sql_db_schema
Action Input: PlaylistTrack
Observation: 
CREATE TABLE "PlaylistTrack" (
    "PlaylistId" INTEGER NOT NULL, 
    "TrackId" INTEGER NOT NULL, 
    PRIMARY KEY ("PlaylistId", "TrackId"), 
    FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
    FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId  TrackId
1   3402
1   3389
1   3390
*/
Thought: I now know the final answer
Final Answer: The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.

> Finished chain.
'The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.'

7 Extending the SQL Toolkit with Domain Specific Knowledge Tools

In a recent Langchain blog article on 5/9/23 they highlighted how you can use few shot examples to bring in domain specific knowlege as an alternative approach.

Although the Langchain SQL Toolkit includes all of the tools needed to begin working on a database, several additional tools may be beneficial for increasing the agent’s capabilities. This is especially important when attempting to integrate domain-specific information in the solution to increase overall performance.

Here are a few examples:

  • Including dynamic demonstrations of a few shots
  • Identifying misspellings of proper nouns for use as column filters

We can develop distinct tools to address these unique use cases and include them as an addition to the regular SQL Toolkit. Let’s look at how to incorporate these two bespoke tools.

7.1 Including dynamic few-shot examples

To integrate dynamic few-shot examples, we require a custom Retriever Tool that searches the vector database for examples that are semantically related to the user’s query.

Let’s begin by making a dictionary out of several examples:

few_shots = {'List all artists.': 'SELECT * FROM artists;',
              "Find all albums for the artist 'AC/DC'.": "SELECT * FROM albums WHERE ArtistId = (SELECT ArtistId FROM artists WHERE Name = 'AC/DC');",
              "List all tracks in the 'Rock' genre.": "SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');",
              'Find the total duration of all tracks.': 'SELECT SUM(Milliseconds) FROM tracks;',
              'List all customers from Canada.': "SELECT * FROM customers WHERE Country = 'Canada';",
              'How many tracks are there in the album with ID 5?': 'SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;',
              'Find the total number of invoices.': 'SELECT COUNT(*) FROM invoices;',
              'List all tracks that are longer than 5 minutes.': 'SELECT * FROM tracks WHERE Milliseconds > 300000;',
              'Who are the top 5 customers by total purchase?': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;',
              'Which albums are from the year 2000?': "SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';",
              'How many employees are there': 'SELECT COUNT(*) FROM "employee"'
             }

We can then create a retriever using the list of questions, assigning the target SQL query as metadata:

from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document

embeddings = OpenAIEmbeddings()

few_shot_docs = [Document(page_content=question, metadata={'sql_query': few_shots[question]}) for question in few_shots.keys()]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

Now we can create our own custom tool and append it as a new tool in the create_sql_agent function:

from langchain.agents.agent_toolkits import create_retriever_tool

tool_description = """
This tool will help you understand similar examples to adapt them to the user question.
Input to this tool should be the user question.
"""

retriever_tool = create_retriever_tool(
        retriever,
        name='sql_get_similar_examples',
        description=tool_description
    )
custom_tool_list = [retriever_tool]

We can now create the agent by modifying the normal SQL Agent suffix to reflect our use case. Although including it in the tool description is the simplest method to manage this, it is frequently insufficient, and we must express it in the agent prompt using the suffix argument in the constructor.

from langchain.agents import create_sql_agent, AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities import SQLDatabase
from langchain.chat_models import ChatOpenAI

db = SQLDatabase.from_uri("sqlite:///docs/Chinook.db")
llm = ChatOpenAI(model_name='gpt-4',temperature=0)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

custom_suffix = """
I should first get the similar examples I know.
If the examples are enough to construct the query, I can build it.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""

agent = create_sql_agent(llm=llm,
                         toolkit=toolkit,
                         verbose=True,
                         agent_type=AgentType.OPENAI_FUNCTIONS,
                         extra_tools=custom_tool_list,
                         suffix=custom_suffix
                        )
agent.run("How many employees do we have?")


> Entering new AgentExecutor chain...

Invoking: `sql_get_similar_examples` with `How many employees do we have?`


[Document(page_content='How many employees are there', metadata={'sql_query': 'SELECT COUNT(*) FROM "employee"'}), Document(page_content='Find the total number of invoices.', metadata={'sql_query': 'SELECT COUNT(*) FROM invoices;'}), Document(page_content='Who are the top 5 customers by total purchase?', metadata={'sql_query': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;'}), Document(page_content='List all customers from Canada.', metadata={'sql_query': "SELECT * FROM customers WHERE Country = 'Canada';"})]
Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM employee`
responded: {content}

SELECT COUNT(*) FROM employee
Invoking: `sql_db_query` with `SELECT COUNT(*) FROM employee`


[(8,)]We have 8 employees.

> Finished chain.
'We have 8 employees.'

7.2 Identifying and fixing proper noun misspellings

To accurately filter data from columns that contain proper nouns such as addresses, song titles, or artists, we must first double-check the spelling.

We may accomplish this by establishing a vector store with all of the various proper nouns in the database. The agent can then query that vector storage each time a proper noun is included in a question to find the right spelling for that word. Before constructing the target query, the agent can ensure that it understands which entity the user is referring to.

Let’s take a similar technique to the few shots, but without the metadata: embedding the proper nouns and then querying to find the one that is most similar to the misspelt user question.

First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:

import ast
import re

def run_query_save_results(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r'\b\d+\b', '', string).strip() for string in res]
    return res

artists = run_query_save_results(db, "SELECT Name FROM Artist")
albums = run_query_save_results(db, "SELECT Title FROM Album")

Now we can proceed with creating the custom retreiver tool and the final agent:

from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS


texts = (artists + albums)

embeddings = OpenAIEmbeddings()
vector_db = FAISS.from_texts(texts, embeddings)
retriever = vector_db.as_retriever()

retriever_tool = create_retriever_tool(
        retriever,
        name='name_search',
        description='use to learn how a piece of data is actually written, can be from names, surnames addresses etc'
    )

custom_tool_list = [retriever_tool]
from langchain.agents import create_sql_agent, AgentType
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities import SQLDatabase
from langchain.chat_models import ChatOpenAI

# db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model_name='gpt-4', temperature=0)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

custom_suffix = """
If a user asks for me to filter based on proper nouns, I should first check the spelling using the name_search tool.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""

agent = create_sql_agent(llm=llm,
                         toolkit=toolkit,
                         verbose=True,
                         agent_type=AgentType.OPENAI_FUNCTIONS,
                         extra_tools=custom_tool_list,
                         suffix=custom_suffix
                        )
agent.run("How many albums does alis in pains have?")


> Entering new AgentExecutor chain...

Invoking: `name_search` with `alis in pains`


[Document(page_content='House of Pain', metadata={}), Document(page_content='Alice In Chains', metadata={}), Document(page_content='Aisha Duo', metadata={}), Document(page_content='House Of Pain', metadata={})]
Invoking: `sql_db_list_tables` with ``
responded: {content}

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Invoking: `sql_db_schema` with `Album, Artist`
responded: {content}


CREATE TABLE "Album" (
    "AlbumId" INTEGER NOT NULL, 
    "Title" NVARCHAR(160) NOT NULL, 
    "ArtistId" INTEGER NOT NULL, 
    PRIMARY KEY ("AlbumId"), 
    FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId Title   ArtistId
1   For Those About To Rock We Salute You   1
2   Balls to the Wall   2
3   Restless and Wild   2
*/


CREATE TABLE "Artist" (
    "ArtistId" INTEGER NOT NULL, 
    "Name" NVARCHAR(120), 
    PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId    Name
1   AC/DC
2   Accept
3   Aerosmith
*/
Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')`
responded: {content}

SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')
Invoking: `sql_db_query` with `SELECT COUNT(*) FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'Alice In Chains')`


[(1,)]Alice In Chains has 1 album in the database.

> Finished chain.
'Alice In Chains has 1 album in the database.'

8 Further Reading

To learn more about the SQL Agent and how it works please refer to the SQL Agent Toolkit and LangChain Uses cases - SQL documentation.

You can also check Agents for other document types: - Pandas Agent - CSV Agent

9 Acknowledgements

I’d like to express my thanks to the wonderful Langsmith Documentation and acknowledge the use of some images and other materials from the documentation in this article.

Subscribe