| import gradio as gr | |
| import subprocess | |
| import time | |
| import requests | |
| import json | |
| import sys | |
| import os | |
| import asyncio | |
| import aiohttp | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_FILENAME = "GigaChat3-10B-A1.8B-Q8_0.gguf" | |
| MODEL_PATH = os.path.join(BASE_DIR, MODEL_FILENAME) | |
| TEMPLATE_PATH = os.path.join(BASE_DIR, "chat_template.jinja") | |
| SERVER_PORT = 8080 | |
| API_BASE = f"http://localhost:{SERVER_PORT}/v1" | |
| def start_llama_server(): | |
| if not os.path.exists(MODEL_PATH): | |
| print(f"CRITICAL ERROR: Model not found at {MODEL_PATH}") | |
| sys.exit(1) | |
| if not os.path.exists(TEMPLATE_PATH): | |
| print(f"CRITICAL ERROR: Template not found at {TEMPLATE_PATH}") | |
| sys.exit(1) | |
| llama_bin_path = "/app/build/bin/llama-server" | |
| cmd = [ | |
| llama_bin_path, | |
| "-m", MODEL_PATH, | |
| "--chat-template-file", TEMPLATE_PATH, | |
| "--jinja", | |
| "-cmoe", | |
| "--port", str(SERVER_PORT), | |
| "--host", "0.0.0.0", | |
| "-c", "8192", | |
| "-np", "1", | |
| "--threads", "2", | |
| "-b", "512", | |
| ] | |
| print(f"Starting server with command: {' '.join(cmd)}") | |
| env = os.environ.copy() | |
| env['LD_LIBRARY_PATH'] = '/app/build/bin' | |
| process = subprocess.Popen( | |
| cmd, | |
| cwd="/app/build/bin", | |
| env=env, | |
| stdout=sys.stdout, | |
| stderr=sys.stderr | |
| ) | |
| print("Waiting for server to become healthy...") | |
| for i in range(90): | |
| try: | |
| resp = requests.get(f"http://localhost:{SERVER_PORT}/health", timeout=2) | |
| if resp.status_code == 200: | |
| print("\nServer is ready!") | |
| return process | |
| except: | |
| pass | |
| time.sleep(1) | |
| if i % 5 == 0: | |
| print(".", end="", flush=True) | |
| print("\nServer failed to start within timeout.") | |
| process.terminate() | |
| raise RuntimeError("Server failed to start") | |
| server_process = start_llama_server() | |
| async def chat_with_model(message, history): | |
| messages = [] | |
| if history: | |
| if isinstance(history[0], dict): | |
| for msg in history: | |
| role = msg.get('role') | |
| content_data = msg.get('content') | |
| content_str = "" | |
| if isinstance(content_data, str): | |
| content_str = content_data | |
| elif isinstance(content_data, list): | |
| for part in content_data: | |
| if isinstance(part, dict) and 'text' in part: | |
| content_str += part['text'] | |
| if role and content_str: | |
| messages.append({"role": role, "content": content_str}) | |
| elif isinstance(history[0], (list, tuple)): | |
| for item in history: | |
| if len(item) >= 2: | |
| user_msg = item[0] | |
| assistant_msg = item[1] | |
| if user_msg and assistant_msg: | |
| messages.append({"role": "user", "content": str(user_msg)}) | |
| messages.append({"role": "assistant", "content": str(assistant_msg)}) | |
| elif isinstance(history[0], str): | |
| for i in range(0, len(history), 2): | |
| if i + 1 < len(history): | |
| messages.append({"role": "user", "content": str(history[i])}) | |
| messages.append({"role": "assistant", "content": str(history[i+1])}) | |
| messages.append({"role": "user", "content": message}) | |
| print(f"DEBUG: Sending {len(messages)} messages. Prompt caching should work now.") | |
| partial_text = "" | |
| timeout = aiohttp.ClientTimeout(total=600) | |
| try: | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.post( | |
| f"{API_BASE}/chat/completions", | |
| json={ | |
| "messages": messages, | |
| "temperature": 0.5, | |
| "top_p": 0.95, | |
| "max_tokens": 1024, | |
| "stream": True | |
| } | |
| ) as response: | |
| if response.status != 200: | |
| yield f"Error: Server returned status {response.status}" | |
| return | |
| async for line in response.content: | |
| line = line.decode('utf-8').strip() | |
| if not line: | |
| continue | |
| if line.startswith("data: "): | |
| json_str = line[6:] | |
| if json_str == "[DONE]": | |
| break | |
| try: | |
| chunk = json.loads(json_str) | |
| if "choices" in chunk and chunk["choices"]: | |
| delta = chunk["choices"][0].get("delta", {}) | |
| content = delta.get("content") | |
| if content: | |
| partial_text += content | |
| yield partial_text | |
| except json.JSONDecodeError: | |
| continue | |
| except asyncio.CancelledError: | |
| print("User stopped generation.") | |
| return | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| if partial_text: | |
| yield partial_text | |
| else: | |
| yield f"Error: {str(e)}" | |
| return | |
| demo = gr.ChatInterface( | |
| fn=chat_with_model, | |
| title="GigaChat3-10B-A1.8B (Q8_0)", | |
| description="Running with llama.cpp b7130 on CPU", | |
| examples=["What is GigaChat?", "Write Python code", "What is quantum mechanics?"], | |
| concurrency_limit=1 | |
| ) | |
| if __name__ == "__main__": | |
| try: | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| finally: | |
| if server_process: | |
| server_process.terminate() | |
| server_process.wait() |