import torch import torch.nn.functional as F import os from transformers import BertJapaneseTokenizer # 年代モデルと性別モデルの定義をインポート from SupervisedLearning import BertForAgeClassification, PRE_TRAINED_MODEL_NAME, DEVICE, NUM_AGE_CLASSIFIERS, AGE_CATEGORIES from GenderLearning import BertForGenderClassification, NUM_GENDER_LABELS # モデルファイルのパス AGE_MODEL_PATH = 'bert_age_model.bin' GENDER_MODEL_PATH = 'bert_gender_model.bin' # 性別のカテゴリマッピング GENDER_CATEGORIES = ["male", "female"] GENDER_CATEGORIES_JP = ["男性", "女性"] # --- グローバル変数としてモデルとトークナイザを一度だけロード --- TOKENIZER = None AGE_MODEL = None GENDER_MODEL = None def load_models(): """アプリケーション起動時にモデルを一度だけ読み込む""" global TOKENIZER, AGE_MODEL, GENDER_MODEL # モデルファイルの存在確認 if not os.path.exists(AGE_MODEL_PATH): raise FileNotFoundError(f"エラー: 年代学習済みモデル '{AGE_MODEL_PATH}' が見つかりません。") # 性別モデルはまだ学習されていない可能性があるので、警告のみ表示 if not os.path.exists(GENDER_MODEL_PATH): print(f"警告: 性別学習済みモデル '{GENDER_MODEL_PATH}' が見つかりません。") print("性別予測は利用できません。年代予測のみ実行されます。") print("--- モデルの読み込みを開始します ---") TOKENIZER = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME) # 年代モデルの読み込み print(" 年代モデルを読み込み中...") AGE_MODEL = BertForAgeClassification(PRE_TRAINED_MODEL_NAME, NUM_AGE_CLASSIFIERS) try: if torch.__version__.startswith('1.'): AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE)) else: AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE, weights_only=True)) except Exception as e: print(f"年代モデルの読み込み中にエラーが発生しました: {e}") raise AGE_MODEL.to(DEVICE) AGE_MODEL.eval() # 性別モデルの読み込み(存在する場合のみ) if os.path.exists(GENDER_MODEL_PATH): print(" 性別モデルを読み込み中...") GENDER_MODEL = BertForGenderClassification(PRE_TRAINED_MODEL_NAME, NUM_GENDER_LABELS) try: if torch.__version__.startswith('1.'): GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE)) else: GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE, weights_only=True)) except Exception as e: print(f"性別モデルの読み込み中にエラーが発生しました: {e}") raise GENDER_MODEL.to(DEVICE) GENDER_MODEL.eval() else: GENDER_MODEL = None print("--- モデルの読み込みが完了しました ---") def predict_text(text: str): """ 入力されたテキストから「年代」と「性別」の各ラベルのパーセンテージを返す関数 """ if AGE_MODEL is None or TOKENIZER is None: load_models() print(f"DEBUG: 入力テキスト: '{text}'") # テキストの前処理 encoding = TOKENIZER.encode_plus( text, add_special_tokens=True, max_length=128, return_token_type_ids=False, padding='max_length', truncation=True, return_attention_mask=True, return_tensors='pt', ) input_ids = encoding['input_ids'].to(DEVICE) attention_mask = encoding['attention_mask'].to(DEVICE) print(f"DEBUG: input_ids shape: {input_ids.shape}") print(f"DEBUG: attention_mask shape: {attention_mask.shape}") # 年代の予測 with torch.no_grad(): _, age_logits = AGE_MODEL(input_ids=input_ids, attention_mask=attention_mask) print(f"DEBUG: age_logits shape: {age_logits.shape}") print(f"DEBUG: age_logits values: {age_logits}") # 各年代の二値分類の確率(シグモイド関数) age_probs = torch.sigmoid(age_logits)[0] # shape: (6,) print(f"DEBUG: age_probs shape: {age_probs.shape}") print(f"DEBUG: age_probs values: {age_probs}") # 年代の確率を辞書形式で保存 age_percentages = {} for i, age in enumerate(AGE_CATEGORIES): percentage = float(f"{age_probs[i].item() * 100:.2f}") # 小数第2位まで age_percentages[age] = percentage print(f"DEBUG: {age}: {age_probs[i].item()} -> {percentage}%") # 性別の予測(モデルが存在する場合のみ) if GENDER_MODEL is not None: with torch.no_grad(): _, gender_logits = GENDER_MODEL(input_ids=input_ids, attention_mask=attention_mask) print(f"DEBUG: gender_logits shape: {gender_logits.shape}") print(f"DEBUG: gender_logits values: {gender_logits}") # 性別の確率(Softmax関数) gender_probs = F.softmax(gender_logits, dim=1)[0] # shape: (2,) print(f"DEBUG: gender_probs shape: {gender_probs.shape}") print(f"DEBUG: gender_probs values: {gender_probs}") # 性別の確率を辞書形式で保存 gender_percentages = {} for i, gender_jp in enumerate(GENDER_CATEGORIES_JP): percentage = float(f"{gender_probs[i].item() * 100:.2f}") # 小数第2位まで gender_percentages[gender_jp] = percentage print(f"DEBUG: {gender_jp}: {gender_probs[i].item()} -> {percentage}%") else: # 性別モデルが存在しない場合はデフォルト値を設定 gender_percentages = {"男性": 50.0, "女性": 50.0} print("DEBUG: 性別モデルが存在しないため、デフォルト値を設定しました") # 結果を返す results = { "age_percentages": age_percentages, "gender_percentages": gender_percentages } print(f"DEBUG: 最終結果: {results}") return results