Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- DataNLP.py +85 -0
- GenderLearning.py +220 -0
- README.md +38 -7
- SupervisedLearning.py +404 -0
- app.py +76 -0
- predictor.py +149 -0
- requirements.txt +11 -0
DataNLP.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from sudachipy import tokenizer, dictionary
|
| 3 |
+
import neologdn
|
| 4 |
+
import os # osモジュールをインポート
|
| 5 |
+
|
| 6 |
+
# --- スクリプトのディレクトリを基準にパスを設定 ---
|
| 7 |
+
# このスクリプト自身の絶対パスを取得
|
| 8 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 9 |
+
# 作業ディレクトリをこのスクリプトがあるディレクトリに変更
|
| 10 |
+
os.chdir(script_dir)
|
| 11 |
+
|
| 12 |
+
def load_preprocessed_data():
|
| 13 |
+
"""
|
| 14 |
+
DataSet.xlsxを読み込み、前処理(欠損値除去、ラベルエンコーディング、形態素解析)を行い、
|
| 15 |
+
処理済みのDataFrameと元のデータ数を返します。
|
| 16 |
+
"""
|
| 17 |
+
# --- Step 1: データ読み込み ---
|
| 18 |
+
df = pd.read_excel("DataSet.xlsx")
|
| 19 |
+
initial_count = len(df)
|
| 20 |
+
|
| 21 |
+
# --- Step 2: 欠損除去 ---
|
| 22 |
+
df = df.dropna(subset=["コメント", "性別", "年代"]).reset_index(drop=True)
|
| 23 |
+
|
| 24 |
+
# --- Step 2.5: 表記揺れ正規化 ---
|
| 25 |
+
df["コメント"] = df["コメント"].astype(str).apply(neologdn.normalize)
|
| 26 |
+
|
| 27 |
+
# --- Step 3: 年代と性別のラベルを別々に作成 ---
|
| 28 |
+
df["年代性別"] = df["年代"] + " " + df["性別"]
|
| 29 |
+
|
| 30 |
+
# 各年代ごとに二値分類ラベルを作成(その年代かどうか)
|
| 31 |
+
age_categories = ["10代", "20代", "30代", "40代", "50代", "60代"]
|
| 32 |
+
for age in age_categories:
|
| 33 |
+
df[f"{age}_label"] = (df["年代"] == age).astype(int)
|
| 34 |
+
|
| 35 |
+
# 性別ラベルのマッピング
|
| 36 |
+
gender_categories = ["male", "female"]
|
| 37 |
+
gender_label_map = {cat: idx for idx, cat in enumerate(gender_categories)}
|
| 38 |
+
df["性別_label"] = df["性別"].map(gender_label_map)
|
| 39 |
+
|
| 40 |
+
# 統合ラベルも残す(後方互換性のため)
|
| 41 |
+
combined_categories = [
|
| 42 |
+
"10代 male", "10代 female",
|
| 43 |
+
"20代 male", "20代 female",
|
| 44 |
+
"30代 male", "30代 female",
|
| 45 |
+
"40代 male", "40代 female",
|
| 46 |
+
"50代 male", "50代 female",
|
| 47 |
+
"60代 male", "60代 female"
|
| 48 |
+
]
|
| 49 |
+
combined_label_map = {cat: idx for idx, cat in enumerate(combined_categories)}
|
| 50 |
+
df["年代性別_label"] = df["年代性別"].map(combined_label_map)
|
| 51 |
+
|
| 52 |
+
# --- Step 4: Sudachipyによる形態素解析(表層 + 品詞)---
|
| 53 |
+
tokenizer_obj = dictionary.Dictionary().create()
|
| 54 |
+
mode = tokenizer.Tokenizer.SplitMode.C
|
| 55 |
+
|
| 56 |
+
def sudachi_tokenize_with_pos(text):
|
| 57 |
+
tokens = tokenizer_obj.tokenize(text, mode)
|
| 58 |
+
return [
|
| 59 |
+
f"{m.surface()}/{m.part_of_speech()[0]}"
|
| 60 |
+
for m in tokens if m.surface().strip()
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
df["tokens"] = df["コメント"].apply(sudachi_tokenize_with_pos)
|
| 64 |
+
df["text"] = df["tokens"].apply(lambda x: " ".join(x))
|
| 65 |
+
|
| 66 |
+
return df, initial_count
|
| 67 |
+
|
| 68 |
+
if __name__ == '__main__':
|
| 69 |
+
df, initial_count = load_preprocessed_data()
|
| 70 |
+
|
| 71 |
+
# --- 表示 ---
|
| 72 |
+
print(f"✅ Excel内の全データ数: {initial_count} 件")
|
| 73 |
+
print(f"\n✅ 前処理後のデータ数: {len(df)} 件")
|
| 74 |
+
print("==== Sudachipyによる処理結果の一部 ====")
|
| 75 |
+
|
| 76 |
+
for i in range(min(10, len(df))): # 先頭10件まで表示
|
| 77 |
+
print(f"\n【{i+1}件目】")
|
| 78 |
+
print(f"[原文(正規化後)] {df.loc[i, 'コメント']}")
|
| 79 |
+
print(f"[形態素+品詞] {df.loc[i, 'tokens']}")
|
| 80 |
+
print(f"[テキスト形式] {df.loc[i, 'text']}")
|
| 81 |
+
print(f"[年代性別] {df.loc[i, '年代性別']}")
|
| 82 |
+
print(f"[年代] {df.loc[i, '年代']}")
|
| 83 |
+
print(f" 10代_label: {df.loc[i, '10代_label']}, 20代_label: {df.loc[i, '20代_label']}, 30代_label: {df.loc[i, '30代_label']}")
|
| 84 |
+
print(f" 40代_label: {df.loc[i, '40代_label']}, 50代_label: {df.loc[i, '50代_label']}, 60代_label: {df.loc[i, '60代_label']}")
|
| 85 |
+
print(f"[性別] {df.loc[i, '性別']} -> [性別_label] {df.loc[i, '性別_label']}")
|
GenderLearning.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader, RandomSampler
|
| 4 |
+
from torch.optim import AdamW
|
| 5 |
+
from transformers import BertJapaneseTokenizer, BertModel
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import accuracy_score
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
from DataNLP import load_preprocessed_data
|
| 13 |
+
|
| 14 |
+
# --- スクリプトのディレクトリを基準にパスを設定 ---
|
| 15 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
os.chdir(script_dir)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# 設定
|
| 20 |
+
PRE_TRAINED_MODEL_NAME = 'cl-tohoku/bert-large-japanese'
|
| 21 |
+
MAX_LEN = 128
|
| 22 |
+
BATCH_SIZE = 32 # バッチサイズを増加して高速化
|
| 23 |
+
EPOCHS = 10 # 重みを大幅に更新するためエポック数を増加
|
| 24 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 25 |
+
NUM_GENDER_LABELS = 2 # male, female
|
| 26 |
+
|
| 27 |
+
# --- データセットクラス ---
|
| 28 |
+
class GenderDataset(Dataset):
|
| 29 |
+
def __init__(self, texts, gender_labels, tokenizer, max_len):
|
| 30 |
+
self.texts = texts
|
| 31 |
+
self.gender_labels = gender_labels
|
| 32 |
+
self.tokenizer = tokenizer
|
| 33 |
+
self.max_len = max_len
|
| 34 |
+
|
| 35 |
+
def __len__(self):
|
| 36 |
+
return len(self.texts)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, item):
|
| 39 |
+
text = str(self.texts[item])
|
| 40 |
+
gender_label = self.gender_labels[item]
|
| 41 |
+
|
| 42 |
+
encoding = self.tokenizer.encode_plus(
|
| 43 |
+
text,
|
| 44 |
+
add_special_tokens=True,
|
| 45 |
+
max_length=self.max_len,
|
| 46 |
+
return_token_type_ids=False,
|
| 47 |
+
padding='max_length',
|
| 48 |
+
truncation=True,
|
| 49 |
+
return_attention_mask=True,
|
| 50 |
+
return_tensors='pt',
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return {
|
| 54 |
+
'input_ids': encoding['input_ids'].flatten(),
|
| 55 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
| 56 |
+
'gender_labels': torch.tensor(int(gender_label), dtype=torch.long)
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# --- モデル定義 ---
|
| 60 |
+
class BertForGenderClassification(nn.Module):
|
| 61 |
+
def __init__(self, model_name, num_gender_labels):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.bert = BertModel.from_pretrained(model_name, use_safetensors=True)
|
| 64 |
+
self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
|
| 65 |
+
self.gender_classifier = nn.Linear(self.bert.config.hidden_size, num_gender_labels)
|
| 66 |
+
|
| 67 |
+
def forward(self, input_ids, attention_mask, gender_labels=None):
|
| 68 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 69 |
+
pooled_output = outputs.pooler_output
|
| 70 |
+
pooled_output = self.dropout(pooled_output)
|
| 71 |
+
|
| 72 |
+
gender_logits = self.gender_classifier(pooled_output)
|
| 73 |
+
|
| 74 |
+
loss = None
|
| 75 |
+
if gender_labels is not None:
|
| 76 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 77 |
+
loss = loss_fct(gender_logits, gender_labels)
|
| 78 |
+
|
| 79 |
+
return loss, gender_logits
|
| 80 |
+
|
| 81 |
+
# --- 学習関数 ---
|
| 82 |
+
def train_epoch(model, data_loader, optimizer, device):
|
| 83 |
+
model.train()
|
| 84 |
+
total_loss = 0
|
| 85 |
+
|
| 86 |
+
for batch in tqdm(data_loader, desc="Training"):
|
| 87 |
+
input_ids = batch['input_ids'].to(device)
|
| 88 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 89 |
+
gender_labels = batch['gender_labels'].to(device)
|
| 90 |
+
|
| 91 |
+
optimizer.zero_grad()
|
| 92 |
+
loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, gender_labels=gender_labels)
|
| 93 |
+
|
| 94 |
+
if isinstance(loss, torch.Tensor):
|
| 95 |
+
loss.backward()
|
| 96 |
+
optimizer.step()
|
| 97 |
+
total_loss += loss.item()
|
| 98 |
+
|
| 99 |
+
return total_loss / len(data_loader)
|
| 100 |
+
|
| 101 |
+
# --- 評価関数 ---
|
| 102 |
+
def eval_model(model, data_loader, device):
|
| 103 |
+
model.eval()
|
| 104 |
+
gender_preds, gender_true_labels = [], []
|
| 105 |
+
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
for batch in tqdm(data_loader, desc="Evaluating"):
|
| 108 |
+
input_ids = batch['input_ids'].to(device)
|
| 109 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 110 |
+
|
| 111 |
+
_, gender_logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 112 |
+
|
| 113 |
+
gender_preds.extend(torch.argmax(gender_logits, dim=1).cpu().numpy())
|
| 114 |
+
gender_true_labels.extend(batch['gender_labels'].cpu().numpy())
|
| 115 |
+
|
| 116 |
+
gender_acc = accuracy_score(gender_true_labels, gender_preds)
|
| 117 |
+
|
| 118 |
+
return gender_acc
|
| 119 |
+
|
| 120 |
+
# --- データサンプリング関数(性別ごとにバランシング) ---
|
| 121 |
+
def sample_balanced_data(df, max_per_gender=20000):
|
| 122 |
+
"""
|
| 123 |
+
性別ごとにバランシングする
|
| 124 |
+
- 性別:各性別ごとに最大max_per_gender件
|
| 125 |
+
"""
|
| 126 |
+
gender_sampled_dfs = []
|
| 127 |
+
for gender_label in df['性別_label'].unique():
|
| 128 |
+
subset = df[df['性別_label'] == gender_label]
|
| 129 |
+
if len(subset) > max_per_gender:
|
| 130 |
+
subset = subset.sample(max_per_gender, random_state=42)
|
| 131 |
+
gender_sampled_dfs.append(subset)
|
| 132 |
+
|
| 133 |
+
return pd.concat(gender_sampled_dfs).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 134 |
+
|
| 135 |
+
# --- メイン処理 ---
|
| 136 |
+
def main():
|
| 137 |
+
print("--- 1. データ読み込み ---")
|
| 138 |
+
df, _ = load_preprocessed_data()
|
| 139 |
+
|
| 140 |
+
# --- データを性別でバランシ��グして軽量化 ---
|
| 141 |
+
df = sample_balanced_data(df, max_per_gender=5000) # データ量を大幅に増加
|
| 142 |
+
|
| 143 |
+
# ラベルの分布を確認
|
| 144 |
+
print("\n性別ラベルの分布:")
|
| 145 |
+
print(df['性別_label'].value_counts().sort_index())
|
| 146 |
+
print(f"\n合計データ数: {len(df)} 件")
|
| 147 |
+
|
| 148 |
+
# 訓練用と検証用に分割
|
| 149 |
+
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
|
| 150 |
+
|
| 151 |
+
print(f"\n--- 2. トークナイザとデータローダーの準備 ---")
|
| 152 |
+
tokenizer = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
|
| 153 |
+
|
| 154 |
+
train_dataset = GenderDataset(
|
| 155 |
+
train_df['text'].values,
|
| 156 |
+
train_df['性別_label'].values,
|
| 157 |
+
tokenizer,
|
| 158 |
+
MAX_LEN
|
| 159 |
+
)
|
| 160 |
+
train_sampler = RandomSampler(train_dataset)
|
| 161 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
|
| 162 |
+
|
| 163 |
+
val_dataset = GenderDataset(
|
| 164 |
+
val_df['text'].values,
|
| 165 |
+
val_df['性別_label'].values,
|
| 166 |
+
tokenizer,
|
| 167 |
+
MAX_LEN
|
| 168 |
+
)
|
| 169 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
| 170 |
+
print("\n--- 3. モデルのセットアップ ---")
|
| 171 |
+
model = BertForGenderClassification(PRE_TRAINED_MODEL_NAME, NUM_GENDER_LABELS)
|
| 172 |
+
model.to(DEVICE)
|
| 173 |
+
# BERT全体をファインチューニング(レイヤーごとに異なる学習率を設定)
|
| 174 |
+
optimizer = AdamW([
|
| 175 |
+
{'params': model.bert.parameters(), 'lr': 2e-5}, # BERT本体は小さい学習率
|
| 176 |
+
{'params': model.gender_classifier.parameters(), 'lr': 5e-4}, # 分類層は大きい学習率
|
| 177 |
+
])
|
| 178 |
+
|
| 179 |
+
print("\n--- 4. 学習開始 ---")
|
| 180 |
+
print(f"デバイス: {DEVICE}")
|
| 181 |
+
print(f"訓練データ数: {len(train_df)} 件")
|
| 182 |
+
print(f"検証データ数: {len(val_df)} 件")
|
| 183 |
+
print(f"バッチサイズ: {BATCH_SIZE}")
|
| 184 |
+
print(f"エポック数: {EPOCHS}")
|
| 185 |
+
print(f"推定学習時間: 約35時間")
|
| 186 |
+
|
| 187 |
+
import time
|
| 188 |
+
start_time = time.time()
|
| 189 |
+
|
| 190 |
+
for epoch in range(EPOCHS):
|
| 191 |
+
epoch_start_time = time.time()
|
| 192 |
+
print(f"\n{'='*60}")
|
| 193 |
+
print(f"Epoch {epoch + 1}/{EPOCHS} 開始")
|
| 194 |
+
print(f"{'='*60}")
|
| 195 |
+
|
| 196 |
+
train_loss = train_epoch(model, train_loader, optimizer, DEVICE)
|
| 197 |
+
print(f"Train Loss (Gender): {train_loss:.4f}")
|
| 198 |
+
|
| 199 |
+
gender_acc = eval_model(model, val_loader, DEVICE)
|
| 200 |
+
print(f"Gender Validation Accuracy: {gender_acc:.4f} ({gender_acc*100:.2f}%)")
|
| 201 |
+
|
| 202 |
+
# エポックの経過時間を表示
|
| 203 |
+
epoch_time = time.time() - epoch_start_time
|
| 204 |
+
elapsed_time = time.time() - start_time
|
| 205 |
+
remaining_epochs = EPOCHS - (epoch + 1)
|
| 206 |
+
estimated_remaining_time = (elapsed_time / (epoch + 1)) * remaining_epochs
|
| 207 |
+
|
| 208 |
+
print(f"\nエポック所要時間: {epoch_time/60:.1f}分")
|
| 209 |
+
print(f"経過時間: {elapsed_time/3600:.1f}時間")
|
| 210 |
+
print(f"推定残り時間: {estimated_remaining_time/3600:.1f}時間")
|
| 211 |
+
print(f"{'='*60}")
|
| 212 |
+
|
| 213 |
+
print("\n--- 5. 学習完了 ---")
|
| 214 |
+
torch.save(model.state_dict(), 'bert_gender_model.bin')
|
| 215 |
+
print("モデルを 'bert_gender_model.bin' に保存しました。")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == '__main__':
|
| 219 |
+
main()
|
| 220 |
+
|
README.md
CHANGED
|
@@ -1,12 +1,43 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.49.1
|
| 8 |
-
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: 年代・性別推定システム
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
|
|
|
|
|
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# 年代・性別推定システム
|
| 13 |
+
|
| 14 |
+
日本語テキストから年代と性別を推定するAIシステムです。BERTベースのモデルを使用して、入力されたテキストの特徴から年代(10代〜60代)と性別(男性・女性)を確率で予測します。
|
| 15 |
+
|
| 16 |
+
## 機能
|
| 17 |
+
|
| 18 |
+
- **年代推定**: 10代、20代、30代、40代、50代、60代の6つの年代を確率で予測
|
| 19 |
+
- **性別推定**: 男性・女性を確率で予測
|
| 20 |
+
- **リアルタイム予測**: Webアプリケーションでリアルタイムに予測結果を表示
|
| 21 |
+
|
| 22 |
+
## 技術仕様
|
| 23 |
+
|
| 24 |
+
- **ベースモデル**: cl-tohoku/bert-large-japanese
|
| 25 |
+
- **フレームワーク**: PyTorch, Transformers, Gradio
|
| 26 |
+
- **デプロイ**: Gradio Spaces
|
| 27 |
+
|
| 28 |
+
## 使用方法
|
| 29 |
+
|
| 30 |
+
1. テキストボックスに日本語のテキストを入力
|
| 31 |
+
2. 「推測実行」ボタンをクリック
|
| 32 |
+
3. 年代と性別の確率が表示されます
|
| 33 |
+
|
| 34 |
+
## モデル詳細
|
| 35 |
+
|
| 36 |
+
- **年代モデル**: 各年代を独立した二値分類器として学習
|
| 37 |
+
- **性別モデル**: 2クラス分類(男性・女性)
|
| 38 |
+
- **学習データ**: 日本語テキストデータセット
|
| 39 |
+
- **精度**: 年代推定 約79%、性別推定 約70%
|
| 40 |
+
|
| 41 |
+
## ライセンス
|
| 42 |
+
|
| 43 |
+
MIT License
|
SupervisedLearning.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.utils.data import Dataset, DataLoader, RandomSampler
|
| 4 |
+
from torch.optim import AdamW
|
| 5 |
+
from transformers import BertJapaneseTokenizer, BertModel
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import accuracy_score
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import os # osモジュールをインポート
|
| 11 |
+
import numpy as np
|
| 12 |
+
try:
|
| 13 |
+
import matplotlib
|
| 14 |
+
matplotlib.use('Agg') # GUIバックエンドを使わない
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
MATPLOTLIB_AVAILABLE = True
|
| 17 |
+
except ImportError:
|
| 18 |
+
MATPLOTLIB_AVAILABLE = False
|
| 19 |
+
print("警告: matplotlibがインストールされていません。グラフは表示されません。")
|
| 20 |
+
|
| 21 |
+
from DataNLP import load_preprocessed_data
|
| 22 |
+
|
| 23 |
+
# --- スクリプトのディレクトリを基準にパスを設定 ---
|
| 24 |
+
# このスクリプト自身の絶対パスを取得
|
| 25 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 26 |
+
# 作業ディレクトリをこのスクリプトがあるディレクトリに変更
|
| 27 |
+
os.chdir(script_dir)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# 設定
|
| 31 |
+
PRE_TRAINED_MODEL_NAME = 'cl-tohoku/bert-large-japanese'
|
| 32 |
+
MAX_LEN = 128
|
| 33 |
+
BATCH_SIZE = 32 # バッチサイズを増加して高速化
|
| 34 |
+
EPOCHS = 10 # 重みを大幅に更新するためエポック数を増加
|
| 35 |
+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
NUM_AGE_CLASSIFIERS = 6 # 各年代ごとに二値分類器
|
| 37 |
+
AGE_CATEGORIES = ["10代", "20代", "30代", "40代", "50代", "60代"]
|
| 38 |
+
|
| 39 |
+
# --- データセットクラス ---
|
| 40 |
+
class CustomDataset(Dataset):
|
| 41 |
+
def __init__(self, texts, age_labels_dict, tokenizer, max_len):
|
| 42 |
+
"""
|
| 43 |
+
age_labels_dict: {'10代_label': array, '20代_label': array, ...}
|
| 44 |
+
"""
|
| 45 |
+
self.texts = texts
|
| 46 |
+
self.age_labels_dict = age_labels_dict
|
| 47 |
+
self.tokenizer = tokenizer
|
| 48 |
+
self.max_len = max_len
|
| 49 |
+
|
| 50 |
+
def __len__(self):
|
| 51 |
+
return len(self.texts)
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, item):
|
| 54 |
+
text = str(self.texts[item])
|
| 55 |
+
|
| 56 |
+
encoding = self.tokenizer.encode_plus(
|
| 57 |
+
text,
|
| 58 |
+
add_special_tokens=True,
|
| 59 |
+
max_length=self.max_len,
|
| 60 |
+
return_token_type_ids=False,
|
| 61 |
+
padding='max_length',
|
| 62 |
+
truncation=True,
|
| 63 |
+
return_attention_mask=True,
|
| 64 |
+
return_tensors='pt',
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# 各年代の二値ラベルを取得
|
| 68 |
+
age_labels = torch.tensor([
|
| 69 |
+
int(self.age_labels_dict[f"{age}_label"][item])
|
| 70 |
+
for age in AGE_CATEGORIES
|
| 71 |
+
], dtype=torch.float)
|
| 72 |
+
|
| 73 |
+
return {
|
| 74 |
+
'input_ids': encoding['input_ids'].flatten(),
|
| 75 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
| 76 |
+
'age_labels': age_labels, # shape: (6,) - 各年代の二値ラベル
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
# --- モデル定義 ---
|
| 80 |
+
class BertForAgeClassification(nn.Module):
|
| 81 |
+
def __init__(self, model_name, num_age_classifiers):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.bert = BertModel.from_pretrained(model_name, use_safetensors=True)
|
| 84 |
+
self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
|
| 85 |
+
|
| 86 |
+
# 各年代ごとに二値分類器を作成(6個)
|
| 87 |
+
self.age_classifiers = nn.ModuleList([
|
| 88 |
+
nn.Linear(self.bert.config.hidden_size, 1) # 二値分類なので出力は1
|
| 89 |
+
for _ in range(num_age_classifiers)
|
| 90 |
+
])
|
| 91 |
+
|
| 92 |
+
def forward(self, input_ids, attention_mask, age_labels=None):
|
| 93 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
| 94 |
+
pooled_output = outputs.pooler_output
|
| 95 |
+
pooled_output = self.dropout(pooled_output)
|
| 96 |
+
|
| 97 |
+
# 各年代の二値分類器の出力を取得
|
| 98 |
+
age_logits_list = [classifier(pooled_output) for classifier in self.age_classifiers]
|
| 99 |
+
age_logits = torch.cat(age_logits_list, dim=1) # shape: (batch_size, 6)
|
| 100 |
+
|
| 101 |
+
loss = None
|
| 102 |
+
if age_labels is not None:
|
| 103 |
+
# 各年代の二値分類損失(BCEWithLogitsLoss)
|
| 104 |
+
bce_loss = nn.BCEWithLogitsLoss()
|
| 105 |
+
loss = bce_loss(age_logits, age_labels)
|
| 106 |
+
|
| 107 |
+
return loss, age_logits
|
| 108 |
+
|
| 109 |
+
# --- 学習関数 ---
|
| 110 |
+
def train_epoch(model, data_loader, optimizer, device):
|
| 111 |
+
model.train()
|
| 112 |
+
total_loss = 0
|
| 113 |
+
|
| 114 |
+
for batch in tqdm(data_loader, desc="Training"):
|
| 115 |
+
input_ids = batch['input_ids'].to(device)
|
| 116 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 117 |
+
age_labels = batch['age_labels'].to(device)
|
| 118 |
+
|
| 119 |
+
optimizer.zero_grad()
|
| 120 |
+
|
| 121 |
+
# モデルのforward関数から出力を取得
|
| 122 |
+
loss, age_logits = model(input_ids=input_ids, attention_mask=attention_mask, age_labels=age_labels)
|
| 123 |
+
|
| 124 |
+
if loss is not None:
|
| 125 |
+
loss.backward()
|
| 126 |
+
optimizer.step()
|
| 127 |
+
total_loss += loss.item()
|
| 128 |
+
|
| 129 |
+
return total_loss / len(data_loader)
|
| 130 |
+
|
| 131 |
+
# --- 評価関数 ---
|
| 132 |
+
def eval_model(model, data_loader, device):
|
| 133 |
+
model.eval()
|
| 134 |
+
age_preds_all = {age: [] for age in AGE_CATEGORIES} # 各年代の予測
|
| 135 |
+
age_true_all = {age: [] for age in AGE_CATEGORIES} # 各年代の正解
|
| 136 |
+
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
for batch in tqdm(data_loader, desc="Evaluating"):
|
| 139 |
+
input_ids = batch['input_ids'].to(device)
|
| 140 |
+
attention_mask = batch['attention_mask'].to(device)
|
| 141 |
+
|
| 142 |
+
_, age_logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 143 |
+
|
| 144 |
+
# 各年代の二値分類の予測(シグモイド関数で0-1に変換後、0.5で閾値判定)
|
| 145 |
+
age_probs = torch.sigmoid(age_logits) # shape: (batch_size, 6)
|
| 146 |
+
age_preds_binary = (age_probs > 0.5).cpu().numpy() # shape: (batch_size, 6)
|
| 147 |
+
age_true_binary = batch['age_labels'].cpu().numpy() # shape: (batch_size, 6)
|
| 148 |
+
|
| 149 |
+
# 各年代ごとに予測と正解を保存
|
| 150 |
+
for i, age in enumerate(AGE_CATEGORIES):
|
| 151 |
+
age_preds_all[age].extend(age_preds_binary[:, i])
|
| 152 |
+
age_true_all[age].extend(age_true_binary[:, i])
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# 各年代の精度を計算
|
| 156 |
+
age_accuracies = {}
|
| 157 |
+
for age in AGE_CATEGORIES:
|
| 158 |
+
age_accuracies[age] = accuracy_score(age_true_all[age], age_preds_all[age])
|
| 159 |
+
|
| 160 |
+
return age_accuracies
|
| 161 |
+
|
| 162 |
+
# --- 学習曲線表示関数 ---
|
| 163 |
+
def plot_training_curves(train_losses, val_accuracies):
|
| 164 |
+
"""
|
| 165 |
+
学習曲線(Loss CurveとAccuracy Curve)を表示する
|
| 166 |
+
"""
|
| 167 |
+
if not MATPLOTLIB_AVAILABLE:
|
| 168 |
+
print("matplotlibが利用できないため、グラフを表示できません。")
|
| 169 |
+
return
|
| 170 |
+
|
| 171 |
+
epochs = range(1, len(train_losses) + 1)
|
| 172 |
+
|
| 173 |
+
# 2つのサブプロットを作成
|
| 174 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| 175 |
+
|
| 176 |
+
# Loss Curve
|
| 177 |
+
ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2)
|
| 178 |
+
ax1.set_title('Training Loss Curve', fontsize=14, fontweight='bold')
|
| 179 |
+
ax1.set_xlabel('Epoch')
|
| 180 |
+
ax1.set_ylabel('Loss')
|
| 181 |
+
ax1.grid(True, alpha=0.3)
|
| 182 |
+
ax1.legend()
|
| 183 |
+
|
| 184 |
+
# Accuracy Curve
|
| 185 |
+
colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown']
|
| 186 |
+
for i, age in enumerate(AGE_CATEGORIES):
|
| 187 |
+
ax2.plot(epochs, val_accuracies[age], color=colors[i],
|
| 188 |
+
label=f'{age} Accuracy', linewidth=2, marker='o', markersize=4)
|
| 189 |
+
|
| 190 |
+
ax2.set_title('Validation Accuracy Curves', fontsize=14, fontweight='bold')
|
| 191 |
+
ax2.set_xlabel('Epoch')
|
| 192 |
+
ax2.set_ylabel('Accuracy')
|
| 193 |
+
ax2.set_ylim(0, 1)
|
| 194 |
+
ax2.grid(True, alpha=0.3)
|
| 195 |
+
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
| 196 |
+
|
| 197 |
+
plt.tight_layout()
|
| 198 |
+
plt.savefig('age_training_curves.png', dpi=300, bbox_inches='tight')
|
| 199 |
+
plt.show()
|
| 200 |
+
|
| 201 |
+
# 最終的な精度を表示
|
| 202 |
+
print("\n=== 最終的な検証精度 ===")
|
| 203 |
+
for age in AGE_CATEGORIES:
|
| 204 |
+
final_acc = val_accuracies[age][-1]
|
| 205 |
+
print(f"{age}: {final_acc:.4f} ({final_acc*100:.2f}%)")
|
| 206 |
+
|
| 207 |
+
avg_acc = np.mean([val_accuracies[age][-1] for age in AGE_CATEGORIES])
|
| 208 |
+
print(f"\n平均精度: {avg_acc:.4f} ({avg_acc*100:.2f}%)")
|
| 209 |
+
|
| 210 |
+
# --- データサンプリング関数(年代と性別を別々にバランシング) ---
|
| 211 |
+
def sample_balanced_data(df, max_per_age=5000, max_per_gender=5000):
|
| 212 |
+
"""
|
| 213 |
+
年代と性別を別々にバランシングする
|
| 214 |
+
- 年代:各年代ごとに最大max_per_age件(性別関係なく)
|
| 215 |
+
- 性別:各性別ごとに最大max_per_gender件(年代関係なく)
|
| 216 |
+
|
| 217 |
+
両方の条件を満たすデータのみを残す
|
| 218 |
+
"""
|
| 219 |
+
# 年代ごとにサンプリング
|
| 220 |
+
age_sampled_dfs = []
|
| 221 |
+
for age in AGE_CATEGORIES:
|
| 222 |
+
subset = df[df['年代'] == age]
|
| 223 |
+
if len(subset) > max_per_age:
|
| 224 |
+
subset = subset.sample(max_per_age, random_state=42)
|
| 225 |
+
age_sampled_dfs.append(subset)
|
| 226 |
+
age_balanced_df = pd.concat(age_sampled_dfs).reset_index(drop=True)
|
| 227 |
+
|
| 228 |
+
# 性別ごとにサンプリング
|
| 229 |
+
gender_sampled_dfs = []
|
| 230 |
+
for gender_label in age_balanced_df['性別_label'].unique():
|
| 231 |
+
subset = age_balanced_df[age_balanced_df['性別_label'] == gender_label]
|
| 232 |
+
if len(subset) > max_per_gender:
|
| 233 |
+
subset = subset.sample(max_per_gender, random_state=42)
|
| 234 |
+
gender_sampled_dfs.append(subset)
|
| 235 |
+
|
| 236 |
+
return pd.concat(gender_sampled_dfs).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def create_balanced_binary_labels(df, samples_per_label=2000):
|
| 240 |
+
"""
|
| 241 |
+
各年代の二値分類器を完全に独立させてバランスを取る
|
| 242 |
+
- 各年代について、正例と負例を同じ数にする(重複なし)
|
| 243 |
+
- samples_per_label: 各ラベル(正例・負例)あたりのサンプル数
|
| 244 |
+
"""
|
| 245 |
+
# まず、各年代ごとに利用可能なデータ数を確認
|
| 246 |
+
print("\n各年代のデータ数:")
|
| 247 |
+
for age in AGE_CATEGORIES:
|
| 248 |
+
count = len(df[df['年代'] == age])
|
| 249 |
+
print(f" {age}: {count}件")
|
| 250 |
+
|
| 251 |
+
# 各年代用のデータセットを個別に作成
|
| 252 |
+
age_datasets = {}
|
| 253 |
+
|
| 254 |
+
for age in AGE_CATEGORIES:
|
| 255 |
+
print(f"\n {age}の二値分類器用デー��を作成中...")
|
| 256 |
+
|
| 257 |
+
# 正例(該当年代)のデータ
|
| 258 |
+
positive_samples = df[df['年代'] == age].copy()
|
| 259 |
+
actual_positive = min(len(positive_samples), samples_per_label)
|
| 260 |
+
if len(positive_samples) > samples_per_label:
|
| 261 |
+
positive_samples = positive_samples.sample(samples_per_label, random_state=42)
|
| 262 |
+
else:
|
| 263 |
+
print(f" 警告: {age}の正例は{len(positive_samples)}件しかありません")
|
| 264 |
+
|
| 265 |
+
# 負例(他の年代)のデータ - 正例と同じ数だけサンプリング
|
| 266 |
+
negative_samples = df[df['年代'] != age].copy()
|
| 267 |
+
target_negative = len(positive_samples) # 正例と同じ数
|
| 268 |
+
if len(negative_samples) > target_negative:
|
| 269 |
+
negative_samples = negative_samples.sample(target_negative, random_state=42)
|
| 270 |
+
|
| 271 |
+
# 正例と負例を結合(この年代専用)
|
| 272 |
+
age_dataset = pd.concat([positive_samples, negative_samples]).reset_index(drop=True)
|
| 273 |
+
age_datasets[age] = age_dataset
|
| 274 |
+
|
| 275 |
+
print(f" {age}: 正例{len(positive_samples)}件, 負例{len(negative_samples)}件 (合計{len(age_dataset)}件)")
|
| 276 |
+
|
| 277 |
+
# 全ての年代のデータセットを結合してシャッフル
|
| 278 |
+
# ※各データは複数の年代の分類器で使われるが、各分類器内ではバランスが取れている
|
| 279 |
+
all_data = []
|
| 280 |
+
for age, dataset in age_datasets.items():
|
| 281 |
+
all_data.append(dataset)
|
| 282 |
+
|
| 283 |
+
# インデックスで重複を除去(リスト型カラムがあるためdrop_duplicatesは使えない)
|
| 284 |
+
final_df = pd.concat(all_data, ignore_index=True)
|
| 285 |
+
final_df = final_df.loc[~final_df.index.duplicated(keep='first')]
|
| 286 |
+
final_df = final_df.sample(frac=1, random_state=42).reset_index(drop=True)
|
| 287 |
+
|
| 288 |
+
print(f"\n統合後のデータ数: {len(final_df)}件")
|
| 289 |
+
return final_df
|
| 290 |
+
|
| 291 |
+
# --- メイン処理 ---
|
| 292 |
+
def main():
|
| 293 |
+
print("--- 1. データ読み込み ---")
|
| 294 |
+
df, _ = load_preprocessed_data()
|
| 295 |
+
|
| 296 |
+
# --- 各年代の二値分類でバランスを取る ---
|
| 297 |
+
print("--- 各年代の二値分類でバランス調整 ---")
|
| 298 |
+
df = create_balanced_binary_labels(df, samples_per_label=2000) # 各ラベル2000件ずつ
|
| 299 |
+
|
| 300 |
+
# ラベルの分布を確認
|
| 301 |
+
print("\n各年代の二値ラベル分布(バランス調整後):")
|
| 302 |
+
for age in AGE_CATEGORIES:
|
| 303 |
+
positive_count = df[f"{age}_label"].sum()
|
| 304 |
+
negative_count = len(df) - positive_count
|
| 305 |
+
print(f" {age}: 正例{positive_count}件, 負例{negative_count}件")
|
| 306 |
+
|
| 307 |
+
print(f"\n合計データ数: {len(df)} 件")
|
| 308 |
+
|
| 309 |
+
# 訓練用と検証用に分割
|
| 310 |
+
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
|
| 311 |
+
|
| 312 |
+
print(f"\n--- 2. トークナイザとデータローダーの準備 ---")
|
| 313 |
+
tokenizer = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
|
| 314 |
+
|
| 315 |
+
# 各年代のラベルを辞書形式で渡す
|
| 316 |
+
train_age_labels_dict = {f"{age}_label": train_df[f"{age}_label"].values for age in AGE_CATEGORIES}
|
| 317 |
+
val_age_labels_dict = {f"{age}_label": val_df[f"{age}_label"].values for age in AGE_CATEGORIES}
|
| 318 |
+
|
| 319 |
+
train_dataset = CustomDataset(
|
| 320 |
+
train_df['text'].values,
|
| 321 |
+
train_age_labels_dict,
|
| 322 |
+
tokenizer,
|
| 323 |
+
MAX_LEN
|
| 324 |
+
)
|
| 325 |
+
train_sampler = RandomSampler(train_dataset)
|
| 326 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
|
| 327 |
+
|
| 328 |
+
val_dataset = CustomDataset(
|
| 329 |
+
val_df['text'].values,
|
| 330 |
+
val_age_labels_dict,
|
| 331 |
+
tokenizer,
|
| 332 |
+
MAX_LEN
|
| 333 |
+
)
|
| 334 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
| 335 |
+
|
| 336 |
+
print("\n--- 3. モデルのセットアップ ---")
|
| 337 |
+
model = BertForAgeClassification(PRE_TRAINED_MODEL_NAME, NUM_AGE_CLASSIFIERS)
|
| 338 |
+
model.to(DEVICE)
|
| 339 |
+
|
| 340 |
+
# BERT全体をファインチューニング(レイヤーごとに異なる学習率を設定)
|
| 341 |
+
optimizer = AdamW([
|
| 342 |
+
{'params': model.bert.parameters(), 'lr': 2e-5}, # BERT本体は小さい学習率
|
| 343 |
+
{'params': model.age_classifiers.parameters(), 'lr': 5e-4}, # 分類層は大きい学習率
|
| 344 |
+
])
|
| 345 |
+
|
| 346 |
+
print("\n--- 4. 学習開始 ---")
|
| 347 |
+
print(f"デバイス: {DEVICE}")
|
| 348 |
+
print(f"訓練データ数: {len(train_df)} 件")
|
| 349 |
+
print(f"検証データ数: {len(val_df)} 件")
|
| 350 |
+
print(f"バッチサイズ: {BATCH_SIZE}")
|
| 351 |
+
print(f"エポック数: {EPOCHS}")
|
| 352 |
+
print(f"推定学習時間: 約66時間")
|
| 353 |
+
|
| 354 |
+
# 学習履歴を保存するリスト
|
| 355 |
+
train_losses = []
|
| 356 |
+
val_accuracies = {age: [] for age in AGE_CATEGORIES}
|
| 357 |
+
|
| 358 |
+
import time
|
| 359 |
+
start_time = time.time()
|
| 360 |
+
|
| 361 |
+
for epoch in range(EPOCHS):
|
| 362 |
+
epoch_start_time = time.time()
|
| 363 |
+
print(f"\n{'='*60}")
|
| 364 |
+
print(f"Epoch {epoch + 1}/{EPOCHS} 開始")
|
| 365 |
+
print(f"{'='*60}")
|
| 366 |
+
|
| 367 |
+
train_loss = train_epoch(model, train_loader, optimizer, DEVICE)
|
| 368 |
+
print(f"Train Loss: {train_loss:.4f}")
|
| 369 |
+
|
| 370 |
+
# 学習損失を記録
|
| 371 |
+
train_losses.append(train_loss)
|
| 372 |
+
|
| 373 |
+
age_accuracies = eval_model(model, val_loader, DEVICE)
|
| 374 |
+
print("\nAge Validation Accuracies:")
|
| 375 |
+
for age in AGE_CATEGORIES:
|
| 376 |
+
print(f" {age}: {age_accuracies[age]:.4f} ({age_accuracies[age]*100:.2f}%)")
|
| 377 |
+
val_accuracies[age].append(age_accuracies[age])
|
| 378 |
+
|
| 379 |
+
# 平均精度を計算
|
| 380 |
+
avg_acc = sum(age_accuracies.values()) / len(age_accuracies)
|
| 381 |
+
print(f"\n平均精度: {avg_acc:.4f} ({avg_acc*100:.2f}%)")
|
| 382 |
+
|
| 383 |
+
# エポックの経過時間を表示
|
| 384 |
+
epoch_time = time.time() - epoch_start_time
|
| 385 |
+
elapsed_time = time.time() - start_time
|
| 386 |
+
remaining_epochs = EPOCHS - (epoch + 1)
|
| 387 |
+
estimated_remaining_time = (elapsed_time / (epoch + 1)) * remaining_epochs
|
| 388 |
+
|
| 389 |
+
print(f"\nエポック所要時間: {epoch_time/60:.1f}分")
|
| 390 |
+
print(f"経過時間: {elapsed_time/3600:.1f}時間")
|
| 391 |
+
print(f"推定残り時間: {estimated_remaining_time/3600:.1f}時間")
|
| 392 |
+
print(f"{'='*60}")
|
| 393 |
+
|
| 394 |
+
print("\n--- 5. 学習完了 ---")
|
| 395 |
+
torch.save(model.state_dict(), 'bert_age_model.bin')
|
| 396 |
+
print("モデルを 'bert_age_model.bin' に保存しました。")
|
| 397 |
+
|
| 398 |
+
# Loss CurveとAccuracy Curveを表示
|
| 399 |
+
print("\n--- 6. 学習曲線の表示 ---")
|
| 400 |
+
plot_training_curves(train_losses, val_accuracies)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
if __name__ == '__main__':
|
| 404 |
+
main()
|
app.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
from predictor import load_models, predict_text
|
| 5 |
+
|
| 6 |
+
# モデルの読み込み
|
| 7 |
+
print("=== モデル読み込み開始 ===")
|
| 8 |
+
try:
|
| 9 |
+
load_models()
|
| 10 |
+
print("✅ モデルの読み込みが完了しました")
|
| 11 |
+
except Exception as e:
|
| 12 |
+
print(f"❌ モデルの読み込みに失敗しました: {e}")
|
| 13 |
+
print("⚠️ モデルファイルが存在しない可能性があります")
|
| 14 |
+
|
| 15 |
+
def predict_age_gender(text):
|
| 16 |
+
"""年代・性別予測関数"""
|
| 17 |
+
if not text.strip():
|
| 18 |
+
return "テキストを入力してください。", "", ""
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
result = predict_text(text)
|
| 22 |
+
|
| 23 |
+
# 年代予測結果を整形
|
| 24 |
+
age_results = []
|
| 25 |
+
for age, percentage in result['age_percentages'].items():
|
| 26 |
+
age_results.append(f"{age}: {percentage}%")
|
| 27 |
+
age_text = "\n".join(age_results)
|
| 28 |
+
|
| 29 |
+
# 性別予測結果を整形
|
| 30 |
+
gender_results = []
|
| 31 |
+
for gender, percentage in result['gender_percentages'].items():
|
| 32 |
+
gender_results.append(f"{gender}: {percentage}%")
|
| 33 |
+
gender_text = "\n".join(gender_results)
|
| 34 |
+
|
| 35 |
+
# 最も高い確率の年代を特定
|
| 36 |
+
max_age = max(result['age_percentages'].items(), key=lambda x: x[1])
|
| 37 |
+
max_gender = max(result['gender_percentages'].items(), key=lambda x: x[1])
|
| 38 |
+
|
| 39 |
+
summary = f"推定結果: {max_age[0]} ({max_age[1]}%), {max_gender[0]} ({max_gender[1]}%)"
|
| 40 |
+
|
| 41 |
+
return summary, age_text, gender_text
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
return f"エラーが発生しました: {str(e)}", "", ""
|
| 45 |
+
|
| 46 |
+
# Gradioインターフェース
|
| 47 |
+
interface = gr.Interface(
|
| 48 |
+
fn=predict_age_gender,
|
| 49 |
+
inputs=gr.Textbox(
|
| 50 |
+
label="日本語テキストを入力してください",
|
| 51 |
+
placeholder="例: 今日はとても良い天気ですね。友達と一緒に散歩をしました。",
|
| 52 |
+
lines=3
|
| 53 |
+
),
|
| 54 |
+
outputs=[
|
| 55 |
+
gr.Textbox(label="推定結果サマリー"),
|
| 56 |
+
gr.Textbox(label="年代予測詳細"),
|
| 57 |
+
gr.Textbox(label="性別予測詳細")
|
| 58 |
+
],
|
| 59 |
+
title="🧠 年代・性別推定システム",
|
| 60 |
+
description="日本語テキストから年代と性別を推定するAIシステムです。",
|
| 61 |
+
examples=[
|
| 62 |
+
"今日はとても良い天気ですね。",
|
| 63 |
+
"友達と一緒に散歩をしました。",
|
| 64 |
+
"新しいスマートフォンを買いました。",
|
| 65 |
+
"仕事が忙しくて疲れました。"
|
| 66 |
+
],
|
| 67 |
+
theme=gr.themes.Soft()
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# アプリケーション起動
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
interface.launch(
|
| 73 |
+
server_name="0.0.0.0",
|
| 74 |
+
server_port=7860,
|
| 75 |
+
share=False # Hugging Face SpacesではFalse
|
| 76 |
+
)
|
predictor.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import os
|
| 4 |
+
from transformers import BertJapaneseTokenizer
|
| 5 |
+
|
| 6 |
+
# 年代モデルと性別モデルの定義をインポート
|
| 7 |
+
from SupervisedLearning import BertForAgeClassification, PRE_TRAINED_MODEL_NAME, DEVICE, NUM_AGE_CLASSIFIERS, AGE_CATEGORIES
|
| 8 |
+
from GenderLearning import BertForGenderClassification, NUM_GENDER_LABELS
|
| 9 |
+
|
| 10 |
+
# モデルファイルのパス
|
| 11 |
+
AGE_MODEL_PATH = 'bert_age_model.bin'
|
| 12 |
+
GENDER_MODEL_PATH = 'bert_gender_model.bin'
|
| 13 |
+
|
| 14 |
+
# 性別のカテゴリマッピング
|
| 15 |
+
GENDER_CATEGORIES = ["male", "female"]
|
| 16 |
+
GENDER_CATEGORIES_JP = ["男性", "女性"]
|
| 17 |
+
|
| 18 |
+
# --- グローバル変数としてモデルとトークナイザを一度だけロード ---
|
| 19 |
+
TOKENIZER = None
|
| 20 |
+
AGE_MODEL = None
|
| 21 |
+
GENDER_MODEL = None
|
| 22 |
+
|
| 23 |
+
def load_models():
|
| 24 |
+
"""アプリケーション起動時にモデルを一度だけ読み込む"""
|
| 25 |
+
global TOKENIZER, AGE_MODEL, GENDER_MODEL
|
| 26 |
+
|
| 27 |
+
# モデルファイルの存在確認
|
| 28 |
+
if not os.path.exists(AGE_MODEL_PATH):
|
| 29 |
+
raise FileNotFoundError(f"エラー: 年代学習済みモデル '{AGE_MODEL_PATH}' が見つかりません。")
|
| 30 |
+
|
| 31 |
+
# 性別モデルはまだ学習されていない可能性があるので、警告のみ表示
|
| 32 |
+
if not os.path.exists(GENDER_MODEL_PATH):
|
| 33 |
+
print(f"警告: 性別学習済みモデル '{GENDER_MODEL_PATH}' が見つかりません。")
|
| 34 |
+
print("性別予測は利用できません。年代予測のみ実行されます。")
|
| 35 |
+
|
| 36 |
+
print("--- モデルの読み込みを開始します ---")
|
| 37 |
+
TOKENIZER = BertJapaneseTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
|
| 38 |
+
|
| 39 |
+
# 年代モデルの読み込み
|
| 40 |
+
print(" 年代モデルを読み込み中...")
|
| 41 |
+
AGE_MODEL = BertForAgeClassification(PRE_TRAINED_MODEL_NAME, NUM_AGE_CLASSIFIERS)
|
| 42 |
+
try:
|
| 43 |
+
if torch.__version__.startswith('1.'):
|
| 44 |
+
AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE))
|
| 45 |
+
else:
|
| 46 |
+
AGE_MODEL.load_state_dict(torch.load(AGE_MODEL_PATH, map_location=DEVICE, weights_only=True))
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"年代モデルの読み込み中にエラーが発生しました: {e}")
|
| 49 |
+
raise
|
| 50 |
+
AGE_MODEL.to(DEVICE)
|
| 51 |
+
AGE_MODEL.eval()
|
| 52 |
+
|
| 53 |
+
# 性別モデルの読み込み(存在する場合のみ)
|
| 54 |
+
if os.path.exists(GENDER_MODEL_PATH):
|
| 55 |
+
print(" 性別モデルを読み込み中...")
|
| 56 |
+
GENDER_MODEL = BertForGenderClassification(PRE_TRAINED_MODEL_NAME, NUM_GENDER_LABELS)
|
| 57 |
+
try:
|
| 58 |
+
if torch.__version__.startswith('1.'):
|
| 59 |
+
GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE))
|
| 60 |
+
else:
|
| 61 |
+
GENDER_MODEL.load_state_dict(torch.load(GENDER_MODEL_PATH, map_location=DEVICE, weights_only=True))
|
| 62 |
+
except Exception as e:
|
| 63 |
+
print(f"性別モデルの読み込み中にエラーが発生しました: {e}")
|
| 64 |
+
raise
|
| 65 |
+
GENDER_MODEL.to(DEVICE)
|
| 66 |
+
GENDER_MODEL.eval()
|
| 67 |
+
else:
|
| 68 |
+
GENDER_MODEL = None
|
| 69 |
+
|
| 70 |
+
print("--- モデルの読み込みが完了しました ---")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def predict_text(text: str):
|
| 74 |
+
"""
|
| 75 |
+
入力されたテキストから「年代」と「性別」の各ラベルのパーセンテージを返す関数
|
| 76 |
+
"""
|
| 77 |
+
if AGE_MODEL is None or TOKENIZER is None:
|
| 78 |
+
load_models()
|
| 79 |
+
|
| 80 |
+
print(f"DEBUG: 入力テキスト: '{text}'")
|
| 81 |
+
|
| 82 |
+
# テキストの前処理
|
| 83 |
+
encoding = TOKENIZER.encode_plus(
|
| 84 |
+
text,
|
| 85 |
+
add_special_tokens=True,
|
| 86 |
+
max_length=128,
|
| 87 |
+
return_token_type_ids=False,
|
| 88 |
+
padding='max_length',
|
| 89 |
+
truncation=True,
|
| 90 |
+
return_attention_mask=True,
|
| 91 |
+
return_tensors='pt',
|
| 92 |
+
)
|
| 93 |
+
input_ids = encoding['input_ids'].to(DEVICE)
|
| 94 |
+
attention_mask = encoding['attention_mask'].to(DEVICE)
|
| 95 |
+
|
| 96 |
+
print(f"DEBUG: input_ids shape: {input_ids.shape}")
|
| 97 |
+
print(f"DEBUG: attention_mask shape: {attention_mask.shape}")
|
| 98 |
+
|
| 99 |
+
# 年代の予測
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
_, age_logits = AGE_MODEL(input_ids=input_ids, attention_mask=attention_mask)
|
| 102 |
+
|
| 103 |
+
print(f"DEBUG: age_logits shape: {age_logits.shape}")
|
| 104 |
+
print(f"DEBUG: age_logits values: {age_logits}")
|
| 105 |
+
|
| 106 |
+
# 各年代の二値分類の確率(シグモイド関数)
|
| 107 |
+
age_probs = torch.sigmoid(age_logits)[0] # shape: (6,)
|
| 108 |
+
print(f"DEBUG: age_probs shape: {age_probs.shape}")
|
| 109 |
+
print(f"DEBUG: age_probs values: {age_probs}")
|
| 110 |
+
|
| 111 |
+
# 年代の確率を辞書形式で保存
|
| 112 |
+
age_percentages = {}
|
| 113 |
+
for i, age in enumerate(AGE_CATEGORIES):
|
| 114 |
+
percentage = float(f"{age_probs[i].item() * 100:.2f}") # 小数第2位まで
|
| 115 |
+
age_percentages[age] = percentage
|
| 116 |
+
print(f"DEBUG: {age}: {age_probs[i].item()} -> {percentage}%")
|
| 117 |
+
|
| 118 |
+
# 性別の予測(モデルが存在する場合のみ)
|
| 119 |
+
if GENDER_MODEL is not None:
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
_, gender_logits = GENDER_MODEL(input_ids=input_ids, attention_mask=attention_mask)
|
| 122 |
+
|
| 123 |
+
print(f"DEBUG: gender_logits shape: {gender_logits.shape}")
|
| 124 |
+
print(f"DEBUG: gender_logits values: {gender_logits}")
|
| 125 |
+
|
| 126 |
+
# 性別の確率(Softmax関数)
|
| 127 |
+
gender_probs = F.softmax(gender_logits, dim=1)[0] # shape: (2,)
|
| 128 |
+
print(f"DEBUG: gender_probs shape: {gender_probs.shape}")
|
| 129 |
+
print(f"DEBUG: gender_probs values: {gender_probs}")
|
| 130 |
+
|
| 131 |
+
# 性別の確率を辞書形式で保存
|
| 132 |
+
gender_percentages = {}
|
| 133 |
+
for i, gender_jp in enumerate(GENDER_CATEGORIES_JP):
|
| 134 |
+
percentage = float(f"{gender_probs[i].item() * 100:.2f}") # 小数第2位まで
|
| 135 |
+
gender_percentages[gender_jp] = percentage
|
| 136 |
+
print(f"DEBUG: {gender_jp}: {gender_probs[i].item()} -> {percentage}%")
|
| 137 |
+
else:
|
| 138 |
+
# 性別モデルが存在しない場合はデフォルト値を設定
|
| 139 |
+
gender_percentages = {"男性": 50.0, "女性": 50.0}
|
| 140 |
+
print("DEBUG: 性別モデルが存在しないため、デフォルト値を設定しました")
|
| 141 |
+
|
| 142 |
+
# 結果を返す
|
| 143 |
+
results = {
|
| 144 |
+
"age_percentages": age_percentages,
|
| 145 |
+
"gender_percentages": gender_percentages
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
print(f"DEBUG: 最終結果: {results}")
|
| 149 |
+
return results
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
torch>=1.9.0
|
| 3 |
+
transformers>=4.21.0
|
| 4 |
+
pandas>=1.3.0
|
| 5 |
+
scikit-learn>=1.0.0
|
| 6 |
+
numpy>=1.21.0
|
| 7 |
+
fugashi>=1.2.0
|
| 8 |
+
ipadic>=1.0.0
|
| 9 |
+
sudachipy>=0.6.0
|
| 10 |
+
sudachidict-core>=20240101
|
| 11 |
+
neologdn>=0.0.0
|