RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32 notwithstanding

run trainer with following error

Traceback (most recent call last):
  File "/home/dhl/LongChat-dev/longchat/dist_attn/train.py", line 9, in <module>
    train()
  File "/home/dhl/LongChat-dev/longchat/dist_attn/train_lightseq.py", line 383, in train
    trainer.train()
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/transformers/trainer.py", line 1859, in train
    return inner_training_loop(
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/transformers/trainer.py", line 2266, in _inner_training_loop
    self.optimizer.step()
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/accelerate/optimizer.py", line 170, in step
    self.optimizer.step(closure)
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 75, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/optim/optimizer.py", line 385, in wrapper
    out = func(*args, **kwargs)
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/optim/adamw.py", line 187, in step
    adamw(
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/optim/adamw.py", line 339, in adamw
    func(
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/optim/adamw.py", line 516, in _multi_tensor_adamw
    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/optim/optimizer.py", line 409, in _group_tensors_by_device_and_dtype
    return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/dhl/miniconda3/envs/long/lib/python3.10/site-packages/torch/utils/_foreach_utils.py", line 38, in _group_tensors_by_device_and_dtype
    torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items()
RuntimeError: Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32 notwithstanding

code

class RandomDataset(Dataset):
    """Dataset with random input_ids and labels, attention_mask all ones."""
    
    def __init__(self, args: TrainingArguments, size: int):
        super(RandomDataset, self).__init__()
        self.size = size
        self.seq_size = get_sequence_parallel_size()
        self.model_max_length = int(args.model_max_length / self.seq_size) 
        
        vocab_size = 1000
        # Randomly generating input_ids and labels
        self.input_ids = torch.randint(
            low=0,
            high=vocab_size,
            size=(self.size, self.model_max_length),
            dtype=torch.long
        )
        self.labels = torch.randint(
            low=0,
            high=vocab_size,
            size=(self.size, self.model_max_length),
            dtype=torch.long
        )
        # attention_mask all ones
        self.attention_mask = torch.ones(
            (self.size, self.model_max_length),
            dtype=torch.long
        )

    def __len__(self):
        return self.size

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if isinstance(i, list):
            assert False, f"bs >1 not supported: {i}"
        return dict(
            input_ids=self.input_ids[i],
            labels=self.labels[i],
            attention_mask=self.attention_mask[i],
        )

def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args, training_args
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    train_dataset = RandomDataset(training_args, 256)
    rank0_print("Loading data...")

    #train_json = json.load(open(data_args.data_path, "r"))
    #train_dataset = dataset_cls(train_json, tokenizer=tokenizer)

    #if data_args.eval_data_path:
    #    eval_json = json.load(open(data_args.eval_data_path, "r"))
    #    eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
    #else:
    eval_dataset = None

    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)


def train():
    global local_rank

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    local_rank = training_args.local_rank

    # Set RoPE scaling factor
    config = transformers.AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=model_args.trust_remote_code,
    )
    orig_ctx_len = getattr(config, "max_position_embeddings", None)
    if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
        scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
        config.rope_scaling = {"type": "linear", "factor": scaling_factor}
    config.use_cache = False

    # Load model and tokenizer
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        trust_remote_code=model_args.trust_remote_code,
    )#.to(dtype=torch.float16)
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side=model_args.padding_side,
        use_fast=False,
        trust_remote_code=model_args.trust_remote_code,
    )

    if tokenizer.pad_token != tokenizer.unk_token:
        tokenizer.pad_token = tokenizer.unk_token

    # Load data
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, training_args=training_args)

    # Start trainner
    trainer = Trainer(
        model=model, tokenizer=tokenizer, args=training_args, **data_module
    )
    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    # Save model
    model.config.use_cache = True
    trainer.save_state()
    if trainer.is_deepspeed_enabled:
        trainer.save_model()
    else:
        trainer_save_model_safe(trainer)


if __name__ == "__main__":
    train()

version

accelerate==0.30.0
aiofiles==23.2.1
aiohttp==3.9.5
aiosignal==1.3.1
altair==5.3.0
annotated-types==0.6.0
anyio==4.3.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.2.0
certifi==2024.2.2
charset-normalizer==3.3.2
click==8.1.7
contourpy==1.2.1
cpm-kernels==1.0.11
cycler==0.12.1
distro==1.9.0
dnspython==2.6.1
docker-pycreds==0.4.0
einops==0.8.0
email_validator==2.1.1
exceptiongroup==1.2.1
fastapi==0.111.0
fastapi-cli==0.0.2
ffmpy==0.3.2
filelock==3.13.1
flash-attn==2.0.9
fonttools==4.51.0
frozenlist==1.4.1
fschat==0.2.36
fsspec==2024.2.0
gitdb==4.0.11
GitPython==3.1.43
gradio==4.29.0
gradio_client==0.16.1
h11==0.14.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.23.0
idna==3.7
importlib_resources==6.4.0
iniconfig==2.0.0
Jinja2==3.1.3
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
-e git+https://github.com/RulinShao/LongChat-dev@efe79a799a2f2919706f7172035290b6e7ecf0c2#egg=longchat
markdown-it-py==3.0.0
markdown2==2.4.13
MarkupSafe==2.1.5
matplotlib==3.8.4
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
networkx==3.2.1
nh3==0.2.17
ninja==1.11.1.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.1.105
nvidia-nvtx-cu12==12.1.105
openai==1.25.1
orjson==3.10.3
packaging==24.0
pandas==2.2.2
pillow==10.2.0
pluggy==1.5.0
prompt-toolkit==3.0.43
protobuf==4.25.3
psutil==5.9.8
pydantic==2.7.1
pydantic_core==2.18.2
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.2
pytest==8.2.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
regex==2024.4.28
requests==2.31.0
rich==13.7.1
rpds-py==0.18.0
ruff==0.4.3
safetensors==0.4.3
semantic-version==2.10.0
sentencepiece==0.2.0
sentry-sdk==2.0.1
setproctitle==1.3.3
shellingham==1.5.4
shortuuid==1.0.13
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
starlette==0.37.2
svgwrite==1.4.3
sympy==1.12
tiktoken==0.6.0
tokenizers==0.19.1
tomli==2.0.1
tomlkit==0.12.0
toolz==0.12.1
torch==2.2.2+cu121
torchaudio==2.2.2+cu121
torchvision==0.17.2+cu121
tqdm==4.66.4
transformers==4.40.1
triton==2.1.0
typer==0.12.3
typing_extensions==4.9.0
tzdata==2024.1
ujson==5.9.0
urllib3==2.2.1
uvicorn==0.29.0
uvloop==0.19.0
wandb==0.16.6
watchfiles==0.21.0
wavedrom==2.0.3.post3
wcwidth==0.2.13
websockets==11.0.3
yarl==1.9.4

train on 8xA100node
cuda 12.1