Spaces:
Sleeping
Sleeping
EtienneB
commited on
Commit
·
fc5e0c3
1
Parent(s):
916fd5c
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -3,7 +3,8 @@ import os
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
-
from langchain_core.messages import HumanMessage, SystemMessage,
|
|
|
|
| 7 |
from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
|
| 8 |
HuggingFaceEndpoint)
|
| 9 |
from langgraph.graph import START, MessagesState, StateGraph
|
|
@@ -71,18 +72,25 @@ def build_graph():
|
|
| 71 |
llm_with_tools = llm.bind_tools(tools)
|
| 72 |
|
| 73 |
# --- Nodes ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def assistant(state: MessagesState):
|
| 75 |
-
"""Assistant node"""
|
| 76 |
messages_with_system_prompt = [sys_msg] + state["messages"]
|
| 77 |
llm_response = llm_with_tools.invoke(messages_with_system_prompt)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
if answer_text.strip().lower().startswith("final answer:"):
|
| 81 |
-
answer_text = answer_text.split(":", 1)[1].strip()
|
| 82 |
-
# Get task_id from state or set a placeholder
|
| 83 |
-
task_id = state.get("task_id", "1") # Replace with actual logic if needed
|
| 84 |
formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
|
| 85 |
-
return {"messages": [formatted]}
|
| 86 |
|
| 87 |
# --- Graph Definition ---
|
| 88 |
builder = StateGraph(MessagesState)
|
|
@@ -154,9 +162,12 @@ if __name__ == "__main__":
|
|
| 154 |
print(message.content)
|
| 155 |
print("-----------------------")
|
| 156 |
else:
|
| 157 |
-
output =
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
+
from langchain_core.messages import (AIMessage, HumanMessage, SystemMessage,
|
| 7 |
+
ToolMessage)
|
| 8 |
from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
|
| 9 |
HuggingFaceEndpoint)
|
| 10 |
from langgraph.graph import START, MessagesState, StateGraph
|
|
|
|
| 72 |
llm_with_tools = llm.bind_tools(tools)
|
| 73 |
|
| 74 |
# --- Nodes ---
|
| 75 |
+
def extract_answer(llm_output):
|
| 76 |
+
# Try to parse as JSON if possible
|
| 77 |
+
try:
|
| 78 |
+
# If the LLM output is a JSON list, extract the answer
|
| 79 |
+
parsed = json.loads(llm_output.strip().split('\n')[0])
|
| 80 |
+
if isinstance(parsed, list) and isinstance(parsed[0], dict) and "submitted_answer" in parsed[0]:
|
| 81 |
+
return parsed[0]["submitted_answer"]
|
| 82 |
+
except Exception:
|
| 83 |
+
pass
|
| 84 |
+
# Otherwise, just return the first line (before any explanation)
|
| 85 |
+
return llm_output.strip().split('\n')[0]
|
| 86 |
+
|
| 87 |
def assistant(state: MessagesState):
|
|
|
|
| 88 |
messages_with_system_prompt = [sys_msg] + state["messages"]
|
| 89 |
llm_response = llm_with_tools.invoke(messages_with_system_prompt)
|
| 90 |
+
answer_text = extract_answer(llm_response.content)
|
| 91 |
+
task_id = str(state.get("task_id", "1")) # Ensure task_id is a string
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
formatted = [{"task_id": task_id, "submitted_answer": answer_text}]
|
| 93 |
+
return {"messages": [AIMessage(content=json.dumps(formatted, ensure_ascii=False))]}
|
| 94 |
|
| 95 |
# --- Graph Definition ---
|
| 96 |
builder = StateGraph(MessagesState)
|
|
|
|
| 162 |
print(message.content)
|
| 163 |
print("-----------------------")
|
| 164 |
else:
|
| 165 |
+
output = message.content # This is a string
|
| 166 |
+
try:
|
| 167 |
+
parsed = json.loads(output)
|
| 168 |
+
if isinstance(parsed, list) and "task_id" in parsed[0] and "submitted_answer" in parsed[0]:
|
| 169 |
+
print("✅ Output is in the correct format!")
|
| 170 |
+
else:
|
| 171 |
+
print("❌ Output is NOT in the correct format!")
|
| 172 |
+
except Exception as e:
|
| 173 |
+
print("❌ Output is NOT in the correct format!", e)
|