Spaces:
Sleeping
Sleeping
| from typing import Tuple, Union, Any, Dict, Optional, List | |
| from typing_extensions import Self | |
| import json | |
| import os | |
| from torch.utils.data import Dataset | |
| from open_biomed.data import Cell, Text | |
| from open_biomed.datasets.base_dataset import BaseDataset, assign_split, featurize | |
| from open_biomed.utils.config import Config | |
| from open_biomed.utils.featurizer import Featurizer, Featurized | |
| from datasets import load_from_disk | |
| class CellAnnotationDataset(BaseDataset): | |
| def __init__(self, cfg: Config, featurizer: Featurizer) -> None: | |
| self.cells, self.labels = [], [] | |
| self.class_texts = {} | |
| super(CellAnnotationDataset, self).__init__(cfg, featurizer) | |
| def __len__(self) -> int: | |
| return len(self.cells) | |
| def __getitem__(self, index) -> Dict[str, Featurized[Any]]: | |
| return { | |
| "cell": self.cells[index], | |
| "class_text": self.class_texts[index], | |
| "label": self.labels[index], | |
| } | |
| class CellAnnotation(CellAnnotationDataset): | |
| def __init__(self, cfg: Config, featurizer: Featurizer) -> None: | |
| super(CellAnnotation, self).__init__(cfg, featurizer) | |
| def _load_data(self) -> None: | |
| dataset = load_from_disk(os.path.join(self.cfg.path, f"data.dataset")) | |
| class_texts = json.load(open(os.path.join(self.cfg.path, f"type2text.json"), "r")) | |
| for sample in dataset: | |
| self.cells.append(Cell.from_sequence(sample["input_ids"])) | |
| self.labels.append(Text.from_str(sample["celltype"])) | |
| self.class_texts = class_texts |