Spaces:
Running
Running
| import asyncio | |
| import os | |
| import json | |
| from typing import List, Dict, Any, Union | |
| from contextlib import AsyncExitStack | |
| import gradio as gr | |
| from gradio.components.chatbot import ChatMessage | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp.client.stdio import stdio_client | |
| from mcp.client.sse import sse_client | |
| from anthropic import Anthropic | |
| from datasets import load_dataset | |
| import pandas as pd | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| class MCPClientWrapper: | |
| def __init__(self): | |
| self.session = None | |
| self.exit_stack = None | |
| self.anthropic = None | |
| self.tools = [] | |
| self.dataset = None | |
| self.validation_results = [] | |
| def set_api_key(self, api_key: str) -> str: | |
| """Set the Anthropic API key and initialize the client""" | |
| if not api_key or not api_key.strip(): | |
| return "Please enter a valid Anthropic API key" | |
| try: | |
| self.anthropic = Anthropic(api_key=api_key.strip()) | |
| return "API key set successfully β " | |
| except Exception as e: | |
| return f"Failed to set API key: {str(e)}" | |
| def connect(self, server_input: str) -> str: | |
| if not self.anthropic: | |
| return "Please set your Anthropic API key first" | |
| return loop.run_until_complete(self._connect(server_input)) | |
| async def _connect(self, server_input: str) -> str: | |
| if self.exit_stack: | |
| await self.exit_stack.aclose() | |
| self.exit_stack = AsyncExitStack() | |
| try: | |
| # Check if input is a URL (starts with http:// or https://) | |
| if server_input.startswith(('http://', 'https://')): | |
| # Connect via SSE | |
| read, write = await self.exit_stack.enter_async_context( | |
| sse_client(server_input) | |
| ) | |
| connection_type = "SSE URL" | |
| else: | |
| # Connect via stdio (local file) | |
| is_python = server_input.endswith('.py') | |
| command = "python" if is_python else "node" | |
| server_params = StdioServerParameters( | |
| command=command, | |
| args=[server_input], | |
| env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"} | |
| ) | |
| read, write = await self.exit_stack.enter_async_context( | |
| stdio_client(server_params) | |
| ) | |
| connection_type = "Local script" | |
| self.session = await self.exit_stack.enter_async_context( | |
| ClientSession(read, write) | |
| ) | |
| await self.session.initialize() | |
| response = await self.session.list_tools() | |
| self.tools = [{ | |
| "name": tool.name, | |
| "description": tool.description, | |
| "input_schema": tool.inputSchema | |
| } for tool in response.tools] | |
| tool_names = [tool["name"] for tool in self.tools] | |
| return f"Connected to MCP server via {connection_type}. Available tools: {', '.join(tool_names)}" | |
| except Exception as e: | |
| return f"Connection failed: {str(e)}" | |
| def load_dataset(self) -> tuple: | |
| """Load the TAAIC Phase1 validation dataset""" | |
| try: | |
| self.dataset = load_dataset("aitxchallenge/Phase1_Model_Validator", split="train") | |
| dataset_info = f"Dataset loaded successfully! {len(self.dataset)} validation cases available." | |
| # Create a preview of the dataset | |
| df = pd.DataFrame(self.dataset) | |
| preview = df.head().to_string() | |
| return ( | |
| dataset_info, | |
| gr.Button("π Validate", interactive=True), | |
| gr.Textbox(value=f"Dataset Preview:\n{preview}", visible=True) | |
| ) | |
| except Exception as e: | |
| return ( | |
| f"Failed to load dataset: {str(e)}", | |
| gr.Button("π₯ Load Dataset", interactive=True), | |
| gr.Textbox(visible=False) | |
| ) | |
| def validate_tools(self) -> str: | |
| """Run validation on all dataset cases""" | |
| if not self.anthropic: | |
| return "Please set your Anthropic API key first." | |
| if not self.dataset: | |
| return "Please load the dataset first." | |
| if not self.session: | |
| return "Please connect to an MCP server first." | |
| return loop.run_until_complete(self._run_validation()) | |
| async def _run_validation(self) -> str: | |
| """Async validation runner""" | |
| self.validation_results = [] | |
| total_cases = len(self.dataset) | |
| passed = 0 | |
| failed = 0 | |
| for i, case in enumerate(self.dataset): | |
| try: | |
| # Extract test case information | |
| query = case.get('query', case.get('question', '')) | |
| expected_output = case.get('expected_output', case.get('expected', '')) | |
| test_id = case.get('id', f'test_{i}') | |
| # Run the query through the MCP tools | |
| result = await self._validate_single_case(query, expected_output, test_id) | |
| self.validation_results.append(result) | |
| if result['passed']: | |
| passed += 1 | |
| else: | |
| failed += 1 | |
| except Exception as e: | |
| failed += 1 | |
| self.validation_results.append({ | |
| 'test_id': test_id, | |
| 'query': query, | |
| 'error': str(e), | |
| 'passed': False | |
| }) | |
| # Generate validation report | |
| report = f""" | |
| VALIDATION COMPLETE | |
| ================== | |
| Total Cases: {total_cases} | |
| Passed: {passed} | |
| Failed: {failed} | |
| Success Rate: {(passed/total_cases)*100:.1f}% | |
| DETAILED RESULTS: | |
| """ | |
| for result in self.validation_results: | |
| status = "β PASS" if result['passed'] else "β FAIL" | |
| report += f"\n{status} [{result['test_id']}] {result['query'][:50]}..." | |
| if not result['passed'] and 'error' in result: | |
| report += f"\n Error: {result['error']}" | |
| return report | |
| async def _validate_single_case(self, query: str, expected_output: str, test_id: str) -> Dict[str, Any]: | |
| """Validate a single test case""" | |
| try: | |
| # Send query to Claude with MCP tools | |
| claude_messages = [{"role": "user", "content": query}] | |
| response = self.anthropic.messages.create( | |
| model="claude-3-5-sonnet-20241022", | |
| max_tokens=1000, | |
| messages=claude_messages, | |
| tools=self.tools | |
| ) | |
| # Process tool calls if any | |
| actual_output = "" | |
| for content in response.content: | |
| if content.type == 'text': | |
| actual_output += content.text | |
| elif content.type == 'tool_use': | |
| tool_result = await self.session.call_tool(content.name, content.input) | |
| actual_output += str(tool_result.content) | |
| # Simple validation logic - you may want to customize this | |
| passed = self._validate_output(actual_output, expected_output) | |
| return { | |
| 'test_id': test_id, | |
| 'query': query, | |
| 'expected': expected_output, | |
| 'actual': actual_output, | |
| 'passed': passed | |
| } | |
| except Exception as e: | |
| return { | |
| 'test_id': test_id, | |
| 'query': query, | |
| 'error': str(e), | |
| 'passed': False | |
| } | |
| def _validate_output(self, actual: str, expected: str) -> bool: | |
| """Basic output validation - customize based on your needs""" | |
| # This is a simple implementation - you may want more sophisticated validation | |
| if not expected: | |
| return True # If no expected output specified, consider it passed | |
| # You can implement more sophisticated matching here | |
| # For now, using simple substring matching | |
| return expected.lower() in actual.lower() | |
| def process_message(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]) -> tuple: | |
| if not self.anthropic: | |
| return history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": "Please set your Anthropic API key first."} | |
| ], gr.Textbox(value="") | |
| if not self.session: | |
| return history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": "Please connect to an MCP server first."} | |
| ], gr.Textbox(value="") | |
| new_messages = loop.run_until_complete(self._process_query(message, history)) | |
| return history + [{"role": "user", "content": message}] + new_messages, gr.Textbox(value="") | |
| async def _process_query(self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]): | |
| claude_messages = [] | |
| for msg in history: | |
| if isinstance(msg, ChatMessage): | |
| role, content = msg.role, msg.content | |
| else: | |
| role, content = msg.get("role"), msg.get("content") | |
| if role in ["user", "assistant", "system"]: | |
| claude_messages.append({"role": role, "content": content}) | |
| claude_messages.append({"role": "user", "content": message}) | |
| response = self.anthropic.messages.create( | |
| model="claude-3-5-sonnet-20241022", | |
| max_tokens=1000, | |
| messages=claude_messages, | |
| tools=self.tools | |
| ) | |
| result_messages = [] | |
| for content in response.content: | |
| if content.type == 'text': | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": content.text | |
| }) | |
| elif content.type == 'tool_use': | |
| tool_name = content.name | |
| tool_args = content.input | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": f"I'll only use the {tool_name} tool to help answer your question.", | |
| "metadata": { | |
| "title": f"Using tool: {tool_name}", | |
| "log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}", | |
| "status": "pending", | |
| "id": f"tool_call_{tool_name}" | |
| } | |
| }) | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "```json\n" + json.dumps(tool_args, indent=2, ensure_ascii=True) + "\n```", | |
| "metadata": { | |
| "parent_id": f"tool_call_{tool_name}", | |
| "id": f"params_{tool_name}", | |
| "title": "Tool Parameters" | |
| } | |
| }) | |
| try: | |
| result = await self.session.call_tool(tool_name, tool_args) | |
| if result_messages and "metadata" in result_messages[-2]: | |
| result_messages[-2]["metadata"]["status"] = "done" | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "Here are the results from the tool:", | |
| "metadata": { | |
| "title": f"Tool Result for {tool_name}", | |
| "status": "done", | |
| "id": f"result_{tool_name}" | |
| } | |
| }) | |
| result_content = result.content | |
| if isinstance(result_content, list): | |
| result_content = "\n".join(str(item) for item in result_content) | |
| try: | |
| result_json = json.loads(result_content) | |
| if isinstance(result_json, dict) and "type" in result_json: | |
| if result_json["type"] == "image" and "url" in result_json: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": {"path": result_json["url"], "alt_text": result_json.get("message", "Generated image")}, | |
| "metadata": { | |
| "parent_id": f"result_{tool_name}", | |
| "id": f"image_{tool_name}", | |
| "title": "Generated Image" | |
| } | |
| }) | |
| else: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "```\n" + result_content + "\n```", | |
| "metadata": { | |
| "parent_id": f"result_{tool_name}", | |
| "id": f"raw_result_{tool_name}", | |
| "title": "Raw Output" | |
| } | |
| }) | |
| except: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": "```\n" + result_content + "\n```", | |
| "metadata": { | |
| "parent_id": f"result_{tool_name}", | |
| "id": f"raw_result_{tool_name}", | |
| "title": "Raw Output" | |
| } | |
| }) | |
| claude_messages.append({"role": "user", "content": f"Tool result for {tool_name}: {result_content}"}) | |
| next_response = self.anthropic.messages.create( | |
| model="claude-3-5-sonnet-20241022", | |
| max_tokens=1000, | |
| messages=claude_messages, | |
| ) | |
| if next_response.content and next_response.content[0].type == 'text': | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": next_response.content[0].text | |
| }) | |
| except Exception as e: | |
| result_messages.append({ | |
| "role": "assistant", | |
| "content": f"Error calling tool {tool_name}: {str(e)}", | |
| "metadata": { | |
| "title": f"Error - {tool_name}", | |
| "status": "error", | |
| "id": f"error_{tool_name}" | |
| } | |
| }) | |
| return result_messages | |
| client = MCPClientWrapper() | |
| def gradio_interface(): | |
| with gr.Blocks(title="TAAIC Tool Validation") as demo: | |
| gr.Markdown("# TAAIC Tool Validation") | |
| gr.Markdown("Connect your Gradio MCP Tool for validation for the TAAIC challenge.") | |
| # API Key input section | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4): | |
| api_key_input = gr.Textbox( | |
| label="Anthropic API Key", | |
| placeholder="Enter your Anthropic API key (sk-ant-...)", | |
| type="password" | |
| ) | |
| with gr.Column(scale=1): | |
| api_key_btn = gr.Button("Set API Key") | |
| api_key_status = gr.Textbox(label="API Key Status", interactive=False) | |
| # MCP Server connection section | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4): | |
| server_input = gr.Textbox( | |
| label="MCP Server URL or Script Path", | |
| placeholder="Enter URL (e.g., https://cyrilzakka-clinical-trials.hf.space/gradio_api/mcp/sse) or local script path (e.g., weather.py)", | |
| value="https://cyrilzakka-clinical-trials.hf.space/gradio_api/mcp/sse" | |
| ) | |
| with gr.Column(scale=1): | |
| connect_btn = gr.Button("Connect") | |
| status = gr.Textbox(label="Connection Status", interactive=False) | |
| # Dataset loading section | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| dataset_status = gr.Textbox( | |
| label="Dataset Status", | |
| value="Click 'Load Dataset' to load validation cases", | |
| interactive=False | |
| ) | |
| with gr.Column(scale=1): | |
| dataset_btn = gr.Button("π₯ Load Dataset", interactive=True) | |
| dataset_preview = gr.Textbox( | |
| label="Dataset Preview", | |
| visible=False, | |
| interactive=False, | |
| max_lines=10 | |
| ) | |
| # Validation results | |
| validation_results = gr.Textbox( | |
| label="Validation Results", | |
| visible=False, | |
| interactive=False, | |
| max_lines=20 | |
| ) | |
| # Event handlers | |
| api_key_btn.click(client.set_api_key, inputs=api_key_input, outputs=api_key_status) | |
| connect_btn.click(client.connect, inputs=server_input, outputs=status) | |
| dataset_btn.click( | |
| client.load_dataset, | |
| outputs=[dataset_status, dataset_btn, dataset_preview] | |
| ) | |
| def run_validation(): | |
| results = client.validate_tools() | |
| return gr.Textbox(value=results, visible=True) | |
| dataset_btn.click( | |
| lambda: client.validate_tools() if client.dataset else "Please load dataset first.", | |
| outputs=validation_results, | |
| show_progress=True | |
| ).then( | |
| lambda: gr.Textbox(visible=True), | |
| outputs=validation_results | |
| ) | |
| # msg.submit(client.process_message, [msg, chatbot], [chatbot, msg]) | |
| # clear_btn.click(lambda: [], None, chatbot) | |
| return demo | |
| if __name__ == "__main__": | |
| interface = gradio_interface() | |
| interface.launch(debug=True) |