| import os |
| import torch |
| from transformers import PreTrainedModel, GenerationConfig, BertLMHeadModel |
| from transformers.modeling_outputs import Seq2SeqLMOutput |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| from typing import Optional, Tuple, Union |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
| class MyModel(PreTrainedModel): |
| def __init__(self, config, trans_model, nougat_model): |
| super().__init__(config) |
| self.encoder = nougat_model.encoder |
| self.decoder = trans_model.decoder |
| self.project = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) |
| |
| def forward( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, |
| encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict=True, |
| **kwargs, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
| |
| encoder_outputs = self.encoder( |
| pixel_values=pixel_values, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| encoder_hidden_states = encoder_outputs.last_hidden_state |
| encoder_hidden_states_proj = self.project(encoder_hidden_states) |
| |
| decoder_outputs = self.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| encoder_hidden_states=encoder_hidden_states_proj, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| use_cache=use_cache, |
| past_key_values=past_key_values, |
| return_dict=return_dict, |
| ) |
|
|
| |
| loss = None |
| if labels is not None: |
| logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
| loss_fct_trans = CrossEntropyLoss() |
| loss_trans = loss_fct_trans(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1).long()) |
| |
| loss = loss_trans |
|
|
| if not return_dict: |
| if loss is not None: |
| return (loss,) + decoder_outputs + encoder_outputs |
| else: |
| return decoder_outputs + encoder_outputs |
|
|
| return Seq2SeqLMOutput( |
| loss=loss, |
| logits=decoder_outputs.logits, |
| past_key_values=decoder_outputs.past_key_values, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| decoder_attentions=decoder_outputs.attentions, |
| cross_attentions=decoder_outputs.cross_attentions, |
| encoder_last_hidden_state=encoder_hidden_states, |
| ) |
| |
| def generate( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, |
| encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict=True, |
| generation_config: Optional[GenerationConfig] = None, |
| **kwargs, |
| ): |
| |
| encoder_outputs = self.encoder( |
| pixel_values=pixel_values, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| encoder_hidden_states = encoder_outputs.last_hidden_state |
| encoder_hidden_states_proj = self.project(encoder_hidden_states) |
| |
| generation_outputs = self.decoder.generate( |
| encoder_hidden_states=encoder_hidden_states_proj, |
| generation_config=generation_config, |
| ) |
| |
| return generation_outputs |
|
|
| class MyDataset(Dataset): |
| def __init__(self, processor, tokenizer, name_list, max_length, image_dir, text_dir): |
| self.processor = processor |
| self.tokenizer = tokenizer |
| self.name_list = name_list |
| self.max_length = max_length |
| self.image_dir = image_dir |
| self.text_dir = text_dir |
| |
| def __len__(self): |
| return len(self.name_list) |
| |
| def __getitem__(self, index): |
| encoding = {} |
| image_file_path = os.path.join(self.image_dir, self.name_list[index]+'.png') |
| image = Image.open(image_file_path) |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| pixel_values = self.processor(image, return_tensors="pt").pixel_values.squeeze(0) |
| encoding['pixel_values'] = pixel_values |
| |
| text_file_path = os.path.join(self.text_dir, self.name_list[index]+'.mmd') |
| with open(text_file_path, 'r') as f: |
| lines = f.readlines() |
| text = ''.join(lines) |
| input_ids = self.tokenizer(text, max_length=self.max_length, truncation=True).input_ids |
| input_ids = [x for x in input_ids if x != 6] |
| input_ids = [self.tokenizer.bos_token_id] + input_ids[1:] |
| |
| decoder_input_ids = input_ids + [self.tokenizer.pad_token_id]*(self.max_length-len(input_ids)) |
| decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long) |
| labels = input_ids[1:] + [-100]*(self.max_length-len(input_ids)+1) |
| labels = torch.tensor(labels, dtype=torch.long) |
| encoding['decoder_input_ids'] = decoder_input_ids |
| encoding['labels'] = labels |
| |
| return encoding |