age-gender-app / app.py
vola2004's picture
Upload 3 files
efbc6bc verified
Raw
History Blame
2.81 kB
import os
import sys
import pkg_resources
# 必要なライブラリをインストール
required_packages = ['Flask', 'torch', 'transformers', 'pandas', 'scikit-learn', 'tqdm', 'matplotlib', 'numpy']
for package in required_packages:
try:
pkg_resources.require(package)
except pkg_resources.DistributionNotFound:
os.system(f'pip install {package}')
from flask import Flask, render_template, request, jsonify
from predictor import load_models, predict_text
import sys
app = Flask(__name__)
# アプリケーション起動時にモデルを読み込み
print("Checking library versions...")
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print("Library check complete.")
# モデルを読み込み
load_models()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/', methods=['POST'])
def predict():
input_text = request.form.get('text', '')
print("=== 推測ボタンが押されました ===", flush=True)
print(f"入力されたテキスト: '{input_text}'", flush=True)
print(f"テキストの長さ: {len(input_text)} 文字", flush=True)
print("================================", flush=True)
sys.stdout.flush() # 強制的にフラッシュ
if not input_text.strip():
print("⚠️ 空のテキストが入力されました", flush=True)
return render_template('index.html',
error="テキストを入力してください。",
input_text=input_text)
try:
# 予測を実行
result = predict_text(input_text)
print("--- 予測結果 ---")
print(f"入力文: {input_text}")
print("年代の確率:")
for age, percentage in result['age_percentages'].items():
print(f" {age}: {percentage}%")
print("性別の確率:")
for gender, percentage in result['gender_percentages'].items():
print(f" {gender}: {percentage}%")
print("-----------------")
return render_template('index.html',
result=result,
input_text=input_text)
except Exception as e:
print(f"予測中にエラーが発生しました: {e}", flush=True)
import traceback
traceback.print_exc()
return render_template('index.html',
error=f"予測中にエラーが発生しました: {str(e)}",
input_text=input_text)
if __name__ == '__main__':
# Hugging Face Spaces用の設定
port = int(os.environ.get('PORT', 5000))
debug = os.environ.get('FLASK_ENV') == 'development'
app.run(host='0.0.0.0', port=port, debug=debug)