How to sample from the validation set when using Trainer?

When using the Trainer, e.g.

# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    logging_steps=2,  # set to 1000 for full training
    save_steps=16,    # set to 500 for full training
    eval_steps=4,     # set to 8000 for full training
    warmup_steps=1,   # set to 2000 for full training
    max_steps=16,     # delete for full training
    # overwrite_output_dir=True,

# instantiate trainer
trainer = Seq2SeqTrainer(

Is there someway to randomly select/sample from the eval_data at every n eval_steps ?

E.g. I have tried

eval_data =
trainer = Seq2SeqTrainer(

But that would be statically defining the eval_data subset before the training. Is it possible to do the selecting during the training and make it kind of select a different subset at every evaluation point?


No that’s not supported. The evaluation dataset has to be fixed.


Thanks @sgugger for the prompt reply!

1 Like


  1. will this ever be supported?
  2. Why does the eval set have to be fixed?
  3. How do we avoid memory issues because when the train set samples data to log and the eval data set samples to log I’ve gotten OM issues. Can they run seperately?

fyi best solution I know:

    # - Get eval data set (AF for us),
    per_device_eval_batch_size = 4  # TODO: change to something larger, right now due to size of my debug0
    # eval_steps=1000
    # TODO: probably need to write a collate_fn for the eval so that the eval is done right?
    # TODO: we need ppl (and ideally token edit distance for eval, reason explained here:
    path, name = 'brando/debug1_af', None
    eval_dataset = load_dataset(path, name, streaming=False, split="test").with_format(type="torch") 
    eval_dataset =
    ## eval_dataset = train_dataset  # TODO: fix obviously to something else using af
    raw_text_batch = eval_dataset.take(per_device_eval_batch_size) if streaming else
    column_names = next(iter(raw_text_batch)).keys()
    def eval_preprocess(examples):
        return tokenizer(examples["formal statement"] + [' '] + examples["generated informal statement"], padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
    remove_columns = column_names  # remove all keys that are not tensors to avoid bugs in collate function in task2vec's pytorch data loader
    def map(batch):
        return, batched=True, remove_columns=remove_columns)
    eval_dataset = map(eval_dataset)
    train_dataset = train_dataset

    # -- Compute max steps
    per_device_train_batch_size = batch_size
    # dataset_size: int = int(1.5e12)  # TODO, doesn't seem easy to solve. Either count all the sequennces/rows or have the meta data have this. Or make this number huge. 
    dataset_size: int = train_dataset.num_rows
    # dataset_size: int = len(train_dataset)
    # TODO['split']['train']['num_examples']
    # dataset_size = sum(len(dataset) for dataset in datasets)  # TODO: works on with streaming = False?
    # dataset_size = sum(dataset.cardinality() for dataset in datasets)
    # # TODO: feel free to fix the issue if I'm not seeing all the data points...
    # num_epochs = 1
    max_steps = (dataset_size // per_device_train_batch_size) * num_epochs
    print(f'{num_epochs=} {max_steps=}')
    ## DOESNT WORK num_train_epochs = 3  # TODO: since I decided to do streaming = False and if we collect enough data it's unlikely we see it all hopefully (if we do 3 times seems good given that LLMs are trained to see the data only once this seems a sensible soln, + in the imagenet days things were trained to convergence with no overfitting ref:

    # -- Define custom collate function
    def custom_collate_fn(data: list[dict[str, str]], tokenizer: PreTrainedTokenizer) -> dict[str, torch.Tensor]:
        """ trains on first occurence of eos
        ref: on when to call .clone()
        # we are training full context length forllama so remove code bellow, if it triesto pad hopefully it throws an error
        # -- Ensure tokenizer has a padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        # -- Extract sequences
        # sequences: list[str] = [example.get("text", "") or "" for example in data]
        sequences: list[str] = []
        for idx, example in enumerate(data):
            # Retrieve the value for "text" from the dictionary or default to an empty string if not present or falsy. ref:
            text: str = example.get("text", "") or ""
        # -- Tokenize the sequences
        tokenized_data = tokenizer(sequences, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
        tokenized_data["labels"] = tokenized_data["input_ids"].clone()  # labels is hardcoded in HF so put it!
        # -- Set the mask value for the first eos_token in each sequence to 1
        eos_token_id = tokenizer.eos_token_id
        for idx, input_ids in enumerate(tokenized_data["input_ids"]):
            # Find all occurrences of eos_token
            eos_positions = (input_ids == eos_token_id).nonzero(as_tuple=True)[0]
            if eos_positions.nelement() > 0:  # Check if eos_token is present
                first_eos_position = eos_positions[0]
                tokenized_data["attention_mask"][idx, first_eos_position] = 1  # Set the mask value to 1
                # Assert that the label for the first occurrence of eos_token is eos_token_id
                assert tokenized_data["labels"][idx, first_eos_position] == eos_token_id, "The label for the first eos_token is incorrect!"
                # For all subsequent occurrences of eos_token, set their labels to -100
                for subsequent_eos_position in eos_positions[1:]:
                    tokenized_data["labels"][idx, subsequent_eos_position] = -100
                    assert tokenized_data["labels"][idx, subsequent_eos_position] == -100, "The label for the subsequent_eos_position incorrect! Should be -100."
        return tokenized_data

    # - Debug before training to see data
    sample_data = if not isinstance(train_dataset, datasets.iterable_dataset.IterableDataset) else train_dataset.take(per_device_train_batch_size)
    processed_data = custom_collate_fn(sample_data, tokenizer=tokenizer)

    # -- Training arguments and trainer instantiation ref:
    output_dir = Path(f'~/data/maf_data/results_{today}/').expanduser() if not debug else Path(f'~/data/maf_data/results/').expanduser()
    print(f'{debug=} {output_dir=} \n {report_to=}')
    training_args = TrainingArguments(
        output_dir=output_dir,  #The output directory where the model predictions and checkpoints will be written.
        # num_train_epochs = num_train_epochs, 
        max_steps=max_steps,  # TODO: hard to fix, see above
        gradient_accumulation_steps=gradient_accumulation_steps,  # based on alpaca, allows to process effective_batch_size = gradient_accumulation_steps * batch_size, num its to accumulate before opt update step
        gradient_checkpointing = gradient_checkpointing,  # TODO depending on hardware set to true?
        optim="paged_adamw_32bit",  # David hall says to keep 32bit opt TODO: if we are using brain float 16 bf16 should we be using 32 bit? are optimizers always fb32?
        warmup_steps=500,  # TODO: once real training starts we can select this number for llama v2, what does llama v2 do to make it stable while v1 didn't?
        warmup_ratio=0.03,  # copying alpaca for now, number of steps for a linear warmup, TODO once real training starts change? 
        # weight_decay=0.01,  # TODO once real training change?
        weight_decay=0.00,  # TODO once real training change?
        learning_rate = 1e-5,  # TODO once real training change? anything larger than -3 I've had terrible experiences with
        max_grad_norm=1.0, # TODO once real training change?
        lr_scheduler_type="cosine",  # TODO once real training change? using what I've seen most in vision 
        save_steps=2000,  # alpaca does 2000, other defaults were 500
        # logging_steps=250,
        # logging_steps=50,  
        remove_unused_columns=False,  # TODO don't get why ,
        report_to=report_to,  # change to wandb!
        fp16=False,  # never ever set to True
        bf16=torch.cuda.get_device_capability(torch.cuda.current_device())[0] >= 8,  # if >= 8 ==> brain float 16 available or set to True if you always want fp32
    # print(f'{training_args=}')

    # TODO: might be nice to figure our how llamav2 counts the number of token's they've trained on
    trainer = Trainer(
        data_collator=lambda data: custom_collate_fn(data, tokenizer=tokenizer)

All links related to this Q: