File size: 5,269 Bytes
4d5a5ba
b2a1c79
f878b1b
e49d5ca
 
2ce0b48
f878b1b
 
2ce0b48
 
 
4d5a5ba
2ce0b48
 
 
 
fd4b655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b21c373
 
 
 
 
fd4b655
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5a5ba
2ce0b48
 
b21c373
2ce0b48
 
 
6aea901
2ce0b48
 
 
 
 
 
4d5a5ba
2ce0b48
 
 
fd4b655
4d5a5ba
2ce0b48
 
 
 
 
4d5a5ba
178fcdd
 
 
 
 
 
564da1a
 
 
 
 
 
 
4b1ac20
b2a1c79
 
4d5a5ba
2ce0b48
 
 
 
 
 
 
 
 
 
 
 
4d5a5ba
 
 
 
 
 
 
178fcdd
 
 
 
 
 
 
 
 
 
 
 
4d5a5ba
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
from huggingface_hub import login
from smolagents import HfApiModel, Tool, CodeAgent

import os
import sys
import json

if './lib' not in sys.path :
    sys.path.append('./lib')
from ingestion_chroma import retrieve_info_from_db

############################################################################################
################################### TOOLS ##################################################
############################################################################################

def search_key(d, target_key):
    """
    Recherche une clé dans un dictionnaire imbriqué.

    :param d: Le dictionnaire dans lequel chercher.
    :param target_key: La clé à chercher.
    :return: Une liste des valeurs associées à la clé trouvée.
    """
    results = []

    def recursive_search(d):
        if isinstance(d, dict):
            for key, value in d.items():
                if key == target_key:
                    results.append(value)
                if isinstance(value, dict):
                    recursive_search(value)
                elif isinstance(value, list):
                    for item in value:
                        if isinstance(item, dict):
                            recursive_search(item)

    recursive_search(d)
    if len(results)>0:
        return str(results[0])
    else :
        return "Indicator not found. Try globals indicators in this list : ['ESRS E4', 'ESRS 2 MDR', 'ESRS S2', 'ESRS E2', 'ESRS S4', 'ESRS E5', 'ESRS 2', 'ESRS E1', 'ESRS S3', 'ESRS S1', 'ESRS G1', 'ESRS E3']"
    
    

############################################################################################

class Chroma_retrieverTool(Tool):
    name = "request"
    description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
        },
    }
    output_type = "string"
    
    def forward(self, query: str) -> str:
        assert isinstance(query, str), "The request needs to be a string."
        
        query_results = retrieve_info_from_db(query)
        str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])

        return str_result
        
############################################################################################

class ESRS_info_tool(Tool):
    name = "find_ESRS"
    description = "Find ESRS description to help you to find what indicators the user want to analyze"
    inputs = {
        "indicator": {
            "type": "string",
            "description": "The indicator name with format for example like following 'ESRS EX' or 'EX'. return the description of the indicator demanded.",
        },
    }
    output_type = "string"
    
    def forward(self, indicator: str) -> str:
        assert isinstance(indicator, str), "The request needs to be a string."

        with open('./data/dico_esrs.json') as json_data:
            dico_esrs = json.load(json_data)
        
        result = search_key(dico_esrs, indicator)

        return result
        
############################################################################################
############################################################################################
############################################################################################

def respond(message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,):
    system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database. 
User's question : """
    agent_output = agent.run(system_prompt_added + message)
    
    yield agent_output
    
############################################################################################
hf_token = os.getenv("HF_TOKEN_all")
login(hf_token)
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")

retriever_tool = Chroma_retrieverTool()
get_ESRS_info_tool = ESRS_info_tool()
agent = CodeAgent(
    tools=[
            get_ESRS_info_tool,
            retriever_tool,
          ],
    model=model,
    max_steps=10,
    max_print_outputs_length=16000,
    additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
    )


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()