2222import torch .nn as nn
2323import torch .nn .functional as F
2424from transformers import AutoConfig , AutoModel , AutoModelForSequenceClassification , PreTrainedModel
25+ from transformers .models .auto .modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
2526from transformers .utils import logging
2627
2728from nemo_automodel ._transformers .registry import ModelRegistry
3031logger = logging .get_logger (__name__ )
3132
3233
34+ def _extract_submodel (model : nn .Module , extract_submodel : str ) -> PreTrainedModel :
35+ """Extract a nested submodel from a loaded model using a dotted attribute path."""
36+ extracted_model = model
37+ for attr in extract_submodel .split ("." ):
38+ extracted_model = getattr (extracted_model , attr )
39+ if not hasattr (extracted_model , "config" ):
40+ raise ValueError (
41+ f"Extracted submodel at '{ extract_submodel } ' has no .config attribute. "
42+ f"The submodel must be a PreTrainedModel for save/reload to work. "
43+ f"Got { type (extracted_model ).__name__ } ."
44+ )
45+ return extracted_model
46+
47+
48+ def _get_supported_backbone_class (model_type : str , task : str ) -> type [nn .Module ] | None :
49+ """Return the registered retrieval backbone class for a model type and task."""
50+ task_map = SUPPORTED_BACKBONES .get (model_type .lower ())
51+ if task_map is None :
52+ return None
53+
54+ arch_name = task_map .get (task )
55+ if arch_name is None :
56+ raise ValueError (
57+ f"Unsupported task '{ task } ' for model type '{ model_type } '. Available tasks: { ', ' .join (task_map )} ."
58+ )
59+
60+ if arch_name not in ModelRegistry .model_arch_name_to_cls :
61+ raise ValueError (f"Model class '{ arch_name } ' not found in ModelRegistry." )
62+
63+ logger .info (f"Using { arch_name } from registry" )
64+ return ModelRegistry .model_arch_name_to_cls [arch_name ]
65+
66+
67+ def _move_to_extracted_dtype (model : nn .Module , extracted_model : nn .Module ) -> nn .Module :
68+ """Move a newly-built model to the dtype used by the extracted model."""
69+ for parameter in extracted_model .parameters ():
70+ return model .to (dtype = parameter .dtype )
71+ for buffer in extracted_model .buffers ():
72+ return model .to (dtype = buffer .dtype )
73+ return model
74+
75+
76+ def _load_from_extracted_state (
77+ backbone_class : type [PreTrainedModel ],
78+ config ,
79+ extracted_model : PreTrainedModel ,
80+ ) -> PreTrainedModel :
81+ """Load a target backbone from an extracted model's in-memory state dict."""
82+ # Use the base HF loader because some retrieval classes override
83+ # from_pretrained for path-based checkpoint loading.
84+ backbone = PreTrainedModel .from_pretrained .__func__ (
85+ backbone_class ,
86+ None ,
87+ config = config ,
88+ state_dict = extracted_model .state_dict (),
89+ )
90+ return _move_to_extracted_dtype (backbone , extracted_model )
91+
92+
93+ def _build_backbone_from_extracted_submodel (
94+ extracted_model : PreTrainedModel ,
95+ task : str ,
96+ pooling : Optional [str ],
97+ num_labels : Optional [int ],
98+ temperature : Optional [float ],
99+ ) -> PreTrainedModel :
100+ """Build a task-specific retrieval backbone from an extracted text submodel."""
101+ text_config = extracted_model .config
102+ model_type = getattr (text_config , "model_type" , "" )
103+ task_map = SUPPORTED_BACKBONES .get (model_type .lower ())
104+ has_supported_target = task_map is not None and task in task_map
105+
106+ if task_map is not None and not has_supported_target and task != "score" :
107+ raise ValueError (
108+ f"Unsupported task '{ task } ' for model type '{ model_type } '. Available tasks: { ', ' .join (task_map )} ."
109+ )
110+
111+ if task == "score" and not has_supported_target :
112+ config = text_config .__class__ .from_dict (text_config .to_dict ())
113+ try :
114+ backbone_class = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING [type (config )]
115+ except KeyError as exc :
116+ raise ValueError (f"No HuggingFace sequence-classification model found for '{ model_type } '." ) from exc
117+ elif not has_supported_target :
118+ return extracted_model
119+ else :
120+ backbone_class = _get_supported_backbone_class (model_type , task )
121+ config_class = getattr (backbone_class , "config_class" , None )
122+ if config_class is None or not hasattr (text_config , "to_dict" ):
123+ return extracted_model
124+
125+ config_dict = text_config .to_dict ()
126+ config_dict .pop ("model_type" , None )
127+ config = config_class (** config_dict )
128+
129+ attn_implementation = getattr (text_config , "_attn_implementation" , None )
130+ if attn_implementation is not None :
131+ config ._attn_implementation = attn_implementation
132+ if has_supported_target and pooling is not None :
133+ config .pooling = pooling
134+ if num_labels is not None :
135+ config .num_labels = num_labels
136+ if has_supported_target and temperature is not None :
137+ config .temperature = temperature
138+
139+ return _load_from_extracted_state (backbone_class , config , extracted_model )
140+
141+
33142def pool (last_hidden_states : torch .Tensor , attention_mask : torch .Tensor , pool_type : str ) -> torch .Tensor :
34143 """
35144 Pool hidden states using the specified pooling method.
@@ -100,15 +209,25 @@ def build_encoder_backbone(
100209 task : str ,
101210 trust_remote_code : bool = False ,
102211 pooling : Optional [str ] = None ,
212+ extract_submodel : Optional [str ] = None ,
213+ num_labels : Optional [int ] = None ,
214+ temperature : Optional [float ] = None ,
103215 ** hf_kwargs ,
104216) -> PreTrainedModel :
105217 """Build an encoder backbone from a pretrained checkpoint.
106218
107- For model types listed in :data:`SUPPORTED_BACKBONES`, resolves the
108- custom bidirectional architecture class from :class:`ModelRegistry`.
109- For all other model types, falls back to
110- ``AutoModel.from_pretrained`` (or ``AutoModelForSequenceClassification``
111- for the ``"score"`` task).
219+ When ``extract_submodel`` is set, loads the parent model with HuggingFace
220+ Auto classes and extracts the dotted path. For supported extracted text
221+ backbones, it then builds the registered retrieval class for the requested
222+ task (bidirectional base model for ``"embedding"``, sequence-classification
223+ wrapper for ``"score"``). For unsupported extracted text backbones, it
224+ returns the extracted model for ``"embedding"`` and wraps it with
225+ ``AutoModelForSequenceClassification`` for ``"score"``.
226+
227+ Without ``extract_submodel``, model types listed in
228+ :data:`SUPPORTED_BACKBONES` resolve to custom bidirectional classes from
229+ :class:`ModelRegistry`; all other model types fall back to HuggingFace Auto
230+ classes.
112231
113232 Args:
114233 model_name_or_path: Path or HuggingFace Hub identifier.
@@ -117,6 +236,10 @@ def build_encoder_backbone(
117236 pooling: Bi-encoder pooling strategy for registry backbones (e.g. Llama bidirectional)
118237 that accept it on ``from_pretrained``. Must not be forwarded to standard HF models
119238 (e.g. Qwen3) loaded via ``AutoModel``; those only receive ``**hf_kwargs``.
239+ extract_submodel: Dotted attribute path to extract from the loaded model
240+ (e.g. ``"language_model"`` to extract the text backbone from a VLM).
241+ num_labels: Number of labels for reranking/classification backbones.
242+ temperature: Optional retrieval score temperature for custom retrieval backbones.
120243 **hf_kwargs: Extra keyword arguments forwarded to ``from_pretrained``.
121244
122245 Returns:
@@ -129,29 +252,34 @@ def build_encoder_backbone(
129252 config = AutoConfig .from_pretrained (model_name_or_path , trust_remote_code = trust_remote_code )
130253 model_type = getattr (config , "model_type" , "" )
131254
132- task_map = SUPPORTED_BACKBONES .get (model_type .lower ())
133-
134- if task_map is not None :
135- arch_name = task_map .get (task )
136- if arch_name is None :
137- raise ValueError (
138- f"Unsupported task '{ task } ' for model type '{ model_type } '. Available tasks: { ', ' .join (task_map )} ."
139- )
140-
141- if arch_name not in ModelRegistry .model_arch_name_to_cls :
142- raise ValueError (f"Model class '{ arch_name } ' not found in ModelRegistry." )
143-
144- BidirectionalModelClass = ModelRegistry .model_arch_name_to_cls [arch_name ]
145- logger .info (f"Using { arch_name } from registry" )
255+ if extract_submodel is not None :
256+ logger .info (f"Loading { model_name_or_path } with HuggingFace Auto classes to extract { extract_submodel } " )
257+ model = AutoModel .from_pretrained (model_name_or_path , trust_remote_code = trust_remote_code , ** hf_kwargs )
258+ extracted_model = _extract_submodel (model , extract_submodel )
259+ return _build_backbone_from_extracted_submodel (
260+ extracted_model ,
261+ task = task ,
262+ pooling = pooling ,
263+ num_labels = num_labels ,
264+ temperature = temperature ,
265+ )
146266
267+ BidirectionalModelClass = _get_supported_backbone_class (model_type , task )
268+ if BidirectionalModelClass is not None :
147269 if pooling is not None :
148270 hf_kwargs ["pooling" ] = pooling
271+ if num_labels is not None :
272+ hf_kwargs ["num_labels" ] = num_labels
273+ if temperature is not None :
274+ hf_kwargs ["temperature" ] = temperature
149275 return BidirectionalModelClass .from_pretrained (
150276 model_name_or_path , trust_remote_code = trust_remote_code , ** hf_kwargs
151277 )
152278
153279 # Fallback: use HuggingFace Auto classes for model types not in SUPPORTED_BACKBONES
154280 logger .info (f"Model type '{ model_type } ' not in SUPPORTED_BACKBONES; falling back to HuggingFace Auto classes" )
281+ if task == "score" and num_labels is not None :
282+ hf_kwargs ["num_labels" ] = num_labels
155283 if task == "score" :
156284 return AutoModelForSequenceClassification .from_pretrained (
157285 model_name_or_path , trust_remote_code = trust_remote_code , ** hf_kwargs
@@ -205,8 +333,6 @@ def save_encoder_pretrained(model: nn.Module, save_directory: str, **kwargs) ->
205333 "llama_bidirec" : _LLAMA_TASKS ,
206334 "ministral3" : _MINISTRAL3_BIDIREC_TASKS ,
207335 "ministral3_bidirec" : _MINISTRAL3_BIDIREC_TASKS ,
208- # Mistral3-VL Hub configs use top-level model_type "mistral3" (language is nested under text_config).
209- "mistral3" : _MINISTRAL3_BIDIREC_TASKS ,
210336}
211337
212338
0 commit comments