gaussian-city / app.py
hysts's picture
hysts HF Staff
feat: migrate to ZeroGPU on Python 3.12 + PyTorch 2.8.
46b4a9e
Raw
History Blame
3.82 kB
import logging
import os
import pickle
import ssl
import sys
import urllib.request
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
ssl._create_default_https_context = ssl._create_unverified_context
sys.path.append(os.path.join(os.path.dirname(__file__), "gaussiancity"))
def get_models(file_name):
import gaussiancity.generator
if not os.path.exists(file_name):
urllib.request.urlretrieve(
f"https://huggingface.co/hzxie/gaussian-city/resolve/main/{file_name}",
file_name,
)
ckpt = torch.load(file_name, map_location="cpu", weights_only=False)
model = gaussiancity.generator.Generator(
ckpt["cfg"].NETWORK.GAUSSIAN,
n_classes=ckpt["cfg"].DATASETS.GOOGLE_EARTH.N_CLASSES,
proj_size=ckpt["cfg"].DATASETS.GOOGLE_EARTH.PROJ_SIZE,
)
model = torch.nn.DataParallel(model).cuda().eval()
model.load_state_dict(ckpt["gaussian_g"], strict=False)
return model
def get_city_layout():
import gaussiancity.inference
if os.path.exists("assets/NYC.pkl"):
with open("assets/NYC.pkl", "rb") as fp:
layout = pickle.load(fp)
else:
td_hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
# Fix: nonzero is not supported for tensors with more than INT_MAX elements
td_hf[td_hf > 500] = 500
bu_hf = np.zeros_like(td_hf)
seg_map = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(np.int32)
ins_map = gaussiancity.inference.get_instance_seg_map(seg_map.copy())
pts_map = gaussiancity.inference.get_point_map(seg_map)
layout = {
"TD_HF": td_hf,
"BU_HF": bu_hf,
"SEG": seg_map,
"INS": ins_map,
"PTS": pts_map,
}
with open("assets/NYC.pkl", "wb") as fp:
pickle.dump(layout, fp)
if os.path.exists("assets/CENTERS.pkl"):
with open("assets/CENTERS.pkl", "rb") as fp:
centers = pickle.load(fp)
else:
centers = gaussiancity.inference.get_centers(layout["INS"], layout["TD_HF"])
with open("assets/CENTERS.pkl", "wb") as fp:
pickle.dump(centers, fp)
layout["CTR"] = centers
return layout
logging.basicConfig(format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO)
logging.info("Loading pretrained models...")
fgm = get_models("GaussianCity-Fgnd.pth")
bgm = get_models("GaussianCity-Bgnd.pth")
logging.info("Loading New York city layout to RAM...")
city_layout = get_city_layout()
@spaces.GPU
def get_generated_city(radius, altitude, azimuth, map_center):
import gaussiancity.inference
return gaussiancity.inference.generate_city(
fgm.to("cuda"),
bgm.to("cuda"),
city_layout,
map_center,
map_center,
radius,
altitude,
azimuth,
)
def main():
title = "Generative Gaussian Splatting for Unbounded 3D City Generation"
with open("README.md", "r") as f:
markdown = f.read()
desc = markdown[markdown.rfind("---") + 3 :]
with open("ARTICLE.md", "r") as f:
arti = f.read()
app = gr.Interface(
get_generated_city,
[
gr.Slider(256, 960, value=768, step=4, label="Camera Radius (m)"),
gr.Slider(256, 960, value=768, step=4, label="Camera Altitude (m)"),
gr.Slider(0, 360, value=210, step=5, label="Camera Azimuth (°)"),
gr.Slider(1024, 7168, value=3570, step=4, label="Map Center (px)"),
],
[gr.Image(type="numpy", label="Generated City")],
title=title,
description=desc,
article=arti,
flagging_mode="never",
)
app.queue()
app.launch()
if __name__ == "__main__":
main()