import torch import numpy as np from PIL import Image from typing import List, Union, Optional, Tuple from dataclasses import dataclass import logging import math import requests from io import BytesIO import warnings from tqdm import tqdm from transformers import logging as transformers_logging from diffusers import DiffusionPipeline from diffusers.utils import BaseOutput @dataclass class SdxsPipelineOutput(BaseOutput): images: Union[List[Image.Image], List[List[Image.Image]], np.ndarray] prompt: Optional[Union[str, List[str]]] = None class SdxsPipeline(DiffusionPipeline): MAX_TEXT_TOKENS = 250 def __init__(self, vae, text_encoder, tokenizer, transformer, scheduler): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 8 if getattr(self.tokenizer, "pad_token_id", None) is None: self.tokenizer.pad_token_id = getattr(self.tokenizer, "eos_token_id", 248044) if getattr(self.text_encoder.config, "pad_token_id", None) is None: self.text_encoder.config.pad_token_id = self.tokenizer.pad_token_id @staticmethod def load_image(source: Union[str, Image.Image], max_size: int = 1152) -> Tuple[Image.Image, int, int]: """ Загрузчик картинок с умным ресайзом для диффузионных сетей. - Приводит большую сторону к max_size (увеличивает или уменьшает). - Меньшую сторону делает кратной 8 (округляя вниз). - Делает center-crop, чтобы не искажать пропорции. """ # 1. Загрузка изображения if isinstance(source, Image.Image): img = source elif isinstance(source, str): if source.startswith(("http://", "https://")): response = requests.get(source, stream=True) response.raise_for_status() img = Image.open(BytesIO(response.content)) else: img = Image.open(source) else: raise ValueError("Источник должен быть объектом PIL.Image, URL или путем к файлу.") # Обязательный перевод в RGB img = img.convert("RGB") orig_w, orig_h = img.size # 2. Вычисляем масштаб так, чтобы бОльшая сторона стала ровно max_size scale = max_size / max(orig_w, orig_h) target_w = int(orig_w * scale) target_h = int(orig_h * scale) # 3. Делаем пропорциональный ресайз img = img.resize((target_w, target_h), Image.LANCZOS) # 4. Вычисляем финальные размеры, кратные 8 (округление вниз) # Так как 1152 делится на 8 без остатка, большая сторона не изменится. # Изменится только меньшая сторона (максимум на 7 пикселей). final_w = (target_w // 8) * 8 final_h = (target_h // 8) * 8 # 5. Центральный кроп (если размеры изменились) if final_w != target_w or final_h != target_h: left = (target_w - final_w) // 2 top = (target_h - final_h) // 2 right = left + final_w bottom = top + final_h img = img.crop((left, top, right, bottom)) return img, final_w, final_h @torch.no_grad() def refine_prompts( self, prompts: Union[str, List[str]], system_prompt: Optional[str] = None, temperature: float = 0.6, # Снижено для thinking VL enable_thinking: bool = True ) -> List[str]: """Refines a list of prompts using the Text Encoder (LLM).""" device = self.device if system_prompt is None: system_prompt = ( "You are a visual translator. Take the user's input in any language and convert it into an English 'compact image description'. " "Describe what is visually happening in 1-2 sentences. " "Output ONLY the final description.\n\n" "Input: " ) prompts_list = [prompts] if isinstance(prompts, str) else prompts refined_list = [] for p in prompts_list: full_text = system_prompt + p messages = [{"role": "user", "content": [{"type": "text", "text": full_text}]}] # 1. ВКЛЮЧАЕМ THINKING ЗДЕСЬ через chat_template_kwargs inputs = self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", chat_template_kwargs={"enable_thinking": enable_thinking} ).to(device) # 2. Используем только нативные для transformers параметры generated_ids = self.text_encoder.generate( **inputs, max_new_tokens=self.MAX_TEXT_TOKENS, do_sample=True, temperature=temperature, top_p=0.95, # Вместо top_k используем top_p repetition_penalty=1.15, # Снижено с 1.15 (presence_penalty в HF заменяется этим) pad_token_id=self.tokenizer.pad_token_id, ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = self.tokenizer.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] # 3. ВАЖНО: Вырезаем блок размышлений ... из финального ответа # Если модель думала, она выдаст мысль, а потом ответ. Нам нужен только ответ. if " Tuple[torch.Tensor, torch.Tensor]: device = self.device dtype = self.transformer.dtype if text is None: text = "" if isinstance(text, str): text = [text] formatted_prompts = [] for t in text: messages = [{"role": "user", "content": [{"type": "text", "text": t}]}] formatted_prompts.append(self.tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)) toks = self.tokenizer( formatted_prompts, padding="max_length", max_length=self.MAX_TEXT_TOKENS, truncation=True, return_tensors="pt" ).to(device) outputs = self.text_encoder( input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True ) last_hidden = outputs.hidden_states[-2].to(dtype=dtype) lengths = toks.attention_mask.sum(dim=1) for i, length in enumerate(lengths): last_hidden[i, length:] = 0 return last_hidden, toks.attention_mask.to(dtype=torch.int64) @torch.no_grad() def encode_image(self, image: Image.Image, height: int, width: int) -> torch.Tensor: """Кодирует PIL Image в латенты с учетом нормализации VAE Cosmos.""" device = self.device dtype = self.transformer.dtype # 1. Ресайз и приведение к тензору [-1, 1] image = image.convert("RGB").resize((width, height), Image.LANCZOS) img_np = np.array(image).astype(np.float32) / 255.0 img_tensor = torch.from_numpy(img_np).permute(2, 0, 1) # [C, H, W] img_tensor = (img_tensor - 0.5) / 0.5 # Нормализация в [-1, 1] img_tensor = img_tensor.unsqueeze(0).unsqueeze(2) # Добавляем Batch и Time -> [1, C, 1, H, W] # Переключаем VAE на GPU для энкода self.vae.to(device=device) # 2. Энкодинг with torch.no_grad(): latent_dist = self.vae.encode(img_tensor.to(dtype=dtype, device=device)).latent_dist latents = latent_dist.sample() # [1, C, 1, H, W] # 3. Обратная нормализация (как у тебя в датасете) # Переводим из пространства VAE в нормализованное пространство Diffusion if getattr(self.vae.config, "latents_std", None) is not None: l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype) l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype) latents = (latents - l_mean) / l_std # Убираем VAE с GPU для экономии памяти self.vae.to("cpu") torch.cuda.empty_cache() return latents @torch.no_grad() def __call__( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 1152, width: int = 768, num_inference_steps: int = 35, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, seed: Optional[int] = None, output_type: str = "pil", return_dict: bool = True, show_progress_bar: bool = False, ): """Standard text-to-image generation (Frames=1)""" return self._generate( prompt=prompt, negative_prompt=negative_prompt, num_frames=1, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, seed=seed, output_type=output_type, show_progress_bar=show_progress_bar, image=None, strength=None ) @torch.no_grad() def img2img( self, image: Image.Image, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, strength: float = 0.95, height: int = 1152, width: int = 768, num_inference_steps: int = 40, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, seed: Optional[int] = None, output_type: str = "pil", return_dict: bool = True, show_progress_bar: bool = True, ): """Image-to-Image generation""" return self._generate( prompt=prompt, negative_prompt=negative_prompt, num_frames=1, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, seed=seed, output_type=output_type, show_progress_bar=show_progress_bar, image=image, strength=strength ) @torch.no_grad() def generate_video( self, prompt: Optional[Union[str, List[str]]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, num_frames: int = 9, height: int = 512, width: int = 512, num_inference_steps: int = 40, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, seed: Optional[int] = None, output_type: str = "pil", show_progress_bar: bool = True, ): """Text-to-video generation (Frames > 1)""" return self._generate( prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, seed=seed, output_type=output_type, show_progress_bar=show_progress_bar, image=None, strength=None ) @torch.no_grad() def _generate( self, prompt, negative_prompt, num_frames, height, width, num_inference_steps, guidance_scale, num_images_per_prompt, # Принимаем параметр seed, output_type, show_progress_bar, image=None, strength=None ): device = self.device dtype = self.transformer.dtype if seed is not None: generator = torch.Generator(device=device).manual_seed(seed) else: generator = None do_classifier_free_guidance = guidance_scale > 1.0 # 1. Encode Prompts prompt_embeds, prompt_mask = self.encode_text(prompt) prompt_batch_size = prompt_embeds.shape[0] # Итоговый батч сайз зависит от кол-ва промптов и множителя batch_size = prompt_batch_size * num_images_per_prompt # Дублируем позитивные эмбеддинги, если запрашивается > 1 изображения на промпт if num_images_per_prompt > 1: prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: # Обрабатываем негативный промпт if negative_prompt is None: neg_text = [""] * batch_size neg_embeds, neg_mask = self.encode_text(neg_text) else: neg_embeds, neg_mask = self.encode_text(negative_prompt) # Выравниваем размерности негативного промпта под финальный batch_size if neg_embeds.shape[0] < batch_size: if neg_embeds.shape[0] == 1: # Один негативный промпт применяется ко всем neg_embeds = neg_embeds.repeat(batch_size, 1, 1) elif neg_embeds.shape[0] == prompt_batch_size and num_images_per_prompt > 1: # Если негативных столько же, сколько позитивных, размножаем их пропорционально neg_embeds = neg_embeds.repeat_interleave(num_images_per_prompt, dim=0) else: raise ValueError(f"Cannot map {neg_embeds.shape[0]} negative prompts to {batch_size} outputs.") text_embeddings = torch.cat([neg_embeds, prompt_embeds], dim=0) else: text_embeddings = prompt_embeds # 2. Prepare Timesteps (Bridging Diffusers Scheduler + Karras EDM) # Ищем sigma_max/min в конфиге шедулера (если ты их туда добавил), # если их там нет — используем константы Cosmos по умолчанию. sigma_max = getattr(self.scheduler.config, "sigma_max", 10.0) sigma_min = getattr(self.scheduler.config, "sigma_min", 0.002) # Создаем наше любимое экспоненциальное убывание для лучшей детализации custom_sigmas = torch.exp(torch.linspace( math.log(sigma_max), math.log(sigma_min), num_inference_steps, device=device, dtype=torch.float32 )) custom_sigmas[-1] = sigma_min # Безопасная интеграция со стандартным шедулером diffusers # Пробуем скормить ему кастомные сигмы (поддерживается в новых версиях diffusers) try: self.scheduler.set_timesteps(sigmas=custom_sigmas, device=device) except TypeError: # Фолбэк для старых версий или стандартного FlowMatchEulerDiscreteScheduler: # Задаем шаги, а затем принудительно перезаписываем внутренние массивы времени self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.timesteps = custom_sigmas.to(device=device, dtype=self.scheduler.timesteps.dtype) # Добавляем 0.0 в конец sigmas, чтобы метод step() правильно посчитал dt на последнем шаге self.scheduler.sigmas = torch.cat([custom_sigmas, torch.tensor([0.0], device=device)]) timesteps = self.scheduler.timesteps # 3. Prepare Latents latent_h = height // self.vae_scale_factor latent_w = width // self.vae_scale_factor if num_frames > 1 and hasattr(self.vae.config, "temporal_downsample") and any(self.vae.config.temporal_downsample): temporal_factor = 4 latent_f = (num_frames - 1) // temporal_factor + 1 else: latent_f = num_frames in_channels = self.transformer.config.in_channels latents = torch.randn( (batch_size, in_channels, latent_f, latent_h, latent_w), generator=generator, device=device, dtype=dtype ) # ВАЖНО: Умножаем начальный шум на sigma_max latents = latents * sigma_max padding_mask = torch.zeros((1, 1, latent_h, latent_w), device=device, dtype=dtype) # 4. Denoising Loop (Теперь со стандартным шедулером!) for i, t in enumerate(tqdm(timesteps, desc=f"Sampling {num_frames} frames", disable=not show_progress_bar)): # Берем текущую сигму прямо из шедулера current_sigma = self.scheduler.sigmas[i] # Математика прекондиционирования Cosmos остается неизменной current_t = current_sigma / (current_sigma + 1.0) c_in = 1.0 - current_t c_skip = 1.0 - current_t c_out = -current_t timestep_tensor = current_t.expand(latents.shape[0]).to(dtype) # Дублируем входы для CFG latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents timestep_tensor = torch.cat([timestep_tensor] * 2) if do_classifier_free_guidance else timestep_tensor # Применяем c_in latent_model_input = latent_model_input * c_in model_output = self.transformer( hidden_states=latent_model_input, timestep=timestep_tensor, encoder_hidden_states=text_embeddings, padding_mask=padding_mask, return_dict=False, )[0] if do_classifier_free_guidance: v_uncond, v_cond = model_output.chunk(2) # Получаем предсказание чистой картинки (x_0) x0_cond = (c_skip * latents + c_out * v_cond.float()).to(dtype) x0_uncond = (c_skip * latents + c_out * v_uncond.float()).to(dtype) # Применяем CFG x0_pred = x0_uncond + guidance_scale * (x0_cond - x0_uncond) else: x0_pred = (c_skip * latents + c_out * model_output.float()).to(dtype) # --- МАГИЯ ИНТЕГРАЦИИ С DIFFUSERS --- # Стандартный шедулер ожидает на вход вектор скорости (derivative/noise_pred). # Мы переводим наш x0_pred обратно в формат производной: noise_pred = (latents - x0_pred) / current_sigma # Передаем управление методу .step()! # Он сам внутри посчитает dt и обновит латенты. latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # 5. Decode if output_type == "latent": return SdxsPipelineOutput(images=latents) if getattr(self.vae.config, "latents_std", None) is not None: l_mean = torch.tensor(self.vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, dtype) l_std = torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, dtype) latents = latents * l_std + l_mean # Декодируем видео-латенты with torch.no_grad(): image_output = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] # Нормализация image_output = (image_output.clamp(-1, 1) + 1) / 2 # image_output имеет размерность [B, C, F, H, W] # Переводим в [B, F, H, W, C] для numpy image_np = image_output.cpu().permute(0, 2, 3, 4, 1).float().numpy() image_np = np.nan_to_num(image_np, nan=0.0) if output_type == "pil": batch_images = [] for b in range(batch_size): frames = [] # Итерируемся по оси фреймов (F) for f in range(image_np.shape[1]): frame_arr = (image_np[b, f] * 255).round().astype("uint8") frames.append(Image.fromarray(frame_arr)) # Если генерировали 1 кадр, возвращаем просто список картинок (как раньше) if num_frames == 1: batch_images.append(frames[0]) else: # Для видео возвращаем список списков кадров (по списку на каждый элемент батча) batch_images.append(frames) images = batch_images else: images = image_np return SdxsPipelineOutput(images=images, prompt=prompt)