# Project EmbodiedGen # # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Asset scaling utility for resizing 3D assets and URDF files.""" import shutil import xml.etree.ElementTree as ET from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from pathlib import Path from typing import Optional import trimesh import tyro from embodied_gen.models.gs_model import GaussianOperator from embodied_gen.utils.log import logger __all__ = ["AssetScaler", "AssetScaleConfig", "scale_asset", "entrypoint"] # URDF metadata height fields (shared with urdf_convertor.py) URDF_HEIGHT_FIELDS = ("min_height", "max_height", "real_height") # Asset directory structure conventions URDF_RESULT_DIR = "result" MESH_DIR = "mesh" @dataclass class AssetScaleConfig: """Configuration for asset scaling. Args: urdf_path: Path to the URDF file to scale. scale_factor: Scaling factor (e.g., 0.8 for 80% size). inplace: If True, modify files in-place. output_dir is not required. output_dir: Root output directory for scaled assets (not needed if inplace=True). """ urdf_path: str scale_factor: float inplace: bool = False output_dir: Optional[str] = None class AssetScaler: """Scale 3D assets including meshes, Gaussian splats, and URDF metadata. This class handles the complete scaling workflow for embodied assets, processing OBJ, GLB, collision meshes, Gaussian splatting models, and URDF metadata files. """ def __init__( self, urdf_path: str | Path, scale_factor: float, output_dir: Optional[str | Path] = None, inplace: bool = False, ) -> None: """Initialize the asset scaler. Args: urdf_path: Path to the URDF file to scale. scale_factor: Scaling factor (e.g., 0.8 for 80% size). output_dir: Root output directory for scaled assets (not needed if inplace=True). inplace: If True, modify files in-place instead of copying to output_dir. Raises: FileNotFoundError: If URDF file does not exist. ValueError: If scale_factor is not positive, or if neither output_dir nor inplace is specified. """ self.urdf_path = Path(urdf_path) self.scale_factor = scale_factor self.inplace = inplace if not self.urdf_path.exists(): raise FileNotFoundError(f"URDF file not found: {self.urdf_path}") if self.scale_factor <= 0: raise ValueError( f"Scale factor must be positive, got: {self.scale_factor}" ) # Derive asset directory structure from URDF path # URDF is at: /result/.urdf self.asset_dir = self.urdf_path.parent.parent self.node_name = self.urdf_path.stem # Handle inplace mode if self.inplace: self.output_dir = self.asset_dir.parent logger.info( f"Running in inplace mode, will modify {self.asset_dir} directly" ) else: if output_dir is None: raise ValueError("output_dir is required when inplace=False") self.output_dir = Path(output_dir) def scale(self) -> Path: """Execute the complete scaling workflow. Returns: Path to the output URDF file. Raises: FileNotFoundError: If required mesh files are missing. """ if self.inplace: # Inplace mode: scale directly in asset_dir output_urdf_path = self.urdf_path self._scale_mesh_files_parallel(self.asset_dir) self._scale_urdf_metadata(output_urdf_path) logger.info( f"Scaled {self.asset_dir} by x{self.scale_factor} (inplace)" ) else: # Normal mode: copy to output_dir and scale relative_asset_dir = self.asset_dir.name output_asset_dir = self.output_dir / relative_asset_dir output_asset_dir.mkdir(parents=True, exist_ok=True) # Copy entire asset directory structure first output_urdf_path = self._copy_asset_structure(output_asset_dir) # Scale all mesh files in parallel self._scale_mesh_files_parallel(output_asset_dir) # Scale URDF metadata self._scale_urdf_metadata(output_urdf_path) logger.info( f"Scaled {self.asset_dir} by x{self.scale_factor} -> {output_asset_dir}" ) return output_urdf_path def _copy_asset_structure(self, output_asset_dir: Path) -> Path: """Copy asset directory structure to output location. Args: output_asset_dir: Target directory for copied assets. Returns: Path to the copied URDF file. """ # Use ignore_errors=True to avoid TOCTOU race condition shutil.rmtree(output_asset_dir, ignore_errors=True) shutil.copytree(self.asset_dir, output_asset_dir) output_urdf_path = ( output_asset_dir / URDF_RESULT_DIR / f"{self.node_name}.urdf" ) return output_urdf_path def _scale_mesh_files_parallel(self, output_asset_dir: Path) -> None: """Scale all mesh files in parallel for efficiency. Args: output_asset_dir: Directory containing assets to scale. """ mesh_dir = output_asset_dir / URDF_RESULT_DIR / MESH_DIR # Define mesh scaling tasks tasks = [ (mesh_dir / f"{self.node_name}.obj", self._scale_obj_mesh), (mesh_dir / f"{self.node_name}.glb", self._scale_glb_mesh), ( mesh_dir / f"{self.node_name}_collision.obj", self._scale_collision_mesh, ), ( mesh_dir / f"{self.node_name}_gs.ply", self._scale_gaussian_splat, ), ] # Process files in parallel with ThreadPoolExecutor(max_workers=4) as executor: futures = [executor.submit(task, path) for path, task in tasks] for future in futures: future.result() # Propagate any exceptions def _scale_obj_mesh(self, mesh_path: Path) -> None: """Scale OBJ mesh file.""" if not mesh_path.exists(): return mesh = trimesh.load(str(mesh_path)) mesh.apply_scale(self.scale_factor) mesh.export(str(mesh_path)) def _scale_glb_mesh(self, mesh_path: Path) -> None: """Scale GLB mesh file.""" if not mesh_path.exists(): return mesh = trimesh.load(str(mesh_path)) for mesh_part in mesh.geometry.values(): mesh_part.apply_scale(self.scale_factor) mesh.export(str(mesh_path)) def _scale_collision_mesh(self, mesh_path: Path) -> None: """Scale collision mesh file.""" if not mesh_path.exists(): return meshes = self._load_collision_obj(str(mesh_path)) scene = trimesh.Scene() for mesh_part in meshes: mesh_part.apply_scale(self.scale_factor) scene.add_geometry(mesh_part) scene.export(str(mesh_path)) def _scale_gaussian_splat(self, mesh_path: Path) -> None: """Scale Gaussian splatting model.""" if not mesh_path.exists(): return gs_model: GaussianOperator = GaussianOperator.load_from_ply( str(mesh_path) ) gs_model.rescale(self.scale_factor) gs_model.save_to_ply(str(mesh_path)) def _scale_urdf_metadata(self, urdf_path: Path) -> None: """Scale height metadata in URDF file. Args: urdf_path: Path to URDF file to modify. """ tree = ET.parse(str(urdf_path)) root = tree.getroot() extra_info = root.find("link/extra_info") if extra_info is None: logger.warning(f"No extra_info found in URDF: {urdf_path}") return for height_field in URDF_HEIGHT_FIELDS: element = extra_info.find(height_field) if element is not None and element.text: scaled_value = float(element.text) * self.scale_factor element.text = f"{scaled_value:.3f}" tree.write(str(urdf_path), encoding="utf-8", xml_declaration=True) @staticmethod def _load_collision_obj(filepath: str) -> list[trimesh.Trimesh]: """Robustly load collision OBJ with multiple objects. Handles OBJ files with multiple objects/groups by parsing manually to avoid issues with trimesh's default loader. Args: filepath: Path to collision OBJ file. Returns: List of trimesh objects, one per object group in the file. """ vertices = [] meshes = [] current_faces = [] # Use lazy iteration instead of readlines() for memory efficiency with open(filepath, "r") as f: for line in f: if line.startswith("v "): parts = line.split() vertices.append( [float(parts[1]), float(parts[2]), float(parts[3])] ) elif line.startswith("f "): parts = line.split() face = [int(p.split("/")[0]) - 1 for p in parts[1:]] current_faces.append(face) elif line.startswith("o ") or line.startswith("g "): if current_faces and vertices: m = trimesh.Trimesh( vertices=vertices, faces=current_faces, process=False, ) m.remove_unreferenced_vertices() meshes.append(m) current_faces = [] # Flush final mesh if current_faces and vertices: m = trimesh.Trimesh( vertices=vertices, faces=current_faces, process=False ) m.remove_unreferenced_vertices() meshes.append(m) return meshes def scale_asset( urdf_path: str | Path, scale_factor: float, output_dir: Optional[str | Path] = None, inplace: bool = False, ) -> Path: """Scale a 3D asset from URDF file. Args: urdf_path: Path to the URDF file to scale. scale_factor: Scaling factor (e.g., 0.8 for 80% size). output_dir: Root output directory for scaled assets (not needed if inplace=True). inplace: If True, modify files in-place instead of copying to output_dir. Returns: Path to the output URDF file. """ scaler = AssetScaler(urdf_path, scale_factor, output_dir, inplace) return scaler.scale() def entrypoint() -> None: """CLI entrypoint for asset scaling.""" config = tyro.cli(AssetScaleConfig) output_urdf = scale_asset( urdf_path=config.urdf_path, scale_factor=config.scale_factor, output_dir=config.output_dir, inplace=config.inplace, ) logger.info(f"Scaled asset successfully: {output_urdf}") if __name__ == "__main__": entrypoint()