age-gender-app / predictor.py
vola2004's picture
Upload 3 files
efbc6bc verified
Raw
History Blame
6.21 kB
import torch
import torch.nn.functional as F
import os
from transformers import BertJapaneseTokenizer
# 年代モデルと性別モデルの定義をインポート
from SupervisedLearning import BertForAgeClassification, PRE_TRAINED_MODEL_NAME, DEVICE, NUM_AGE_CLASSIFIERS, AGE_CATEGORIES
from GenderLearning import BertForGenderClassification, NUM_GENDER_LABELS
# モデルファイルのパス
AGE_MODEL_PATH = 'bert_age_model.bin'
GENDER_MODEL_PATH = 'bert_gender_model.bin'
# 性別のカテゴリマッピング
GENDER_CATEGORIES = ["male", "female"]
GENDER_CATEGORIES_JP = ["男性", "女性"]
# --- グローバル変数としてモデルとトークナイザを一度だけロード ---
TOKENIZER = None
AGE_MODEL = None
GENDER_MODEL = None
def load_models():
"""アプリケーション起動時にモデルを一度だけ読み込む"""
global TOKENIZER, AGE_MODEL, GENDER_MODEL
# モデルファイルの存在確認
if not os.path.exists(AGE_MODEL_PATH):
raise FileNotFoundError(f"エラー: 年代学習済みモデル '{AGE_MODEL_PATH}' が見つかりません。")
# 性別モデルはまだ学習されていない可能性があるので、警告のみ表示
if not os.path.exists(GENDER_MODEL_PATH):
print(f"警告: 性別学習済みモデル '{GENDER_MODEL_PATH}' が見つかりません。")
print("性別予測は利用できません。年代予測のみ実行されます。")
print("--- モデルの読み込みを開始します ---")
TOKENIZER = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
# 年代モデルの読み込み
print(" 年代モデルを読み込み中...")
AGE_MODEL = BertForAgeClassification(PRE_TRAINED_MODEL_NAME, NUM_AGE_CLASSIFIERS)
try:
if torch.__version__.startswith('1.'):
AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE))
else:
AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE, weights_only=True))
except Exception as e:
print(f"年代モデルの読み込み中にエラーが発生しました: {e}")
raise
AGE_MODEL.to(DEVICE)
AGE_MODEL.eval()
# 性別モデルの読み込み(存在する場合のみ)
if os.path.exists(GENDER_MODEL_PATH):
print(" 性別モデルを読み込み中...")
GENDER_MODEL = BertForGenderClassification(PRE_TRAINED_MODEL_NAME, NUM_GENDER_LABELS)
try:
if torch.__version__.startswith('1.'):
GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE))
else:
GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE, weights_only=True))
except Exception as e:
print(f"性別モデルの読み込み中にエラーが発生しました: {e}")
raise
GENDER_MODEL.to(DEVICE)
GENDER_MODEL.eval()
else:
GENDER_MODEL = None
print("--- モデルの読み込みが完了しました ---")
def predict_text(text: str):
"""
入力されたテキストから「年代」と「性別」の各ラベルのパーセンテージを返す関数
"""
if AGE_MODEL is None or TOKENIZER is None:
load_models()
print(f"DEBUG: 入力テキスト: '{text}'")
# テキストの前処理
encoding = TOKENIZER.encode_plus(
text,
add_special_tokens=True,
max_length=128,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt',
)
input_ids = encoding['input_ids'].to(DEVICE)
attention_mask = encoding['attention_mask'].to(DEVICE)
print(f"DEBUG: input_ids shape: {input_ids.shape}")
print(f"DEBUG: attention_mask shape: {attention_mask.shape}")
# 年代の予測
with torch.no_grad():
_, age_logits = AGE_MODEL(input_ids=input_ids, attention_mask=attention_mask)
print(f"DEBUG: age_logits shape: {age_logits.shape}")
print(f"DEBUG: age_logits values: {age_logits}")
# 各年代の二値分類の確率(シグモイド関数)
age_probs = torch.sigmoid(age_logits)[0] # shape: (6,)
print(f"DEBUG: age_probs shape: {age_probs.shape}")
print(f"DEBUG: age_probs values: {age_probs}")
# 年代の確率を辞書形式で保存
age_percentages = {}
for i, age in enumerate(AGE_CATEGORIES):
percentage = float(f"{age_probs[i].item() * 100:.2f}") # 小数第2位まで
age_percentages[age] = percentage
print(f"DEBUG: {age}: {age_probs[i].item()} -> {percentage}%")
# 性別の予測(モデルが存在する場合のみ)
if GENDER_MODEL is not None:
with torch.no_grad():
_, gender_logits = GENDER_MODEL(input_ids=input_ids, attention_mask=attention_mask)
print(f"DEBUG: gender_logits shape: {gender_logits.shape}")
print(f"DEBUG: gender_logits values: {gender_logits}")
# 性別の確率(Softmax関数)
gender_probs = F.softmax(gender_logits, dim=1)[0] # shape: (2,)
print(f"DEBUG: gender_probs shape: {gender_probs.shape}")
print(f"DEBUG: gender_probs values: {gender_probs}")
# 性別の確率を辞書形式で保存
gender_percentages = {}
for i, gender_jp in enumerate(GENDER_CATEGORIES_JP):
percentage = float(f"{gender_probs[i].item() * 100:.2f}") # 小数第2位まで
gender_percentages[gender_jp] = percentage
print(f"DEBUG: {gender_jp}: {gender_probs[i].item()} -> {percentage}%")
else:
# 性別モデルが存在しない場合はデフォルト値を設定
gender_percentages = {"男性": 50.0, "女性": 50.0}
print("DEBUG: 性別モデルが存在しないため、デフォルト値を設定しました")
# 結果を返す
results = {
"age_percentages": age_percentages,
"gender_percentages": gender_percentages
}
print(f"DEBUG: 最終結果: {results}")
return results