euiia commited on
Commit
d354454
·
verified ·
1 Parent(s): 6474f1f

Update ltx_manager_helpers.py

Browse files
Files changed (1) hide show
  1. ltx_manager_helpers.py +44 -41
ltx_manager_helpers.py CHANGED
@@ -25,7 +25,6 @@ class LtxWorker:
25
  Gerencia o carregamento do modelo para a CPU e a movimentação de/para a GPU.
26
  """
27
  def __init__(self, device_id, ltx_config_file):
28
- # ... (código do LtxWorker __init__ permanece o mesmo) ...
29
  self.cpu_device = torch.device('cpu')
30
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
31
  logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
@@ -58,7 +57,6 @@ class LtxWorker:
58
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
59
  self.pipeline.to(self.device)
60
 
61
- # A otimização agora ocorre aqui, uma única vez, quando o modelo vai para a GPU.
62
  if self.device.type == 'cuda' and can_optimize_fp8():
63
  logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
64
  optimize_ltx_worker(self)
@@ -80,8 +78,7 @@ class LtxWorker:
80
 
81
  class LtxPoolManager:
82
  """
83
- Gerencia um pool de LtxWorkers para otimizar o uso de múltiplas GPUs.
84
- NOVO MODO "HOT START": Mantém todos os modelos carregados na VRAM para latência mínima.
85
  """
86
  def __init__(self, device_ids, ltx_config_file):
87
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
@@ -89,9 +86,6 @@ class LtxPoolManager:
89
  self.current_worker_index = 0
90
  self.lock = threading.Lock()
91
 
92
- # ######################################################################
93
- # ## MUDANÇA 1: PRÉ-AQUECIMENTO DAS GPUs ##
94
- # ######################################################################
95
  if all(w.device.type == 'cuda' for w in self.workers):
96
  logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
97
  for worker in self.workers:
@@ -99,10 +93,8 @@ class LtxPoolManager:
99
  logger.info("LTX POOL MANAGER: Todas as GPUs estão quentes e prontas.")
100
  else:
101
  logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. O pré-aquecimento de GPU foi ignorado.")
102
- # ######################################################################
103
 
104
  def _prepare_and_log_params(self, worker_to_use, **kwargs):
105
- # ... (Esta função permanece exatamente a mesma) ...
106
  target_device = worker_to_use.device
107
  height, width = kwargs['height'], kwargs['width']
108
 
@@ -118,8 +110,14 @@ class LtxPoolManager:
118
  )
119
 
120
  first_pass_config = worker_to_use.config.get("first_pass", {})
121
- padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
122
- padding_vals = calculate_padding(height, width, padded_h, padded_w)
 
 
 
 
 
 
123
 
124
  pipeline_params = {
125
  "height": padded_h, "width": padded_w,
@@ -148,29 +146,18 @@ class LtxPoolManager:
148
 
149
  logger.info("="*60)
150
  logger.info(f"CHAMADA AO PIPELINE LTX NO DISPOSITIVO: {worker_to_use.device}")
151
- logger.info(f"Modelo: {'Distilled' if worker_to_use.is_distilled else 'Base'}")
152
- logger.info("-" * 20 + " PARÂMETROS DA PIPELINE " + "-" * 20)
153
- logger.info(json.dumps(log_friendly_params, indent=2))
154
- logger.info("-" * 20 + " ITENS DE CONDICIONAMENTO " + "-" * 19)
155
- logger.info("\n".join(conditioning_log_details) if conditioning_log_details else " - Nenhum")
156
- logger.info("="*60)
157
 
158
  return pipeline_params, padding_vals
159
 
160
  def _execute_on_worker(self, execution_fn, **kwargs):
161
- """
162
- Função unificada para selecionar um worker e executar uma tarefa,
163
- sem a lógica de carregar/descarregar.
164
- """
165
  worker_to_use = None
166
  try:
167
  with self.lock:
168
  worker_to_use = self.workers[self.current_worker_index]
169
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
170
 
171
- pipeline_params, padding_vals = self._prepare_and_log_params(worker_to_use, **kwargs)
172
-
173
- result = execution_fn(worker_to_use, pipeline_params, **kwargs)
174
 
175
  return result, padding_vals
176
 
@@ -178,37 +165,53 @@ class LtxPoolManager:
178
  logger.error(f"LTX POOL MANAGER: Erro durante a execução em {worker_to_use.device if worker_to_use else 'N/A'}: {e}", exc_info=True)
179
  raise e
180
  finally:
181
- # Apenas limpa o cache da GPU, não descarrega o modelo.
182
  if worker_to_use and worker_to_use.device.type == 'cuda':
183
  with torch.cuda.device(worker_to_use.device):
184
  gc.collect()
185
  torch.cuda.empty_cache()
186
 
187
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
188
- """
189
- Orquestra a geração de um novo fragmento de vídeo a partir do ruído.
190
- """
191
- def execution_logic(worker, params, **inner_kwargs):
192
- params['output_type'] = "latent"
193
  with torch.no_grad():
194
- return worker.generate_video_fragment_internal(**params)
 
195
 
196
  return self._execute_on_worker(execution_logic, **kwargs)
197
 
198
  def refine_latents(self, upscaled_latents: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
199
- """
200
- Orquestra um passe de difusão curto em latentes já existentes para refinamento.
201
- """
202
- def execution_logic(worker, params, **inner_kwargs):
203
- params['latents'] = upscaled_latents.to(worker.device, dtype=worker.pipeline.transformer.dtype)
204
- params['strength'] = inner_kwargs.get('denoise_strength', 0.4)
205
- params['num_inference_steps'] = int(inner_kwargs.get('refine_steps', 10))
206
- params['output_type'] = "latent"
207
 
208
- logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise) em latentes de alta resolução.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  with torch.no_grad():
211
- return worker.generate_video_fragment_internal(**params)
 
 
212
 
213
  return self._execute_on_worker(execution_logic, upscaled_latents=upscaled_latents, **kwargs)
214
 
 
25
  Gerencia o carregamento do modelo para a CPU e a movimentação de/para a GPU.
26
  """
27
  def __init__(self, device_id, ltx_config_file):
 
28
  self.cpu_device = torch.device('cpu')
29
  self.device = torch.device(device_id if torch.cuda.is_available() else 'cpu')
30
  logger.info(f"LTX Worker ({self.device}): Inicializando com config '{ltx_config_file}'...")
 
57
  logger.info(f"LTX Worker: Movendo pipeline para a GPU {self.device}...")
58
  self.pipeline.to(self.device)
59
 
 
60
  if self.device.type == 'cuda' and can_optimize_fp8():
61
  logger.info(f"LTX Worker ({self.device}): GPU com suporte a FP8 detectada. Iniciando otimização...")
62
  optimize_ltx_worker(self)
 
78
 
79
  class LtxPoolManager:
80
  """
81
+ Gerencia um pool de LtxWorkers. MODO "HOT START": Mantém todos os modelos carregados na VRAM.
 
82
  """
83
  def __init__(self, device_ids, ltx_config_file):
84
  logger.info(f"LTX POOL MANAGER: Criando workers para os dispositivos: {device_ids}")
 
86
  self.current_worker_index = 0
87
  self.lock = threading.Lock()
88
 
 
 
 
89
  if all(w.device.type == 'cuda' for w in self.workers):
90
  logger.info("LTX POOL MANAGER: MODO HOT START ATIVADO. Pré-aquecendo todas as GPUs...")
91
  for worker in self.workers:
 
93
  logger.info("LTX POOL MANAGER: Todas as GPUs estão quentes e prontas.")
94
  else:
95
  logger.info("LTX POOL MANAGER: Operando em modo CPU ou misto. O pré-aquecimento de GPU foi ignorado.")
 
96
 
97
  def _prepare_and_log_params(self, worker_to_use, **kwargs):
 
98
  target_device = worker_to_use.device
99
  height, width = kwargs['height'], kwargs['width']
100
 
 
110
  )
111
 
112
  first_pass_config = worker_to_use.config.get("first_pass", {})
113
+
114
+ # Correção para o modo de refinamento: não recalcular padding
115
+ if 'latents' in kwargs and kwargs['latents'] is not None:
116
+ padded_h, padded_w = height, width
117
+ padding_vals = (0, 0, 0, 0)
118
+ else:
119
+ padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
120
+ padding_vals = calculate_padding(height, width, padded_h, padded_w)
121
 
122
  pipeline_params = {
123
  "height": padded_h, "width": padded_w,
 
146
 
147
  logger.info("="*60)
148
  logger.info(f"CHAMADA AO PIPELINE LTX NO DISPOSITIVO: {worker_to_use.device}")
149
+ # ... (resto do logging)
 
 
 
 
 
150
 
151
  return pipeline_params, padding_vals
152
 
153
  def _execute_on_worker(self, execution_fn, **kwargs):
 
 
 
 
154
  worker_to_use = None
155
  try:
156
  with self.lock:
157
  worker_to_use = self.workers[self.current_worker_index]
158
  self.current_worker_index = (self.current_worker_index + 1) % len(self.workers)
159
 
160
+ result, padding_vals = execution_fn(worker_to_use, **kwargs)
 
 
161
 
162
  return result, padding_vals
163
 
 
165
  logger.error(f"LTX POOL MANAGER: Erro durante a execução em {worker_to_use.device if worker_to_use else 'N/A'}: {e}", exc_info=True)
166
  raise e
167
  finally:
 
168
  if worker_to_use and worker_to_use.device.type == 'cuda':
169
  with torch.cuda.device(worker_to_use.device):
170
  gc.collect()
171
  torch.cuda.empty_cache()
172
 
173
  def generate_latent_fragment(self, **kwargs) -> (torch.Tensor, tuple):
174
+ def execution_logic(worker, **inner_kwargs):
175
+ pipeline_params, padding_vals = self._prepare_and_log_params(worker, **inner_kwargs)
176
+ pipeline_params['output_type'] = "latent"
 
 
177
  with torch.no_grad():
178
+ result_tensor = worker.generate_video_fragment_internal(**pipeline_params)
179
+ return result_tensor, padding_vals
180
 
181
  return self._execute_on_worker(execution_logic, **kwargs)
182
 
183
  def refine_latents(self, upscaled_latents: torch.Tensor, **kwargs) -> (torch.Tensor, tuple):
184
+ def execution_logic(worker, **inner_kwargs):
185
+ pipeline_params, padding_vals = self._prepare_and_log_params(worker, **inner_kwargs)
186
+
187
+ # --- LÓGICA DE REFINAMENTO EXPLÍCITA E CORRIGIDA ---
188
+ strength = inner_kwargs.get('denoise_strength', 0.4)
189
+ num_refine_steps = int(inner_kwargs.get('refine_steps', 10))
 
 
190
 
191
+ scheduler = worker.pipeline.scheduler
192
+ scheduler.set_timesteps(num_refine_steps, device=worker.device)
193
+ timesteps = scheduler.timesteps
194
+
195
+ start_timestep_idx = int(num_refine_steps * strength)
196
+ if start_timestep_idx >= len(timesteps):
197
+ start_timestep_idx = len(timesteps) - 1
198
+ start_timestep = timesteps[start_timestep_idx]
199
+
200
+ noise = torch.randn_like(upscaled_latents, device=worker.device)
201
+ noisy_latents = scheduler.add_noise(upscaled_latents.to(worker.device), noise, start_timestep)
202
+
203
+ pipeline_params['latents'] = noisy_latents.to(worker.device, dtype=worker.pipeline.transformer.dtype)
204
+ pipeline_params['timesteps'] = timesteps[start_timestep_idx:]
205
+ pipeline_params['num_inference_steps'] = len(pipeline_params['timesteps'])
206
+ pipeline_params.pop('strength', None)
207
+ pipeline_params['output_type'] = "latent"
208
+
209
+ logger.info("LTX POOL MANAGER: Iniciando passe de refinamento (denoise) com controle manual de ruído.")
210
 
211
  with torch.no_grad():
212
+ refined_tensor = worker.generate_video_fragment_internal(**pipeline_params)
213
+
214
+ return refined_tensor, padding_vals
215
 
216
  return self._execute_on_worker(execution_logic, upscaled_latents=upscaled_latents, **kwargs)
217