Skip to content

Commit 3b1a513

Browse files
authored
Merge branch 'main' into khazic/fix/pp-vlm-preserve-forward
2 parents 1a09f17 + 174ba8d commit 3b1a513

17 files changed

Lines changed: 840 additions & 123 deletions

File tree

nemo_automodel/_transformers/retrieval.py

Lines changed: 147 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.nn as nn
2323
import torch.nn.functional as F
2424
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification, PreTrainedModel
25+
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
2526
from transformers.utils import logging
2627

2728
from nemo_automodel._transformers.registry import ModelRegistry
@@ -30,6 +31,114 @@
3031
logger = logging.get_logger(__name__)
3132

3233

34+
def _extract_submodel(model: nn.Module, extract_submodel: str) -> PreTrainedModel:
35+
"""Extract a nested submodel from a loaded model using a dotted attribute path."""
36+
extracted_model = model
37+
for attr in extract_submodel.split("."):
38+
extracted_model = getattr(extracted_model, attr)
39+
if not hasattr(extracted_model, "config"):
40+
raise ValueError(
41+
f"Extracted submodel at '{extract_submodel}' has no .config attribute. "
42+
f"The submodel must be a PreTrainedModel for save/reload to work. "
43+
f"Got {type(extracted_model).__name__}."
44+
)
45+
return extracted_model
46+
47+
48+
def _get_supported_backbone_class(model_type: str, task: str) -> type[nn.Module] | None:
49+
"""Return the registered retrieval backbone class for a model type and task."""
50+
task_map = SUPPORTED_BACKBONES.get(model_type.lower())
51+
if task_map is None:
52+
return None
53+
54+
arch_name = task_map.get(task)
55+
if arch_name is None:
56+
raise ValueError(
57+
f"Unsupported task '{task}' for model type '{model_type}'. Available tasks: {', '.join(task_map)}."
58+
)
59+
60+
if arch_name not in ModelRegistry.model_arch_name_to_cls:
61+
raise ValueError(f"Model class '{arch_name}' not found in ModelRegistry.")
62+
63+
logger.info(f"Using {arch_name} from registry")
64+
return ModelRegistry.model_arch_name_to_cls[arch_name]
65+
66+
67+
def _move_to_extracted_dtype(model: nn.Module, extracted_model: nn.Module) -> nn.Module:
68+
"""Move a newly-built model to the dtype used by the extracted model."""
69+
for parameter in extracted_model.parameters():
70+
return model.to(dtype=parameter.dtype)
71+
for buffer in extracted_model.buffers():
72+
return model.to(dtype=buffer.dtype)
73+
return model
74+
75+
76+
def _load_from_extracted_state(
77+
backbone_class: type[PreTrainedModel],
78+
config,
79+
extracted_model: PreTrainedModel,
80+
) -> PreTrainedModel:
81+
"""Load a target backbone from an extracted model's in-memory state dict."""
82+
# Use the base HF loader because some retrieval classes override
83+
# from_pretrained for path-based checkpoint loading.
84+
backbone = PreTrainedModel.from_pretrained.__func__(
85+
backbone_class,
86+
None,
87+
config=config,
88+
state_dict=extracted_model.state_dict(),
89+
)
90+
return _move_to_extracted_dtype(backbone, extracted_model)
91+
92+
93+
def _build_backbone_from_extracted_submodel(
94+
extracted_model: PreTrainedModel,
95+
task: str,
96+
pooling: Optional[str],
97+
num_labels: Optional[int],
98+
temperature: Optional[float],
99+
) -> PreTrainedModel:
100+
"""Build a task-specific retrieval backbone from an extracted text submodel."""
101+
text_config = extracted_model.config
102+
model_type = getattr(text_config, "model_type", "")
103+
task_map = SUPPORTED_BACKBONES.get(model_type.lower())
104+
has_supported_target = task_map is not None and task in task_map
105+
106+
if task_map is not None and not has_supported_target and task != "score":
107+
raise ValueError(
108+
f"Unsupported task '{task}' for model type '{model_type}'. Available tasks: {', '.join(task_map)}."
109+
)
110+
111+
if task == "score" and not has_supported_target:
112+
config = text_config.__class__.from_dict(text_config.to_dict())
113+
try:
114+
backbone_class = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)]
115+
except KeyError as exc:
116+
raise ValueError(f"No HuggingFace sequence-classification model found for '{model_type}'.") from exc
117+
elif not has_supported_target:
118+
return extracted_model
119+
else:
120+
backbone_class = _get_supported_backbone_class(model_type, task)
121+
config_class = getattr(backbone_class, "config_class", None)
122+
if config_class is None or not hasattr(text_config, "to_dict"):
123+
return extracted_model
124+
125+
config_dict = text_config.to_dict()
126+
config_dict.pop("model_type", None)
127+
config = config_class(**config_dict)
128+
129+
attn_implementation = getattr(text_config, "_attn_implementation", None)
130+
if attn_implementation is not None:
131+
config._attn_implementation = attn_implementation
132+
if has_supported_target and pooling is not None:
133+
config.pooling = pooling
134+
if num_labels is not None:
135+
config.num_labels = num_labels
136+
if has_supported_target and temperature is not None:
137+
config.temperature = temperature
138+
139+
return _load_from_extracted_state(backbone_class, config, extracted_model)
140+
141+
33142
def pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str) -> torch.Tensor:
34143
"""
35144
Pool hidden states using the specified pooling method.
@@ -100,15 +209,25 @@ def build_encoder_backbone(
100209
task: str,
101210
trust_remote_code: bool = False,
102211
pooling: Optional[str] = None,
212+
extract_submodel: Optional[str] = None,
213+
num_labels: Optional[int] = None,
214+
temperature: Optional[float] = None,
103215
**hf_kwargs,
104216
) -> PreTrainedModel:
105217
"""Build an encoder backbone from a pretrained checkpoint.
106218
107-
For model types listed in :data:`SUPPORTED_BACKBONES`, resolves the
108-
custom bidirectional architecture class from :class:`ModelRegistry`.
109-
For all other model types, falls back to
110-
``AutoModel.from_pretrained`` (or ``AutoModelForSequenceClassification``
111-
for the ``"score"`` task).
219+
When ``extract_submodel`` is set, loads the parent model with HuggingFace
220+
Auto classes and extracts the dotted path. For supported extracted text
221+
backbones, it then builds the registered retrieval class for the requested
222+
task (bidirectional base model for ``"embedding"``, sequence-classification
223+
wrapper for ``"score"``). For unsupported extracted text backbones, it
224+
returns the extracted model for ``"embedding"`` and wraps it with
225+
``AutoModelForSequenceClassification`` for ``"score"``.
226+
227+
Without ``extract_submodel``, model types listed in
228+
:data:`SUPPORTED_BACKBONES` resolve to custom bidirectional classes from
229+
:class:`ModelRegistry`; all other model types fall back to HuggingFace Auto
230+
classes.
112231
113232
Args:
114233
model_name_or_path: Path or HuggingFace Hub identifier.
@@ -117,6 +236,10 @@ def build_encoder_backbone(
117236
pooling: Bi-encoder pooling strategy for registry backbones (e.g. Llama bidirectional)
118237
that accept it on ``from_pretrained``. Must not be forwarded to standard HF models
119238
(e.g. Qwen3) loaded via ``AutoModel``; those only receive ``**hf_kwargs``.
239+
extract_submodel: Dotted attribute path to extract from the loaded model
240+
(e.g. ``"language_model"`` to extract the text backbone from a VLM).
241+
num_labels: Number of labels for reranking/classification backbones.
242+
temperature: Optional retrieval score temperature for custom retrieval backbones.
120243
**hf_kwargs: Extra keyword arguments forwarded to ``from_pretrained``.
121244
122245
Returns:
@@ -129,29 +252,34 @@ def build_encoder_backbone(
129252
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
130253
model_type = getattr(config, "model_type", "")
131254

132-
task_map = SUPPORTED_BACKBONES.get(model_type.lower())
133-
134-
if task_map is not None:
135-
arch_name = task_map.get(task)
136-
if arch_name is None:
137-
raise ValueError(
138-
f"Unsupported task '{task}' for model type '{model_type}'. Available tasks: {', '.join(task_map)}."
139-
)
140-
141-
if arch_name not in ModelRegistry.model_arch_name_to_cls:
142-
raise ValueError(f"Model class '{arch_name}' not found in ModelRegistry.")
143-
144-
BidirectionalModelClass = ModelRegistry.model_arch_name_to_cls[arch_name]
145-
logger.info(f"Using {arch_name} from registry")
255+
if extract_submodel is not None:
256+
logger.info(f"Loading {model_name_or_path} with HuggingFace Auto classes to extract {extract_submodel}")
257+
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **hf_kwargs)
258+
extracted_model = _extract_submodel(model, extract_submodel)
259+
return _build_backbone_from_extracted_submodel(
260+
extracted_model,
261+
task=task,
262+
pooling=pooling,
263+
num_labels=num_labels,
264+
temperature=temperature,
265+
)
146266

267+
BidirectionalModelClass = _get_supported_backbone_class(model_type, task)
268+
if BidirectionalModelClass is not None:
147269
if pooling is not None:
148270
hf_kwargs["pooling"] = pooling
271+
if num_labels is not None:
272+
hf_kwargs["num_labels"] = num_labels
273+
if temperature is not None:
274+
hf_kwargs["temperature"] = temperature
149275
return BidirectionalModelClass.from_pretrained(
150276
model_name_or_path, trust_remote_code=trust_remote_code, **hf_kwargs
151277
)
152278

153279
# Fallback: use HuggingFace Auto classes for model types not in SUPPORTED_BACKBONES
154280
logger.info(f"Model type '{model_type}' not in SUPPORTED_BACKBONES; falling back to HuggingFace Auto classes")
281+
if task == "score" and num_labels is not None:
282+
hf_kwargs["num_labels"] = num_labels
155283
if task == "score":
156284
return AutoModelForSequenceClassification.from_pretrained(
157285
model_name_or_path, trust_remote_code=trust_remote_code, **hf_kwargs
@@ -205,8 +333,6 @@ def save_encoder_pretrained(model: nn.Module, save_directory: str, **kwargs) ->
205333
"llama_bidirec": _LLAMA_TASKS,
206334
"ministral3": _MINISTRAL3_BIDIREC_TASKS,
207335
"ministral3_bidirec": _MINISTRAL3_BIDIREC_TASKS,
208-
# Mistral3-VL Hub configs use top-level model_type "mistral3" (language is nested under text_config).
209-
"mistral3": _MINISTRAL3_BIDIREC_TASKS,
210336
}
211337

212338

nemo_automodel/components/datasets/vlm/collate_fns.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,10 @@ def kimi_k25_vl_collate_fn(
871871
all_expanded = []
872872
all_pixel_values = []
873873
all_grid_thws = []
874+
# Per-sample image counts, kept in lockstep with all_expanded so that
875+
# n_images_per_sample length matches batch_size downstream. Samples that
876+
# are text-only or whose image region was orphaned by truncation get 0.
877+
per_sample_image_count: List[int] = []
874878

875879
for i, conversation in enumerate(conversations):
876880
# Collect medias for this conversation
@@ -923,12 +927,14 @@ def kimi_k25_vl_collate_fn(
923927

924928
# Only include image data if all expanded image tokens survived truncation.
925929
# Partial truncation into image regions would cause a mismatch in the model forward.
930+
sample_image_count = 0
926931
if grid_thws is not None:
927932
merge_h, merge_w = _DEFAULT_MERGE_KERNEL
928933
expected_image_tokens = sum(int((h // merge_h) * (w // merge_w)) for _, h, w in grid_thws.tolist())
929934
actual_image_tokens = (input_ids == media_token_id).sum().item()
930935
if actual_image_tokens == expected_image_tokens:
931936
all_grid_thws.append(grid_thws)
937+
sample_image_count = int(grid_thws.shape[0])
932938
if "pixel_values" in sample_batch:
933939
all_pixel_values.append(sample_batch["pixel_values"])
934940
else:
@@ -943,6 +949,7 @@ def kimi_k25_vl_collate_fn(
943949
"attention_mask": attention_mask,
944950
}
945951
)
952+
per_sample_image_count.append(sample_image_count)
946953

947954
if not all_expanded:
948955
raise ValueError(
@@ -990,9 +997,10 @@ def kimi_k25_vl_collate_fn(
990997
result["grid_thws"] = torch.cat(all_grid_thws, dim=0)
991998
# Also add as image_grid_hws for PP chunking in finetune.py
992999
result["image_grid_hws"] = result["grid_thws"][:, 1:] # [N, 3] -> [N, 2] (drop temporal dim, keep H,W)
993-
# Per-sample image counts for PP chunking
994-
image_counts = [g.shape[0] for g in all_grid_thws]
995-
result["n_images_per_sample"] = torch.tensor(image_counts, dtype=torch.long)
1000+
# Per-sample image counts for PP chunking. Length must equal batch_size,
1001+
# so include zeros for text-only samples and for samples whose image
1002+
# region was orphaned by truncation.
1003+
result["n_images_per_sample"] = torch.tensor(per_sample_image_count, dtype=torch.long)
9961004

9971005
# Build labels
9981006
labels = build_labels_from_template(

nemo_automodel/components/distributed/parallelizer.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,46 +1195,55 @@ def validate_tp_mesh(model, tp_mesh):
11951195
)
11961196

11971197

1198-
def _find_largest_module_list(model: nn.Module) -> Optional[nn.ModuleList]:
1198+
def _find_largest_module_list(model: nn.Module) -> Optional[Union[nn.ModuleList, nn.ModuleDict]]:
11991199
"""
1200-
Heuristic function to find the largest nn.ModuleList in a model.
1200+
Heuristic function to find the largest layer container in a model.
12011201
1202-
This function recursively traverses the model to find all nn.ModuleList instances
1203-
and returns the one with the most modules. This is useful as a fallback when
1204-
the model architecture is unknown, since transformer layers are typically
1205-
organized in ModuleLists.
1202+
This function recursively traverses the model to find all nn.ModuleList and
1203+
pipeline-split nn.ModuleDict instances and returns the one with the most
1204+
modules. This is useful as a fallback when the model architecture is unknown,
1205+
since transformer layers are typically organized in ModuleLists. Pipeline
1206+
splitting converts ModuleLists to ModuleDicts keyed by original layer index.
12061207
12071208
Args:
12081209
model (nn.Module): The model to search through.
12091210
12101211
Returns:
1211-
Optional[nn.ModuleList]: The largest ModuleList found, or None if no ModuleList exists.
1212+
Optional[Union[nn.ModuleList, nn.ModuleDict]]: The largest layer container found, or None.
12121213
"""
1213-
largest_module_list = None
1214+
largest_module_list: Optional[Union[nn.ModuleList, nn.ModuleDict]] = None
12141215
largest_size = 0
12151216

1217+
def _is_pp_layer_module_dict(module: nn.ModuleDict) -> bool:
1218+
# functional.py converts split ModuleLists to ModuleDicts with stringified
1219+
# numeric indices. Avoid treating arbitrary named ModuleDicts (for example
1220+
# adapter registries) as transformer layer containers in the heuristic path.
1221+
return all(key.isdigit() for key in module.keys())
1222+
12161223
def _recursive_search(module: nn.Module, path: str = ""):
12171224
nonlocal largest_module_list, largest_size
12181225

12191226
for name, child in module.named_children():
12201227
current_path = f"{path}.{name}" if path else name
12211228

1222-
if isinstance(child, nn.ModuleList):
1229+
if isinstance(child, nn.ModuleList) or (
1230+
isinstance(child, nn.ModuleDict) and _is_pp_layer_module_dict(child)
1231+
):
12231232
current_size = len(child)
12241233
if current_size > largest_size:
12251234
largest_size = current_size
12261235
largest_module_list = child
1227-
logger.debug(f"Found ModuleList at {current_path} with {current_size} modules")
1236+
logger.debug(f"Found {type(child).__name__} at {current_path} with {current_size} modules")
12281237

12291238
# Continue recursive search
12301239
_recursive_search(child, current_path)
12311240

12321241
_recursive_search(model)
12331242

12341243
if largest_module_list is not None:
1235-
logger.info(f"Largest ModuleList found with {largest_size} modules")
1244+
logger.info(f"Largest layer container found with {largest_size} modules")
12361245
else:
1237-
logger.warning("No ModuleList found in the model")
1246+
logger.warning("No ModuleList or ModuleDict found in the model")
12381247

12391248
return largest_module_list
12401249

@@ -1320,6 +1329,8 @@ def _extend_layers(layers, modules):
13201329
for m in modules:
13211330
if isinstance(m, nn.ModuleList):
13221331
layers.extend(m)
1332+
elif isinstance(m, nn.ModuleDict):
1333+
layers.extend(m.values())
13231334
else:
13241335
layers.append(m)
13251336

@@ -1338,15 +1349,20 @@ def _extend_layers(layers, modules):
13381349
elif hasattr(model, "layers"):
13391350
layers.extend(model.layers)
13401351
else:
1341-
# Use heuristic to find the largest ModuleList in the model
1352+
# Use heuristic to find the largest layer container in the model.
13421353
logger.warning(f"Unknown model type: {model_cls}. Using heuristic to find transformer layers.")
13431354
largest_module_list = _find_largest_module_list(model)
13441355
if largest_module_list is None:
1345-
# If no ModuleList found, still raise an exception
1356+
# If no layer container is found, still raise an exception.
13461357
print(model)
1347-
raise ValueError(f"Unknown model type: {model_cls} and no ModuleList found in model structure")
1358+
raise ValueError(
1359+
f"Unknown model type: {model_cls} and no ModuleList or ModuleDict found in model structure"
1360+
)
13481361

1349-
layers.extend(largest_module_list)
1362+
if isinstance(largest_module_list, nn.ModuleDict):
1363+
layers.extend(largest_module_list.values())
1364+
else:
1365+
layers.extend(largest_module_list)
13501366
logger.info(f"Successfully extracted {len(largest_module_list)} layers using heuristic")
13511367

13521368
assert all(isinstance(m, nn.Module) for m in layers), "layers shoudl be nn.Module instances"

nemo_automodel/components/models/deepseek_v4/layers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
See ``_hc_split_sinkhorn`` for the pure-torch port of the reference mixer
4747
(ported from miles PR 1045's ``kernel/sinkhorn.py``).
4848
49-
Sliding-window / compress-ratio attention is NOT yet implemented.
50-
All layers use full causal attention regardless of compress_ratios.
49+
Compress-ratio attention (Compressor + Indexer) is wired into
50+
DeepseekV4Attention.forward for layers with compress_ratio > 0.
51+
All layers share the same sliding-window causal mask on the local KV path.
5152
"""
5253

5354
from __future__ import annotations
@@ -473,7 +474,7 @@ def eager_attention_with_sink(
473474
sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
474475
combined = torch.cat([attn_weights, sinks.to(attn_weights.dtype)], dim=-1)
475476
combined = combined - combined.max(dim=-1, keepdim=True).values
476-
probs = F.softmax(combined, dim=-1, dtype=combined.dtype)[..., :-1]
477+
probs = F.softmax(combined, dim=-1, dtype=torch.float32)[..., :-1]
477478
probs = F.dropout(probs, p=dropout, training=module.training).to(value_states.dtype)
478479
return torch.matmul(probs, value_states).transpose(1, 2).contiguous(), probs
479480

0 commit comments

Comments
 (0)