Cardiosense-AG commited on
Commit
2fc86e8
·
verified ·
1 Parent(s): ec5fb8f

Update src/rag_index.py

Browse files
Files changed (1) hide show
  1. src/rag_index.py +24 -3
src/rag_index.py CHANGED
@@ -93,22 +93,42 @@ def _format_e5_query(text: str) -> str:
93
  return f"query: {text.strip()}"
94
 
95
  def search_index(query_text: str, *args, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]:
96
- """Flexible signature: supports legacy (query, top_k) or explicit embedder/bundle."""
 
 
 
 
 
97
  if len(args) >= 2:
98
  embedder = args[0]
99
  bundle = args[1]
 
 
 
 
100
  else:
101
  embedder, bundle = load_index_bundle()
102
 
103
- index: faiss.Index = bundle["index"]
104
- chunks: List[Dict[str, Any]] = bundle["chunks"]
 
 
 
 
 
 
105
 
 
 
 
 
106
  q = _format_e5_query(query_text)
107
  qv = np.asarray(embedder.encode([q], normalize_embeddings=True), dtype="float32")
108
 
109
  scores, idxs = index.search(qv, top_k)
110
  idxs, scores = idxs[0], scores[0]
111
 
 
112
  results: List[Dict[str, Any]] = []
113
  for rank, (i, s) in enumerate(zip(idxs, scores)):
114
  if i < 0 or i >= len(chunks):
@@ -123,6 +143,7 @@ def search_index(query_text: str, *args, top_k: int = 5, **kwargs) -> List[Dict[
123
  })
124
  return results
125
 
 
126
  # ------------------------------ stats helper ---------------------------------
127
 
128
  def index_stats(idx_dir: Path | None = None) -> Dict[str, Any]:
 
93
  return f"query: {text.strip()}"
94
 
95
  def search_index(query_text: str, *args, top_k: int = 5, **kwargs) -> List[Dict[str, Any]]:
96
+ """
97
+ Flexible signature:
98
+ - search_index(query, top_k=5)
99
+ - search_index(query, embedder, bundle, top_k=5)
100
+ """
101
+ # --- Determine how the function was called ---
102
  if len(args) >= 2:
103
  embedder = args[0]
104
  bundle = args[1]
105
+ elif len(args) == 1:
106
+ # Called as search_index(query, embedder, top_k=5)
107
+ embedder = args[0]
108
+ _, bundle = load_index_bundle()
109
  else:
110
  embedder, bundle = load_index_bundle()
111
 
112
+ # --- Ensure top_k is an integer ---
113
+ try:
114
+ top_k = int(top_k)
115
+ except Exception:
116
+ raise TypeError(f"top_k must be int, got {type(top_k)}")
117
+
118
+ index = bundle.get("index")
119
+ chunks: List[Dict[str, Any]] = bundle.get("chunks", [])
120
 
121
+ if index is None or not hasattr(index, "search"):
122
+ raise ValueError("Invalid FAISS index bundle.")
123
+
124
+ # --- Embed and search ---
125
  q = _format_e5_query(query_text)
126
  qv = np.asarray(embedder.encode([q], normalize_embeddings=True), dtype="float32")
127
 
128
  scores, idxs = index.search(qv, top_k)
129
  idxs, scores = idxs[0], scores[0]
130
 
131
+ # --- Collect results ---
132
  results: List[Dict[str, Any]] = []
133
  for rank, (i, s) in enumerate(zip(idxs, scores)):
134
  if i < 0 or i >= len(chunks):
 
143
  })
144
  return results
145
 
146
+
147
  # ------------------------------ stats helper ---------------------------------
148
 
149
  def index_stats(idx_dir: Path | None = None) -> Dict[str, Any]: