Spaces:
Sleeping
Sleeping
| import uuid | |
| import chromadb | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.retrievers import ContextualCompressionRetriever | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| import gradio as gr | |
| # Initialize embedding model | |
| embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Initialize ChromaDB client and collection | |
| chroma_client = chromadb.PersistentClient(path="./chroma_db") | |
| vectorstore = Chroma( | |
| client=chroma_client, | |
| collection_name="text_collection", | |
| embedding_function=embedding_model, | |
| ) | |
| # Initialize reranker | |
| reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") | |
| compressor = CrossEncoderReranker(model=reranker, top_n=5) | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) # Retrieve 2k initially | |
| compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=compressor, base_retriever=retriever | |
| ) | |
| def add_text_to_db(text): | |
| """ | |
| Add a piece of text to the vector database. | |
| Args: | |
| text (str): The text to add. | |
| Returns: | |
| str: Confirmation message. | |
| """ | |
| if not text or not text.strip(): | |
| return "Error: Text cannot be empty." | |
| # Generate unique ID | |
| doc_id = str(uuid.uuid4()) | |
| # Add text to vectorstore | |
| vectorstore.add_texts( | |
| texts=[text], | |
| metadatas=[{"text": text}], | |
| ids=[doc_id] | |
| ) | |
| return f"Text added successfully with ID: {doc_id}" | |
| def search_similar_texts(query, k): | |
| """ | |
| Search for the top k similar texts in the vector database and rerank them. | |
| Args: | |
| query (str): The search query. | |
| k (int): Number of results to return. | |
| Returns: | |
| str: Formatted search results with similarity scores. | |
| """ | |
| if not query or not query.strip(): | |
| return "Error: Query cannot be empty." | |
| if not isinstance(k, int) or k < 1: | |
| return "Error: k must be a positive integer." | |
| # Retrieve and rerank | |
| retriever.search_kwargs["k"] = max(k * 2, 10) # Retrieve 2k or at least 10 | |
| compressor.top_n = k # Rerank to top k | |
| docs = compression_retriever.get_relevant_documents(query) | |
| if not docs: | |
| return "No results found." | |
| # Format results | |
| results = [] | |
| for i, doc in enumerate(docs[:k]): # Ensure we return at most k | |
| text = doc.metadata.get("text", "No text available") | |
| score = doc.metadata.get("score", 0.0) # Reranker score | |
| results.append(f"Result {i+1}:\nText: {text}\nScore: {score:.4f}\n") | |
| return "\n".join(results) or "No results found." | |
| # Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Semantic Search Pipeline") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Add Text to Database") | |
| text_input = gr.Textbox(label="Enter text to add") | |
| add_button = gr.Button("Add Text") | |
| add_output = gr.Textbox(label="Result") | |
| with gr.Column(): | |
| gr.Markdown("## Search Similar Texts") | |
| query_input = gr.Textbox(label="Enter search query") | |
| k_input = gr.Number(label="Number of results (k)", value=5, precision=0) | |
| search_button = gr.Button("Search") | |
| search_output = gr.Textbox(label="Search Results") | |
| # Button actions | |
| add_button.click( | |
| fn=add_text_to_db, | |
| inputs=text_input, | |
| outputs=add_output | |
| ) | |
| search_button.click( | |
| fn=search_similar_texts, | |
| inputs=[query_input, k_input], | |
| outputs=search_output | |
| ) | |
| # Launch Gradio app | |
| if __name__ == "__main__": | |
| demo.launch() |