from fastapi import FastAPI, HTTPException from pydantic import BaseModel import os import logging import sys from dotenv import load_dotenv from .config import DATASET_CONFIGS, load_prompt_template from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam import json # Load environment variables load_dotenv() # Lazy imports to avoid blocking startup # from .pipeline import RAGPipeline # Will import when needed # import umap # Will import when needed for visualization # import plotly.express as px # Will import when needed for visualization # import plotly.graph_objects as go # Will import when needed for visualization # from plotly.subplots import make_subplots # Will import when needed for visualization # import numpy as np # Will import when needed for visualization # from sklearn.preprocessing import normalize # Will import when needed for visualization # import pandas as pd # Will import when needed for visualization # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) app = FastAPI(title="RAG Pipeline API", description="Multi-dataset RAG API", version="1.0.0") # Initialize OpenRouter client openrouter_api_key = os.getenv("OPENROUTER_API_KEY") if not openrouter_api_key: raise ValueError("OPENROUTER_API_KEY environment variable is not set") openrouter_client = OpenAI( base_url="https://openrouter.ai/api/v1", api_key=openrouter_api_key ) # Model configuration MODEL_NAME = "z-ai/glm-4.5-air:free" # Initialize pipelines for all datasets pipelines = {} google_api_key = os.getenv("GOOGLE_API_KEY") logger.info(f"Starting RAG Pipeline API") logger.info(f"Port from env: {os.getenv('PORT', 'Not set - will use 8000')}") logger.info(f"Google API Key present: {'Yes' if google_api_key else 'No'}") logger.info(f"Available datasets: {list(DATASET_CONFIGS.keys())}") # Define tools for the GLM model def rag_qa(question: str, dataset: str = "developer-portfolio") -> str: """ Get answers from the RAG pipeline for specific questions about the dataset. Args: question: The question to answer using the RAG pipeline dataset: The dataset to search in (default: developer-portfolio) Returns: Answer from the RAG pipeline """ try: # Check if pipelines are loaded if not pipelines: return "RAG Pipeline is running but datasets are still loading in the background. Please try again in a moment." # Select the appropriate pipeline based on dataset if dataset not in pipelines: return f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}" selected_pipeline = pipelines[dataset] answer = selected_pipeline.answer_question(question) return answer except Exception as e: return f"Error accessing RAG pipeline: {str(e)}" # Tool definitions for GLM TOOLS = [ { "type": "function", "function": { "name": "rag_qa", "description": "Get answers from the RAG pipeline for specific questions about datasets", "parameters": { "type": "object", "properties": { "question": { "type": "string", "description": "The question to answer using the RAG pipeline" }, "dataset": { "type": "string", "description": "The dataset to search in (default: developer-portfolio)", "default": "developer-portfolio" } }, "required": ["question"] } } } ] # Don't load datasets during startup - do it asynchronously after server starts logger.info("RAG Pipeline API is ready to serve requests - datasets will load in background") # Visualization function disabled to speed up startup # def create_3d_visualization(pipeline): # ... (commented out for faster startup) class Question(BaseModel): text: str dataset: str = "developer-portfolio" # Default dataset class ChatMessage(BaseModel): role: str content: str class ChatRequest(BaseModel): messages: list[ChatMessage] dataset: str = "developer-portfolio" # Default dataset @app.post("/chat") async def chat_with_ai(request: ChatRequest): """ Chat with the AI assistant. The AI will use the RAG pipeline when needed to answer questions about the datasets. """ try: # Convert messages to OpenAI format with proper typing messages: list[ChatCompletionMessageParam] = [ {"role": msg.role, "content": msg.content} # type: ignore for msg in request.messages ] # Add system message to guide the AI if request.dataset == "developer-portfolio": system_message: ChatCompletionMessageParam = { "role": "system", "content": load_prompt_template("system-instruction.txt") } else: system_message: ChatCompletionMessageParam = { "role": "system", "content": load_prompt_template("generic-system-instruction.txt") } messages.insert(0, system_message) # Make the API call with tools response = openrouter_client.chat.completions.create( model=MODEL_NAME, messages=messages, tools=TOOLS, # type: ignore tool_choice="auto" ) message = response.choices[0].message finish_reason = response.choices[0].finish_reason # Handle tool calls if finish_reason == "tool_calls" and hasattr(message, 'tool_calls') and message.tool_calls: tool_results = [] # Execute tool calls for tool_call in message.tool_calls: if tool_call.function.name == "rag_qa": # Parse arguments args = json.loads(tool_call.function.arguments) question = args.get("question") dataset = args.get("dataset", request.dataset) # Call the rag_qa function result = rag_qa(question, dataset) tool_results.append({ "tool_call_id": tool_call.id, "result": result }) # Add tool results to conversation and get final response assistant_message: ChatCompletionMessageParam = { "role": "assistant", "content": message.content or "", "tool_calls": [ { "id": tc.id, "type": tc.type, "function": { "name": tc.function.name, "arguments": tc.function.arguments } } for tc in message.tool_calls ] } messages.append(assistant_message) for tool_result in tool_results: tool_message: ChatCompletionMessageParam = { "role": "tool", "tool_call_id": tool_result["tool_call_id"], "content": tool_result["result"] } messages.append(tool_message) # Get final response final_response = openrouter_client.chat.completions.create( model=MODEL_NAME, messages=messages ) return { "response": final_response.choices[0].message.content, "tool_calls": [ { "name": tc.function.name, "arguments": tc.function.arguments } for tc in message.tool_calls ] } else: # Direct response without tool calls return { "response": message.content, "tool_calls": None } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # /answer endpoint removed - use /chat for all interactions @app.get("/datasets") async def list_datasets(): """List all available datasets""" return {"datasets": list(pipelines.keys())} @app.get("/questions") async def list_questions(dataset: str = "developer-portfolio"): """List all questions for a given dataset""" if dataset not in pipelines: raise HTTPException(status_code=400, detail=f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}") selected_pipeline = pipelines[dataset] questions = [doc.meta['question'] for doc in selected_pipeline.documents if 'question' in doc.meta] return {"dataset": dataset, "questions": questions} async def load_datasets_background(): """Load datasets in background after server starts""" global pipelines # Import RAGPipeline only when needed from .pipeline import RAGPipeline # Only load developer-portfolio to save memory dataset_name = "developer-portfolio" try: logger.info(f"Loading dataset: {dataset_name}") pipeline = RAGPipeline.from_preset(preset_name=dataset_name) pipelines[dataset_name] = pipeline logger.info(f"Successfully loaded {dataset_name}") except Exception as e: logger.error(f"Failed to load {dataset_name}: {e}") logger.info(f"Background loading complete - {len(pipelines)} datasets loaded") @app.on_event("startup") async def startup_event(): logger.info("FastAPI application startup complete") logger.info(f"Server should be running on port: {os.getenv('PORT', '8000')}") # Start loading datasets in background (non-blocking) import asyncio asyncio.create_task(load_datasets_background()) @app.on_event("shutdown") async def shutdown_event(): logger.info("FastAPI application shutting down") @app.get("/") async def root(): """Root endpoint""" return {"status": "ok", "message": "RAG Pipeline API", "version": "1.0.0", "datasets": list(pipelines.keys())} @app.get("/health") async def health_check(): """Health check endpoint""" logger.info("Health check called") loading_status = "complete" if "developer-portfolio" in pipelines else "loading" return { "status": "healthy", "datasets_loaded": len(pipelines), "total_datasets": 1, # Only loading developer-portfolio "loading_status": loading_status, "port": os.getenv('PORT', '8000') }