|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import os |
|
|
import logging |
|
|
import sys |
|
|
from dotenv import load_dotenv |
|
|
from .config import DATASET_CONFIGS, load_prompt_template |
|
|
from openai import OpenAI |
|
|
from openai.types.chat import ChatCompletionMessageParam |
|
|
import json |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.StreamHandler(sys.stdout) |
|
|
] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI(title="RAG Pipeline API", description="Multi-dataset RAG API", version="1.0.0") |
|
|
|
|
|
|
|
|
openrouter_api_key = os.getenv("OPENROUTER_API_KEY") |
|
|
if not openrouter_api_key: |
|
|
raise ValueError("OPENROUTER_API_KEY environment variable is not set") |
|
|
|
|
|
openrouter_client = OpenAI( |
|
|
base_url="https://openrouter.ai/api/v1", |
|
|
api_key=openrouter_api_key |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_NAME = "z-ai/glm-4.5-air:free" |
|
|
|
|
|
|
|
|
pipelines = {} |
|
|
google_api_key = os.getenv("GOOGLE_API_KEY") |
|
|
|
|
|
logger.info(f"Starting RAG Pipeline API") |
|
|
logger.info(f"Port from env: {os.getenv('PORT', 'Not set - will use 8000')}") |
|
|
logger.info(f"Google API Key present: {'Yes' if google_api_key else 'No'}") |
|
|
logger.info(f"Available datasets: {list(DATASET_CONFIGS.keys())}") |
|
|
|
|
|
|
|
|
def rag_qa(question: str, dataset: str = "developer-portfolio") -> str: |
|
|
""" |
|
|
Get answers from the RAG pipeline for specific questions about the dataset. |
|
|
|
|
|
Args: |
|
|
question: The question to answer using the RAG pipeline |
|
|
dataset: The dataset to search in (default: developer-portfolio) |
|
|
|
|
|
Returns: |
|
|
Answer from the RAG pipeline |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not pipelines: |
|
|
return "RAG Pipeline is running but datasets are still loading in the background. Please try again in a moment." |
|
|
|
|
|
|
|
|
if dataset not in pipelines: |
|
|
return f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}" |
|
|
|
|
|
selected_pipeline = pipelines[dataset] |
|
|
answer = selected_pipeline.answer_question(question) |
|
|
return answer |
|
|
except Exception as e: |
|
|
return f"Error accessing RAG pipeline: {str(e)}" |
|
|
|
|
|
|
|
|
TOOLS = [ |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "rag_qa", |
|
|
"description": "Get answers from the RAG pipeline for specific questions about datasets", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"question": { |
|
|
"type": "string", |
|
|
"description": "The question to answer using the RAG pipeline" |
|
|
}, |
|
|
"dataset": { |
|
|
"type": "string", |
|
|
"description": "The dataset to search in (default: developer-portfolio)", |
|
|
"default": "developer-portfolio" |
|
|
} |
|
|
}, |
|
|
"required": ["question"] |
|
|
} |
|
|
} |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
logger.info("RAG Pipeline API is ready to serve requests - datasets will load in background") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Question(BaseModel): |
|
|
text: str |
|
|
dataset: str = "developer-portfolio" |
|
|
|
|
|
class ChatMessage(BaseModel): |
|
|
role: str |
|
|
content: str |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
messages: list[ChatMessage] |
|
|
dataset: str = "developer-portfolio" |
|
|
|
|
|
@app.post("/chat") |
|
|
async def chat_with_ai(request: ChatRequest): |
|
|
""" |
|
|
Chat with the AI assistant. The AI will use the RAG pipeline when needed to answer questions about the datasets. |
|
|
""" |
|
|
try: |
|
|
|
|
|
messages: list[ChatCompletionMessageParam] = [ |
|
|
{"role": msg.role, "content": msg.content} |
|
|
for msg in request.messages |
|
|
] |
|
|
|
|
|
|
|
|
if request.dataset == "developer-portfolio": |
|
|
system_message: ChatCompletionMessageParam = { |
|
|
"role": "system", |
|
|
"content": load_prompt_template("system-instruction.txt") |
|
|
} |
|
|
else: |
|
|
system_message: ChatCompletionMessageParam = { |
|
|
"role": "system", |
|
|
"content": load_prompt_template("generic-system-instruction.txt") |
|
|
} |
|
|
messages.insert(0, system_message) |
|
|
|
|
|
|
|
|
response = openrouter_client.chat.completions.create( |
|
|
model=MODEL_NAME, |
|
|
messages=messages, |
|
|
tools=TOOLS, |
|
|
tool_choice="auto" |
|
|
) |
|
|
|
|
|
message = response.choices[0].message |
|
|
finish_reason = response.choices[0].finish_reason |
|
|
|
|
|
|
|
|
if finish_reason == "tool_calls" and hasattr(message, 'tool_calls') and message.tool_calls: |
|
|
tool_results = [] |
|
|
|
|
|
|
|
|
for tool_call in message.tool_calls: |
|
|
if tool_call.function.name == "rag_qa": |
|
|
|
|
|
args = json.loads(tool_call.function.arguments) |
|
|
question = args.get("question") |
|
|
dataset = args.get("dataset", request.dataset) |
|
|
|
|
|
|
|
|
result = rag_qa(question, dataset) |
|
|
tool_results.append({ |
|
|
"tool_call_id": tool_call.id, |
|
|
"result": result |
|
|
}) |
|
|
|
|
|
|
|
|
assistant_message: ChatCompletionMessageParam = { |
|
|
"role": "assistant", |
|
|
"content": message.content or "", |
|
|
"tool_calls": [ |
|
|
{ |
|
|
"id": tc.id, |
|
|
"type": tc.type, |
|
|
"function": { |
|
|
"name": tc.function.name, |
|
|
"arguments": tc.function.arguments |
|
|
} |
|
|
} |
|
|
for tc in message.tool_calls |
|
|
] |
|
|
} |
|
|
messages.append(assistant_message) |
|
|
|
|
|
for tool_result in tool_results: |
|
|
tool_message: ChatCompletionMessageParam = { |
|
|
"role": "tool", |
|
|
"tool_call_id": tool_result["tool_call_id"], |
|
|
"content": tool_result["result"] |
|
|
} |
|
|
messages.append(tool_message) |
|
|
|
|
|
|
|
|
final_response = openrouter_client.chat.completions.create( |
|
|
model=MODEL_NAME, |
|
|
messages=messages |
|
|
) |
|
|
|
|
|
return { |
|
|
"response": final_response.choices[0].message.content, |
|
|
"tool_calls": [ |
|
|
{ |
|
|
"name": tc.function.name, |
|
|
"arguments": tc.function.arguments |
|
|
} |
|
|
for tc in message.tool_calls |
|
|
] |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
"response": message.content, |
|
|
"tool_calls": None |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/datasets") |
|
|
async def list_datasets(): |
|
|
"""List all available datasets""" |
|
|
return {"datasets": list(pipelines.keys())} |
|
|
|
|
|
@app.get("/questions") |
|
|
async def list_questions(dataset: str = "developer-portfolio"): |
|
|
"""List all questions for a given dataset""" |
|
|
if dataset not in pipelines: |
|
|
raise HTTPException(status_code=400, detail=f"Dataset '{dataset}' not available. Available datasets: {list(pipelines.keys())}") |
|
|
|
|
|
selected_pipeline = pipelines[dataset] |
|
|
questions = [doc.meta['question'] for doc in selected_pipeline.documents if 'question' in doc.meta] |
|
|
return {"dataset": dataset, "questions": questions} |
|
|
|
|
|
async def load_datasets_background(): |
|
|
"""Load datasets in background after server starts""" |
|
|
global pipelines |
|
|
|
|
|
from .pipeline import RAGPipeline |
|
|
|
|
|
dataset_name = "developer-portfolio" |
|
|
try: |
|
|
logger.info(f"Loading dataset: {dataset_name}") |
|
|
pipeline = RAGPipeline.from_preset(preset_name=dataset_name) |
|
|
pipelines[dataset_name] = pipeline |
|
|
logger.info(f"Successfully loaded {dataset_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load {dataset_name}: {e}") |
|
|
logger.info(f"Background loading complete - {len(pipelines)} datasets loaded") |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
logger.info("FastAPI application startup complete") |
|
|
logger.info(f"Server should be running on port: {os.getenv('PORT', '8000')}") |
|
|
|
|
|
|
|
|
import asyncio |
|
|
asyncio.create_task(load_datasets_background()) |
|
|
|
|
|
@app.on_event("shutdown") |
|
|
async def shutdown_event(): |
|
|
logger.info("FastAPI application shutting down") |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Root endpoint""" |
|
|
return {"status": "ok", "message": "RAG Pipeline API", "version": "1.0.0", "datasets": list(pipelines.keys())} |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint""" |
|
|
logger.info("Health check called") |
|
|
loading_status = "complete" if "developer-portfolio" in pipelines else "loading" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"datasets_loaded": len(pipelines), |
|
|
"total_datasets": 1, |
|
|
"loading_status": loading_status, |
|
|
"port": os.getenv('PORT', '8000') |
|
|
} |
|
|
|