File size: 4,205 Bytes
c7829ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Local FastMCP client used by the Gradio simulator service.

This module provides a lightweight in-process client that routes every request
through the registered FastMCP server instance so the application can rely on
actual MCP tool invocations (and capture request/response telemetry) instead
of calling helper functions directly.
"""

from __future__ import annotations

import json
import threading
import time
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Sequence

import anyio

from mcp.types import ContentBlock


class LocalFastMCPClient:
    """Synchronous helper that forwards calls to a FastMCP server instance."""

    def __init__(self, server, log_callback: Callable[[dict[str, Any]], None] | None = None):
        self._server = server
        self._log_callback = log_callback
        self._lock = threading.Lock()

    def list_tools(self) -> Any:
        """Expose server tool metadata (used for debugging/tests)."""

        async def _list_tools():
            return await self._server.list_tools()

        with self._lock:
            return anyio.run(_list_tools)

    def call_tool(self, name: str, **arguments: Any) -> dict[str, Any]:
        """Invoke an MCP tool and return a normalized dict response."""
        clean_args = {k: v for k, v in arguments.items() if v is not None}
        start = time.perf_counter()

        async def _call():
            return await self._server.call_tool(name, clean_args)

        with self._lock:
            raw_result = anyio.run(_call)

        normalized = self._normalize_result(raw_result)
        self._log(name, clean_args, normalized, start)
        return normalized

    # --------------------------------------------------------------------- #
    # Internal utilities
    # --------------------------------------------------------------------- #
    def _normalize_result(self, result: Any) -> dict[str, Any]:
        """Convert FastMCP responses into standard dicts for easier handling."""
        if isinstance(result, dict):
            return result
        if isinstance(result, Sequence):
            parsed = self._maybe_parse_json_from_blocks(result)
            if parsed is not None:
                return parsed
            blocks: list[dict[str, Any]] = []
            for block in result:
                if isinstance(block, ContentBlock):
                    blocks.append(block.model_dump(mode="json"))
                elif hasattr(block, "model_dump"):
                    blocks.append(block.model_dump(mode="json"))
                else:
                    blocks.append({"type": "text", "text": str(block)})
            return {"status": "ok", "content": blocks}
        return {"status": "ok", "data": deepcopy(result)}

    def _maybe_parse_json_from_blocks(self, blocks: Sequence[Any]) -> dict[str, Any] | None:
        """If the MCP server returned a single text block containing JSON, parse it."""
        if not blocks:
            return None
        first = blocks[0]
        text = None
        if isinstance(first, ContentBlock) and getattr(first, "type", None) == "text":
            text = first.model_dump().get("text")
        elif hasattr(first, "text"):
            text = getattr(first, "text")
        elif isinstance(first, dict) and first.get("type") == "text":
            text = first.get("text")
        if text is None:
            return None
        stripped = text.strip()
        if not stripped or stripped[0] not in "{[":
            return None
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            return None

    def _log(self, name: str, arguments: dict[str, Any], result: dict[str, Any], start: float) -> None:
        """Send invocation metadata to the optional callback."""
        if not self._log_callback:
            return
        duration_ms = round((time.perf_counter() - start) * 1000, 1)
        entry = {
            "timestamp": datetime.utcnow().isoformat(),
            "tool": name,
            "arguments": deepcopy(arguments),
            "result": deepcopy(result),
            "duration_ms": duration_ms,
        }
        self._log_callback(entry)