Skip to content

Commit 4644bf5

Browse files
committed
Remove the monkey-patch code and unify impl replacement
Signed-off-by: Kai Xu <kaix@nvidia.com>
1 parent e4c4680 commit 4644bf5

1 file changed

Lines changed: 46 additions & 162 deletions

File tree

examples/vllm_serve/sparse_attn_worker.py

Lines changed: 46 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@
2929
"""
3030

3131
import fnmatch
32-
import functools
3332
import json
3433
import os
3534
from typing import Any
3635

37-
import torch
3836
from fakequant_worker import disable_compilation
37+
from vllm.attention.layer import Attention as VLLMAttention
3938
from vllm.v1.worker.gpu_worker import Worker as BaseWorker
4039

4140
import modelopt.torch.sparsity.attention_sparsity as mtsa
42-
from modelopt.torch.kernels.triton_fa import attention as triton_attention
41+
from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import (
42+
ModelOptSparseAttentionImpl,
43+
set_sparse_config,
44+
)
4345

4446
# ---------------------------------------------------------------------------
4547
# Configuration from environment variables
@@ -117,115 +119,43 @@ def _match_sparse_config(module_name: str, sparse_cfg: dict) -> dict | None:
117119
return None
118120

119121

120-
def _sparse_attention_forward(module, query, key, value, kv_cache, attn_metadata, **kwargs):
121-
"""Sparse attention forward — used by SparseQuantWorker for direct module patching."""
122-
if not getattr(module, "_sparse_enabled", False):
123-
return module._original_forward(query, key, value, kv_cache, attn_metadata, **kwargs)
122+
def _replace_attention_impl(worker, config: dict):
123+
"""Replace FlashAttentionImpl with ModelOptSparseAttentionImpl on all Attention layers.
124124
125-
from vllm._custom_ops import reshape_and_cache_flash
126-
127-
reshape_and_cache_flash(
128-
key,
129-
value,
130-
kv_cache,
131-
attn_metadata.slot_mapping,
132-
module.impl.kv_cache_dtype,
133-
getattr(module.impl, "k_scale", 1.0),
134-
getattr(module.impl, "v_scale", 1.0),
135-
)
136-
137-
# Unpack paged KV cache
138-
k_cache = kv_cache[:, 0] # [num_blocks, page_size, num_kv_heads, head_dim]
139-
v_cache = kv_cache[:, 1]
140-
page_size = k_cache.shape[1]
141-
142-
output = torch.empty_like(query)
143-
sm_scale = module.impl.scale
144-
sparse_kw = module._sparse_kw
145-
146-
# Paged KV kwargs
147-
paged_kw = {
148-
"k_cache": k_cache,
149-
"v_cache": v_cache,
150-
"page_size": page_size,
151-
}
152-
153-
if attn_metadata.num_prefill_tokens > 0:
154-
pm = attn_metadata.prefill
155-
n = attn_metadata.num_prefill_tokens
156-
output[:n] = triton_attention(
157-
q=query[:n],
158-
k=query[:0], # dummy, not used in paged mode
159-
v=query[:0],
160-
b_start_loc=pm.query_start_loc,
161-
b_seq_len=pm.seq_lens_q,
162-
max_input_len=int(pm.seq_lens_q.max().item()),
163-
is_causal=True,
164-
softmax_scale=sm_scale,
165-
b_seq_len_k=pm.seq_lens,
166-
max_input_len_k=int(pm.seq_lens.max().item()),
167-
block_table=pm.block_tables,
168-
**paged_kw,
169-
**sparse_kw,
170-
)
171-
172-
if attn_metadata.num_decode_tokens > 0:
173-
dm = attn_metadata.decode
174-
offset = attn_metadata.num_prefill_tokens
175-
nd = attn_metadata.num_decode_tokens
176-
output[offset : offset + nd] = triton_attention(
177-
q=query[offset : offset + nd],
178-
k=query[:0], # dummy, not used in paged mode
179-
v=query[:0],
180-
b_start_loc=dm.query_start_loc,
181-
b_seq_len=torch.ones(nd, dtype=torch.int32, device=query.device),
182-
max_input_len=1,
183-
is_causal=True,
184-
softmax_scale=sm_scale,
185-
b_seq_len_k=dm.seq_lens,
186-
max_input_len_k=int(dm.seq_lens.max().item()),
187-
block_table=dm.block_tables,
188-
**paged_kw,
189-
**sparse_kw,
190-
)
191-
192-
return output
193-
194-
195-
def _apply_sparse_to_attention_modules(model, sparse_cfg: dict):
196-
"""Walk model modules, patch attention layers with sparse forward.
197-
198-
Used by SparseQuantWorker where registry-based mtsa.sparsify() cannot
199-
find already-quantized attention modules (forward identity check fails).
125+
Shared by SparseAttnWorker and SparseQuantWorker.
200126
"""
201-
from vllm.attention.layer import Attention as VLLMAttention
202-
203-
for name, module in model.named_modules():
204-
if not isinstance(module, VLLMAttention):
205-
continue
127+
if config["calib_config_path"]:
128+
cfg = _load_sparse_config(config["calib_config_path"])
129+
else:
130+
cfg = _build_sparse_config(config)
206131

207-
layer_cfg = _match_sparse_config(name, sparse_cfg)
208-
if layer_cfg is None or not layer_cfg.get("enable", True):
209-
continue
132+
if cfg is None:
133+
return
210134

211-
# Build kernel kwargs from layer config
212-
sparse_kw = {}
213-
sparsity_n = layer_cfg.get("sparsity_n", 0)
214-
if sparsity_n > 0:
215-
sparse_kw["sparsity_n"] = sparsity_n
216-
sparse_kw["sparsity_m"] = layer_cfg.get("sparsity_m", 4)
217-
sparse_kw["num_sink_tokens"] = layer_cfg.get("num_sink_tokens", 0)
218-
sparse_kw["dense_window_size"] = layer_cfg.get("dense_window_size", 1)
219-
threshold = layer_cfg.get("skip_softmax_threshold", None)
220-
if threshold:
221-
sparse_kw["skip_softmax_threshold"] = threshold
135+
set_sparse_config(cfg)
222136

223-
module._sparse_enabled = True
224-
module._sparse_kw = sparse_kw
137+
model = worker.model_runner.model
138+
if hasattr(model, "unwrap"):
139+
model = model.unwrap()
225140

226-
original_forward = module.forward
227-
module._original_forward = original_forward
228-
module.forward = functools.partial(_sparse_attention_forward, module)
141+
patched = 0
142+
for name, module in model.named_modules():
143+
if isinstance(module, VLLMAttention):
144+
old_impl = module.impl
145+
module.impl = ModelOptSparseAttentionImpl(
146+
num_heads=old_impl.num_heads,
147+
head_size=old_impl.head_size,
148+
scale=old_impl.scale,
149+
num_kv_heads=old_impl.num_kv_heads,
150+
alibi_slopes=old_impl.alibi_slopes,
151+
sliding_window=None,
152+
kv_cache_dtype=old_impl.kv_cache_dtype,
153+
logits_soft_cap=old_impl.logits_soft_cap,
154+
attn_type=old_impl.attn_type,
155+
kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name,
156+
)
157+
patched += 1
158+
print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers")
229159

230160

231161
# ---------------------------------------------------------------------------
@@ -244,78 +174,32 @@ class SparseAttnWorker(BaseWorker):
244174
def load_model(self, *args, **kwargs) -> None:
245175
"""Load model, then replace attention impl with sparse variant."""
246176
super().load_model(*args, **kwargs)
247-
248-
if sparse_config["calib_config_path"]:
249-
cfg = _load_sparse_config(sparse_config["calib_config_path"])
250-
else:
251-
cfg = _build_sparse_config(sparse_config)
252-
253-
if cfg is None:
254-
return
255-
256-
from modelopt.torch.sparsity.attention_sparsity.plugins.vllm import (
257-
ModelOptSparseAttentionImpl,
258-
set_sparse_config,
259-
)
260-
261-
set_sparse_config(cfg)
262-
263-
from vllm.attention.layer import Attention as VLLMAttention
264-
265-
model = self.model_runner.model
266-
if hasattr(model, "unwrap"):
267-
model = model.unwrap()
268-
269-
patched = 0
270-
for name, module in model.named_modules():
271-
if isinstance(module, VLLMAttention):
272-
old_impl = module.impl
273-
module.impl = ModelOptSparseAttentionImpl(
274-
num_heads=old_impl.num_heads,
275-
head_size=old_impl.head_size,
276-
scale=old_impl.scale,
277-
num_kv_heads=old_impl.num_kv_heads,
278-
alibi_slopes=old_impl.alibi_slopes,
279-
sliding_window=None,
280-
kv_cache_dtype=old_impl.kv_cache_dtype,
281-
logits_soft_cap=old_impl.logits_soft_cap,
282-
attn_type=old_impl.attn_type,
283-
kv_sharing_target_layer_name=old_impl.kv_sharing_target_layer_name,
284-
)
285-
patched += 1
286-
print(f"[ModelOpt] Sparse attention: replaced impl on {patched} attention layers")
177+
_replace_attention_impl(self, sparse_config)
287178

288179

289180
class SparseQuantWorker(BaseWorker):
290181
"""vLLM worker that applies quantization + sparse attention.
291182
292183
Quantization uses the standard registry-based ``mtq.quantize()``.
293-
Sparse attention uses direct module walk because the registry cannot
294-
match already-quantized attention modules (forward identity check).
184+
Sparse attention replaces FlashAttentionImpl with ModelOptSparseAttentionImpl
185+
(same approach as SparseAttnWorker).
295186
"""
296187

188+
def load_model(self, *args, **kwargs) -> None:
189+
"""Load model, then replace attention impl with sparse variant."""
190+
super().load_model(*args, **kwargs)
191+
_replace_attention_impl(self, sparse_config)
192+
297193
def compile_or_warm_up_model(self) -> None:
298-
"""Apply quantization then sparse attention before warm-up."""
299-
from .fakequant_worker import _fakequant_run_prolog_worker, quant_config
194+
"""Apply quantization before warm-up."""
195+
from fakequant_worker import _fakequant_run_prolog_worker, quant_config
300196

301197
model = self.model_runner.model
302198
if hasattr(model, "unwrap"):
303199
model = model.unwrap()
304200

305201
with disable_compilation(model):
306-
# Step 1: Quantize
307202
if quant_config["quant_cfg"] or quant_config["kv_quant_cfg"]:
308203
_fakequant_run_prolog_worker(self)
309204

310-
# Step 2: Apply sparse attention via direct module walk
311-
if sparse_config["calib_config_path"]:
312-
cfg = _load_sparse_config(sparse_config["calib_config_path"])
313-
elif sparse_config["sparse_cfg"]:
314-
cfg = _build_sparse_config(sparse_config)
315-
else:
316-
cfg = None
317-
318-
if cfg is not None:
319-
_apply_sparse_to_attention_modules(model, cfg)
320-
321205
super().compile_or_warm_up_model()

0 commit comments

Comments
 (0)