Spaces:
Sleeping
Sleeping
Commit
·
726f816
1
Parent(s):
dba87ef
fixed RAG state message issue
Browse files- agents/rag_agent.py +74 -24
- app.py +4 -14
agents/rag_agent.py
CHANGED
|
@@ -64,6 +64,7 @@ class AgenticRAGState(MessagesState):
|
|
| 64 |
is_sufficient: bool = False
|
| 65 |
retry_count: int = 0 # Track number of retries to prevent infinite loops
|
| 66 |
max_retries: int = 3 # Maximum number of query rewrites allowed
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
class AgenticRAGChat(ChatInterface):
|
|
@@ -207,17 +208,38 @@ class AgenticRAGChat(ChatInterface):
|
|
| 207 |
"""Evaluate the documents retrieved from the retriever tool."""
|
| 208 |
print("Evaluating documents...")
|
| 209 |
|
| 210 |
-
#
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
retrieved_docs = state["messages"][-1].content
|
| 213 |
|
|
|
|
|
|
|
| 214 |
chain = DOCUMENT_EVALUATOR_PROMPT | self.evaluator_llm
|
| 215 |
evaluation = chain.invoke({
|
| 216 |
"question": user_question,
|
| 217 |
"retrieved_docs": retrieved_docs
|
| 218 |
})
|
| 219 |
|
| 220 |
-
print(f"Evaluation result: {evaluation}")
|
| 221 |
return {
|
| 222 |
"is_sufficient": evaluation.is_sufficient,
|
| 223 |
"feedback": evaluation.feedback
|
|
@@ -227,9 +249,22 @@ class AgenticRAGChat(ChatInterface):
|
|
| 227 |
"""Synthesize the final answer from retrieved documents."""
|
| 228 |
print("Synthesizing answer...")
|
| 229 |
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
retrieved_docs = state["messages"][-1].content
|
| 232 |
|
|
|
|
|
|
|
| 233 |
chain = DOCUMENT_SYNTHESIZER_PROMPT | self.llm
|
| 234 |
answer = chain.invoke({
|
| 235 |
"question": user_question,
|
|
@@ -242,10 +277,25 @@ class AgenticRAGChat(ChatInterface):
|
|
| 242 |
"""Rewrite the query based on evaluation feedback."""
|
| 243 |
print("Rewriting query...")
|
| 244 |
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
retrieved_docs = state["messages"][-1].content
|
| 247 |
feedback = state["feedback"]
|
| 248 |
|
|
|
|
|
|
|
| 249 |
chain = QUERY_REWRITER_PROMPT | self.llm
|
| 250 |
new_query = chain.invoke({
|
| 251 |
"question": user_question,
|
|
@@ -253,8 +303,11 @@ class AgenticRAGChat(ChatInterface):
|
|
| 253 |
"retrieved_docs": retrieved_docs
|
| 254 |
})
|
| 255 |
|
| 256 |
-
print(f"Rewritten query: {new_query.content}")
|
| 257 |
-
return {
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
def _create_graph(self) -> Any:
|
| 260 |
"""Create the agentic RAG graph."""
|
|
@@ -331,30 +384,29 @@ class AgenticRAGChat(ChatInterface):
|
|
| 331 |
# Convert chat history to messages
|
| 332 |
history_messages = self._convert_history_to_messages(chat_history)
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
# Add the current message
|
| 335 |
-
|
|
|
|
| 336 |
|
| 337 |
# Create initial state with full conversation history
|
|
|
|
| 338 |
state = AgenticRAGState(
|
| 339 |
-
messages=history_messages,
|
| 340 |
feedback="",
|
| 341 |
is_sufficient=False,
|
| 342 |
retry_count=0,
|
| 343 |
-
max_retries=3
|
|
|
|
|
|
|
| 344 |
)
|
| 345 |
|
| 346 |
-
# state = AgenticRAGState(
|
| 347 |
-
# messages=[HumanMessage(content=message)],
|
| 348 |
-
# feedback="",
|
| 349 |
-
# is_sufficient=False,
|
| 350 |
-
# retry_count=0,
|
| 351 |
-
# max_retries=3 # Limit to 3 retries to prevent infinite loops
|
| 352 |
-
# )
|
| 353 |
-
|
| 354 |
try:
|
| 355 |
-
# Run the workflow
|
| 356 |
# Run the workflow with increased recursion limit
|
| 357 |
-
config = {"recursion_limit": 30}
|
| 358 |
result = self.graph.invoke(state, config=config)
|
| 359 |
|
| 360 |
print("\n=== RAG QUERY COMPLETED ===")
|
|
@@ -371,10 +423,8 @@ class AgenticRAGChat(ChatInterface):
|
|
| 371 |
|
| 372 |
except Exception as e:
|
| 373 |
print(f"Error in RAG processing: {e}")
|
| 374 |
-
# Provide a more helpful fallback response
|
| 375 |
if "recursion" in str(e).lower():
|
| 376 |
return ("I had difficulty finding the exact information you're looking for in the documents. "
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
"You might want to try asking about a specific aspect or department.")
|
| 380 |
return f"I encountered an error while searching for information: {str(e)}"
|
|
|
|
| 64 |
is_sufficient: bool = False
|
| 65 |
retry_count: int = 0 # Track number of retries to prevent infinite loops
|
| 66 |
max_retries: int = 3 # Maximum number of query rewrites allowed
|
| 67 |
+
current_query_index: int = 0 # Track which message is the current query
|
| 68 |
|
| 69 |
|
| 70 |
class AgenticRAGChat(ChatInterface):
|
|
|
|
| 208 |
"""Evaluate the documents retrieved from the retriever tool."""
|
| 209 |
print("Evaluating documents...")
|
| 210 |
|
| 211 |
+
# Check if we've hit max retries
|
| 212 |
+
if state.get("retry_count", 0) >= state.get("max_retries", 3):
|
| 213 |
+
print(f"Max retries ({state.get('max_retries', 3)}) reached. Forcing synthesis with available documents.")
|
| 214 |
+
return {
|
| 215 |
+
"is_sufficient": True, # Force synthesis even if not perfect
|
| 216 |
+
"feedback": "Maximum retries reached. Using available documents."
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
# Get the CURRENT user question, not the first message in history
|
| 220 |
+
# Use the current_query_index to get the right message
|
| 221 |
+
current_query_index = state.get("current_query_index", 0)
|
| 222 |
+
|
| 223 |
+
# Find the current query message
|
| 224 |
+
user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
|
| 225 |
+
if current_query_index < len(state["messages"]):
|
| 226 |
+
user_question = state["messages"][current_query_index].content
|
| 227 |
+
else:
|
| 228 |
+
# Fallback: get the last user message
|
| 229 |
+
user_question = user_messages[-1].content if user_messages else state["messages"][-1].content
|
| 230 |
+
|
| 231 |
+
# Get the retrieved documents (should be the last message)
|
| 232 |
retrieved_docs = state["messages"][-1].content
|
| 233 |
|
| 234 |
+
print(f"Evaluating for query: '{user_question[:50]}...'") # Debug print
|
| 235 |
+
|
| 236 |
chain = DOCUMENT_EVALUATOR_PROMPT | self.evaluator_llm
|
| 237 |
evaluation = chain.invoke({
|
| 238 |
"question": user_question,
|
| 239 |
"retrieved_docs": retrieved_docs
|
| 240 |
})
|
| 241 |
|
| 242 |
+
print(f"Evaluation result: {evaluation} (Retry {state.get('retry_count', 0)}/{state.get('max_retries', 3)})")
|
| 243 |
return {
|
| 244 |
"is_sufficient": evaluation.is_sufficient,
|
| 245 |
"feedback": evaluation.feedback
|
|
|
|
| 249 |
"""Synthesize the final answer from retrieved documents."""
|
| 250 |
print("Synthesizing answer...")
|
| 251 |
|
| 252 |
+
# Get the CURRENT user question using the index
|
| 253 |
+
current_query_index = state.get("current_query_index", 0)
|
| 254 |
+
|
| 255 |
+
# Find the current query message
|
| 256 |
+
user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
|
| 257 |
+
if current_query_index < len(state["messages"]):
|
| 258 |
+
user_question = state["messages"][current_query_index].content
|
| 259 |
+
else:
|
| 260 |
+
# Fallback: get the last user message
|
| 261 |
+
user_question = user_messages[-1].content if user_messages else state["messages"][-1].content
|
| 262 |
+
|
| 263 |
+
# Get the retrieved documents
|
| 264 |
retrieved_docs = state["messages"][-1].content
|
| 265 |
|
| 266 |
+
print(f"Synthesizing answer for: '{user_question[:50]}...'") # Debug print
|
| 267 |
+
|
| 268 |
chain = DOCUMENT_SYNTHESIZER_PROMPT | self.llm
|
| 269 |
answer = chain.invoke({
|
| 270 |
"question": user_question,
|
|
|
|
| 277 |
"""Rewrite the query based on evaluation feedback."""
|
| 278 |
print("Rewriting query...")
|
| 279 |
|
| 280 |
+
# Increment retry count
|
| 281 |
+
current_retry = state.get("retry_count", 0)
|
| 282 |
+
|
| 283 |
+
# Get the CURRENT user question using the index
|
| 284 |
+
current_query_index = state.get("current_query_index", 0)
|
| 285 |
+
|
| 286 |
+
# Find the current query message
|
| 287 |
+
user_messages = [msg for msg in state["messages"] if isinstance(msg, HumanMessage)]
|
| 288 |
+
if current_query_index < len(state["messages"]):
|
| 289 |
+
user_question = state["messages"][current_query_index].content
|
| 290 |
+
else:
|
| 291 |
+
# Fallback: get the last user message
|
| 292 |
+
user_question = user_messages[-1].content if user_messages else state["messages"][-1].content
|
| 293 |
+
|
| 294 |
retrieved_docs = state["messages"][-1].content
|
| 295 |
feedback = state["feedback"]
|
| 296 |
|
| 297 |
+
print(f"Rewriting query for: '{user_question[:50]}...'") # Debug print
|
| 298 |
+
|
| 299 |
chain = QUERY_REWRITER_PROMPT | self.llm
|
| 300 |
new_query = chain.invoke({
|
| 301 |
"question": user_question,
|
|
|
|
| 303 |
"retrieved_docs": retrieved_docs
|
| 304 |
})
|
| 305 |
|
| 306 |
+
print(f"Rewritten query (Attempt {current_retry + 1}/{state.get('max_retries', 3)}): {new_query.content}")
|
| 307 |
+
return {
|
| 308 |
+
"messages": [new_query],
|
| 309 |
+
"retry_count": current_retry + 1 # Increment retry count
|
| 310 |
+
}
|
| 311 |
|
| 312 |
def _create_graph(self) -> Any:
|
| 313 |
"""Create the agentic RAG graph."""
|
|
|
|
| 384 |
# Convert chat history to messages
|
| 385 |
history_messages = self._convert_history_to_messages(chat_history)
|
| 386 |
|
| 387 |
+
# Mark the position where the current query starts
|
| 388 |
+
# This is important for the evaluator to know which is the actual query
|
| 389 |
+
history_length = len(history_messages)
|
| 390 |
+
|
| 391 |
# Add the current message
|
| 392 |
+
current_query_message = HumanMessage(content=message)
|
| 393 |
+
history_messages.append(current_query_message)
|
| 394 |
|
| 395 |
# Create initial state with full conversation history
|
| 396 |
+
# Store the index of the current query for reference
|
| 397 |
state = AgenticRAGState(
|
| 398 |
+
messages=history_messages,
|
| 399 |
feedback="",
|
| 400 |
is_sufficient=False,
|
| 401 |
retry_count=0,
|
| 402 |
+
max_retries=3,
|
| 403 |
+
# Add this to track the current query index
|
| 404 |
+
current_query_index=history_length # This is the index of the current query
|
| 405 |
)
|
| 406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
try:
|
|
|
|
| 408 |
# Run the workflow with increased recursion limit
|
| 409 |
+
config = {"recursion_limit": 30}
|
| 410 |
result = self.graph.invoke(state, config=config)
|
| 411 |
|
| 412 |
print("\n=== RAG QUERY COMPLETED ===")
|
|
|
|
| 423 |
|
| 424 |
except Exception as e:
|
| 425 |
print(f"Error in RAG processing: {e}")
|
|
|
|
| 426 |
if "recursion" in str(e).lower():
|
| 427 |
return ("I had difficulty finding the exact information you're looking for in the documents. "
|
| 428 |
+
"Based on the available documents, I can see references to various topics, "
|
| 429 |
+
"but I couldn't find specific details. You might want to try asking about a specific aspect.")
|
|
|
|
| 430 |
return f"I encountered an error while searching for information: {str(e)}"
|
app.py
CHANGED
|
@@ -81,9 +81,9 @@ def create_demo():
|
|
| 81 |
examples=[
|
| 82 |
"What is 847 * 293?",
|
| 83 |
"What's today's date?",
|
| 84 |
-
"What's the weather in San Francisco?",
|
| 85 |
-
"Explain quantum computing in simple terms",
|
| 86 |
-
"Research the impact of AI on healthcare",
|
| 87 |
],
|
| 88 |
theme=gr.themes.Soft(),
|
| 89 |
analytics_enabled=False,
|
|
@@ -95,14 +95,4 @@ def create_demo():
|
|
| 95 |
if __name__ == "__main__":
|
| 96 |
# Create and launch the demo
|
| 97 |
demo = create_demo()
|
| 98 |
-
|
| 99 |
-
# Check if running in Hugging Face Spaces
|
| 100 |
-
if os.environ.get("SPACE_ID"):
|
| 101 |
-
# Hugging Face Spaces configuration
|
| 102 |
-
demo.launch(
|
| 103 |
-
server_name="0.0.0.0",
|
| 104 |
-
server_port=int(os.environ.get("PORT", 7860))
|
| 105 |
-
)
|
| 106 |
-
else:
|
| 107 |
-
# Local development - use simple defaults
|
| 108 |
-
demo.launch()
|
|
|
|
| 81 |
examples=[
|
| 82 |
"What is 847 * 293?",
|
| 83 |
"What's today's date?",
|
| 84 |
+
# "What's the weather in San Francisco?",
|
| 85 |
+
# "Explain quantum computing in simple terms",
|
| 86 |
+
# "Research the impact of AI on healthcare",
|
| 87 |
],
|
| 88 |
theme=gr.themes.Soft(),
|
| 89 |
analytics_enabled=False,
|
|
|
|
| 95 |
if __name__ == "__main__":
|
| 96 |
# Create and launch the demo
|
| 97 |
demo = create_demo()
|
| 98 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|