bioflow / open_biomed /datasets /cell_annotation_dataset.py
Rami-Troudi's picture
Phase 1: FastAPI integration with DeepPurpose DTI predictor
adecc9b
raw
history blame
1.57 kB
from typing import Tuple, Union, Any, Dict, Optional, List
from typing_extensions import Self
import json
import os
from torch.utils.data import Dataset
from open_biomed.data import Cell, Text
from open_biomed.datasets.base_dataset import BaseDataset, assign_split, featurize
from open_biomed.utils.config import Config
from open_biomed.utils.featurizer import Featurizer, Featurized
from datasets import load_from_disk
class CellAnnotationDataset(BaseDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
self.cells, self.labels = [], []
self.class_texts = {}
super(CellAnnotationDataset, self).__init__(cfg, featurizer)
def __len__(self) -> int:
return len(self.cells)
@featurize
def __getitem__(self, index) -> Dict[str, Featurized[Any]]:
return {
"cell": self.cells[index],
"class_text": self.class_texts[index],
"label": self.labels[index],
}
class CellAnnotation(CellAnnotationDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
super(CellAnnotation, self).__init__(cfg, featurizer)
def _load_data(self) -> None:
dataset = load_from_disk(os.path.join(self.cfg.path, f"data.dataset"))
class_texts = json.load(open(os.path.join(self.cfg.path, f"type2text.json"), "r"))
for sample in dataset:
self.cells.append(Cell.from_sequence(sample["input_ids"]))
self.labels.append(Text.from_str(sample["celltype"]))
self.class_texts = class_texts