Spaces:
Sleeping
Sleeping
File size: 2,836 Bytes
adecc9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | from typing import Any, Optional, Tuple
from typing_extensions import Self
import json
import os
import random
from open_biomed.data import Protein, Text
from open_biomed.datasets.base_dataset import BaseDataset, assign_split, featurize
from open_biomed.utils.config import Config
from open_biomed.utils.featurizer import Featurizer
class ProteinTextDataset(BaseDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
self.proteins, self.texts = [], []
super().__init__(cfg, featurizer)
def __len__(self) -> int:
return len(self.proteins)
class TextBasedProteinGenerationDataset(ProteinTextDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
super().__init__(cfg, featurizer)
def _load_data(self) -> None:
self.labels = self.proteins
@featurize
def __getitem__(self, index) -> Any:
return {
"text": self.texts[index],
"label": self.proteins[index],
}
class MolInstructionsForProteinDesign(TextBasedProteinGenerationDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
super().__init__(cfg, featurizer)
def _load_data(self) -> None:
self.split_indexes = {"train": [], "valid": [], "test": []}
data = json.load(open(os.path.join(self.cfg.path, "protein_design.json"), "r"))
cnt = 0
for i, sample in enumerate(data):
seq = sample["output"].split("\n")[-2]
if "X" in seq:
continue
self.texts.append(Text.from_str(sample["input"]))
self.proteins.append(Protein.from_fasta(seq))
if sample["metadata"]["split"] == "train":
if self.cfg.debug:
self.split_indexes["train"].append(cnt)
if cnt >= 498:
self.split_indexes["valid"].append(cnt)
elif random.randint(1, 100) > 95:
self.split_indexes["valid"].append(cnt)
else:
self.split_indexes["train"].append(cnt)
else:
self.split_indexes[sample["metadata"]["split"]].append(cnt)
cnt += 1
if cnt >= 500 and self.cfg.debug:
break
# print(len(self.split_indexes["train"]), len(self.split_indexes["valid"]), len(self.split_indexes["test"]))
super()._load_data()
@assign_split
def split(self, split_cfg: Optional[Config]=None) -> Tuple[Self, Self, Self]:
attrs = ["proteins", "texts", "labels"]
ret = (
self.get_subset(self.split_indexes["train"], attrs),
self.get_subset(self.split_indexes["valid"], attrs),
self.get_subset(self.split_indexes["test"], attrs),
)
del self
return ret |