| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import logging |
| | import os |
| | import random |
| | import re |
| | import sys |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Literal, Optional |
| |
|
| | import numpy as np |
| | from rouge import Rouge |
| | from tqdm import tqdm |
| |
|
| | from camel.agents import ChatAgent |
| | from camel.benchmarks.base import BaseBenchmark |
| | from camel.messages import BaseMessage |
| | from camel.utils import download_github_subdirectory |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | current_folder = os.getcwd() |
| | if current_folder not in sys.path: |
| | sys.path.append(current_folder) |
| |
|
| |
|
| | def process_messages( |
| | chat_history: List[Dict[str, Any]], |
| | prompt: str, |
| | ) -> List[Dict[str, str]]: |
| | """ |
| | Processes chat history into a structured format for further use. |
| | |
| | Args: |
| | chat_history (List[Dict[str, Any]): |
| | A list of dictionaries representing the chat history. |
| | prompt (str): A propmt to be set as the system message. |
| | |
| | Returns: |
| | List[Dict[str, str]]: A list of dictionaries representing |
| | the processed messages, where each dictionary has: |
| | - 'role': The role of the message ('system', 'user', or 'assistant'). |
| | - 'content': The content of the message, including formatted |
| | API responses when applicable. |
| | """ |
| | messages = [{'role': 'system', 'content': prompt}] |
| | for item in chat_history: |
| | role_map = {'User': 'user', 'AI': 'assistant', 'API': 'system'} |
| | chat_role = role_map.get( |
| | item['role'], 'unknown' |
| | ) |
| | if item['role'] == 'API': |
| | chat_content = '[{}({})] Response: {}'.format( |
| | item['api_name'], |
| | ', '.join( |
| | [ |
| | '{}=\'{}\''.format(k, v) |
| | for k, v in item['param_dict'].items() |
| | ] |
| | ), |
| | str(item['result']['output']), |
| | ) |
| | else: |
| | chat_content = item['text'] |
| | messages.append({'role': chat_role, 'content': chat_content}) |
| | return messages |
| |
|
| |
|
| | class APIBankBenchmark(BaseBenchmark): |
| | r"""API-Bank Benchmark adapted from `API-Bank: |
| | A Comprehensive Benchmark for Tool-Augmented LLMs` |
| | <https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank>. |
| | |
| | Args: |
| | save_to (str): The file to save the results. |
| | processes (int, optional): The number of processes to use. |
| | (default: :obj:`1`) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | save_to: str, |
| | processes: int = 1, |
| | ): |
| | r"""Initialize the APIBank benchmark. |
| | |
| | Args: |
| | save_to (str): The file to save the results. |
| | processes (int, optional): The number of processes to use for |
| | parallel processing. (default: :obj:`1`) |
| | """ |
| | |
| | super().__init__("apibank", "api_bank", save_to, processes) |
| | self._data: Dict[str, List[APIBankSample]] = dict() |
| |
|
| | def download(self): |
| | r"""Download APIBank dataset and code from Github.""" |
| |
|
| | repo = "AlibabaResearch/DAMO-ConvAI" |
| | subdir = "api-bank" |
| | data_dir = self.data_dir |
| |
|
| | download_github_subdirectory(repo, subdir, data_dir) |
| |
|
| | sys.path.insert(0, self.data_dir) |
| | logger.info("Download completed.") |
| |
|
| | def load(self, level: str, force_download: bool = False): |
| | r"""Load the APIBank Benchmark dataset. |
| | |
| | Args: |
| | level (str): Level to run benchmark on. |
| | force_download (bool, optional): Whether to |
| | force download the data. |
| | """ |
| | if force_download: |
| | logger.info("Force downloading data.") |
| | self.download() |
| |
|
| | if level == "level-1": |
| | file_path = Path("api_bank/lv1-lv2-samples/level-1-given-desc") |
| | elif level == 'level-2': |
| | file_path = Path("api_bank/lv1-lv2-samples/level-2-toolsearcher") |
| | jsonl_files = [ |
| | f for f in os.listdir(file_path) if f.endswith('.jsonl') |
| | ] |
| | for file in tqdm(jsonl_files, desc="Processing files"): |
| | history = [] |
| | with open(file_path / file, 'r') as f: |
| | for line in f: |
| | history.append(json.loads(line)) |
| | samples = APIBankSample.from_chat_history(history) |
| | self._data[file.rsplit('.', 1)[0]] = samples |
| |
|
| | |
| | def process_files(folder_path, replacements): |
| | r"""Replace absolute imports in downloaded files with |
| | relative import.""" |
| | for file in os.listdir(folder_path): |
| | if file.endswith(".py"): |
| | file_path = os.path.join(folder_path, file) |
| | try: |
| | with open(file_path, "r", encoding="utf-8") as file: |
| | content = file.read() |
| |
|
| | original_content = content |
| |
|
| | for pattern, replacement in replacements: |
| | content = re.sub(pattern, replacement, content) |
| |
|
| | if content != original_content: |
| | with open( |
| | file_path, "w", encoding="utf-8" |
| | ) as file: |
| | file.write(content) |
| | logger.info(f"Updated file: {file_path}") |
| |
|
| | except Exception as e: |
| | logger.info(f"Error processing file {file_path}: {e}") |
| |
|
| | api_bank_folder = "api_bank" |
| | apis_folder = os.path.join(api_bank_folder, "apis") |
| |
|
| | apis_replacements = [ |
| | (r"from apis.api", "from .api"), |
| | (r"from apis import", "from .api import"), |
| | ] |
| |
|
| | api_bank_replacements = [ |
| | (r"from apis", "from .apis"), |
| | (r"from api_call_extraction", "from .api_call_extraction"), |
| | (r"f'{basename}", r"f'api_bank.{basename}"), |
| | ] |
| |
|
| | process_files(apis_folder, apis_replacements) |
| | process_files(api_bank_folder, api_bank_replacements) |
| |
|
| | def run( |
| | self, |
| | agent: ChatAgent, |
| | level: Literal["level-1", "level-2"], |
| | api_test_enabled=True, |
| | randomize: bool = False, |
| | subset: Optional[int] = None, |
| | ) -> Dict[str, Any]: |
| | r"""Run the benchmark. |
| | |
| | Args: |
| | agent (ChatAgent): The agent to run the |
| | benchmark. |
| | level (Literal['level-1', 'level-2']): |
| | The level to run the benchmark on. |
| | randomize (bool, optional): Whether to |
| | randomize the data. |
| | api_test_enabled (bool): Whether to test |
| | API calling (`True`) or response (`False`) |
| | (default: :obj:`False`) |
| | subset (Optional[int], optional): |
| | The subset of data to run. |
| | (default: :obj:`None`) |
| | |
| | Returns: |
| | Dict[str, Any]: The results of the benchmark. |
| | """ |
| | logger.info(f"Running APIBench benchmark on {level}.") |
| | self.load(level) |
| | datas = self._data |
| |
|
| | |
| | if randomize: |
| | randomized_items = list(datas.items()) |
| | random.shuffle(randomized_items) |
| | datas = dict(randomized_items) |
| | if subset: |
| | datas = dict(list(datas.items())[:subset]) |
| |
|
| | logger.info(f"Number of tasks: {len(datas)}") |
| |
|
| | |
| | self._results = [] |
| |
|
| | |
| | |
| | tool_search_enabled = level == "level-2" |
| | dialog_test_enabled = not api_test_enabled |
| | total_api_calls, correct_api_calls, rougel_scores = 0, 0, [] |
| |
|
| | with open(self.save_to, "w") as f: |
| | for test in tqdm(datas, desc="Running"): |
| | samples = self._data[test] |
| | evaluator = Evaluator(samples) |
| |
|
| | for sample_id in evaluator.get_all_sample_ids(): |
| | |
| | sample = evaluator.dataset[sample_id] |
| |
|
| | if ( |
| | sample.ground_truth['role'] == 'API' |
| | and api_test_enabled |
| | ): |
| | if tool_search_enabled: |
| | _, chat_history = evaluator.get_model_input( |
| | sample_id |
| | ) |
| | api_descriptions = evaluator.get_api_description( |
| | 'ToolSearcher' |
| | ) |
| | else: |
| | api_descriptions, chat_history = ( |
| | evaluator.get_model_input(sample_id) |
| | ) |
| | messages = process_messages( |
| | chat_history, API_CALL_PROMPT + api_descriptions |
| | ) |
| | model_output = agent_call(messages, agent) |
| | api_call = get_api_call(model_output) |
| |
|
| | |
| | if api_call: |
| | try: |
| | correct, model_output_result = ( |
| | evaluator.evaluate(sample_id, api_call) |
| | ) |
| | except AssertionError as e: |
| | if 'The API name is not correct.' not in str( |
| | e |
| | ): |
| | raise e |
| | logging.info('AssertionError: {}'.format(e)) |
| | correct = False |
| | else: |
| | model_output_result = 'No API call found' |
| | correct = False |
| | if correct: |
| | correct_api_calls += 1 |
| | logging.info( |
| | 'Correct API call: {} Ground truth: {}'.format( |
| | api_call, sample.ground_truth |
| | ) |
| | ) |
| | else: |
| | logging.info( |
| | 'Incorrect model output: {} Result: {} \ |
| | Ground truth: {} File: {} Sample ID: {} \ |
| | Messages: {}'.format( |
| | model_output.replace('\n', ' '), |
| | model_output_result, |
| | sample.ground_truth, |
| | test, |
| | sample_id, |
| | messages[1:], |
| | ) |
| | ) |
| | total_api_calls += 1 |
| | self._results.append( |
| | { |
| | 'Role': 'API', |
| | 'Model_output': model_output, |
| | 'Model_output_result': model_output_result, |
| | 'Ground_truth': sample.ground_truth, |
| | 'Test': test, |
| | 'Correct': correct, |
| | } |
| | ) |
| | f.write(json.dumps(self._results[-1], indent=2) + "\n") |
| |
|
| | elif ( |
| | sample.ground_truth['role'] == 'AI' |
| | and dialog_test_enabled |
| | ): |
| | |
| | api_descriptions, chat_history = ( |
| | evaluator.get_model_input(sample_id) |
| | ) |
| |
|
| | messages = process_messages( |
| | chat_history, RESPONSE_PROMPT + api_descriptions |
| | ) |
| | model_output = agent_call(messages, agent) |
| |
|
| | |
| | if model_output: |
| | score = evaluator.evaluate(sample_id, model_output) |
| | else: |
| | score = 0 |
| | rougel_scores.append(score) |
| | if score < 0.2: |
| | logging.info( |
| | 'Low score: {} Score: {} Ground truth: {} \ |
| | Test: {} Sample ID: {} \ |
| | Messages: {}'.format( |
| | model_output.replace('\n', ' '), |
| | score, |
| | sample.ground_truth, |
| | test, |
| | sample_id, |
| | messages[1:], |
| | ) |
| | ) |
| |
|
| | self._results.append( |
| | { |
| | 'Role': 'AI', |
| | 'Model_output': model_output, |
| | 'Score': score, |
| | 'Ground_truth': sample.ground_truth, |
| | 'Test': test, |
| | } |
| | ) |
| | f.write(json.dumps(self._results[-1], indent=2) + "\n") |
| |
|
| | f.flush() |
| |
|
| | if api_test_enabled: |
| | return { |
| | 'total': total_api_calls, |
| | 'correct': correct_api_calls, |
| | "accuracy": correct_api_calls / total_api_calls |
| | if total_api_calls |
| | else 0, |
| | } |
| | elif dialog_test_enabled: |
| | return {'Dialog_score': np.mean(rougel_scores)} |
| |
|
| |
|
| | |
| | |
| | def agent_call(messages: List[Dict], agent: ChatAgent): |
| | r"""Add messages to agent memory and get response.""" |
| | for i, msg in enumerate(messages): |
| | if msg['role'] == 'user': |
| | message = BaseMessage.make_user_message( |
| | role_name="CAMEL User", content=msg['content'] |
| | ) |
| | elif msg['role'] == 'assistant': |
| | message = BaseMessage.make_assistant_message( |
| | role_name="CAMEL Assistant", content=msg['content'] |
| | ) |
| | elif msg['role'] == 'system': |
| | message = BaseMessage.make_assistant_message( |
| | role_name="System", content=msg['content'] |
| | ) |
| | else: |
| | raise ValueError(f"Unrecognized role: {msg['role']}") |
| |
|
| | if i == len(messages) - 1: |
| | break |
| | agent.record_message(message) |
| |
|
| | response = agent.step(message) |
| | model_output = response.msgs[0].content |
| | agent.reset() |
| | return model_output |
| |
|
| |
|
| | def calculate_rouge_l_score(reference, hypothesis): |
| | r"""Calculate rouge l score between hypothesis and reference.""" |
| | rouge = Rouge() |
| | scores = rouge.get_scores(hypothesis, reference) |
| | rouge_l_score = scores[0]['rouge-l']['f'] |
| | return rouge_l_score |
| |
|
| |
|
| | def get_api_call(model_output): |
| | r"""Parse api call from model output.""" |
| | api_call_pattern = r"\[(\w+)\((.*)\)\]" |
| | api_call_pattern = re.compile(api_call_pattern) |
| | match = api_call_pattern.search(model_output) |
| | if match: |
| | return match.group(0) |
| | else: |
| | return None |
| |
|
| |
|
| | class APIBankSample: |
| | r"""APIBank sample used to load the datasets.""" |
| |
|
| | def __init__(self, chat_history, apis, ground_truth): |
| | self.chat_history = chat_history |
| | self.apis = apis |
| | self.ground_truth = ground_truth |
| |
|
| | def __repr__(self): |
| | return 'Sample(chat_history={}, apis={}, ground_truth={})'.format( |
| | self.chat_history, self.apis, self.ground_truth |
| | ) |
| |
|
| | @classmethod |
| | def from_chat_history(cls, chat_history): |
| | apis = set() |
| | api_positions = [] |
| | for i, item in enumerate(chat_history): |
| | if item['role'] == 'API': |
| | apis.add(item['api_name']) |
| | api_positions.append(i) |
| |
|
| | samples = [] |
| | for i in api_positions: |
| | sample = cls(chat_history[:i], apis, chat_history[i]) |
| | samples.append(sample) |
| | sample = cls(chat_history[: i + 1], apis, chat_history[i + 1]) |
| | samples.append(sample) |
| |
|
| | return samples |
| |
|
| |
|
| | class Evaluator: |
| | r"""Evaluator for APIBank benchmark.""" |
| |
|
| | def __init__(self, samples: List[APIBankSample]): |
| | |
| | |
| | try: |
| | from api_bank.tool_manager import ( |
| | ToolManager, |
| | ) |
| | except Exception as e: |
| | logger.info(f"{e}, Module will be imported after download.") |
| | self.dataset = samples |
| | self.sample_ids = list(range(len(self.dataset))) |
| | os.chdir("api_bank") |
| | self.tool_manager = ToolManager("apis") |
| | os.chdir("..") |
| |
|
| | def get_all_sample_ids(self): |
| | return self.sample_ids |
| |
|
| | def get_api_description(self, api_name): |
| | return self.tool_manager.get_api_description(api_name) |
| |
|
| | def get_model_input(self, sample_id: int): |
| | sample = self.dataset[sample_id] |
| | apis = sample.apis |
| | chat_history = sample.chat_history |
| | api_descriptions = [] |
| | for api_name in apis: |
| | api_descriptions.append( |
| | self.tool_manager.get_api_description(api_name) |
| | ) |
| | api_description = '\n'.join(api_descriptions) |
| | return api_description, chat_history |
| |
|
| | def evaluate(self, sample_id, model_output): |
| | try: |
| | from api_bank.api_call_extraction import ( |
| | parse_api_call, |
| | ) |
| | except Exception as e: |
| | logger.info(f"{e}, Module will be imported after download.") |
| | sample = self.dataset[sample_id] |
| | ground_truth = sample.ground_truth |
| | if ground_truth['role'] == 'API': |
| | api_name, param_dict = parse_api_call(model_output) |
| | if api_name != ground_truth['api_name']: |
| | return False, 'API Name Mismatch: {} vs {}'.format( |
| | api_name, ground_truth['api_name'] |
| | ) |
| | try: |
| | result = self.tool_manager.api_call(api_name, **param_dict) |
| | except Exception as e: |
| | return False, str(e) |
| | api = self.tool_manager.init_tool(api_name) |
| | try: |
| | correct = api.check_api_call_correctness( |
| | result, ground_truth['result'] |
| | ) |
| | except KeyError: |
| | correct = False |
| | result = 'KeyError' + str(result) |
| | return correct, result |
| | elif ground_truth['role'] == 'AI': |
| | score = calculate_rouge_l_score(ground_truth['text'], model_output) |
| | return round(score, 4) |
| |
|
| |
|
| | API_CALL_PROMPT = ''' |
| | Based on the given API description and the existing \ |
| | conversation history 1..t, please generate the API request \ |
| | that the AI should call in step t+1 and output it in the \ |
| | format of [ApiName(key1='value1', key2='value2', ...)], \ |
| | replace the ApiName with the actual API name, and \ |
| | replace the key and value with the actual parameters. \ |
| | Your output should start with a square bracket "[" \ |
| | and end with a square bracket "]". Do not output any \ |
| | other explanation or prompt or the result of the API call in your output. |
| | This year is 2023. |
| | Input: |
| | User: [User's utterence] |
| | AI: [AI's utterence] |
| | |
| | Expected output: |
| | [ApiName(key1='value1', key2='value2', ...)] |
| | |
| | API descriptions: |
| | ''' |
| |
|
| | RESPONSE_PROMPT = ''' |
| | Based on the given API description and the existing \ |
| | conversation history 1..t, please generate the next \ |
| | dialog that the AI should response after the API call t. |
| | This year is 2023. |
| | Input: |
| | User: [User's utterence] |
| | AI: [AI's utterence] |
| | [ApiName(key1='value1', key2='value2', …)] |
| | |
| | Expected output: |
| | AI: [AI's utterence] |
| | |
| | API descriptions: |
| | ''' |
| |
|