2020)
2121from fairseq2 .file_system import FileMode , FileSystem
2222from fairseq2 .gang import GangError , Gangs
23- from fairseq2 .models .llama import LLAMA_MODEL_FAMILY , LLaMAConfig
24- from fairseq2 .models .llama .integ import convert_to_hg_llama_config
2523from fairseq2 .utils .structured import unstructure
2624from fairseq2 .utils .yaml import YamlDumper
2725
2826
2927class CheckpointMetadataSaver (ABC ):
3028 @abstractmethod
31- def save (self , model_family : str , model_config : object ) -> None : ...
29+ def save (
30+ self , model_family : str , model_config : object , hg_model_config : object = None
31+ ) -> None : ...
3232
3333
3434@final
@@ -50,39 +50,14 @@ def __init__(
5050 self ._file_system = file_system
5151 self ._yaml_dumper = yaml_dumper
5252
53- def save (self , model_family : str , model_config : object ) -> None :
53+ def save (
54+ self , model_family : str , model_config : object , hg_model_config : object = None
55+ ) -> None :
5456 if self ._gangs .root .rank == 0 :
55- unstructured_config = unstructure (model_config )
56-
57- metadata : dict [str , object ] = {
58- "name" : "checkpoint" ,
59- "model_family" : model_family ,
60- "model_config" : {
61- "_set_" : unstructured_config ,
62- },
63- }
57+ self ._save_asset_card (model_family , model_config )
6458
65- if self ._gangs .tp .size != 1 :
66- metadata ["num_shards" ] = self ._gangs .tp .size
67-
68- metadata_file = self ._checkpoint_dir .joinpath ("model.yaml" )
69-
70- def save_error () -> AssetMetadataSaveError :
71- return AssetMetadataSaveError (
72- f"The checkpoint metadata cannot be saved to the '{ metadata_file } ' file. See the nested exception for details."
73- )
74-
75- try :
76- self ._file_system .make_directory (metadata_file .parent )
77- except OSError as ex :
78- raise save_error () from ex
79-
80- try :
81- self ._yaml_dumper .dump (metadata , metadata_file )
82- except OSError as ex :
83- raise save_error () from ex
84-
85- self ._save_huggingface_config (model_family , model_config )
59+ if hg_model_config is not None :
60+ self ._save_hg_config (hg_model_config )
8661
8762 try :
8863 self ._gangs .root .barrier ()
@@ -91,17 +66,38 @@ def save_error() -> AssetMetadataSaveError:
9166 "The collective barrier after the checkpoint metadata save operation has failed. See the nested exception for details."
9267 ) from ex
9368
94- def _save_huggingface_config (self , model_family : str , model_config : object ) -> None :
95- if model_family != LLAMA_MODEL_FAMILY :
96- return
69+ def _save_asset_card (self , model_family : str , model_config : object ) -> None :
70+ unstructured_model_config = unstructure (model_config )
71+
72+ metadata : dict [str , object ] = {
73+ "name" : "checkpoint" ,
74+ "model_family" : model_family ,
75+ "model_config" : {
76+ "_set_" : unstructured_model_config ,
77+ },
78+ }
9779
98- if not isinstance (model_config , LLaMAConfig ):
99- raise TypeError (
100- f"`model_config` must be of type `{ LLaMAConfig } `, but is of type `{ type (model_config )} ` instead."
80+ if self ._gangs .tp .size != 1 :
81+ metadata ["num_shards" ] = self ._gangs .tp .size
82+
83+ metadata_file = self ._checkpoint_dir .joinpath ("model.yaml" )
84+
85+ def save_error () -> AssetMetadataSaveError :
86+ return AssetMetadataSaveError (
87+ f"The checkpoint metadata cannot be saved to the '{ metadata_file } ' file. See the nested exception for details."
10188 )
10289
103- hg_config = convert_to_hg_llama_config (model_config )
90+ try :
91+ self ._file_system .make_directory (metadata_file .parent )
92+ except OSError as ex :
93+ raise save_error () from ex
94+
95+ try :
96+ self ._yaml_dumper .dump (metadata , metadata_file )
97+ except OSError as ex :
98+ raise save_error () from ex
10499
100+ def _save_hg_config (self , hg_model_config : object ) -> None :
105101 hg_config_file = self ._checkpoint_dir .joinpath ("cc/config.json" )
106102
107103 def save_error () -> AssetMetadataSaveError :
@@ -120,7 +116,7 @@ def save_error() -> AssetMetadataSaveError:
120116 raise save_error () from ex
121117
122118 try :
123- json .dump (hg_config , fp , indent = 2 , sort_keys = True )
119+ json .dump (hg_model_config , fp , indent = 2 , sort_keys = True )
124120 except OSError as ex :
125121 raise save_error () from ex
126122 finally :
0 commit comments