File size: 3,821 Bytes
83d5461
 
 
 
 
 
 
46b4a9e
 
 
83d5461
46b4a9e
83d5461
 
 
46b4a9e
83d5461
 
 
 
 
 
 
46b4a9e
83d5461
 
 
46b4a9e
83d5461
 
 
 
 
46b4a9e
83d5461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b4a9e
83d5461
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46b4a9e
 
 
 
 
 
 
 
 
 
83d5461
 
 
 
 
46b4a9e
 
 
83d5461
 
 
 
 
 
 
 
46b4a9e
83d5461
 
 
 
 
 
 
 
 
 
f424e40
 
 
83d5461
 
 
 
 
 
 
 
46b4a9e
 
83d5461
 
 
46b4a9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import logging
import os
import pickle
import ssl
import sys
import urllib.request

import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image

ssl._create_default_https_context = ssl._create_unverified_context

sys.path.append(os.path.join(os.path.dirname(__file__), "gaussiancity"))


def get_models(file_name):
    import gaussiancity.generator

    if not os.path.exists(file_name):
        urllib.request.urlretrieve(
            f"https://huggingface.co/hzxie/gaussian-city/resolve/main/{file_name}",
            file_name,
        )

    ckpt = torch.load(file_name, map_location="cpu", weights_only=False)
    model = gaussiancity.generator.Generator(
        ckpt["cfg"].NETWORK.GAUSSIAN,
        n_classes=ckpt["cfg"].DATASETS.GOOGLE_EARTH.N_CLASSES,
        proj_size=ckpt["cfg"].DATASETS.GOOGLE_EARTH.PROJ_SIZE,
    )
    model = torch.nn.DataParallel(model).cuda().eval()
    model.load_state_dict(ckpt["gaussian_g"], strict=False)
    return model


def get_city_layout():
    import gaussiancity.inference

    if os.path.exists("assets/NYC.pkl"):
        with open("assets/NYC.pkl", "rb") as fp:
            layout = pickle.load(fp)
    else:
        td_hf = np.array(Image.open("assets/NYC-HghtFld.png")).astype(np.int32)
        # Fix: nonzero is not supported for tensors with more than INT_MAX elements
        td_hf[td_hf > 500] = 500
        bu_hf = np.zeros_like(td_hf)
        seg_map = np.array(Image.open("assets/NYC-SegMap.png").convert("P")).astype(np.int32)
        ins_map = gaussiancity.inference.get_instance_seg_map(seg_map.copy())
        pts_map = gaussiancity.inference.get_point_map(seg_map)
        layout = {
            "TD_HF": td_hf,
            "BU_HF": bu_hf,
            "SEG": seg_map,
            "INS": ins_map,
            "PTS": pts_map,
        }
        with open("assets/NYC.pkl", "wb") as fp:
            pickle.dump(layout, fp)

    if os.path.exists("assets/CENTERS.pkl"):
        with open("assets/CENTERS.pkl", "rb") as fp:
            centers = pickle.load(fp)
    else:
        centers = gaussiancity.inference.get_centers(layout["INS"], layout["TD_HF"])
        with open("assets/CENTERS.pkl", "wb") as fp:
            pickle.dump(centers, fp)

    layout["CTR"] = centers
    return layout


logging.basicConfig(format="[%(levelname)s] %(asctime)s %(message)s", level=logging.INFO)

logging.info("Loading pretrained models...")
fgm = get_models("GaussianCity-Fgnd.pth")
bgm = get_models("GaussianCity-Bgnd.pth")

logging.info("Loading New York city layout to RAM...")
city_layout = get_city_layout()


@spaces.GPU
def get_generated_city(radius, altitude, azimuth, map_center):
    import gaussiancity.inference

    return gaussiancity.inference.generate_city(
        fgm.to("cuda"),
        bgm.to("cuda"),
        city_layout,
        map_center,
        map_center,
        radius,
        altitude,
        azimuth,
    )


def main():
    title = "Generative Gaussian Splatting for Unbounded 3D City Generation"
    with open("README.md", "r") as f:
        markdown = f.read()
        desc = markdown[markdown.rfind("---") + 3 :]
    with open("ARTICLE.md", "r") as f:
        arti = f.read()

    app = gr.Interface(
        get_generated_city,
        [
            gr.Slider(256, 960, value=768, step=4, label="Camera Radius (m)"),
            gr.Slider(256, 960, value=768, step=4, label="Camera Altitude (m)"),
            gr.Slider(0, 360, value=210, step=5, label="Camera Azimuth (°)"),
            gr.Slider(1024, 7168, value=3570, step=4, label="Map Center (px)"),
        ],
        [gr.Image(type="numpy", label="Generated City")],
        title=title,
        description=desc,
        article=arti,
        flagging_mode="never",
    )
    app.queue()
    app.launch()


if __name__ == "__main__":
    main()