Unlocking Cognitive Capabilities and Analyzing the Perception-Logic Trade-off
Paper • 2602.23730 • Published
Resource reproduction following the paper:
"Unlocking Cognitive Capabilities and Analyzing the Perception-Logic Trade-off"
MERaLiON-Omni-MM-10B is a 10B-parameter multimodal model that accepts image, video, and audio inputs, with output in 5 Southeast Asian languages (English, Mandarin, Indonesian, Thai, Malay).
Architecture: MERaLiON-2-10B (AudioLLM) + Qwen2.5-Omni ViT, trained with LoRA on ViT + audio encoder + LLM (ACRC strategy).
This is the pre-GRPO checkpoint. For the GRPO-aligned SOTA model, see MERaLiON-Omni-GRPO-10B.
This model is released under a Research-Only License. Commercial use is prohibited without explicit authorization.
Tested with the following environment:
| Dependency | Version |
|---|---|
| Python | 3.10 |
| torch | 2.6.0 |
| transformers | 4.52.0 |
| flash-attn | 2.7.4 |
| torchaudio | 2.6.0 |
| pillow | 11.1.0 |
| triton | 3.2.0 |
| CUDA | 12.7 |
pip install transformers torch pillow torchaudio flash-attn triton
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from mm_utils_mm import process_mm_info # bundled in this repository
model_path = "zzlynxSG/MERaLiON-Omni-MM-10B"
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
processor.tokenizer.padding_side = "left"
if processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_path, trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
use_safetensors=True,
device_map="auto"
)
MIN_PIX = 4 * 28 * 28
MAX_PIX = 8192 * 28 * 28
MAX_PIX_VIDEO = 8600 * 28 * 28
chat_prompt = processor.tokenizer.apply_chat_template(
conversation=[[{"role": "user", "content": "Follow the text instruction based on the following image: <ImageHere> \n Describe this image."}]],
tokenize=False, add_generation_prompt=True
)[0]
mm_input = [{"role": "user", "content": [
{"type": "image", "image": "your_image.jpg", "min_pixels": MIN_PIX, "max_pixels": MAX_PIX}
]}]
audios, images, videos = process_mm_info(mm_input)
inputs = processor(text=[chat_prompt], audios=audios, images=images, videos=videos).to(model.device).to(model.dtype)
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
print(processor.decode(outputs[0], skip_special_tokens=True))
chat_prompt = processor.tokenizer.apply_chat_template(
conversation=[[{"role": "user", "content": "Follow the text instruction based on the following audio: <SpeechHere> \n Transcribe this audio."}]],
tokenize=False, add_generation_prompt=True
)[0]
mm_input = [{"role": "user", "content": [
{"type": "audio", "audio": "your_audio.mp3"}
]}]
audios, images, videos = process_mm_info(mm_input)
inputs = processor(text=[chat_prompt], audios=audios, images=images, videos=videos).to(model.device).to(model.dtype)
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
print(processor.decode(outputs[0], skip_special_tokens=True))
chat_prompt = processor.tokenizer.apply_chat_template(
conversation=[[{"role": "user", "content": "Follow the text instruction based on the following video: <VideoHere> \n Summarize this video."}]],
tokenize=False, add_generation_prompt=True
)[0]
mm_input = [{"role": "user", "content": [
{"type": "video", "video": "your_video.mp4", "fps": 2, "min_frames": 32, "max_frames": 64,
"min_pixels": MIN_PIX, "total_pixels": MAX_PIX_VIDEO}
]}]
audios, images, videos = process_mm_info(mm_input, use_audio_in_video=True)
inputs = processor(text=[chat_prompt], audios=audios, images=images, videos=videos).to(model.device).to(model.dtype)
with torch.inference_mode():
outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.7, do_sample=True)
print(processor.decode(outputs[0], skip_special_tokens=True))
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
from threading import Thread
class StopWordCriteria(StoppingCriteria):
def __init__(self, tokenizer, device):
stop_ids = tokenizer.encode("</answer>", add_special_tokens=False)
self.stop_seq = torch.tensor(stop_ids, device=device)
def __call__(self, input_ids, scores, **kwargs):
if input_ids[0, -len(self.stop_seq):].equal(self.stop_seq):
return True
return False
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
stop_criteria = StoppingCriteriaList([StopWordCriteria(processor.tokenizer, model.device)])
gen_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=2048,
temperature=0.75, top_p=0.9, top_k=50, do_sample=True,
stopping_criteria=stop_criteria)
thread = Thread(target=model.generate, kwargs=gen_kwargs); thread.start()
for token in streamer:
print(token, end="", flush=True)