| import os |
| import io |
| import json |
| import uuid |
| import base64 |
| import time |
| import random |
| import math |
| from typing import List, Dict, Tuple, Optional |
|
|
| import gradio as gr |
| import spaces |
|
|
| |
| |
| try: |
| from ollama import Client |
| except Exception as e: |
| raise RuntimeError( |
| "Failed to import the 'ollama' Python client. Ensure it's in requirements.txt." |
| ) from e |
|
|
| DEFAULT_PORT = int(os.getenv("PORT", 7860)) |
| DEFAULT_OLLAMA_HOST = os.getenv("OLLAMA_HOST", "").strip() or os.getenv("OLLAMA_BASE_URL", "").strip() or "" |
| DEFAULT_MODEL = os.getenv("OLLAMA_MODEL", "llama3.1") |
| APP_TITLE = "Ollama Chat (Gradio + Docker)" |
| APP_DESCRIPTION = """ |
| A lightweight, fully functional chat UI for Ollama, designed to run on Hugging Face Spaces (Docker). |
| - Bring your own Ollama host (set OLLAMA_HOST in repo secrets or via the UI). |
| - Streamed responses, model management (list/pull), and basic vision support (image input). |
| - Compatible with Spaces ZeroGPU via a @spaces.GPU-decorated function (see GPU Tools panel). |
| """ |
|
|
|
|
| def ensure_scheme(host: str) -> str: |
| if not host: |
| return host |
| host = host.strip() |
| if not host.startswith(("http://", "https://")): |
| host = "http://" + host |
| |
| while host.endswith("/"): |
| host = host[:-1] |
| return host |
|
|
|
|
| def get_client(host: str) -> Client: |
| host = ensure_scheme(host) |
| if not host: |
| |
| return Client() |
| return Client(host=host) |
|
|
|
|
| def list_models(host: str) -> Tuple[List[str], Optional[str]]: |
| try: |
| client = get_client(host) |
| data = client.list() |
| names = sorted(m.get("name", "") for m in data.get("models", []) if m.get("name")) |
| return names, None |
| except Exception as e: |
| return [], f"Unable to list models from {host or '(env default)'}: {e}" |
|
|
|
|
| def test_connection(host: str) -> Tuple[bool, str]: |
| names, err = list_models(host) |
| if err: |
| return False, err |
| if not names: |
| return True, f"Connected to {host or '(env default)'} but no models found. Pull one to continue." |
| return True, f"Connected to {host or '(env default)'}; found {len(names)} models." |
|
|
|
|
| def show_model(host: str, model: str) -> Tuple[Optional[dict], Optional[str]]: |
| try: |
| client = get_client(host) |
| info = client.show(model=model) |
| return info, None |
| except Exception as e: |
| return None, f"Unable to show model '{model}': {e}" |
|
|
|
|
| def pull_model(host: str, model: str): |
| """ |
| Generator that pulls a model on the remote Ollama host, yielding progress strings. |
| """ |
| if not model: |
| yield "Provide a model name to pull (e.g., llama3.1, mistral, qwen2.5:latest)" |
| return |
| try: |
| client = get_client(host) |
| already, _ = show_model(host, model) |
| if already: |
| yield f"Model '{model}' already present on the host." |
| return |
|
|
| yield f"Pulling '{model}' from registry..." |
| for part in client.pull(model=model, stream=True): |
| |
| status = part.get("status", "") |
| total = part.get("total", 0) |
| completed = part.get("completed", 0) |
| pct = f"{(completed / total * 100):.1f}%" if total else "" |
| line = status |
| if pct: |
| line += f" ({pct})" |
| yield line |
| yield f"Finished pulling '{model}'." |
| except Exception as e: |
| yield f"Error pulling '{model}': {e}" |
|
|
|
|
| def encode_image_to_base64(path: str) -> Optional[str]: |
| try: |
| with open(path, "rb") as f: |
| return base64.b64encode(f.read()).decode("utf-8") |
| except Exception: |
| return None |
|
|
|
|
| def build_ollama_messages( |
| system_prompt: str, |
| convo_messages: List[Dict], |
| user_text: str, |
| image_paths: Optional[List[str]] = None, |
| ) -> List[Dict]: |
| """ |
| Returns the full message list to send to Ollama, including system prompt (if provided), |
| past conversation, and the new user message. |
| """ |
| messages = [] |
| if system_prompt.strip(): |
| messages.append({"role": "system", "content": system_prompt.strip()}) |
|
|
| messages.extend(convo_messages or []) |
|
|
| msg: Dict = {"role": "user", "content": user_text or ""} |
| if image_paths: |
| images_b64 = [] |
| for p in image_paths: |
| b64 = encode_image_to_base64(p) |
| if b64: |
| images_b64.append(b64) |
| if images_b64: |
| msg["images"] = images_b64 |
| messages.append(msg) |
| return messages |
|
|
|
|
| def messages_for_chatbot( |
| text: str, |
| image_paths: Optional[List[str]] = None, |
| role: str = "user", |
| ) -> Dict: |
| """ |
| Build a Gradio Chatbot message in "messages" mode: |
| {"role": "user"|"assistant", "content": [{"type":"text","text":...}, {"type":"image","image":<PIL.Image>}, ...]} |
| """ |
| content = [] |
| t = (text or "").strip() |
| if t: |
| content.append({"type": "text", "text": t}) |
|
|
| if image_paths: |
| |
| for p in image_paths: |
| try: |
| |
| content.append({"type": "image", "image": p}) |
| except Exception: |
| continue |
| return {"role": role, "content": content if content else [{"type": "text", "text": ""}]} |
|
|
|
|
| def stream_chat( |
| host: str, |
| model: str, |
| system_prompt: str, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| repeat_penalty: float, |
| num_ctx: int, |
| max_tokens: Optional[int], |
| seed: Optional[int], |
| convo_messages: List[Dict], |
| chatbot_history: List[Dict], |
| user_text: str, |
| image_files: Optional[List[str]], |
| ): |
| """ |
| Stream a chat completion from Ollama and update Gradio Chatbot incrementally. |
| """ |
| |
| user_msg_for_bot = messages_for_chatbot(user_text, image_files, role="user") |
| chatbot_history = chatbot_history + [user_msg_for_bot] |
|
|
| |
| ollama_messages = build_ollama_messages(system_prompt, convo_messages, user_text, image_files) |
|
|
| |
| options = { |
| "temperature": temperature, |
| "top_p": top_p, |
| "top_k": top_k, |
| "repeat_penalty": repeat_penalty, |
| "num_ctx": num_ctx, |
| } |
| if max_tokens is not None and max_tokens > 0: |
| |
| options["num_predict"] = max_tokens |
| if seed is not None: |
| options["seed"] = seed |
|
|
| |
| client = get_client(host) |
| assistant_text_accum = "" |
| start_time = time.time() |
|
|
| |
| assistant_msg_for_bot = messages_for_chatbot("", None, role="assistant") |
| chatbot_history = chatbot_history + [assistant_msg_for_bot] |
| status_md = f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | Streaming..." |
|
|
| |
| yield chatbot_history, status_md, convo_messages |
|
|
| try: |
| for part in client.chat( |
| model=model, |
| messages=ollama_messages, |
| stream=True, |
| options=options, |
| ): |
| msg = part.get("message", {}) or {} |
| delta = msg.get("content", "") |
| if delta: |
| assistant_text_accum += delta |
| chatbot_history[-1] = messages_for_chatbot(assistant_text_accum, None, role="assistant") |
|
|
| done = part.get("done", False) |
| if done: |
| eval_count = part.get("eval_count", 0) |
| prompt_eval_count = part.get("prompt_eval_count", 0) |
| total = time.time() - start_time |
| tok_s = (eval_count / total) if total > 0 else 0.0 |
| status_md = ( |
| f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | " |
| f"Prompt tokens: {prompt_eval_count} | Output tokens: {eval_count} | " |
| f"Time: {total:.2f}s | Speed: {tok_s:.1f} tok/s" |
| ) |
| yield chatbot_history, status_md, convo_messages |
|
|
| |
| convo_messages = convo_messages + [ |
| { |
| "role": "user", |
| "content": user_text or "", |
| **( |
| { |
| "images": [ |
| b for p in (image_files or []) |
| for b in ([encode_image_to_base64(p)] if encode_image_to_base64(p) else []) |
| ] |
| } if image_files else {} |
| ), |
| }, |
| {"role": "assistant", "content": assistant_text_accum}, |
| ] |
|
|
| yield chatbot_history, status_md, convo_messages |
|
|
| except Exception as e: |
| err_msg = f"Error during generation: {e}" |
| chatbot_history[-1] = messages_for_chatbot(err_msg, None, role="assistant") |
| yield chatbot_history, err_msg, convo_messages |
|
|
|
|
| def clear_conversation(): |
| return [], [], "" |
|
|
|
|
| def export_conversation(history: List[Dict], convo_messages: List[Dict]) -> Tuple[str, str]: |
| export_blob = { |
| "chat_messages": history, |
| "ollama_messages": convo_messages, |
| "meta": { |
| "title": APP_TITLE, |
| "exported_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()), |
| "version": "1.1", |
| }, |
| } |
| path = f"chat_export_{int(time.time())}.json" |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(export_blob, f, ensure_ascii=False, indent=2) |
| return path, f"Exported {len(history)} messages to {path}" |
|
|
|
|
| |
| @spaces.GPU |
| def gpu_ping(workload: int = 256) -> dict: |
| """ |
| Minimal function to satisfy ZeroGPU Spaces requirement and optionally exercise the GPU. |
| If torch with CUDA is available, perform a tiny matmul on GPU; otherwise do a CPU loop. |
| """ |
| t0 = time.time() |
| |
| acc = 0.0 |
| for i in range(max(1, workload)): |
| x = random.random() * 1000.0 |
| |
| s = math.sin(x) |
| c = math.cos(x) |
| t = math.tan(x) if abs(math.cos(x)) > 1e-9 else 1.0 |
| acc += s * c / t |
|
|
| info = {"mode": "cpu", "ops": workload} |
| |
| try: |
| import torch |
| if torch.cuda.is_available(): |
| a = torch.randn((256, 256), device="cuda") |
| b = torch.mm(a, a) |
| _ = float(b.mean().item()) |
| info["mode"] = "cuda" |
| info["device"] = torch.cuda.get_device_name(torch.cuda.current_device()) |
| info["cuda"] = True |
| else: |
| info["cuda"] = False |
| except Exception: |
| |
| info["cuda"] = "unavailable" |
|
|
| elapsed = time.time() - t0 |
| return {"ok": True, "elapsed_s": round(elapsed, 4), "acc_checksum": float(acc % 1.0), "info": info} |
| |
|
|
|
|
| def ui() -> gr.Blocks: |
| with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft()) as demo: |
| gr.Markdown(f"# {APP_TITLE}") |
| gr.Markdown(APP_DESCRIPTION) |
|
|
| |
| state_convo = gr.State([]) |
| state_history = gr.State([]) |
| state_system_prompt = gr.State("") |
| state_host = gr.State(DEFAULT_OLLAMA_HOST) |
| state_session = gr.State(str(uuid.uuid4())) |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| chatbot = gr.Chatbot(label="Chat", type="messages", height=520, avatar_images=(None, None)) |
| with gr.Row(): |
| txt = gr.Textbox( |
| label="Your message", |
| placeholder="Ask anything...", |
| autofocus=True, |
| scale=4, |
| ) |
| image_files = gr.Files( |
| label="Optional image(s)", |
| file_types=["image"], |
| type="filepath", |
| visible=True, |
| ) |
| with gr.Row(): |
| send_btn = gr.Button("Send", variant="primary") |
| stop_btn = gr.Button("Stop") |
| clear_btn = gr.Button("Clear") |
| export_btn = gr.Button("Export") |
|
|
| status = gr.Markdown("Ready.", elem_id="status_box") |
|
|
| with gr.Column(scale=2): |
| gr.Markdown("## Connection") |
| host_in = gr.Textbox( |
| label="Ollama Host URL", |
| placeholder="http://127.0.0.1:11434 (or leave blank to use server env OLLAMA_HOST)", |
| value=DEFAULT_OLLAMA_HOST, |
| ) |
| with gr.Row(): |
| test_btn = gr.Button("Test Connection") |
| refresh_models_btn = gr.Button("Refresh Models") |
|
|
| models_dd = gr.Dropdown( |
| choices=[], |
| value=None, |
| label="Model", |
| allow_custom_value=True, |
| info="Select a model from the server or type a name (e.g., llama3.1, mistral, phi4:latest)", |
| ) |
| pull_model_txt = gr.Textbox( |
| label="Pull Model (by name)", |
| placeholder="e.g., llama3.1, mistral, qwen2.5:latest", |
| ) |
| pull_btn = gr.Button("Pull Model") |
| pull_log = gr.Textbox(label="Pull Progress", interactive=False, lines=6) |
|
|
| gr.Markdown("## System Prompt") |
| sys_prompt = gr.Textbox( |
| label="System Prompt", |
| placeholder="You are a helpful assistant...", |
| lines=4, |
| value=os.getenv("SYSTEM_PROMPT", ""), |
| ) |
|
|
| gr.Markdown("## Generation Settings") |
| with gr.Row(): |
| temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") |
| with gr.Row(): |
| top_k = gr.Slider(0, 200, value=40, step=1, label="Top-k") |
| repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.01, label="Repeat Penalty") |
| with gr.Row(): |
| num_ctx = gr.Slider(256, 8192, value=4096, step=256, label="Context Window (num_ctx)") |
| max_tokens = gr.Slider(0, 8192, value=0, step=16, label="Max New Tokens (0 = auto)") |
| seed = gr.Number(value=None, label="Seed (optional)", precision=0) |
|
|
| gr.Markdown("## GPU Tools (ZeroGPU compatible)") |
| with gr.Row(): |
| gpu_workload = gr.Slider(64, 4096, value=256, step=64, label="GPU Ping Workload") |
| with gr.Row(): |
| gpu_btn = gr.Button("Run GPU Ping") |
| gpu_out = gr.Textbox(label="GPU Ping Result", lines=6, interactive=False) |
|
|
| |
| def _on_load(): |
| |
| host = DEFAULT_OLLAMA_HOST |
| names, err = list_models(host) |
| if err: |
| status_msg = f"Note: {err}" |
| else: |
| status_msg = f"Loaded {len(names)} models from {ensure_scheme(host) or '(env default)'}." |
| |
| value = DEFAULT_MODEL if DEFAULT_MODEL in names else (names[0] if names else None) |
| return ( |
| names, value, |
| host, |
| status_msg, |
| [], [], "", |
| ) |
|
|
| load_outputs = [ |
| models_dd, models_dd, |
| host_in, |
| status, |
| state_history, state_convo, state_system_prompt |
| ] |
| demo.load(_on_load, outputs=load_outputs) |
|
|
| |
| def set_host(h): |
| return ensure_scheme(h) |
|
|
| host_in.change(set_host, inputs=host_in, outputs=state_host) |
|
|
| |
| def _test(h): |
| ok, msg = test_connection(h) |
| |
| names, err = list_models(h) if ok else ([], None) |
| model_val = models_dd.value if ok and models_dd.value in names else (names[0] if names else None) |
| if err: |
| msg += f"\nAlso: {err}" |
| return names, model_val, msg |
|
|
| test_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status]) |
|
|
| |
| refresh_models_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status]) |
|
|
| |
| def _pull(h, name): |
| if not name: |
| yield "Please enter a model name to pull." |
| return |
| for line in pull_model(h, name.strip()): |
| yield line |
|
|
| pull_btn.click(_pull, inputs=[host_in, pull_model_txt], outputs=pull_log) |
|
|
| |
| clear_btn.click(clear_conversation, outputs=[chatbot, state_convo, status]) |
|
|
| |
| export_file = gr.File(label="Download Conversation", visible=True) |
| export_btn.click(export_conversation, inputs=[state_history, state_convo], outputs=[export_file, status]) |
|
|
| |
| def _submit( |
| h, m, sp, t, tp, tk, rp, ctx, mx, sd, convo, history, text, files |
| ): |
| |
| mx_int = int(mx) if mx and int(mx) > 0 else None |
| sd_int = int(sd) if sd is not None else None |
| yield from stream_chat( |
| host=h, |
| model=m or DEFAULT_MODEL, |
| system_prompt=sp or "", |
| temperature=float(t), |
| top_p=float(tp), |
| top_k=int(tk), |
| repeat_penalty=float(rp), |
| num_ctx=int(ctx), |
| max_tokens=mx_int, |
| seed=sd_int, |
| convo_messages=convo, |
| chatbot_history=history, |
| user_text=text, |
| image_files=files, |
| ) |
|
|
| submit_event = send_btn.click( |
| _submit, |
| inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files], |
| outputs=[chatbot, status, state_convo], |
| ) |
| |
| txt.submit( |
| _submit, |
| inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files], |
| outputs=[chatbot, status, state_convo], |
| ) |
|
|
| |
| stop_btn.click(None, None, None, cancels=[submit_event]) |
|
|
| |
| def _post_send(): |
| return "", None |
|
|
| send_btn.click(_post_send, outputs=[txt, image_files]) |
| txt.submit(_post_send, outputs=[txt, image_files]) |
|
|
| |
| def _sync_chatbot_state(history): |
| return history |
|
|
| chatbot.change(_sync_chatbot_state, inputs=chatbot, outputs=state_history) |
|
|
| |
| def _gpu_ping_ui(n): |
| try: |
| res = gpu_ping(int(n)) |
| try: |
| return json.dumps(res, indent=2) |
| except Exception: |
| return str(res) |
| except Exception as e: |
| return f"GPU ping failed: {e}" |
|
|
| gpu_btn.click(_gpu_ping_ui, inputs=[gpu_workload], outputs=[gpu_out]) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = ui() |
| demo.queue(default_concurrency_limit=10) |
| demo.launch(server_name="0.0.0.0", server_port=DEFAULT_PORT, show_api=True, ssr_mode=False) |
|
|