SentenceTransformer TrainingArguments torch and accelerate version issue

0

I am using below script to train a custom embedding model. The data uses a description and corresponding search query so that a custom embedding model can be trained using them both. I have been using sentence-transformers 2.2.2 before but when I updated to version 3.0.1 it suggested to use SentenceTransformerTrainer object for training to be able to use the new fit method (sentence-transformers/sentence_transformers/fit_mixin.py at b37f470e1625878b0f31525251db74b658a26dcb · UKPLab/sentence-transformers · GitHub)

import pandas as pd
from datasets import Dataset
from sentence_transformers import SentenceTransformer, InputExample, losses, LoggingHandler
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
import logging
from datetime import datetime
import os

print("transformers version:", transformers.__version__)
print("accelerate version:", accelerate.__version__)
print("sentence_transformers version:", sentence_transformers.__version__)

# Enable logging
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])



Load your data into a pandas dataframe
df = pd.read_csv('my_data.csv')  # Replace with your dataframe loading method


description = [txt for txt in df.description]
query = [q for q in df.query]

train_examples = Dataset.from_dict(
    {
    "description": description,
    "query": query
    }
)


# Define hyperparameters
train_batch_size = 16
num_epochs = 4
model_save_path = 'output/training_sts_model_' + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
os.makedirs(model_save_path, exist_ok=True)

# Initialize a pre-trained model
model = SentenceTransformer('distilbert-base-nli-mean-tokens')

# Define the loss function
train_loss = losses.MultipleNegativesRankingLoss(model=model)

# Define training arguments
training_args = SentenceTransformerTrainingArguments(
    output_dir=model_save_path,
    overwrite_output_dir=True,
    num_train_epochs=num_epochs,
    per_device_train_batch_size=train_batch_size
)

# Create the SentenceTransformerTrainer
trainer = SentenceTransformerTrainer(
    model=model,
    args=training_args,
    train_dataset=train_examples,
    loss=train_loss,
)

# Train the model
trainer.train()

print("Model training complete. Model saved to:", model_save_path)

I am getting error at training_args = SentenceTransformerTrainingArguments() step as below:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[49], line 49
     46 train_loss = losses.MultipleNegativesRankingLoss(model=model)
     48 # Define training arguments
---> 49 training_args = SentenceTransformerTrainingArguments(
     50     output_dir=model_save_path,
     51     overwrite_output_dir=True,
     52     num_train_epochs=num_epochs,
     53     per_device_train_batch_size=train_batch_size
     54 )
     56 # Create the SentenceTransformerTrainer
     57 trainer = SentenceTransformerTrainer(
     58     model=model,
     59     args=training_args,
   (...)
     62     evaluator=evaluator,
     63 )

File <string>:133, in __init__(self, output_dir, overwrite_output_dir, do_train, do_eval, do_predict, eval_strategy, prediction_loss_only, per_device_train_batch_size, per_device_eval_batch_size, per_gpu_train_batch_size, per_gpu_eval_batch_size, gradient_accumulation_steps, eval_accumulation_steps, eval_delay, torch_empty_cache_steps, learning_rate, weight_decay, adam_beta1, adam_beta2, adam_epsilon, max_grad_norm, num_train_epochs, max_steps, lr_scheduler_type, lr_scheduler_kwargs, warmup_ratio, warmup_steps, log_level, log_level_replica, log_on_each_node, logging_dir, logging_strategy, logging_first_step, logging_steps, logging_nan_inf_filter, save_strategy, save_steps, save_total_limit, save_safetensors, save_on_each_node, save_only_model, restore_callback_states_from_checkpoint, no_cuda, use_cpu, use_mps_device, seed, data_seed, jit_mode_eval, use_ipex, bf16, fp16, fp16_opt_level, half_precision_backend, bf16_full_eval, fp16_full_eval, tf32, local_rank, ddp_backend, tpu_num_cores, tpu_metrics_debug, debug, dataloader_drop_last, eval_steps, dataloader_num_workers, dataloader_prefetch_factor, past_index, run_name, disable_tqdm, remove_unused_columns, label_names, load_best_model_at_end, metric_for_best_model, greater_is_better, ignore_data_skip, fsdp, fsdp_min_num_params, fsdp_config, fsdp_transformer_layer_cls_to_wrap, accelerator_config, deepspeed, label_smoothing_factor, optim, optim_args, adafactor, group_by_length, length_column_name, report_to, ddp_find_unused_parameters, ddp_bucket_cap_mb, ddp_broadcast_buffers, dataloader_pin_memory, dataloader_persistent_workers, skip_memory_metrics, use_legacy_prediction_loop, push_to_hub, resume_from_checkpoint, hub_model_id, hub_strategy, hub_token, hub_private_repo, hub_always_push, gradient_checkpointing, gradient_checkpointing_kwargs, include_inputs_for_metrics, eval_do_concat_batches, fp16_backend, evaluation_strategy, push_to_hub_model_id, push_to_hub_organization, push_to_hub_token, mp_parameters, auto_find_batch_size, full_determinism, torchdynamo, ray_scope, ddp_timeout, torch_compile, torch_compile_backend, torch_compile_mode, dispatch_batches, split_batches, include_tokens_per_second, include_num_input_tokens_seen, neftune_noise_alpha, optim_target_modules, batch_eval_metrics, eval_on_start, eval_use_gather_object, batch_sampler, multi_dataset_batch_sampler)

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/sentence_transformers/training_args.py:73, in SentenceTransformerTrainingArguments.__post_init__(self)
     72 def __post_init__(self):
---> 73     super().__post_init__()
     75     self.batch_sampler = BatchSamplers(self.batch_sampler)
     76     self.multi_dataset_batch_sampler = MultiDatasetBatchSamplers(self.multi_dataset_batch_sampler)

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/training_args.py:1730, in TrainingArguments.__post_init__(self)
   1728 # Initialize device before we proceed
   1729 if self.framework == "pt" and is_torch_available():
-> 1730     self.device
   1732 if self.torchdynamo is not None:
   1733     warnings.warn(
   1734         "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
   1735         " `torch_compile_backend` instead",
   1736         FutureWarning,
   1737     )

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/training_args.py:2227, in TrainingArguments.device(self)
   2223 """
   2224 The device used by this process.
   2225 """
   2226 requires_backends(self, ["torch"])
-> 2227 return self._setup_devices

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/utils/generic.py:60, in cached_property.__get__(self, obj, objtype)
     58 cached = getattr(obj, attr, None)
     59 if cached is None:
---> 60     cached = self.fget(obj)
     61     setattr(obj, attr, cached)
     62 return cached

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/training_args.py:2103, in TrainingArguments._setup_devices(self)
   2101 if not is_sagemaker_mp_enabled():
   2102     if not is_accelerate_available():
-> 2103         raise ImportError(
   2104             f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
   2105             "Please run `pip install transformers[torch]` or `pip install accelerate -U`"
   2106         )
   2107 # We delay the init of `PartialState` to the end for clarity
   2108 accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}

ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.21.0`: Please run `pip install transformers[torch]` or `pip install accelerate -U`

my versions for the libraries are as below:

transformers version: 4.44.2
accelerate version: 0.33.0
sentence_transformers version: 3.0.1

can someone suggest what can be done given i have upto date versions of dependencies but still struggling to instantiate the SentenceTransformerTrainingArguments object.