Using custom models (not necessarily transformer based) with generate() and sampling

I’m comparing a number of Seq2Seq models for a particular (non NLP) task. I’ve got a version of this working in pure PyTorch, but I’m interested in porting the research to HuggingFace to take advantage of the existing sampling and beam search capabilities. Most of my existing models are autoregressive LMs, but I also have an autoregressive LSTM baseline that would need to be ported to use the same sampling infrastructure.

Is this practical, or am I going to run into problems trying to use a non transformer model in Transformers?

I assume that the way to do this is to derive from PreTrainedModel and return CausalLMOutput, but some of the optional parameters on the output (particularly attentions) are transformer specific. Am I safe leaving these as None, or are there going to be times when the generation infrastructure relies on these values? Is there a better way to make use of Transformers’ generation/sampling/beam capabilities with custom models?

I think what you’re trying to do is feasible. As you mentioned, you could inherit PreTrainedModel, but could perhaps also just inherit GenerationMixin. You’d also want to make sure the forward method of your custom class resembles the forward methods of Huggingface models.

but some of the optional parameters on the output (particularly attentions ) are transformer specific. Am I safe leaving these as None, or are there going to be times when the generation infrastructure relies on these values?

You should be fine setting those to None. Most of the things in the return objects are only there so users can access them if desired. In fact, the attentions field of the return objects is usually None, because it’s controlled by the output_attentions arg, which is False by default.

1 Like

Just inheriting GenerationMixin is a nice idea! I’ve got a version based on PreTrainedModel implemented now, but I haven’t made it fully compatible with everything that class expects.

I’ve opted encapsulate the whole PreTrainedModel class inside a generate function to hide the details and only expose the few features that I’ve actually implemented.