@patrickvonplaten, thank you for that. I was scratching my head about that for a bit, hahaha. Also, thank you for pointing out that the forum is better suited for this type of question than the github. I’ll update with my github post here.
Here is my updated model:
from transformers import GPT2LMHeadModel
from FeatureExtraction.NGrams import *
def __init__(self, config):
def load_ngrams_model(self, ngrams_model_path):
self.ngrams_model = NGrams(ngrams_model_path)
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer(
hidden_states = transformer_outputs
lm_logits = self.lm_head(hidden_states)
#use gpt2 to generate a span of text based off input_ids?
#gpt2_sent = ???
loss = self.ngrams_model.sentence_loss(gpt2_sent)
return (loss, lm_logits)
and here is my training script using Transformers
from text_gen_w_transformers.finetune_gpt2 import GPT2FinetunedWithNgrams
from transformers import Trainer, TrainingArguments
model = GPT2FinetunedWithNgrams.from_pretrained('gpt2')
training_args = TrainingArguments(
trainer = Trainer(
My questions are:
You can see from the
#gpt2_sent = ??? comment in the model code that I presume this is the place where I would generate a gpt2 sequence based off this version of gpt2 that is currently being finetuned. However, I am not sure what the best way to go about doing this is. Any recommendations?
In the training script, I am using the
Trainer module. However, I don’t understand what the
train_dataset parameter is in
Trainer . I have a csv file that contains one sequence per line, but I have a feeling I need to construct a
Dataset object or something.
I haven’t tried to run this code because I need to fill in the above 2 parts, but I also think I’m not setting any of the parameters for
transformer_outputs . It looks like they are set to
None and I don’t know if that will be problematic. Any thoughts on this?
I’ve been reading through the documentation and really like the library. I’m also new to it and pytorch so I apologize if my questions are pretty basic. Thanks in advance for your help!