Spaces:
Sleeping
Sleeping
| import uuid | |
| import chromadb | |
| import torch | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| import gradio as gr | |
| # Set device to GPU if available, else CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Initialize embedding model | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2", | |
| model_kwargs={"device": device} | |
| ) | |
| # Initialize ChromaDB client and collection | |
| chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
| vectorstore = Chroma( | |
| client=chroma_client, | |
| collection_name="text_collection", | |
| embedding_function=embedding_model, | |
| ) | |
| # Initialize reranker | |
| reranker = HuggingFaceCrossEncoder( | |
| model_name="BAAI/bge-reranker-base", | |
| model_kwargs={"device": device} | |
| ) | |
| def add_text_to_db(text): | |
| """ | |
| Add a piece of text to the vector database. | |
| Args: | |
| text (str): The text to add. | |
| Returns: | |
| str: Confirmation message. | |
| """ | |
| if not text or not text.strip(): | |
| return "Error: Text cannot be empty." | |
| # Generate unique ID | |
| doc_id = str(uuid.uuid4()) | |
| # Add text to vectorstore | |
| vectorstore.add_texts( | |
| texts=[text], | |
| metadatas=[{"text": text}], | |
| ids=[doc_id] | |
| ) | |
| return f"Text added successfully with ID: {doc_id}" | |
| def search_similar_texts(query, k, threshold): | |
| """ | |
| Search for the top k similar texts in the vector database and rerank them. | |
| Only return results with similarity scores above the threshold. | |
| Args: | |
| query (str): The search query. | |
| k (int): Number of results to return. | |
| threshold (float): Minimum similarity score (0 to 1). | |
| Returns: | |
| str: Formatted search results with similarity scores or "No such record". | |
| """ | |
| if not query or not query.strip(): | |
| return "Error: Query cannot be empty." | |
| if not isinstance(k, int) or k < 1: | |
| return "Error: k must be a positive integer." | |
| if not isinstance(threshold, (int, float)) or threshold < 0 or threshold > 1: | |
| return "Error: Threshold must be a number between 0 and 1." | |
| # Retrieve initial documents | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": max(k * 2, 10)}) | |
| docs = retriever.get_relevant_documents(query) | |
| if not docs: | |
| return "No such record." | |
| # Compute reranker scores | |
| scored_docs = [] | |
| for doc in docs: | |
| text = doc.metadata.get("text", "No text available") | |
| # Compute score using reranker | |
| score = reranker.score([query, text]) | |
| doc.metadata["score"] = float(score) | |
| scored_docs.append((doc, score)) | |
| # Sort by score in descending order | |
| scored_docs.sort(key=lambda x: x[1], reverse=True) | |
| # Filter by threshold and limit to k | |
| results = [] | |
| for i, (doc, score) in enumerate(scored_docs[:k]): | |
| if score >= threshold: | |
| text = doc.metadata.get("text", "No text available") | |
| results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n") | |
| if not results: | |
| return "No such record." | |
| return "\n".join(results) | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Semantic Search Pipeline") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Add Text to Database") | |
| text_input = gr.Textbox(label="Enter text to add") | |
| add_button = gr.Button("Add Text") | |
| add_output = gr.Textbox(label="Result") | |
| with gr.Column(): | |
| gr.Markdown("## Search Similar Texts") | |
| query_input = gr.Textbox(label="Enter search query") | |
| k_input = gr.Number(label="Number of results (k)", value=5, precision=0) | |
| threshold_input = gr.Number(label="Similarity threshold (0 to 1)", value=0.5, minimum=0, maximum=1) | |
| search_button = gr.Button("Search") | |
| search_output = gr.Textbox(label="Search Results") | |
| # Button actions | |
| add_button.click( | |
| fn=add_text_to_db, | |
| inputs=text_input, | |
| outputs=add_output | |
| ) | |
| search_button.click( | |
| fn=search_similar_texts, | |
| inputs=[query_input, k_input, threshold_input], | |
| outputs=search_output | |
| ) | |
| # Launch Gradio app | |
| if __name__ == "__main__": | |
| demo.launch() |