Large Language Models (LLMs) and Neural Networks have revolutionized tasks like classification, summarization, and information extraction. However, achieving precision in these tasks often necessitates fine-tuning.
Traditional fine-tuning involves appending a task-specific head and updating the neural network’s weights during training. This approach contrasts with training from scratch, where the model’s weights are initialized randomly. In fine-tuning, the weights are already somewhat optimized from the pre-training phase.
Full fine-tuning entails training all layers of the neural network. While it typically yields superior results, it also demands substantial computational resources and time.
Nevertheless, efficient fine-tuning methods have emerged, offering promising alternatives. Among these, Low Rank Adaptation (LoRA) stands out for its ability to outperform full fine-tuning in certain scenarios, notably in preventing catastrophic forgetting, where the pre-trained model’s knowledge diminishes during fine-tuning.
In the training of neural network models, a set of weights is established, denoted as $W$. If the neural network’s input is $x$, each layer’s output is computed by multiplying the weights with the input and applying a non-linear activation function like ReLU:
$$
h = \text{ReLU}(Wx)
$$
Here, $h$ represents the model’s output.
In traditional fine-tuning, the pre-trained neural network’s weights are adjusted to suit a new task. This process involves continuing training from the previous state, resulting in minor changes in the model’s weights, denoted as $\Delta{W}$, where the new model’s weights become $W + \Delta{W}$.
However, this method can be resource-intensive for Large Language Models. In LoRA, rather than directly modifying $W$, we decompose the weight matrix to achieve the desired adjustments.
Research suggests that the weights of neural networks are overparametrized, implying that not all weight elements are equally crucial. This idea is encapsulated in the intrinsic rank hypothesis.
The intrinsic rank hypothesis proposes that significant changes in neural networks can be captured using a lower-dimensional representation. When fine-tuning neural networks, these changes in weights can be effectively encapsulated using low-rank matrices, implying that only a subset of the weight changes is essential.
Based on the intrinsic rank hypothesis, we represent $W$ with smaller matrices. Let’s assume that $W$ has dimensions $(d, k)$. We define two smaller matrices, $A$ and $B$, with dimensions $(d, r)$ and $(r, k)$ respectively.
Here, $r$ represents the rank (reduced dimension) of the matrix, serving as a hyperparameter during model fine-tuning. The number of parameters of LoRA-adapted layers depends on the value of $r$.
The product of matrices $A$ and $B$ represents the change in pre-trained weights $\Delta{W}$. Thus, the updated weight matrix ($W’$) becomes:
$$
W’ = W + BA
$$
where:
$$
W’ = W + \Delta{W} = W + BA
$$
$$
B \in \mathbb{R}^{d \times r}, A \in \mathbb{R}^{r \times k}
$$
$$
r « \min(d, k)
$$
In this equation, $W$ remains fixed (not updated during fine-tuning), while $A$ and $B$ are lower-dimensional matrices, with their product representing a low-rank approximation of $\Delta{W}$.
Authors typically initialize $A$ with random Gaussian values and $B$ with zeros, ensuring that $\Delta{W} = BA$ is zero at the beginning of training.
With LoRA, the output of the neural network becomes:
By selecting matrices $A$ and $B$ to have a lower rank $r$, the number of trainable parameters is significantly reduced.
For instance, if $W$ is a $d \times d$ matrix, traditional weight updating would involve $d^2$ parameters. However, with $B$ and $A$, the total number of parameters reduces to $2dr$, which is much smaller when $r « d$.
For example, if a model has 60,000 parameters, and the dimension of $W$ is $(300,200)$, matrices $A$ and $B$ will have dimensions $(300, r)$ and $(r, 200)$ respectively.
If $r = 4$, then the total number of trainable parameters will be $300 \times 4 + 200 \times 4 = 2000$, just 3% of the total parameters.
Similarly, if $r = 8$, then the total parameters will be 4000, still substantially lower than the original 60,000 of the model.
As models grow larger, the ratio of trainable to frozen parameters diminishes rapidly. For example, the base GPT-2 (L) model has 774 million trainable parameters, whereas the LoRA-adapted model has only 770 thousand — fewer than 0.1% of the base total.
One of the most impressive aspects of LoRA is that fine-tuned LoRA models typically perform as well as or better than their base model counterparts that have been finetuned.
The reduction in the number of trainable parameters achieved through Low-Rank Adaptation (LoRA) offers several significant benefits, particularly when fine-tuning large-scale neural networks:
Reduced Memory Footprint: LoRA decreases memory needs by lowering the number of parameters to update, aiding in the management of large-scale models.
Faster Training and Adaptation: By simplifying computational demands, LoRA accelerates the training and fine-tuning of large models for new tasks.
Feasibility for Smaller Hardware: LoRA’s lower parameter count enables the fine-tuning of substantial models on less powerful hardware, like modest GPUs or CPUs.
Scaling to Larger Models: LoRA facilitates the expansion of AI models without a corresponding increase in computational resources, making the management of growing model sizes more practical.
deftest(model):
correct =0 total =0 wrong_counts = [0for i in range(10)]
with torch.no_grad():
for data in tqdm(test_loader, desc="Testing"):
x, y = data
x = x.to(device)
y = y.to(device)
output= model(x.view(-1, 28*28))
for idx, i in enumerate(output):
if torch.argmax(i) == y[idx]:
correct +=1else:
wrong_counts[y[idx]] +=1 total+=1 print(f"Accuracy: {round(correct/total, 3) *100}")
for i in range(len(wrong_counts)):
print(f"Wrong counts for the digit {i}: {wrong_counts[i]}")
test(model)
Testing: 100%|██████████| 1000/1000 [00:06<00:00, 156.93it/s]
Accuracy: 89.1
Wrong counts for the digit 0: 33
Wrong counts for the digit 1: 12
Wrong counts for the digit 2: 147
Wrong counts for the digit 3: 115
Wrong counts for the digit 4: 69
Wrong counts for the digit 5: 196
Wrong counts for the digit 6: 98
Wrong counts for the digit 7: 91
Wrong counts for the digit 8: 185
Wrong counts for the digit 9: 148
Here, you can see the accuracy of model is 89.1%. We can see model is performing badly in case of digits 2, 3, 5, 8 and 9. While finetuning, we finetune the model for these digits.
Now, We will create another copy of model so that we can check our output later when we finetune the model.
original_model =copy.deepcopy(model)
Before applying, Lora, let’s create a function to get the number of trainable parameters in model.
defcount_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)
# 235,146
We have total of 235,146 parameters. Remember this number, we will compare it in later phase.
Now, I would like to discuss two methods from which we can finetune the model. There is a class parametrize, which helps us in replacing the original weights of any pytorch model. We will first go through this method. Since, this method will help us in getting the results of old model without much hassle.
So basically, parametrization replaces the parameters of model with some other function, Here in this case, we replace the “weight” of each layer with our custom function. After registering parametrization, whenever we call this layer, it will return the function we used, not the original weights of model. This will be helpful here since we don’t want to touch the original weights of model.
importtorch.nn.utils.parametrizeasparametrizefor name, layer in model.named_children():
if isinstance(layer, nn.Linear):
parametrize.register_parametrization(layer, "weight", LoraAdapter(layer.in_features, layer.out_features, rank=2, alpha=1))
Now, remember the model layers after we use parametrization, it adds additional parametrized layer in the model.
Here we can see the model has Parametrized layer with LoraAdapter . Now since we are only going to train the A and B, weight matrices, we will freeze all the weights of model.
for name, param in model.named_parameters():
if"A"notin name and"B"notin name:
param.requires_grad =Falsecount_parameters(model)
# 3124
Now see, how the number of parameter decreased from 238270 to 3124. Let’s see the percentage of parameters we are going to train.
So, we are only training 1.32% of original model. This is still a bit high number since our model is really small. For big models, This number will be in the range of 0-1.
Now, Let’s create finetuning dataset. For that, we will only use digits in which the model performed poor. We then train the model for 200 iterations. I am limiting the number of iterations for simplicity.
mnist_dataset = datasets.MNIST(root="./data", train=True, download =True, transform=transform)
for digit in [2,3,5,8,9]:
if digit ==2:
exclude_indices = mnist_dataset.targets == digit
else:
indices = mnist_dataset.targets == digit
exclude_indices = exclude_indices | indices
mnist_dataset.data = mnist_dataset.data[exclude_indices]
mnist_dataset.targets = mnist_dataset.targets[exclude_indices]
train_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=10, shuffle=True)
train(train_loader, model, epochs=1, limit=200)
Here you can see the loss is decreased from 0.72 to 0.32. But the model may have overfitted the data to only the digits we finetuned. Nevertheless, let’s see the results below.
As I said earlier, using parametrization, we will have an option to enable or disable paremetrization any time. we will create a function to do the same.
defenable_disable_lora(enabled=True):
for layer in [model.linear1, model.linear2, model.linear3]:
layer.parametrizations["weight"][0].enabled = enabled
Let’s check the results before finetuning.
enable_disable_lora(False)
test(model)
Testing: 100%|██████████| 1000/1000 [00:05<00:00, 171.80it/s]
Accuracy: 89.1
Wrong counts for the digit 0: 33
Wrong counts for the digit 1: 12
Wrong counts for the digit 2: 147
Wrong counts for the digit 3: 115
Wrong counts for the digit 4: 69
Wrong counts for the digit 5: 196
Wrong counts for the digit 6: 98
Wrong counts for the digit 7: 91
Wrong counts for the digit 8: 185
Wrong counts for the digit 9: 148
Now, After finetuning
enable_disable_lora(True)
test(model)
Testing: 100%|██████████| 1000/1000 [00:07<00:00, 136.35it/s]
Accuracy: 68.30000000000001
Wrong counts for the digit 0: 543
Wrong counts for the digit 1: 484
Wrong counts for the digit 2: 79
Wrong counts for the digit 3: 101
Wrong counts for the digit 4: 609
Wrong counts for the digit 5: 98
Wrong counts for the digit 6: 348
Wrong counts for the digit 7: 765
Wrong counts for the digit 8: 103
Wrong counts for the digit 9: 44
As you can see the model is performing much better for the digits we finetuned, but as suspected the model has overfitted the data. So you have to be careful when finetuning any models.
As now we need to apply LORA layer to all of our linear layers in model and then add the weight with the original weights when inference, we create another model with all of that. I am using rank=2 here. To check how rank impacts in finetuning.
Testing: 100%|██████████| 1000/1000 [00:07<00:00, 138.74it/s]
Accuracy: 78.0
Wrong counts for the digit 0: 88
Wrong counts for the digit 1: 503
Wrong counts for the digit 2: 97
Wrong counts for the digit 3: 101
Wrong counts for the digit 4: 361
Wrong counts for the digit 5: 119
Wrong counts for the digit 6: 120
Wrong counts for the digit 7: 651
Wrong counts for the digit 8: 103
Wrong counts for the digit 9: 55
As we can see, it did perform better than the previous method. This is due to the rank. we used rank 2 matrix here. While finetuning any models, you need to find the best value of rank and use that.
As with previous method, this method also overfitted the data, but the accuracy has improved by 10% which is huge.