Source code for langchain.chains.sql_database.query

from typing import List, Optional, TypedDict, Union

from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnableParallel

from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
from langchain.utilities.sql_database import SQLDatabase


def _strip(text: str) -> str:
    return text.strip()


[docs]class SQLInput(TypedDict): """Input for a SQL Chain.""" question: str
[docs]class SQLInputWithTables(TypedDict): """Input for a SQL Chain.""" question: str table_names_to_use: List[str]
[docs]def create_sql_query_chain( llm: BaseLanguageModel, db: SQLDatabase, prompt: Optional[BasePromptTemplate] = None, k: int = 5, ) -> Runnable[Union[SQLInput, SQLInputWithTables], str]: """Create a chain that generates SQL queries. *Security Note*: This chain generates SQL queries for the given database. The SQLDatabase class provides a get_table_info method that can be used to get column information as well as sample data from the table. To mitigate risk of leaking sensitive data, limit permissions to read and scope to the tables that are needed. Optionally, use the SQLInputWithTables input type to specify which tables are allowed to be accessed. Control access to who can submit requests to this chain. See https://python.langchain.com/docs/security for more information. Args: llm: The language model to use db: The SQLDatabase to generate the query for prompt: The prompt to use. If none is provided, will choose one based on dialect. Defaults to None. k: The number of results per select statement to return. Defaults to 5. Returns: A chain that takes in a question and generates a SQL query that answers that question. """ if prompt is not None: prompt_to_use = prompt elif db.dialect in SQL_PROMPTS: prompt_to_use = SQL_PROMPTS[db.dialect] else: prompt_to_use = PROMPT inputs = { "input": lambda x: x["question"] + "\nSQLQuery: ", "top_k": lambda _: k, "table_info": lambda x: db.get_table_info( table_names=x.get("table_names_to_use") ), } if "dialect" in prompt_to_use.input_variables: inputs["dialect"] = lambda _: (db.dialect, prompt_to_use) return ( RunnableParallel(inputs) | prompt_to_use | llm.bind(stop=["\nSQLResult:"]) | StrOutputParser() | _strip )