Using hyperparameter-search in Trainer

Hi @sgugger, I am using raytune with huggingface for hyperparameter tunning, here is my code snippet:

from ray.tune.schedulers import PopulationBasedTraining
from ray.tune import uniform
from random import randint
scheduler = PopulationBasedTraining(
    mode = "max",
    metric='mean_accuracy',
    perturbation_interval=2,
    hyperparam_mutations={
        "weight_decay": lambda: uniform(0.0, 0.3),
        "learning_rate": lambda: uniform(1e-5, 5e-5),
        "per_gpu_train_batch_size": [16, 32, 64],
        "num_train_epochs": [2,3,4],
        "warmup_steps":lambda: randint(0, 500)
    }
)

best_trial = trainer.hyperparameter_search(
    direction="maximize",
    backend="ray",
    n_trials=4,
    keep_checkpoints_num=1,
    scheduler=scheduler)

However, this code results in the following error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/trial_runner.py", line 586, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/ray_trial_executor.py", line 609, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/usr/local/lib/python3.7/dist-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/ray/worker.py", line 1456, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): e[36mray::ImplicitFunc.train_buffered()e[39m (pid=800, ip=172.28.0.2)
  File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/trainable.py", line 167, in train_buffered
    result = self.train()
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/trainable.py", line 226, in train
    result = self.step()
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/function_runner.py", line 366, in step
    self._report_thread_runner_error(block=True)
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/function_runner.py", line 513, in _report_thread_runner_error
    ("Trial raised an exception. Traceback:\n{}".format(err_tb_str)
ray.tune.error.TuneError: Trial raised an exception. Traceback:
e[36mray::ImplicitFunc.train_buffered()e[39m (pid=800, ip=172.28.0.2)
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/function_runner.py", line 248, in run
    self._entrypoint()
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/function_runner.py", line 316, in entrypoint
    self._status_reporter.get_checkpoint())
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/function_runner.py", line 576, in _trainable_func
    output = fn()
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/function_runner.py", line 651, in _inner
    inner(config, checkpoint_dir=None)
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/function_runner.py", line 644, in inner
    fn_kwargs[k] = parameter_registry.get(prefix + k)
  File "/usr/local/lib/python3.7/dist-packages/ray/tune/registry.py", line 167, in get
    return ray.get(self.references[k])
  File "/usr/local/lib/python3.7/dist-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/ray/serialization.py", line 245, in deserialize_objects
    self._deserialize_object(data, metadata, object_ref))
  File "/usr/local/lib/python3.7/dist-packages/ray/serialization.py", line 192, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
  File "/usr/local/lib/python3.7/dist-packages/ray/serialization.py", line 170, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
  File "/usr/local/lib/python3.7/dist-packages/ray/serialization.py", line 160, in _deserialize_pickle5_data
    obj = pickle.loads(in_band)
ModuleNotFoundError: No module named 'datasets_modules'

I would really appreciate if I could be help to identify the cause of this problem, thanks! :slight_smile:

Note: I have dataset correctly imported, everything works as expected except this snippet results in the error mentioned.