bioflow / open_biomed /datasets /protein_text_dataset.py
Rami-Troudi's picture
Phase 1: FastAPI integration with DeepPurpose DTI predictor
adecc9b
raw
history blame
2.84 kB
from typing import Any, Optional, Tuple
from typing_extensions import Self
import json
import os
import random
from open_biomed.data import Protein, Text
from open_biomed.datasets.base_dataset import BaseDataset, assign_split, featurize
from open_biomed.utils.config import Config
from open_biomed.utils.featurizer import Featurizer
class ProteinTextDataset(BaseDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
self.proteins, self.texts = [], []
super().__init__(cfg, featurizer)
def __len__(self) -> int:
return len(self.proteins)
class TextBasedProteinGenerationDataset(ProteinTextDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
super().__init__(cfg, featurizer)
def _load_data(self) -> None:
self.labels = self.proteins
@featurize
def __getitem__(self, index) -> Any:
return {
"text": self.texts[index],
"label": self.proteins[index],
}
class MolInstructionsForProteinDesign(TextBasedProteinGenerationDataset):
def __init__(self, cfg: Config, featurizer: Featurizer) -> None:
super().__init__(cfg, featurizer)
def _load_data(self) -> None:
self.split_indexes = {"train": [], "valid": [], "test": []}
data = json.load(open(os.path.join(self.cfg.path, "protein_design.json"), "r"))
cnt = 0
for i, sample in enumerate(data):
seq = sample["output"].split("\n")[-2]
if "X" in seq:
continue
self.texts.append(Text.from_str(sample["input"]))
self.proteins.append(Protein.from_fasta(seq))
if sample["metadata"]["split"] == "train":
if self.cfg.debug:
self.split_indexes["train"].append(cnt)
if cnt >= 498:
self.split_indexes["valid"].append(cnt)
elif random.randint(1, 100) > 95:
self.split_indexes["valid"].append(cnt)
else:
self.split_indexes["train"].append(cnt)
else:
self.split_indexes[sample["metadata"]["split"]].append(cnt)
cnt += 1
if cnt >= 500 and self.cfg.debug:
break
# print(len(self.split_indexes["train"]), len(self.split_indexes["valid"]), len(self.split_indexes["test"]))
super()._load_data()
@assign_split
def split(self, split_cfg: Optional[Config]=None) -> Tuple[Self, Self, Self]:
attrs = ["proteins", "texts", "labels"]
ret = (
self.get_subset(self.split_indexes["train"], attrs),
self.get_subset(self.split_indexes["valid"], attrs),
self.get_subset(self.split_indexes["test"], attrs),
)
del self
return ret