Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import login | |
| from smolagents import HfApiModel, Tool, CodeAgent | |
| import os | |
| import sys | |
| import json | |
| if './lib' not in sys.path : | |
| sys.path.append('./lib') | |
| from ingestion_chroma import retrieve_info_from_db | |
| ############################################################################################ | |
| ################################### TOOLS ################################################## | |
| ############################################################################################ | |
| def search_key(d, target_key): | |
| """ | |
| Recherche une clé dans un dictionnaire imbriqué. | |
| :param d: Le dictionnaire dans lequel chercher. | |
| :param target_key: La clé à chercher. | |
| :return: Une liste des valeurs associées à la clé trouvée. | |
| """ | |
| results = [] | |
| def recursive_search(d): | |
| if isinstance(d, dict): | |
| for key, value in d.items(): | |
| if key == target_key: | |
| results.append(value) | |
| if isinstance(value, dict): | |
| recursive_search(value) | |
| elif isinstance(value, list): | |
| for item in value: | |
| if isinstance(item, dict): | |
| recursive_search(item) | |
| recursive_search(d) | |
| if len(results)>0: | |
| return str(results[0]) | |
| else : | |
| return "Indicator not found. Try globals indicators in this list : ['ESRS E4', 'ESRS 2 MDR', 'ESRS S2', 'ESRS E2', 'ESRS S4', 'ESRS E5', 'ESRS 2', 'ESRS E1', 'ESRS S3', 'ESRS S1', 'ESRS G1', 'ESRS E3']" | |
| ############################################################################################ | |
| class Chroma_retrieverTool(Tool): | |
| name = "request" | |
| description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query." | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| assert isinstance(query, str), "The request needs to be a string." | |
| query_results = retrieve_info_from_db(query) | |
| str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))]) | |
| return str_result | |
| ############################################################################################ | |
| class ESRS_info_tool(Tool): | |
| name = "find_ESRS" | |
| description = "Find ESRS description to help you to find what indicators the user want to analyze" | |
| inputs = { | |
| "indicator": { | |
| "type": "string", | |
| "description": "The indicator name with format for example like following 'ESRS EX' or 'EX'. return the description of the indicator demanded.", | |
| }, | |
| } | |
| output_type = "string" | |
| def forward(self, indicator: str) -> str: | |
| assert isinstance(indicator, str), "The request needs to be a string." | |
| with open('./data/dico_esrs.json') as json_data: | |
| dico_esrs = json.load(json_data) | |
| result = search_key(dico_esrs, indicator) | |
| return result | |
| ############################################################################################ | |
| ############################################################################################ | |
| ############################################################################################ | |
| def respond(message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p,): | |
| system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. | |
| User's question : """ | |
| agent_output = agent.run(system_prompt_added + message) | |
| yield agent_output | |
| ############################################################################################ | |
| hf_token = os.getenv("HF_TOKEN_all") | |
| login(hf_token) | |
| model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct") | |
| retriever_tool = Chroma_retrieverTool() | |
| get_ESRS_info_tool = ESRS_info_tool() | |
| agent = CodeAgent( | |
| tools=[ | |
| get_ESRS_info_tool, | |
| retriever_tool, | |
| ], | |
| model=model, | |
| max_steps=10, | |
| max_print_outputs_length=16000, | |
| additional_authorized_imports=['pandas', 'matplotlib', 'datetime'] | |
| ) | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |