I’m trying to train a model with very standard HF code I’ve used before:
import os
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
from pathlib import Path
import glob
def preprocess_function_proofnet_simple(examples: dict[str, list], tokenizer, max_length: int = 1024) -> dict[str, torch.Tensor]:
"""
Preprocess the input data for the proofnet dataset.
Args:
examples: The examples to preprocess.
tokenizer: The tokenizer for encoding the texts.
Returns:
The processed model inputs.
"""
inputs = [f"{examples['nl_statement'][i]}{tokenizer.eos_token}{examples['formal_statement'][i]}" for i in range(len(examples['nl_statement']))]
model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
labels = model_inputs.input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
model_inputs["labels"] = labels
return model_inputs
def get_proofnet_dataset(tokenizer, preprocess_function=preprocess_function_proofnet_simple):
dataset_val = load_dataset("hoskinson-center/proofnet", split='validation')
dataset_test = load_dataset("hoskinson-center/proofnet", split='test')
val_dataset = dataset_val.map(lambda examples: preprocess_function(examples, tokenizer), batched=True, remove_columns=["nl_statement", "formal_statement"])
test_dataset = dataset_test.map(lambda examples: preprocess_function(examples, tokenizer), batched=True, remove_columns=["nl_statement", "formal_statement"])
return val_dataset, test_dataset
print()
# export PYTORCH_ENABLE_MPS_FALLBACK=1
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
# Load Hugging Face token from file
with open(Path("~/keys/hf_file_key.txt").expanduser(), "r") as file:
hf_token = file.read().strip()
# Set the Hugging Face token as an environment variable
os.environ["HF_TOKEN"] = hf_token
# Login using the token
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))
# Load model and tokenizer
pretrained_model_name_or_path = "openai-community/gpt2"
if 'gpt2' in pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
print(f'{tokenizer.pad_token=}')
print(f'{tokenizer.eos_token=}\n{tokenizer.eos_token_id=}')
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
# device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
# device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')
print(f'{device=}')
model = model.to(device)
max_length: int = tokenizer.model_max_length
print(f'{max_length=}')
# Define training arguments with memory optimization tricks
training_args = TrainingArguments(
output_dir="~/tmp/results", # Output directory for saving model checkpoints
per_device_train_batch_size=1, # Training batch size per device
per_device_eval_batch_size=1, # Evaluation batch size per device
max_steps=2, # Total number of training steps
logging_dir='~/tmp/logs', # Directory for storing logs
logging_steps=10, # Frequency of logging steps
gradient_accumulation_steps=1, # Accumulate gradients to simulate a larger batch size
save_steps=500, # Save checkpoint every 500 steps
save_total_limit=3, # Only keep the last 3 checkpoints
evaluation_strategy="steps", # Evaluate model at specified steps
eval_steps=100, # Evaluate every 100 steps
gradient_checkpointing=True, # Enable gradient checkpointing to save memory
optim="paged_adamw_32bit", # Optimizer choice with memory optimization
learning_rate=1e-5, # Learning rate for training
warmup_ratio=0.01, # Warmup ratio for learning rate schedule
weight_decay=0.01, # Weight decay for regularization
lr_scheduler_type='cosine', # Learning rate scheduler type
report_to="none", # Disable reporting to external tracking tools
# bf16=torch.cuda.is_bf16_supported(), # Use BF16 if supported by the hardware
half_precision_backend="auto", # Automatically select the best backend for mixed precision
# dataloader_num_workers=4, # TODO Number of subprocesses for data loading
# dataloader_pin_memory=True, # TODO periphery, Pin memory in data loaders for faster transfer to GPU
# skip_memory_metrics=True, # Skip memory metrics to save memory
# dataloader_prefetch_factor=2, # TODO periphery, Number of batches to prefetch
# torchdynamo="nvfuser", # TODO periphery, Use NVFuser backend for optimized torch operations
full_determinism=True, # TODO periphery, Ensure reproducibility
use_cpu=True,
)
train_dataset, test_dataset = get_proofnet_dataset(tokenizer)
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
# eval_dataset=eval_dataset,
)
# Start training
print(f'\n-- Start training')
trainer.train()
# Save the model and tokenizer
trainer.save_model(output_dir="~/tmp/results")
tokenizer.save_pretrained(output_dir="~/tmp/results")
but no matter what I do e.g.,
- I’ve forced every possible way I can to have cpu enabled to force it to train at all
- use a HF dataset from the internet I’ve used before
- updated pytorch
pip install --upgrade torch
- Disabled MPS
- tried making sure cpu was used
- training_args = TrainingArguments(
…
no_cuda=True,
use_mps_device=True if torch.backends.mps.is_available() else False,
…
) - " 1. Verify data types: Ensure that your model and data are using compatible data types. MPS might have issues with certain data types." but it’s obvious it should work cuz the HF trainer does this on it’s own by fetching the device from my model. I’ve checked this code before.
- yes I did
device = torch.device("cpu")
but it doesn’t work and I get a very cryptic error I’ve never seen before and nothing on google shows up:
Exception has occurred: AttributeError
'NoneType' object has no attribute 'cget_managed_ptr'
File "/Users/me/py_proj/py_src/train/hf_trainer_train.py", line 93, in <module>
trainer.train()
AttributeError: 'NoneType' object has no attribute 'cget_managed_ptr'
what is going on? How do I debug this?
Related to this issue I also have this odd warning, wonder if it’s related:
'NoneType' object has no attribute 'cadam32bit_grad_fp32'
My conda env (locally, in server I’m using venv):
% pip list
Package Version Editable project location
----------------------- ----------- ------------------------------
absl-py 2.1.0
accelerate 0.32.1
my_proj 0.0.1 /Users/me/my_proj/py_src
aiohttp 3.9.5
aiosignal 1.3.1
alembic 1.13.2
annotated-types 0.7.0
anthropic 0.31.1
anthropic-bedrock 0.8.0
anyio 4.4.0
attrs 23.2.0
backoff 2.2.1
backports.tarfile 1.2.0
bitsandbytes 0.42.0
boto3 1.34.145
botocore 1.34.145
certifi 2024.7.4
charset-normalizer 3.3.2
click 8.1.7
colorlog 6.8.2
contourpy 1.2.1
cycler 0.12.1
datasets 2.20.0
dill 0.3.8
distro 1.9.0
docker-pycreds 0.4.0
docutils 0.21.2
dspy-ai 2.4.12
evaluate 0.4.2
filelock 3.15.4
fire 0.6.0
fonttools 4.53.1
frozenlist 1.4.1
fsspec 2024.5.0
gitdb 4.0.11
GitPython 3.1.43
greenlet 3.0.3
grpcio 1.64.1
h11 0.14.0
httpcore 1.0.5
httpx 0.27.0
huggingface-hub 0.23.4
idna 3.7
importlib_metadata 8.0.0
jaraco.classes 3.4.0
jaraco.context 5.3.0
jaraco.functools 4.0.1
Jinja2 3.1.4
jiter 0.5.0
jmespath 1.0.1
joblib 1.3.2
jsonlines 4.0.0
keyring 25.2.1
kiwisolver 1.4.5
lark-parser 0.12.0
Mako 1.3.5
Markdown 3.6
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.9.1
mdurl 0.1.2
more-itertools 10.3.0
mpmath 1.3.0
multidict 6.0.5
multiprocess 0.70.16
networkx 3.3
nh3 0.2.18
nltk 3.8.1
numpy 1.26.4
nvidia-htop 1.2.0
openai 1.35.13
optuna 3.6.1
packaging 24.1
pandas 2.2.2
pillow 10.4.0
pip 24.0
pkginfo 1.10.0
platformdirs 4.2.2
plotly 5.22.0
protobuf 4.25.3
psutil 6.0.0
pyarrow 16.1.0
pyarrow-hotfix 0.6
pydantic 2.8.2
pydantic_core 2.20.1
Pygments 2.18.0
pyparsing 3.1.2
python-dateutil 2.9.0.post0
pytz 2024.1
PyYAML 6.0.1
readme_renderer 44.0
regex 2024.5.15
requests 2.32.3
requests-toolbelt 1.0.0
rfc3986 2.0.0
rich 13.7.1
s3transfer 0.10.2
safetensors 0.4.3
scikit-learn 1.5.1
scipy 1.14.0
seaborn 0.13.2
sentry-sdk 2.10.0
setproctitle 1.3.3
setuptools 69.5.1
six 1.16.0
smmap 5.0.1
sniffio 1.3.1
SQLAlchemy 2.0.31
structlog 24.2.0
sympy 1.13.0
tenacity 8.5.0
tensorboard 2.17.0
tensorboard-data-server 0.7.2
termcolor 2.4.0
threadpoolctl 3.5.0
tokenizers 0.19.1
torch 2.2.2
tqdm 4.66.4
transformers 4.42.4
twine 5.1.1
typing_extensions 4.12.2
tzdata 2024.1
ujson 5.10.0
urllib3 2.2.2
wandb 0.17.4
Werkzeug 3.0.3
wheel 0.43.0
xxhash 3.4.1
yarl 1.9.4
zipp 3.19.2
refs: