Spaces:
Build error
Build error
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| from transformers import BertJapaneseTokenizer | |
| # 年代モデルと性別モデルの定義をインポート | |
| from SupervisedLearning import BertForAgeClassification, PRE_TRAINED_MODEL_NAME, DEVICE, NUM_AGE_CLASSIFIERS, AGE_CATEGORIES | |
| from GenderLearning import BertForGenderClassification, NUM_GENDER_LABELS | |
| # モデルファイルのパス | |
| AGE_MODEL_PATH = 'bert_age_model.bin' | |
| GENDER_MODEL_PATH = 'bert_gender_model.bin' | |
| # 性別のカテゴリマッピング | |
| GENDER_CATEGORIES = ["male", "female"] | |
| GENDER_CATEGORIES_JP = ["男性", "女性"] | |
| # --- グローバル変数としてモデルとトークナイザを一度だけロード --- | |
| TOKENIZER = None | |
| AGE_MODEL = None | |
| GENDER_MODEL = None | |
| def load_models(): | |
| """アプリケーション起動時にモデルを一度だけ読み込む""" | |
| global TOKENIZER, AGE_MODEL, GENDER_MODEL | |
| # モデルファイルの存在確認 | |
| if not os.path.exists(AGE_MODEL_PATH): | |
| raise FileNotFoundError(f"エラー: 年代学習済みモデル '{AGE_MODEL_PATH}' が見つかりません。") | |
| # 性別モデルはまだ学習されていない可能性があるので、警告のみ表示 | |
| if not os.path.exists(GENDER_MODEL_PATH): | |
| print(f"警告: 性別学習済みモデル '{GENDER_MODEL_PATH}' が見つかりません。") | |
| print("性別予測は利用できません。年代予測のみ実行されます。") | |
| print("--- モデルの読み込みを開始します ---") | |
| TOKENIZER = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME) | |
| # 年代モデルの読み込み | |
| print(" 年代モデルを読み込み中...") | |
| AGE_MODEL = BertForAgeClassification(PRE_TRAINED_MODEL_NAME, NUM_AGE_CLASSIFIERS) | |
| try: | |
| if torch.__version__.startswith('1.'): | |
| AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE)) | |
| else: | |
| AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE, weights_only=True)) | |
| except Exception as e: | |
| print(f"年代モデルの読み込み中にエラーが発生しました: {e}") | |
| raise | |
| AGE_MODEL.to(DEVICE) | |
| AGE_MODEL.eval() | |
| # 性別モデルの読み込み(存在する場合のみ) | |
| if os.path.exists(GENDER_MODEL_PATH): | |
| print(" 性別モデルを読み込み中...") | |
| GENDER_MODEL = BertForGenderClassification(PRE_TRAINED_MODEL_NAME, NUM_GENDER_LABELS) | |
| try: | |
| if torch.__version__.startswith('1.'): | |
| GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE)) | |
| else: | |
| GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE, weights_only=True)) | |
| except Exception as e: | |
| print(f"性別モデルの読み込み中にエラーが発生しました: {e}") | |
| raise | |
| GENDER_MODEL.to(DEVICE) | |
| GENDER_MODEL.eval() | |
| else: | |
| GENDER_MODEL = None | |
| print("--- モデルの読み込みが完了しました ---") | |
| def predict_text(text: str): | |
| """ | |
| 入力されたテキストから「年代」と「性別」の各ラベルのパーセンテージを返す関数 | |
| """ | |
| if AGE_MODEL is None or TOKENIZER is None: | |
| load_models() | |
| print(f"DEBUG: 入力テキスト: '{text}'") | |
| # テキストの前処理 | |
| encoding = TOKENIZER.encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=128, | |
| return_token_type_ids=False, | |
| padding='max_length', | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt', | |
| ) | |
| input_ids = encoding['input_ids'].to(DEVICE) | |
| attention_mask = encoding['attention_mask'].to(DEVICE) | |
| print(f"DEBUG: input_ids shape: {input_ids.shape}") | |
| print(f"DEBUG: attention_mask shape: {attention_mask.shape}") | |
| # 年代の予測 | |
| with torch.no_grad(): | |
| _, age_logits = AGE_MODEL(input_ids=input_ids, attention_mask=attention_mask) | |
| print(f"DEBUG: age_logits shape: {age_logits.shape}") | |
| print(f"DEBUG: age_logits values: {age_logits}") | |
| # 各年代の二値分類の確率(シグモイド関数) | |
| age_probs = torch.sigmoid(age_logits)[0] # shape: (6,) | |
| print(f"DEBUG: age_probs shape: {age_probs.shape}") | |
| print(f"DEBUG: age_probs values: {age_probs}") | |
| # 年代の確率を辞書形式で保存 | |
| age_percentages = {} | |
| for i, age in enumerate(AGE_CATEGORIES): | |
| percentage = float(f"{age_probs[i].item() * 100:.2f}") # 小数第2位まで | |
| age_percentages[age] = percentage | |
| print(f"DEBUG: {age}: {age_probs[i].item()} -> {percentage}%") | |
| # 性別の予測(モデルが存在する場合のみ) | |
| if GENDER_MODEL is not None: | |
| with torch.no_grad(): | |
| _, gender_logits = GENDER_MODEL(input_ids=input_ids, attention_mask=attention_mask) | |
| print(f"DEBUG: gender_logits shape: {gender_logits.shape}") | |
| print(f"DEBUG: gender_logits values: {gender_logits}") | |
| # 性別の確率(Softmax関数) | |
| gender_probs = F.softmax(gender_logits, dim=1)[0] # shape: (2,) | |
| print(f"DEBUG: gender_probs shape: {gender_probs.shape}") | |
| print(f"DEBUG: gender_probs values: {gender_probs}") | |
| # 性別の確率を辞書形式で保存 | |
| gender_percentages = {} | |
| for i, gender_jp in enumerate(GENDER_CATEGORIES_JP): | |
| percentage = float(f"{gender_probs[i].item() * 100:.2f}") # 小数第2位まで | |
| gender_percentages[gender_jp] = percentage | |
| print(f"DEBUG: {gender_jp}: {gender_probs[i].item()} -> {percentage}%") | |
| else: | |
| # 性別モデルが存在しない場合はデフォルト値を設定 | |
| gender_percentages = {"男性": 50.0, "女性": 50.0} | |
| print("DEBUG: 性別モデルが存在しないため、デフォルト値を設定しました") | |
| # 結果を返す | |
| results = { | |
| "age_percentages": age_percentages, | |
| "gender_percentages": gender_percentages | |
| } | |
| print(f"DEBUG: 最終結果: {results}") | |
| return results | |