| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| import torch |
| import torch.distributions as dists |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| import numpy as np |
|
|
| import math |
|
|
|
|
| import sacrebleu |
|
|
| from rouge import Rouge |
|
|
| @dataclass |
| class DiscreteDiffusionGeneratorArguments: |
| max_iterations: int = field( |
| default=10 |
| ) |
| mbr: int = field( |
| default=1 |
| ) |
| length_beam: int = field( |
| default=1 |
| ) |
| oracle_length: bool = field( |
| default=False |
| ) |
| strategy: str = field( |
| default="reparam-uncond-deterministic-cosine" |
| ) |
| argmax_decoding: bool = field( |
| default=True |
| ) |
| bpe: str = field( |
| default="sentencepiece" |
| ) |
| bleu_tokenize: str = field( |
| default="13a" |
| ) |
| return_history: bool = field( |
| default=False |
| ) |
| temperature: float = field( |
| default=0.8 |
| ) |
|
|
|
|
|
|
| def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0): |
| """ |
| scores: [b, n] |
| cutoff_len: [b, 1] |
| stochastic: bool, whether to add noise to select top_k or not |
| returns: |
| mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise |
| """ |
| if stochastic: |
| gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8) |
| _scores = scores + temp * gumbel_noise |
| else: |
| _scores = scores |
| sorted_index = _scores.sort(-1)[0] |
| cutoff = sorted_index.gather(dim=-1, index=cutoff_len) |
| |
| masking = _scores < cutoff |
| try: |
| assert (~(cutoff_len == 0).all()) | (~masking).all() |
| except: |
| import ipdb;ipdb.set_trace() |
| return masking |
|
|
|
|
| class MergeBLEU(object): |
| def __call__(self, evalpreds): |
| |
| |
| |
| |
| import inspect |
| sys_stats, ref_stats = evalpreds[0], evalpreds[1] |
| |
| sys_stats = sys_stats.reshape(-1, 5).astype('long').sum(0).tolist() |
| ref_stats = ref_stats.reshape(-1, 5).astype('long').sum(0).tolist() |
| try: |
| from sacrebleu.metrics import BLEU |
| comp_bleu = BLEU.compute_bleu |
| except ImportError: |
| comp_bleu = sacrebleu.compute_bleu |
| fn_sig = inspect.getfullargspec(comp_bleu)[0] |
| if "smooth_method" in fn_sig: |
| smooth = {"smooth_method": "exp"} |
| else: |
| smooth = {"smooth": "exp"} |
| return { |
| "bleu": comp_bleu( |
| correct=sys_stats[:4], |
| total=ref_stats[:4], |
| sys_len=sys_stats[-1], |
| ref_len=ref_stats[-1], |
| **smooth |
| ).score |
| } |
|
|
| class MergeRouge(object): |
| def __call__(self, evalpreds): |
| |
| |
| |
| |
| import inspect |
| |
| avg_rouge, batch_size = evalpreds[0], evalpreds[1] |
| |
| rouge = (avg_rouge * batch_size).sum() / batch_size.sum() |
| |
| return { |
| "rouge": rouge |
| } |
| |
|
|
| class DiscreteDiffusionGenerator: |
| def __init__(self, args, dictionary=None, tokenizer=None) -> None: |
| self.args = args |
| self.dictionary = dictionary |
| self.tokenizer = tokenizer |
| self.write_prediction = None |
| |
| assert (dictionary is not None) or (tokenizer is not None) |
| assert (dictionary is None) ^ (tokenizer is None) |
| |
| self.retain_history = args.return_history |
| |
| if dictionary is not None: |
| self.pad_id = dictionary.pad() |
| self.bos_id = dictionary.bos() |
| self.eos_id = dictionary.eos() |
| self.mask_id = dictionary.mask_index |
| else: |
| self.pad_id = tokenizer.pad_token_id |
| self.bos_id = tokenizer.bos_token_id |
| self.eos_id = tokenizer.eos_token_id |
| self.mask_id = tokenizer.mask_token_id |
| |
| self.rouge = Rouge(["rouge-l"]) |
| |
| def set_write_to(self, path): |
| self.write_prediction = path |
| |
| def _reparam_decoding( |
| self, |
| output_tokens, |
| output_scores, |
| cur_tokens, |
| cur_scores, |
| decoding_strategy, |
| xt_neq_x0, |
| non_special_sym_mask, |
| t, |
| max_step, |
| noise |
| ): |
| """ |
| This function is used to perform reparameterized decoding. |
| """ |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _, condition, topk_mode, schedule = decoding_strategy.split("-") |
|
|
| |
| if schedule == "linear": |
| rate = 1 - t / max_step |
| elif schedule == "cosine": |
| rate = np.cos(t / max_step * np.pi * 0.5) |
| else: |
| raise NotImplementedError |
|
|
| |
| cutoff_len = ( |
| non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate |
| ).long() |
| |
| _scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0) |
| |
| |
| if topk_mode.startswith("stochastic"): |
| noise_scale = float(topk_mode.replace("stochastic", "")) |
| lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate) |
| elif topk_mode == "deterministic": |
| lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False) |
| else: |
| raise NotImplementedError |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if condition == "cond": |
| not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask |
| elif condition == "uncond": |
| not_v1_t = lowest_k_mask |
| else: |
| raise NotImplementedError |
| |
| |
| not_v2_t = lowest_k_mask |
|
|
| masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t) |
| if isinstance(noise, torch.Tensor): |
| output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise]) |
| elif isinstance(noise, (int, float)): |
| output_tokens.masked_fill_(masked_to_noise, noise) |
| else: |
| raise NotImplementedError("noise should be either a tensor or a scalar") |
| output_scores.masked_fill_(masked_to_noise, -math.inf) |
|
|
| masked_to_x0 = xt_neq_x0 & ~not_v2_t |
| output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0]) |
| output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0]) |
| |
| |
| |
| new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t |
| return new_xt_neq_x0 |
| |
| def denoise_step(self, model, decoder_out, partial_masks): |
| output_tokens = decoder_out.output_tokens |
| output_scores = decoder_out.output_scores |
| prev_step, cur_step = decoder_out.step, decoder_out.step + 1 |
| max_step = decoder_out.max_step |
| temperature = self.args.temperature |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| logits = model(output_tokens, partial_masks) |
| |
| logits[..., self.mask_id] = -math.inf |
| scores = torch.log_softmax(logits, dim=-1) |
| |
| |
| if self.args.strategy == "cmlm": |
| |
| |
| |
| output_masks = output_tokens.eq(self.mask_id) |
| unmask_prob = 1 / (max_step - prev_step) |
| |
| changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob |
| |
| changes = torch.bitwise_and(changes, output_masks) |
|
|
| if self.args.argmax_decoding: |
| output_scores, new_tokens = scores.max(-1) |
| else: |
| new_tokens = dists.Categorical(logits=scores / temperature).sample() |
| output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
| output_tokens[changes] = new_tokens[changes] |
| elif self.args.strategy == "ar": |
| output_masks = output_tokens.eq(self.mask_id) |
| unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1) |
| indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device) |
| if self.args.argmax_decoding: |
| output_scores, new_tokens = scores.max(-1) |
| else: |
| new_tokens = dists.Categorical(logits=scores / temperature).sample() |
| output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
| output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices] |
| |
| else: |
| if self.args.argmax_decoding: |
| cur_scores, cur_tokens = scores.max(-1) |
| else: |
| cur_tokens = dists.Categorical(logits=scores / temperature).sample() |
| cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1) |
| cur_scores = cur_scores.to(output_scores) |
| |
| output_masks = self._reparam_decoding( |
| output_tokens=output_tokens, |
| output_scores=output_scores, |
| cur_tokens=cur_tokens, |
| cur_scores=cur_scores, |
| decoding_strategy=self.args.strategy, |
| xt_neq_x0=decoder_out.output_masks, |
| non_special_sym_mask=decoder_out.non_fixed_sym_masks, |
| t=cur_step, |
| max_step=max_step, |
| noise=self.mask_id |
| ) |
| if self.retain_history: |
| history = ([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()] |
| else: |
| history = None |
| |
| |
| |
| |
| |
| return decoder_out._replace( |
| step=cur_step, |
| output_tokens=output_tokens, |
| output_scores=output_scores, |
| output_masks=output_masks, |
| history=history, |
| ) |
|
|
| |
| def decode(self, seqs_tensors, preserve_special=False): |
| seqs_tensors[seqs_tensors < 0] = self.pad_id |
| if self.dictionary is not None: |
| seqs = [ |
| self.dictionary.string(seq, self.args.bpe).strip() |
| for seq in seqs_tensors |
| ] |
| if not preserve_special: |
| seqs = [seq.replace(self.dictionary.pad_word, '') for seq in seqs] |
| else: |
| seqs = self.tokenizer.batch_decode(seqs_tensors, skip_special_tokens=(not preserve_special)) |
| return [seq.lower() for seq in seqs] |
| |
| def compute_bleu(self, hyps, refs): |
| if isinstance(hyps, torch.Tensor): |
| hyps = self.decode(hyps) |
| if isinstance(refs, torch.Tensor): |
| refs = self.decode(refs) |
| return sacrebleu.corpus_bleu(hyps, [refs], tokenize=self.args.bleu_tokenize) |
| |
| def compute_rouge(self, hyps, refs): |
| if isinstance(hyps, torch.Tensor): |
| hyps = self.decode(hyps) |
| if isinstance(refs, torch.Tensor): |
| refs = self.decode(refs) |
| return self.rouge.get_scores(hyps, [[ref] for ref in refs])['rouge-l']['f'] |
| |
| def stepwise_generate(self, model, inputs): |
| src_tokens = inputs["net_input"]["src_tokens"] |
| partial_masks = inputs["net_input"]["partial_masks"] |
| |
| |
| |
| raw_model = model.module if hasattr(model, "module") else model |
| if "prefix_masks" in inputs["net_input"]: |
| prefix_masks = inputs["net_input"]["prefix_masks"] |
| else: |
| prefix_masks = partial_masks |
| |
| partial_masks, prev_decoder_out = raw_model.initialize_decode_samples( |
| src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr |
| ) |
| prev_decoder_out = prev_decoder_out._replace( |
| step=0, max_step=self.args.max_iterations |
| ) |
| for step in range(self.args.max_iterations): |
| prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks) |
| yield prev_decoder_out |
| |
| @torch.no_grad() |
| def generate(self, model, inputs): |
| src_tokens = inputs["net_input"]["src_tokens"] |
| partial_masks = inputs["net_input"]["partial_masks"] |
| |
| |
| |
| |
| if "prefix_masks" in inputs["net_input"]: |
| prefix_masks = inputs["net_input"]["prefix_masks"] |
| else: |
| prefix_masks = partial_masks |
| partial_masks, prev_decoder_out = model.initialize_decode_samples( |
| src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr |
| ) |
| prev_decoder_out = prev_decoder_out._replace( |
| step=0, max_step=self.args.max_iterations |
| ) |
| |
| for step in range(self.args.max_iterations): |
| prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks) |
| |
| def finalized_hypos(tokens, scores, partial_mask, history=None): |
| cutoff = ( |
| tokens.ne(self.pad_id) & |
| tokens.ne(self.bos_id) & |
| tokens.ne(self.eos_id) & |
| (~partial_mask) |
| ) |
| tokens = tokens[cutoff] |
| if scores is None: |
| score = None |
| else: |
| scores = scores[cutoff] |
| score = scores.mean().item() |
| ret_dict = { |
| "tokens": tokens, |
| "positional_scores": scores, |
| "score": score, |
| "alignment": None |
| } |
| if history is not None: |
| ret_dict["history"] = [ |
| finalized_hypos(history_tokens, None, partial_mask, history=None) |
| for history_tokens in history |
| ] |
| return ret_dict |
| |
| def mbr_select(hyps): |
| index = np.argmax(np.array( |
| [self.rouge.get_scores([hyps[i]], [[hyps[j]]])['rouge-l']['f'] |
| for j in range(len(hyps)) if i != j] |
| ).mean() for i in range(len(hyps))) |
| return hyps[index] |
| |
| def score_select(hyps): |
| index = np.argmax([hyp["score"] for hyp in hyps]) |
| return hyps[index] |
| |
| output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores |
| if self.retain_history: |
| full_history = prev_decoder_out.history |
| histories = [[full_history[j][i] for j in range(self.args.max_iterations)] for i in range(output_tokens.size(0))] |
| hyps = [] |
| for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories): |
| hyps.append(finalized_hypos(tokens, scores, partial_mask, history)) |
| |
| |
| |
| |
| else: |
| hyps = [ |
| finalized_hypos(tokens, scores, partial_mask, None) |
| for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks) |
| ] |
| repeatition = self.args.mbr * self.args.length_beam |
| if repeatition > 1: |
| hyps = [score_select(hyps[i:i+repeatition])for i in range(0, len(hyps), repeatition)] |
| |
| |
| finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id) |
| history = [[item["tokens"] for item in h["history"]] for h in hyps] if self.retain_history else None |
| return finalized, history |