| import torch |
| import torch.nn as nn |
| from torch.nn.utils.rnn import pad_sequence |
| from transformers import PreTrainedModel, AutoModelForMaskedLM, AutoConfig |
| try: |
| from .configuration_dlm import DiscreteDiffusionConfig |
| except ImportError: |
| from configuration_dlm import DiscreteDiffusionConfig |
|
|
| from collections import namedtuple |
| import math |
| import numpy as np |
| from typing import List, Optional, Tuple, Union |
|
|
| decoder_out_t = namedtuple( |
| "decoder_out_t", |
| ["output_tokens", "output_scores", "output_masks", "non_fixed_sym_masks", "attn", "step", "max_step", "history"], |
| ) |
|
|
| def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0): |
| """ |
| scores: [b, n] |
| cutoff_len: [b, 1] |
| stochastic: bool, whether to add noise to select top_k or not |
| returns: |
| mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise |
| """ |
| if stochastic: |
| gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8) |
| _scores = scores + temp * gumbel_noise |
| else: |
| _scores = scores |
| sorted_index = _scores.sort(-1)[0] |
| cutoff = sorted_index.gather(dim=-1, index=cutoff_len) |
| |
| masking = _scores < cutoff |
| return masking |
|
|
| class DiscreteDiffusionModel(PreTrainedModel): |
| config_class = DiscreteDiffusionConfig |
| _keys_to_ignore_on_load_missing = ["fake_layer", "length_trm", "length_predictor", "model.lm_head.decoder.weight"] |
|
|
| def __init__(self, config: DiscreteDiffusionConfig): |
| super().__init__(config) |
| self.config = config |
| self.args = config |
|
|
| |
| if config.backbone_config: |
| |
| backbone_config_obj = AutoConfig.for_model(**config.backbone_config) |
| self.model = AutoModelForMaskedLM.from_config(backbone_config_obj) |
| else: |
| |
| raise ValueError("backbone_config must be provided in config") |
|
|
| if config.tie_word_embeddings: |
| self.model.lm_head.decoder.weight = self.model.roberta.embeddings.word_embeddings.weight |
|
|
| self.mask_id = config.mask_token_id |
| self.bos_id = config.bos_token_id |
| self.eos_id = config.eos_token_id |
| self.pad_id = config.pad_token_id |
| |
| |
| if config.lora: |
| self.add_fake_layer() |
|
|
| |
| self.length_trm = nn.TransformerEncoder( |
| nn.TransformerEncoderLayer( |
| d_model=self.config.hidden_size, |
| nhead=self.config.num_attention_heads, |
| dim_feedforward=self.config.intermediate_size, |
| batch_first=True |
| ), |
| num_layers=1, |
| ) |
| self.length_predictor = nn.Sequential( |
| nn.Linear(self.config.hidden_size , self.config.intermediate_size), |
| nn.Tanh(), |
| nn.Linear(self.config.intermediate_size, self.config.max_position_embeddings) |
| ) |
|
|
| def add_fake_layer(self): |
| self.fake_layer = nn.Parameter(torch.zeros((self.config.hidden_size, ))) |
|
|
| def gradient_checkpointing_enable(self): |
| self.model.gradient_checkpointing_enable() |
|
|
| def _tie_weights(self): |
| """Tie the weights between the input embeddings and the output embeddings.""" |
| if self.config.tie_word_embeddings: |
| self._tie_or_clone_weights( |
| self.model.lm_head.decoder, |
| self.model.roberta.embeddings.word_embeddings |
| ) |
|
|
| def _init_weights(self, module): |
| """Initialize the weights - called after loading checkpoint.""" |
| |
| super()._init_weights(module) |
| |
| self._tie_weights() |
|
|
| @property |
| def _tied_weights_keys(self): |
| """Return the keys of tied weights.""" |
| if self.config.tie_word_embeddings: |
| return ["model.lm_head.decoder.weight"] |
| return [] |
|
|
| def q_sample_coupled(self, x_0, t1, t2, maskable_mask): |
| |
| assert self.config.diffusion_type == "absorbing", "we only support absorbing diffusion temporarily" |
| t1_eq_t2_mask = (t1 == t2) |
| t1, t2 = torch.maximum(t1, t2).float(), torch.minimum(t1, t2).float() |
| |
| u = torch.rand_like(x_0, dtype=torch.float) |
| t1_mask = (u < (t1 / self.config.num_diffusion_timesteps)[:, None]) & maskable_mask |
| x_t1 = x_0.masked_fill(t1_mask, self.mask_id) |
| |
| u = torch.rand_like(x_0, dtype=torch.float) |
| t2_mask = t1_mask & (u > ((t1 - t2) / t1)[:, None]) |
| u = torch.rand_like(x_0[t1_eq_t2_mask], dtype=torch.float) |
| t2_mask[t1_eq_t2_mask] = (u < (t1[t1_eq_t2_mask] / self.config.num_diffusion_timesteps)[:, None]) & (maskable_mask[t1_eq_t2_mask]) |
| x_t2 = x_0.masked_fill(t2_mask, self.mask_id) |
| |
| return { |
| "x_t": torch.cat([x_t1, x_t2], dim=0), |
| "t": torch.cat([t1, t2]), |
| "mask_mask": torch.cat([t1_mask, t2_mask], dim=0) |
| } |
|
|
| def initialize_decode_samples(self, tokens, partial_masks, prefix_masks, oracle_length=False, length_beam=1, mbr=1): |
| |
| if tokens is None: |
| raise NotImplementedError |
| else: |
| if not oracle_length: |
| inputs_tokens = tokens.masked_fill(~prefix_masks, self.pad_id) |
| src_length = inputs_tokens.ne(self.pad_id).sum(dim=-1) |
| inputs_tokens = inputs_tokens[:, :src_length.max()] |
| length_logits = self.forward_length(inputs_tokens) |
| |
| max_allowed_length = torch.min( |
| torch.tensor([100]).to(src_length.device), |
| (src_length * 3)[:, None] |
| ) |
| length = ( |
| torch.min( |
| torch.min( |
| length_logits.topk(length_beam, dim=-1).indices + 1, |
| max_allowed_length |
| ), |
| self.config.max_position_embeddings - 2 - src_length[:, None] - 1 |
| ) |
| ) |
| output_tokens = [] |
| new_partial_masks = [] |
| for i, token in enumerate(inputs_tokens): |
| for b in range(length_beam): |
| for m in range(mbr): |
| |
| seq = torch.cat([ |
| token[:src_length[i]], |
| torch.tensor([self.mask_id] * length[i][b] + [self.eos_id]).to(token) |
| ]) |
| output_tokens.append(seq) |
| |
| |
| |
| |
| |
| p_mask = torch.cat([ |
| partial_masks[i][:src_length[i]], |
| torch.tensor([False] * (length[i][b] + 1)).to(partial_masks) |
| ]) |
| new_partial_masks.append(p_mask) |
| |
| output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True) |
| |
| |
| output_mask = output_tokens.eq(self.mask_id) |
| |
| |
| non_fixed_sym_masks = ( |
| output_tokens.ne(self.pad_id) & |
| output_tokens.ne(self.bos_id) & |
| ~partial_masks |
| ) |
| else: |
| output_tokens = torch.stack([token for token in tokens for m in range(mbr)]) |
| partial_masks = torch.stack([mask for mask in partial_masks for m in range(mbr)]) |
| prefix_masks = torch.stack([mask for mask in prefix_masks for m in range(mbr)]) |
| output_mask = ( |
| output_tokens.ne(self.pad_id) & |
| output_tokens.ne(self.bos_id) & |
| output_tokens.ne(self.eos_id) & |
| ~prefix_masks |
| ) |
| output_tokens = output_tokens.masked_fill(output_mask, self.mask_id) |
| non_fixed_sym_masks = output_mask.clone() |
| output_scores = torch.zeros_like(output_tokens, dtype=torch.float) |
| |
| return partial_masks, decoder_out_t( |
| output_tokens=output_tokens, |
| output_scores=output_scores, |
| output_masks=output_mask, |
| non_fixed_sym_masks=non_fixed_sym_masks, |
| attn=None, |
| step=0, |
| max_step=math.inf, |
| history=None |
| ) |
|
|
| def forward_length(self, input_ids): |
| attention_mask = input_ids.ne(self.pad_id).int() |
| with torch.no_grad(): |
| _feature = self.model.roberta(input_ids, attention_mask=attention_mask)[0] |
| feature = self.length_trm(_feature, src_key_padding_mask=(1-attention_mask).bool()) |
| length = attention_mask.sum(dim=-1) |
| pooled_feature = feature.masked_fill((attention_mask==0)[:, :, None], 0).float().sum(1) / length[:, None] |
| length_logits = self.length_predictor(pooled_feature.to(feature)) |
| return length_logits |
|
|
| def forward(self, prev_output_tokens, partial_mask, attention_mask=None, loss_mask=None, cache=None): |
| input_ids = prev_output_tokens |
| if attention_mask is None: |
| attention_mask = prev_output_tokens.ne(self.pad_id).int() |
| |
| embeddings = self.model.roberta.embeddings.word_embeddings(input_ids) |
| |
| if hasattr(self, "fake_layer") and self.training: |
| self.fake_layer.requires_grad = True |
| embeddings = embeddings + self.fake_layer * 0 |
| |
| if self.config.attention_strategy == "prefix_lm": |
| |
| |
| ext_partial_mask = partial_mask.float() |
| ext_partial_mask = torch.bmm(ext_partial_mask[:, :, None], ext_partial_mask[:, None, :]).int() |
| ext_mask = attention_mask[:, None, :].repeat(1, attention_mask.size(-1), 1) |
| ext_mask[partial_mask] = ext_partial_mask[partial_mask] |
| outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=ext_mask)[0] |
| else: |
| outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=attention_mask)[0] |
| |
| if not (~torch.isnan(outputs)).all(): |
| outputs.masked_fill_(outputs.isnan(), 0) |
| |
| outputs = outputs[loss_mask] if loss_mask is not None else outputs |
| return self.model.lm_head(outputs) |
|
|
| def _reparam_decoding( |
| self, |
| output_tokens, |
| output_scores, |
| cur_tokens, |
| cur_scores, |
| decoding_strategy, |
| xt_neq_x0, |
| non_special_sym_mask, |
| t, |
| max_step, |
| noise |
| ): |
| _, condition, topk_mode, schedule = decoding_strategy.split("-") |
|
|
| if schedule == "linear": |
| rate = 1 - t / max_step |
| elif schedule == "cosine": |
| rate = np.cos(t / max_step * np.pi * 0.5) |
| else: |
| raise NotImplementedError |
|
|
| cutoff_len = ( |
| non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate |
| ).long() |
| _scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0) |
| |
| if topk_mode.startswith("stochastic"): |
| noise_scale = float(topk_mode.replace("stochastic", "")) |
| lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate) |
| elif topk_mode == "deterministic": |
| lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False) |
| else: |
| raise NotImplementedError |
| |
| if condition == "cond": |
| not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask |
| elif condition == "uncond": |
| not_v1_t = lowest_k_mask |
| else: |
| raise NotImplementedError |
| |
| not_v2_t = lowest_k_mask |
|
|
| masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t) |
| if isinstance(noise, torch.Tensor): |
| output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise]) |
| elif isinstance(noise, (int, float)): |
| output_tokens.masked_fill_(masked_to_noise, noise) |
| else: |
| raise NotImplementedError("noise should be either a tensor or a scalar") |
| output_scores.masked_fill_(masked_to_noise, -math.inf) |
|
|
| masked_to_x0 = xt_neq_x0 & ~not_v2_t |
| output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0]) |
| output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0]) |
| |
| new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t |
| return new_xt_neq_x0 |
|
|
| def denoise_step(self, decoder_out, partial_masks, temperature=1.0, strategy="reparam-uncond-deterministic-cosine"): |
| output_tokens = decoder_out.output_tokens |
| output_scores = decoder_out.output_scores |
| prev_step, cur_step = decoder_out.step, decoder_out.step + 1 |
| max_step = decoder_out.max_step |
| |
| logits = self.forward(output_tokens, partial_masks) |
| |
| logits[..., self.mask_id] = -math.inf |
| scores = torch.log_softmax(logits, dim=-1) |
| |
| if strategy == "cmlm": |
| |
| |
| |
| output_masks = output_tokens.eq(self.mask_id) |
| unmask_prob = 1 / (max_step - prev_step) |
| |
| changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob |
| |
| changes = torch.bitwise_and(changes, output_masks) |
|
|
| if getattr(self.config, "argmax_decoding", False): |
| output_scores, new_tokens = scores.max(-1) |
| else: |
| |
| |
| |
| |
| import torch.distributions as dists |
| new_tokens = dists.Categorical(logits=scores / temperature).sample() |
| output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
| output_tokens[changes] = new_tokens[changes] |
| elif strategy == "ar": |
| output_masks = output_tokens.eq(self.mask_id) |
| unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1) |
| indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device) |
| if getattr(self.config, "argmax_decoding", False): |
| output_scores, new_tokens = scores.max(-1) |
| else: |
| import torch.distributions as dists |
| new_tokens = dists.Categorical(logits=scores / temperature).sample() |
| output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) |
| output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices] |
| else: |
| if getattr(self.config, "argmax_decoding", False): |
| cur_scores, cur_tokens = scores.max(-1) |
| else: |
| import torch.distributions as dists |
| cur_tokens = dists.Categorical(logits=scores / temperature).sample() |
| cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1) |
| cur_scores = cur_scores.to(output_scores) |
| |
| output_masks = self._reparam_decoding( |
| output_tokens=output_tokens, |
| output_scores=output_scores, |
| cur_tokens=cur_tokens, |
| cur_scores=cur_scores, |
| decoding_strategy=strategy, |
| xt_neq_x0=decoder_out.output_masks, |
| non_special_sym_mask=decoder_out.non_fixed_sym_masks, |
| t=cur_step, |
| max_step=max_step, |
| noise=self.mask_id |
| ) |
| |
| history = ( |
| ([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()] |
| if decoder_out.history is not None else None |
| ) |
| |
| return decoder_out._replace( |
| step=cur_step, |
| output_tokens=output_tokens, |
| output_scores=output_scores, |
| output_masks=output_masks, |
| history=history, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids, |
| attention_mask=None, |
| max_iterations=10, |
| strategy="reparam-uncond-deterministic-cosine", |
| temperature=1.0, |
| return_history=False, |
| max_length=128, |
| **kwargs |
| ): |
| |
| src_tokens = input_ids |
| |
| if attention_mask is None: |
| partial_masks = torch.ones_like(src_tokens).bool() |
| else: |
| partial_masks = attention_mask.bool() |
| |
| prefix_masks = partial_masks |
| |
| |
| |
| batch_size = src_tokens.size(0) |
| src_length = src_tokens.ne(self.pad_id).sum(dim=-1) |
| |
| |
| output_tokens = [] |
| new_partial_masks = [] |
| |
| for i in range(batch_size): |
| |
| |
| src_len = src_length[i].item() |
| src_seq = src_tokens[i, :src_len] |
| |
| |
| if src_seq[-1] == self.eos_id: |
| src_seq = src_seq[:-1] |
| src_len -= 1 |
| |
| seq = torch.cat([ |
| src_seq, |
| torch.full((max_length,), self.mask_id, dtype=src_tokens.dtype, device=src_tokens.device), |
| torch.tensor([self.eos_id], dtype=src_tokens.dtype, device=src_tokens.device) |
| ]) |
| output_tokens.append(seq) |
| |
| |
| mask = torch.cat([ |
| torch.ones(src_len, dtype=torch.bool, device=src_tokens.device), |
| torch.zeros(max_length + 1, dtype=torch.bool, device=src_tokens.device) |
| ]) |
| new_partial_masks.append(mask) |
| |
| output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id) |
| partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True) |
| |
| |
| output_mask = output_tokens.eq(self.mask_id) |
| non_fixed_sym_masks = ( |
| output_tokens.ne(self.pad_id) & |
| output_tokens.ne(self.bos_id) & |
| ~partial_masks |
| ) |
| |
| output_scores = torch.zeros_like(output_tokens, dtype=torch.float) |
| |
| prev_decoder_out = decoder_out_t( |
| output_tokens=output_tokens, |
| output_scores=output_scores, |
| output_masks=output_mask, |
| non_fixed_sym_masks=non_fixed_sym_masks, |
| attn=None, |
| step=0, |
| max_step=max_iterations, |
| history=None |
| ) |
| |
| if return_history: |
| prev_decoder_out = prev_decoder_out._replace(history=[]) |
| |
| for step in range(max_iterations): |
| prev_decoder_out = self.denoise_step(prev_decoder_out, partial_masks, temperature=temperature, strategy=strategy) |
| |
| |
| def finalized_hypos(tokens, scores, partial_mask, history=None): |
| |
| eos_positions = (tokens == self.eos_id).nonzero(as_tuple=True)[0] |
| if len(eos_positions) > 0: |
| first_eos = eos_positions[0].item() |
| |
| tokens = tokens[:first_eos] |
| if scores is not None: |
| scores = scores[:first_eos] |
| partial_mask = partial_mask[:first_eos] |
| |
| |
| cutoff = ( |
| tokens.ne(self.pad_id) & |
| tokens.ne(self.bos_id) & |
| tokens.ne(self.eos_id) & |
| (~partial_mask) |
| ) |
| tokens = tokens[cutoff] |
| if scores is None: |
| score = None |
| else: |
| scores = scores[cutoff] |
| score = scores.mean().item() if len(scores) > 0 else 0.0 |
| ret_dict = { |
| "tokens": tokens, |
| "positional_scores": scores, |
| "score": score, |
| "alignment": None |
| } |
| if history is not None: |
| ret_dict["history"] = [ |
| finalized_hypos(history_tokens, None, partial_mask, history=None) |
| for history_tokens in history |
| ] |
| return ret_dict |
| |
| def score_select(hyps): |
| index = np.argmax([hyp["score"] for hyp in hyps]) |
| return hyps[index] |
| |
| output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores |
| |
| |
| if return_history and prev_decoder_out.history is not None: |
| full_history = prev_decoder_out.history |
| histories = [[full_history[j][i] for j in range(max_iterations)] for i in range(output_tokens.size(0))] |
| hyps = [] |
| for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories): |
| hyps.append(finalized_hypos(tokens, scores, partial_mask, history)) |
| else: |
| hyps = [ |
| finalized_hypos(tokens, scores, partial_mask, None) |
| for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks) |
| ] |
| |
| repeatition = kwargs.get("mbr", 1) * kwargs.get("length_beam", 1) |
| if repeatition > 1: |
| hyps = [score_select(hyps[i:i+repeatition]) for i in range(0, len(hyps), repeatition)] |
| |
| finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id) |
| |
| |
| |
| return finalized |
|
|