| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import logging |
| import os |
| import random |
| import subprocess |
|
|
| import numpy as np |
| import torch |
| from diffusers import ( |
| AutoencoderKL, |
| EulerDiscreteScheduler, |
| UNet2DConditionModel, |
| ) |
| from kolors.models.modeling_chatglm import ChatGLMModel |
| from kolors.models.tokenization_chatglm import ChatGLMTokenizer |
| from kolors.models.unet_2d_condition import ( |
| UNet2DConditionModel as UNet2DConditionModelIP, |
| ) |
| from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import ( |
| StableDiffusionXLPipeline, |
| ) |
| from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import ( |
| StableDiffusionXLPipeline as StableDiffusionXLPipelineIP, |
| ) |
| from PIL import Image |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| __all__ = [ |
| "build_text2img_ip_pipeline", |
| "build_text2img_pipeline", |
| "text2img_gen", |
| "download_kolors_weights", |
| ] |
|
|
| PROMPT_APPEND = ( |
| "Angled 3D view of one {object}, centered, no cropping, no occlusion, isolated product photo, placed horizontally, " |
| "no surroundings, high-quality appearance, vivid colors, on a plain clean surface, 3D style revealing multiple surfaces" |
| ) |
| PROMPT_KAPPEND = "Single {object}, in the center of the image, white background, 3D style, best quality" |
|
|
|
|
| def download_kolors_weights(local_dir: str = "weights/Kolors") -> None: |
| """Downloads Kolors model weights from HuggingFace. |
| |
| Args: |
| local_dir (str, optional): Local directory to store weights. |
| """ |
| logger.info(f"Download kolors weights from huggingface...") |
| os.makedirs(local_dir, exist_ok=True) |
| subprocess.run( |
| [ |
| "huggingface-cli", |
| "download", |
| "--resume-download", |
| "Kwai-Kolors/Kolors", |
| "--local-dir", |
| local_dir, |
| ], |
| check=True, |
| ) |
|
|
| ip_adapter_path = f"{local_dir}/../Kolors-IP-Adapter-Plus" |
| subprocess.run( |
| [ |
| "huggingface-cli", |
| "download", |
| "--resume-download", |
| "Kwai-Kolors/Kolors-IP-Adapter-Plus", |
| "--local-dir", |
| ip_adapter_path, |
| ], |
| check=True, |
| ) |
|
|
|
|
| def build_text2img_ip_pipeline( |
| ckpt_dir: str, |
| ref_scale: float, |
| device: str = "cuda", |
| ) -> StableDiffusionXLPipelineIP: |
| """Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation. |
| |
| Args: |
| ckpt_dir (str): Directory containing model checkpoints. |
| ref_scale (float): Reference scale for IP-Adapter. |
| device (str, optional): Device for inference. |
| |
| Returns: |
| StableDiffusionXLPipelineIP: Configured pipeline. |
| |
| Example: |
| ```py |
| from embodied_gen.models.text_model import build_text2img_ip_pipeline |
| pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3) |
| ``` |
| """ |
| download_kolors_weights(ckpt_dir) |
|
|
| text_encoder = ChatGLMModel.from_pretrained( |
| f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 |
| ).half() |
| tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") |
| vae = AutoencoderKL.from_pretrained( |
| f"{ckpt_dir}/vae", revision=None |
| ).half() |
| scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") |
| unet = UNet2DConditionModelIP.from_pretrained( |
| f"{ckpt_dir}/unet", revision=None |
| ).half() |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
| f"{ckpt_dir}/../Kolors-IP-Adapter-Plus/image_encoder", |
| ignore_mismatched_sizes=True, |
| ).to(dtype=torch.float16) |
| clip_image_processor = CLIPImageProcessor(size=336, crop_size=336) |
|
|
| pipe = StableDiffusionXLPipelineIP( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| image_encoder=image_encoder, |
| feature_extractor=clip_image_processor, |
| force_zeros_for_empty_prompt=False, |
| ) |
|
|
| if hasattr(pipe.unet, "encoder_hid_proj"): |
| pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj |
|
|
| pipe.load_ip_adapter( |
| f"{ckpt_dir}/../Kolors-IP-Adapter-Plus", |
| subfolder="", |
| weight_name=["ip_adapter_plus_general.bin"], |
| ) |
| pipe.set_ip_adapter_scale([ref_scale]) |
|
|
| pipe = pipe.to(device) |
| pipe.image_encoder = pipe.image_encoder.to(device) |
| |
| |
| |
|
|
| return pipe |
|
|
|
|
| def build_text2img_pipeline( |
| ckpt_dir: str, |
| device: str = "cuda", |
| ) -> StableDiffusionXLPipeline: |
| """Builds a Stable Diffusion XL pipeline for text-to-image generation. |
| |
| Args: |
| ckpt_dir (str): Directory containing model checkpoints. |
| device (str, optional): Device for inference. |
| |
| Returns: |
| StableDiffusionXLPipeline: Configured pipeline. |
| |
| Example: |
| ```py |
| from embodied_gen.models.text_model import build_text2img_pipeline |
| pipe = build_text2img_pipeline("weights/Kolors") |
| ``` |
| """ |
| download_kolors_weights(ckpt_dir) |
|
|
| text_encoder = ChatGLMModel.from_pretrained( |
| f"{ckpt_dir}/text_encoder", torch_dtype=torch.float16 |
| ).half() |
| tokenizer = ChatGLMTokenizer.from_pretrained(f"{ckpt_dir}/text_encoder") |
| vae = AutoencoderKL.from_pretrained( |
| f"{ckpt_dir}/vae", revision=None |
| ).half() |
| scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") |
| unet = UNet2DConditionModel.from_pretrained( |
| f"{ckpt_dir}/unet", revision=None |
| ).half() |
| pipe = StableDiffusionXLPipeline( |
| vae=vae, |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| unet=unet, |
| scheduler=scheduler, |
| force_zeros_for_empty_prompt=False, |
| ) |
| pipe = pipe.to(device) |
| |
| |
|
|
| return pipe |
|
|
|
|
| def text2img_gen( |
| prompt: str, |
| n_sample: int, |
| guidance_scale: float, |
| pipeline: StableDiffusionXLPipeline | StableDiffusionXLPipelineIP, |
| ip_image: Image.Image | str = None, |
| image_wh: tuple[int, int] = [1024, 1024], |
| infer_step: int = 50, |
| ip_image_size: int = 512, |
| seed: int = None, |
| ) -> list[Image.Image]: |
| """Generates images from text prompts using a Stable Diffusion XL pipeline. |
| |
| Args: |
| prompt (str): Text prompt for image generation. |
| n_sample (int): Number of images to generate. |
| guidance_scale (float): Guidance scale for diffusion. |
| pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance. |
| ip_image (Image.Image | str, optional): Reference image for IP-Adapter. |
| image_wh (tuple[int, int], optional): Output image size (width, height). |
| infer_step (int, optional): Number of inference steps. |
| ip_image_size (int, optional): Size for IP-Adapter image. |
| seed (int, optional): Random seed. |
| |
| Returns: |
| list[Image.Image]: List of generated images. |
| |
| Example: |
| ```py |
| from embodied_gen.models.text_model import text2img_gen |
| images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5) |
| images[0].save("banana.png") |
| ``` |
| """ |
| prompt = PROMPT_KAPPEND.format(object=prompt.strip()) |
| logger.info(f"Processing prompt: {prompt}") |
|
|
| generator = None |
| if seed is not None: |
| generator = torch.Generator(pipeline.device).manual_seed(seed) |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| kwargs = dict( |
| prompt=prompt, |
| height=image_wh[1], |
| width=image_wh[0], |
| num_inference_steps=infer_step, |
| guidance_scale=guidance_scale, |
| num_images_per_prompt=n_sample, |
| generator=generator, |
| ) |
| if ip_image is not None: |
| if isinstance(ip_image, str): |
| ip_image = Image.open(ip_image) |
| ip_image = ip_image.resize((ip_image_size, ip_image_size)) |
| kwargs.update(ip_adapter_image=[ip_image]) |
|
|
| return pipeline(**kwargs).images |
|
|