Skip to content

Commit effb533

Browse files
committed
Improve Hugging Face support
1 parent b91190f commit effb533

17 files changed

Lines changed: 267 additions & 270 deletions

File tree

src/fairseq2/checkpoint/_metadata_provider.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
)
2121
from fairseq2.file_system import FileMode, FileSystem
2222
from fairseq2.gang import GangError, Gangs
23-
from fairseq2.models.llama import LLAMA_MODEL_FAMILY, LLaMAConfig
24-
from fairseq2.models.llama.integ import convert_to_hg_llama_config
2523
from fairseq2.utils.structured import unstructure
2624
from fairseq2.utils.yaml import YamlDumper
2725

2826

2927
class CheckpointMetadataSaver(ABC):
3028
@abstractmethod
31-
def save(self, model_family: str, model_config: object) -> None: ...
29+
def save(
30+
self, model_family: str, model_config: object, hg_model_config: object = None
31+
) -> None: ...
3232

3333

3434
@final
@@ -50,39 +50,14 @@ def __init__(
5050
self._file_system = file_system
5151
self._yaml_dumper = yaml_dumper
5252

53-
def save(self, model_family: str, model_config: object) -> None:
53+
def save(
54+
self, model_family: str, model_config: object, hg_model_config: object = None
55+
) -> None:
5456
if self._gangs.root.rank == 0:
55-
unstructured_config = unstructure(model_config)
56-
57-
metadata: dict[str, object] = {
58-
"name": "checkpoint",
59-
"model_family": model_family,
60-
"model_config": {
61-
"_set_": unstructured_config,
62-
},
63-
}
57+
self._save_asset_card(model_family, model_config)
6458

65-
if self._gangs.tp.size != 1:
66-
metadata["num_shards"] = self._gangs.tp.size
67-
68-
metadata_file = self._checkpoint_dir.joinpath("model.yaml")
69-
70-
def save_error() -> AssetMetadataSaveError:
71-
return AssetMetadataSaveError(
72-
f"The checkpoint metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details."
73-
)
74-
75-
try:
76-
self._file_system.make_directory(metadata_file.parent)
77-
except OSError as ex:
78-
raise save_error() from ex
79-
80-
try:
81-
self._yaml_dumper.dump(metadata, metadata_file)
82-
except OSError as ex:
83-
raise save_error() from ex
84-
85-
self._save_huggingface_config(model_family, model_config)
59+
if hg_model_config is not None:
60+
self._save_hg_config(hg_model_config)
8661

8762
try:
8863
self._gangs.root.barrier()
@@ -91,17 +66,38 @@ def save_error() -> AssetMetadataSaveError:
9166
"The collective barrier after the checkpoint metadata save operation has failed. See the nested exception for details."
9267
) from ex
9368

94-
def _save_huggingface_config(self, model_family: str, model_config: object) -> None:
95-
if model_family != LLAMA_MODEL_FAMILY:
96-
return
69+
def _save_asset_card(self, model_family: str, model_config: object) -> None:
70+
unstructured_model_config = unstructure(model_config)
71+
72+
metadata: dict[str, object] = {
73+
"name": "checkpoint",
74+
"model_family": model_family,
75+
"model_config": {
76+
"_set_": unstructured_model_config,
77+
},
78+
}
9779

98-
if not isinstance(model_config, LLaMAConfig):
99-
raise TypeError(
100-
f"`model_config` must be of type `{LLaMAConfig}`, but is of type `{type(model_config)}` instead."
80+
if self._gangs.tp.size != 1:
81+
metadata["num_shards"] = self._gangs.tp.size
82+
83+
metadata_file = self._checkpoint_dir.joinpath("model.yaml")
84+
85+
def save_error() -> AssetMetadataSaveError:
86+
return AssetMetadataSaveError(
87+
f"The checkpoint metadata cannot be saved to the '{metadata_file}' file. See the nested exception for details."
10188
)
10289

103-
hg_config = convert_to_hg_llama_config(model_config)
90+
try:
91+
self._file_system.make_directory(metadata_file.parent)
92+
except OSError as ex:
93+
raise save_error() from ex
94+
95+
try:
96+
self._yaml_dumper.dump(metadata, metadata_file)
97+
except OSError as ex:
98+
raise save_error() from ex
10499

100+
def _save_hg_config(self, hg_model_config: object) -> None:
105101
hg_config_file = self._checkpoint_dir.joinpath("cc/config.json")
106102

107103
def save_error() -> AssetMetadataSaveError:
@@ -120,7 +116,7 @@ def save_error() -> AssetMetadataSaveError:
120116
raise save_error() from ex
121117

122118
try:
123-
json.dump(hg_config, fp, indent=2, sort_keys=True)
119+
json.dump(hg_model_config, fp, indent=2, sort_keys=True)
124120
except OSError as ex:
125121
raise save_error() from ex
126122
finally:

src/fairseq2/cli/_setup.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from fairseq2.chatbots import UnknownChatbotError
1010
from fairseq2.cli.commands.assets import ListAssetsHandler, ShowAssetHandler
1111
from fairseq2.cli.commands.chatbot import RunChatbotHandler
12-
from fairseq2.cli.commands.llama import (
13-
ConvertLLaMACheckpointHandler,
14-
WriteHFLLaMAConfigHandler,
15-
)
12+
from fairseq2.cli.commands.llama import ConvertLLaMACheckpointHandler
1613
from fairseq2.cli.commands.recipe import RecipeCommandHandler
1714
from fairseq2.context import RuntimeContext
1815
from fairseq2.data.text.tokenizers import (
@@ -179,12 +176,6 @@ def _register_llama_cli(cli: Cli) -> None:
179176
help="convert fairseq2 LLaMA checkpoints to reference checkpoints",
180177
)
181178

182-
group.add_command(
183-
name="write_hf_config",
184-
handler=WriteHFLLaMAConfigHandler(),
185-
help="write fairseq2 LLaMA configurations in Hugging Face format",
186-
)
187-
188179

189180
def _register_lm_cli(cli: Cli) -> None:
190181
group = cli.add_group("lm", help="language model recipes")

src/fairseq2/cli/commands/llama/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,3 @@
99
from fairseq2.cli.commands.llama._convert_checkpoint import (
1010
ConvertLLaMACheckpointHandler as ConvertLLaMACheckpointHandler,
1111
)
12-
from fairseq2.cli.commands.llama._write_hf_config import (
13-
WriteHFLLaMAConfigHandler as WriteHFLLaMAConfigHandler,
14-
)

src/fairseq2/cli/commands/llama/_convert_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def file_write_error() -> CliCommandError:
232232
"dim": model_config.model_dim,
233233
"n_layers": model_config.num_layers,
234234
"n_heads": model_config.num_attn_heads,
235-
"multiple_of": model_config.ffn_inner_dim_to_multiple,
235+
"multiple_of": model_config.ffn_inner_dim_multiple_of,
236236
"rope_theta": model_config.rope_theta,
237237
"norm_eps": 1e-5,
238238
}

src/fairseq2/cli/commands/llama/_write_hf_config.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

src/fairseq2/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from fairseq2.models._handler import CheckpointConverter as CheckpointConverter
2626
from fairseq2.models._handler import DelegatingModelHandler as DelegatingModelHandler
2727
from fairseq2.models._handler import FsdpApplier as FsdpApplier
28+
from fairseq2.models._handler import HGConfigConverter as HGConfigConverter
2829
from fairseq2.models._handler import ModelCompiler as ModelCompiler
2930
from fairseq2.models._handler import ModelFactory as ModelFactory
3031
from fairseq2.models._handler import ModelHandler as ModelHandler

0 commit comments

Comments
 (0)