Large Language Models (LLMs) are revolutionizing AI research these days. Many tasks that once required complex models can now be solved in minutes with the help of LLMs. Not only are they generation models, but they can also tackle summarization, classification, question answering, clustering, and much more. While all these benefits are fantastic, using LLMs on your own machine can be challenging due to their size. Most LLMs require larger GPUs to run, which might be feasible for big companies but can be a stumbling block for individuals.
Enter quantization. Quantization is a method that significantly reduces the size of any model. In quantization, we convert the model’s parameters from higher precision, like FLOAT32 (FP32), to lower precision, such as INT4 or INT8. This greatly shrinks the model’s size. However, since we’re reducing the precision of the model’s parameters, there’s a slight decrease in accuracy. But the trade-off in size might be well worth it.
As I mentioned earlier, running an LLM with 100 billion parameters on your computer is not feasible, so we need a way to run these models on our machines without significant performance degradation. Models are usually trained in higher precision, i.e., FLOAT32, and when we quantize these models, we typically convert them to a lower precision range like INT8 or even INT4.
Let’s use an LLM as an example. Llama3-70B has 70 billion parameters. To store this model in full precision, i.e., FP32, we need $70 \space billion \times 32 \space bit = \frac{70 \times 10^9 \times 32}{8 \times 10^9} = 280 \space GB$ of storage.
Now, let’s also calculate the GPU memory (VRAM) required to run this 70 billion parameter model.
So, the memory required to run the LLama3-70B in full precision is approximately 336 GB.
This is really high, since we need high end GPU or multiple GPUs to run this model in full precision. Even if we take smaller models like 7B or 8B, we need almost 28 GB VRAM.
This seems impossible for individual users. But if we quantize the same model in 4-bit or 8-bit we need much smaller VRAM. let’s calculate:
For 4-bit
$$ M \approx \frac{7 \times 4 \times 4}{32} \times 1.2 $$
$$M \approx 4.2 \space GB$$
For 8-bit
$$ M \approx \frac{7 \times 8 \times 4}{32} \times 1.2 $$
$$M \approx 8.4 \space GB$$
Even after quantization, these models require much larger GPU RAM. But these are in acceptable range as we can run these models simply in our computer.
So, quantization helps us use these models on smaller GPU without much performance degradation.
Before delving into the details of how we quantize models from high precision to low precision, let’s take a look at how computers store numbers in CPU or memory. For every bit of a number, the computer requires an equal amount of storage. For instance, if we have a 32-bit number, the computer will use 32 bits to store that number in memory.
Computers store numbers in two ways: Unsigned and Signed. A Signed value will store the sign (positive or negative) of the number, while an Unsigned value cannot store the sign of the number. To store the sign of any number, the computer uses 1 bit of memory.
Let’s look an example on how a 8-bit Unsigned integer stored in memory.
The range of integers an 8-bit integer can store is calculated as: $range \space [0, 2^n-1] = [0, 255] $ for 8-bit unsisigned integers. This means that, it can only stores values from 0 to 255. If the value lies beside these boundaries, then they will be set as either 0 for smaller and 255 for larger value.
This is different for signed integers, as the range is calculated as:
$$range \space [- \space 2^{n-1}, 2^{n-1}-1]$$
For 8-bit signed integers the range will be $[-128, 127]$
This is completely different when it comes to floating point numbers. In floating point numbers, we will have three components:
Sign
Exponent or Range
Fraction or Precision or Mantissa
Below, you can see how computer stores floating point numbers. Each of these formats consume different chunk of the memory.
For example, float32 allocates 1 bit for sign, 8 bits for exponent and 23 bits for mantissa.
Similarly, float16 or FP16 allocates 1 bit for sign but just 5 bits for exponent and 10 bits for mantissa. On the other hand, BF16 (B stands for Google Brain) allocates 8 bits for the exponent and just 7 bits for mantissa.
So, in short the conversion from a higher memory format to a lower memory format is called quantization. Talking in deep learning terms, Float32 is referred to as single or full precision and Float16 and BFloat16 are called half precision. The default way in which deep learning models are trained and stored is in full precision. The most commonly used conversion is from full precision to an int8 and int4 format.
In asymmetric mode, we map min-max of dequantized value to min-max range of target precision. This is done by using zero-point also called quantization bias. Here zero point can be non-zero value.
$$X_q = clamp \left( \frac{X}{S} + Z ; 0, 2^n - 1\right)$$
Where,
$$
clamp(x;a,c) = \begin{cases}
a &\quad x < a \\
x &\quad a \le x \le c \\
c &\quad x>c
\end{cases}
$$
$$ S = \frac{X_{max} - X_{min}}{2^n - 1} $$
$$Z = -\frac{X_{min}}{S}$$
Here,
$X \quad = $ Original floating point tensor
$X_q \quad = $ Quantized Tensor
$S \quad=$ Scale Factor
$Z \quad= $ Zero Point
$n \quad= $ Number of bits used for quantization
Note that in the derivation above we use unsigned integer to represent the quantized range. That is, $X_q∈[0,2^n−1]$. One could use signed integer if necessary (perhaps due to HW considerations). This can be achieved by subtracting $2^n−1$
.
As, we can see there is quite the data loss when quantizing the tensor with Mean Squared error 202.06, which is really high. But we have more optimized methods these days, so the data loss will be really small.
Also, see how the tensor performs when we dequantize it, we can see huge difference there also.
In symmetric quantization when converting from higher precision to lower precision, we can always restrict to values between $[-(2^{n-1} - 1), \space + 2^{n-1}-1]$ and ensure that the zero of the input perfectly maps to the zero of the output leading to a symmetric mapping.
For FLOAT16 to INT8 Conversion, we restrict the values between -127 to +127.
Asymmetric range fully utilizes the quantized range because we exactly map the min-max values from float to the min-max range of quantized range.
While in Symmetric, if the float range is biased towards one side, it could result in a quantized range where significant dynamic range is dedicated to values that we’ll never see. This could result in greater loss.
Also Zero point in asymmetric quantization leads extra weight on Hardware as it requires extra calculation, while the symmetric quantization is much simpler when we compare it to asymmetric. So we mostly use symmetric quantization.
Above, we saw that the zero point of symmetric quantization is zero, while it is different in case of Asymmetric quantization. How do we decide this?
Let’s take an example, Every integer or floating point number will have their own range (-128 to 127 for int8), the scaling factor essentially divides these numbers into equal factor. Since, when quantization, the high precision values should be reduced to lower precision, we need to clip those values at some point say alpha and beta for negative and positive values respectively. Any value beyond alpha and beta is not meaningful because it maps to the same output as that of alpha and beta. For the case of INT8 its -127 and +127 (we use -127 or numerical stability, this is called restricted quantization). The process of choosing these clipping values alpha and beta and hence the clipping range is called calibration.
To avoid cutting off too many numbers, a simple solution is to set alpha to $X_{\text{min}}$ and beta to $X_{\text{max}}$. Then we can easily figure out the scale factor, $S$, using these smallest and largest numbers. But this might make our counting uneven. For instance, if the largest number ($X_{\text{max}}$) is 1.5 and the smallest ($X_{\text{min}}$) is -1.2, our counting isn’t balanced. To make it fair, we pick the larger number between the two ends and use that as our cutoff point on both sides. And we start counting from 0.
This balanced counting method is what we use when simplifying neural network weights. It’s simpler because we always start counting from 0, making the math easier.
Now, let’s consider when our numbers are mostly on one side, like the positive side. This is similar to the outputs of popular activation functions like ReLU or GeLU. Also, these activation outputs change based on the input. For example, showing the network two images of a cat might give different outputs. 1
While, Minimum and maximum value works, sometimes we may see outliers that affect in quantization, in such cases, we can choose percentiles to choose the value of alpha and beta.
In Post-Training quantization or PTQ, we quantize the weights of already trained model. This is straightforward and easy to implement, however it may degred the performance of model slightly due to the loss of precision in the value of weights.
To better calibrate the model, model’s weight and activations are evaluated on a representative dataset to determine the range of values (alpha, beta, scale, and zero-point) taken by these parameters. We then use these parameters to quantize the model.
Based on the methods of quantization, we can further divide PTQ into three categories:
Dynamic-Range Quantization: In this method, quantize the model based on the range of data globally. This method produces small model but there may be a bit more accuracy loss.
Weight Quantization: In this method, we only quantize the weights of model leaving activations in their high precision. There may be higher accuracy loss with this method.
Per-Channel Quantization: In this method, we quantize the model parameters based on the dynamic range per channel rather than globally. This helps in achieving optimal accuracy.
classDigitsNet(nn.Module):
def __init__(self, input_shape, num_classes):
super(DigitsNet, self).__init__()
self.linear1 = nn.Linear(input_shape, 100)
self.linear2 = nn.Linear(100, 100)
self.linear3 = nn.Linear(100,num_classes)
self.relu = nn.ReLU()
defforward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.relu(x)
return x
model = DigitsNet(28*28, 10).to(device)
Now, let’s create a simple training loop to train and test the model
deftrain(train_loader, model, epochs =5, limit=None):
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
total_iterations =0for epoch in range(epochs):
model.train()
loss_sum =0 num_iterations =0 data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
if limit isnotNone:
data_iterator.total = limit
for data in data_iterator:
num_iterations +=1 total_iterations +=1 x, y = data
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
output = model(x.view(-1, 28*28))
loss = ce_loss(output, y)
loss_sum += loss
avg_loss = loss_sum / num_iterations
data_iterator.set_postfix(loss=avg_loss.item())
loss.backward()
optimizer.step()
if limit isnotNoneand total_iterations >= limit:
returndeftest(model):
correct =0 total =0 wrong_counts = [0for i in range(10)]
with torch.no_grad():
for data in tqdm(test_loader, desc="Testing"):
x, y = data
x = x.to(device)
y = y.to(device)
output= model(x.view(-1, 28*28))
for idx, i in enumerate(output):
if torch.argmax(i) == y[idx]:
correct +=1else:
wrong_counts[y[idx]] +=1 total+=1 print(f"Accuracy: {round(correct/total, 3) *100}")
Our original model size before quantizing is 360KB with 90.6% accuracy. Now let’s quantize the model.
To quantize the model, we first create a exact copy of model with two extra layers, i.e. quant and dequant. These layers basically help us in finding the optimal range.
classQuantizedDigitsNet(nn.Module):
def __init__(self, input_shape, num_classes):
super(QuantizedDigitsNet, self).__init__()
self.quant = torch.quantization.QuantStub()
self.linear1 = nn.Linear(input_shape, 100)
self.linear2 = nn.Linear(100, 100)
self.linear3 = nn.Linear(100,num_classes)
self.relu = nn.ReLU()
self.dequant = torch.quantization.DeQuantStub()
defforward(self, x):
x = x.view(-1, 28*28)
x = self.quant(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.relu(x)
x = self.dequant(x)
return x
As I said earlier, we introduce model with some test set to calibrate the range. This gives us idea about the values of $\alpha$, $\beta$, $S$ and $Z$.
In above model, we used two observers to observe the model behaviour. These observers calculate the required values when we calibrate the model. To calbrate it we simply pass the test set to dataset for a epoch.
As we can see above, we have MinMaxObserver, in each layer of our model. These observers when passed through test set calculate the of min_val and max_val.
As we can see, after quantization, the model weights are converted into int8.
Also, we can see the loss when dequantizing the quantized model. To check how much memory we gained and performance loss, let’s check the accuracy and size of the model.
The size is reduced by almost 4 times. This is reasonable since we are reducing weights from fp32 to int8. It is slightly more than 4 times, as we also need to store scale and other parameters after quantization.
Regarding the accuracy, we didn’t loss that much from quantization. The loss is only 0.3%, which is really good.
While this is just a example, we may have more or less loss in real life models. There are many advance algorithms such as GPTQ, AWQ, GGUF to quantize the model after training. I will talk about them in later articles.
Unlike PTQ, QAT integrates the weight conversion process during the training stage. This often results in superior model performance, but it’s more computationally demanding. A highly used QAT technique is the QLoRA.
As we move to a lower precision from float, we generally notice a significant accuracy drop as this is a lossy process. This loss can be minimized with the help of quant-aware training. So basically, quant-aware training simulates low precision behavior in the forward pass,
while the backward pass remains the same. This induces some quantization error which is accumulated in the total loss of the model and hence the optimizer tries to reduce it by adjusting the parameters accordingly. This makes our parameters more robust to quantization making our process almost .
To introduce the quantization loss we introduce something known as FakeQuant nodes into our model after every operation involving computations to obtain the output in the range of our required precision. A FakeQuant node is basically a combination of Quantize and Dequantize operations stacked together.
Now that we have defined our FakeQuant nodes, we need to determine the correct position to insert them in the graph. We need to apply Quantize operations on our weights and activations using the following rules:
Weights need to be quantized before they are multiplied or convolved with the input.
Our graph should display inference behavior while training so the BatchNorm layers must be folded and Dropouts must be removed.
Outputs of each layer are generally quantized after the activation layer like Relu is applied to them which is beneficial because most optimized hardware generally have the activation function fused with the main operation.
We also need to quantize the outputs of layers like Concat and Add where the outputs of several layers are merged.
We do not need to quantize the bias during training as we would be using int32 bias during inference and that can be calculated later on with the parameters obtained using the quantization of weights and activations.
Now that our graph is ready, we train the graph with quantization layers. While training we use the quantization layers only in forward pass to introduce extra quantization error which helps
is reducing quantization loss.
Now that our model is trained and ready, we take the quantized weights and quantize them using the parameters we get from QAT. Since, it only accepts quantized input, we also need to quantize the input while inferencing.
The functions to load data and train and testing loop are same for this operation. I will just go through the model creation and preparing model for training.
importosfromtqdmimport tqdm
importtorchimporttorch.nnasnnimporttorchvision.datasetsasdatasetsimporttorchvision.transformsastransforms_ = torch.manual_seed(0)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1301, ), (0.3081, ))])
mnist_trainset = datasets.MNIST(root ="./data", train =True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
mnist_testset = datasets.MNIST(root ="./data", train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10)
device = torch.device("cpu")
deftrain(train_loader, model, epochs =5, limit=None):
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
total_iterations =0for epoch in range(epochs):
model.train()
loss_sum =0 num_iterations =0 data_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}")
if limit isnotNone:
data_iterator.total = limit
for data in data_iterator:
num_iterations +=1 total_iterations +=1 x, y = data
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
output = model(x.view(-1, 28*28))
loss = ce_loss(output, y)
loss_sum += loss
avg_loss = loss_sum / num_iterations
data_iterator.set_postfix(loss=avg_loss.item())
loss.backward()
optimizer.step()
if limit isnotNoneand total_iterations >= limit:
returndefprint_model_size(model):
torch.save(model.state_dict(), "temp_model.pt")
print("Size (KB): ", os.path.getsize("temp_model.pt")/1e3)
os.remove("temp_model.pt")
deftest(model):
correct =0 total =0 wrong_counts = [0for i in range(10)]
with torch.no_grad():
for data in tqdm(test_loader, desc="Testing"):
x, y = data
x = x.to(device)
y = y.to(device)
output= model(x.view(-1, 28*28))
for idx, i in enumerate(output):
if torch.argmax(i) == y[idx]:
correct +=1else:
wrong_counts[y[idx]] +=1 total+=1 print(f"Accuracy: {round(correct/total, 3) *100}")
Let’s create a model. Previously, we created a model and trained without quantization. But in this method, we introduce fake quantization layer before training. So, we add quant and dequant stub in the model creation process like below:
classDigitsNet(nn.Module):
def __init__(self, input_shape, num_classes):
super(DigitsNet, self).__init__()
self.quant = torch.quantization.QuantStub()
self.linear1 = nn.Linear(input_shape, 100)
self.linear2 = nn.Linear(100, 100)
self.linear3 = nn.Linear(100,num_classes)
self.relu = nn.ReLU()
self.dequant = torch.quantization.DeQuantStub()
defforward(self, x):
x = x.view(-1, 28*28)
x = self.quant(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.relu(x)
x = self.dequant(x)
return x
model = DigitsNet(28*28, 10).to(device)
Now, we also introduce observers which helps us in getting the quantization range and other parameters.
In conclusion, quantization emerges as a pivotal solution to mitigate the computational barriers posed by Large Language Models (LLMs). While these models offer unparalleled capabilities in various tasks, their size necessitates significant resources for efficient execution. Quantization addresses this challenge by converting model parameters from higher precision formats like FLOAT32 to lower precision formats such as INT8 or INT4, substantially reducing storage and memory requirements. Whether through asymmetric or symmetric methods, quantization offers a trade-off between computational efficiency and accuracy preservation, enabling the execution of LLMs on standard consumer-grade hardware.
Moreover, modes of quantization, such as Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT), provide flexibility in integrating quantization into the model lifecycle. PTQ simplifies the process by quantizing pre-trained models, albeit with potential accuracy loss, while QAT, although computationally demanding, offers superior performance by integrating quantization into the training stage. As research and development in quantization techniques progress, we anticipate further advancements in efficiency and performance, ultimately democratizing access to advanced AI capabilities and fostering widespread adoption across diverse hardware infrastructures.