Instructions to use Lin-Chen/ShareCaptioner with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Lin-Chen/ShareCaptioner with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Lin-Chen/ShareCaptioner", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Lin-Chen/ShareCaptioner", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import copy | |
| import os | |
| import sys | |
| dir_path = os.path.dirname(os.path.realpath(__file__)) | |
| sys.path.insert(0, dir_path) | |
| import contextlib | |
| import torch.utils.checkpoint | |
| import torch.nn as nn | |
| from torch.nn import LayerNorm | |
| from torchvision import transforms | |
| from torchvision.transforms.functional import InterpolationMode | |
| from PIL import Image | |
| from .modeling_vit import * | |
| from .modeling_InternLM import * | |
| from .modeling_utils import * | |
| from .resampler import create_resampler | |
| from transformers.utils import logging | |
| logger = logging.get_logger(__name__) | |
| class InternLMXComposerForCausalLM(PreTrainedModel): | |
| config_class = InternLMXComposerConfig | |
| _auto_class = "AutoModelForCausalLM" | |
| gen_config = dict( | |
| num_beams=5, | |
| do_sample=True, | |
| min_length=1, | |
| repetition_penalty=1.5, | |
| length_penalty=1.0, | |
| temperature=1.0, | |
| max_new_tokens=500, | |
| ) | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.max_length = config.max_length | |
| print (f'Set max length to {self.max_length}') | |
| print('Init VIT ... ', end='') | |
| self.visual_encoder = create_eva_vit_g(img_size=448) | |
| self.ln_vision = nn.Identity() | |
| self.supports_gradient_checkpointing = True | |
| print('Done') | |
| print('Init Perceive Sampler ... ', end='') | |
| with all_logging_disabled(): | |
| self.Qformer = create_resampler(num_query_token=256) | |
| print('Done') | |
| print('Init InternLM ... ', end='') | |
| self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096])) | |
| self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096])) | |
| self.flag_image_start.requires_grad = False | |
| self.flag_image_end.requires_grad = False | |
| if int(torch.__version__[0]) == 1: | |
| self.internlm_model = InternLMForCausalLM._from_config(config).to( | |
| torch.float16) | |
| else: | |
| assert int(torch.__version__[0]) == 2 | |
| # speed up init llm | |
| with torch.device('meta'): | |
| self.internlm_model = InternLMForCausalLM._from_config(config) | |
| self.internlm_model.to_empty(device=config.device).to(torch.float16) | |
| self.internlm_proj = nn.Linear(4096, | |
| self.internlm_model.config.hidden_size) | |
| print('Done') | |
| self.vis_processor = transforms.Compose([ | |
| transforms.Resize((448, 448), | |
| interpolation=InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711)), | |
| ]) | |
| self.tokenizer = None | |
| def eoh(self): | |
| return '<TOKENS_UNUSED_0>' | |
| def eoa(self): | |
| return '<TOKENS_UNUSED_1>' | |
| def get_input_embeddings(self): | |
| return self.internlm_model.get_input_embeddings() | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if value: | |
| self.internlm_model.apply( | |
| partial(self.internlm_model._set_gradient_checkpointing, value=True) | |
| ) | |
| def encode_img(self, image): | |
| if image is None: | |
| return None | |
| if isinstance(image, str): | |
| image = Image.open(image).convert("RGB") | |
| image = self.vis_processor(image).unsqueeze(0).to(self.device) | |
| else: | |
| assert isinstance(image, torch.Tensor) | |
| device = image.device | |
| image_embeds = self.ln_vision( | |
| self.visual_encoder(image)).to(device) | |
| image_atts = torch.ones(image_embeds.size()[:-1], | |
| dtype=torch.long).to(device) | |
| query_output = self.Qformer(image_embeds) | |
| inputs_internlm = self.internlm_proj(query_output) | |
| inputs_internlm = torch.cat([ | |
| self.flag_image_start.expand(inputs_internlm.shape[0], -1, -1), | |
| inputs_internlm, | |
| self.flag_image_end.expand(inputs_internlm.shape[0], -1, -1) | |
| ], | |
| dim=1) | |
| return inputs_internlm | |
| def encode_text(self, text, add_special_tokens=False): | |
| text_token_ids = self.tokenizer( | |
| text, | |
| return_tensors='pt', | |
| add_special_tokens=add_special_tokens, | |
| ).input_ids.to(self.device) | |
| text_embeds = self.internlm_model.model.embed_tokens(text_token_ids) | |
| return text_embeds | |
| def decode_text(self, out_embeds): | |
| out_text = self.tokenizer.batch_decode(out_embeds, | |
| skip_special_tokens=True)[0] | |
| out_text = out_text.split(self.eoa)[0] | |
| return out_text | |
| def wrap_text(self, user_text, bot_text='', add_special=True): | |
| if add_special: | |
| eoh = self.eoh | |
| else: | |
| eoh = '' | |
| text = f'<|User|>:{user_text}{eoh}\n<|Bot|>:{bot_text}' | |
| return text | |
| def get_gen_args(self, **kwargs): | |
| new_kargs = copy.deepcopy(self.gen_config) | |
| new_kargs.update(kwargs) | |
| return new_kargs | |
| def generate(self, text, image=None, **kwargs): | |
| text_embeds = self.encode_text(text) | |
| img_embeds = self.encode_img(image) | |
| prompt_embeds = self.wrap_prompt(text_embeds, img_embeds) | |
| out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, | |
| **self.get_gen_args(**kwargs)) | |
| out_text = self.decode_text(out_embeds) | |
| return out_text | |
| def chat(self, text, image=None, history=None, **kwargs): | |
| text_embeds = self.encode_text(text) | |
| img_embeds = self.encode_img(image) | |
| prompt_embeds = self.wrap_prompt(text_embeds, | |
| img_embeds, | |
| history=history) | |
| out_embeds = self.internlm_model.generate(inputs_embeds=prompt_embeds, | |
| **self.get_gen_args(**kwargs)) | |
| out_text = self.decode_text(out_embeds) | |
| # trunc at eoh and eoa | |
| clean_out_text_token_ids = self.tokenizer( | |
| out_text, return_tensors='pt').input_ids.to(self.device) | |
| clean_out_text_embeds = self.internlm_model.model.embed_tokens( | |
| clean_out_text_token_ids) | |
| clean_prompt_embeds = self.wrap_prompt(text_embeds, | |
| img_embeds, | |
| add_special=False) | |
| cur_history = torch.cat([clean_prompt_embeds, clean_out_text_embeds], | |
| dim=1) | |
| if history is None: | |
| history = [] | |
| history.append(cur_history) | |
| return out_text, history | |
| def wrap_prompt(self, | |
| text_embeds, | |
| img_embeds=None, | |
| history=None, | |
| add_special=True): | |
| if add_special: | |
| prompt_segs = ['<|User|>:', f'{self.eoh}\n<|Bot|>:'] | |
| else: | |
| prompt_segs = ['<|User|>:', '<|Bot|>:'] # used in wrap history | |
| prompt_seg_embeds = [] | |
| for i, seg in enumerate(prompt_segs): | |
| if history is not None: | |
| add_special_tokens = False | |
| else: | |
| add_special_tokens = i == 0 | |
| seg_embeds = self.encode_text( | |
| seg, add_special_tokens=add_special_tokens) | |
| prompt_seg_embeds.append(seg_embeds) | |
| if img_embeds is None: | |
| img_embeds = text_embeds.new_empty(text_embeds.size(0), 0, | |
| text_embeds.size(-1)) | |
| prompt_seg_embeds = [ | |
| prompt_seg_embeds[0], img_embeds, text_embeds, prompt_seg_embeds[1] | |
| ] | |
| prompt_embeds = torch.cat(prompt_seg_embeds, dim=1) | |
| if history is not None: | |
| prompt_embeds = torch.cat([*history, prompt_embeds], dim=1) | |
| return prompt_embeds | |