Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify | |
| import os | |
| import subprocess | |
| import json | |
| import logging | |
| from typing import Dict, Any, List | |
| import requests | |
| app = Flask(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| # Configuration | |
| OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434') | |
| MODELS_DIR = os.getenv('MODELS_DIR', '/models') | |
| ALLOWED_MODELS = os.getenv('ALLOWED_MODELS', 'llama2,llama2:13b,llama2:70b,codellama,neural-chat,gemma-3-270m').split(',') | |
| class OllamaManager: | |
| def __init__(self, base_url: str): | |
| self.base_url = base_url | |
| self.available_models = [] | |
| self.refresh_models() | |
| def refresh_models(self): | |
| """Refresh the list of available models""" | |
| try: | |
| response = requests.get(f"{self.base_url}/api/tags", timeout=10) | |
| if response.status_code == 200: | |
| data = response.json() | |
| self.available_models = [model['name'] for model in data.get('models', [])] | |
| else: | |
| self.available_models = [] | |
| except Exception as e: | |
| logging.error(f"Error refreshing models: {e}") | |
| self.available_models = [] | |
| def list_models(self) -> List[str]: | |
| """List all available models""" | |
| self.refresh_models() | |
| return self.available_models | |
| def pull_model(self, model_name: str) -> Dict[str, Any]: | |
| """Pull a model from Ollama""" | |
| try: | |
| response = requests.post(f"{self.base_url}/api/pull", | |
| json={"name": model_name}, | |
| timeout=300) | |
| if response.status_code == 200: | |
| return {"status": "success", "model": model_name} | |
| else: | |
| return {"status": "error", "message": f"Failed to pull model: {response.text}"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| def generate(self, model_name: str, prompt: str, **kwargs) -> Dict[str, Any]: | |
| """Generate text using a model""" | |
| try: | |
| payload = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| "stream": False | |
| } | |
| payload.update(kwargs) | |
| response = requests.post(f"{self.base_url}/api/generate", | |
| json=payload, | |
| timeout=120) | |
| if response.status_code == 200: | |
| data = response.json() | |
| return { | |
| "status": "success", | |
| "response": data.get('response', ''), | |
| "model": model_name, | |
| "usage": data.get('usage', {}) | |
| } | |
| else: | |
| return {"status": "error", "message": f"Generation failed: {response.text}"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| # Initialize Ollama manager | |
| ollama_manager = OllamaManager(OLLAMA_BASE_URL) | |
| def home(): | |
| """Home page with API documentation""" | |
| return ''' | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Ollama API Space</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
| .endpoint { background: #f5f5f5; padding: 15px; margin: 10px 0; border-radius: 5px; } | |
| .method { background: #007bff; color: white; padding: 2px 8px; border-radius: 3px; font-size: 12px; } | |
| .url { font-family: monospace; background: #e9ecef; padding: 2px 6px; border-radius: 3px; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>π Ollama API Space</h1> | |
| <p>This Space provides API endpoints for Ollama model management and inference.</p> | |
| <h2>Available Endpoints</h2> | |
| <div class="endpoint"> | |
| <span class="method">GET</span> <span class="url">/api/models</span> | |
| <p>List all available models</p> | |
| </div> | |
| <div class="endpoint"> | |
| <span class="method">POST</span> <span class="url">/api/models/pull</span> | |
| <p>Pull a model from Ollama</p> | |
| <p>Body: {"name": "model_name"}</p> | |
| </div> | |
| <div class="endpoint"> | |
| <span class="method">POST</span> <span class="url">/api/generate</span> | |
| <p>Generate text using a model</p> | |
| <p>Body: {"model": "model_name", "prompt": "your prompt"}</p> | |
| </div> | |
| <div class="endpoint"> | |
| <span class="method">GET</span> <span class="url">/health</span> | |
| <p>Health check endpoint</p> | |
| </div> | |
| <h2>Usage Examples</h2> | |
| <p>You can use this API with OpenWebUI or any other client that supports REST APIs.</p> | |
| <h3>cURL Examples</h3> | |
| <pre> | |
| # List models | |
| curl https://your-space-url.hf.space/api/models | |
| # Generate text | |
| curl -X POST https://your-space-url.hf.space/api/generate \ | |
| -H "Content-Type: application/json" \ | |
| -d '{"model": "llama2", "prompt": "Hello, how are you?"}' | |
| </pre> | |
| </body> | |
| </html> | |
| ''' | |
| def list_models(): | |
| """List all available models""" | |
| try: | |
| models = ollama_manager.list_models() | |
| return jsonify({ | |
| "status": "success", | |
| "models": models, | |
| "count": len(models) | |
| }) | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def pull_model(): | |
| """Pull a model from Ollama""" | |
| try: | |
| data = request.get_json() | |
| if not data or 'name' not in data: | |
| return jsonify({"status": "error", "message": "Model name is required"}), 400 | |
| model_name = data['name'] | |
| if model_name not in ALLOWED_MODELS: | |
| return jsonify({"status": "error", "message": f"Model {model_name} not in allowed list"}), 400 | |
| result = ollama_manager.pull_model(model_name) | |
| if result["status"] == "success": | |
| return jsonify(result), 200 | |
| else: | |
| return jsonify(result), 500 | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def generate_text(): | |
| """Generate text using a model""" | |
| try: | |
| data = request.get_json() | |
| if not data or 'model' not in data or 'prompt' not in data: | |
| return jsonify({"status": "error", "message": "Model name and prompt are required"}), 400 | |
| model_name = data['model'] | |
| prompt = data['prompt'] | |
| # Remove additional parameters that might be passed | |
| kwargs = {k: v for k, v in data.items() if k not in ['model', 'prompt']} | |
| result = ollama_manager.generate(model_name, prompt, **kwargs) | |
| if result["status"] == "success": | |
| return jsonify(result), 200 | |
| else: | |
| return jsonify(result), 500 | |
| except Exception as e: | |
| return jsonify({"status": "error", "message": str(e)}), 500 | |
| def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| # Try to connect to Ollama | |
| response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5) | |
| if response.status_code == 200: | |
| return jsonify({ | |
| "status": "healthy", | |
| "ollama_connection": "connected", | |
| "available_models": len(ollama_manager.available_models) | |
| }) | |
| else: | |
| return jsonify({ | |
| "status": "unhealthy", | |
| "ollama_connection": "failed", | |
| "error": f"Ollama returned status {response.status_code}" | |
| }), 503 | |
| except Exception as e: | |
| return jsonify({ | |
| "status": "unhealthy", | |
| "ollama_connection": "failed", | |
| "error": str(e) | |
| }), 503 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=False) | |