Error loading timesfm appears to be looking for torch_model.ckpt when I am loading jax model

Error loading timesfm model from timesfm/notebooks/covariates.ipynb at master · google-research/timesfm · GitHub

model = timesfm.TimesFm(
hparams=timesfm.TimesFmHparams(
backend=‘cpu’,
per_core_batch_size=32,
horizon_len=128,
num_layers=50,
#use_positional_embedding=False,
context_len=2048,
),
checkpoint=timesfm.TimesFmCheckpoint(
huggingface_repo_id=“google/timesfm-2.0-500m-jax”),
)

TimesFM v1.2.0. See timesfm/README.md at master · google-research/timesfm · GitHub for updated APIs.
Loaded Jax TimesFM.
Loaded PyTorch TimesFM.
README.md: 7.34kB [00:00, 10.1MB/s] | 0/6 [00:00<?, ?it/s]
descriptor.pbtxt: 100%|█████████████████████████████████████████████████████████████████████████████████████| 490/490 [00:00<00:00, 966kB/s]
_CHECKPOINT_METADATA: 100%|███████████████████████████████████████████████████████████████████████████████| 92.0/92.0 [00:00<00:00, 465kB/s]
metadata: 107kB [00:00, 150MB/s] | 0.00/92.0 [00:00<?, ?B/s]
.gitattributes: 1.69kB [00:00, 5.69MB/s]
checkpoint: 100%|██████████████████████████████████████████████████████████████████████████████████████| 2.00G/2.00G [01:46<00:00, 18.7MB/s]
Fetching 6 files: 100%|███████████████████████████████████████████████████████████████████████████████████████| 6/6 [01:46<00:00, 17.79s/it]
Traceback (most recent call last):
File “/Users/rc/Downloads/price_model/run_timesfm.py”, line 6, in
model = timesfm.TimesFm(
File “/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/timesfm/timesfm_base.py”, line 180, in init
self.load_from_checkpoint(checkpoint)
File “/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/timesfm/timesfm_torch.py”, line 61, in load_from_checkpoint
loaded_checkpoint = torch.load(checkpoint_path, weights_only=True)
File “/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/serialization.py”, line 998, in load
with _open_file_like(f, ‘rb’) as opened_file:
File “/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/serialization.py”, line 445, in _open_file_like
return _open_file(name_or_buffer, mode)
File “/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/serialization.py”, line 426, in init
super().init(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: ‘/Users/rc/.cache/huggingface/hub/models–google–timesfm-2.0-500m-jax/snapshots/47dedfcadf2abace1cc96071ddb798cfcd3bfcef/torch_model.ckpt’

1 Like

If the JAX-related libraries are not installed correctly, import will fail and the PyTorch version of TimesFM will be used by default.

I think it’s probably easier to use the PyTorch version…

Thank you, I see it’s loading TimesFmTorch, so I used that checkpoint as you suggested and it worked

1 Like