Spaces:
Sleeping
Sleeping
| from flask import Flask, render_template, request, redirect, url_for, send_from_directory, flash | |
| from flask_socketio import SocketIO | |
| import threading | |
| import os | |
| from dotenv import load_dotenv | |
| import sqlite3 | |
| from werkzeug.utils import secure_filename | |
| # LangChain and agent imports | |
| from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later | |
| from langchain.agents import Tool | |
| from langchain.agents.format_scratchpad import format_log_to_str | |
| from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser | |
| from langchain_core.callbacks import CallbackManager, BaseCallbackHandler | |
| from langchain_community.agent_toolkits.load_tools import load_tools | |
| from langchain_core.tools import tool | |
| from langchain_community.agent_toolkits import PowerBIToolkit | |
| from langchain.chains import LLMMathChain | |
| from langchain import hub | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| # Agent requirements and type hints | |
| from typing import Annotated, Literal, TypedDict, Any | |
| from langchain_core.messages import AIMessage, ToolMessage | |
| from pydantic import BaseModel, Field | |
| from typing_extensions import TypedDict | |
| from langgraph.graph import END, StateGraph, START | |
| from langgraph.graph.message import AnyMessage, add_messages | |
| from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks | |
| from langgraph.prebuilt import ToolNode | |
| import traceback | |
| # Load environment variables | |
| load_dotenv() | |
| # Global configuration variables | |
| UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads") | |
| BASE_DIR = os.path.abspath(os.path.dirname(__file__)) | |
| DATABASE_URI = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'mydb.db')}" | |
| print("DATABASE URI:", DATABASE_URI) | |
| # API Keys from .env file | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
| MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
| os.environ["GROQ_API_KEY"] = GROQ_API_KEY | |
| os.environ["MISTRAL_API_KEY"] = MISTRAL_API_KEY | |
| # Global variables for dynamic agent and DB file path; initially None. | |
| agent_app = None | |
| abs_file_path = None | |
| db_path = None | |
| print(traceback.format_exc()) | |
| # ============================================================================= | |
| # create_agent_app: Given a database path, initialize the agent workflow. | |
| # ============================================================================= | |
| def create_agent_app(db_path: str): | |
| # Use ChatGroq as our LLM here; you can swap to ChatMistralAI if preferred. | |
| from langchain_groq import ChatGroq | |
| llm = ChatGroq(model="llama3-70b-8192") | |
| # ------------------------------------------------------------------------- | |
| # Define a tool for executing SQL queries. | |
| # ------------------------------------------------------------------------- | |
| def db_query_tool(query: str) -> str: | |
| """ | |
| Executes a SQL query on the connected SQLite database. | |
| Parameters: | |
| query (str): A SQL query string to be executed. | |
| Returns: | |
| str: The result from the database if successful, or an error message if not. | |
| """ | |
| result = db_instance.run_no_throw(query) | |
| return result if result else "Error: Query failed. Please rewrite your query and try again." | |
| # ------------------------------------------------------------------------- | |
| # Pydantic model for final answer | |
| # ------------------------------------------------------------------------- | |
| class SubmitFinalAnswer(BaseModel): | |
| final_answer: str = Field(..., description="The final answer to the user") | |
| # ------------------------------------------------------------------------- | |
| # Define state type for our workflow. | |
| # ------------------------------------------------------------------------- | |
| class State(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| # ------------------------------------------------------------------------- | |
| # Set up prompt templates (using langchain_core.prompts) for query checking | |
| # and query generation. | |
| # ------------------------------------------------------------------------- | |
| from langchain_core.prompts import ChatPromptTemplate | |
| query_check_system = ( | |
| "You are a SQL expert with a strong attention to detail.\n" | |
| "Double check the SQLite query for common mistakes, including:\n" | |
| "- Using NOT IN with NULL values\n" | |
| "- Using UNION when UNION ALL should have been used\n" | |
| "- Using BETWEEN for exclusive ranges\n" | |
| "- Data type mismatch in predicates\n" | |
| "- Properly quoting identifiers\n" | |
| "- Using the correct number of arguments for functions\n" | |
| "- Casting to the correct data type\n" | |
| "- Using the proper columns for joins\n\n" | |
| "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n" | |
| "You will call the appropriate tool to execute the query after running this check." | |
| ) | |
| query_check_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", query_check_system), | |
| ("placeholder", "{messages}") | |
| ]) | |
| query_check = query_check_prompt | llm.bind_tools([db_query_tool]) | |
| query_gen_system = ( | |
| "You are a SQL expert with a strong attention to detail.\n\n" | |
| "Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n\n" | |
| "DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.\n\n" | |
| "When generating the query:\n" | |
| "Output the SQL query that answers the input question without a tool call.\n" | |
| "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n" | |
| "You can order the results by a relevant column to return the most interesting examples in the database.\n" | |
| "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n\n" | |
| "If you get an error while executing a query, rewrite the query and try again.\n" | |
| "If you get an empty result set, you should try to rewrite the query to get a non-empty result set.\n" | |
| "NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.\n\n" | |
| "If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.\n" | |
| "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any SQL query except answer." | |
| ) | |
| query_gen_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", query_gen_system), | |
| ("placeholder", "{messages}") | |
| ]) | |
| query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer]) | |
| # ------------------------------------------------------------------------- | |
| # Update database URI and file path, create SQLDatabase connection. | |
| # ------------------------------------------------------------------------- | |
| abs_db_path_local = os.path.abspath(db_path) | |
| global DATABASE_URI | |
| DATABASE_URI = abs_db_path_local | |
| db_uri = f"sqlite:///{abs_db_path_local}" | |
| print("db_uri", db_uri) | |
| # Uncomment if flash is needed; ensure you have flask.flash imported if so. | |
| # flash(f"db_uri:{db_uri}", "warning") | |
| from langchain_community.utilities import SQLDatabase | |
| db_instance = SQLDatabase.from_uri(db_uri) | |
| print("db_instance----->", db_instance) | |
| # flash(f"db_instance:{db_instance}", "warning") | |
| # ------------------------------------------------------------------------- | |
| # Create SQL toolkit. | |
| # ------------------------------------------------------------------------- | |
| from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
| toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm) | |
| tools_instance = toolkit_instance.get_tools() | |
| # ------------------------------------------------------------------------- | |
| # Define workflow nodes and fallback functions. | |
| # ------------------------------------------------------------------------- | |
| def first_tool_call(state: State) -> dict[str, list[AIMessage]]: | |
| return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]} | |
| def handle_tool_error(state: State) -> dict: | |
| error = state.get("error") | |
| tool_calls = state["messages"][-1].tool_calls | |
| return {"messages": [ | |
| ToolMessage(content=f"Error: {repr(error)}. Please fix your mistakes.", tool_call_id=tc["id"]) | |
| for tc in tool_calls | |
| ]} | |
| def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]: | |
| return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error") | |
| def query_gen_node(state: State): | |
| message = query_gen.invoke(state) | |
| tool_messages = [] | |
| if message.tool_calls: | |
| for tc in message.tool_calls: | |
| if tc["name"] != "SubmitFinalAnswer": | |
| tool_messages.append(ToolMessage( | |
| content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes.", | |
| tool_call_id=tc["id"] | |
| )) | |
| return {"messages": [message] + tool_messages} | |
| def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]: | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| if getattr(last_message, "tool_calls", None): | |
| return END | |
| if last_message.content.startswith("Error:"): | |
| return "query_gen" | |
| return "correct_query" | |
| def model_check_query(state: State) -> dict[str, list[AIMessage]]: | |
| return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]} | |
| # ------------------------------------------------------------------------- | |
| # Get tools for listing tables and fetching schema. | |
| # ------------------------------------------------------------------------- | |
| list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None) | |
| get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None) | |
| workflow = StateGraph(State) | |
| workflow.add_node("first_tool_call", first_tool_call) | |
| workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool])) | |
| workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool])) | |
| model_get_schema = llm.bind_tools([get_schema_tool]) | |
| workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],}) | |
| workflow.add_node("query_gen", query_gen_node) | |
| workflow.add_node("correct_query", model_check_query) | |
| workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool])) | |
| workflow.add_edge(START, "first_tool_call") | |
| workflow.add_edge("first_tool_call", "list_tables_tool") | |
| workflow.add_edge("list_tables_tool", "model_get_schema") | |
| workflow.add_edge("model_get_schema", "get_schema_tool") | |
| workflow.add_edge("get_schema_tool", "query_gen") | |
| workflow.add_conditional_edges("query_gen", should_continue) | |
| workflow.add_edge("correct_query", "execute_query") | |
| workflow.add_edge("execute_query", "query_gen") | |
| # Return compiled workflow | |
| return workflow.compile() | |
| # ============================================================================= | |
| # create_app: The application factory. | |
| # ============================================================================= | |
| def create_app(): | |
| flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads') | |
| socketio = SocketIO(flask_app, cors_allowed_origins="*") | |
| # Ensure uploads folder exists. | |
| if not os.path.exists(UPLOAD_FOLDER): | |
| os.makedirs(UPLOAD_FOLDER) | |
| flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER | |
| # ------------------------------------------------------------------------- | |
| # Serve uploaded files via a custom route. | |
| # ------------------------------------------------------------------------- | |
| def uploaded_file(filename): | |
| return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename) | |
| # ------------------------------------------------------------------------- | |
| # Helper: run_agent runs the agent with the given prompt. | |
| # ------------------------------------------------------------------------- | |
| def run_agent(prompt, socketio): | |
| global agent_app, abs_file_path, db_path | |
| if not abs_file_path: | |
| socketio.emit("log", {"message": "[ERROR]: No DB file uploaded."}) | |
| socketio.emit("final", {"message": "No database available. Please upload one and try again."}) | |
| return | |
| try: | |
| # Lazy agent initialization: use the previously uploaded DB. | |
| if agent_app is None: | |
| print("[INFO]: Initializing agent for the first time...") | |
| agent_app = create_agent_app(abs_file_path) | |
| socketio.emit("log", {"message": "[INFO]: Agent initialized."}) | |
| query = {"messages": [("user", prompt)]} | |
| result = agent_app.invoke(query) | |
| try: | |
| result = result["messages"][-1].tool_calls[0]["args"]["final_answer"] | |
| except Exception: | |
| result = "Query failed or no valid answer found." | |
| print("final_answer------>", result) | |
| socketio.emit("final", {"message": result}) | |
| except Exception as e: | |
| print(f"[ERROR]: {str(e)}") | |
| socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
| socketio.emit("final", {"message": "Generation failed."}) | |
| # ------------------------------------------------------------------------- | |
| # Route: index page. | |
| # ------------------------------------------------------------------------- | |
| def index(): | |
| return render_template("index.html") | |
| # ------------------------------------------------------------------------- | |
| # Route: generate (POST) – receives a prompt and runs the agent. | |
| # ------------------------------------------------------------------------- | |
| def generate(): | |
| try: | |
| socketio.emit("log", {"message": "[STEP]: Entering query_gen..."}) | |
| data = request.json | |
| prompt = data.get("prompt", "") | |
| socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}"}) | |
| thread = threading.Thread(target=run_agent, args=(prompt, socketio)) | |
| socketio.emit("log", {"message": f"[INFO]: Starting thread: {thread}"}) | |
| thread.start() | |
| return "OK", 200 | |
| except Exception as e: | |
| print(f"[ERROR]: {str(e)}") | |
| socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
| return "ERROR", 500 | |
| # ------------------------------------------------------------------------- | |
| # Route: upload (GET/POST) – handles uploading the SQLite DB file. | |
| # ------------------------------------------------------------------------- | |
| def upload(): | |
| global abs_file_path, agent_app, db_path | |
| try: | |
| if request.method == "POST": | |
| file = request.files.get("file") | |
| if not file: | |
| print("No file uploaded") | |
| return "No file uploaded", 400 | |
| filename = secure_filename(file.filename) | |
| if filename.endswith('.db'): | |
| db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db") | |
| print("Saving file to:", db_path) | |
| file.save(db_path) | |
| abs_file_path = os.path.abspath(db_path) # Save it here; agent init will occur on first query. | |
| print(f"[INFO]: File '{filename}' uploaded. Agent will be initialized on first query.") | |
| socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded."}) | |
| return redirect(url_for("index")) | |
| return render_template("upload.html") | |
| except Exception as e: | |
| print(f"[ERROR]: {str(e)}") | |
| socketio.emit("log", {"message": f"[ERROR]: {str(e)}"}) | |
| return render_template("upload.html") | |
| return flask_app, socketio | |
| # ============================================================================= | |
| # Create the app for Gunicorn compatibility. | |
| # ============================================================================= | |
| app, socketio_instance = create_app() | |
| if __name__ == "__main__": | |
| socketio_instance.run(app, debug=True) |