|
21 | 21 | except ImportError: |
22 | 22 | HAS_AWQ = False |
23 | 23 |
|
| 24 | +HAS_EETQ = False |
| 25 | +try: |
| 26 | + from EETQ import quant_weights, w8_a16_gemm |
| 27 | + |
| 28 | + HAS_EETQ = True |
| 29 | +except ImportError: |
| 30 | + pass |
| 31 | + |
24 | 32 | HAS_HQQ = True |
25 | 33 | try: |
26 | 34 | from hqq.core.quantize import BaseQuantizeConfig, HQQLinear |
@@ -127,6 +135,78 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: |
127 | 135 | x = torch.addmm(self.bias, input.view(-1, input.size(-1)), self.weight) |
128 | 136 | x = x.view(size_out) |
129 | 137 | return x |
| 138 | + |
| 139 | + |
| 140 | +class EETQLinear(nn.Module): |
| 141 | + """ |
| 142 | + EETQLinear module applies quantized linear transformation to the input tensor. |
| 143 | +
|
| 144 | + Args: |
| 145 | + weight (torch.Tensor): The weight tensor for the linear transformation. |
| 146 | + bias (torch.Tensor): The bias tensor for the linear transformation. |
| 147 | +
|
| 148 | + Attributes: |
| 149 | + weight (torch.Tensor): The weight tensor for the linear transformation. |
| 150 | + scale (torch.Tensor): The scale tensor used for quantization. |
| 151 | + bias (torch.Tensor): The bias tensor for the linear transformation. |
| 152 | +
|
| 153 | + """ |
| 154 | + |
| 155 | + def __init__( |
| 156 | + self, |
| 157 | + weight, |
| 158 | + bias, |
| 159 | + ) -> None: |
| 160 | + super().__init__() |
| 161 | + # Get the device where the weight tensor is currently stored. |
| 162 | + device = weight.device |
| 163 | + |
| 164 | + # Transpose the weight tensor and make a contiguous copy of it on the CPU. |
| 165 | + # The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory, |
| 166 | + # which can improve performance in some cases. |
| 167 | + weight_transposed = torch.t(weight) |
| 168 | + weight_contiguous = weight_transposed.contiguous() |
| 169 | + weight_cpu = weight_contiguous.cpu() |
| 170 | + |
| 171 | + # Quantize the weights. The quant_weights function is assumed to perform the quantization. |
| 172 | + # The weights are quantized to int8 format, and the quantization is not performed in place (False). |
| 173 | + weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False) |
| 174 | + |
| 175 | + # Move the quantized weights and the scale back to the original device (GPU if available). |
| 176 | + # The cuda() function is used to move the tensors to the GPU. |
| 177 | + self.weight = weight_quantized.cuda(device) |
| 178 | + self.scale = scale.cuda(device) |
| 179 | + |
| 180 | + # If a bias is present, move it to the GPU as well. If not, set the bias to None. |
| 181 | + if bias is not None: |
| 182 | + self.bias = bias.cuda(device) |
| 183 | + else: |
| 184 | + self.bias = None |
| 185 | + |
| 186 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 187 | + """ |
| 188 | + Performs the forward pass of the layer. |
| 189 | +
|
| 190 | + Args: |
| 191 | + input (torch.Tensor): The input tensor. |
| 192 | +
|
| 193 | + Returns: |
| 194 | + torch.Tensor: The output tensor. |
| 195 | + """ |
| 196 | + # The function w8_a16_gemm performs a matrix multiplication operation between the input and the weight of the layer. |
| 197 | + # The result is then scaled by a factor (self.scale). |
| 198 | + gemm_output = w8_a16_gemm(input, self.weight, self.scale) |
| 199 | + |
| 200 | + # If a bias is present (i.e., self.bias is not None), it is added to the output of the matrix multiplication. |
| 201 | + # If a bias is not present (i.e., self.bias is None), the output of the matrix multiplication is returned as is. |
| 202 | + if self.bias is not None: |
| 203 | + final_output = gemm_output + self.bias |
| 204 | + else: |
| 205 | + final_output = gemm_output |
| 206 | + |
| 207 | + # The final output is returned. |
| 208 | + return final_output |
| 209 | + |
130 | 210 |
|
131 | 211 |
|
132 | 212 | class Linear8bitLt(nn.Module): |
@@ -254,6 +334,13 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): |
254 | 334 | bias, |
255 | 335 | quant_type="fp4", |
256 | 336 | ) |
| 337 | + elif quantize == "eetq": |
| 338 | + if HAS_EETQ: |
| 339 | + linear = EETQLinear(weight, bias) |
| 340 | + else: |
| 341 | + raise ImportError( |
| 342 | + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" |
| 343 | + ) |
257 | 344 | elif quantize == "gptq": |
258 | 345 | try: |
259 | 346 | qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight |
|
0 commit comments