Spaces:
Sleeping
Sleeping
| """ | |
| Shared helper functions used by both HuggingFace Space and Local environments. | |
| Contains: configuration, memory management, vectorstore operations, PDF helpers, and UI utilities. | |
| """ | |
| import os | |
| from typing import List, Optional | |
| from datetime import datetime | |
| from collections import deque | |
| import gradio as gr | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.documents import Document | |
| from langchain_core.tools import tool | |
| # ============================================================================ | |
| # CONFIGURATION - All settings in one place | |
| # ============================================================================ | |
| def setup(): | |
| """ | |
| Central configuration for the RAG Agent application. | |
| Modify these values to customize the application behavior. | |
| Returns a config dictionary with all settings. | |
| """ | |
| return { | |
| # Model Configuration | |
| "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", | |
| "ollama_model": "qwen2m:latest", # Local Ollama model | |
| "hf_model": "mistralai/Mistral-7B-Instruct-v0.2", # HuggingFace cloud model | |
| "ollama_base_url": "http://localhost:11434", | |
| # Text Splitting Configuration | |
| "chunk_size": 1000, | |
| "chunk_overlap": 200, | |
| # Search Configuration | |
| "search_k": 5, # Number of documents to retrieve | |
| "search_content_limit": 500, # Max chars to show per chunk | |
| # LLM Generation Configuration | |
| "max_tokens": 512, | |
| "temperature": 0.1, # Lower = more deterministic | |
| "temperature_fallback": 0.7, # For text_generation fallback | |
| # Memory Configuration | |
| "max_memory_turns": 50, # Max conversation turns to store | |
| "memory_context_limit": 500, # Max chars per memory entry | |
| # Server Configuration | |
| "server_port": 7860, | |
| "server_host": "0.0.0.0", | |
| # UI Configuration | |
| "chatbot_height": 600, | |
| "progress_bar_length": 20, | |
| "chat_progress_bar_length": 15, | |
| } | |
| # Initialize configuration | |
| CONFIG = setup() | |
| # ============================================================================ | |
| # ENVIRONMENT DETECTION | |
| # ============================================================================ | |
| IS_HF_SPACE = os.getenv("SPACE_ID") is not None | |
| # Directories - use persistent storage on HF Spaces if available | |
| DATA_DIR = "/data" if (IS_HF_SPACE and os.path.exists("/data")) else "data" | |
| EMBEDDINGS_DIR = os.path.join(DATA_DIR, "embeddings") | |
| # Check storage persistence status | |
| HAS_PERSISTENT_STORAGE = IS_HF_SPACE and os.path.exists("/data") | |
| STORAGE_WARNING = "" if not IS_HF_SPACE else ( | |
| "β Persistent storage enabled - files will survive restarts" if HAS_PERSISTENT_STORAGE else | |
| "β οΈ Temporary storage - uploaded files will be lost when Space restarts" | |
| ) | |
| # Initialize embeddings (shared across environments) | |
| embeddings = HuggingFaceEmbeddings(model_name=CONFIG["embedding_model"]) | |
| # Global vectorstore (will be set by build_vectorstore) | |
| vs = None | |
| # ============================================================================ | |
| # CONVERSATION MEMORY | |
| # ============================================================================ | |
| conversation_memory: deque = deque(maxlen=CONFIG["max_memory_turns"]) | |
| def add_to_memory(role: str, content: str): | |
| """Add a message to conversation memory with timestamp.""" | |
| timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| conversation_memory.append({ | |
| "timestamp": timestamp, | |
| "role": role, | |
| "content": content | |
| }) | |
| print(f"πΎ Memory updated: {role} message added (total: {len(conversation_memory)} turns)") | |
| def get_memory_context(last_n: int = 10) -> str: | |
| """Get the last N conversation turns as context.""" | |
| if not conversation_memory: | |
| return "No previous conversation history." | |
| recent = list(conversation_memory)[-last_n:] | |
| context_parts = [] | |
| for msg in recent: | |
| role_emoji = "π€" if msg["role"] == "user" else "π€" | |
| context_parts.append(f"[{msg['timestamp']}] {role_emoji} {msg['role'].upper()}: {msg['content'][:CONFIG['memory_context_limit']]}") | |
| return "\n\n".join(context_parts) | |
| def search_memory(query: str) -> str: | |
| """Search conversation memory for relevant past discussions.""" | |
| if not conversation_memory: | |
| return "No conversation history to search." | |
| query_lower = query.lower() | |
| matches = [] | |
| for msg in conversation_memory: | |
| content_lower = msg["content"].lower() | |
| # Simple keyword matching | |
| if any(word in content_lower for word in query_lower.split()): | |
| role_emoji = "π€" if msg["role"] == "user" else "π€" | |
| matches.append(f"[{msg['timestamp']}] {role_emoji} {msg['role'].upper()}: {msg['content'][:CONFIG['memory_context_limit'] - 200]}...") | |
| if matches: | |
| return f"Found {len(matches)} relevant conversation(s):\n\n" + "\n\n---\n\n".join(matches[:5]) | |
| else: | |
| return f"No conversations found matching '{query}'." | |
| def clear_memory(): | |
| """Clear all conversation memory.""" | |
| conversation_memory.clear() | |
| print("π§Ή Conversation memory cleared") | |
| # ============================================================================ | |
| # UTILITY FUNCTIONS | |
| # ============================================================================ | |
| def get_timestamp() -> str: | |
| """Get current timestamp in HH:MM:SS format.""" | |
| return datetime.now().strftime("%H:%M:%S") | |
| def create_elapsed_timer(start_time: datetime): | |
| """Create an elapsed time function for tracking duration.""" | |
| def get_elapsed() -> str: | |
| elapsed = datetime.now() - start_time | |
| return f"β±οΈ {elapsed.total_seconds():.1f}s" | |
| return get_elapsed | |
| def format_progress_bar(elapsed_time: str, percentage: int, message: str, bar_length: int = 20) -> str: | |
| """Format progress with visual progress bar using Unicode blocks.""" | |
| filled_length = int(bar_length * percentage / 100) | |
| bar = 'β' * filled_length + 'β' * (bar_length - filled_length) | |
| return f"{elapsed_time} | [{percentage:3d}%] {bar} {message}" | |
| # ========================================================================= | |
| # FLOATING PROGRESS BAR HTML/JS (for Gradio UI) | |
| # ========================================================================= | |
| def floating_progress_bar_html(): | |
| """Return HTML+JS for a floating, borderless, fit-content progress bar overlay.""" | |
| return ''' | |
| <div id="floating-progress" style=" | |
| display: none; | |
| position: fixed; | |
| top: 20px; left: 50%; transform: translateX(-50%); | |
| background: #222; color: #fff; padding: 8px 0; border-radius: 8px; z-index: 9999; | |
| font-family: monospace; font-size: 1.2em; box-shadow: none; border: none; | |
| width: fit-content; min-width: 0; max-width: none; | |
| "> | |
| [....................................................................................................] | |
| </div> | |
| <script> | |
| function showProgressBar(barText) { | |
| var el = document.getElementById('floating-progress'); | |
| el.innerText = barText; | |
| el.style.display = 'block'; | |
| } | |
| function hideProgressBar() { | |
| document.getElementById('floating-progress').style.display = 'none'; | |
| } | |
| // Example usage (remove or replace with Python/Gradio event): | |
| // showProgressBar('[|||||||||||||.............]'); | |
| // setTimeout(hideProgressBar, 2000); | |
| </script> | |
| ''' | |
| # ============================================================================ | |
| # PDF HELPERS | |
| # ============================================================================ | |
| def get_pdf_list() -> List[str]: | |
| """Get list of PDF files in data folder.""" | |
| return [f for f in os.listdir(DATA_DIR) if f.endswith(".pdf")] | |
| def get_pdf_list_ui() -> List[str]: | |
| """Get PDF list for UI dropdown (with error handling).""" | |
| try: | |
| return get_pdf_list() | |
| except Exception as e: | |
| print(f"Error getting PDF list: {e}") | |
| return [] | |
| def make_pdf_dropdown(value=None): | |
| """Create a PDF dropdown with current file list.""" | |
| return gr.Dropdown(choices=get_pdf_list_ui(), value=value) | |
| # ============================================================================ | |
| # VECTORSTORE OPERATIONS | |
| # ============================================================================ | |
| def build_vectorstore(force_rebuild: bool = False) -> Optional[FAISS]: | |
| """Build or load FAISS vectorstore from PDFs. | |
| Args: | |
| force_rebuild: If True, rebuild from scratch even if existing vectorstore found | |
| """ | |
| global vs | |
| # Check if we should load existing vectorstore | |
| if not force_rebuild and os.path.exists(os.path.join(EMBEDDINGS_DIR, "index.faiss")): | |
| try: | |
| print("π Loading existing vectorstore...") | |
| vectorstore = FAISS.load_local(EMBEDDINGS_DIR, embeddings, allow_dangerous_deserialization=True) | |
| print("β Vectorstore loaded successfully") | |
| vs = vectorstore | |
| return vectorstore | |
| except Exception as e: | |
| print(f"β Error loading vectorstore: {e}, rebuilding...") | |
| # Build new vectorstore from PDFs | |
| pdf_files = get_pdf_list() | |
| if not pdf_files: | |
| print("No PDF files found to build embeddings") | |
| vs = None | |
| return None | |
| print(f"π¨ Building vectorstore from {len(pdf_files)} PDF(s): {pdf_files}") | |
| docs: List[Document] = [] | |
| for filename in pdf_files: | |
| try: | |
| filepath = os.path.join(DATA_DIR, filename) | |
| print(f"π Loading {filename}...") | |
| loader = PyPDFLoader(filepath) | |
| file_docs = loader.load() | |
| docs.extend(file_docs) | |
| print(f"β Loaded {len(file_docs)} pages from {filename}") | |
| except Exception as e: | |
| print(f"β Error loading {filename}: {e}") | |
| continue | |
| if not docs: | |
| print("β οΈ No documents could be loaded") | |
| vs = None | |
| return None | |
| print(f"βοΈ Splitting {len(docs)} pages into chunks...") | |
| splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=CONFIG["chunk_size"], | |
| chunk_overlap=CONFIG["chunk_overlap"] | |
| ) | |
| splits = splitter.split_documents(docs) | |
| print(f"π§© Created {len(splits)} text chunks") | |
| print("π€ Creating FAISS embeddings...") | |
| try: | |
| vectorstore = FAISS.from_documents(splits, embeddings) | |
| print(f"πΎ Saving vectorstore to {EMBEDDINGS_DIR}...") | |
| vectorstore.save_local(EMBEDDINGS_DIR) | |
| vs = vectorstore | |
| print("β Vectorstore built and saved successfully") | |
| return vectorstore | |
| except Exception as e: | |
| print(f"β Failed to build vectorstore: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| vs = None | |
| return None | |
| def get_vectorstore(): | |
| """Get the current vectorstore instance.""" | |
| global vs | |
| return vs | |
| def set_vectorstore(vectorstore): | |
| """Set the vectorstore instance.""" | |
| global vs | |
| vs = vectorstore | |
| # ============================================================================ | |
| # RAG AGENT TOOLS (LangChain @tool decorator pattern) | |
| # ============================================================================ | |
| def list_documents() -> str: | |
| """List all available PDF documents in the system. Use this tool when the user asks what documents are available, what files they have, or wants to see the document list.""" | |
| pdfs = get_pdf_list() | |
| if pdfs: | |
| return f"π Available documents: {', '.join(pdfs)}" | |
| else: | |
| return "π No documents are currently uploaded." | |
| def count_documents() -> str: | |
| """Count the total number of uploaded PDF documents. Use this tool when the user asks how many documents they have or wants a document count.""" | |
| count = len(get_pdf_list()) | |
| return f"π Total documents: {count}" | |
| def search_documents(query: str) -> str: | |
| """Search document content using RAG (Retrieval Augmented Generation). Use this tool to find information within the uploaded PDF documents based on a search query.""" | |
| global vs | |
| # Check if we have any PDF files first | |
| pdf_files = get_pdf_list() | |
| if not pdf_files: | |
| return "π No documents are currently uploaded. Please upload PDF files first." | |
| # Force reload vectorstore from disk if files exist | |
| print(f"π Checking vectorstore for {len(pdf_files)} PDF files...") | |
| # Check if FAISS files exist on disk | |
| faiss_path = os.path.join(EMBEDDINGS_DIR, "index.faiss") | |
| pkl_path = os.path.join(EMBEDDINGS_DIR, "index.pkl") | |
| if os.path.exists(faiss_path) and os.path.exists(pkl_path): | |
| print(f"π Found vectorstore files, loading...") | |
| try: | |
| # Force reload from disk | |
| vs = FAISS.load_local(EMBEDDINGS_DIR, embeddings, allow_dangerous_deserialization=True) | |
| print(f"β Vectorstore loaded successfully from disk") | |
| except Exception as e: | |
| print(f"β Error loading vectorstore: {e}") | |
| vs = None | |
| else: | |
| print(f"π No vectorstore files found, attempting to build...") | |
| vs = build_vectorstore() | |
| if vs is None: | |
| return f"π Found {len(pdf_files)} document(s) but search index could not be created. Please try re-uploading your files." | |
| try: | |
| # Extract key search terms from query (remove common words) | |
| search_query = query | |
| print(f"π Searching vectorstore for: {search_query}") | |
| # Use similarity_search_with_score to filter by relevance | |
| docs_with_scores = vs.similarity_search_with_score(search_query, k=CONFIG["search_k"]) | |
| if docs_with_scores: | |
| # Filter by score (lower is better for L2 distance) - adjust threshold as needed | |
| # Show more content from each chunk for better context | |
| context_parts = [] | |
| for doc, score in docs_with_scores: | |
| # Get source file from metadata | |
| source = doc.metadata.get('source', 'Unknown').split('/')[-1] | |
| page = doc.metadata.get('page', '?') | |
| # Include score and source in debug output | |
| print(f" π Score: {score:.3f} | Source: {source} pg{page} - {doc.page_content[:50]}...") | |
| # Show more content with source info | |
| context_parts.append(f"[Source: {source}, Page: {page}, Relevance: {score:.2f}]\n{doc.page_content[:CONFIG['search_content_limit']]}") | |
| context = "\n\n---\n\n".join(context_parts) | |
| print(f"β Found {len(docs_with_scores)} document chunks") | |
| return f"π Search results for '{query}':\n\n{context}" | |
| else: | |
| print(f"β οΈ No relevant documents found for query: {query}") | |
| return f"π No relevant information found for '{query}' in your {len(pdf_files)} document(s). Try different keywords or check if your documents contain relevant content." | |
| except Exception as e: | |
| error_msg = f"π Search error: {str(e)}. You have {len(pdf_files)} documents available." | |
| print(f"β Search error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return error_msg | |
| def search_conversation_history(query: str) -> str: | |
| """Search through previous conversation history to find past discussions. Use this tool when the user asks about something they discussed before, wants to recall previous answers, or references past conversations.""" | |
| return search_memory(query) | |
| def get_recent_conversation(turns: int = 5) -> str: | |
| """Get the most recent conversation turns. Use this tool when the user asks what they were discussing, wants a summary of recent chat, or needs context from earlier in the conversation.""" | |
| return get_memory_context(last_n=turns) | |
| # List of all available tools | |
| AGENT_TOOLS = [list_documents, count_documents, search_documents, search_conversation_history, get_recent_conversation] | |
| # Sample question texts - Enhanced for agent capabilities | |
| SAMPLE_Q1 = "How many documents are loaded? List their names and types." | |
| SAMPLE_Q2 = "Summarize the key points of each document in 5 bullet points." | |
| SAMPLE_Q3 = "What is the attention mechanism? list the main topics." | |
| SAMPLE_Q4 = "How can I cook chicken breast with Phillips air fryer recipes?" | |
| SAMPLE_Q5 = "Summarize each document in max 10 bullet points." | |
| SAMPLE_Q6 = "What did we discuss earlier?" | |
| SAMPLE_Q7 = "Summarize it in 50 words." | |