After using the map function to tokenize all splits via a lambda function Iām unable pass the encoded text/inputs forward through Trainer. Thanks in advance for any direction!
category_data = load_dataset("csv", data_files="testdatav2.csv")
category_data = category_data.remove_columns(["amazoncontactid", "regiid", "lineofbusiness", "primary_label"])
category_data = category_data['train']
train_testvalid = category_data.train_test_split(test_size=0.3)
test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
from datasets.dataset_dict import DatasetDict
cd = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train']})
print(cd)
category_data = load_dataset("csv", data_files="testdatav2.csv")
category_data = category_data.remove_columns(["amazoncontactid", "regiid", "lineofbusiness", "primary_label"])
category_data = category_data['train']
train_testvalid = category_data.train_test_split(test_size=0.3)
test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
from datasets.dataset_dict import DatasetDict
cd = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'valid': test_valid['train']})
print(cd)
Using custom data configuration default-89c081370f72e624
Found cached dataset csv (/root/.cache/huggingface/datasets/csv/default-89c081370f72e624/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%
1/1 [00:00<00:00, 31.60it/s]
DatasetDict({
train: Dataset({
features: ['transcript', 'idx'],
num_rows: 858
})
test: Dataset({
features: ['transcript', 'idx'],
num_rows: 185
})
valid: Dataset({
features: ['transcript', 'idx'],
num_rows: 184
})
})
Here is where tokenize my examples
from transformers import AutoTokenizer
model_transcripts = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_transcripts)
from transformers import DistilBertTokenizer
db_tokenizer = DistilBertTokenizer.from_pretrained(model_transcripts)
transcripts_encoded = cd.map(lambda examples: tokenizer(examples["transcript"]), batched=True)
print(transcripts_encoded)
transcripts_encoded = transcripts_encoded.set_format("torch",
columns=["input_ids", "attention_mask", "idx"])`
print(transcripts_encoded)
None
Iāve checked my variable for any NaN values and can confirm there are none. Any help would be greatly appreciated!