rohitchandra commited on
Commit
726f816
·
1 Parent(s): dba87ef

fixed RAG state message issue

Browse files
Files changed (2) hide show
  1. agents/rag_agent.py +74 -24
  2. app.py +4 -14
agents/rag_agent.py CHANGED
@@ -64,6 +64,7 @@ class AgenticRAGState(MessagesState):
64
  is_sufficient: bool = False
65
  retry_count: int = 0 # Track number of retries to prevent infinite loops
66
  max_retries: int = 3 # Maximum number of query rewrites allowed
 
67
 
68
 
69
  class AgenticRAGChat(ChatInterface):
@@ -207,17 +208,38 @@ class AgenticRAGChat(ChatInterface):
207
  """Evaluate the documents retrieved from the retriever tool."""
208
  print("Evaluating documents...")
209
 
210
- # Get original user question and retrieved docs
211
- user_question = state["messages"][0].content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  retrieved_docs = state["messages"][-1].content
213
 
 
 
214
  chain = DOCUMENT_EVALUATOR_PROMPT | self.evaluator_llm
215
  evaluation = chain.invoke({
216
  "question": user_question,
217
  "retrieved_docs": retrieved_docs
218
  })
219
 
220
- print(f"Evaluation result: {evaluation}")
221
  return {
222
  "is_sufficient": evaluation.is_sufficient,
223
  "feedback": evaluation.feedback
@@ -227,9 +249,22 @@ class AgenticRAGChat(ChatInterface):
227
  """Synthesize the final answer from retrieved documents."""
228
  print("Synthesizing answer...")
229
 
230
- user_question = state["messages"][0].content
 
 
 
 
 
 
 
 
 
 
 
231
  retrieved_docs = state["messages"][-1].content
232
 
 
 
233
  chain = DOCUMENT_SYNTHESIZER_PROMPT | self.llm
234
  answer = chain.invoke({
235
  "question": user_question,
@@ -242,10 +277,25 @@ class AgenticRAGChat(ChatInterface):
242
  """Rewrite the query based on evaluation feedback."""
243
  print("Rewriting query...")
244
 
245
- user_question = state["messages"][0].content
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  retrieved_docs = state["messages"][-1].content
247
  feedback = state["feedback"]
248
 
 
 
249
  chain = QUERY_REWRITER_PROMPT | self.llm
250
  new_query = chain.invoke({
251
  "question": user_question,
@@ -253,8 +303,11 @@ class AgenticRAGChat(ChatInterface):
253
  "retrieved_docs": retrieved_docs
254
  })
255
 
256
- print(f"Rewritten query: {new_query.content}")
257
- return {"messages": [new_query]}
 
 
 
258
 
259
  def _create_graph(self) -> Any:
260
  """Create the agentic RAG graph."""
@@ -331,30 +384,29 @@ class AgenticRAGChat(ChatInterface):
331
  # Convert chat history to messages
332
  history_messages = self._convert_history_to_messages(chat_history)
333
 
 
 
 
 
334
  # Add the current message
335
- history_messages.append(HumanMessage(content=message))
 
336
 
337
  # Create initial state with full conversation history
 
338
  state = AgenticRAGState(
339
- messages=history_messages, # Include full history instead of just current message
340
  feedback="",
341
  is_sufficient=False,
342
  retry_count=0,
343
- max_retries=3
 
 
344
  )
345
 
346
- # state = AgenticRAGState(
347
- # messages=[HumanMessage(content=message)],
348
- # feedback="",
349
- # is_sufficient=False,
350
- # retry_count=0,
351
- # max_retries=3 # Limit to 3 retries to prevent infinite loops
352
- # )
353
-
354
  try:
355
- # Run the workflow
356
  # Run the workflow with increased recursion limit
357
- config = {"recursion_limit": 30} # Increased but reasonable limit
358
  result = self.graph.invoke(state, config=config)
359
 
360
  print("\n=== RAG QUERY COMPLETED ===")
@@ -371,10 +423,8 @@ class AgenticRAGChat(ChatInterface):
371
 
372
  except Exception as e:
373
  print(f"Error in RAG processing: {e}")
374
- # Provide a more helpful fallback response
375
  if "recursion" in str(e).lower():
376
  return ("I had difficulty finding the exact information you're looking for in the documents. "
377
- "Based on the available documents, I can see references to various offices and services, "
378
- "but I couldn't find specific details about Mission Support Services. "
379
- "You might want to try asking about a specific aspect or department.")
380
  return f"I encountered an error while searching for information: {str(e)}"
 
64
  is_sufficient: bool = False
65
  retry_count: int = 0 # Track number of retries to prevent infinite loops
66
  max_retries: int = 3 # Maximum number of query rewrites allowed
67
+ current_query_index: int = 0 # Track which message is the current query
68
 
69
 
70
  class AgenticRAGChat(ChatInterface):
 
208
  """Evaluate the documents retrieved from the retriever tool."""
209
  print("Evaluating documents...")
210
 
211
+ # Check if we've hit max retries
212
+ if state.get("retry_count", 0) >= state.get("max_retries", 3):
213
+ print(f"Max retries ({state.get('max_retries', 3)}) reached. Forcing synthesis with available documents.")
214
+ return {
215
+ "is_sufficient": True, # Force synthesis even if not perfect
216
+ "feedback": "Maximum retries reached. Using available documents."
217
+ }
218
+
219
+ # Get the CURRENT user question, not the first message in history
220
+ # Use the current_query_index to get the right message
221
+ current_query_index = state.get("current_query_index", 0)
222
+
223
+ # Find the current query message
224
+ user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
225
+ if current_query_index < len(state["messages"]):
226
+ user_question = state["messages"][current_query_index].content
227
+ else:
228
+ # Fallback: get the last user message
229
+ user_question = user_messages[-1].content if user_messages else state["messages"][-1].content
230
+
231
+ # Get the retrieved documents (should be the last message)
232
  retrieved_docs = state["messages"][-1].content
233
 
234
+ print(f"Evaluating for query: '{user_question[:50]}...'") # Debug print
235
+
236
  chain = DOCUMENT_EVALUATOR_PROMPT | self.evaluator_llm
237
  evaluation = chain.invoke({
238
  "question": user_question,
239
  "retrieved_docs": retrieved_docs
240
  })
241
 
242
+ print(f"Evaluation result: {evaluation} (Retry {state.get('retry_count', 0)}/{state.get('max_retries', 3)})")
243
  return {
244
  "is_sufficient": evaluation.is_sufficient,
245
  "feedback": evaluation.feedback
 
249
  """Synthesize the final answer from retrieved documents."""
250
  print("Synthesizing answer...")
251
 
252
+ # Get the CURRENT user question using the index
253
+ current_query_index = state.get("current_query_index", 0)
254
+
255
+ # Find the current query message
256
+ user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
257
+ if current_query_index < len(state["messages"]):
258
+ user_question = state["messages"][current_query_index].content
259
+ else:
260
+ # Fallback: get the last user message
261
+ user_question = user_messages[-1].content if user_messages else state["messages"][-1].content
262
+
263
+ # Get the retrieved documents
264
  retrieved_docs = state["messages"][-1].content
265
 
266
+ print(f"Synthesizing answer for: '{user_question[:50]}...'") # Debug print
267
+
268
  chain = DOCUMENT_SYNTHESIZER_PROMPT | self.llm
269
  answer = chain.invoke({
270
  "question": user_question,
 
277
  """Rewrite the query based on evaluation feedback."""
278
  print("Rewriting query...")
279
 
280
+ # Increment retry count
281
+ current_retry = state.get("retry_count", 0)
282
+
283
+ # Get the CURRENT user question using the index
284
+ current_query_index = state.get("current_query_index", 0)
285
+
286
+ # Find the current query message
287
+ user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
288
+ if current_query_index < len(state["messages"]):
289
+ user_question = state["messages"][current_query_index].content
290
+ else:
291
+ # Fallback: get the last user message
292
+ user_question = user_messages[-1].content if user_messages else state["messages"][-1].content
293
+
294
  retrieved_docs = state["messages"][-1].content
295
  feedback = state["feedback"]
296
 
297
+ print(f"Rewriting query for: '{user_question[:50]}...'") # Debug print
298
+
299
  chain = QUERY_REWRITER_PROMPT | self.llm
300
  new_query = chain.invoke({
301
  "question": user_question,
 
303
  "retrieved_docs": retrieved_docs
304
  })
305
 
306
+ print(f"Rewritten query (Attempt {current_retry + 1}/{state.get('max_retries', 3)}): {new_query.content}")
307
+ return {
308
+ "messages": [new_query],
309
+ "retry_count": current_retry + 1 # Increment retry count
310
+ }
311
 
312
  def _create_graph(self) -> Any:
313
  """Create the agentic RAG graph."""
 
384
  # Convert chat history to messages
385
  history_messages = self._convert_history_to_messages(chat_history)
386
 
387
+ # Mark the position where the current query starts
388
+ # This is important for the evaluator to know which is the actual query
389
+ history_length = len(history_messages)
390
+
391
  # Add the current message
392
+ current_query_message = HumanMessage(content=message)
393
+ history_messages.append(current_query_message)
394
 
395
  # Create initial state with full conversation history
396
+ # Store the index of the current query for reference
397
  state = AgenticRAGState(
398
+ messages=history_messages,
399
  feedback="",
400
  is_sufficient=False,
401
  retry_count=0,
402
+ max_retries=3,
403
+ # Add this to track the current query index
404
+ current_query_index=history_length # This is the index of the current query
405
  )
406
 
 
 
 
 
 
 
 
 
407
  try:
 
408
  # Run the workflow with increased recursion limit
409
+ config = {"recursion_limit": 30}
410
  result = self.graph.invoke(state, config=config)
411
 
412
  print("\n=== RAG QUERY COMPLETED ===")
 
423
 
424
  except Exception as e:
425
  print(f"Error in RAG processing: {e}")
 
426
  if "recursion" in str(e).lower():
427
  return ("I had difficulty finding the exact information you're looking for in the documents. "
428
+ "Based on the available documents, I can see references to various topics, "
429
+ "but I couldn't find specific details. You might want to try asking about a specific aspect.")
 
430
  return f"I encountered an error while searching for information: {str(e)}"
app.py CHANGED
@@ -81,9 +81,9 @@ def create_demo():
81
  examples=[
82
  "What is 847 * 293?",
83
  "What's today's date?",
84
- "What's the weather in San Francisco?",
85
- "Explain quantum computing in simple terms",
86
- "Research the impact of AI on healthcare",
87
  ],
88
  theme=gr.themes.Soft(),
89
  analytics_enabled=False,
@@ -95,14 +95,4 @@ def create_demo():
95
  if __name__ == "__main__":
96
  # Create and launch the demo
97
  demo = create_demo()
98
-
99
- # Check if running in Hugging Face Spaces
100
- if os.environ.get("SPACE_ID"):
101
- # Hugging Face Spaces configuration
102
- demo.launch(
103
- server_name="0.0.0.0",
104
- server_port=int(os.environ.get("PORT", 7860))
105
- )
106
- else:
107
- # Local development - use simple defaults
108
- demo.launch()
 
81
  examples=[
82
  "What is 847 * 293?",
83
  "What's today's date?",
84
+ # "What's the weather in San Francisco?",
85
+ # "Explain quantum computing in simple terms",
86
+ # "Research the impact of AI on healthcare",
87
  ],
88
  theme=gr.themes.Soft(),
89
  analytics_enabled=False,
 
95
  if __name__ == "__main__":
96
  # Create and launch the demo
97
  demo = create_demo()
98
+ demo.launch()