# upscaler_specialist.py # Copyright (C) 2025 Carlos Rodrigues # Especialista ADUC para upscaling espacial de tensores latentes. import torch import logging from diffusers import LTXLatentUpsamplePipeline from ltx_manager_helpers import ltx_manager_singleton logger = logging.getLogger(__name__) class UpscalerSpecialist: """ Especialista responsável por aumentar a resolução espacial de tensores latentes usando o LTX Video Spatial Upscaler. """ def __init__(self): # Força uso de CUDA se disponível self.device = "cuda" if torch.cuda.is_available() else "cpu" self.base_vae = None self.pipe_upsample = None def _lazy_init(self): """Inicializa VAE e pipeline apenas quando necessário.""" if self.base_vae is None: try: if ltx_manager_singleton.workers: self.base_vae = ltx_manager_singleton.workers[0].pipeline.vae logger.info("[Upscaler] VAE base obtido com sucesso.") else: logger.warning("[Upscaler] Nenhum worker disponível no ltx_manager_singleton.") except Exception as e: logger.error(f"[Upscaler] Falha ao inicializar VAE: {e}") return if self.pipe_upsample is None and self.base_vae is not None: try: self.pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained( "linoyts/LTX-Video-spatial-upscaler-0.9.8", vae=self.base_vae, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) logger.info("[Upscaler] Pipeline carregado com sucesso.") except Exception as e: logger.error(f"[Upscaler] Falha ao carregar pipeline: {e}") @torch.no_grad() def upscale(self, latents: torch.Tensor) -> torch.Tensor: """Aplica o upscaling 2x nos tensores latentes fornecidos.""" self._lazy_init() if self.pipe_upsample is None: logger.warning("[Upscaler] Pipeline indisponível. Retornando latentes originais.") return latents try: logger.info(f"[Upscaler] Recebido shape {latents.shape}. Executando upscale em {self.device}...") result = self.pipe_upsample(latents=latents, output_type="latent") logger.info(f"[Upscaler] Upscale concluído. Novo shape: {result.latents.shape}") return result.latents except Exception as e: logger.error(f"[Upscaler] Erro durante upscale: {e}", exc_info=True) return latents # --------------------------- # Singleton global # --------------------------- upscaler_specialist_singleton = UpscalerSpecialist()