LiteRT-LM
yuhuichen1015 commited on
Commit
83f14bc
·
verified ·
1 Parent(s): a4a831c

Upload notebook.ipynb

Browse files

Uploaded the Colab to showcase how LiteRT-LM python API works.

Files changed (1) hide show
  1. notebook.ipynb +534 -0
notebook.ipynb ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "markdown",
21
+ "source": [
22
+ "# Run On-Device LLM Inference with LiteRT-LM and Gemma 4\n",
23
+ "\n",
24
+ "This tutorial demonstrates how to use the **LiteRT-LM** Python library to run efficient, on-device LLM inference using `.litertlm` model files.\n",
25
+ "\n",
26
+ "[LiteRT-LM](https://ai.google.dev/edge/litert-lm) is a production-ready, open-source inference framework designed to deliver high-performance, cross-platform LLM deployments on edge devices.\n",
27
+ "\n",
28
+ "* **Cross-Platform Support**: Run on Android, iOS, Web, Desktop, and IoT (e.g. Raspberry Pi).\n",
29
+ "* **Hardware Acceleration**: Get peak performance and system stability by leveraging GPU and NPU accelerators across diverse hardware.\n",
30
+ "* **Multi-Modality**: Build with LLMs that have vision and audio support.\n",
31
+ "* **Tool Use**: Function calling support for agentic workflows with constrained decoding for improved accuracy.\n",
32
+ "* **Broad Model Support**: Run Gemma, Llama, Phi-4, Qwen and more.\n",
33
+ "\n",
34
+ "### Useful Links:\n",
35
+ "* **Official Documentation**: https://ai.google.dev/edge/litert-lm\n",
36
+ "* **GitHub Repository**: https://github.com/google-ai-edge/LiteRT-LM\n",
37
+ "* **Web Demo Page**: https://google-ai-edge.github.io/LiteRT-LM/web_demos/chat/index.html\n",
38
+ "* **LiteRT-LM Developers Blogpost**: https://developers.googleblog.com/blazing-fast-on-device-genai-with-litert-lm/\n",
39
+ "\n",
40
+ "---\n",
41
+ "\n",
42
+ "In this notebook, we will showcase the core capabilities of LiteRT-LM using the **Gemma 4 E2B** multimodal model in the following order:\n",
43
+ "1. **Basic text generation**\n",
44
+ "2. **Asynchronous streaming response**\n",
45
+ "3. **Multi-modality (Vision / Image inputs)**\n",
46
+ "4. **Multi-modality (Audio / Speech inputs)**\n",
47
+ "5. **Custom system instructions & conversation history** (with switchable compact personas)\n",
48
+ "6. **Speculative decoding with Multi-Token Prediction (MTP)** (optimized with streaming)\n",
49
+ "7. **Benchmarking model execution speeds**"
50
+ ],
51
+ "metadata": {
52
+ "id": "KLZy2OXwVYw2"
53
+ }
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "source": [
58
+ "## 1. Setup and Installation\n",
59
+ "\n",
60
+ "First, let's install the required packages. We need `litert-lm-api` for the LiteRT-LM runtime, and `huggingface_hub` to easily download our optimized model from the Hugging Face model hub.\n"
61
+ ],
62
+ "metadata": {
63
+ "id": "sn1KczgNVbag"
64
+ }
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "source": [
69
+ "!pip install -q litert-lm-api huggingface_hub\n",
70
+ "\n",
71
+ "# Required for GPU\n",
72
+ "!apt-get update && apt-get install -y libvulkan1"
73
+ ],
74
+ "metadata": {
75
+ "id": "4cuBn5FSVdfp"
76
+ },
77
+ "execution_count": null,
78
+ "outputs": []
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "source": [
83
+ "## 2. Download the Gemma 4 E2B Model\n",
84
+ "\n",
85
+ "The Gemma 4 E2B instruction-tuned model is hosted on Hugging Face in the `.litertlm` format, which is optimized specifically for on-device execution.\n",
86
+ "\n",
87
+ "We will download the `gemma-4-E2B-it.litertlm` file from the [litert-community/gemma-4-E2B-it-litert-lm](https://huggingface.co/litert-community/gemma-4-E2B-it-litert-lm) repository.\n"
88
+ ],
89
+ "metadata": {
90
+ "id": "sVuZFNSFVik4"
91
+ }
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "source": [
96
+ "from huggingface_hub import hf_hub_download\n",
97
+ "\n",
98
+ "print(\"Downloading Gemma 4 E2B model from Hugging Face. This may take a few minutes...\")\n",
99
+ "model_path = hf_hub_download(\n",
100
+ " repo_id=\"litert-community/gemma-4-E2B-it-litert-lm\",\n",
101
+ " filename=\"gemma-4-E2B-it.litertlm\"\n",
102
+ ")\n",
103
+ "print(f\"Downloaded model successfully to: {model_path}\")"
104
+ ],
105
+ "metadata": {
106
+ "id": "LC-YtowxVkBj"
107
+ },
108
+ "execution_count": null,
109
+ "outputs": []
110
+ },
111
+ {
112
+ "cell_type": "markdown",
113
+ "source": [
114
+ "## 3. Basic Text Generation\n",
115
+ "\n",
116
+ "To perform inference, we initialize the `Engine` with our downloaded model. The `Engine` manages model resources.\n",
117
+ "\n",
118
+ "We then create a `Conversation` object, which handles conversation history and state. Using the `with` statement (context manager) ensures that all on-device memory and hardware resources are properly released when done.\n"
119
+ ],
120
+ "metadata": {
121
+ "id": "ZU3Y-SQtVyrg"
122
+ }
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "source": [
127
+ "import litert_lm\n",
128
+ "\n",
129
+ "# Load the model using the Engine. We will use Backend.CPU() for local CPU execution.\n",
130
+ "# (Note: GPU acceleration can be configured via backend=litert_lm.Backend.GPU() if supported)\n",
131
+ "with litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU()) as engine:\n",
132
+ " # Create a conversation instance\n",
133
+ " with engine.create_conversation() as conversation:\n",
134
+ " # Send a synchronous message\n",
135
+ " response = conversation.send_message(\"What is the capital of France?\")\n",
136
+ "\n",
137
+ " # Extract the response text\n",
138
+ " text = response[\"content\"][0][\"text\"]\n",
139
+ " print(f\"Response:\\n{text}\")"
140
+ ],
141
+ "metadata": {
142
+ "id": "vd2UPI1vVzWI"
143
+ },
144
+ "execution_count": null,
145
+ "outputs": []
146
+ },
147
+ {
148
+ "cell_type": "markdown",
149
+ "source": [
150
+ "## 4. Asynchronous Streaming (Token-by-Token)\n",
151
+ "\n",
152
+ "For interactive chat applications, waiting for the entire response to generate can feel slow.\n",
153
+ "\n",
154
+ "LiteRT-LM provides `send_message_async`, which returns an iterator that yields response chunks in real-time as they are being decoded. This allows you to stream the output token-by-token.\n"
155
+ ],
156
+ "metadata": {
157
+ "id": "gJdHVUtDZqdJ"
158
+ }
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "source": [
163
+ "with litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU()) as engine:\n",
164
+ " with engine.create_conversation() as conversation:\n",
165
+ " prompt = \"Tell me a short 3-sentence story about a brave little robot.\"\n",
166
+ " print(f\"Prompt: {prompt}\\n\\nStreaming Response:\\n\", end=\"\")\n",
167
+ "\n",
168
+ " # Start asynchronous streaming\n",
169
+ " stream = conversation.send_message_async(prompt)\n",
170
+ " for chunk in stream:\n",
171
+ " # Response chunks are dictionary objects containing a content array\n",
172
+ " for item in chunk.get(\"content\", []):\n",
173
+ " if item.get(\"type\") == \"text\":\n",
174
+ " print(item[\"text\"], end=\"\", flush=True)\n",
175
+ " print()"
176
+ ],
177
+ "metadata": {
178
+ "id": "JS9EP0LsZr5S"
179
+ },
180
+ "execution_count": null,
181
+ "outputs": []
182
+ },
183
+ {
184
+ "cell_type": "markdown",
185
+ "source": [
186
+ "## 5. Multi-Modality (Vision / Image Input)\n",
187
+ "\n",
188
+ "The **Gemma 4 E2B** model natively supports vision (images) and audio inputs in addition to text.\n",
189
+ "\n",
190
+ "To pass an image to the model:\n",
191
+ "1. Wrap the inputs in a `litert_lm.Contents` object.\n",
192
+ "2. Use `litert_lm.Content.ImageFile(image_path)` to specify the local path to the image.\n",
193
+ "\n",
194
+ "*Note: While CPU execution is shown here for simplicity, offloading vision encoding to GPU (via `vision_backend=litert_lm.Backend.GPU()`) is strongly recommended for interactive use cases.*\n"
195
+ ],
196
+ "metadata": {
197
+ "id": "NyQnf9GDZx0o"
198
+ }
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "source": [
203
+ "import urllib.request\n",
204
+ "from PIL import Image, ImageDraw\n",
205
+ "import os\n",
206
+ "from IPython.display import display\n",
207
+ "\n",
208
+ "# Download a public image (a standard red STOP sign)\n",
209
+ "image_url = \"https://upload.wikimedia.org/wikipedia/commons/f/f9/STOP_sign.jpg\"\n",
210
+ "image_path = \"stop_sign.jpg\"\n",
211
+ "\n",
212
+ "print(f\"Downloading image from {image_url}...\")\n",
213
+ "try:\n",
214
+ " # Wikimedia requires a User-Agent header to allow downloads, otherwise it returns 403 Forbidden.\n",
215
+ " req = urllib.request.Request(\n",
216
+ " image_url,\n",
217
+ " headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}\n",
218
+ " )\n",
219
+ " with urllib.request.urlopen(req) as response, open(image_path, 'wb') as out_file:\n",
220
+ " out_file.write(response.read())\n",
221
+ " print(\"Download complete.\")\n",
222
+ "except Exception as e:\n",
223
+ " print(f\"Failed to download image: {e}\")\n",
224
+ " # Fallback: create a red square with text \"STOP\" if download fails\n",
225
+ " img = Image.new(\"RGB\", (300, 300), color=\"red\")\n",
226
+ " draw = ImageDraw.Draw(img)\n",
227
+ " draw.text((120, 140), \"STOP\", fill=\"white\")\n",
228
+ " img.save(image_path)\n",
229
+ " print(\"Created a fallback image.\")\n",
230
+ "\n",
231
+ "# Open and display the image\n",
232
+ "img = Image.open(image_path)\n",
233
+ "img.thumbnail((300, 300))\n",
234
+ "display(img)\n",
235
+ "\n",
236
+ "# Load the model with vision support enabled on CPU\n",
237
+ "with litert_lm.Engine(\n",
238
+ " model_path,\n",
239
+ " backend=litert_lm.Backend.CPU(),\n",
240
+ " vision_backend=litert_lm.Backend.CPU() # Specify CPU backend for the vision processor\n",
241
+ ") as engine:\n",
242
+ " with engine.create_conversation() as conversation:\n",
243
+ " # Turn 1: Construct multimodal inputs combining image and a text prompt\n",
244
+ " multimodal_input = litert_lm.Contents.of(\n",
245
+ " litert_lm.Content.ImageFile(image_path),\n",
246
+ " \"Describe what you see in this image.\"\n",
247
+ " )\n",
248
+ "\n",
249
+ " print(\"\\nSending image + prompt to the model (streaming)...\")\n",
250
+ " stream = conversation.send_message_async(multimodal_input)\n",
251
+ " print(f\"\\nModel Description:\\n\", end=\"\")\n",
252
+ " for chunk in stream:\n",
253
+ " for item in chunk.get(\"content\", []):\n",
254
+ " if item.get(\"type\") == \"text\":\n",
255
+ " print(item[\"text\"], end=\"\", flush=True)\n",
256
+ " print(\"\\n\\n\" + \"-\" * 50 + \"\\n\")\n",
257
+ "\n",
258
+ " # Turn 2: Ask the model to read the text (context and image are preserved!)\n",
259
+ " print(\"Asking the model to perform OCR on the same image (streaming)...\")\n",
260
+ " stream2 = conversation.send_message_async(\"What text is written on the sign?\")\n",
261
+ " print(f\"\\nModel OCR Result:\\n\", end=\"\")\n",
262
+ " for chunk in stream2:\n",
263
+ " for item in chunk.get(\"content\", []):\n",
264
+ " if item.get(\"type\") == \"text\":\n",
265
+ " print(item[\"text\"], end=\"\", flush=True)\n",
266
+ " print()\n",
267
+ "\n",
268
+ "# Clean up the temporary image\n",
269
+ "if os.path.exists(image_path):\n",
270
+ " os.remove(image_path)"
271
+ ],
272
+ "metadata": {
273
+ "id": "lurL7oahZzDI"
274
+ },
275
+ "execution_count": null,
276
+ "outputs": []
277
+ },
278
+ {
279
+ "cell_type": "markdown",
280
+ "source": [
281
+ "## 6. Multi-Modality (Audio / Speech Input)\n",
282
+ "\n",
283
+ "In addition to images, **Gemma 4 E2B** natively supports audio inputs. This enables on-device **Automatic Speech Recognition (ASR)** and audio understanding.\n",
284
+ "\n",
285
+ "In this section, we will:\n",
286
+ "1. Download a public audio sample (a WAV file containing spoken words).\n",
287
+ "2. Display an interactive audio player inside the notebook.\n",
288
+ "3. Send the audio along with a text prompt to perform on-device transcription (ASR) using streaming.\n",
289
+ "\n",
290
+ "*Note: Similar to vision, offloading audio processing to CPU is shown here for simplicity, but hardware acceleration is recommended for production.*"
291
+ ],
292
+ "metadata": {
293
+ "id": "OwTWaRAWGEEd"
294
+ }
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "source": [
299
+ "import urllib.request\n",
300
+ "from IPython.display import Audio, display\n",
301
+ "import os\n",
302
+ "\n",
303
+ "# Download a public audio file (contains spoken words: \"Have a wonderful day\")\n",
304
+ "audio_url = \"https://github.com/google-ai-edge/LiteRT-LM/raw/refs/heads/main/runtime/testdata/have_a_wonderful_day.wav\"\n",
305
+ "audio_path = \"have_a_wonderful_day.wav\"\n",
306
+ "\n",
307
+ "print(f\"Downloading audio from {audio_url}...\")\n",
308
+ "try:\n",
309
+ " # Use a User-Agent to avoid potential 403 Forbidden errors\n",
310
+ " req = urllib.request.Request(\n",
311
+ " audio_url,\n",
312
+ " headers={'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64)'}\n",
313
+ " )\n",
314
+ " with urllib.request.urlopen(req) as response, open(audio_path, 'wb') as out_file:\n",
315
+ " out_file.write(response.read())\n",
316
+ " print(\"Download complete.\")\n",
317
+ "except Exception as e:\n",
318
+ " print(f\"Failed to download audio: {e}\")\n",
319
+ "\n",
320
+ "# Play the audio in the notebook\n",
321
+ "if os.path.exists(audio_path):\n",
322
+ " display(Audio(audio_path))\n",
323
+ "\n",
324
+ " # Load the model with audio support enabled on CPU\n",
325
+ " with litert_lm.Engine(\n",
326
+ " model_path,\n",
327
+ " backend=litert_lm.Backend.CPU(),\n",
328
+ " audio_backend=litert_lm.Backend.CPU() # Specify CPU backend for the audio processor\n",
329
+ " ) as engine:\n",
330
+ " with engine.create_conversation() as conversation:\n",
331
+ " # Construct multimodal inputs combining audio and a text prompt\n",
332
+ " multimodal_input = litert_lm.Contents.of(\n",
333
+ " litert_lm.Content.AudioFile(audio_path),\n",
334
+ " \"Transcribe this audio.\"\n",
335
+ " )\n",
336
+ "\n",
337
+ " print(\"\\nSending audio + prompt to the model (streaming ASR)...\")\n",
338
+ " stream = conversation.send_message_async(multimodal_input)\n",
339
+ " print(f\"\\nModel Transcription:\\n\", end=\"\")\n",
340
+ " for chunk in stream:\n",
341
+ " for item in chunk.get(\"content\", []):\n",
342
+ " if item.get(\"type\") == \"text\":\n",
343
+ " print(item[\"text\"], end=\"\", flush=True)\n",
344
+ " print()\n",
345
+ "\n",
346
+ " # Clean up the temporary audio file\n",
347
+ " os.remove(audio_path)\n",
348
+ "else:\n",
349
+ " print(\"\\nError: Audio file was not downloaded successfully. Skipping inference.\")"
350
+ ],
351
+ "metadata": {
352
+ "id": "vIGC7kfOGFKh"
353
+ },
354
+ "execution_count": null,
355
+ "outputs": []
356
+ },
357
+ {
358
+ "cell_type": "markdown",
359
+ "source": [
360
+ "## 7. System Instructions & Conversation History\n",
361
+ "\n",
362
+ "A `Conversation` object preserves the state and history of your conversation.\n",
363
+ "\n",
364
+ "You can also customize the assistant's persona and guidelines by passing a list of initial messages containing a **system instruction** when creating the conversation.\n",
365
+ "\n",
366
+ "In this example, we provide two switchable options for the assistant's persona:\n",
367
+ "* **Option A: The Grumpy Pirate**: A curt, direct character who grunts and answers in at most 30 words.\n",
368
+ "* **Option B: The Wise Zen Master**: A calm, cryptic character who answers with a short riddle/koan of at most 30 words.\n",
369
+ "\n",
370
+ "Both options are strictly constrained to at most 30 words. This demonstrates how system instructions can shape diverse assistant behaviors while keeping the generated output very brief to minimize on-device decoding latency."
371
+ ],
372
+ "metadata": {
373
+ "id": "n0FgtNFTZ2IJ"
374
+ }
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "source": [
379
+ "# Configure the system instruction. Choose one of the options below by uncommenting:\n",
380
+ "\n",
381
+ "# Option A: The Grumpy Pirate (curt, direct)\n",
382
+ "assistant_name = \"Pirate Assistant\"\n",
383
+ "system_instruction = (\n",
384
+ " \"You are a grumpy, curt pirate who hates talking. You must always answer \"\n",
385
+ " \"in a succinct but critical paragraph of at most 30 words, starting with a pirate grunt \"\n",
386
+ " \"like 'Arr', 'Bah', or 'Avast'.\"\n",
387
+ ")\n",
388
+ "\n",
389
+ "# # Option B: The Wise Zen Master (calm, cryptic) - Uncomment to switch:\n",
390
+ "# assistant_name = \"Zen Master\"\n",
391
+ "# system_instruction = (\n",
392
+ "# \"You are a wise, calm Zen Master. You must always answer with a short, \"\n",
393
+ "# \"cryptic riddle or koan of at most 30 words that forces the user to reflect.\"\n",
394
+ "# )\n",
395
+ "\n",
396
+ "initial_messages = [\n",
397
+ " litert_lm.Message.system(system_instruction)\n",
398
+ "]\n",
399
+ "\n",
400
+ "with litert_lm.Engine(model_path, backend=litert_lm.Backend.CPU()) as engine:\n",
401
+ " # Initialize conversation with our custom system instruction\n",
402
+ " with engine.create_conversation(messages=initial_messages) as conversation:\n",
403
+ "\n",
404
+ " # Turn 1 (Streaming)\n",
405
+ " print(f\"User: How can I write clean code?\\n\\n{assistant_name} (streaming):\\n\", end=\"\")\n",
406
+ " stream = conversation.send_message_async(\"How can I write clean code?\")\n",
407
+ " for chunk in stream:\n",
408
+ " for item in chunk.get(\"content\", []):\n",
409
+ " if item.get(\"type\") == \"text\":\n",
410
+ " print(item[\"text\"], end=\"\", flush=True)\n",
411
+ " print(\"\\n\\n\" + \"-\" * 50 + \"\\n\")\n",
412
+ "\n",
413
+ " # Turn 2 (Context is automatically maintained in this conversation, Streaming)\n",
414
+ " print(f\"User: And what about testing?\\n\\n{assistant_name} (streaming):\\n\", end=\"\")\n",
415
+ " stream2 = conversation.send_message_async(\"And what about testing?\")\n",
416
+ " for chunk in stream2:\n",
417
+ " for item in chunk.get(\"content\", []):\n",
418
+ " if item.get(\"type\") == \"text\":\n",
419
+ " print(item[\"text\"], end=\"\", flush=True)\n",
420
+ " print()"
421
+ ],
422
+ "metadata": {
423
+ "id": "m4GU7oe6Z2v9"
424
+ },
425
+ "execution_count": null,
426
+ "outputs": []
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "source": [
431
+ "## 8. Multi-Token Prediction (MTP)\n",
432
+ "\n",
433
+ "**Multi-Token Prediction (MTP)** is an advanced performance optimization in LiteRT-LM that significantly accelerates decoding speed. It works by predicting multiple tokens in parallel per execution step (speculative decoding).\n",
434
+ "\n",
435
+ "To learn more about how MTP works and its performance benefits, check out the official [Google DeepMind Blog post](https://blog.google/innovation-and-ai/technology/developers-tools/multi-token-prediction-gemma-4/).\n",
436
+ "\n",
437
+ "<img src=\"https://storage.googleapis.com/gweb-uniblog-publish-prod/images/Chart_Blog_Updated.width-1000.format-webp.webp\" width=\"600\" alt=\"MTP Speedup Chart\" />\n",
438
+ "\n",
439
+ "To use MTP, you simply set `enable_speculative_decoding=True` when creating the `Engine`."
440
+ ],
441
+ "metadata": {
442
+ "id": "Dq2sff4EZ4ob"
443
+ }
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "source": [
448
+ "# Initialize with enable_speculative_decoding=True to leverage Multi-Token Prediction (MTP)\n",
449
+ "with litert_lm.Engine(\n",
450
+ " model_path,\n",
451
+ " backend=litert_lm.Backend.CPU(),\n",
452
+ " enable_speculative_decoding=True\n",
453
+ ") as engine:\n",
454
+ " with engine.create_conversation() as conversation:\n",
455
+ " print(\"Sending prompt with MTP enabled (streaming):\\n\", end=\"\")\n",
456
+ " stream = conversation.send_message_async(\"Explain quantum computing in one sentence.\")\n",
457
+ " for chunk in stream:\n",
458
+ " for item in chunk.get(\"content\", []):\n",
459
+ " if item.get(\"type\") == \"text\":\n",
460
+ " print(item[\"text\"], end=\"\", flush=True)\n",
461
+ " print()"
462
+ ],
463
+ "metadata": {
464
+ "id": "gh89sKI3Z7Qf"
465
+ },
466
+ "execution_count": null,
467
+ "outputs": []
468
+ },
469
+ {
470
+ "cell_type": "markdown",
471
+ "source": [
472
+ "## 9. Benchmarking Performance\n",
473
+ "\n",
474
+ "LiteRT-LM includes built-in benchmarking utilities that let you measure important performance metrics for on-device execution:\n",
475
+ "* **Model Init Time**: Time (in seconds) to load and prepare the model.\n",
476
+ "* **Time-to-First-Token (TTFT)**: Latency from sending the prompt to receiving the first generated token.\n",
477
+ "* **Prefill Speed**: Throughput during prompt ingestion (tokens/sec).\n",
478
+ "* **Decode Speed**: Generation speed of subsequent tokens (tokens/sec).\n",
479
+ "\n",
480
+ "Let's measure these on our current environment.\n"
481
+ ],
482
+ "metadata": {
483
+ "id": "9cbhDA8sZ8kL"
484
+ }
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "source": [
489
+ "# Configure a benchmark run\n",
490
+ "benchmark = litert_lm.Benchmark(\n",
491
+ " model_path,\n",
492
+ " litert_lm.Backend.CPU(),\n",
493
+ " prefill_tokens=64, # Emulate a 64-token prompt\n",
494
+ " decode_tokens=64, # Emulate generating 64 tokens\n",
495
+ ")\n",
496
+ "\n",
497
+ "print(\"Running benchmark. Please wait...\")\n",
498
+ "results = benchmark.run()\n",
499
+ "\n",
500
+ "print(\"\\n=== Benchmark Results ===\")\n",
501
+ "print(f\"Model Init Time: {results.init_time_in_second:.4f} seconds\")\n",
502
+ "print(f\"Time to First Token (TTFT): {results.time_to_first_token_in_second:.4f} seconds\")\n",
503
+ "print(f\"Prefill Speed: {results.last_prefill_tokens_per_second:.2f} tokens/second\")\n",
504
+ "print(f\"Decode Speed: {results.last_decode_tokens_per_second:.2f} tokens/second\")"
505
+ ],
506
+ "metadata": {
507
+ "id": "Ac5zkAJFZ95J"
508
+ },
509
+ "execution_count": null,
510
+ "outputs": []
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "source": [
515
+ "## Summary & Next Steps\n",
516
+ "\n",
517
+ "Congratulations! You've completed this LiteRT-LM tutorial with Gemma 4 E2B. You now know how to:\n",
518
+ "1. Download `.litertlm` optimized model files from Hugging Face.\n",
519
+ "2. Run synchronous and asynchronous streaming text generation.\n",
520
+ "3. Use native on-device Multi-Modality by passing both **image** and **audio** files (running on-device OCR and ASR).\n",
521
+ "4. Configure custom system instructions with **switchable compact personas** and maintain conversation context.\n",
522
+ "5. Optimize performance with speculative decoding / Multi-Token Prediction (MTP).\n",
523
+ "6. Benchmark on-device execution speeds with the built-in suite.\n",
524
+ "\n",
525
+ "For more details and native deployment platforms, visit:\n",
526
+ "* **Official Documentation**: [https://ai.google.dev/edge/litert-lm](https://ai.google.dev/edge/litert-lm)\n",
527
+ "* **GitHub Repository**: [https://github.com/google-ai-edge/LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM)"
528
+ ],
529
+ "metadata": {
530
+ "id": "XgnfA2I_aDZY"
531
+ }
532
+ }
533
+ ]
534
+ }