Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample row in table info for SQLDatabase (#769) #782

Merged
merged 2 commits into from
Jan 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions docs/modules/chains/examples/sqlite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "a8fc8f23",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -242,6 +242,74 @@
"db_chain.run(\"What are some example tracks by composer Johann Sebastian Bach?\")"
]
},
{
"cell_type": "markdown",
"id": "bcc5e936",
"metadata": {},
"source": [
"## Adding first row of each table\n",
"Sometimes, the format of the data is not obvious and it is optimal to include the first row of the table in the prompt to allow the LLM to understand the data before providing a final query. Here we will use this feature to let the LLM know that artists are saved with their full names."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9a22ee47",
"metadata": {},
"outputs": [],
"source": [
"db = SQLDatabase.from_uri(\n",
" \"sqlite:///../../../../notebooks/Chinook.db\", \n",
" include_tables=['Track'], # we include only one table to save tokens in the prompt :)\n",
" sample_row_in_table_info=True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "bcb7a489",
"metadata": {},
"outputs": [],
"source": [
"db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "81e05d82",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
"What are some example tracks by Bach? \n",
"SQLQuery:Table 'Track' has columns: TrackId (INTEGER), Name (NVARCHAR(200)), AlbumId (INTEGER), MediaTypeId (INTEGER), GenreId (INTEGER), Composer (NVARCHAR(220)), Milliseconds (INTEGER), Bytes (INTEGER), UnitPrice (NUMERIC(10, 2)). Here is an example row for this table (long strings are truncated): ['1', 'For Those About To Rock (We Salute You)', '1', '1', '1', 'Angus Young, Malcolm Young, Brian Johnson', '343719', '11170334', '0.99'].\n",
"\u001b[32;1m\u001b[1;3m SELECT TrackId, Name, Composer FROM Track WHERE Composer LIKE '%Bach%' ORDER BY Name LIMIT 5;\u001b[0m\n",
"SQLResult: \u001b[33;1m\u001b[1;3m[(1709, 'American Woman', 'B. Cummings/G. Peterson/M.J. Kale/R. Bachman'), (3408, 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Johann Sebastian Bach'), (3433, 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Johann Sebastian Bach'), (3407, 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', 'Johann Sebastian Bach'), (3490, 'Partita in E Major, BWV 1006A: I. Prelude', 'Johann Sebastian Bach')]\u001b[0m\n",
"Answer:\u001b[32;1m\u001b[1;3m Some example tracks by Bach are 'American Woman', 'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria', 'Concerto No.2 in F Major, BWV1047, I. Allegro', 'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace', and 'Partita in E Major, BWV 1006A: I. Prelude'.\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' Some example tracks by Bach are \\'American Woman\\', \\'Aria Mit 30 Veränderungen, BWV 988 \"Goldberg Variations\": Aria\\', \\'Concerto No.2 in F Major, BWV1047, I. Allegro\\', \\'Concerto for 2 Violins in D Minor, BWV 1043: I. Vivace\\', and \\'Partita in E Major, BWV 1006A: I. Prelude\\'.'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_chain.run(\"What are some example tracks by Bach?\")"
]
},
{
"cell_type": "markdown",
"id": "c12ae15a",
Expand Down Expand Up @@ -319,14 +387,6 @@
"source": [
"chain.run(\"How many employees are also customers?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2998b03",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -345,7 +405,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.8.16"
}
},
"nbformat": 4,
Expand Down
16 changes: 16 additions & 0 deletions langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_row_in_table_info: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
Expand All @@ -39,6 +40,7 @@ def __init__(
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
self._sample_row_in_table_info = sample_row_in_table_info

@classmethod
def from_uri(cls, database_uri: str, **kwargs: Any) -> SQLDatabase:
Expand Down Expand Up @@ -69,14 +71,28 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names

template = "Table '{table_name}' has columns: {columns}."

tables = []
for table_name in all_table_names:

columns = []
for column in self._inspector.get_columns(table_name, schema=self._schema):
columns.append(f"{column['name']} ({str(column['type'])})")
column_str = ", ".join(columns)
table_str = template.format(table_name=table_name, columns=column_str)

if self._sample_row_in_table_info:
row_template = (
" Here is an example row for this table"
" (long strings are truncated): {sample_row}."
)
sample_row = self.run(f"SELECT * FROM '{table_name}' LIMIT 1")
if len(eval(sample_row)) > 0:
sample_row = " ".join([str(i)[:100] for i in eval(sample_row)[0]])
table_str += row_template.format(sample_row=sample_row)

tables.append(table_str)
return "\n".join(tables)

Expand Down
21 changes: 21 additions & 0 deletions tests/unit_tests/test_sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ def test_table_info() -> None:
assert sorted(output.split("\n")) == sorted(expected_output)


def test_table_info_w_sample_row() -> None:
"""Test that table info is constructed properly."""
engine = create_engine("sqlite:///:memory:")
metadata_obj.create_all(engine)
stmt = insert(user).values(user_id=13, user_name="Harrison")
with engine.begin() as conn:
conn.execute(stmt)

db = SQLDatabase(engine, sample_row_in_table_info=True)

output = db.table_info
expected_output = (
"Table 'company' has columns: company_id (INTEGER), "
"company_location (VARCHAR).\n"
"Table 'user' has columns: user_id (INTEGER), "
"user_name (VARCHAR(16)). Here is an example row "
"for this table (long strings are truncated): 13 Harrison."
)
assert sorted(output.split("\n")) == sorted(expected_output.split("\n"))


def test_sql_database_run() -> None:
"""Test that commands can be run successfully and returned in correct format."""
engine = create_engine("sqlite:///:memory:")
Expand Down