Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def AscendCudaGraphMixin_make_buffers_cudagraph(
input_buffers["x_active_mask"] = torch.zeros(
(max_batches), dtype=torch.bool, device=device
)

# ssm
if graph_meta.is_ssm:
input_buffers["state_ids"] = torch.full(
(max_batches,), -1, dtype=torch.int64, device=device
)

# mrope
if graph_meta.use_mrope:
input_buffers["mrope_position_ids"] = torch.zeros(
3, max_tokens, dtype=torch.int64, device=device
)

return input_buffers


Expand All @@ -94,6 +107,8 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
kv_start_indices: Tensor = attn_metadata.kv_start_indices
moe_metadata = get_step_ctx_manager().current_context().moe_metadata
x_active_mask: Tensor = moe_metadata.x_active_mask
q_start_loc: Tensor = attn_metadata.q_start_loc

input_buffers: BuffType = graph_meta.input_buffers

batch_size, num_blocks = block_offsets.size()
Expand All @@ -119,6 +134,15 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
input_buffers["x_active_mask"].fill_(0)
input_buffers["x_active_mask"][:batch_size] = x_active_mask

# ssm
if graph_meta.is_ssm:
input_buffers["q_start_loc"][: batch_size + 1] = q_start_loc
input_buffers["q_start_loc"][batch_size + 1 :] = q_start_loc[-1]

state_ids = kwargs["state_ids"]
input_buffers["state_ids"].fill_(-1)
input_buffers["state_ids"][: state_ids.size(0)].copy_(state_ids)

if inputs_embeds is not None:
emb_size = inputs_embeds.size(-1)
if "inputs_embeds" not in input_buffers:
Expand All @@ -135,6 +159,7 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(
attn_metadata.kv_seqlens = input_buffers["kv_seqlens"]
attn_metadata.kv_start_indices = input_buffers["kv_start_indices"]
moe_metadata.x_active_mask = input_buffers["x_active_mask"]
attn_metadata.q_start_loc = input_buffers["q_start_loc"]

new_inputs = dict(
past_key_values=past_key_values,
Expand All @@ -150,6 +175,18 @@ def AscendCudaGraphMixin_fill_buffers_cudagraph(

new_inputs.update(kwargs)

# ssm: override kwargs' variable-length state_ids with the fixed-size buffer
if graph_meta.is_ssm:
new_inputs["state_ids"] = input_buffers["state_ids"]

# mrope
if graph_meta.use_mrope:
mrope_position_ids = kwargs.get("mrope_position_ids", None)
if mrope_position_ids is not None:
input_buffers["mrope_position_ids"].zero_()
input_buffers["mrope_position_ids"][:, :num_tokens] = mrope_position_ids
new_inputs["mrope_position_ids"] = input_buffers["mrope_position_ids"]

return new_inputs


Expand All @@ -163,6 +200,14 @@ def AscendCudaGraphMixin_update_context_cudagraph(self, graph_meta, context):
context.kv_start_indices = input_buffers["kv_start_indices"]
context.moe_metadata.x_active_mask = input_buffers["x_active_mask"]

# ssm
if graph_meta.is_ssm:
context.state_offsets = input_buffers["state_ids"]

# mrope
if graph_meta.use_mrope:
context.mrope_position_ids = input_buffers["mrope_position_ids"]


CudaGraphMixin.make_buffers_cudagraph = AscendCudaGraphMixin_make_buffers_cudagraph
CudaGraphMixin.fill_buffers_cudagraph = AscendCudaGraphMixin_fill_buffers_cudagraph
Expand Down Expand Up @@ -254,6 +299,8 @@ def __init__(
input_buffers=dict(),
output_buffers=dict(),
vocab_size=self.model_config.vocab_size,
is_ssm=len(model_config.states_shapes) > 0,
use_mrope=model_config.use_mrope,
)
self.device = device
self.max_batches = max_batches
Expand Down
Loading
Loading