| import os |
| import cv2 |
| import gradio as gr |
| import numpy as np |
| import random |
| import base64 |
| import requests |
| import json |
| import time |
|
|
| MAX_SEED = 999999 |
| |
| example_path = os.path.join(os.path.dirname(__file__), 'assets') |
|
|
| refer_list = sorted(os.listdir(os.path.join(example_path, "refer_imgs"))) |
| refer_list_path = [os.path.join(example_path, "refer_imgs", img) for img in refer_list] |
|
|
| human_list = sorted(os.listdir(os.path.join(example_path, "id_imgs"))) |
| human_list_path = [os.path.join(example_path, "id_imgs", img) for img in human_list] |
|
|
| prompt_list = [ |
| # ... 這裡放你的 prompt_list ... |
| ] |
|
|
| prompt_list_single = [ |
| 'hold a sign writing "Kolors Portrait with Flux"', |
| "A beautiful girl reading book, high quality.", |
| "Full body shot of young Asian face woman in sun hat and white dress standing on sunny beach with sea and mountains in background, high quality, sharp focus,", |
| ] |
| |
| def imgpath_to_prompt(img_path): |
| if img_path is None: |
| return None |
| img_name = os.path.basename(img_path) |
| if img_name in refer_list: |
| prompt_id = refer_list.index(img_name) |
| return prompt_list[prompt_id] |
| else: |
| return None |
| |
| def portrait_gen(person_img, template_img, prompt, seed, randomize_seed): |
| if person_img is None: |
| raise gr.Error("Empty person image") |
| |
| if randomize_seed: |
| seed = random.randint(0, MAX_SEED) |
| |
| encoded_person_img = cv2.imencode('.jpg', cv2.cvtColor(person_img, cv2.COLOR_RGB2BGR))[1].tobytes() |
| encoded_person_img = base64.b64encode(encoded_person_img).decode('utf-8') |
| |
| encoded_template_img = None |
| if template_img is not None and template_img != "": |
| prompt_temp = imgpath_to_prompt(template_img) |
| if prompt_temp is not None: |
| prompt = prompt_temp |
| template_bgr = cv2.imread(template_img) |
| if template_bgr is not None: |
| encoded_template_img = cv2.imencode('.jpg', template_bgr)[1].tobytes() |
| encoded_template_img = base64.b64encode(encoded_template_img).decode('utf-8') |
| |
| url = "http://" + os.environ.get('avatar_url', '') + "Submit" |
| token = os.environ.get('token', '') |
| referer = os.environ.get('referer', '') |
| headers = {'Content-Type': 'application/json', 'token': token, 'referer': referer} |
| data = { |
| "templateImage": encoded_template_img, |
| "humanImage": encoded_person_img, |
| "seed": seed, |
| "prompt": prompt |
| } |
| |
| try: |
| response = requests.post(url, headers=headers, data=json.dumps(data), timeout=50) |
| if response.status_code == 200: |
| result = response.json()['result'] |
| status = result['status'] |
| if status == "success": |
| uuid = result['result'] |
| else: |
| raise gr.Error(f"Submit error status: {status}") |
| else: |
| raise gr.Error(f"Submit request failed with status code {response.status_code}") |
| except Exception as e: |
| raise gr.Error(f"Post Exception: {e}") |
| |
| time.sleep(9) # 等待模型生成 |
| |
| Max_Retry = 12 |
| result_img = None |
| info = "" |
| for _ in range(Max_Retry): |
| try: |
| query_url = "http://" + os.environ.get('avatar_url', '') + "Query?taskId=" + uuid |
| response = requests.get(query_url, headers=headers, timeout=20) |
| if response.status_code == 200: |
| result = response.json()['result'] |
| status = result['status'] |
| if status == "success": |
| result_b64 = result['result'] |
| result_decoded = base64.b64decode(result_b64) |
| result_np = np.frombuffer(result_decoded, np.uint8) |
| result_img = cv2.imdecode(result_np, cv2.IMREAD_UNCHANGED) |
| if result_img is not None: |
| result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB) # OpenCV 是BGR要轉RGB給Gradio |
| info = "Success" |
| break |
| elif status == "error": |
| info = "Error from server" |
| break |
| else: |
| info = f"Query failed with status {response.status_code}" |
| break |
| except requests.exceptions.ReadTimeout: |
| info = "Timeout, please try later" |
| except Exception as e: |
| info = f"Exception during query: {e}" |
| time.sleep(1) |
| |
| if info != "Success": |
| raise gr.Error(info) |
| |
| return result_img, seed, info |
| |
| css = """ |
| #col-left { max-width: 270px; margin: 0 auto;} |
| #col-mid { max-width: 270px; margin: 0 auto;} |
| #col-right { max-width: 550px; margin: 0 auto;} |
| #col-showcase { max-width: 1100px; margin: 0 auto;} |
| #button { color: blue;} |
| """ |
|
|
| def change_mode(mode: str): |
| if mode == "Input prompt": |
| return gr.update(visible=True), gr.update(visible=False, value=None) |
| elif mode == "Ref. image": |
| return gr.update(visible=False, value=""), gr.update(visible=True) |
| else: |
| raise gr.Error("no such mode!") |
| |
| change_mode_js = """ |
| function change_mode_js(mode){ |
| const text_examples = document.querySelector("#text_examples"); |
| const image_examples = document.querySelector("#image_examples"); |
| const ref_examples = document.querySelector("#ref_examples"); |
| const prompt_examples = document.querySelector("#prompt_examples"); |
| if (mode === "Input prompt"){ |
| text_examples.hidden = false; |
| prompt_examples.hidden = false; |
| image_examples.hidden = true; |
| ref_examples.hidden = true; |
| } else { |
| text_examples.hidden = true; |
| prompt_examples.hidden = true; |
| image_examples.hidden = false; |
| ref_examples.hidden = false; |
| } |
| return [mode]; |
| } |
| """ |
| |
| with gr.Blocks(css=css) as Portrait: |
| gr.HTML("<h1>Kolors Portrait with Flux</h1>") |
| with gr.Row(): |
| with gr.Column(elem_id="col-left"): |
| gr.HTML("<h3>Step 1. Upload a portrait image ⬇️</h3>") |
| with gr.Column(elem_id="col-mid"): |
| gr.HTML("<h3>Step 2. Set a style reference ⬇️</h3>") |
| with gr.Column(elem_id="col-right"): |
| gr.HTML("<h3>Step 3. Press “Run” to get results ⬇️</h3>") |
| |
| with gr.Row(): |
| with gr.Column(elem_id="col-left"): |
| imgs = gr.Image(label="Person Images", source='upload', type="numpy", width=400) |
| gr.Examples( |
| examples=human_list_path, |
| inputs=imgs, |
| label="Person Images Examples", |
| examples_per_page=10 |
| ) |
| with gr.Column(elem_id="col-mid"): |
| mode = gr.Radio(label="Set a reference image or input prompt", |
| choices=["Ref. image", "Input prompt"], |
| value="Input prompt") |
| refer_img = gr.Image(label="Reference images", source='upload', type="filepath") |
| gr.Examples( |
| examples=refer_list_path, |
| inputs=refer_img, |
| label='Reference images Examples', |
| elem_id="ref_examples", |
| examples_per_page=9 |
| ) |
| prompt = gr.Textbox( |
| label="Input prompts", |
| placeholder="Enter your prompt", |
| max_lines=1, |
| visible=True, |
| ) |
| gr.Examples( |
| examples=prompt_list_single, |
| inputs=prompt, |
| label='Prompt Examples', |
| elem_id="prompt_examples", |
| ) |
| mode.change( |
| fn=change_mode, |
| inputs=mode, |
| outputs=[prompt, refer_img], |
| js=change_mode_js |
| ) |
| with gr.Column(elem_id="col-right"): |
| image_out = gr.Image(label="Results", show_share_button=False) |
| with gr.Row(): |
| seed = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0, label="Seed") |
| randomize_seed = gr.Checkbox(label="Random seed", value=True) |
| with gr.Row(): |
| seed_used = gr.Number(label="Seed used", interactive=False) |
| result_info = gr.Textbox(label="Response", interactive=False) |
| test_button = gr.Button(value="Run", elem_id="button") |
| |
| test_button.click( |
| fn=portrait_gen, |
| inputs=[imgs, refer_img, prompt, seed, randomize_seed], |
| outputs=[image_out, seed_used, result_info], |
| api_name=False, |
| concurrency_limit=10 |
| ) |
| |
| Portrait.queue(api_open=False).launch(show_api=False) |
|
|