import gradio as gr import subprocess import time import requests import json import sys import os import asyncio import aiohttp BASE_DIR = os.path.dirname(os.path.abspath(__file__)) MODEL_FILENAME = "GigaChat3-10B-A1.8B-Q8_0.gguf" MODEL_PATH = os.path.join(BASE_DIR, MODEL_FILENAME) TEMPLATE_PATH = os.path.join(BASE_DIR, "chat_template.jinja") SERVER_PORT = 8080 API_BASE = f"http://localhost:{SERVER_PORT}/v1" def start_llama_server(): if not os.path.exists(MODEL_PATH): print(f"CRITICAL ERROR: Model not found at {MODEL_PATH}") sys.exit(1) if not os.path.exists(TEMPLATE_PATH): print(f"CRITICAL ERROR: Template not found at {TEMPLATE_PATH}") sys.exit(1) llama_bin_path = "/app/build/bin/llama-server" cmd = [ llama_bin_path, "-m", MODEL_PATH, "--chat-template-file", TEMPLATE_PATH, "--jinja", "-cmoe", "--port", str(SERVER_PORT), "--host", "0.0.0.0", "-c", "8192", "-np", "1", "--threads", "2", "-b", "512", ] print(f"Starting server with command: {' '.join(cmd)}") env = os.environ.copy() env['LD_LIBRARY_PATH'] = '/app/build/bin' process = subprocess.Popen( cmd, cwd="/app/build/bin", env=env, stdout=sys.stdout, stderr=sys.stderr ) print("Waiting for server to become healthy...") for i in range(90): try: resp = requests.get(f"http://localhost:{SERVER_PORT}/health", timeout=2) if resp.status_code == 200: print("\nServer is ready!") return process except: pass time.sleep(1) if i % 5 == 0: print(".", end="", flush=True) print("\nServer failed to start within timeout.") process.terminate() raise RuntimeError("Server failed to start") server_process = start_llama_server() async def chat_with_model(message, history): messages = [] if history: if isinstance(history[0], dict): for msg in history: role = msg.get('role') content_data = msg.get('content') content_str = "" if isinstance(content_data, str): content_str = content_data elif isinstance(content_data, list): for part in content_data: if isinstance(part, dict) and 'text' in part: content_str += part['text'] if role and content_str: messages.append({"role": role, "content": content_str}) elif isinstance(history[0], (list, tuple)): for item in history: if len(item) >= 2: user_msg = item[0] assistant_msg = item[1] if user_msg and assistant_msg: messages.append({"role": "user", "content": str(user_msg)}) messages.append({"role": "assistant", "content": str(assistant_msg)}) elif isinstance(history[0], str): for i in range(0, len(history), 2): if i + 1 < len(history): messages.append({"role": "user", "content": str(history[i])}) messages.append({"role": "assistant", "content": str(history[i+1])}) messages.append({"role": "user", "content": message}) print(f"DEBUG: Sending {len(messages)} messages. Prompt caching should work now.") partial_text = "" timeout = aiohttp.ClientTimeout(total=600) try: async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post( f"{API_BASE}/chat/completions", json={ "messages": messages, "temperature": 0.5, "top_p": 0.95, "max_tokens": 1024, "stream": True } ) as response: if response.status != 200: yield f"Error: Server returned status {response.status}" return async for line in response.content: line = line.decode('utf-8').strip() if not line: continue if line.startswith("data: "): json_str = line[6:] if json_str == "[DONE]": break try: chunk = json.loads(json_str) if "choices" in chunk and chunk["choices"]: delta = chunk["choices"][0].get("delta", {}) content = delta.get("content") if content: partial_text += content yield partial_text except json.JSONDecodeError: continue except asyncio.CancelledError: print("User stopped generation.") return except Exception as e: print(f"Error: {e}") if partial_text: yield partial_text else: yield f"Error: {str(e)}" return demo = gr.ChatInterface( fn=chat_with_model, title="GigaChat3-10B-A1.8B (Q8_0)", description="Running with llama.cpp b7130 on CPU", examples=["What is GigaChat?", "Write Python code", "What is quantum mechanics?"], concurrency_limit=1 ) if __name__ == "__main__": try: demo.launch(server_name="0.0.0.0", server_port=7860) finally: if server_process: server_process.terminate() server_process.wait()