{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# Run On-Device LLM Inference with LiteRT-LM and Gemma 4\n", "\n", "This tutorial demonstrates how to use the **LiteRT-LM** Python library to run efficient, on-device LLM inference using `.litertlm` model files.\n", "\n", "[LiteRT-LM](https://ai.google.dev/edge/litert-lm) is a production-ready, open-source inference framework designed to deliver high-performance, cross-platform LLM deployments on edge devices.\n", "\n", "* **Cross-Platform Support**: Run on Android, iOS, Web, Desktop, and IoT (e.g. Raspberry Pi).\n", "* **Hardware Acceleration**: Get peak performance and system stability by leveraging GPU and NPU accelerators across diverse hardware.\n", "* **Multi-Modality**: Build with LLMs that have vision and audio support.\n", "* **Tool Use**: Function calling support for agentic workflows with constrained decoding for improved accuracy.\n", "* **Broad Model Support**: Run Gemma, Llama, Phi-4, Qwen and more.\n", "\n", "### Useful Links:\n", "* **Official Documentation**: https://ai.google.dev/edge/litert-lm\n", "* **GitHub Repository**: https://github.com/google-ai-edge/LiteRT-LM\n", "* **Web Demo Page**: https://google-ai-edge.github.io/LiteRT-LM/web_demos/chat/index.html\n", "* **LiteRT-LM Developers Blogpost**: https://developers.googleblog.com/blazing-fast-on-device-genai-with-litert-lm/\n", "\n", "---\n", "\n", "In this notebook, we will showcase the core capabilities of LiteRT-LM using the **Gemma 4 E2B** multimodal model in the following order:\n", "1. **Basic text generation**\n", "2. **Asynchronous streaming response**\n", "3. **Multi-modality (Vision / Image inputs)**\n", "4. **Multi-modality (Audio / Speech inputs)**\n", "5. **Custom system instructions & conversation history** (with switchable compact personas)\n", "6. **Speculative decoding with Multi-Token Prediction (MTP)** (optimized with streaming)\n", "7. **Benchmarking model execution speeds**" ], "metadata": { "id": "KLZy2OXwVYw2" } }, { "cell_type": "markdown", "source": [ "## 1. Setup and Installation\n", "\n", "First, let's install the required packages. We need `litert-lm-api` for the LiteRT-LM runtime, and `huggingface_hub` to easily download our optimized model from the Hugging Face model hub.\n" ], "metadata": { "id": "sn1KczgNVbag" } }, { "cell_type": "code", "source": [ "!pip install -q litert-lm-api huggingface_hub\n", "\n", "# Required for GPU\n", "!apt-get update && apt-get install -y libvulkan1" ], "metadata": { "id": "4cuBn5FSVdfp" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 2. Download the Gemma 4 E2B Model\n", "\n", "The Gemma 4 E2B instruction-tuned model is hosted on Hugging Face in the `.litertlm` format, which is optimized specifically for on-device execution.\n", "\n", "We will download the `gemma-4-E2B-it.litertlm` file from the [litert-community/gemma-4-E2B-it-litert-lm](https://huggingface.co/litert-community/gemma-4-E2B-it-litert-lm) repository.\n" ], "metadata": { "id": "sVuZFNSFVik4" } }, { "cell_type": "code", "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "print(\"Downloading Gemma 4 E2B model from Hugging Face. This may take a few minutes...\")\n", "model_path = hf_hub_download(\n", " repo_id=\"litert-community/gemma-4-E2B-it-litert-lm\",\n", " filename=\"gemma-4-E2B-it.litertlm\"\n", ")\n", "print(f\"Downloaded model successfully to: {model_path}\")" ], "metadata": { "id": "LC-YtowxVkBj" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 3. Basic Text Generation\n", "\n", "To perform inference, we initialize the `Engine` with our downloaded model. The `Engine` manages model resources.\n", "\n", "We then create a `Conversation` object, which handles conversation history and state. Using the `with` statement (context manager) ensures that all on-device memory and hardware resources are properly released when done.\n" ], "metadata": { "id": "ZU3Y-SQtVyrg" } }, { "cell_type": "code", "source": [ "import litert_lm\n", "\n", "# Load the model using the Engine. We will use Backend.CPU() for local CPU execution.\n", "# (Note: GPU acceleration can be configured via backend=litert_lm.Backend.GPU() if supported)\n", "with litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU()) as engine:\n", " # Create a conversation instance\n", " with engine.create_conversation() as conversation:\n", " # Send a synchronous message\n", " response = conversation.send_message(\"What is the capital of France?\")\n", "\n", " # Extract the response text\n", " text = response[\"content\"][0][\"text\"]\n", " print(f\"Response:\\n{text}\")" ], "metadata": { "id": "vd2UPI1vVzWI" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 4. Asynchronous Streaming (Token-by-Token)\n", "\n", "For interactive chat applications, waiting for the entire response to generate can feel slow.\n", "\n", "LiteRT-LM provides `send_message_async`, which returns an iterator that yields response chunks in real-time as they are being decoded. This allows you to stream the output token-by-token.\n" ], "metadata": { "id": "gJdHVUtDZqdJ" } }, { "cell_type": "code", "source": [ "with litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU()) as engine:\n", " with engine.create_conversation() as conversation:\n", " prompt = \"Tell me a short 3-sentence story about a brave little robot.\"\n", " print(f\"Prompt: {prompt}\\n\\nStreaming Response:\\n\", end=\"\")\n", "\n", " # Start asynchronous streaming\n", " stream = conversation.send_message_async(prompt)\n", " for chunk in stream:\n", " # Response chunks are dictionary objects containing a content array\n", " for item in chunk.get(\"content\", []):\n", " if item.get(\"type\") == \"text\":\n", " print(item[\"text\"], end=\"\", flush=True)\n", " print()" ], "metadata": { "id": "JS9EP0LsZr5S" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 5. Multi-Modality (Vision / Image Input)\n", "\n", "The **Gemma 4 E2B** model natively supports vision (images) and audio inputs in addition to text.\n", "\n", "To pass an image to the model:\n", "1. Wrap the inputs in a `litert_lm.Contents` object.\n", "2. Use `litert_lm.Content.ImageFile(image_path)` to specify the local path to the image.\n", "\n", "*Note: While CPU execution is shown here for simplicity, offloading vision encoding to GPU (via `vision_backend=litert_lm.Backend.GPU()`) is strongly recommended for interactive use cases.*\n" ], "metadata": { "id": "NyQnf9GDZx0o" } }, { "cell_type": "code", "source": [ "import urllib.request\n", "from PIL import Image, ImageDraw\n", "import os\n", "from IPython.display import display\n", "\n", "# Download a public image (a standard red STOP sign)\n", "image_url = \"https://upload.wikimedia.org/wikipedia/commons/f/f9/STOP_sign.jpg\"\n", "image_path = \"stop_sign.jpg\"\n", "\n", "print(f\"Downloading image from {image_url}...\")\n", "try:\n", " # Wikimedia requires a User-Agent header to allow downloads, otherwise it returns 403 Forbidden.\n", " req = urllib.request.Request(\n", " image_url,\n", " headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}\n", " )\n", " with urllib.request.urlopen(req) as response, open(image_path, 'wb') as out_file:\n", " out_file.write(response.read())\n", " print(\"Download complete.\")\n", "except Exception as e:\n", " print(f\"Failed to download image: {e}\")\n", " # Fallback: create a red square with text \"STOP\" if download fails\n", " img = Image.new(\"RGB\", (300, 300), color=\"red\")\n", " draw = ImageDraw.Draw(img)\n", " draw.text((120, 140), \"STOP\", fill=\"white\")\n", " img.save(image_path)\n", " print(\"Created a fallback image.\")\n", "\n", "# Open and display the image\n", "img = Image.open(image_path)\n", "img.thumbnail((300, 300))\n", "display(img)\n", "\n", "# Load the model with vision support enabled on CPU\n", "with litert_lm.Engine(\n", " model_path,\n", " backend=litert_lm.Backend.CPU(),\n", " vision_backend=litert_lm.Backend.CPU() # Specify CPU backend for the vision processor\n", ") as engine:\n", " with engine.create_conversation() as conversation:\n", " # Turn 1: Construct multimodal inputs combining image and a text prompt\n", " multimodal_input = litert_lm.Contents.of(\n", " litert_lm.Content.ImageFile(image_path),\n", " \"Describe what you see in this image.\"\n", " )\n", "\n", " print(\"\\nSending image + prompt to the model (streaming)...\")\n", " stream = conversation.send_message_async(multimodal_input)\n", " print(f\"\\nModel Description:\\n\", end=\"\")\n", " for chunk in stream:\n", " for item in chunk.get(\"content\", []):\n", " if item.get(\"type\") == \"text\":\n", " print(item[\"text\"], end=\"\", flush=True)\n", " print(\"\\n\\n\" + \"-\" * 50 + \"\\n\")\n", "\n", " # Turn 2: Ask the model to read the text (context and image are preserved!)\n", " print(\"Asking the model to perform OCR on the same image (streaming)...\")\n", " stream2 = conversation.send_message_async(\"What text is written on the sign?\")\n", " print(f\"\\nModel OCR Result:\\n\", end=\"\")\n", " for chunk in stream2:\n", " for item in chunk.get(\"content\", []):\n", " if item.get(\"type\") == \"text\":\n", " print(item[\"text\"], end=\"\", flush=True)\n", " print()\n", "\n", "# Clean up the temporary image\n", "if os.path.exists(image_path):\n", " os.remove(image_path)" ], "metadata": { "id": "lurL7oahZzDI" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 6. Multi-Modality (Audio / Speech Input)\n", "\n", "In addition to images, **Gemma 4 E2B** natively supports audio inputs. This enables on-device **Automatic Speech Recognition (ASR)** and audio understanding.\n", "\n", "In this section, we will:\n", "1. Download a public audio sample (a WAV file containing spoken words).\n", "2. Display an interactive audio player inside the notebook.\n", "3. Send the audio along with a text prompt to perform on-device transcription (ASR) using streaming.\n", "\n", "*Note: Similar to vision, offloading audio processing to CPU is shown here for simplicity, but hardware acceleration is recommended for production.*" ], "metadata": { "id": "OwTWaRAWGEEd" } }, { "cell_type": "code", "source": [ "import urllib.request\n", "from IPython.display import Audio, display\n", "import os\n", "\n", "# Download a public audio file (contains spoken words: \"Have a wonderful day\")\n", "audio_url = \"https://github.com/google-ai-edge/LiteRT-LM/raw/refs/heads/main/runtime/testdata/have_a_wonderful_day.wav\"\n", "audio_path = \"have_a_wonderful_day.wav\"\n", "\n", "print(f\"Downloading audio from {audio_url}...\")\n", "try:\n", " # Use a User-Agent to avoid potential 403 Forbidden errors\n", " req = urllib.request.Request(\n", " audio_url,\n", " headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)'}\n", " )\n", " with urllib.request.urlopen(req) as response, open(audio_path, 'wb') as out_file:\n", " out_file.write(response.read())\n", " print(\"Download complete.\")\n", "except Exception as e:\n", " print(f\"Failed to download audio: {e}\")\n", "\n", "# Play the audio in the notebook\n", "if os.path.exists(audio_path):\n", " display(Audio(audio_path))\n", "\n", " # Load the model with audio support enabled on CPU\n", " with litert_lm.Engine(\n", " model_path,\n", " backend=litert_lm.Backend.CPU(),\n", " audio_backend=litert_lm.Backend.CPU() # Specify CPU backend for the audio processor\n", " ) as engine:\n", " with engine.create_conversation() as conversation:\n", " # Construct multimodal inputs combining audio and a text prompt\n", " multimodal_input = litert_lm.Contents.of(\n", " litert_lm.Content.AudioFile(audio_path),\n", " \"Transcribe this audio.\"\n", " )\n", "\n", " print(\"\\nSending audio + prompt to the model (streaming ASR)...\")\n", " stream = conversation.send_message_async(multimodal_input)\n", " print(f\"\\nModel Transcription:\\n\", end=\"\")\n", " for chunk in stream:\n", " for item in chunk.get(\"content\", []):\n", " if item.get(\"type\") == \"text\":\n", " print(item[\"text\"], end=\"\", flush=True)\n", " print()\n", "\n", " # Clean up the temporary audio file\n", " os.remove(audio_path)\n", "else:\n", " print(\"\\nError: Audio file was not downloaded successfully. Skipping inference.\")" ], "metadata": { "id": "vIGC7kfOGFKh" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 7. System Instructions & Conversation History\n", "\n", "A `Conversation` object preserves the state and history of your conversation.\n", "\n", "You can also customize the assistant's persona and guidelines by passing a list of initial messages containing a **system instruction** when creating the conversation.\n", "\n", "In this example, we provide two switchable options for the assistant's persona:\n", "* **Option A: The Grumpy Pirate**: A curt, direct character who grunts and answers in at most 30 words.\n", "* **Option B: The Wise Zen Master**: A calm, cryptic character who answers with a short riddle/koan of at most 30 words.\n", "\n", "Both options are strictly constrained to at most 30 words. This demonstrates how system instructions can shape diverse assistant behaviors while keeping the generated output very brief to minimize on-device decoding latency." ], "metadata": { "id": "n0FgtNFTZ2IJ" } }, { "cell_type": "code", "source": [ "# Configure the system instruction. Choose one of the options below by uncommenting:\n", "\n", "# Option A: The Grumpy Pirate (curt, direct)\n", "assistant_name = \"Pirate Assistant\"\n", "system_instruction = (\n", " \"You are a grumpy, curt pirate who hates talking. You must always answer \"\n", " \"in a succinct but critical paragraph of at most 30 words, starting with a pirate grunt \"\n", " \"like 'Arr', 'Bah', or 'Avast'.\"\n", ")\n", "\n", "# # Option B: The Wise Zen Master (calm, cryptic) - Uncomment to switch:\n", "# assistant_name = \"Zen Master\"\n", "# system_instruction = (\n", "# \"You are a wise, calm Zen Master. You must always answer with a short, \"\n", "# \"cryptic riddle or koan of at most 30 words that forces the user to reflect.\"\n", "# )\n", "\n", "initial_messages = [\n", " litert_lm.Message.system(system_instruction)\n", "]\n", "\n", "with litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU()) as engine:\n", " # Initialize conversation with our custom system instruction\n", " with engine.create_conversation(messages=initial_messages) as conversation:\n", "\n", " # Turn 1 (Streaming)\n", " print(f\"User: How can I write clean code?\\n\\n{assistant_name} (streaming):\\n\", end=\"\")\n", " stream = conversation.send_message_async(\"How can I write clean code?\")\n", " for chunk in stream:\n", " for item in chunk.get(\"content\", []):\n", " if item.get(\"type\") == \"text\":\n", " print(item[\"text\"], end=\"\", flush=True)\n", " print(\"\\n\\n\" + \"-\" * 50 + \"\\n\")\n", "\n", " # Turn 2 (Context is automatically maintained in this conversation, Streaming)\n", " print(f\"User: And what about testing?\\n\\n{assistant_name} (streaming):\\n\", end=\"\")\n", " stream2 = conversation.send_message_async(\"And what about testing?\")\n", " for chunk in stream2:\n", " for item in chunk.get(\"content\", []):\n", " if item.get(\"type\") == \"text\":\n", " print(item[\"text\"], end=\"\", flush=True)\n", " print()" ], "metadata": { "id": "m4GU7oe6Z2v9" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 8. Multi-Token Prediction (MTP)\n", "\n", "**Multi-Token Prediction (MTP)** is an advanced performance optimization in LiteRT-LM that significantly accelerates decoding speed. It works by predicting multiple tokens in parallel per execution step (speculative decoding).\n", "\n", "To learn more about how MTP works and its performance benefits, check out the official [Google DeepMind Blog post](https://blog.google/innovation-and-ai/technology/developers-tools/multi-token-prediction-gemma-4/).\n", "\n", "\"MTP\n", "\n", "To use MTP, you simply set `enable_speculative_decoding=True` when creating the `Engine`." ], "metadata": { "id": "Dq2sff4EZ4ob" } }, { "cell_type": "code", "source": [ "# Initialize with enable_speculative_decoding=True to leverage Multi-Token Prediction (MTP)\n", "with litert_lm.Engine(\n", " model_path,\n", " backend=litert_lm.Backend.CPU(),\n", " enable_speculative_decoding=True\n", ") as engine:\n", " with engine.create_conversation() as conversation:\n", " print(\"Sending prompt with MTP enabled (streaming):\\n\", end=\"\")\n", " stream = conversation.send_message_async(\"Explain quantum computing in one sentence.\")\n", " for chunk in stream:\n", " for item in chunk.get(\"content\", []):\n", " if item.get(\"type\") == \"text\":\n", " print(item[\"text\"], end=\"\", flush=True)\n", " print()" ], "metadata": { "id": "gh89sKI3Z7Qf" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## 9. Benchmarking Performance\n", "\n", "LiteRT-LM includes built-in benchmarking utilities that let you measure important performance metrics for on-device execution:\n", "* **Model Init Time**: Time (in seconds) to load and prepare the model.\n", "* **Time-to-First-Token (TTFT)**: Latency from sending the prompt to receiving the first generated token.\n", "* **Prefill Speed**: Throughput during prompt ingestion (tokens/sec).\n", "* **Decode Speed**: Generation speed of subsequent tokens (tokens/sec).\n", "\n", "Let's measure these on our current environment.\n" ], "metadata": { "id": "9cbhDA8sZ8kL" } }, { "cell_type": "code", "source": [ "# Configure a benchmark run\n", "benchmark = litert_lm.Benchmark(\n", " model_path,\n", " litert_lm.Backend.CPU(),\n", " prefill_tokens=64, # Emulate a 64-token prompt\n", " decode_tokens=64, # Emulate generating 64 tokens\n", ")\n", "\n", "print(\"Running benchmark. Please wait...\")\n", "results = benchmark.run()\n", "\n", "print(\"\\n=== Benchmark Results ===\")\n", "print(f\"Model Init Time: {results.init_time_in_second:.4f} seconds\")\n", "print(f\"Time to First Token (TTFT): {results.time_to_first_token_in_second:.4f} seconds\")\n", "print(f\"Prefill Speed: {results.last_prefill_tokens_per_second:.2f} tokens/second\")\n", "print(f\"Decode Speed: {results.last_decode_tokens_per_second:.2f} tokens/second\")" ], "metadata": { "id": "Ac5zkAJFZ95J" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Summary & Next Steps\n", "\n", "Congratulations! You've completed this LiteRT-LM tutorial with Gemma 4 E2B. You now know how to:\n", "1. Download `.litertlm` optimized model files from Hugging Face.\n", "2. Run synchronous and asynchronous streaming text generation.\n", "3. Use native on-device Multi-Modality by passing both **image** and **audio** files (running on-device OCR and ASR).\n", "4. Configure custom system instructions with **switchable compact personas** and maintain conversation context.\n", "5. Optimize performance with speculative decoding / Multi-Token Prediction (MTP).\n", "6. Benchmark on-device execution speeds with the built-in suite.\n", "\n", "For more details and native deployment platforms, visit:\n", "* **Official Documentation**: [https://ai.google.dev/edge/litert-lm](https://ai.google.dev/edge/litert-lm)\n", "* **GitHub Repository**: [https://github.com/google-ai-edge/LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM)" ], "metadata": { "id": "XgnfA2I_aDZY" } } ] }