I am trying to create a PretrainedConfig, which uses another config as sub-config. However, the resulting composite config cannot be saved via save_pretrained method and thus causes the Trainer to crash. Anyone can tell me what i did wrong? Sample code for quick replication:
from transformers import PretrainedConfig
class ConfigA(PretrainedConfig):
def __init__(self, a=10):
super().__init__()
self.a = a
class ConfigB(PretrainedConfig):
def __init__(self, b=1):
super().__init__()
self.b = b
class CompositeConfig(PretrainedConfig):
is_composition = True
def __init__(self, a: ConfigA, b: ConfigB, c=1):
super().__init__()
self.a = a
self.b = b
self.c = c
a = ConfigA()
b = ConfigB()
c = CompositeConfig(a, b)
a.save_pretrained("test/a")
b.save_pretrained("test/b")
c.save_pretrained("test/c")
output:
Traceback (most recent call last):
File "/Users/andre/PycharmProjects/aho-master-ocr/tests/configs.py", line 32, in <module>
c.save_pretrained("test/c")
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/site-packages/transformers/configuration_utils.py", line 457, in save_pretrained
self.to_json_file(output_config_file, use_diff=True)
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/site-packages/transformers/configuration_utils.py", line 900, in to_json_file
writer.write(self.to_json_string(use_diff=use_diff))
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/site-packages/transformers/configuration_utils.py", line 886, in to_json_string
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/json/__init__.py", line 234, in dumps
return cls(
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/json/encoder.py", line 201, in encode
chunks = list(chunks)
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/json/encoder.py", line 431, in _iterencode
yield from _iterencode_dict(o, _current_indent_level)
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/json/encoder.py", line 405, in _iterencode_dict
yield from chunks
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/json/encoder.py", line 438, in _iterencode
o = _default(o)
File "/Users/andre/opt/anaconda3/envs/easyocr/lib/python3.8/json/encoder.py", line 179, in default
raise TypeError(f'Object of type {o.__class__.__name__} '
TypeError: Object of type ConfigA is not JSON serializable