Skip to content

Whisper processor.batch_decode() function ignoring skip_special_tokens params #44811

@cfasana

Description

@cfasana

System Info

  • transformers version: 4.57.6
  • Platform: Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.36.2
  • Safetensors version: 0.7.0
  • Accelerate version: 1.13.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.10.0+cu130 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: NVIDIA GeForce RTX 3080 Ti Laptop GPU

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Running the code provided here, it seems that skip_special_tokens param is being ignored.

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
model.config.forced_decoder_ids = None

# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

# generate token ids
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
print("[Skip special tokens=False] ", transcription)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
print("[Skip special tokens=True] ", transcription)

The output is the following:
[Skip special tokens=False] [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.']
[Skip special tokens=True] [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.']

If however a dict output is required by running the following code:

# generate token ids
predicted_ids = model.generate(input_features, return_dict_in_generate=True)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids.sequences, skip_special_tokens=False)
print("[Skip special tokens=False] ", transcription)
transcription = processor.batch_decode(predicted_ids.sequences, skip_special_tokens=True)
print("[Skip special tokens=True] ", transcription)

The output is (correctly) the following:
[Skip special tokens=False] ['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.<|endoftext|>']
[Skip special tokens=True] [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.']

Expected behavior

The processor.batch_decode() function should output also special tokens if skip_special_tokens=False is passed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions