neural-news-lab commited on
Commit
794de73
·
verified ·
1 Parent(s): e1c318f

Update new_sum.py

Browse files
Files changed (1) hide show
  1. new_sum.py +14 -13
new_sum.py CHANGED
@@ -1,9 +1,6 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
2
  import torch
3
 
4
- # ======================
5
- # MODEL SETUP
6
- # ======================
7
  MODEL_NAME = "cointegrated/rut5-base-multitask"
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
@@ -11,37 +8,41 @@ config = AutoConfig.from_pretrained(MODEL_NAME)
11
  config.tie_word_embeddings = False
12
 
13
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
-
15
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, config=config).to(device)
16
  model.eval()
17
 
18
- # ======================
19
- # CORE FUNCTION
20
- # ======================
21
  def generate_summary(text: str) -> str:
22
  if not text:
23
  return ""
24
 
25
- prompt = "summarize | " + text
 
26
 
27
  inputs = tokenizer(
28
  prompt,
29
  return_tensors="pt",
30
  truncation=True,
31
- padding="max_length",
32
  max_length=512
33
  ).to(device)
34
 
35
  with torch.no_grad():
36
  outputs = model.generate(
37
  **inputs,
38
- max_length=200,
39
  min_length=30,
40
- num_beams=3,
41
  do_sample=False,
 
42
  no_repeat_ngram_size=3,
43
- repetition_penalty=1.2,
44
  early_stopping=True
45
  )
46
 
47
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
2
  import torch
3
 
 
 
 
4
  MODEL_NAME = "cointegrated/rut5-base-multitask"
5
  device = "cuda" if torch.cuda.is_available() else "cpu"
6
 
 
8
  config.tie_word_embeddings = False
9
 
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
11
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, config=config).to(device)
12
  model.eval()
13
 
14
+
 
 
15
  def generate_summary(text: str) -> str:
16
  if not text:
17
  return ""
18
 
19
+ # чуть лучше для T5
20
+ prompt = "summarize: " + text
21
 
22
  inputs = tokenizer(
23
  prompt,
24
  return_tensors="pt",
25
  truncation=True,
26
+ padding="longest",
27
  max_length=512
28
  ).to(device)
29
 
30
  with torch.no_grad():
31
  outputs = model.generate(
32
  **inputs,
33
+ max_length=150,
34
  min_length=30,
35
+ num_beams=4,
36
  do_sample=False,
37
+ repetition_penalty=2.0,
38
  no_repeat_ngram_size=3,
 
39
  early_stopping=True
40
  )
41
 
42
+ summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+
44
+ # 🔥 ВАЖНО: защита от мусорных токенов
45
+ if "<0x" in summary or len(summary.strip()) < 10:
46
+ return "Model output invalid or unstable. Try different input."
47
+
48
+ return summary