import logging from typing import Iterable from datasets import load_dataset, load_from_disk import torch from sentence_transformers import ( SparseEncoder, SparseEncoderModelCardData, SparseEncoderTrainer, SparseEncoderTrainingArguments, ) from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator from sentence_transformers.sparse_encoder.losses import ( SparseDistillKLDivLoss, SpladeLoss, SparseMarginMSELoss, ) from sentence_transformers.training_args import BatchSamplers logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) # 1. Load a model to finetune with 2. (Optional) model card data model = SparseEncoder( "Luyu/co-condenser-marco", # "naver/splade-cocondenser-selfdistil", model_card_data=SparseEncoderModelCardData( language="en", license="apache-2.0", model_name="CoCondenser finetuned on MS MARCO", ), # revision="refs/pr/2", # similarity_fn_name="cosine", ) dataset_size = 100_000 # 3. Load the MS MARCO dataset: https://huggingface.co/datasets/sentence-transformers/msmarco logging.info("Read train dataset") try: train_dataset = load_from_disk("ms-marco-kldiv-train-minilm") eval_dataset = load_from_disk("ms-marco-kldiv-eval-minilm") except FileNotFoundError: """ logging.info("The dataset has not been fully stored as texts on disk yet. We will do this now.") corpus = load_dataset("sentence-transformers/msmarco", "corpus", split="train") corpus = dict(zip(corpus["passage_id"], corpus["passage"])) queries = load_dataset("sentence-transformers/msmarco", "queries", split="train") queries = dict(zip(queries["query_id"], queries["query"])) dataset = load_dataset("sentence-transformers/msmarco", "bert-ensemble-margin-mse", split="train") dataset = dataset.select(range(dataset_size)) def id_to_text_map(batch): return { "query": [queries[qid] for qid in batch["query_id"]], "positive": [corpus[pid] for pid in batch["positive_id"]], "negative": [corpus[pid] for pid in batch["negative_id"]], "score": batch["score"], } dataset = dataset.map(id_to_text_map, batched=True, remove_columns=["query_id", "positive_id", "negative_id"]) dataset = dataset.train_test_split(test_size=10_000) train_dataset = dataset["train"] eval_dataset = dataset["test"] train_dataset.save_to_disk("ms-marco-margin-mse-train") eval_dataset.save_to_disk("ms-marco-margin-mse-eval") logging.info( "The dataset has now been stored as texts on disk. The script will now stop to ensure that memory is freed. " "Please restart the script to start training." ) """ quit() logging.info(train_dataset) class SparseDistillKLDivMarginMSELoss(torch.nn.Module): def __init__(self, model: SparseEncoder, kl_div_weight: float = 1.0, mse_weight: float = 1.0): super().__init__() self.model = model self.kl_div_weight = kl_div_weight self.mse_weight = mse_weight self.distill_kl_div_loss = SparseDistillKLDivLoss( model=model, similarity_fct=model.similarity_pairwise, ) self.margin_mse_loss = SparseMarginMSELoss( model=model, similarity_fct=model.similarity_pairwise, ) def forward(self, sentence_features: Iterable[dict[str, torch.Tensor]], labels: torch.Tensor) -> dict[str, torch.Tensor]: raise NotImplementedError( "This loss function is designed to be used with SpladeLoss or CSRLoss, not directly." ) def compute_loss_from_embeddings(self, embeddings: list[torch.Tensor], labels: torch.Tensor) -> torch.Tensor: # Compute the KL Divergence loss kl_loss = self.distill_kl_div_loss.compute_loss_from_embeddings(embeddings, labels) # Compute the Margin MSE loss # NOTE: MarginMSE expects different labels: difference between positive and negative scores, # whereas currently our labels are just the query-positive/negative scores mse_labels = labels[:, 0].unsqueeze(1) - labels[:, 1:] mse_loss = self.margin_mse_loss.compute_loss_from_embeddings(embeddings, mse_labels) # Combine the losses (you can adjust the weights as needed) return { "kl_loss": kl_loss * self.kl_div_weight, "mse_loss": mse_loss * self.mse_weight, } # 4. Define a loss function loss = SpladeLoss( model=model, loss=SparseDistillKLDivMarginMSELoss(model, kl_div_weight=1.0, mse_weight=0.05), lambda_query=5e-4, lambda_corpus=5e-4, ) # 5. (Optional) Specify training arguments run_name = "splade-cocondenser-msmarco-kldiv-marginmse-minilm" args = SparseEncoderTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: num_train_epochs=1, per_device_train_batch_size=16, per_device_eval_batch_size=16, learning_rate=2e-5, warmup_ratio=0.1, fp16=True, # Set to False if you get an error that your GPU can't run on FP16 bf16=False, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch # Optional tracking/debugging parameters: eval_strategy="steps", eval_steps=500, save_strategy="steps", save_steps=500, save_total_limit=2, logging_steps=100, run_name=run_name, # Will be used in W&B if `wandb` is installed ) # 6. (Optional) Create an evaluator & evaluate the base model dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16) dev_evaluator(model) # 7. Create a trainer & train trainer = SparseEncoderTrainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, evaluator=dev_evaluator, ) trainer.train() # 8. Evaluate the model performance again after training dev_evaluator(model) # 9. Save the trained model model.save_pretrained(f"models/{run_name}/final") # 10. (Optional) Push it to the Hugging Face Hub model.push_to_hub(run_name)