whoy's picture
Update app.py
512dcf1 verified
import gradio as gr
import subprocess
import time
import requests
import json
import sys
import os
import asyncio
import aiohttp
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_FILENAME = "GigaChat3-10B-A1.8B-Q8_0.gguf"
MODEL_PATH = os.path.join(BASE_DIR, MODEL_FILENAME)
TEMPLATE_PATH = os.path.join(BASE_DIR, "chat_template.jinja")
SERVER_PORT = 8080
API_BASE = f"http://localhost:{SERVER_PORT}/v1"
def start_llama_server():
if not os.path.exists(MODEL_PATH):
print(f"CRITICAL ERROR: Model not found at {MODEL_PATH}")
sys.exit(1)
if not os.path.exists(TEMPLATE_PATH):
print(f"CRITICAL ERROR: Template not found at {TEMPLATE_PATH}")
sys.exit(1)
llama_bin_path = "/app/build/bin/llama-server"
cmd = [
llama_bin_path,
"-m", MODEL_PATH,
"--chat-template-file", TEMPLATE_PATH,
"--jinja",
"-cmoe",
"--port", str(SERVER_PORT),
"--host", "0.0.0.0",
"-c", "8192",
"-np", "1",
"--threads", "2",
"-b", "512",
]
print(f"Starting server with command: {' '.join(cmd)}")
env = os.environ.copy()
env['LD_LIBRARY_PATH'] = '/app/build/bin'
process = subprocess.Popen(
cmd,
cwd="/app/build/bin",
env=env,
stdout=sys.stdout,
stderr=sys.stderr
)
print("Waiting for server to become healthy...")
for i in range(90):
try:
resp = requests.get(f"http://localhost:{SERVER_PORT}/health", timeout=2)
if resp.status_code == 200:
print("\nServer is ready!")
return process
except:
pass
time.sleep(1)
if i % 5 == 0:
print(".", end="", flush=True)
print("\nServer failed to start within timeout.")
process.terminate()
raise RuntimeError("Server failed to start")
server_process = start_llama_server()
async def chat_with_model(message, history):
messages = []
if history:
if isinstance(history[0], dict):
for msg in history:
role = msg.get('role')
content_data = msg.get('content')
content_str = ""
if isinstance(content_data, str):
content_str = content_data
elif isinstance(content_data, list):
for part in content_data:
if isinstance(part, dict) and 'text' in part:
content_str += part['text']
if role and content_str:
messages.append({"role": role, "content": content_str})
elif isinstance(history[0], (list, tuple)):
for item in history:
if len(item) >= 2:
user_msg = item[0]
assistant_msg = item[1]
if user_msg and assistant_msg:
messages.append({"role": "user", "content": str(user_msg)})
messages.append({"role": "assistant", "content": str(assistant_msg)})
elif isinstance(history[0], str):
for i in range(0, len(history), 2):
if i + 1 < len(history):
messages.append({"role": "user", "content": str(history[i])})
messages.append({"role": "assistant", "content": str(history[i+1])})
messages.append({"role": "user", "content": message})
print(f"DEBUG: Sending {len(messages)} messages. Prompt caching should work now.")
partial_text = ""
timeout = aiohttp.ClientTimeout(total=600)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{API_BASE}/chat/completions",
json={
"messages": messages,
"temperature": 0.5,
"top_p": 0.95,
"max_tokens": 1024,
"stream": True
}
) as response:
if response.status != 200:
yield f"Error: Server returned status {response.status}"
return
async for line in response.content:
line = line.decode('utf-8').strip()
if not line:
continue
if line.startswith("data: "):
json_str = line[6:]
if json_str == "[DONE]":
break
try:
chunk = json.loads(json_str)
if "choices" in chunk and chunk["choices"]:
delta = chunk["choices"][0].get("delta", {})
content = delta.get("content")
if content:
partial_text += content
yield partial_text
except json.JSONDecodeError:
continue
except asyncio.CancelledError:
print("User stopped generation.")
return
except Exception as e:
print(f"Error: {e}")
if partial_text:
yield partial_text
else:
yield f"Error: {str(e)}"
return
demo = gr.ChatInterface(
fn=chat_with_model,
title="GigaChat3-10B-A1.8B (Q8_0)",
description="Running with llama.cpp b7130 on CPU",
examples=["What is GigaChat?", "Write Python code", "What is quantum mechanics?"],
concurrency_limit=1
)
if __name__ == "__main__":
try:
demo.launch(server_name="0.0.0.0", server_port=7860)
finally:
if server_process:
server_process.terminate()
server_process.wait()