Say I want to train a simple LSTM or MLP with Trainer (Pytroch nn.Module
s). Do I just need to ensure the model adheres to the following?
Is there an example of using Trainer to train models that are not HF Transformers models? Best practices?
Say I want to train a simple LSTM or MLP with Trainer (Pytroch nn.Module
s). Do I just need to ensure the model adheres to the following?
Is there an example of using Trainer to train models that are not HF Transformers models? Best practices?
I think HF trainer API is specifically for transformers but not for other models.
We don’t have an example, but as long as you follow the recommendation in that list of the documentation, you should be fine.
and if you use it successfully and want to do a short writeup, publish it, we’ll make sure to share your writeup!
Confirmed that you can train a simple LSTM or MLP with Trainer. This is nice since I can just stay within the HF ecosystem. I’m not sure I’ll have time to do a write-up but as long as you follow that list in the original post, it will work.
Hello ivnle!
I am currently trying to train an LSTM model that takes as input the embeddings that are outputted from a pretrained model from the hf hub. I am following the example at Sharing custom models and my model class that inherits from PreTrainedModel is the following:
from transformers import PreTrainedModel
class LSTMModel(PreTrainedModel):
config_class = LSTMConfig
def __init__(self, config, pretrained_model):
super().__init__(config)
self.model = SentimentLSTM(pretrained_model,
output_size=config.output_size,
hidden_dim=config.hidden_dim,
n_layers=config.n_layers)
def forward(self, tensor, labels=None):
logits = self.model(tensor)
if labels is not None:
loss = torch.nn.cross_entropy(logits, labels)
return {"loss": loss, "logits": logits}
return {"logits": logits}
where SentimentLSTM is my custom LSTM model.
Then I initialize the training_arguments and the trainer object and try to perform trainer.train(). However I get an error, due to the tensor argument, because the function cannot find it.
Did you use the same example to perform your own training?
If so, what was your forward function in the corresponding LSTMModel class and how could you pass an extra argument for this function through Trainer?
Thank you in advance,
Petrina