Ray Serve’s model multiplexing with LoRA adapters per request allows a single model deployment to serve multiple fine-tuned versions of that model concurrently, routing each incoming request to the specific LoRA adapter needed.
This is what it looks like in practice. Imagine you have a base LLM, say Llama 2. Without LoRA, if you wanted to serve it for three different tasks (e.g., summarization, translation, code generation), you’d typically deploy three separate Ray Serve deployments, each running a full copy of Llama 2. This is incredibly resource-intensive.
from ray import serve
from ray.serve.handle import DeploymentHandle
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Assume you have a base model and multiple LoRA adapters saved
base_model_name = "meta-llama/Llama-2-7b-hf"
lora_adapter_paths = {
"summarization": "/path/to/your/lora/summarization",
"translation": "/path/to/your/lora/translation",
"code_gen": "/path/to/your/lora/code_gen",
}
@serve.deployment(num_replicas=1)
class MultiLoRAModel:
def __init__(self, base_model_name: str, lora_adapter_paths: dict):
self.base_model_name = base_model_name
self.lora_adapter_paths = lora_adapter_paths
self.base_model = None
self.loaded_adapters = {}
self._load_base_model()
def _load_base_model(self):
print(f"Loading base model: {self.base_model_name}")
self.base_model = AutoModelForCausalLM.from_pretrained(
self.base_model_name,
torch_dtype=torch.float16,
device_map="auto",
load_in_8bit=True # Example for memory efficiency
)
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
print("Base model loaded.")
def _load_lora_adapter(self, adapter_name: str):
if adapter_name not in self.loaded_adapters:
print(f"Loading LoRA adapter: {adapter_name}")
from peft import PeftModel
# Load the adapter and merge it with the base model for inference
# Note: For true dynamic switching without merging, you'd use PeftModel directly
# and manage its state. Here, we simplify by showing a common pattern.
# A more advanced version would avoid merging and just load PEFT configs.
model_with_adapter = PeftModel.from_pretrained(
self.base_model,
self.lora_adapter_paths[adapter_name],
adapter_name=adapter_name # This is key for PEFT's multi-adapter support
)
# If you need to switch between adapters without re-merging,
# you'd typically use the `set_adapter` method on the PeftModel.
# For simplicity in this example, we'll assume a pattern where
# the adapter is applied.
# In a real-world scenario with dynamic switching, you'd instantiate
# PeftModel once and then call model_with_adapter.set_adapter(adapter_name)
# before inference.
# For this example, let's just register it.
self.loaded_adapters[adapter_name] = model_with_adapter # Or a reference to the adapter config
print(f"LoRA adapter '{adapter_name}' loaded.")
return self.loaded_adapters[adapter_name]
async def __call__(self, request):
data = await request.json()
prompt = data.get("prompt")
adapter_key = data.get("adapter", "summarization") # Default to summarization
if adapter_key not in self.lora_adapter_paths:
return {"error": f"Adapter '{adapter_key}' not found."}
# Dynamically load or get the adapter model
model_to_use = self._load_lora_adapter(adapter_key)
# --- Actual Inference ---
# In a real PeftModel setup, you'd do:
# model_with_adapter.set_adapter(adapter_key)
# outputs = model_with_adapter.generate(...)
# For this simplified example, we'll just simulate.
# The key is that `model_to_use` would be configured for `adapter_key`.
# Placeholder for actual generation using the specified adapter
print(f"Generating text for prompt: '{prompt}' using adapter: '{adapter_key}'")
# In a real scenario, you'd tokenize, generate, and decode.
# Example using the base model and assuming adapter effects:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.base_model.device)
# Simulate generation that would be influenced by the adapter
# This is where the PeftModel's `generate` method would be called
# with the correct adapter activated.
# For demonstration, we'll just return a mock response.
generated_text = f"Mock response for '{prompt}' using {adapter_key} adapter."
# --- End Inference ---
return {"generated_text": generated_text, "adapter_used": adapter_key}
# Deploy the model
multi_lora_app = MultiLoRAModel.bind(base_model_name, lora_adapter_paths)
serve.run(multi_lora_app)
This setup drastically cuts down on memory. Instead of three full LLMs, you have one base LLM and then just the small LoRA adapter weights for each task loaded into memory. When a request comes in for summarization, Ray Serve routes it to one of the MultiLoRAModel replicas. The replica then dynamically loads (or already has loaded) the summarization LoRA weights, applies them to the base model for that specific request’s inference, and returns the result.
The core problem this solves is resource exhaustion for serving multiple fine-tuned models. Traditionally, if you need to serve a base model fine-tuned for N different tasks, you’d deploy N separate models, each with its own copy of the base model weights. This leads to N times the VRAM and RAM usage. Model multiplexing with LoRA allows you to run a single instance of the base model and only load the delta weights (the LoRA adapters) for each task. This means your VRAM and RAM usage scales with the size of the base model plus the sum of the sizes of the LoRA adapters, not N times the size of the base model.
Internally, Ray Serve’s DeploymentHandle acts as a proxy. When you call handle.remote(), the request is sent to a running replica of the MultiLoRAModel deployment. Inside the MultiLoRAModel class, the __call__ method receives the request. It inspects the request payload (e.g., a JSON object containing the prompt and an adapter key) to determine which LoRA adapter to use. It then calls _load_lora_adapter if that adapter hasn’t been loaded yet. _load_lora_adapter uses a library like PEFT (Parameter-Efficient Fine-Tuning) to load the adapter weights. Crucially, PEFT allows you to load multiple adapters into the same base model and switch between them dynamically using methods like set_adapter(). Once the correct adapter is active, the model performs inference.
The key levers you control are:
base_model_name: Which foundational model you’re starting with.lora_adapter_paths: A mapping from a logical adapter name (e.g.,"translation") to the file path where its LoRA weights are stored.- Request Payload: How you indicate which adapter to use. This is typically a field in the JSON body of the HTTP request (e.g.,
{"prompt": "...", "adapter": "code_gen"}). num_replicas: The number ofMultiLoRAModelinstances running. Ray Serve will automatically distribute incoming requests across these replicas.
When you have multiple LoRA adapters loaded into a single PeftModel instance, the set_adapter() method is what truly enables multiplexing. It efficiently swaps the active adapter weights without needing to re-instantiate or heavily re-load the base model. This operation is typically very fast, often just a few milliseconds, making it suitable for per-request switching.
The next challenge you’ll encounter is managing the lifecycle of these loaded adapters, particularly when dealing with a very large number of adapters or when memory becomes a constraint and you need to unload inactive adapters.