Spaces:
Sleeping
Sleeping
| from typing import Any, Optional, Tuple | |
| from typing_extensions import Self | |
| import json | |
| import os | |
| import random | |
| from open_biomed.data import Protein, Text | |
| from open_biomed.datasets.base_dataset import BaseDataset, assign_split, featurize | |
| from open_biomed.utils.config import Config | |
| from open_biomed.utils.featurizer import Featurizer | |
| class ProteinTextDataset(BaseDataset): | |
| def __init__(self, cfg: Config, featurizer: Featurizer) -> None: | |
| self.proteins, self.texts = [], [] | |
| super().__init__(cfg, featurizer) | |
| def __len__(self) -> int: | |
| return len(self.proteins) | |
| class TextBasedProteinGenerationDataset(ProteinTextDataset): | |
| def __init__(self, cfg: Config, featurizer: Featurizer) -> None: | |
| super().__init__(cfg, featurizer) | |
| def _load_data(self) -> None: | |
| self.labels = self.proteins | |
| def __getitem__(self, index) -> Any: | |
| return { | |
| "text": self.texts[index], | |
| "label": self.proteins[index], | |
| } | |
| class MolInstructionsForProteinDesign(TextBasedProteinGenerationDataset): | |
| def __init__(self, cfg: Config, featurizer: Featurizer) -> None: | |
| super().__init__(cfg, featurizer) | |
| def _load_data(self) -> None: | |
| self.split_indexes = {"train": [], "valid": [], "test": []} | |
| data = json.load(open(os.path.join(self.cfg.path, "protein_design.json"), "r")) | |
| cnt = 0 | |
| for i, sample in enumerate(data): | |
| seq = sample["output"].split("\n")[-2] | |
| if "X" in seq: | |
| continue | |
| self.texts.append(Text.from_str(sample["input"])) | |
| self.proteins.append(Protein.from_fasta(seq)) | |
| if sample["metadata"]["split"] == "train": | |
| if self.cfg.debug: | |
| self.split_indexes["train"].append(cnt) | |
| if cnt >= 498: | |
| self.split_indexes["valid"].append(cnt) | |
| elif random.randint(1, 100) > 95: | |
| self.split_indexes["valid"].append(cnt) | |
| else: | |
| self.split_indexes["train"].append(cnt) | |
| else: | |
| self.split_indexes[sample["metadata"]["split"]].append(cnt) | |
| cnt += 1 | |
| if cnt >= 500 and self.cfg.debug: | |
| break | |
| # print(len(self.split_indexes["train"]), len(self.split_indexes["valid"]), len(self.split_indexes["test"])) | |
| super()._load_data() | |
| def split(self, split_cfg: Optional[Config]=None) -> Tuple[Self, Self, Self]: | |
| attrs = ["proteins", "texts", "labels"] | |
| ret = ( | |
| self.get_subset(self.split_indexes["train"], attrs), | |
| self.get_subset(self.split_indexes["valid"], attrs), | |
| self.get_subset(self.split_indexes["test"], attrs), | |
| ) | |
| del self | |
| return ret |