Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
00b364c
Up
patrickvonplaten Nov 30, 2025
05db608
WIP
patrickvonplaten Nov 30, 2025
b852600
WIP
patrickvonplaten Nov 30, 2025
723db43
WIP
patrickvonplaten Nov 30, 2025
ce3015d
Apply suggestions from code review
patrickvonplaten Dec 1, 2025
2d4578a
Merge branch 'main' into add_ministral_3
patrickvonplaten Dec 1, 2025
8470dca
Update src/transformers/models/ministral3/configuration_ministral3.py
patrickvonplaten Dec 1, 2025
b7100b0
Merge branch 'main' of github.com:huggingface/transformers into add_m…
ArthurZucker Dec 1, 2025
d551a48
fix most tests
ArthurZucker Dec 1, 2025
4f2331f
update docsting
ArthurZucker Dec 1, 2025
4d550ae
fixup
ArthurZucker Dec 1, 2025
d9d8aa6
typo in the ocnfig
ArthurZucker Dec 1, 2025
580e2f1
make the last 3 tests pass
ArthurZucker Dec 1, 2025
d2331f2
fix auto
ArthurZucker Dec 1, 2025
44f1f30
nits
ArthurZucker Dec 1, 2025
19c2af3
WIP
patrickvonplaten Dec 1, 2025
0c5eb8f
Merge branch 'main' of github.com:huggingface/transformers into add_m…
ArthurZucker Dec 1, 2025
7481a4a
Merge branch 'add_ministral_3' of https://github.com/patrickvonplaten…
patrickvonplaten Dec 1, 2025
668de82
Merge branch 'add_ministral_3' of github.com:patrickvonplaten/transfo…
ArthurZucker Dec 1, 2025
2cd4f15
WIP
patrickvonplaten Dec 1, 2025
5cb33cf
Merge branch 'add_ministral_3' of https://github.com/patrickvonplaten…
patrickvonplaten Dec 1, 2025
128d37c
WIP
patrickvonplaten Dec 1, 2025
0640a36
per tensor
MekkCyber Dec 1, 2025
4048042
WIP
patrickvonplaten Dec 1, 2025
0154cd0
Merge branch 'add_ministral_3' of https://github.com/patrickvonplaten…
patrickvonplaten Dec 1, 2025
67b1619
WIP
patrickvonplaten Dec 1, 2025
142f794
WIP
patrickvonplaten Dec 1, 2025
70f89d0
style
ArthurZucker Dec 1, 2025
de7888c
Merge branch 'add_ministral_3' of github.com:patrickvonplaten/transfo…
ArthurZucker Dec 1, 2025
3ee6f61
Merge branch 'main' of github.com:huggingface/transformers into add_m…
ArthurZucker Dec 1, 2025
547ac70
fixup
ArthurZucker Dec 1, 2025
b212ea0
WIP
patrickvonplaten Dec 1, 2025
b020d3b
WIP
patrickvonplaten Dec 1, 2025
5b813a6
WIP
patrickvonplaten Dec 1, 2025
2e4e5ae
WIP
patrickvonplaten Dec 1, 2025
4aad125
hack for now
MekkCyber Dec 1, 2025
2b41984
add todo
MekkCyber Dec 1, 2025
c1528a4
fixup
ArthurZucker Dec 1, 2025
730d489
WIP
patrickvonplaten Dec 1, 2025
f46a88a
WIP
patrickvonplaten Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,8 @@
title: MiniMax
- local: model_doc/ministral
title: Ministral
- local: model_doc/ministral3
title: Ministral3
- local: model_doc/mistral
title: Mistral
- local: model_doc/mixtral
Expand Down
114 changes: 114 additions & 0 deletions docs/source/en/model_doc/ministral3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
<!--Copyright 2025 the HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be rendered properly in your Markdown viewer.

-->


# Ministral3

## Overview

A balanced model in the Ministral 3 family, Ministral 3 8B is a powerful, efficient tiny language model with vision capabilities.

This model is the instruct post-trained version, fine-tuned for instruction tasks, making it ideal for chat and instruction based use cases.

The Ministral 3 family is designed for edge deployment, capable of running on a wide range of hardware.

Key features:
- Vision: Enables the model to analyze images and provide insights based on visual content, in addition to text.
- Multilingual: Supports dozens of languages, including English, French, Spanish, German, Italian, Portuguese, Dutch, Chinese, Japanese, Korean, Arabic.
- System Prompt: Maintains strong adherence and support for system prompts.
- Agentic: Offers best-in-class agentic capabilities with native function calling and JSON outputting.
- Edge-Optimized: Delivers best-in-class performance at a small scale, deployable anywhere.
- Apache 2.0 License: Open-source license allowing usage and modification for both commercial and non-commercial purposes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based.

- Large Context Window: Supports a 256k context window.

## Usage examples

```py
import torch
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's verified that this example works

from transformers import Mistral3ForConditionalGeneration, MistralCommonBackend


model_id = "mistralai/Ministral-3-3B-Instruct-2512"

tokenizer = MistralCommonBackend.from_pretrained(model_id)
model = Mistral3ForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="auto"
)

image_url = "https://static.wikia.nocookie.net/essentialsdocs/images/7/70/Battle.png/revision/latest?cb=20220523172438"

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "What action do you think I should take in this situation? List all the possible actions and explain why you think they are good or bad.",
},
{"type": "image_url", "image_url": {"url": image_url}},
],
},
]

tokenized = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True)

tokenized["input_ids"] = tokenized["input_ids"].to(device="cuda")
tokenized["pixel_values"] = tokenized["pixel_values"].to(dtype=torch.bfloat16, device="cuda")
image_sizes = [tokenized["pixel_values"].shape[-2:]]

output = model.generate(
**tokenized,
image_sizes=image_sizes,
max_new_tokens=512,
)[0]

decoded_output = tokenizer.decode(output[len(tokenized["input_ids"][0]):])
print(decoded_output)
```


## Ministral3Config

[[autodoc]] Ministral3Config

## Ministral3PreTrainedModel

[[autodoc]] Ministral3PreTrainedModel
- forward

## Ministral3Model

[[autodoc]] Ministral3Model
- forward

## Ministral3ForCausalLM

[[autodoc]] Ministral3ForCausalLM

## Ministral3ForSequenceClassification

[[autodoc]] Ministral3ForSequenceClassification

## Ministral3ForTokenClassification

[[autodoc]] Ministral3ForTokenClassification

## Ministral3ForQuestionAnswering

[[autodoc]] Ministral3ForQuestionAnswering
59 changes: 40 additions & 19 deletions src/transformers/integrations/finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,15 @@ def w8a8_block_fp8_matmul_triton(
block_n, block_k = block_size[0], block_size[1]

assert A.shape[-1] == B.shape[-1]

assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]

assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
assert triton.cdiv(N, block_n) == Bs.shape[0], f"{N}, {block_n}, {Bs.shape}"
assert triton.cdiv(K, block_k) == Bs.shape[1], f"{K}, {block_k}, {Bs.shape}"

C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
Expand Down Expand Up @@ -322,21 +323,29 @@ def __init__(
self.in_features = in_features
self.out_features = out_features

if block_size is not None:
self.block_size = block_size
else:
self.block_size = (out_features, in_features)

self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))

if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
scale_out_features = (out_features + self.block_size[0] - 1) // self.block_size[0]
scale_in_features = (in_features + self.block_size[1] - 1) // self.block_size[1]
if scale_out_features * scale_in_features == 1:
self.weight_scale_inv = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))
else:
self.weight_scale_inv = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
)
else:
self.register_parameter("weight_scale_inv", None)

self.block_size = block_size

self.activation_scheme = activation_scheme

if self.activation_scheme == "static":
self.activation_scale = nn.Parameter(torch.tensor(1.0, dtype=torch.float32, device=device))

if bias:
self.bias = nn.Parameter(torch.empty(self.out_features))
else:
Expand All @@ -356,15 +365,27 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
with torch_accelerator_module.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
scale,
scale_inv,
self.block_size,
output_dtype=input.dtype,
)
if self.activation_scheme == "dynamic":
qinput, scale = act_quant(input, self.block_size[1])
elif self.activation_scheme == "static":
scale = self.activation_scale
qinput = (input / scale).to(torch.float8_e4m3fn)
else:
raise NotImplementedError("Not supported")
# TODO: fix this later to use the triton kernel
if self.activation_scheme == "static":
output = F.linear(qinput.to(torch.bfloat16), weight.to(torch.bfloat16), None) * scale_inv * scale
output = output.to(input.dtype)
else:
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
scale,
scale_inv,
self.block_size,
output_dtype=input.dtype,
)

# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch_accelerator_module.synchronize()
Expand Down
20 changes: 12 additions & 8 deletions src/transformers/integrations/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,23 @@ def convert_tekken_tokenizer(tokenizer_file: str):

# Extract vocab and special tokens
vocab = mistral_tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
all_special = [
token.get("token_str", str(token))
if isinstance(token, dict)
else (token.value if hasattr(token, "value") else str(token))
for token in mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens
]
specials_tokens = {token: all_special.index(token) for token in all_special}
sorted_tokens = sorted(mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens, key=lambda x: x["rank"])
all_special = [token["token_str"] for token in sorted_tokens]

specials_tokens = {token: idx for idx, token in enumerate(all_special)}

specials_tokens.update(vocab)
vocab = specials_tokens

# TODO(juliendenize): expose this in mistral-common to avoid accessing private attributes
# and improve maintainability
pattern = mistral_tokenizer.instruct_tokenizer.tokenizer._model._pat_str

# Convert
tokenizer = PreTrainedTokenizerFast(
tokenizer_object=MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted()
tokenizer_object=MistralConverter(
vocab=vocab, additional_special_tokens=all_special, pattern=pattern
).converted()
)

# Post-process
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@
from .mimi import *
from .minimax import *
from .ministral import *
from .ministral3 import *
from .mistral import *
from .mistral3 import *
from .mixtral import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@
("mimi", "MimiConfig"),
("minimax", "MiniMaxConfig"),
("ministral", "MinistralConfig"),
("ministral3", "Ministral3Config"),
("mistral", "MistralConfig"),
("mistral3", "Mistral3Config"),
("mixtral", "MixtralConfig"),
Expand Down Expand Up @@ -703,6 +704,7 @@
("mimi", "Mimi"),
("minimax", "MiniMax"),
("ministral", "Ministral"),
("ministral3", "Ministral3"),
("mistral", "Mistral"),
("mistral3", "Mistral3"),
("mixtral", "Mixtral"),
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("mimi", "MimiModel"),
("minimax", "MiniMaxModel"),
("ministral", "MinistralModel"),
("ministral3", "Ministral3Model"),
("mistral", "MistralModel"),
("mistral3", "Mistral3Model"),
("mixtral", "MixtralModel"),
Expand Down Expand Up @@ -700,6 +701,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("megatron-bert", "MegatronBertForCausalLM"),
("minimax", "MiniMaxForCausalLM"),
("ministral", "MinistralForCausalLM"),
("ministral3", "Ministral3ForCausalLM"),
("mistral", "MistralForCausalLM"),
("mixtral", "MixtralForCausalLM"),
("mllama", "MllamaForCausalLM"),
Expand Down Expand Up @@ -1254,6 +1256,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("megatron-bert", "MegatronBertForSequenceClassification"),
("minimax", "MiniMaxForSequenceClassification"),
("ministral", "MinistralForSequenceClassification"),
("ministral3", "Ministral3ForSequenceClassification"),
("mistral", "MistralForSequenceClassification"),
("mixtral", "MixtralForSequenceClassification"),
("mobilebert", "MobileBertForSequenceClassification"),
Expand Down Expand Up @@ -1349,6 +1352,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("megatron-bert", "MegatronBertForQuestionAnswering"),
("minimax", "MiniMaxForQuestionAnswering"),
("ministral", "MinistralForQuestionAnswering"),
("ministral3", "Ministral3ForQuestionAnswering"),
("mistral", "MistralForQuestionAnswering"),
("mixtral", "MixtralForQuestionAnswering"),
("mobilebert", "MobileBertForQuestionAnswering"),
Expand Down Expand Up @@ -1461,6 +1465,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("megatron-bert", "MegatronBertForTokenClassification"),
("minimax", "MiniMaxForTokenClassification"),
("ministral", "MinistralForTokenClassification"),
("ministral3", "Ministral3ForTokenClassification"),
("mistral", "MistralForTokenClassification"),
("mixtral", "MixtralForTokenClassification"),
("mobilebert", "MobileBertForTokenClassification"),
Expand Down
23 changes: 22 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,30 @@
("metaclip_2", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None),
("mgp-str", "MgpstrTokenizer"),
("minimax", "GPT2Tokenizer" if is_tokenizers_available() else None),
(
"ministral3",
(
"MistralCommonBackend"
if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
),
),
(
"mistral",
"MistralCommonBackend"
if is_mistral_common_available()
else ("LlamaTokenizerFast" if is_tokenizers_available() else None),
),
(
"mistral3",
(
"MistralCommonBackend"
if is_mistral_common_available()
else ("LlamaTokenizer" if is_sentencepiece_available() else None),
"LlamaTokenizerFast" if is_tokenizers_available() and not is_mistral_common_available() else None,
),
),
(
"mixtral",
"MistralCommonBackend"
Expand Down Expand Up @@ -384,7 +402,10 @@ def tokenizer_class_from_name(class_name: str) -> Union[type[Any], None]:
for module_name, tokenizer_class in TOKENIZER_MAPPING_NAMES.items():
if tokenizer_class == class_name:
module_name = model_type_to_module_name(module_name)
if module_name in ["mistral", "mixtral", "ministral"] and class_name == "MistralCommonBackend":
if (
module_name in ["mistral", "mistral3", "mixtral", "ministral", "ministral3", "pixtral", "voxtral"]
and class_name == "MistralCommonTokenizer"
):
module = importlib.import_module(".tokenization_mistral_common", "transformers")
else:
module = importlib.import_module(f".{module_name}", "transformers.models")
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/ministral3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 Mistral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_ministral3 import *
from .modeling_ministral3 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Loading