ANKH_base / modeling_ankh.py
lhallee's picture
Upload modeling_ankh.py with huggingface_hub
7ec329a verified
Raw
History Blame Contribute Delete
109 kB
from __future__ import annotations
import torch
import torch._inductor.config as inductor_config
import torch._dynamo as dynamo
# Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs)
# Provides significant speedup with minimal precision loss
torch.set_float32_matmul_precision('high')
# Enable TF32 for matrix multiplications and cuDNN operations
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Enable cuDNN autotuner - finds fastest algorithms for your hardware
# Best when input sizes are consistent; may slow down first iterations
torch.backends.cudnn.benchmark = True
# Deterministic operations off for speed (set True if reproducibility needed)
torch.backends.cudnn.deterministic = False
inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM"
dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.recompile_limit = 16
import io
import os
import queue
import sqlite3
import struct
import threading
import time
import networkx as nx
import numpy as np
import torch
from tqdm.auto import tqdm
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from transformers import PreTrainedTokenizerBase
# SQLite stores tensors as compact blobs. Keep this header format compatible
# with Protify readers that share the same dtype/version codes.
_COMPACT_VERSION = 0x01
_DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2}
_CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32}
_CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32}
def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes:
"""Serialize a tensor to compact binary format for SQLite blob storage.
Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes]
bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16)
but tagged with dtype_code=1 so they can be cast back on read.
Falls back to torch.save for unsupported dtypes.
"""
t = tensor.cpu()
if t.dtype not in _DTYPE_TO_CODE:
buffer = io.BytesIO()
torch.save(t, buffer)
return buffer.getvalue()
dtype_code = _DTYPE_TO_CODE[t.dtype]
if t.dtype == torch.bfloat16:
raw = t.half().numpy().tobytes()
else:
raw = t.numpy().tobytes()
shape = t.shape
header = struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
return header + raw
def _compact_header(dtype: torch.dtype, shape: tuple) -> bytes:
"""Build just the compact header for a given dtype and shape."""
dtype_code = _DTYPE_TO_CODE[dtype]
return struct.pack(f'<BBi{len(shape)}i', _COMPACT_VERSION, dtype_code, len(shape), *shape)
def batch_tensor_to_blobs(batch: torch.Tensor) -> List[bytes]:
"""Serialize a batch of same-shape tensors to compact blobs (fast path for vectors).
Builds the header once and slices raw bytes per row. Much faster than
per-row tensor_to_embedding_blob calls for uniform-shape batches.
"""
assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}"
t = batch.cpu()
store_dtype = t.dtype
if t.dtype not in _DTYPE_TO_CODE:
return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])]
if t.dtype == torch.bfloat16:
arr = t.half().numpy()
store_dtype = torch.bfloat16
else:
arr = t.numpy()
row_shape = tuple(t.shape[1:])
header = _compact_header(store_dtype, row_shape)
raw = arr.tobytes()
stride = len(raw) // t.shape[0]
return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])]
def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor:
"""Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats."""
if len(blob) >= 6 and blob[0] == _COMPACT_VERSION:
dtype_code = blob[1]
ndim = struct.unpack_from('<i', blob, 2)[0]
shape = struct.unpack_from(f'<{ndim}i', blob, 6)
data_offset = 6 + 4 * ndim
np_dtype = _CODE_TO_NP_DTYPE[dtype_code]
arr = np.frombuffer(blob, dtype=np_dtype, offset=data_offset).copy().reshape(shape)
t = torch.from_numpy(arr)
target_dtype = _CODE_TO_DTYPE[dtype_code]
if target_dtype != t.dtype:
t = t.to(target_dtype)
return t
# Older `.pth`-style blobs were written with torch.save.
try:
buffer = io.BytesIO(blob)
return torch.load(buffer, map_location='cpu', weights_only=True)
except Exception:
pass
# Oldest SQLite rows stored raw float32 bytes and need a caller-supplied shape.
assert fallback_shape is not None, "Cannot deserialize blob: unknown format and no fallback_shape provided."
arr = np.frombuffer(blob, dtype=np.float32).copy().reshape(fallback_shape)
return torch.from_numpy(arr)
def select_hidden_state_embeddings(
last_hidden_state: torch.Tensor,
hidden_states: Optional[Tuple[torch.Tensor, ...]],
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
) -> torch.Tensor:
assert isinstance(hidden_state_index, int), "hidden_state_index must be an integer."
if store_all_hidden_states:
assert hidden_states is not None, "store_all_hidden_states requires output_hidden_states=True."
assert len(hidden_states) > 0, "Model returned no hidden states."
return torch.stack(tuple(hidden_states), dim=1)
if hidden_state_index == -1:
return last_hidden_state
assert hidden_states is not None, "hidden_state_index selection requires output_hidden_states=True."
return hidden_states[hidden_state_index]
def _trim_full_embedding(embedding: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.bool()
if embedding.ndim == 2:
return embedding[mask].reshape(-1, embedding.shape[-1])
if embedding.ndim == 3:
return embedding[:, mask, :].reshape(embedding.shape[0], -1, embedding.shape[-1])
raise AssertionError(f"Expected full embedding tensor with 2 or 3 dims, got {embedding.ndim}.")
def pool_embeddings(
embeddings: Dict[str, torch.Tensor],
pooling_types: List[str] = ['mean'],
hidden_state_index: int = -1,
) -> Dict[str, torch.Tensor]:
pooler = Pooler(pooling_types)
pooled: Dict[str, torch.Tensor] = {}
for sequence, embedding in embeddings.items():
assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
assert isinstance(embedding, torch.Tensor), "Expected embedding dictionary values to be tensors."
if embedding.ndim == 1:
pooled[sequence] = embedding.cpu()
continue
if embedding.ndim == 3:
embedding = embedding[hidden_state_index]
assert embedding.ndim == 2, f"Expected token-wise embedding with 2 dims, got {embedding.ndim}."
pooled[sequence] = pooler(embedding.unsqueeze(0)).squeeze(0).cpu()
return pooled
def load_pooled_embeddings_from_pth(
save_path: str,
pooling_types: List[str] = ['mean'],
hidden_state_index: int = -1,
) -> Dict[str, torch.Tensor]:
assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
payload = torch.load(save_path, map_location="cpu", weights_only=True)
assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
return pool_embeddings(payload, pooling_types=pooling_types, hidden_state_index=hidden_state_index)
def load_pooled_embeddings_from_db(
db_path: str,
sequences: Optional[List[str]] = None,
pooling_types: List[str] = ['mean'],
hidden_state_index: int = -1,
) -> Dict[str, torch.Tensor]:
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
loaded: Dict[str, torch.Tensor] = {}
with sqlite3.connect(db_path, timeout=30) as conn:
cursor = conn.cursor()
if sequences is None:
cursor.execute("SELECT sequence, embedding FROM embeddings")
else:
if len(sequences) == 0:
return loaded
placeholders = ",".join(["?"] * len(sequences))
cursor.execute(
f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})",
tuple(sequences),
)
for sequence, embedding_bytes in cursor.fetchall():
loaded[sequence] = embedding_blob_to_tensor(embedding_bytes)
return pool_embeddings(loaded, pooling_types=pooling_types, hidden_state_index=hidden_state_index)
def maybe_compile(model: torch.nn.Module, dynamic: bool = False) -> torch.nn.Module:
"""Compile model with torch.compile if possible.
Skips compilation when dynamic=True (padding='longest') because
flex attention's create_block_mask is incompatible with dynamic shapes
under torch.compile, causing CUDA illegal memory access.
"""
if dynamic:
print("Skipping torch.compile (dynamic shapes + flex attention incompatible)")
return model
try:
model = torch.compile(model)
print("Model compiled")
except Exception as e:
print(f"Skipping torch.compile: {e}")
return model
def build_collator(
tokenizer: PreTrainedTokenizerBase,
padding: str = 'max_length',
max_length: int = 512,
) -> Callable[[List[str]], Dict[str, torch.Tensor]]:
def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]:
kwargs: Dict[str, Any] = dict(
return_tensors="pt", padding=padding, truncation=True, max_length=max_length,
)
if padding != 'max_length':
kwargs['pad_to_multiple_of'] = 8
return tokenizer(sequences, **kwargs)
return _collate_fn
def _make_embedding_progress(
dataloader: DataLoader,
padding: str,
n_warmup: int = 3,
n_calibration: int = 5,
) -> Iterator[Tuple[int, Any]]:
"""Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader).
When padding='max_length', all batches have uniform cost so plain tqdm works.
When padding='longest' (sorted longest-first), batch times vary dramatically.
In that case: yield warmup batches first (compiler warmup + OOM check on longest
sequences), then time mid-length calibration batches to estimate total ETA.
Keep in sync with protify/embedder.py and core/atlas/precomputed.py.
"""
total = len(dataloader)
if padding == 'max_length' or total <= n_warmup + n_calibration:
for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'):
yield i, batch
return
dl_iter = iter(dataloader)
# Warm up on the longest batches first; sorted inputs make these the OOM-risk
# and compile-stabilization cases.
warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False)
for i in warmup_bar:
batch = next(dl_iter)
yield i, batch
warmup_bar.close()
# Move toward mid-length batches for ETA calibration, yielding every real
# batch on the way so no sequences are skipped.
mid_start = total // 2
intermediate_bar = tqdm(
range(n_warmup, mid_start), desc='Embedding batches', leave=False,
)
for i in intermediate_bar:
batch = next(dl_iter)
yield i, batch
intermediate_bar.close()
# Mid-length batches give a better remaining-time estimate than the longest
# warmup batches.
calibration_times: List[float] = []
cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False)
for j in cal_bar:
t0 = time.perf_counter()
batch = next(dl_iter)
yield mid_start + j, batch
calibration_times.append(time.perf_counter() - t0)
cal_bar.close()
avg_time = sum(calibration_times) / len(calibration_times)
remaining_start = mid_start + n_calibration
remaining_count = total - remaining_start
estimated_total_seconds = avg_time * remaining_count
# Finish the tail with the calibrated ETA shown in the progress bar.
main_bar = tqdm(
range(remaining_count),
desc='Embedding batches',
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]',
)
main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)')
for k in main_bar:
batch = next(dl_iter)
yield remaining_start + k, batch
main_bar.close()
class _SQLWriter:
"""Context manager for async SQL embedding writes. Matches core/embed/storage.SQLEmbeddingWriter."""
def __init__(self, conn: sqlite3.Connection, queue_maxsize: int = 4) -> None:
self._conn = conn
self._queue: queue.Queue = queue.Queue(maxsize=queue_maxsize)
self._thread: Optional[threading.Thread] = None
def __enter__(self) -> "_SQLWriter":
self._thread = threading.Thread(target=self._writer_loop, daemon=True)
self._thread.start()
return self
def write_batch(self, rows: List[Tuple[str, bytes]]) -> None:
self._queue.put(rows)
def _writer_loop(self) -> None:
cursor = self._conn.cursor()
while True:
item = self._queue.get()
if item is None:
break
cursor.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item)
if self._queue.qsize() == 0:
self._conn.commit()
self._conn.commit()
def __exit__(self, *exc) -> None:
if self._thread is not None:
self._queue.put(None)
self._thread.join()
self._thread = None
class Pooler:
def __init__(self, pooling_types: List[str]) -> None:
self.pooling_types = pooling_types
self.pooling_options: Dict[str, Callable] = {
'mean': self.mean_pooling,
'max': self.max_pooling,
'norm': self.norm_pooling,
'median': self.median_pooling,
'std': self.std_pooling,
'var': self.var_pooling,
'cls': self.cls_pooling,
'parti': self._pool_parti,
}
def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor:
assert isinstance(attentions, torch.Tensor)
maxed_attentions = torch.max(attentions, dim=1)[0]
return maxed_attentions
def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]:
G = self._convert_to_graph(attention_matrix)
if G.number_of_nodes() != attention_matrix.shape[0]:
raise Exception(
f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.")
if G.number_of_edges() == 0:
raise Exception(f"You don't seem to have any attention edges left in the graph.")
return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100)
def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph:
G = nx.from_numpy_array(matrix, create_using=nx.DiGraph)
return G
def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray:
if attention_mask is not None:
for k in list(dict_importance.keys()):
if attention_mask[k] == 0:
del dict_importance[k]
total = sum(dict_importance.values())
return np.array([v / total for _, v in dict_importance.items()])
def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy()
emb_pooled = []
for e, a, mask in zip(emb, maxed_attentions, attention_mask):
dict_importance = self._page_rank(a)
importance_weights = self._calculate_importance_weights(dict_importance, mask)
num_tokens = int(mask.sum().item())
emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0))
pooled = torch.tensor(np.array(emb_pooled))
return pooled
def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
if attention_mask is None:
return emb.mean(dim=1)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
if attention_mask is None:
return emb.max(dim=1).values
else:
mask = attention_mask.unsqueeze(-1).bool()
return emb.masked_fill(~mask, float('-inf')).max(dim=1).values
def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
if attention_mask is None:
return emb.norm(dim=1, p=2)
else:
attention_mask = attention_mask.unsqueeze(-1)
return (emb * attention_mask).norm(dim=1, p=2)
def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
if attention_mask is None:
return emb.median(dim=1).values
else:
mask = attention_mask.unsqueeze(-1).bool()
return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values
def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
if attention_mask is None:
return emb.std(dim=1)
else:
var = self.var_pooling(emb, attention_mask, **kwargs)
return torch.sqrt(var)
def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
if attention_mask is None:
return emb.var(dim=1)
else:
attention_mask = attention_mask.unsqueeze(-1)
mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
mean = mean.unsqueeze(1)
squared_diff = (emb - mean) ** 2
var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
return var
def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
return emb[:, 0, :]
def __call__(
self,
emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attentions: Optional[torch.Tensor] = None
) -> torch.Tensor:
if attention_mask is not None:
assert attention_mask.sum(dim=-1).min() > 0, (
"Pooler received samples with all-zero attention masks. "
"This causes NaN from division by zero. Filter empty inputs before pooling."
)
final_emb: List[torch.Tensor] = []
for pooling_type in self.pooling_types:
final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions))
return torch.cat(final_emb, dim=-1)
class ProteinDataset(TorchDataset):
"""Simple dataset for protein sequences."""
def __init__(self, sequences: List[str]) -> None:
self.sequences = sequences
def __len__(self) -> int:
return len(self.sequences)
def __getitem__(self, idx: int) -> str:
return self.sequences[idx]
def parse_fasta(fasta_path: str) -> List[str]:
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
sequences = []
current_seq = []
with open(fasta_path, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
if line.startswith('>'):
if current_seq:
sequences.append(''.join(current_seq))
current_seq = []
else:
current_seq.append(line)
if current_seq:
sequences.append(''.join(current_seq))
return sequences
class EmbeddingMixin:
def _embed(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
) -> torch.Tensor:
raise NotImplementedError
@property
def device(self) -> torch.device:
"""Get the device of the model."""
return next(self.parameters()).device
def _read_sequences_from_db(self, db_path: str) -> Set[str]:
"""Read sequences from SQLite database."""
with sqlite3.connect(db_path, timeout=30) as conn:
c = conn.cursor()
c.execute("SELECT sequence FROM embeddings")
return {row[0] for row in c.fetchall()}
def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None:
cursor = conn.cursor()
cursor.execute(
"CREATE TABLE IF NOT EXISTS embeddings ("
"sequence TEXT PRIMARY KEY, "
"embedding BLOB NOT NULL"
")"
)
conn.commit()
def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]:
assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}"
payload = torch.load(save_path, map_location="cpu", weights_only=True)
assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary."
for sequence, tensor in payload.items():
assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)."
assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors."
return payload
def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]:
assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}"
loaded: Dict[str, torch.Tensor] = {}
with sqlite3.connect(db_path, timeout=30) as conn:
self._ensure_embeddings_table(conn)
cursor = conn.cursor()
if sequences is None:
cursor.execute("SELECT sequence, embedding FROM embeddings")
else:
if len(sequences) == 0:
return loaded
placeholders = ",".join(["?"] * len(sequences))
cursor.execute(
f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})",
tuple(sequences),
)
rows = cursor.fetchall()
for row in rows:
sequence = row[0]
embedding_bytes = row[1]
loaded[sequence] = embedding_blob_to_tensor(embedding_bytes)
return loaded
def pool_embeddings(
self,
embeddings: Dict[str, torch.Tensor],
pooling_types: List[str] = ['mean'],
hidden_state_index: int = -1,
) -> Dict[str, torch.Tensor]:
return pool_embeddings(embeddings, pooling_types=pooling_types, hidden_state_index=hidden_state_index)
def load_pooled_embeddings_from_pth(
self,
save_path: str,
pooling_types: List[str] = ['mean'],
hidden_state_index: int = -1,
) -> Dict[str, torch.Tensor]:
return load_pooled_embeddings_from_pth(
save_path,
pooling_types=pooling_types,
hidden_state_index=hidden_state_index,
)
def load_pooled_embeddings_from_db(
self,
db_path: str,
sequences: Optional[List[str]] = None,
pooling_types: List[str] = ['mean'],
hidden_state_index: int = -1,
) -> Dict[str, torch.Tensor]:
return load_pooled_embeddings_from_db(
db_path,
sequences=sequences,
pooling_types=pooling_types,
hidden_state_index=hidden_state_index,
)
def embed_dataset(
self,
sequences: Optional[List[str]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
batch_size: int = 2,
max_len: int = 512,
truncate: bool = True,
full_embeddings: bool = False,
embed_dtype: torch.dtype = torch.float32,
pooling_types: List[str] = ['mean'],
num_workers: int = 0,
sql: bool = False,
save: bool = True,
sql_db_path: str = 'embeddings.db',
save_path: str = 'embeddings.pth',
fasta_path: Optional[str] = None,
padding: str = 'max_length',
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
**kwargs,
) -> Optional[Dict[str, torch.Tensor]]:
"""
Embed a dataset of protein sequences.
Supports two modes:
- Tokenizer mode (ESM2/ESM++): provide `tokenizer` or use `self.tokenizer`.
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
`fasta_path`, or both (the two sources are combined). At least one must be provided.
"""
if fasta_path is not None:
fasta_sequences = parse_fasta(fasta_path)
sequences = list(sequences or []) + fasta_sequences
assert sequences is not None and len(sequences) > 0, \
"Must provide at least one sequence via `sequences` or `fasta_path`."
assert isinstance(hidden_state_index, int), "hidden_state_index must be an integer."
assert full_embeddings or not store_all_hidden_states, \
"store_all_hidden_states=True requires full_embeddings=True."
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
sequences = sorted(sequences, key=len, reverse=True)
pooler = Pooler(pooling_types) if not full_embeddings else None
if tokenizer is None and self.config.model_type != "E1":
tokenizer = self.tokenizer
tokenizer_mode = tokenizer is not None
# Resolve padding and compilation
dynamic = padding == 'longest'
compiled_model = maybe_compile(self, dynamic=dynamic)
if tokenizer_mode:
collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len)
device = self.device
else:
collate_fn = None
device = None
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
assert isinstance(residue_embeddings, torch.Tensor)
if full_embeddings or residue_embeddings.ndim == 2:
return residue_embeddings
return pooler(residue_embeddings, attention_mask)
def iter_batches(to_embed: List[str]):
if tokenizer_mode:
assert collate_fn is not None
assert device is not None
dataset = ProteinDataset(to_embed)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=2 if num_workers > 0 else None,
collate_fn=collate_fn,
shuffle=False,
pin_memory=True,
)
for i, batch in _make_embedding_progress(dataloader, padding):
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
residue_embeddings = compiled_model._embed(
input_ids,
attention_mask,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
)
yield seqs, residue_embeddings, attention_mask
else:
for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'):
seqs = to_embed[batch_start:batch_start + batch_size]
batch_output = compiled_model._embed(
seqs,
return_attention_mask=True,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
**kwargs,
)
assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)."
assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values."
residue_embeddings, attention_mask = batch_output
assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor."
yield seqs, residue_embeddings, attention_mask
if sql:
# Resume safely: skip sequences already present in the SQLite table.
conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False)
conn.execute('PRAGMA journal_mode=WAL')
conn.execute('PRAGMA busy_timeout=30000')
conn.execute('PRAGMA synchronous=OFF')
conn.execute('PRAGMA cache_size=-64000')
self._ensure_embeddings_table(conn)
already_embedded = self._read_sequences_from_db(sql_db_path)
to_embed = [seq for seq in sequences if seq not in already_embedded]
print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
print(f"Embedding {len(to_embed)} new sequences")
if len(to_embed) > 0:
# Embed batches synchronously; serialize/write them on the SQL writer thread.
with _SQLWriter(conn) as writer:
with torch.inference_mode():
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
if full_embeddings:
batch_rows = []
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
batch_rows.append((seq, tensor_to_embedding_blob(_trim_full_embedding(emb, mask))))
else:
blobs = batch_tensor_to_blobs(embeddings)
batch_rows = list(zip(seqs, blobs))
writer.write_batch(batch_rows)
conn.close()
return None
embeddings_dict = {}
if os.path.exists(save_path):
embeddings_dict = self.load_embeddings_from_pth(save_path)
to_embed = [seq for seq in sequences if seq not in embeddings_dict]
print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
print(f"Embedding {len(to_embed)} new sequences")
else:
to_embed = sequences
print(f"Embedding {len(to_embed)} new sequences")
if len(to_embed) > 0:
with torch.inference_mode():
for seqs, residue_embeddings, attention_mask in iter_batches(to_embed):
embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype)
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
if full_embeddings:
emb = _trim_full_embedding(emb, mask)
embeddings_dict[seq] = emb.cpu()
if save:
torch.save(embeddings_dict, save_path)
return embeddings_dict
if __name__ == "__main__":
# Manual smoke test for pooling shape behavior.
pooler = Pooler(pooling_types=['max', 'parti'])
batch_size = 8
seq_len = 64
hidden_size = 128
num_layers = 12
emb = torch.randn(batch_size, seq_len, hidden_size)
attentions = torch.randn(batch_size, num_layers, seq_len, seq_len)
attention_mask = torch.ones(batch_size, seq_len)
y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions)
print(y.shape)
"""Shared attention infrastructure for all FastPLMs models.
Contains: AttentionBackend enum, backend resolution, mask creation,
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
"""
from enum import Enum
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange
try:
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask
except ImportError:
create_block_mask = None
flex_attention = None
BlockMask = None
_compiled_flex_attention = None
def _get_flex_attention_fn():
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set."""
global _compiled_flex_attention
if flex_attention is None:
return None
flex_mod = torch.nn.attention.flex_attention
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
return flex_attention
if _compiled_flex_attention is None:
_compiled_flex_attention = torch.compile(
flex_attention,
dynamic=False,
)
return _compiled_flex_attention
# HuggingFace `kernels` exposes slightly different APIs for Flash Attention 2
# and 3. Detect the loaded variant once so every caller uses the same dispatch.
def _infer_kernels_flash_variant(kernel) -> Optional[str]:
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
return "flash_attn2"
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
return "flash_attn3"
return None
def _try_get_kernels_flash():
try:
from kernels import get_kernel
except ImportError:
return None, None
flash_kernel = None
flash_kernel_variant = None
try:
flash_kernel = get_kernel("kernels-community/flash-attn3")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API."
except Exception:
try:
flash_kernel = get_kernel("kernels-community/flash-attn2")
flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel)
assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API."
except Exception:
flash_kernel = None
flash_kernel_variant = None
return flash_kernel, flash_kernel_variant
_FLASH_KERNELS_LOADED = False
FLASH_KERNEL = None
FLASH_KERNEL_VARIANT = None
def _ensure_flash_kernels_loaded():
global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT
if _FLASH_KERNELS_LOADED:
return
_FLASH_KERNELS_LOADED = True
FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash()
def _kernels_flash_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
causal: bool = False,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
"""Flash-attention forward, optionally overriding the softmax scale.
When `softmax_scale is None`, the flash kernel applies its default
`1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already
pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold).
Failing to override when Q is pre-scaled applies the scale twice. On
DPLM-150M, that produced pooled-embedding cosine around -0.12 and argmax
agreement around 0.27 vs SDPA.
"""
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.fwd(
q=query_states, k=key_states, v=value_states,
softmax_scale=softmax_scale, is_causal=causal,
)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_func(
q=query_states, k=key_states, v=value_states,
softmax_scale=softmax_scale, causal=causal,
)
except TypeError:
output = FLASH_KERNEL.flash_attn_func(
query_states, key_states, value_states,
0.0, softmax_scale, causal,
)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
def _kernels_flash_varlen_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_in_batch_q: int,
max_seqlen_in_batch_k: int,
causal: bool = False,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
"""Varlen flash-attention forward, optionally overriding the softmax scale.
See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be
passed when Q has been pre-scaled by the caller.
"""
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if FLASH_KERNEL_VARIANT == "flash_attn2":
return FLASH_KERNEL.varlen_fwd(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
softmax_scale=softmax_scale, is_causal=causal,
)[0]
if FLASH_KERNEL_VARIANT == "flash_attn3":
try:
output = FLASH_KERNEL.flash_attn_varlen_func(
q=query_states, k=key_states, v=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k,
softmax_scale=softmax_scale, causal=causal,
)
except TypeError:
output = FLASH_KERNEL.flash_attn_varlen_func(
query_states, key_states, value_states,
cu_seqlens_q, cu_seqlens_k,
max_seqlen_in_batch_q, max_seqlen_in_batch_k,
0.0, softmax_scale, causal,
)
if isinstance(output, tuple):
return output[0]
return output
raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}")
# Varlen flash attention runs only on real tokens. These helpers remove padding
# before the kernel call and restore the original padded batch shape afterward.
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices) -> torch.Tensor:
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim)
).reshape(-1, *other_shape)
@staticmethod
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype
)
grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor:
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
output[indices] = values
return output
@staticmethod
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
(indices,) = ctx.saved_tensors
return grad_output[indices], None, None
index_first_axis = IndexFirstAxis.apply
index_put_first_axis = IndexPutFirstAxis.apply
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor:
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)
def _unpad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask_2d: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
batch_size, seq_len, num_heads, head_dim = query_layer.shape
seqlens = attention_mask_2d.sum(dim=1).int()
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
max_seqlen = int(seqlens.max().item())
indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten()
query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices)
return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen)
def kernels_flash_attention_func(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
causal: bool = False,
softmax_scale: Optional[float] = None,
) -> torch.Tensor:
"""Public flash-attention entry point with optional padding handling.
`softmax_scale`:
None -> kernel applies its default `1 / sqrt(head_dim)`.
float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled
by the caller).
Caller contract: if a model family pre-scales Q by `1/sqrt(head_dim)`
before calling this function (ESM2, DPLM, DPLM2, E1, and ESMFold do), pass
`softmax_scale=1.0`. Otherwise the flash kernel applies its default scale
again, yielding an effective `1/head_dim` scale that drifts across layers.
"""
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
if not causal and attention_mask_2d is not None:
batch_size, q_len = query_states.shape[:2]
(
query_states, key_states, value_states,
indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k),
) = _unpad_input(query_states, key_states, value_states, attention_mask_2d)
attn_output_unpad = _kernels_flash_varlen_forward(
query_states=query_states, key_states=key_states, value_states=value_states,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k,
softmax_scale=softmax_scale,
)
return pad_input(attn_output_unpad, indices_q, batch_size, q_len)
else:
return _kernels_flash_forward(
query_states=query_states, key_states=key_states, value_states=value_states,
causal=causal, softmax_scale=softmax_scale,
)
# User-facing backend strings resolve to this enum before attention dispatch.
class AttentionBackend(Enum):
AUTO = "auto"
KERNELS_FLASH = "kernels_flash"
FLEX = "flex"
SDPA = "sdpa"
VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend)
_BACKEND_CONFIRMED = False
def resolve_attention_backend(requested_backend: str) -> AttentionBackend:
global _BACKEND_CONFIRMED
assert requested_backend in VALID_ATTENTION_BACKENDS, (
f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
)
if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value):
_ensure_flash_kernels_loaded()
if requested_backend == AttentionBackend.AUTO.value:
if FLASH_KERNEL is not None:
resolved = AttentionBackend.KERNELS_FLASH
elif flex_attention is not None:
resolved = AttentionBackend.FLEX
else:
resolved = AttentionBackend.SDPA
elif requested_backend == AttentionBackend.KERNELS_FLASH.value:
assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment."
resolved = AttentionBackend.KERNELS_FLASH
elif requested_backend == AttentionBackend.FLEX.value:
assert flex_attention is not None, "Flex Attention is not available in this environment."
resolved = AttentionBackend.FLEX
elif requested_backend == AttentionBackend.SDPA.value:
resolved = AttentionBackend.SDPA
else:
raise AssertionError(f"Unsupported attention backend: {requested_backend}")
if not _BACKEND_CONFIRMED:
print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'")
_BACKEND_CONFIRMED = True
return resolved
@torch.compiler.disable
def get_attention_mask(
effective_backend: AttentionBackend,
batch_size: int,
seq_len: int,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
"""Build padding masks once for all encoder layers.
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
"""
if attention_mask is None:
return None, None, None
attention_mask_2d = attention_mask.bool()
if effective_backend == AttentionBackend.KERNELS_FLASH:
return attention_mask_2d, None, None
if effective_backend == AttentionBackend.FLEX:
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
valid_lens = attention_mask_2d.sum(dim=-1)
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
return attention_mask_2d, None, flex_block_mask
# SDPA/manual masks only keys. Padding queries still attend to real keys, so
# their outputs stay finite instead of softmaxing over all -inf scores.
attention_mask_4d = attention_mask_2d[:, None, None, :]
return attention_mask_2d, attention_mask_4d, None
def bool_to_additive_mask(
bool_mask: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid).
Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))`
directly on a bool tensor returns a bool tensor because `-inf` casts to `True`.
That silently drops the mask. Always allocate a float tensor first, then fill it.
This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask.
"""
assert bool_mask.dtype == torch.bool, (
f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}"
)
additive = torch.zeros_like(bool_mask, dtype=dtype)
additive.masked_fill_(bool_mask.logical_not(), float("-inf"))
return additive
import typing as T
from dataclasses import dataclass, fields
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class TTTConfig:
lr: float = 4e-4
steps: int = 30
ags: int = 16
batch_size: int = 2
mask_ratio: float = 0.15
crop_size: int = 1024
bert_leave_prob: float = 0.1
bert_replace_prob: float = 0.1
optimizer: str = "sgd"
momentum: float = 0.0
weight_decay: float = 0.0
seed: int | None = 0
lora_rank: int = 8
lora_alpha: float = 32.0
lora_target_replace_module: str | None = None
lora_target_modules: tuple[str, ...] | None = None
initial_state_reset: bool = True
automatic_best_state_reset: bool = False
eval_each_step: bool = False
gradient_clip: bool = False
gradient_clip_max_norm: float = 1.0
@classmethod
def from_kwargs(cls, **kwargs: T.Any) -> "TTTConfig":
valid_names = {field.name for field in fields(cls)}
unknown_names = set(kwargs) - valid_names
assert len(unknown_names) == 0, f"Unknown TTTConfig fields: {sorted(unknown_names)}"
return cls(**kwargs)
def merged(self, overrides: T.Mapping[str, T.Any] | "TTTConfig" | None) -> "TTTConfig":
if overrides is None:
return self
if isinstance(overrides, TTTConfig):
return overrides
values = {field.name: self.__dict__[field.name] for field in fields(self)}
for name, value in overrides.items():
assert name in values, f"Unknown TTTConfig field: {name}"
values[name] = value
return TTTConfig(**values)
def verify(self) -> None:
assert self.lr > 0.0, "TTT learning rate must be positive."
assert self.steps >= 1, "TTT steps must be >= 1."
assert self.ags >= 1, "TTT gradient accumulation steps must be >= 1."
assert self.batch_size >= 1, "TTT batch_size must be >= 1."
assert 0.0 < self.mask_ratio <= 1.0, "TTT mask_ratio must be in (0, 1]."
assert self.crop_size >= 1, "TTT crop_size must be >= 1."
assert self.lora_rank >= 1, "TTT v1 is LoRA-only, so lora_rank must be >= 1."
assert self.lora_alpha > 0.0, "TTT lora_alpha must be positive."
assert self.optimizer in {"adamw", "sgd"}, "TTT optimizer must be 'adamw' or 'sgd'."
assert 0.0 <= self.bert_leave_prob <= 1.0, "bert_leave_prob must be in [0, 1]."
assert 0.0 <= self.bert_replace_prob <= 1.0, "bert_replace_prob must be in [0, 1]."
assert self.bert_leave_prob + self.bert_replace_prob <= 1.0, (
"bert_leave_prob + bert_replace_prob must be <= 1."
)
if self.gradient_clip:
assert self.gradient_clip_max_norm > 0.0, "gradient_clip_max_norm must be positive."
class LoraInjectedLinear(nn.Module):
def __init__(self, linear: nn.Module, rank: int, alpha: float) -> None:
super().__init__()
weight = linear._parameters["weight"]
assert weight.ndim == 2, "LoRA can only wrap 2D linear weights."
self.linear = linear
self.linear.requires_grad_(False)
self.rank = rank
self.scale = alpha
in_features = weight.shape[1]
out_features = weight.shape[0]
self.lora_down = nn.Linear(in_features, rank, bias=False, dtype=torch.float32)
self.lora_up = nn.Linear(rank, out_features, bias=False, dtype=torch.float32)
self.lora_down.to(device=weight.device)
self.lora_up.to(device=weight.device)
nn.init.normal_(self.lora_down.weight, std=1.0 / rank)
nn.init.zeros_(self.lora_up.weight)
@property
def weight(self) -> torch.Tensor:
return self.linear._parameters["weight"]
@property
def bias(self) -> torch.Tensor | None:
return self.linear._parameters["bias"]
def forward(self, x: torch.Tensor) -> torch.Tensor:
base = self.linear(x)
delta = self.lora_up(self.lora_down(x.to(dtype=torch.float32))) * self.scale
return base + delta.to(dtype=base.dtype)
class FastPLMTestTimeTrainingMixin:
def init_ttt(self, ttt_config: TTTConfig | T.Mapping[str, T.Any] | None = None) -> None:
base_config = TTTConfig()
self._ttt_cfg = base_config.merged(ttt_config)
self._ttt_cfg.verify()
self._ttt_initialized = False
self._ttt_initial_state: list[dict[str, torch.Tensor]] | None = None
@property
def ttt_config(self) -> TTTConfig:
if "_ttt_cfg" not in self.__dict__:
self.init_ttt()
return self._ttt_cfg
def _ttt_get_trainable_modules(self) -> list[nn.Module]:
return [self]
def _ttt_get_frozen_modules(self) -> list[nn.Module]:
return []
def _ttt_tokenize(
self,
seq: str | list[str] | None = None,
input_ids: torch.Tensor | None = None,
**kwargs: T.Any,
) -> torch.Tensor | dict[str, torch.Tensor]:
del kwargs
if input_ids is not None:
return input_ids
assert seq is not None, "Pass either seq or input_ids for TTT."
tokenized = self.tokenizer(seq, return_tensors="pt", padding=True)
return tokenized["input_ids"]
def _ttt_mask_token(self) -> int:
return int(self.tokenizer.mask_token_id)
def _ttt_padding_token(self) -> int:
return int(self.tokenizer.pad_token_id)
def _ttt_replacement_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
tokenizer = self.tokenizer
special_ids = set(tokenizer.all_special_ids)
vocab_size = int(self.config.vocab_size)
ids = [idx for idx in range(vocab_size) if idx not in special_ids]
assert len(ids) > 0, "TTT replacement token set is empty."
return torch.tensor(ids, device=input_ids.device, dtype=input_ids.dtype)
def _ttt_predict_logits(
self,
batch: torch.Tensor | dict[str, torch.Tensor],
**kwargs: T.Any,
) -> torch.Tensor:
del kwargs
if isinstance(batch, dict):
output = self(**batch)
return output.logits
attention_mask = batch.ne(self._ttt_padding_token())
output = self(input_ids=batch, attention_mask=attention_mask)
return output.logits
def _ttt_eval_step(
self,
step: int,
loss: float,
seq: str | list[str] | None = None,
input_ids: torch.Tensor | None = None,
**kwargs: T.Any,
) -> tuple[dict[str, T.Any], float | None]:
del step, loss, seq, input_ids, kwargs
return {}, None
def _ttt_is_lora_target(
self,
name: str,
full_name: str,
module: nn.Module,
active: bool,
target_modules: tuple[str, ...] | None,
) -> bool:
if not active:
return False
if isinstance(module, LoraInjectedLinear):
return False
if (
target_modules is not None
and name not in target_modules
and full_name not in target_modules
):
return False
if isinstance(module, nn.Linear):
return True
if "weight" not in module._parameters:
return False
weight = module._parameters["weight"]
if weight is None or weight.ndim != 2:
return False
return "Linear" in module.__class__.__name__
def _ttt_inject_lora(self) -> int:
cfg = self.ttt_config
cfg.verify()
target_class = cfg.lora_target_replace_module
target_modules = cfg.lora_target_modules
wrapped = 0
def inject(module: nn.Module, prefix: str, active: bool) -> None:
nonlocal wrapped
for name, child in list(module.named_children()):
full_name = f"{prefix}.{name}" if prefix else name
child_active = active
if target_class is not None:
child_active = active or child.__class__.__name__ == target_class
if self._ttt_is_lora_target(name, full_name, child, child_active, target_modules):
setattr(
module,
name,
LoraInjectedLinear(child, rank=cfg.lora_rank, alpha=cfg.lora_alpha),
)
wrapped += 1
continue
inject(child, full_name, child_active)
for trainable_module in self._ttt_get_trainable_modules():
inject(trainable_module, "", target_class is None)
assert wrapped > 0, "TTT LoRA injection did not find any target modules."
return wrapped
def _ttt_lora_modules(self) -> list[LoraInjectedLinear]:
return [module for module in self.modules() if isinstance(module, LoraInjectedLinear)]
def _ttt_lora_parameters(self) -> list[nn.Parameter]:
params: list[nn.Parameter] = []
for module in self._ttt_lora_modules():
params.extend(module.lora_down.parameters())
params.extend(module.lora_up.parameters())
assert len(params) > 0, "TTT has no LoRA parameters."
return params
def _ttt_snapshot_lora_state(self) -> list[dict[str, torch.Tensor]]:
snapshot = []
for module in self._ttt_lora_modules():
snapshot.append(
{
"lora_down.weight": module.lora_down.weight.detach().clone(),
"lora_up.weight": module.lora_up.weight.detach().clone(),
}
)
assert len(snapshot) > 0, "TTT has no LoRA state to snapshot."
return snapshot
def _ttt_restore_lora_state(self, state: list[dict[str, torch.Tensor]]) -> None:
modules = self._ttt_lora_modules()
assert len(modules) == len(state), "TTT LoRA state/module count mismatch."
with torch.no_grad():
for module, module_state in zip(modules, state):
module.lora_down.weight.copy_(module_state["lora_down.weight"])
module.lora_up.weight.copy_(module_state["lora_up.weight"])
def _ttt_ensure_initialized(self) -> None:
if "_ttt_cfg" not in self.__dict__:
self.init_ttt()
if self._ttt_initialized:
return
self._ttt_inject_lora()
self._ttt_initial_state = self._ttt_snapshot_lora_state()
self._ttt_initialized = True
def ttt_reset(self) -> None:
self._ttt_ensure_initialized()
assert self._ttt_initial_state is not None, "TTT initial state is not available."
self._ttt_restore_lora_state(self._ttt_initial_state)
def _ttt_make_optimizer(self) -> torch.optim.Optimizer:
cfg = self.ttt_config
params = self._ttt_lora_parameters()
if cfg.optimizer == "sgd":
return torch.optim.SGD(
params,
lr=cfg.lr,
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
)
return torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)
def _ttt_to_device(
self,
batch: torch.Tensor | dict[str, torch.Tensor],
device: torch.device,
) -> torch.Tensor | dict[str, torch.Tensor]:
if isinstance(batch, dict):
return {name: tensor.to(device) for name, tensor in batch.items()}
return batch.to(device)
def _ttt_input_ids_from_batch(
self,
batch: torch.Tensor | dict[str, torch.Tensor],
) -> torch.Tensor:
if isinstance(batch, dict):
return batch["input_ids"]
return batch
def _ttt_set_input_ids(
self,
batch: torch.Tensor | dict[str, torch.Tensor],
input_ids: torch.Tensor,
) -> torch.Tensor | dict[str, torch.Tensor]:
if isinstance(batch, dict):
updated = dict(batch)
updated["input_ids"] = input_ids
return updated
return input_ids
def _ttt_non_special_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
pad_token = self._ttt_padding_token()
mask = input_ids.ne(pad_token)
special_ids = set(self.tokenizer.all_special_ids)
for special_id in special_ids:
mask = mask & input_ids.ne(int(special_id))
return mask
def _ttt_sample_crop(
self,
batch: torch.Tensor | dict[str, torch.Tensor],
generator: torch.Generator,
) -> torch.Tensor | dict[str, torch.Tensor]:
input_ids = self._ttt_input_ids_from_batch(batch)
cfg = self.ttt_config
if input_ids.shape[1] <= cfg.crop_size:
return batch
high = input_ids.shape[1] - cfg.crop_size + 1
start = int(
torch.randint(
high,
(1,),
generator=generator,
device=input_ids.device,
).item()
)
end = start + cfg.crop_size
if isinstance(batch, dict):
cropped = {}
for name, tensor in batch.items():
if tensor.ndim >= 2 and tensor.shape[1] == input_ids.shape[1]:
cropped[name] = tensor[:, start:end]
else:
cropped[name] = tensor
return cropped
return input_ids[:, start:end]
def _ttt_sample_batch(
self,
tokenized: torch.Tensor | dict[str, torch.Tensor],
generator: torch.Generator,
) -> tuple[torch.Tensor | dict[str, torch.Tensor], torch.Tensor]:
cfg = self.ttt_config
batch = self._ttt_sample_crop(tokenized, generator)
input_ids = self._ttt_input_ids_from_batch(batch)
rows = torch.randint(
input_ids.shape[0],
(cfg.batch_size,),
generator=generator,
device=input_ids.device,
)
if isinstance(batch, dict):
sampled: torch.Tensor | dict[str, torch.Tensor] = {}
for name, tensor in batch.items():
if tensor.ndim >= 1 and tensor.shape[0] == input_ids.shape[0]:
sampled[name] = tensor.index_select(0, rows)
else:
sampled[name] = tensor
else:
sampled = input_ids.index_select(0, rows)
sampled_ids = self._ttt_input_ids_from_batch(sampled)
labels = sampled_ids.clone()
non_special = self._ttt_non_special_mask(sampled_ids)
label_mask = torch.zeros_like(non_special)
for row_idx in range(sampled_ids.shape[0]):
candidate_positions = torch.where(non_special[row_idx])[0]
if candidate_positions.numel() == 0:
continue
num_mask = max(1, int(round(candidate_positions.numel() * cfg.mask_ratio)))
order = torch.randperm(
candidate_positions.numel(),
generator=generator,
device=sampled_ids.device,
)
chosen = candidate_positions[order[:num_mask]]
label_mask[row_idx, chosen] = True
labels = labels.masked_fill(~label_mask, -100)
masked_ids = sampled_ids.clone()
chosen_positions = torch.where(label_mask)
if chosen_positions[0].numel() > 0:
random_values = torch.rand(
chosen_positions[0].shape,
generator=generator,
device=sampled_ids.device,
)
leave = random_values < cfg.bert_leave_prob
replace = (random_values >= cfg.bert_leave_prob) & (
random_values < cfg.bert_leave_prob + cfg.bert_replace_prob
)
mask = ~(leave | replace)
if mask.any():
masked_ids[
chosen_positions[0][mask],
chosen_positions[1][mask],
] = self._ttt_mask_token()
if replace.any():
replacement_tokens = self._ttt_replacement_tokens(sampled_ids)
replacement_idx = torch.randint(
replacement_tokens.shape[0],
(int(replace.sum().item()),),
generator=generator,
device=sampled_ids.device,
)
masked_ids[
chosen_positions[0][replace],
chosen_positions[1][replace],
] = replacement_tokens[replacement_idx]
return self._ttt_set_input_ids(sampled, masked_ids), labels
def ttt(
self,
seq: str | list[str] | None = None,
input_ids: torch.Tensor | None = None,
ttt_config: TTTConfig | T.Mapping[str, T.Any] | None = None,
**kwargs: T.Any,
) -> dict[str, T.Any]:
if ttt_config is not None:
if "_ttt_initialized" in self.__dict__ and self._ttt_initialized:
next_cfg = self.ttt_config.merged(ttt_config)
assert next_cfg.lora_rank == self.ttt_config.lora_rank, (
"Changing lora_rank after TTT initialization is not supported."
)
assert next_cfg.lora_alpha == self.ttt_config.lora_alpha, (
"Changing lora_alpha after TTT initialization is not supported."
)
assert (
next_cfg.lora_target_replace_module
== self.ttt_config.lora_target_replace_module
), "Changing LoRA target class after TTT initialization is not supported."
assert next_cfg.lora_target_modules == self.ttt_config.lora_target_modules, (
"Changing LoRA target modules after TTT initialization is not supported."
)
self._ttt_cfg = next_cfg
else:
self.init_ttt(ttt_config)
self._ttt_ensure_initialized()
cfg = self.ttt_config
if cfg.initial_state_reset:
self.ttt_reset()
device = next(self.parameters()).device
tokenized = self._ttt_tokenize(seq=seq, input_ids=input_ids, **kwargs)
tokenized = self._ttt_to_device(tokenized, device)
generator_device = device if device.type == "cuda" else torch.device("cpu")
generator = torch.Generator(device=generator_device)
if cfg.seed is not None:
generator.manual_seed(cfg.seed)
module_modes = {module: module.training for module in self.modules()}
requires_grad = {param: param.requires_grad for param in self.parameters()}
losses: list[float] = []
step_metrics: list[dict[str, T.Any]] = []
best_state: list[dict[str, torch.Tensor]] | None = None
best_metric: float | None = None
best_step = 0
try:
self.train()
for param in self.parameters():
param.requires_grad_(False)
for param in self._ttt_lora_parameters():
param.requires_grad_(True)
optimizer = self._ttt_make_optimizer()
optimizer.zero_grad(set_to_none=True)
total_micro_steps = cfg.steps * cfg.ags
for micro_step in range(total_micro_steps):
batch, labels = self._ttt_sample_batch(tokenized, generator)
logits = self._ttt_predict_logits(batch, **kwargs)
labels = labels.to(device=logits.device)
loss = F.cross_entropy(
logits.reshape(-1, logits.shape[-1]),
labels.reshape(-1),
ignore_index=-100,
)
(loss / cfg.ags).backward()
if (micro_step + 1) % cfg.ags != 0:
continue
if cfg.gradient_clip:
torch.nn.utils.clip_grad_norm_(
self._ttt_lora_parameters(),
cfg.gradient_clip_max_norm,
)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
step = (micro_step + 1) // cfg.ags
loss_value = float(loss.detach().item())
losses.append(loss_value)
if cfg.eval_each_step:
metrics, metric = self._ttt_eval_step(
step=step,
loss=loss_value,
seq=seq,
input_ids=input_ids,
**kwargs,
)
if len(metrics) > 0:
step_metrics.append(metrics)
if metric is not None and (
best_metric is None or metric > best_metric
):
best_metric = metric
best_step = step
best_state = self._ttt_snapshot_lora_state()
if cfg.automatic_best_state_reset and best_state is not None:
self._ttt_restore_lora_state(best_state)
finally:
for param, value in requires_grad.items():
param.requires_grad_(value)
for module, training in module_modes.items():
module.train(training)
return {
"losses": losses,
"step_metrics": step_metrics,
"best_step": best_step,
"best_metric": best_metric,
}
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Optional, Tuple, Dict, Any
from dataclasses import dataclass
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
from transformers.modeling_outputs import ModelOutput
# ---------------------------------------------------------------------------
# Output dataclasses
# ---------------------------------------------------------------------------
@dataclass
class AnkhEncoderOutput(ModelOutput):
last_hidden_state: Optional[torch.Tensor] = None
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
attentions: Optional[Tuple[torch.Tensor, ...]] = None
@dataclass
class AnkhMaskedLMOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
last_hidden_state: Optional[torch.Tensor] = None
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
attentions: Optional[Tuple[torch.Tensor, ...]] = None
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
class FastAnkhConfig(PretrainedConfig):
model_type = "fast_ankh"
attribute_map = {"hidden_size": "d_model"}
def __init__(
self,
vocab_size: int = 144,
d_model: int = 768,
d_kv: int = 64,
d_ff: int = 3072,
num_heads: int = 12,
num_layers: int = 48,
relative_attention_num_buckets: int = 64,
relative_attention_max_distance: int = 128,
dense_act_fn: str = "gelu_new",
layer_norm_epsilon: float = 1e-6,
initializer_factor: float = 1.0,
pad_token_id: int = 0,
eos_token_id: int = 1,
attn_backend: str = "sdpa",
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_heads = num_heads
self.num_layers = num_layers
self.relative_attention_num_buckets = relative_attention_num_buckets
self.relative_attention_max_distance = relative_attention_max_distance
self.dense_act_fn = dense_act_fn
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.tie_word_embeddings = False
self.attn_backend = attn_backend
def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()
return output
def _load_ankh_tokenizer(config: FastAnkhConfig):
"""Load the checkpoint-matched tokenizer, falling back only for bare configs."""
name_or_path = config._name_or_path
if isinstance(name_or_path, str) and len(name_or_path) > 0:
return AutoTokenizer.from_pretrained(name_or_path)
return AutoTokenizer.from_pretrained("ElnaggarLab/ankh-base")
# ---------------------------------------------------------------------------
# Submodules
# ---------------------------------------------------------------------------
class AnkhRMSNorm(nn.Module):
"""T5-style RMS layer norm: scales without mean subtraction or bias."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(self.weight.dtype)
def _gelu_new(x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class AnkhGatedFFN(nn.Module):
"""T5-style gated feed-forward: activation(wi_0(x)) * wi_1(x) -> wo."""
def __init__(self, config: FastAnkhConfig):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.act = F.silu if config.dense_act_fn == "silu" else _gelu_new
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.wo(self.act(self.wi_0(hidden_states)) * self.wi_1(hidden_states))
# ---------------------------------------------------------------------------
# Attention
# ---------------------------------------------------------------------------
class AnkhSelfAttention(nn.Module):
"""T5-style self-attention with relative position bias and multi-backend dispatch.
Only layer 0 has ``has_relative_attention_bias=True`` and owns the
``nn.Embedding`` that produces the position bias. All other layers
receive the precomputed bias through the forward call.
"""
def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False):
super().__init__()
self.num_heads = config.num_heads
self.d_kv = config.d_kv
self.inner_dim = self.num_heads * self.d_kv
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.q = nn.Linear(config.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(config.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(config.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, config.d_model, bias=False)
# T5/ANKH attention is unscaled: scores = Q K^T (no 1/sqrt(d_kv)).
# The learned relative position bias absorbs any temperature.
self.scale = 1.0
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
)
self.attn_backend: AttentionBackend = AttentionBackend.SDPA # set by encoder
# ---- T5 relative position bucketing ----
@staticmethod
def _relative_position_bucket(
relative_position: torch.Tensor,
num_buckets: int = 32,
max_distance: int = 128,
) -> torch.Tensor:
"""Bidirectional log-bucketed relative position mapping (T5 style)."""
# Bidirectional: half buckets for negative, half for positive
num_buckets //= 2
relative_buckets = (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
max_exact = num_buckets // 2
is_small = relative_position < max_exact
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.clamp(relative_position_if_large, max=num_buckets - 1)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length: int, key_length: int, device: torch.device) -> torch.Tensor:
"""Compute (1, H, Q, K) position bias tensor for SDPA / manual paths."""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position
buckets = self._relative_position_bucket(
relative_position,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(buckets) # (Q, K, H)
return values.permute(2, 0, 1).unsqueeze(0) # (1, H, Q, K)
# ---- Forward ----
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
attention_mask_4d: Optional[torch.Tensor] = None,
flex_block_mask: Optional[BlockMask] = None,
position_bias: Optional[torch.Tensor] = None,
flex_score_mod=None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Returns (attn_output, attn_weights_or_none, position_bias)."""
batch_size, seq_length = hidden_states.shape[:2]
hidden_shape = (batch_size, seq_length, self.num_heads, self.d_kv)
query_BHLD = self.q(hidden_states).view(hidden_shape).transpose(1, 2)
key_BHLD = self.k(hidden_states).view(hidden_shape).transpose(1, 2)
value_BHLD = self.v(hidden_states).view(hidden_shape).transpose(1, 2)
# Compute position bias on first layer (SDPA/manual only; flex uses score_mod)
if position_bias is None and self.has_relative_attention_bias and self.attn_backend != AttentionBackend.FLEX:
position_bias = self.compute_bias(seq_length, seq_length, hidden_states.device)
# Fold padding mask into position bias so layers don't need separate mask.
if attention_mask_4d is not None:
position_bias = position_bias + bool_to_additive_mask(attention_mask_4d, position_bias.dtype)
if output_attentions:
attn_output, attn_weights = self._manual_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
return self.o(attn_output), attn_weights, position_bias
if self.attn_backend == AttentionBackend.FLEX:
attn_output = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask, flex_score_mod)
elif self.attn_backend == AttentionBackend.SDPA:
attn_output = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, position_bias)
else:
raise AssertionError(f"Unsupported backend for ANKH: {self.attn_backend}")
return self.o(attn_output), None, position_bias
def _sdpa_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
position_bias: Optional[torch.Tensor],
) -> torch.Tensor:
# SDPA: position_bias is (1, H, Q, K) additive bias (includes padding mask)
context_BHLD = F.scaled_dot_product_attention(
query_BHLD, key_BHLD, value_BHLD,
attn_mask=position_bias,
scale=self.scale,
)
return context_BHLD.transpose(1, 2).contiguous().view(
query_BHLD.shape[0], -1, self.inner_dim
)
def _flex_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
flex_block_mask: Optional[BlockMask],
flex_score_mod,
) -> torch.Tensor:
assert flex_attention is not None, "Flex attention is not available."
fn = _get_flex_attention_fn()
context_BHLD = fn(
query_BHLD, key_BHLD, value_BHLD,
score_mod=flex_score_mod,
block_mask=flex_block_mask,
scale=self.scale,
)
return context_BHLD.transpose(1, 2).contiguous().view(
query_BHLD.shape[0], -1, self.inner_dim
)
def _manual_attn(
self,
query_BHLD: torch.Tensor,
key_BHLD: torch.Tensor,
value_BHLD: torch.Tensor,
position_bias: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) * self.scale
if position_bias is not None:
attn_weights = attn_weights + position_bias
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
context_BHLD = torch.matmul(attn_weights, value_BHLD)
attn_output = context_BHLD.transpose(1, 2).contiguous().view(
query_BHLD.shape[0], -1, self.inner_dim
)
return attn_output, attn_weights
# ---------------------------------------------------------------------------
# Encoder block & stack (T5-compatible key naming)
# ---------------------------------------------------------------------------
class AnkhSelfAttentionLayer(nn.Module):
"""Wraps AnkhSelfAttention + layer_norm to match T5Block.layer[0] key naming."""
def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False):
super().__init__()
self.SelfAttention = AnkhSelfAttention(config, has_relative_attention_bias)
self.layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
attention_mask_4d: Optional[torch.Tensor] = None,
flex_block_mask: Optional[BlockMask] = None,
position_bias: Optional[torch.Tensor] = None,
flex_score_mod=None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
normed = self.layer_norm(hidden_states)
attn_output, attn_weights, position_bias = self.SelfAttention(
normed,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
position_bias=position_bias,
flex_score_mod=flex_score_mod,
output_attentions=output_attentions,
)
hidden_states = hidden_states + attn_output
return hidden_states, attn_weights, position_bias
class AnkhFFLayer(nn.Module):
"""Wraps AnkhGatedFFN + layer_norm to match T5Block.layer[1] key naming."""
def __init__(self, config: FastAnkhConfig):
super().__init__()
self.DenseReluDense = AnkhGatedFFN(config)
self.layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
normed = self.layer_norm(hidden_states)
hidden_states = hidden_states + self.DenseReluDense(normed)
return hidden_states
class AnkhBlock(nn.Module):
"""Single transformer block with T5-compatible .layer ModuleList naming."""
def __init__(self, config: FastAnkhConfig, has_relative_attention_bias: bool = False):
super().__init__()
self.layer = nn.ModuleList([
AnkhSelfAttentionLayer(config, has_relative_attention_bias),
AnkhFFLayer(config),
])
def forward(
self,
hidden_states: torch.Tensor,
attention_mask_2d: Optional[torch.Tensor] = None,
attention_mask_4d: Optional[torch.Tensor] = None,
flex_block_mask: Optional[BlockMask] = None,
position_bias: Optional[torch.Tensor] = None,
flex_score_mod=None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
hidden_states, attn_weights, position_bias = self.layer[0](
hidden_states,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
position_bias=position_bias,
flex_score_mod=flex_score_mod,
output_attentions=output_attentions,
)
hidden_states = self.layer[1](hidden_states)
return hidden_states, attn_weights, position_bias
# ---------------------------------------------------------------------------
# PreTrainedModel base
# ---------------------------------------------------------------------------
class AnkhPreTrainedModel(PreTrainedModel):
config_class = FastAnkhConfig
base_model_prefix = "encoder"
supports_gradient_checkpointing = True
_no_split_modules = ["AnkhBlock"]
@classmethod
def is_remote_code(cls) -> bool:
return True
@torch.no_grad()
def _init_weights(self, module: nn.Module) -> None:
factor = self.config.initializer_factor
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5))
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, AnkhRMSNorm):
module.weight.data.fill_(1.0)
def post_init(self) -> None:
super().post_init()
def get_output_embeddings(self):
return None
@property
def attn_backend(self) -> str:
return self.config.attn_backend
@attn_backend.setter
def attn_backend(self, backend: str) -> None:
assert backend in VALID_ATTENTION_BACKENDS, (
f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}."
)
self.config.attn_backend = backend
resolved = resolve_attention_backend(backend)
if resolved == AttentionBackend.KERNELS_FLASH:
print("ANKH: kernels_flash -> flex/sdpa fallback")
resolved = AttentionBackend.FLEX if flex_attention is not None else AttentionBackend.SDPA
for module in self.modules():
if isinstance(module, FAST_ANKH_ENCODER):
module.attention_backend = resolved
elif isinstance(module, AnkhSelfAttention):
module.attn_backend = resolved
# ---------------------------------------------------------------------------
# FAST_ANKH_ENCODER (mirrors T5Stack key naming)
# ---------------------------------------------------------------------------
class FAST_ANKH_ENCODER(AnkhPreTrainedModel, EmbeddingMixin):
"""Inner encoder that mirrors T5Stack attribute naming for weight compliance.
State dict keys: embed_tokens.*, block.{i}.layer.0.SelfAttention.*,
block.{i}.layer.1.DenseReluDense.*, final_layer_norm.*.
"""
def __init__(self, config: FastAnkhConfig, **kwargs):
AnkhPreTrainedModel.__init__(self, config, **kwargs)
self.config = config
resolved = resolve_attention_backend(config.attn_backend)
if resolved == AttentionBackend.KERNELS_FLASH:
print("ANKH: kernels_flash not supported (relative position bias); falling back to flex/sdpa")
resolved = AttentionBackend.FLEX if flex_attention is not None else AttentionBackend.SDPA
self.attention_backend = resolved
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.block = nn.ModuleList([
AnkhBlock(config, has_relative_attention_bias=(i == 0))
for i in range(config.num_layers)
])
for blk in self.block:
blk.layer[0].SelfAttention.attn_backend = self.attention_backend
self.final_layer_norm = AnkhRMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.gradient_checkpointing = False
self.tokenizer = _load_ankh_tokenizer(config)
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@torch.compiler.disable
def _compute_materialized_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""Precompute full (Q, K, H) bias tensor for flex score_mod lookup."""
bias_embedding = self.block[0].layer[0].SelfAttention.relative_attention_bias
context_position = torch.arange(seq_len, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(seq_len, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position
buckets = AnkhSelfAttention._relative_position_bucket(
relative_position,
num_buckets=self.config.relative_attention_num_buckets,
max_distance=self.config.relative_attention_max_distance,
)
return bias_embedding(buckets) # (Q, K, H)
def _build_flex_score_mod(self, seq_len: int, device: torch.device):
"""Build score_mod closure that reads from materialized bias tensor."""
bias = self._compute_materialized_bias(seq_len, device)
def score_mod(score, b, h, q_idx, kv_idx):
return score + bias[q_idx, kv_idx, h]
return score_mod
def _embed(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
output_hidden_states = store_all_hidden_states or hidden_state_index != -1
encoder_output = self._run_encoder(
hidden_states,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
)
return select_hidden_state_embeddings(
encoder_output.last_hidden_state,
encoder_output.hidden_states,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
)
def _run_encoder(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
output_attentions: bool = False,
) -> AnkhEncoderOutput:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
batch_size, seq_len = hidden_states.shape[:2]
attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask(
effective_backend=self.attention_backend,
batch_size=batch_size,
seq_len=seq_len,
device=hidden_states.device,
attention_mask=attention_mask,
)
flex_score_mod = None
position_bias = None
if self.attention_backend == AttentionBackend.FLEX:
flex_score_mod = self._build_flex_score_mod(seq_len, hidden_states.device)
for layer_module in self.block:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states, attn_weights, position_bias = self._gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask_2d,
attention_mask_4d,
flex_block_mask,
position_bias,
flex_score_mod,
output_attentions,
)
else:
hidden_states, attn_weights, position_bias = layer_module(
hidden_states,
attention_mask_2d=attention_mask_2d,
attention_mask_4d=attention_mask_4d,
flex_block_mask=flex_block_mask,
position_bias=position_bias,
flex_score_mod=flex_score_mod,
output_attentions=output_attentions,
)
if all_attentions is not None:
all_attentions = all_attentions + (attn_weights,)
hidden_states = self.final_layer_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return AnkhEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
) -> AnkhEncoderOutput:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
hidden_states = self.embed_tokens(input_ids)
elif inputs_embeds is not None:
hidden_states = inputs_embeds
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
return self._run_encoder(
hidden_states,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states or False,
output_attentions=output_attentions or False,
)
# ---------------------------------------------------------------------------
# Model classes
# ---------------------------------------------------------------------------
class FastAnkhModel(AnkhPreTrainedModel, EmbeddingMixin):
"""ANKH encoder model for embedding extraction."""
def __init__(self, config: FastAnkhConfig, **kwargs):
AnkhPreTrainedModel.__init__(self, config, **kwargs)
self.config = config
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = FAST_ANKH_ENCODER(config)
self.post_init()
@property
def tokenizer(self):
return self.encoder.tokenizer
def get_input_embeddings(self):
return self.encoder.embed_tokens
def set_input_embeddings(self, value):
self.encoder.embed_tokens = value
def _embed(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
) -> torch.Tensor:
return self.encoder._embed(
input_ids,
attention_mask,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
) -> AnkhEncoderOutput:
return self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
class FastAnkhForMaskedLM(FastPLMTestTimeTrainingMixin, AnkhPreTrainedModel, EmbeddingMixin):
"""ANKH encoder with LM head for masked language modeling.
NOTE: The LM head is initialized from the shared embedding weights but is NOT
tied. The original ANKH models were trained with T5's span corruption objective
using an encoder-decoder architecture. This encoder-only MaskedLM variant is
not pre-trained for standard MLM and requires additional fine-tuning.
"""
def __init__(self, config: FastAnkhConfig, **kwargs):
AnkhPreTrainedModel.__init__(self, config, **kwargs)
self.config = config
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = FAST_ANKH_ENCODER(config)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.loss_fct = nn.CrossEntropyLoss()
self.post_init()
self.init_ttt({"lora_target_replace_module": "AnkhSelfAttention"})
@property
def tokenizer(self):
return self.encoder.tokenizer
def get_input_embeddings(self):
return self.encoder.embed_tokens
def set_input_embeddings(self, value):
self.encoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def _embed(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
) -> torch.Tensor:
return self.encoder._embed(
input_ids,
attention_mask,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
)
def _ttt_get_trainable_modules(self) -> list[nn.Module]:
return [self.encoder]
def _ttt_tokenize(
self,
seq: str | list[str] | None = None,
input_ids: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
del kwargs
if input_ids is not None:
return input_ids
assert seq is not None, "Pass either seq or input_ids for ANKH TTT."
sequences = [seq] if isinstance(seq, str) else seq
spaced_sequences = [" ".join(sequence) for sequence in sequences]
tokenized = self.tokenizer(spaced_sequences, return_tensors="pt", padding=True)
return tokenized["input_ids"]
def _ttt_replacement_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
amino_acids = "ACDEFGHIKLMNPQRSTVWY"
ids = [self.tokenizer.convert_tokens_to_ids(aa) for aa in amino_acids]
return torch.tensor(ids, device=input_ids.device, dtype=input_ids.dtype)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
) -> AnkhMaskedLMOutput:
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
sequence_output = outputs.last_hidden_state
logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return AnkhMaskedLMOutput(
loss=loss,
logits=logits,
last_hidden_state=sequence_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FastAnkhForSequenceClassification(AnkhPreTrainedModel, EmbeddingMixin):
def __init__(self, config: FastAnkhConfig, **kwargs):
AnkhPreTrainedModel.__init__(self, config, **kwargs)
self.num_labels = config.num_labels
self.config = config
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = FAST_ANKH_ENCODER(config)
self.classifier = nn.Linear(config.d_model, config.num_labels)
self.mse = nn.MSELoss()
self.ce = nn.CrossEntropyLoss()
self.bce = nn.BCEWithLogitsLoss()
self.post_init()
@property
def tokenizer(self):
return self.encoder.tokenizer
def get_input_embeddings(self):
return self.encoder.embed_tokens
def _embed(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
) -> torch.Tensor:
return self.encoder._embed(
input_ids,
attention_mask,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
) -> AnkhMaskedLMOutput:
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
# Pool: mean over non-padding tokens
sequence_output = outputs.last_hidden_state
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).to(sequence_output.dtype)
pooled = (sequence_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
else:
pooled = sequence_output.mean(dim=1)
logits = self.classifier(pooled)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss = self.mse(logits.squeeze(), labels.squeeze()) if self.num_labels == 1 else self.mse(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss = self.bce(logits, labels)
return AnkhMaskedLMOutput(
loss=loss,
logits=logits,
last_hidden_state=sequence_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FastAnkhForTokenClassification(AnkhPreTrainedModel, EmbeddingMixin):
def __init__(self, config: FastAnkhConfig, **kwargs):
AnkhPreTrainedModel.__init__(self, config, **kwargs)
self.num_labels = config.num_labels
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = FAST_ANKH_ENCODER(config)
self.classifier = nn.Linear(config.d_model, config.num_labels)
self.loss_fct = nn.CrossEntropyLoss()
self.post_init()
@property
def tokenizer(self):
return self.encoder.tokenizer
def get_input_embeddings(self):
return self.encoder.embed_tokens
def _embed(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
hidden_state_index: int = -1,
store_all_hidden_states: bool = False,
) -> torch.Tensor:
return self.encoder._embed(
input_ids,
attention_mask,
hidden_state_index=hidden_state_index,
store_all_hidden_states=store_all_hidden_states,
)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
**kwargs,
) -> AnkhMaskedLMOutput:
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
sequence_output = outputs.last_hidden_state
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
labels = labels.to(logits.device)
loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return AnkhMaskedLMOutput(
loss=loss,
logits=logits,
last_hidden_state=sequence_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)