So, you’ve trained a Large Language Model (LLM) or picked up a powerful open-source one. That’s the first mountain climbed. Now comes the next, arguably more complex one: how do you actually serve this model to potentially thousands of users efficiently without setting your cloud budget on fire?
LLMs are notoriously massive and hungry for computational resources. A naive approach to serving them will quickly lead to slow response times and skyrocketing costs. But over the last few years, a stack of brilliant optimization techniques has emerged, turning what was once impractical into a robust, scalable reality.
The Foundation: Understanding the KV Cache
At its core, an LLM generates text one token (roughly, a word or part of a word) at a time in a process called autoregression. To generate the next token, it needs the context of all the tokens that came before it.
The most basic way to do this is also the most inefficient: for every new token, the model would re-process the entire sequence from the very beginning.
This is where the KV Cache comes in, and it’s the single most important optimization for LLM inference.
generated_tokens = []
next_inputs = inputs
durations_s = []
for _ in range(10):
t0 = time.time()
next_token_id = generate_token(next_inputs)
durations_s += [time.time() - t0]
next_inputs = {
"input_ids": torch.cat(
[next_inputs["input_ids"], next_token_id.reshape((1, 1))],
dim=1),
"attention_mask": torch.cat(
[next_inputs["attention_mask"], torch.tensor([[1]])],
dim=1),
}
next_token = tokenizer.decode(next_token_id)
generated_tokens.append(next_token)
- What it is: In the “attention” mechanism of a Transformer, the model calculates three matrices from the input: a Query, a Key, and a Value. To generate a new token, its Query is compared against the Keys of all previous tokens to figure out “what to pay attention to,” and then a weighted sum of the Values is used to produce the output. The KV Cache simply stores the Key and Value matrices for all previous tokens so they don’t have to be recalculated every single time.
- The Analogy We Love: Think of it like taking notes in a meeting. To understand the last sentence spoken, you don’t need to re-listen to the entire meeting recording from the start. You just glance at your notes (the KV Cache).
- A Point of Confusion: In code, you might see the
attention_mask
being manually updated in a loop. Why? The attention mask tells the model which tokens to pay attention to. As we generate a new token and add it to our KV Cache, we must also extend the mask to tell the model, “Hey, pay attention to this new token, too!”
The result is a two-phase generation process:
- Prefill: The slow first step where the prompt is processed and the initial KV Cache is filled. This determines the “Time to First Token” (TTFT).
- Decode: The fast subsequent steps where each new token is generated quickly by re-using the cache. This determines the token-per-second throughput.
Scaling Up: Continuous Batching
Okay, so we can serve one user efficiently. What happens when we have dozens of requests hitting our server at once? The obvious answer is to “batch” them—process them together to maximize GPU utilization. But how we batch makes all the difference.
- The Problem with Static Batching: The simple approach is to group, say, 8 requests and process them as one batch. The catch? The entire batch is only as fast as its slowest request. If one user asks for a 500-token essay and seven others ask for a 10-token answer, those seven will finish quickly and then sit idle, wasting precious GPU cycles while the long request finishes.
- The Analogy: It’s like a bus that must drive to the final destination of its last passenger before it can pick up anyone new, even if everyone else got off at the first stop.
- The Solution: Continuous Batching. This is a smarter scheduling algorithm. As soon as any request in the batch finishes, its spot is immediately filled with a new request from the waiting queue. This keeps the GPU constantly fed with useful work. The result is dramatically higher throughput and lower average latency for all users.
Fine-tuning with LoRA
LoRA is an efficient fine-tuning technique that adapts a model’s behavior by injecting small, trainable “adapter” layers into an existing model, without altering the model’s original, massive weights.
Imagine we setting up a model with a hidden_size
of 1024. The weight matrix W
of the model.linear
layer is a 1024 x 1024
matrix, containing over 1 million parameters (1024×1024=1,048,576).
If you were to perform traditional fine-tuning on this model, you would need to update all of these 1 million+ parameters, which consumes significant computational resources. Furthermore, for each new task you fine-tune, you need to save a new, full-sized copy of the model.
The LoRA Solution: Don’t Modify, Just Add
LoRA’s approach is not to modify the original weights W
, but to “freeze” them and add a “bypass” path alongside.
Create low-rank matrices: It creates two very “thin” matrices: lora_a
(shape 1024 x 2
) and lora_b
(shape 2 x 1024
). The number 2
here is the “rank.”
Calculate the update: These two small matrices are multiplied (W2 = lora_a @ lora_b
) to produce an “update matrix” that has the same shape (1024 x 1024
) as the original weight matrix W
.
Combine the results: During the forward pass, the model’s output is the sum of two parts:
- Output from the original path:
base_output = X @ W
- Output from the LoRA bypass:
lora_output = X @ A @ B
- Final result:
total_output = base_output + lora_outpu
This shows that we only need to add less than 0.4% of the original parameter count to simulate a full update of the entire weight matrix. During fine-tuning, we only train these 0.4% of parameters while the millions of original parameters remain frozen. This dramatically saves computational and storage resources.
Multi-LoRA
Imagine a cloud service platform that uses a powerful foundation model (like Llama 3). Now, there are hundreds or thousands of customers, each of whom has fine-tuned this foundation model with their own data using LoRA to adapt it to their specific tasks (e.g., Customer A uses it for customer service chats, Customer B for summarizing legal documents, Customer C for generating marketing copy, etc.).
Now, the server receives a batch of requests, which contains requests from different customers. This means that for each request in the batch, we need to apply a different LoRA adapter.
The most naive and worst approach would be: to deploy a separate model instance for each customer. This would immediately exhaust all GPU memory, making the cost prohibitively high.
The purpose of Multi-LoRA is to solve this exact problem: it allows you to load just one copy of the giant foundation model into memory, and then load hundreds or thousands of tiny LoRA adapters (the A/B matrices) alongside it. When processing a batch of requests, the system can dynamically apply the correct LoRA adapter for each individual request within the batch.
Method 1: Looping
The LoopMultiLoraModel
class demonstrates the most intuitive method: using a for
loop to iterate through each request in the batch.
- In each step of the loop, it finds the corresponding
lora_a
andlora_b
matrices for the current request usinglora_indices
. - It then performs the LoRA computation for that single request:
y[batch_idx] += x[batch_idx] @ lora_a @ lora_b
.
Disadvantage: Python loops, especially in high-performance computing scenarios, are notoriously inefficient. They cannot fully leverage the massively parallel processing capabilities of a GPU, leading to higher latency as the batch size grows. The first chart in the script clearly illustrates this: the latency grows linearly with the batch size, indicating poor performance.
Method 2: Gathering and Vectorizing
The GatheredMultiLoraModel
class demonstrates the efficient implementation.
- Instead of a loop, it uses a critical operation:
torch.index_select
. This operation can, in a single step, **“gather”**all the required LoRA weights (loras_a
andloras_b
) for every request in the batch into new tensors, based on thelora_indices
. - It then performs a single, batch-wide matrix multiplication:
y += x @ lora_a @ lora_b
. Here,x
,lora_a
, andlora_b
are all tensors containing the data for the entire batch.
Advantage: This operation is highly vectorized and can be processed in parallel at high speed by PyTorch on the GPU. The second chart in the script proves this: even as the batch size increases, the latency growth is far flatter than the looping method, resulting in much higher performance.
Production Systems Like LoRAX
A production-grade serving framework like Predibase’s LoRAX is what you get when you put all these concepts together into a single, polished system.
- It measures TTFT and throughput, proving it uses a KV Cache.
- It handles concurrent requests of different lengths, proving it uses Continuous Batching.
- It serves different models via an
adapter_id
parameter, proving it’s a Multi-LoRA server.