-
Notifications
You must be signed in to change notification settings - Fork 88
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Seems like torch.compile doesn't like using views on dtypes. This causes the PYTORCH_COMPILE backend and model=torch.compile(model) to break when view_as_float is set to True :
BackendCompilerFailed: backend='inductor' raised:
LoweringException: NotImplementedError: bitcast torch.float16 to different bitwidth type torch.uint8 is not supported yet.Wrapping the view with torch.jit.ignore doesn't work in this case.
Minimal code to reproduce the issue:
import torch
from hqq.core.quantize import *
HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
#######################################################################################
batch_size = 1
context_size = 512
compute_dtype = torch.float16
linear_layer = torch.nn.Linear(4096, 4096)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, offload_meta=False, view_as_float=True)
hqq_linear = HQQLinear(linear_layer, quant_config, compute_dtype=compute_dtype, del_orig=False)
@torch.jit.ignore
def dequantize_Wq_aten(W_q, meta):
if meta['view_as_float']: W_q = W_q.view(meta['unpack_view_dtype'])
return hqq_aten.dequantize(W_q, meta['scale'], meta['zero'], meta['shape'], meta['group_size'] if (meta['group_size']) else -1, meta['nbits'], meta['axis'], meta['packing'])
@torch.compile()
def dequantize(hqq_layer):
return dequantize_Wq_aten(hqq_layer.W_q, hqq_layer.meta)
######################################################################################
#This works:
hqq_linear.W_q.data = hqq_linear.W_q.data.view(hqq_linear.meta['unpack_view_dtype'])
W_r = dequantize(hqq_linear)
#This breaks
hqq_linear.W_q.data = hqq_linear.W_q.data.view(compute_dtype)
W_r = dequantize(hqq_linear)A work around would be moving the view call outside dequantize but this will make the code more complicated and will require another call to revert back to float bitpacking.
This is mainly a Pytorch bug, so I created the issue there as well: pytorch/pytorch#120998
@KeremTurgutlu fyi
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working