Skip to content

Commit 1be43cf

Browse files
authored
Eetq (#195)
1 parent 5e68edd commit 1be43cf

7 files changed

Lines changed: 117 additions & 1 deletion

File tree

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,4 @@ jobs:
151151
152152
# Delete the SHA image(s) from containerd store
153153
sudo ctr i rm $(sudo ctr i ls -q)
154-
154+

Dockerfile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ COPY server/punica_kernels/ .
146146
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
147147
RUN python setup.py build
148148

149+
# Build eetq kernels
150+
FROM kernel-builder as eetq-kernels-builder
151+
WORKDIR /usr/src
152+
COPY server/Makefile-eetq Makefile
153+
# Build specific version of transformers
154+
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
155+
149156
# LoRAX base image
150157
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base
151158

@@ -194,6 +201,9 @@ COPY --from=punica-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/cond
194201
# Copy build artifacts from megablocks builder
195202
COPY --from=megablocks-kernels-builder /usr/src/megablocks/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
196203

204+
# Copy build artifacts from eetq builder
205+
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
206+
197207
# Install flash-attention dependencies
198208
RUN pip install einops --no-cache-dir
199209

launcher/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ enum Quantization {
2626
BitsandbytesFP4,
2727
Gptq,
2828
Awq,
29+
Eetq,
2930
Hqq_4bit,
3031
Hqq_3bit,
3132
Hqq_2bit,
@@ -50,6 +51,9 @@ impl std::fmt::Display for Quantization {
5051
Quantization::Awq => {
5152
write!(f, "awq")
5253
}
54+
Quantization::Eetq => {
55+
write!(f, "eetq")
56+
}
5357
Quantization::Hqq_4bit => {
5458
write!(f, "hqq-4bit")
5559
}

server/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ include Makefile-flash-att
22
include Makefile-flash-att-v2
33
include Makefile-vllm
44
include Makefile-megablocks
5+
include Makefile-eetq
56

67
unit-tests:
78
pytest -s -vv -m "not private" tests

server/Makefile-eetq

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
eetq_commit := cc2fdb4637e03652ac264eaef44dd8492472de01 # 323827dd471458a84e9c840f614e4592b157a4b1
2+
3+
eetq:
4+
# Clone eetq
5+
pip install packaging
6+
git clone https://github.com/NetEase-FuXi/EETQ.git eetq
7+
8+
build-eetq: eetq
9+
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
10+
cd eetq && python setup.py build
11+
12+
install-eetq: build-eetq
13+
cd eetq && python setup.py install

server/lorax_server/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Quantization(str, Enum):
1717
bitsandbytes_fp4 = "bitsandbytes-fp4"
1818
gptq = "gptq"
1919
awq = "awq"
20+
eetq = "eetq"
2021
hqq_4bit = "hqq-4bit"
2122
hqq_3bit = "hqq-3bit"
2223
hqq_2bit = "hqq-2bit"

server/lorax_server/utils/layers.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
except ImportError:
2222
HAS_AWQ = False
2323

24+
HAS_EETQ = False
25+
try:
26+
from EETQ import quant_weights, w8_a16_gemm
27+
28+
HAS_EETQ = True
29+
except ImportError:
30+
pass
31+
2432
HAS_HQQ = True
2533
try:
2634
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
@@ -127,6 +135,78 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
127135
x = torch.addmm(self.bias, input.view(-1, input.size(-1)), self.weight)
128136
x = x.view(size_out)
129137
return x
138+
139+
140+
class EETQLinear(nn.Module):
141+
"""
142+
EETQLinear module applies quantized linear transformation to the input tensor.
143+
144+
Args:
145+
weight (torch.Tensor): The weight tensor for the linear transformation.
146+
bias (torch.Tensor): The bias tensor for the linear transformation.
147+
148+
Attributes:
149+
weight (torch.Tensor): The weight tensor for the linear transformation.
150+
scale (torch.Tensor): The scale tensor used for quantization.
151+
bias (torch.Tensor): The bias tensor for the linear transformation.
152+
153+
"""
154+
155+
def __init__(
156+
self,
157+
weight,
158+
bias,
159+
) -> None:
160+
super().__init__()
161+
# Get the device where the weight tensor is currently stored.
162+
device = weight.device
163+
164+
# Transpose the weight tensor and make a contiguous copy of it on the CPU.
165+
# The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory,
166+
# which can improve performance in some cases.
167+
weight_transposed = torch.t(weight)
168+
weight_contiguous = weight_transposed.contiguous()
169+
weight_cpu = weight_contiguous.cpu()
170+
171+
# Quantize the weights. The quant_weights function is assumed to perform the quantization.
172+
# The weights are quantized to int8 format, and the quantization is not performed in place (False).
173+
weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False)
174+
175+
# Move the quantized weights and the scale back to the original device (GPU if available).
176+
# The cuda() function is used to move the tensors to the GPU.
177+
self.weight = weight_quantized.cuda(device)
178+
self.scale = scale.cuda(device)
179+
180+
# If a bias is present, move it to the GPU as well. If not, set the bias to None.
181+
if bias is not None:
182+
self.bias = bias.cuda(device)
183+
else:
184+
self.bias = None
185+
186+
def forward(self, input: torch.Tensor) -> torch.Tensor:
187+
"""
188+
Performs the forward pass of the layer.
189+
190+
Args:
191+
input (torch.Tensor): The input tensor.
192+
193+
Returns:
194+
torch.Tensor: The output tensor.
195+
"""
196+
# The function w8_a16_gemm performs a matrix multiplication operation between the input and the weight of the layer.
197+
# The result is then scaled by a factor (self.scale).
198+
gemm_output = w8_a16_gemm(input, self.weight, self.scale)
199+
200+
# If a bias is present (i.e., self.bias is not None), it is added to the output of the matrix multiplication.
201+
# If a bias is not present (i.e., self.bias is None), the output of the matrix multiplication is returned as is.
202+
if self.bias is not None:
203+
final_output = gemm_output + self.bias
204+
else:
205+
final_output = gemm_output
206+
207+
# The final output is returned.
208+
return final_output
209+
130210

131211

132212
class Linear8bitLt(nn.Module):
@@ -254,6 +334,13 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False):
254334
bias,
255335
quant_type="fp4",
256336
)
337+
elif quantize == "eetq":
338+
if HAS_EETQ:
339+
linear = EETQLinear(weight, bias)
340+
else:
341+
raise ImportError(
342+
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
343+
)
257344
elif quantize == "gptq":
258345
try:
259346
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight

0 commit comments

Comments
 (0)