xinjjj commited on
Commit
238c93c
·
1 Parent(s): d7cbf4a

feat(gpt): add GPT-5.4 support

Browse files
common.py CHANGED
@@ -682,7 +682,9 @@ def generate_texture_mvimages(
682
  if use_ip_adapter and not PIPELINE_HAS_IP_ADAPTER:
683
  logger.info("Load IP adapter into default texture pipeline")
684
  if hasattr(PIPELINE.unet, "encoder_hid_proj"):
685
- PIPELINE.unet.text_encoder_hid_proj = PIPELINE.unet.encoder_hid_proj
 
 
686
  PIPELINE.load_ip_adapter(
687
  "./weights/Kolors-IP-Adapter-Plus",
688
  subfolder="",
@@ -691,7 +693,9 @@ def generate_texture_mvimages(
691
  PIPELINE_HAS_IP_ADAPTER = True
692
 
693
  if PIPELINE_HAS_IP_ADAPTER:
694
- PIPELINE.set_ip_adapter_scale([ip_adapt_scale if use_ip_adapter else 0.0])
 
 
695
 
696
  try:
697
  img_save_paths = infer_pipe(
 
682
  if use_ip_adapter and not PIPELINE_HAS_IP_ADAPTER:
683
  logger.info("Load IP adapter into default texture pipeline")
684
  if hasattr(PIPELINE.unet, "encoder_hid_proj"):
685
+ PIPELINE.unet.text_encoder_hid_proj = (
686
+ PIPELINE.unet.encoder_hid_proj
687
+ )
688
  PIPELINE.load_ip_adapter(
689
  "./weights/Kolors-IP-Adapter-Plus",
690
  subfolder="",
 
693
  PIPELINE_HAS_IP_ADAPTER = True
694
 
695
  if PIPELINE_HAS_IP_ADAPTER:
696
+ PIPELINE.set_ip_adapter_scale(
697
+ [ip_adapt_scale if use_ip_adapter else 0.0]
698
+ )
699
 
700
  try:
701
  img_save_paths = infer_pipe(
embodied_gen/scripts/render_mv.py CHANGED
@@ -135,9 +135,7 @@ def infer_pipe(
135
  for attn_processor in pipeline.unet.attn_processors.values()
136
  )
137
  use_ip_adapter = (
138
- ip_adapt_scale > 0
139
- and ip_img_path is not None
140
- and len(ip_img_path) > 0
141
  )
142
  effective_ip_adapt_scale = ip_adapt_scale if use_ip_adapter else 0.0
143
  if use_ip_adapter:
 
135
  for attn_processor in pipeline.unet.attn_processors.values()
136
  )
137
  use_ip_adapter = (
138
+ ip_adapt_scale > 0 and ip_img_path is not None and len(ip_img_path) > 0
 
 
139
  )
140
  effective_ip_adapt_scale = ip_adapt_scale if use_ip_adapter else 0.0
141
  if use_ip_adapter:
embodied_gen/utils/gpt_clients.py CHANGED
@@ -46,6 +46,9 @@ __all__ = [
46
  _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
47
  CONFIG_FILE = os.path.join(_CURRENT_DIR, "gpt_config.yaml")
48
  DEFAULT_GPT_TIMEOUT = float(os.environ.get("GPT_TIMEOUT", 120))
 
 
 
49
 
50
 
51
  def combine_images_to_grid(
@@ -148,6 +151,11 @@ class GPTclient:
148
 
149
  logger.info(f"Using GPT model: {self.model_name}.")
150
 
 
 
 
 
 
151
  @retry(
152
  retry=retry_if_not_exception_type(openai.BadRequestError),
153
  wait=wait_random_exponential(min=1, max=10),
@@ -215,21 +223,49 @@ class GPTclient:
215
  }
216
  )
217
 
218
- payload = {
219
- "messages": [
220
- {"role": "system", "content": system_role},
221
- {"role": "user", "content": content_user},
222
- ],
223
- "temperature": 0.1,
224
- "max_tokens": 500,
225
- "top_p": 0.1,
226
- "frequency_penalty": 0,
227
- "presence_penalty": 0,
228
- "stop": None,
229
- "model": self.model_name,
230
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if params:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  payload.update(params)
234
 
235
  response = None
@@ -253,15 +289,19 @@ class GPTclient:
253
  ConnectionError: If connection fails.
254
  """
255
  try:
256
- response = self.completion_with_backoff(
257
  messages=[
258
  {"role": "system", "content": "You are a test system."},
259
  {"role": "user", "content": "Hello"},
260
  ],
261
  model=self.model_name,
262
- temperature=0,
263
- max_tokens=100,
264
  )
 
 
 
 
 
 
265
  response.choices[0].message.content
266
  logger.info("Connection check success.")
267
  except Exception:
 
46
  _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
47
  CONFIG_FILE = os.path.join(_CURRENT_DIR, "gpt_config.yaml")
48
  DEFAULT_GPT_TIMEOUT = float(os.environ.get("GPT_TIMEOUT", 120))
49
+ # GPT-5.x counts reasoning tokens against this cap, so it must be high
50
+ # enough to leave room for both reasoning and the visible reply.
51
+ GPT5_DEFAULT_MAX_COMPLETION_TOKENS = 8192
52
 
53
 
54
  def combine_images_to_grid(
 
151
 
152
  logger.info(f"Using GPT model: {self.model_name}.")
153
 
154
+ @staticmethod
155
+ def _is_gpt5_model(model_name: str) -> bool:
156
+ name = (model_name or "").lower()
157
+ return "gpt-5" in name or "gpt5" in name
158
+
159
  @retry(
160
  retry=retry_if_not_exception_type(openai.BadRequestError),
161
  wait=wait_random_exponential(min=1, max=10),
 
223
  }
224
  )
225
 
226
+ is_gpt5 = self._is_gpt5_model(self.model_name)
227
+ if is_gpt5:
228
+ # GPT-5.x only supports default temperature/top_p and uses
229
+ # `max_completion_tokens` instead of `max_tokens`.
230
+ payload = {
231
+ "messages": [
232
+ {"role": "system", "content": system_role},
233
+ {"role": "user", "content": content_user},
234
+ ],
235
+ "max_completion_tokens": GPT5_DEFAULT_MAX_COMPLETION_TOKENS,
236
+ "model": self.model_name,
237
+ }
238
+ else:
239
+ payload = {
240
+ "messages": [
241
+ {"role": "system", "content": system_role},
242
+ {"role": "user", "content": content_user},
243
+ ],
244
+ "temperature": 0.1,
245
+ "max_tokens": 500,
246
+ "top_p": 0.1,
247
+ "frequency_penalty": 0,
248
+ "presence_penalty": 0,
249
+ "stop": None,
250
+ "model": self.model_name,
251
+ }
252
 
253
  if params:
254
+ params = dict(params)
255
+ if is_gpt5:
256
+ # GPT-5.x rejects custom temperature/top_p/penalty/stop and
257
+ # uses `max_completion_tokens` instead of `max_tokens`.
258
+ if "max_tokens" in params and "max_completion_tokens" not in params:
259
+ params["max_completion_tokens"] = params.pop("max_tokens")
260
+ for k in (
261
+ "temperature",
262
+ "top_p",
263
+ "frequency_penalty",
264
+ "presence_penalty",
265
+ "stop",
266
+ "max_tokens",
267
+ ):
268
+ params.pop(k, None)
269
  payload.update(params)
270
 
271
  response = None
 
289
  ConnectionError: If connection fails.
290
  """
291
  try:
292
+ probe_kwargs = dict(
293
  messages=[
294
  {"role": "system", "content": "You are a test system."},
295
  {"role": "user", "content": "Hello"},
296
  ],
297
  model=self.model_name,
 
 
298
  )
299
+ if self._is_gpt5_model(self.model_name):
300
+ probe_kwargs["max_completion_tokens"] = 100
301
+ else:
302
+ probe_kwargs["temperature"] = 0
303
+ probe_kwargs["max_tokens"] = 100
304
+ response = self.completion_with_backoff(**probe_kwargs)
305
  response.choices[0].message.content
306
  logger.info("Connection check success.")
307
  except Exception:
embodied_gen/utils/gpt_config.yaml CHANGED
@@ -1,5 +1,5 @@
1
  # config.yaml
2
- agent_type: "qwen2.5-vl" # gpt-4o or qwen2.5-vl
3
 
4
  gpt-4o:
5
  endpoint: https://xxx.openai.azure.com
@@ -7,6 +7,12 @@ gpt-4o:
7
  api_version: 2025-xx-xx
8
  model_name: yfb-gpt-4o
9
 
 
 
 
 
 
 
10
  qwen2.5-vl:
11
  endpoint: https://openrouter.ai/api/v1
12
  api_key: sk-or-v1-xxx
 
1
  # config.yaml
2
+ agent_type: "gpt-5.4" # gpt-4o, gpt-5.4 or qwen2.5-vl
3
 
4
  gpt-4o:
5
  endpoint: https://xxx.openai.azure.com
 
7
  api_version: 2025-xx-xx
8
  model_name: yfb-gpt-4o
9
 
10
+ gpt-5.4:
11
+ endpoint: https://yfb-openai-sweden.openai.azure.com/
12
+ api_key: xxx
13
+ api_version: 2024-12-01-preview
14
+ model_name: gpt-5.4
15
+
16
  qwen2.5-vl:
17
  endpoint: https://openrouter.ai/api/v1
18
  api_key: sk-or-v1-xxx