age-gender-app / app_docker.py
vola2004's picture
Upload 8 files
0dc3843 verified
Raw
History Blame
4.19 kB
import gradio as gr
import torch
import os
import gzip
import shutil
from predictor import load_models, predict_text
def extract_compressed_models():
"""圧縮されたモデルファイルを展開"""
try:
print("=== 圧縮モデルファイルの展開 ===")
# 圧縮ファイルを展開
compressed_files = [
("bert_age_model.bin.gz", "bert_age_model.bin"),
("bert_gender_model.bin.gz", "bert_gender_model.bin")
]
for compressed_file, extracted_file in compressed_files:
if os.path.exists(compressed_file) and not os.path.exists(extracted_file):
print(f"📦 {compressed_file} を展開中...")
with gzip.open(compressed_file, 'rb') as f_in:
with open(extracted_file, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
print(f"✅ {extracted_file} の展開が完了しました")
elif os.path.exists(extracted_file):
print(f"✅ {extracted_file} は既に存在します")
else:
print(f"❌ {compressed_file} が見つかりません")
return True
except Exception as e:
print(f"❌ 展開エラー: {e}")
return False
def predict_age_gender(text):
"""年代・性別予測関数"""
if not text.strip():
return "テキストを入力してください。", "", ""
try:
result = predict_text(text)
# 年代予測結果を整形
age_results = []
for age, percentage in result['age_percentages'].items():
age_results.append(f"{age}: {percentage}%")
age_text = "\n".join(age_results)
# 性別予測結果を整形
gender_results = []
for gender, percentage in result['gender_percentages'].items():
gender_results.append(f"{gender}: {percentage}%")
gender_text = "\n".join(gender_results)
# 最も高い確率の年代を特定
max_age = max(result['age_percentages'].items(), key=lambda x: x[1])
max_gender = max(result['gender_percentages'].items(), key=lambda x: x[1])
summary = f"推定結果: {max_age[0]} ({max_age[1]}%), {max_gender[0]} ({max_gender[1]}%)"
return summary, age_text, gender_text
except Exception as e:
return f"エラーが発生しました: {str(e)}", "", ""
# モデルの展開と読み込み
print("=== アプリケーション初期化 ===")
# 圧縮ファイルを展開
if extract_compressed_models():
# モデルの読み込み
print("=== モデル読み込み開始 ===")
try:
load_models()
print("✅ モデルの読み込みが完了しました")
except Exception as e:
print(f"❌ モデルの読み込みに失敗しました: {e}")
print("⚠️ モデルファイルが存在しない可能性があります")
else:
print("❌ モデルの展開に失敗しました")
# Gradioインターフェース
interface = gr.Interface(
fn=predict_age_gender,
inputs=gr.Textbox(
label="日本語テキストを入力してください",
placeholder="例: 今日はとても良い天気ですね。友達と一緒に散歩をしました。",
lines=3
),
outputs=[
gr.Textbox(label="推定結果サマリー"),
gr.Textbox(label="年代予測詳細"),
gr.Textbox(label="性別予測詳細")
],
title="🧠 年代・性別推定システム",
description="日本語テキストから年代と性別を推定するAIシステムです。",
examples=[
"今日はとても良い天気ですね。",
"友達と一緒に散歩をしました。",
"新しいスマートフォンを買いました。",
"仕事が忙しくて疲れました。"
],
theme=gr.themes.Soft()
)
# アプリケーション起動
if __name__ == "__main__":
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False # Hugging Face SpacesではFalse
)