2929"""
3030
3131import fnmatch
32- import functools
3332import json
3433import os
3534from typing import Any
3635
37- import torch
3836from fakequant_worker import disable_compilation
37+ from vllm .attention .layer import Attention as VLLMAttention
3938from vllm .v1 .worker .gpu_worker import Worker as BaseWorker
4039
4140import modelopt .torch .sparsity .attention_sparsity as mtsa
42- from modelopt .torch .kernels .triton_fa import attention as triton_attention
41+ from modelopt .torch .sparsity .attention_sparsity .plugins .vllm import (
42+ ModelOptSparseAttentionImpl ,
43+ set_sparse_config ,
44+ )
4345
4446# ---------------------------------------------------------------------------
4547# Configuration from environment variables
@@ -117,115 +119,43 @@ def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None:
117119 return None
118120
119121
120- def _sparse_attention_forward (module , query , key , value , kv_cache , attn_metadata , ** kwargs ):
121- """Sparse attention forward — used by SparseQuantWorker for direct module patching."""
122- if not getattr (module , "_sparse_enabled" , False ):
123- return module ._original_forward (query , key , value , kv_cache , attn_metadata , ** kwargs )
122+ def _replace_attention_impl (worker , config : dict ):
123+ """Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers.
124124
125- from vllm ._custom_ops import reshape_and_cache_flash
126-
127- reshape_and_cache_flash (
128- key ,
129- value ,
130- kv_cache ,
131- attn_metadata .slot_mapping ,
132- module .impl .kv_cache_dtype ,
133- getattr (module .impl , "k_scale" , 1.0 ),
134- getattr (module .impl , "v_scale" , 1.0 ),
135- )
136-
137- # Unpack paged KV cache
138- k_cache = kv_cache [:, 0 ] # [num_blocks, page_size, num_kv_heads, head_dim]
139- v_cache = kv_cache [:, 1 ]
140- page_size = k_cache .shape [1 ]
141-
142- output = torch .empty_like (query )
143- sm_scale = module .impl .scale
144- sparse_kw = module ._sparse_kw
145-
146- # Paged KV kwargs
147- paged_kw = {
148- "k_cache" : k_cache ,
149- "v_cache" : v_cache ,
150- "page_size" : page_size ,
151- }
152-
153- if attn_metadata .num_prefill_tokens > 0 :
154- pm = attn_metadata .prefill
155- n = attn_metadata .num_prefill_tokens
156- output [:n ] = triton_attention (
157- q = query [:n ],
158- k = query [:0 ], # dummy, not used in paged mode
159- v = query [:0 ],
160- b_start_loc = pm .query_start_loc ,
161- b_seq_len = pm .seq_lens_q ,
162- max_input_len = int (pm .seq_lens_q .max ().item ()),
163- is_causal = True ,
164- softmax_scale = sm_scale ,
165- b_seq_len_k = pm .seq_lens ,
166- max_input_len_k = int (pm .seq_lens .max ().item ()),
167- block_table = pm .block_tables ,
168- ** paged_kw ,
169- ** sparse_kw ,
170- )
171-
172- if attn_metadata .num_decode_tokens > 0 :
173- dm = attn_metadata .decode
174- offset = attn_metadata .num_prefill_tokens
175- nd = attn_metadata .num_decode_tokens
176- output [offset : offset + nd ] = triton_attention (
177- q = query [offset : offset + nd ],
178- k = query [:0 ], # dummy, not used in paged mode
179- v = query [:0 ],
180- b_start_loc = dm .query_start_loc ,
181- b_seq_len = torch .ones (nd , dtype = torch .int32 , device = query .device ),
182- max_input_len = 1 ,
183- is_causal = True ,
184- softmax_scale = sm_scale ,
185- b_seq_len_k = dm .seq_lens ,
186- max_input_len_k = int (dm .seq_lens .max ().item ()),
187- block_table = dm .block_tables ,
188- ** paged_kw ,
189- ** sparse_kw ,
190- )
191-
192- return output
193-
194-
195- def _apply_sparse_to_attention_modules (model , sparse_cfg : dict ):
196- """Walk model modules, patch attention layers with sparse forward.
197-
198- Used by SparseQuantWorker where registry-based mtsa.sparsify() cannot
199- find already-quantized attention modules (forward identity check fails).
125+ Shared by SparseAttnWorker and SparseQuantWorker.
200126 """
201- from vllm .attention .layer import Attention as VLLMAttention
202-
203- for name , module in model .named_modules ():
204- if not isinstance (module , VLLMAttention ):
205- continue
127+ if config ["calib_config_path" ]:
128+ cfg = _load_sparse_config (config ["calib_config_path" ])
129+ else :
130+ cfg = _build_sparse_config (config )
206131
207- layer_cfg = _match_sparse_config (name , sparse_cfg )
208- if layer_cfg is None or not layer_cfg .get ("enable" , True ):
209- continue
132+ if cfg is None :
133+ return
210134
211- # Build kernel kwargs from layer config
212- sparse_kw = {}
213- sparsity_n = layer_cfg .get ("sparsity_n" , 0 )
214- if sparsity_n > 0 :
215- sparse_kw ["sparsity_n" ] = sparsity_n
216- sparse_kw ["sparsity_m" ] = layer_cfg .get ("sparsity_m" , 4 )
217- sparse_kw ["num_sink_tokens" ] = layer_cfg .get ("num_sink_tokens" , 0 )
218- sparse_kw ["dense_window_size" ] = layer_cfg .get ("dense_window_size" , 1 )
219- threshold = layer_cfg .get ("skip_softmax_threshold" , None )
220- if threshold :
221- sparse_kw ["skip_softmax_threshold" ] = threshold
135+ set_sparse_config (cfg )
222136
223- module ._sparse_enabled = True
224- module ._sparse_kw = sparse_kw
137+ model = worker .model_runner .model
138+ if hasattr (model , "unwrap" ):
139+ model = model .unwrap ()
225140
226- original_forward = module .forward
227- module ._original_forward = original_forward
228- module .forward = functools .partial (_sparse_attention_forward , module )
141+ patched = 0
142+ for name , module in model .named_modules ():
143+ if isinstance (module , VLLMAttention ):
144+ old_impl = module .impl
145+ module .impl = ModelOptSparseAttentionImpl (
146+ num_heads = old_impl .num_heads ,
147+ head_size = old_impl .head_size ,
148+ scale = old_impl .scale ,
149+ num_kv_heads = old_impl .num_kv_heads ,
150+ alibi_slopes = old_impl .alibi_slopes ,
151+ sliding_window = None ,
152+ kv_cache_dtype = old_impl .kv_cache_dtype ,
153+ logits_soft_cap = old_impl .logits_soft_cap ,
154+ attn_type = old_impl .attn_type ,
155+ kv_sharing_target_layer_name = old_impl .kv_sharing_target_layer_name ,
156+ )
157+ patched += 1
158+ print (f"[ModelOpt] Sparse attention: replaced impl on { patched } attention layers" )
229159
230160
231161# ---------------------------------------------------------------------------
@@ -244,78 +174,32 @@ class SparseAttnWorker(BaseWorker):
244174 def load_model (self , * args , ** kwargs ) -> None :
245175 """Load model, then replace attention impl with sparse variant."""
246176 super ().load_model (* args , ** kwargs )
247-
248- if sparse_config ["calib_config_path" ]:
249- cfg = _load_sparse_config (sparse_config ["calib_config_path" ])
250- else :
251- cfg = _build_sparse_config (sparse_config )
252-
253- if cfg is None :
254- return
255-
256- from modelopt .torch .sparsity .attention_sparsity .plugins .vllm import (
257- ModelOptSparseAttentionImpl ,
258- set_sparse_config ,
259- )
260-
261- set_sparse_config (cfg )
262-
263- from vllm .attention .layer import Attention as VLLMAttention
264-
265- model = self .model_runner .model
266- if hasattr (model , "unwrap" ):
267- model = model .unwrap ()
268-
269- patched = 0
270- for name , module in model .named_modules ():
271- if isinstance (module , VLLMAttention ):
272- old_impl = module .impl
273- module .impl = ModelOptSparseAttentionImpl (
274- num_heads = old_impl .num_heads ,
275- head_size = old_impl .head_size ,
276- scale = old_impl .scale ,
277- num_kv_heads = old_impl .num_kv_heads ,
278- alibi_slopes = old_impl .alibi_slopes ,
279- sliding_window = None ,
280- kv_cache_dtype = old_impl .kv_cache_dtype ,
281- logits_soft_cap = old_impl .logits_soft_cap ,
282- attn_type = old_impl .attn_type ,
283- kv_sharing_target_layer_name = old_impl .kv_sharing_target_layer_name ,
284- )
285- patched += 1
286- print (f"[ModelOpt] Sparse attention: replaced impl on { patched } attention layers" )
177+ _replace_attention_impl (self , sparse_config )
287178
288179
289180class SparseQuantWorker (BaseWorker ):
290181 """vLLM worker that applies quantization + sparse attention.
291182
292183 Quantization uses the standard registry-based ``mtq.quantize()``.
293- Sparse attention uses direct module walk because the registry cannot
294- match already-quantized attention modules (forward identity check ).
184+ Sparse attention replaces FlashAttentionImpl with ModelOptSparseAttentionImpl
185+ (same approach as SparseAttnWorker ).
295186 """
296187
188+ def load_model (self , * args , ** kwargs ) -> None :
189+ """Load model, then replace attention impl with sparse variant."""
190+ super ().load_model (* args , ** kwargs )
191+ _replace_attention_impl (self , sparse_config )
192+
297193 def compile_or_warm_up_model (self ) -> None :
298- """Apply quantization then sparse attention before warm-up."""
299- from . fakequant_worker import _fakequant_run_prolog_worker , quant_config
194+ """Apply quantization before warm-up."""
195+ from fakequant_worker import _fakequant_run_prolog_worker , quant_config
300196
301197 model = self .model_runner .model
302198 if hasattr (model , "unwrap" ):
303199 model = model .unwrap ()
304200
305201 with disable_compilation (model ):
306- # Step 1: Quantize
307202 if quant_config ["quant_cfg" ] or quant_config ["kv_quant_cfg" ]:
308203 _fakequant_run_prolog_worker (self )
309204
310- # Step 2: Apply sparse attention via direct module walk
311- if sparse_config ["calib_config_path" ]:
312- cfg = _load_sparse_config (sparse_config ["calib_config_path" ])
313- elif sparse_config ["sparse_cfg" ]:
314- cfg = _build_sparse_config (sparse_config )
315- else :
316- cfg = None
317-
318- if cfg is not None :
319- _apply_sparse_to_attention_modules (model , cfg )
320-
321205 super ().compile_or_warm_up_model ()
0 commit comments