javipd99 commited on
Commit
389eab4
·
verified ·
1 Parent(s): 3278331

pre-warm VLM vision encoder on GPU + num_ctx 4096

Browse files
eneas/segmentation/generic_category.py CHANGED
@@ -286,6 +286,23 @@ class GenericCategorySegmenter:
286
  logger.warning(f"Could not pull model (server may be down or model unavailable): {e}")
287
  logger.info("Will attempt to use model anyway (may already be cached)")
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  # Mark VLM model as loaded and ready for inference
290
  self.vlm_model = True
291
 
@@ -468,7 +485,7 @@ Example responses:
468
  model=self.vlm_model_name,
469
  messages=messages,
470
  format=ValidationResult.model_json_schema(),
471
- options={"temperature": 0.0, "num_predict": num_predict, "num_ctx": 8192},
472
  keep_alive=-1,
473
  )
474
 
@@ -851,9 +868,9 @@ Example responses:
851
  )
852
 
853
  # Load models
 
854
  self._load_grounding_model()
855
  self._load_image_text_model()
856
- self._load_vlm_model()
857
 
858
  # Load SAM2 model for segmentation
859
  self._load_sam2_model()
 
286
  logger.warning(f"Could not pull model (server may be down or model unavailable): {e}")
287
  logger.info("Will attempt to use model anyway (may already be cached)")
288
 
289
+ # Warm the vision encoder onto the GPU while VRAM is free; otherwise Ollama
290
+ # offloads the mmproj projector to CPU under VRAM pressure (~9s/image vs ~1s).
291
+ if self.device == "cuda":
292
+ try:
293
+ buf = io.BytesIO()
294
+ Image.new("RGB", (64, 64), (32, 32, 32)).save(buf, format="JPEG")
295
+ dummy_image = base64.b64encode(buf.getvalue()).decode("utf-8")
296
+ ollama.chat(
297
+ model=self.vlm_model_name,
298
+ messages=[{"role": "user", "content": "ok", "images": [dummy_image]}],
299
+ options={"temperature": 0.0, "num_predict": 1, "num_ctx": 4096},
300
+ keep_alive=-1,
301
+ )
302
+ logger.info("VLM vision encoder pre-warmed on GPU")
303
+ except Exception as e:
304
+ logger.warning(f"VLM pre-warm failed (non-fatal): {e}")
305
+
306
  # Mark VLM model as loaded and ready for inference
307
  self.vlm_model = True
308
 
 
485
  model=self.vlm_model_name,
486
  messages=messages,
487
  format=ValidationResult.model_json_schema(),
488
+ options={"temperature": 0.0, "num_predict": num_predict, "num_ctx": 4096},
489
  keep_alive=-1,
490
  )
491
 
 
868
  )
869
 
870
  # Load models
871
+ self._load_vlm_model()
872
  self._load_grounding_model()
873
  self._load_image_text_model()
 
874
 
875
  # Load SAM2 model for segmentation
876
  self._load_sam2_model()