Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,93 +2,213 @@ import gradio as gr
|
|
| 2 |
from sentence_transformers import SentenceTransformer
|
| 3 |
import torch
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
|
| 8 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
if not query.strip():
|
| 10 |
-
return "Please enter a query."
|
| 11 |
if not documents.strip():
|
| 12 |
-
return "Please enter documents (one per line)."
|
| 13 |
-
|
| 14 |
-
# Split documents by lines
|
| 15 |
-
doc_list = [doc.strip() for doc in documents.split('\n') if doc.strip()]
|
| 16 |
|
|
|
|
|
|
|
| 17 |
if not doc_list:
|
| 18 |
-
return "Please enter at least one document."
|
| 19 |
-
|
| 20 |
-
# Encode query and documents
|
| 21 |
-
query_embeddings = model.encode_query(query)
|
| 22 |
-
document_embeddings = model.encode_document(doc_list)
|
| 23 |
-
|
| 24 |
-
# Compute similarities
|
| 25 |
-
similarities = model.similarity(query_embeddings, document_embeddings)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
results = []
|
| 32 |
-
for i, idx in enumerate(
|
| 33 |
score = similarities[0][idx].item()
|
| 34 |
-
|
| 35 |
-
results.append(f"{i+1}. Score: {score:.4f}\n Document: {doc}")
|
| 36 |
-
|
| 37 |
return "\n\n".join(results)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
gr.Examples(
|
| 68 |
examples=[
|
| 69 |
[
|
| 70 |
"Which planet is known as the Red Planet?",
|
| 71 |
-
"Venus is
|
| 72 |
-
|
| 73 |
-
[
|
| 74 |
-
"What causes seasons on Earth?",
|
| 75 |
-
"The tilt of Earth's axis causes different parts of the planet to receive varying amounts of sunlight throughout the year.\nThe moon's gravitational pull affects ocean tides but not seasons.\nEarth's orbit around the sun is slightly elliptical, but this has minimal effect on seasons.\nThe rotation of Earth on its axis causes day and night cycles."
|
| 76 |
]
|
| 77 |
],
|
| 78 |
-
inputs=[query_input,
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
search_btn.click(
|
| 82 |
-
fn=find_similar_documents,
|
| 83 |
-
inputs=[query_input, documents_input],
|
| 84 |
-
outputs=output
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# Allow Enter key to trigger search
|
| 88 |
-
query_input.submit(
|
| 89 |
-
fn=find_similar_documents,
|
| 90 |
-
inputs=[query_input, documents_input],
|
| 91 |
-
outputs=output
|
| 92 |
)
|
| 93 |
|
| 94 |
-
demo.launch()
|
|
|
|
| 2 |
from sentence_transformers import SentenceTransformer
|
| 3 |
import torch
|
| 4 |
|
| 5 |
+
# Cache loaded models to avoid reloading
|
| 6 |
+
loaded_models = {}
|
| 7 |
|
| 8 |
+
def load_model(model_name):
|
| 9 |
+
"""Load and cache a model."""
|
| 10 |
+
if model_name in loaded_models:
|
| 11 |
+
return loaded_models[model_name]
|
| 12 |
+
model = SentenceTransformer(model_name)
|
| 13 |
+
loaded_models[model_name] = model
|
| 14 |
+
return model
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def find_similar_documents(query, documents, model_name):
|
| 18 |
+
"""Compute similarity ranking for one model."""
|
| 19 |
if not query.strip():
|
| 20 |
+
return "⚠️ Please enter a query."
|
| 21 |
if not documents.strip():
|
| 22 |
+
return "⚠️ Please enter documents (one per line)."
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
model = load_model(model_name)
|
| 25 |
+
doc_list = [d.strip() for d in documents.split("\n") if d.strip()]
|
| 26 |
if not doc_list:
|
| 27 |
+
return "⚠️ Please enter at least one valid document."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
query_emb = model.encode_query(query)
|
| 30 |
+
doc_emb = model.encode_document(doc_list)
|
| 31 |
+
similarities = model.similarity(query_emb, doc_emb)
|
| 32 |
|
| 33 |
+
sorted_idx = torch.argsort(similarities[0], descending=True)
|
| 34 |
results = []
|
| 35 |
+
for i, idx in enumerate(sorted_idx):
|
| 36 |
score = similarities[0][idx].item()
|
| 37 |
+
results.append(f"**{i+1}. (Score: {score:.4f})**\n{doc_list[idx]}")
|
|
|
|
|
|
|
| 38 |
return "\n\n".join(results)
|
| 39 |
|
| 40 |
+
|
| 41 |
+
def compare_models(query, documents, tarka_model, open_model):
|
| 42 |
+
"""Compare two models side-by-side."""
|
| 43 |
+
if not query.strip():
|
| 44 |
+
return "⚠️ Please enter a query.", ""
|
| 45 |
+
if not documents.strip():
|
| 46 |
+
return "⚠️ Please enter documents (one per line).", ""
|
| 47 |
+
|
| 48 |
+
tarka = load_model(tarka_model)
|
| 49 |
+
openm = load_model(open_model)
|
| 50 |
+
|
| 51 |
+
doc_list = [d.strip() for d in documents.split("\n") if d.strip()]
|
| 52 |
+
if not doc_list:
|
| 53 |
+
return "⚠️ Please enter at least one valid document.", ""
|
| 54 |
+
|
| 55 |
+
# Compute similarities for both models
|
| 56 |
+
tq = tarka.encode_query(query)
|
| 57 |
+
td = tarka.encode_document(doc_list)
|
| 58 |
+
tsim = tarka.similarity(tq, td)
|
| 59 |
+
|
| 60 |
+
oq = openm.encode_query(query)
|
| 61 |
+
od = openm.encode_document(doc_list)
|
| 62 |
+
osim = openm.similarity(oq, od)
|
| 63 |
+
|
| 64 |
+
# Sort for each model
|
| 65 |
+
tsorted = torch.argsort(tsim[0], descending=True)
|
| 66 |
+
osorted = torch.argsort(osim[0], descending=True)
|
| 67 |
+
|
| 68 |
+
tarka_results, open_results = [], []
|
| 69 |
+
for i, idx in enumerate(tsorted):
|
| 70 |
+
tarka_results.append(f"**{i+1}. (Score: {tsim[0][idx]:.4f})**\n{doc_list[idx]}")
|
| 71 |
+
|
| 72 |
+
for i, idx in enumerate(osorted):
|
| 73 |
+
open_results.append(f"**{i+1}. (Score: {osim[0][idx]:.4f})**\n{doc_list[idx]}")
|
| 74 |
+
|
| 75 |
+
return "\n\n".join(tarka_results), "\n\n".join(open_results)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# --- UI Layout ---
|
| 79 |
+
with gr.Blocks(
|
| 80 |
+
title="Document Similarity Explorer",
|
| 81 |
+
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="indigo", neutral_hue="zinc")
|
| 82 |
+
) as demo:
|
| 83 |
|
| 84 |
+
gr.Markdown("# 🔍 Document Similarity Explorer")
|
| 85 |
+
gr.Markdown("Compare document relevance across embedding models easily.")
|
| 86 |
+
|
| 87 |
+
with gr.Tabs():
|
| 88 |
+
# ----------------- SINGLE MODEL TAB -----------------
|
| 89 |
+
with gr.Tab("Single Model Search"):
|
| 90 |
+
with gr.Row():
|
| 91 |
+
with gr.Column(scale=1):
|
| 92 |
+
model_selector = gr.Dropdown(
|
| 93 |
+
label="Select Embedding Model",
|
| 94 |
+
choices=[
|
| 95 |
+
"Tarka-AIR/Tarka-Embedding-150M-V1",
|
| 96 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 97 |
+
"intfloat/e5-base-v2",
|
| 98 |
+
"BAAI/bge-small-en-v1.5"
|
| 99 |
+
],
|
| 100 |
+
value="Tarka-AIR/Tarka-Embedding-150M-V1"
|
| 101 |
+
)
|
| 102 |
+
loading_msg = gr.Markdown(visible=False)
|
| 103 |
+
|
| 104 |
+
query_input = gr.Textbox(
|
| 105 |
+
label="Query",
|
| 106 |
+
placeholder="Enter your search query...",
|
| 107 |
+
lines=2
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
docs_input = gr.Textbox(
|
| 111 |
+
label="Documents",
|
| 112 |
+
placeholder="Enter one document per line...",
|
| 113 |
+
lines=10
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
search_btn = gr.Button("🔎 Search", variant="primary")
|
| 117 |
+
|
| 118 |
+
with gr.Column(scale=1):
|
| 119 |
+
result_box = gr.Markdown(label="Results", elem_id="results-box")
|
| 120 |
+
|
| 121 |
+
def on_model_change(model_name):
|
| 122 |
+
loading_msg.update(value=f"⏳ Loading **{model_name}**...", visible=True)
|
| 123 |
+
load_model(model_name)
|
| 124 |
+
return gr.update(value=f"✅ {model_name} loaded and ready!", visible=True)
|
| 125 |
|
| 126 |
+
model_selector.change(fn=on_model_change, inputs=[model_selector], outputs=[loading_msg])
|
| 127 |
+
|
| 128 |
+
search_btn.click(fn=find_similar_documents,
|
| 129 |
+
inputs=[query_input, docs_input, model_selector],
|
| 130 |
+
outputs=result_box)
|
| 131 |
|
| 132 |
+
query_input.submit(fn=find_similar_documents,
|
| 133 |
+
inputs=[query_input, docs_input, model_selector],
|
| 134 |
+
outputs=result_box)
|
| 135 |
+
|
| 136 |
+
# ----------------- COMPARISON TAB -----------------
|
| 137 |
+
with gr.Tab("Compare Models"):
|
| 138 |
+
with gr.Row():
|
| 139 |
+
with gr.Column(scale=1):
|
| 140 |
+
tarka_selector = gr.Dropdown(
|
| 141 |
+
label="Tarka Model",
|
| 142 |
+
choices=[
|
| 143 |
+
"Tarka-AIR/Tarka-Embedding-150M-V1",
|
| 144 |
+
"Tarka-AIR/Tarka-Embedding-200M-V1",
|
| 145 |
+
"Tarka-AIR/Tarka-Embedding-300M-V1"
|
| 146 |
+
],
|
| 147 |
+
value="Tarka-AIR/Tarka-Embedding-150M-V1"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
open_selector = gr.Dropdown(
|
| 151 |
+
label="Open Source Model",
|
| 152 |
+
choices=[
|
| 153 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 154 |
+
"intfloat/e5-base-v2",
|
| 155 |
+
"BAAI/bge-small-en-v1.5",
|
| 156 |
+
"thenlper/gte-base"
|
| 157 |
+
],
|
| 158 |
+
value="sentence-transformers/all-MiniLM-L6-v2"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
compare_loading = gr.Markdown(visible=False)
|
| 162 |
+
|
| 163 |
+
query_compare = gr.Textbox(
|
| 164 |
+
label="Query",
|
| 165 |
+
placeholder="Enter query to compare...",
|
| 166 |
+
lines=2
|
| 167 |
+
)
|
| 168 |
+
docs_compare = gr.Textbox(
|
| 169 |
+
label="Documents",
|
| 170 |
+
placeholder="Enter documents (one per line)...",
|
| 171 |
+
lines=10
|
| 172 |
+
)
|
| 173 |
+
compare_btn = gr.Button("⚖️ Compare Models", variant="primary")
|
| 174 |
+
|
| 175 |
+
with gr.Column(scale=2):
|
| 176 |
+
with gr.Row():
|
| 177 |
+
tarka_output = gr.Markdown(label="Tarka Model Results")
|
| 178 |
+
open_output = gr.Markdown(label="Open Source Model Results")
|
| 179 |
+
|
| 180 |
+
def on_compare_models_load(tarka_model, open_model):
|
| 181 |
+
compare_loading.update(value=f"⏳ Loading **{tarka_model}** and **{open_model}**...", visible=True)
|
| 182 |
+
load_model(tarka_model)
|
| 183 |
+
load_model(open_model)
|
| 184 |
+
return gr.update(value=f"✅ Models ready for comparison!", visible=True)
|
| 185 |
+
|
| 186 |
+
tarka_selector.change(fn=on_compare_models_load,
|
| 187 |
+
inputs=[tarka_selector, open_selector],
|
| 188 |
+
outputs=[compare_loading])
|
| 189 |
+
open_selector.change(fn=on_compare_models_load,
|
| 190 |
+
inputs=[tarka_selector, open_selector],
|
| 191 |
+
outputs=[compare_loading])
|
| 192 |
+
|
| 193 |
+
compare_btn.click(fn=compare_models,
|
| 194 |
+
inputs=[query_compare, docs_compare, tarka_selector, open_selector],
|
| 195 |
+
outputs=[tarka_output, open_output])
|
| 196 |
+
|
| 197 |
+
query_compare.submit(fn=compare_models,
|
| 198 |
+
inputs=[query_compare, docs_compare, tarka_selector, open_selector],
|
| 199 |
+
outputs=[tarka_output, open_output])
|
| 200 |
+
|
| 201 |
+
# Example block for both tabs
|
| 202 |
gr.Examples(
|
| 203 |
examples=[
|
| 204 |
[
|
| 205 |
"Which planet is known as the Red Planet?",
|
| 206 |
+
"Venus is Earth's twin.\nMars, the Red Planet.\nJupiter is the biggest.\nSaturn has rings.",
|
| 207 |
+
"Tarka-AIR/Tarka-Embedding-150M-V1"
|
|
|
|
|
|
|
|
|
|
| 208 |
]
|
| 209 |
],
|
| 210 |
+
inputs=[query_input, docs_input, model_selector],
|
| 211 |
+
label="Try Example"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
)
|
| 213 |
|
| 214 |
+
demo.launch()
|