Running Two vLLM Instances on a Single GPU: A Deep Dive into VRAM Partitioning, KV Cache Compression, and Automated Memory Management
Running Two vLLM Instances on a Single GPU: A Deep Dive into VRAM Partitioning, KV Cache Compression, and Automated Memory Management
April 7, 2026 · AitherOS Engineering
TL;DR
We run two production vLLM instances — an 8B orchestrator and a 14B reasoning model — simultaneously on a single RTX 5090 (32 GB), with 5.8 GB free for image generation. This shouldn't be possible. Here's how we made it work, the five things vLLM gets wrong about multi-instance GPU sharing, and the three-line fix that saved the project.
The Problem
AitherOS is a local-first AI operating system. Everything runs on your hardware — orchestration, reasoning, embeddings, image generation, code execution — in a Docker Compose stack of ~87 containers. The brain of the system is a dual-model inference pipeline:
| Role | Model | Quantization | Purpose |
|---|---|---|---|
| Orchestrator | Nemotron-8B-AWQ | AWQ 4-bit weights, TQ4 KV cache | Tool routing, planning, conversation |
| Reasoning | DeepSeek-R1-14B-AWQ | AWQ 4-bit weights, TQ4 KV cache | Chain-of-thought, math, logic |
Both run on vLLM 0.19 via our custom TurboQuant-patched fork. The orchestrator handles every user interaction; reasoning is invoked as a "tool" for hard problems. We need both loaded simultaneously because cold-starting a 14B model takes 8–30 seconds — unacceptable for interactive use.
The constraint: one RTX 5090, 32 GB VRAM, no exceptions. The same GPU also needs to serve embeddings, and occasionally run FLUX for image generation (~12 GB).
Part 1: TurboQuant — 3.8× KV Cache Compression
The first innovation that makes any of this possible is TurboQuant (TQ4), our KV cache compression engine. Standard vLLM stores key-value cache in FP16, consuming 256 bytes per token per layer per head. At 48 layers and 8 KV heads (DeepSeek-R1-14B), a single 8192-token sequence eats 768 MB of KV cache.
TurboQuant uses vector quantization to compress KV pairs to 4-bit precision, achieving 3.8× compression with only +2.71% perplexity impact:
| KV Cache Format | Keys | Values | Compression | PPL Impact |
|---|---|---|---|---|
| FP16 (default) | 16-bit | 16-bit | 1× | baseline |
| FP8 | 8-bit | 8-bit | 2× | +0.5% |
| tq-t4nc | 4-bit | 4-bit | 3.8× | +2.71% |
| tq-k3v4nc | 3-bit | 4-bit | ~4.3× | +10.6% |
The implementation lives in aither-kvcache, a vLLM plugin that registers a custom attention backend:
# plugin.py — the entire integration is one plugin registration
register_backend(
AttentionBackendEnum.CUSTOM,
"aither_kvcache.vllm.backend.TurboQuantBackend"
)
From the user's perspective, it's a single flag:
vllm serve your-model --kv-cache-dtype tq-t4nc
The packed format stores 128-dimensional head vectors in 64 bytes (vs 256 FP16) plus 4 bytes of per-vector norms, totaling 68 bytes per KV pair — a hair under 3.8× compression. The encode/decode cycle:
encode: FP16 → vector quantize → (packed_4bit, norms)
decode: (packed_4bit, norms) → approximate FP16 reconstruction
Crucially, this is calibration-free. Unlike GPTQ/AWQ weight quantization that requires calibration datasets, TurboQuant's vector quantization is applied at runtime to whatever KV pairs the model produces. No per-model profiling, no offline quantization step — just flip the flag.
Why This Matters for Dual-Model
Without TQ4, the orchestrator's KV cache at 40K context would need ~6.1 GB (FP16). With TQ4: ~1.6 GB. That's 4.5 GB saved — enough to fit the entire DeepSeek model.
Part 2: The VRAM Budget Problem (And Why vLLM's gpu_memory_utilization Is Broken for Multi-Instance)
vLLM has a parameter called gpu_memory_utilization (default 0.90) that controls how much GPU memory the server claims. At startup, vLLM:
- Snapshots total GPU memory: 31.84 GiB
- Computes requested memory:
total × utilization - Check 1 (startup): Asserts
free_memory >= requested_memory - Loads model weights into GPU
- Profiles CUDA graph memory (if applicable)
- Check 2 (KV budget): Computes
requested_memory - non_kv_memory = available_kv
The problem: both checks assume single-process GPU ownership.
When you run two vLLM instances on the same GPU, instance B sees instance A's VRAM as "used by other processes." Here's the math that broke everything:
Attempt 1: Naive Split (0.35 + 0.35)
Orchestrator (0.35): 31.84 × 0.35 = 11.14 GiB requested
→ Startup: 31.5 GiB free ≥ 11.14 ✓
→ Loads 6.36 GiB weights
→ CUDA graphs: ~14 GiB (piecewise, 16 batch sizes)
→ Total actual usage: ~23 GiB
DeepSeek (0.35): 31.84 × 0.35 = 11.14 GiB requested
→ Startup: 8.8 GiB free ≥ 11.14? ✗ FAIL
Check 1 fails — there isn't 11.14 GiB free because the orchestrator already claimed 23 GiB.
Attempt 2: High Utilization (0.35 + 0.90)
DeepSeek (0.90): 31.84 × 0.90 = 28.66 GiB requested
→ Startup: 8.8 GiB free ≥ 28.66? ✗ FAIL even harder
Higher utilization makes the startup check worse because the requested amount is computed against total GPU, not free GPU.
Attempt 3: Even Higher (0.35 + 0.98)
DeepSeek (0.98): 31.84 × 0.98 = 31.20 GiB requested
→ Startup: 8.8 GiB free ≥ 31.20? ✗ Catastrophic fail
Attempt 4: Low Utilization That Passes Startup (0.35 + 0.50)
DeepSeek (0.50): 31.84 × 0.50 = 15.92 GiB requested
→ Startup: 16.68 GiB free ≥ 15.92? ✓ (barely)
→ Loads 9.4 GiB weights
→ KV budget: 15.92 - 9.4 - overhead = -0.75 GiB? ✗ Check 2 fails
The gpu_memory_utilization parameter creates two non-overlapping checks that cannot simultaneously pass when another process owns significant VRAM. There is no valid value for this parameter in a multi-instance scenario.
Part 3: The Fix — --kv-cache-memory-bytes
Deep in vLLM's gpu_worker.py, there's an alternative code path:
def determine_available_memory(self) -> int:
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
# Skip memory profiling entirely
self.model_runner.profile_run()
logger.info(f"Reserved {kv_cache_memory_bytes} for KV Cache, "
"skipped memory profiling.")
return kv_cache_memory_bytes
# ... normal profiling path that fails for multi-instance
The --kv-cache-memory-bytes flag bypasses vLLM's entire memory budget system. It directly allocates a fixed amount for KV cache without checking utilization, profiling CUDA graphs, or computing budgets. The startup check still runs against gpu_memory_utilization, but we set that low enough to pass:
# The magic configuration for instance B:
VLLM_DS_GPU_UTIL: 0.50 # Just passes startup check (free > total*0.50)
--kv-cache-memory-bytes 1073741824 # 1 GiB KV cache, directly allocated
--enforce-eager # No CUDA graphs (saves ~2-3 GiB)
Final Working Configuration
| Parameter | Orchestrator (8199) | DeepSeek (8176) |
|---|---|---|
| Model | Nemotron-8B-AWQ (4-bit) | DeepSeek-R1-14B-AWQ (4-bit) |
gpu_memory_utilization | 0.35 | 0.50 (startup check only) |
kv-cache-memory-bytes | — (auto) | 1 GiB (direct) |
| KV cache dtype | tq-t4nc (3.8× compression) | tq-t4nc (3.8× compression) |
| CUDA graphs | Piecewise [1,2,4] | Enforce-eager |
max_num_seqs | 4 | 4 |
max_model_len | 40,960 | 8,192 |
| LoRA | 2 adapters | — |
| Tool calling | Hermes parser | — |
| VRAM used | ~20.7 GiB | ~5.6 GiB |
Total: 26.3 GiB used, 5.8 GiB free — enough for Windows compositor, embeddings, and burst image gen.
Part 4: CUDA Graph Memory — The Silent VRAM Killer
During debugging, we discovered that piecewise CUDA graphs are the single largest VRAM consumer — larger than the model weights themselves.
CUDA graphs pre-capture GPU kernel execution paths for fixed batch sizes. vLLM's piecewise mode captures graphs for every power-of-two batch size up to max_num_seqs. With max_num_seqs=16, that's captures for [1, 2, 4, 8, 16] — five graph sets, each containing compiled execution paths for the entire model forward pass.
Our measurements on the orchestrator (Nemotron-8B-AWQ):
max_num_seqs | Graph Captures | VRAM for Graphs | Total VRAM |
|---|---|---|---|
| 16 | [1,2,4,8,16] | ~14.4 GiB | 23.4 GiB |
| 8 | [1,2,4,8] | ~11.2 GiB | 20.2 GiB |
| 4 | [1,2,4] | ~8.0 GiB | 20.7 GiB |
| 1 (enforce-eager) | none | 0 GiB | 9.4 GiB |
Reducing from 16 to 4 max_num_seqs saved 2.7 GiB — the difference between DeepSeek fitting and OOMing.
For the reasoning model (DeepSeek-R1-14B), we use --enforce-eager entirely. Reasoning requests come in short bursts (1-4 concurrent), so CUDA graph speedup is negligible, but the VRAM savings are enormous: 9.4 GiB vs 12+ GiB with graphs.
The lesson: max_num_seqs isn't just a concurrency parameter — it's a VRAM multiplier through CUDA graphs. For memory-constrained deployments, this is the first knob to turn.
Part 5: Graph-Aware KV Cache Eviction
Standard vLLM uses LRU eviction for KV cache — when blocks are full, the least-recently-used block gets evicted. This is fine for a chatbot, catastrophic for an agent.
In AitherOS, the orchestrator's context contains:
- System prompt — agent personality, capabilities, tool definitions (~2K tokens)
- Tool schemas — JSON schemas for 50+ tools (~4K tokens)
- Conversation history — user messages, assistant responses
- Tool results — structured data from executed tools
LRU eviction would happily evict the system prompt (set once, "old") to make room for new tool results. The model then forgets how to use tools.
Our eviction_plugin.py monkey-patches vLLM's FreeKVCacheBlockQueue.popleft() with importance-weighted scoring:
# Block importance scores by source type
IMPORTANCE = {
"system": 0.95, # Almost never evict
"tools": 0.85, # Tool schemas are critical
"user": 0.60, # Recent user context
"assistant": 0.30 # Can be regenerated
}
The eviction policy also tracks co-attendance — which KV blocks are accessed together during attention. If block A and block B are frequently co-attended, evicting one makes the other useless, so they're evicted as a pair or not at all.
Protected sources (system and tools) are pinned — they never enter the eviction queue regardless of memory pressure.
Part 6: Automated VRAM Zone Management
The hardest part isn't running two models — it's managing the GPU when a third workload appears. AitherOS supports image generation via FLUX (12 GB VRAM) and ComfyUI. When a user requests an image, the system must:
- Stop the reasoning container to free ~5.6 GiB
- Load and run the image pipeline
- Restart reasoning when the next reasoning request arrives
This is handled by two cooperating systems:
MicroScheduler VRAM Zones
The MicroScheduler (our central inference gateway, ~11K lines of Python) maintains VRAM zones:
async def _stop_vllm_reasoning(self):
"""Docker-stop the reasoning container to free VRAM for image gen."""
container = self._resolve_local_docker_container(
os.environ["AITHER_VLLM_SWAP_URL"]
)
async with httpx.AsyncClient() as client:
await client.post(
f"http://aitheros-docker-proxy:2375/containers/{container}/stop?t=10"
)
async def _ensure_vllm_running_lazy(self):
"""Cold-start reasoning on first request after image gen."""
# Docker-start + health poll + request queuing
VRAMYield — External App VRAM Reclamation
When an external app (like a game) needs VRAM, AitherVRAMYield (port 8155) coordinates a full yield:
async def _free_vllm(self):
"""Docker-stop reasoning container via the Docker proxy."""
container = os.environ.get(
"AITHER_VLLM_REASONING_CONTAINER", "aither-vllm-deepseek"
)
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"http://aitheros-docker-proxy:2375/containers/{container}/stop?t=10"
)
if resp.status_code in (204, 304):
self._status.vllm_freed = True
self._status.freed_mb += 11_200 # ~11.2 GB estimated
async def _restart_vllm(self):
"""Docker-start reasoning container back up."""
# POST /containers/{container}/start → 204/304
The key architectural decision: container stop/start instead of model load/unload. Raw vLLM has no /unload endpoint, and container lifecycle is already managed by Docker Compose with health checks. A container restart takes ~8 seconds (model load from disk cache), which is acceptable for an asynchronous operation.
Cloud Fallback During VRAM Yield
While the reasoning container is stopped, requests don't fail — they fall through to the cloud fallback chain:
# cloud_providers.yaml — reasoning fallback chain
reasoning:
- vllm_swap # Local DeepSeek (if running)
- vast_serverless # Vast.ai serverless (scale-to-zero)
- deepseek # DeepSeek API
- anthropic # Claude (last resort)
The Vast.ai Serverless bridge (VastServerlessBridge.py) provides scale-to-zero reasoning at ~$0.20/M tokens — an order of magnitude cheaper than direct API calls, while maintaining the same model (DeepSeek-R1).
Part 7: The Docker Proxy Pattern
A subtle but critical piece of infrastructure: containers can't talk to the Docker daemon directly (no socket mount for security). Instead, we run a Docker API proxy (aitheros-docker-proxy:2375) that exposes a filtered subset of the Docker Engine API to the internal network.
MicroScheduler and VRAMYield call this proxy to stop/start/inspect containers:
MicroScheduler → http://aitheros-docker-proxy:2375/containers/aither-vllm-deepseek/stop
→ Docker Engine → container stops → VRAM freed
This avoids mounting /var/run/docker.sock into application containers (a security anti-pattern) while still enabling container lifecycle management from within the Docker network.
Part 8: GPU Profiles
The VRAM budget is defined in gpu_profiles.yaml, allowing different presets for different workloads:
profiles:
dev:
vram_strategy: shared
gpu_total_gb: 32
target_utilization_pct: 76 # 40% orch + 35% reasoning + 1% embed
models:
orchestrator:
vram_pct: 40
cuda_graphs: true
reasoning:
enabled: true
container: aither-vllm-deepseek
vram_pct: 35
sleep_mode: true # VRAM yield enabled
production:
target_utilization_pct: 92 # All VRAM to orchestrator
models:
orchestrator:
vram_pct: 92
reasoning:
enabled: false # Cloud-only reasoning
creative:
target_utilization_pct: 90 # Orch 35% + ComfyUI 50% + embed 5%
Switching profiles is a one-command operation that adjusts container environment variables and restarts the affected services.
Performance Results
Dual-Model Throughput
| Metric | Orchestrator | DeepSeek Reasoning |
|---|---|---|
| Time-to-first-token | ~45ms | ~120ms |
| Tokens/sec (decode) | ~42 tok/s | ~28 tok/s |
| Max concurrent seqs | 4 | 4 |
| Max context length | 40,960 | 8,192 |
| Cold start time | ~45s (CUDA graphs) | ~8s (enforce-eager) |
VRAM Efficiency
| Component | VRAM (GiB) | % of 32 GB |
|---|---|---|
| Orchestrator weights | 6.4 | 20% |
| Orchestrator KV cache (TQ4) | ~1.6 | 5% |
| Orchestrator CUDA graphs | ~8.0 | 25% |
| Orchestrator overhead | ~4.7 | 15% |
| DeepSeek weights | 9.4 | 29% |
| DeepSeek KV cache (1 GiB) | 1.0 | 3% |
| Embeddings | ~0.4 | 1% |
| Free | ~5.8 | 18% |
KV Cache Capacity (TQ4 vs FP16)
| Model | FP16 KV @ 8K ctx | TQ4 KV @ 8K ctx | Savings |
|---|---|---|---|
| Nemotron-8B | 2.0 GiB | 0.53 GiB | 1.47 GiB |
| DeepSeek-R1-14B | 3.0 GiB | 0.79 GiB | 2.21 GiB |
| Combined savings | 3.68 GiB |
Without TQ4, the models wouldn't fit simultaneously. The 3.68 GiB saved by KV cache compression is the difference between a dual-model system and a single-model system with cloud fallback.
Lessons Learned
1. gpu_memory_utilization Doesn't Mean What You Think
In a multi-instance setup, gpu_memory_utilization is calculated against total GPU memory, not free GPU memory. There is no valid value for this parameter when significant VRAM is already allocated by another process. Use --kv-cache-memory-bytes for the second instance.
2. CUDA Graphs Are a VRAM Multiplier
max_num_seqs isn't just about concurrency — it determines how many CUDA graph captures are created. Each capture is essentially a frozen copy of the model's execution path. Going from 16 to 4 max sequences saved 6.4 GiB on an 8B model. For memory-constrained deployments, this is the highest-leverage knob.
3. Enforce-Eager Is Fine for Burst Workloads
CUDA graphs provide significant speedup for sustained throughput (many concurrent sequences, continuous batching). For burst workloads like reasoning (1-4 sequences, intermittent), the throughput difference is negligible but the VRAM savings are enormous. Don't pay the CUDA graph tax on secondary models.
4. Container Stop/Start > Model Load/Unload
We initially tried to build a /unload endpoint for VRAM reclamation. Container-level lifecycle management via Docker is simpler, more reliable, and leverages existing health checks and restart policies. The cold-start penalty (~8s for enforce-eager) is acceptable for an async operation.
5. KV Cache Compression Is the Real Enabler
Weight quantization (AWQ, GPTQ) gets all the attention, but KV cache compression is what enables multi-model GPU sharing. Weights are a fixed cost; KV cache scales with context length and concurrency. At 40K context, TQ4 saves 4.5 GiB on the orchestrator alone — more than the total weight reduction from FP16→AWQ.
The Stack
For reference, here's the complete inference stack running on a single RTX 5090:
┌─────────────────────────────────────────────────────┐
│ RTX 5090 (32 GB) │
│ │
│ ┌──────────────────┐ ┌──────────────────────────┐ │
│ │ Orchestrator │ │ DeepSeek Reasoning │ │
│ │ Nemotron-8B-AWQ │ │ DeepSeek-R1-14B-AWQ │ │
│ │ │ │ │ │
│ │ TQ4 KV cache │ │ TQ4 KV cache (1 GiB) │ │
│ │ CUDA graphs [1,2,4]│ │ Enforce-eager │ │
│ │ 40K context │ │ 8K context │ │
│ │ LoRA adapters │ │ Burst reasoning │ │
│ │ Tool calling │ │ │ │
│ │ Port 8199 │ │ Port 8176 │ │
│ │ ~20.7 GiB │ │ ~5.6 GiB │ │
│ └──────────────────┘ └──────────────────────────┘ │
│ │
│ ┌─────────────┐ ┌──────────────────────────────┐ │
│ │ Embeddings │ │ Free VRAM (~5.8 GiB) │ │
│ │ Port 8209 │ │ → FLUX image gen (on demand) │ │
│ │ ~0.4 GiB │ │ → Gaming VRAM yield │ │
│ └─────────────┘ └──────────────────────────────┘ │
│ │
│ MicroScheduler: VRAM zones, lazy-wake, cloud fallback│
│ VRAMYield: Docker stop/start via proxy │
│ Cloud fallback: Vast.ai Serverless → DeepSeek API │
└─────────────────────────────────────────────────────┘
Reproducing This Setup
The key configurations:
# Instance A (Primary — starts first, owns CUDA graphs)
vllm serve cyankiwi/Nemotron-Orchestrator-8B-AWQ-4bit \
--kv-cache-dtype tq-t4nc \
--gpu-memory-utilization 0.35 \
--max-model-len 40960 \
--max-num-seqs 4 \
--compilation-config '{"cudagraph_mode":"piecewise","max_cudagraph_capture_size":4}' \
--enable-prefix-caching \
--port 8199
# Instance B (Secondary — starts second, bypasses memory budget)
vllm serve casperhansen/deepseek-r1-distill-qwen-14b-awq \
--kv-cache-dtype tq-t4nc \
--gpu-memory-utilization 0.50 \
--kv-cache-memory-bytes 1073741824 \
--max-model-len 8192 \
--max-num-seqs 4 \
--enforce-eager \
--enable-prefix-caching \
--port 8176
The critical insight: gpu_memory_utilization on the second instance only needs to pass the startup free-memory check (free >= total × util). Set it as low as possible. The actual KV cache allocation is controlled entirely by --kv-cache-memory-bytes.
What's Next
- TriAttention: Our spectral KV cache compression engine achieves ~10× compression (26 bytes/token vs 256 FP16) by decomposing RoPE attention into trigonometric series and storing only the dominant frequency components. Currently in calibration testing.
- Dynamic KV budget: Instead of fixed
--kv-cache-memory-bytes, dynamically resize based on current free VRAM reported by the Docker proxy. - Unified CUDA graph pools: Sharing CUDA graph memory across instances for batch sizes that don't overlap in time.
The TurboQuant vLLM fork, graph-aware eviction plugin, and VRAM zone management are part of AitherOS. The --kv-cache-memory-bytes flag is stock vLLM 0.19+ — the most important three words in the multi-instance story, buried in a help page nobody reads.