Spaces:
Sleeping
Sleeping
File size: 5,269 Bytes
4d5a5ba b2a1c79 f878b1b e49d5ca 2ce0b48 f878b1b 2ce0b48 4d5a5ba 2ce0b48 fd4b655 b21c373 fd4b655 2ce0b48 4d5a5ba 2ce0b48 b21c373 2ce0b48 6aea901 2ce0b48 4d5a5ba 2ce0b48 fd4b655 4d5a5ba 2ce0b48 4d5a5ba 178fcdd 564da1a 4b1ac20 b2a1c79 4d5a5ba 2ce0b48 4d5a5ba 178fcdd 4d5a5ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import gradio as gr
from huggingface_hub import login
from smolagents import HfApiModel, Tool, CodeAgent
import os
import sys
import json
if './lib' not in sys.path :
sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db
############################################################################################
################################### TOOLS ##################################################
############################################################################################
def search_key(d, target_key):
"""
Recherche une clé dans un dictionnaire imbriqué.
:param d: Le dictionnaire dans lequel chercher.
:param target_key: La clé à chercher.
:return: Une liste des valeurs associées à la clé trouvée.
"""
results = []
def recursive_search(d):
if isinstance(d, dict):
for key, value in d.items():
if key == target_key:
results.append(value)
if isinstance(value, dict):
recursive_search(value)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
recursive_search(item)
recursive_search(d)
if len(results)>0:
return str(results[0])
else :
return "Indicator not found. Try globals indicators in this list : ['ESRS E4', 'ESRS 2 MDR', 'ESRS S2', 'ESRS E2', 'ESRS S4', 'ESRS E5', 'ESRS 2', 'ESRS E1', 'ESRS S3', 'ESRS S1', 'ESRS G1', 'ESRS E3']"
############################################################################################
class Chroma_retrieverTool(Tool):
name = "request"
description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
inputs = {
"query": {
"type": "string",
"description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
},
}
output_type = "string"
def forward(self, query: str) -> str:
assert isinstance(query, str), "The request needs to be a string."
query_results = retrieve_info_from_db(query)
str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])
return str_result
############################################################################################
class ESRS_info_tool(Tool):
name = "find_ESRS"
description = "Find ESRS description to help you to find what indicators the user want to analyze"
inputs = {
"indicator": {
"type": "string",
"description": "The indicator name with format for example like following 'ESRS EX' or 'EX'. return the description of the indicator demanded.",
},
}
output_type = "string"
def forward(self, indicator: str) -> str:
assert isinstance(indicator, str), "The request needs to be a string."
with open('./data/dico_esrs.json') as json_data:
dico_esrs = json.load(json_data)
result = search_key(dico_esrs, indicator)
return result
############################################################################################
############################################################################################
############################################################################################
def respond(message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,):
system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database.
User's question : """
agent_output = agent.run(system_prompt_added + message)
yield agent_output
############################################################################################
hf_token = os.getenv("HF_TOKEN_all")
login(hf_token)
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
tools=[
get_ESRS_info_tool,
retriever_tool,
],
model=model,
max_steps=10,
max_print_outputs_length=16000,
additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
)
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()
|