Instructions to use Lin-Chen/ShareCaptioner with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Lin-Chen/ShareCaptioner with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="Lin-Chen/ShareCaptioner", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Lin-Chen/ShareCaptioner", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import logging | |
| import math | |
| import os | |
| from contextlib import contextmanager | |
| import timm.models.hub as timm_hub | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| def is_dist_avail_and_initialized(): | |
| if not dist.is_available(): | |
| return False | |
| if not dist.is_initialized(): | |
| return False | |
| return True | |
| def get_rank(): | |
| if not is_dist_avail_and_initialized(): | |
| return 0 | |
| return dist.get_rank() | |
| def is_main_process(): | |
| return get_rank() == 0 | |
| def download_cached_file(url, check_hash=True, progress=False): | |
| """ | |
| Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. | |
| If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. | |
| """ | |
| def get_cached_file_path(): | |
| # a hack to sync the file path across processes | |
| parts = torch.hub.urlparse(url) | |
| filename = os.path.basename(parts.path) | |
| cached_file = os.path.join(timm_hub.get_cache_dir(), filename) | |
| return cached_file | |
| if is_main_process(): | |
| timm_hub.download_cached_file(url, check_hash, progress) | |
| if is_dist_avail_and_initialized(): | |
| dist.barrier() | |
| return get_cached_file_path() | |
| def all_logging_disabled(highest_level=logging.CRITICAL): | |
| """ | |
| A context manager that will prevent any logging messages | |
| triggered during the body from being processed. | |
| :param highest_level: the maximum logging level in use. | |
| This would only need to be changed if a custom level greater than CRITICAL | |
| is defined. | |
| """ | |
| # two kind-of hacks here: | |
| # * can't get the highest logging level in effect => delegate to the user | |
| # * can't get the current module-level override => use an undocumented | |
| # (but non-private!) interface | |
| previous_level = logging.root.manager.disable | |
| logging.disable(highest_level) | |
| try: | |
| yield | |
| finally: | |
| logging.disable(previous_level) | |
| class LoRALinear(nn.Linear): | |
| def __init__(self, | |
| in_features: int, | |
| out_features: int, | |
| bias: bool = True, | |
| device=None, | |
| dtype=None, | |
| lora_r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| **kwargs) -> None: | |
| super().__init__(in_features, out_features, bias, device, dtype) | |
| self.lora_r = lora_r | |
| self.lora_alpha = lora_alpha | |
| if lora_dropout > 0.: | |
| self.lora_dropout = nn.Dropout(p=lora_dropout) | |
| else: | |
| self.lora_dropout = lambda x: x | |
| self.lora_scaling = self.lora_alpha / self.lora_r | |
| self.lora_A = nn.Linear(in_features, | |
| self.lora_r, | |
| bias=False, | |
| device=device, | |
| dtype=dtype) | |
| self.lora_B = nn.Linear(self.lora_r, | |
| out_features, | |
| bias=False, | |
| device=device, | |
| dtype=dtype) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| if hasattr(self, 'lora_A'): | |
| # initialize A the same way as the default for nn.Linear and B to zero | |
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) | |
| nn.init.zeros_(self.lora_B.weight) | |
| #print ("lora weight init {} {}".format(torch.mean(self.lora_A.weight), torch.mean(self.lora_B.weight))) | |
| def forward(self, x): | |
| orig_type = x.dtype | |
| res = super().forward(x) | |
| x = x.float() | |
| res += self.lora_B(self.lora_A( | |
| self.lora_dropout(x))) * self.lora_scaling | |
| return res.to(orig_type) | |