AI Model Compression and Optimization
Deploy AI models efficiently by reducing model size and inference latency. Covers quantization, pruning, knowledge distillation, ONNX runtime, TensorRT optimization, and the trade-offs between model quality and deployment constraints.
Production AI models must run within strict constraints — latency budgets under 100ms, memory limits on edge devices, infrastructure cost targets. Model compression enables deploying state-of-the-art models in real-world environments where raw compute is limited or expensive.
Compression Techniques
| Technique | Size Reduction | Quality Impact | Speedup |
|---|---|---|---|
| Quantization (INT8) | 4x smaller | < 1% accuracy loss | 2-4x faster |
| Pruning (unstructured) | 2-10x smaller | 1-3% accuracy loss | Variable |
| Knowledge distillation | 3-10x smaller | 1-5% accuracy loss | 3-10x faster |
| Weight sharing | 2-4x smaller | < 1% accuracy loss | 1.5-2x faster |
| Low-rank factorization | 2-4x smaller | 1-2% accuracy loss | 2-3x faster |
Quantization
Post-Training Quantization
import torch
# Load full precision model (FP32)
model = torch.load("model_fp32.pt")
# Dynamic quantization (weights only)
quantized_model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear}, # Quantize linear layers
dtype=torch.qint8
)
# Compare sizes
print(f"FP32 model: {os.path.getsize('model_fp32.pt') / 1e6:.1f} MB")
print(f"INT8 model: {os.path.getsize('model_int8.pt') / 1e6:.1f} MB")
# FP32 model: 438.0 MB
# INT8 model: 112.0 MB (4x smaller)
Quantization-Aware Training (QAT)
# Simulate quantization during training for better accuracy
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)
# Train with quantization simulation
for epoch in range(num_epochs):
for batch in dataloader:
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Convert to actual quantized model
model_quantized = torch.quantization.convert(model, inplace=False)
Knowledge Distillation
# Teacher: Large, accurate model
# Student: Small, fast model
# Goal: Student learns from teacher's outputs
class DistillationLoss(nn.Module):
def __init__(self, temperature=4.0, alpha=0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha
def forward(self, student_logits, teacher_logits, targets):
# Soft targets: Learn from teacher's probability distribution
soft_loss = F.kl_div(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1),
reduction='batchmean'
) * (self.temperature ** 2)
# Hard targets: Still learn from ground truth
hard_loss = F.cross_entropy(student_logits, targets)
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
ONNX Runtime Optimization
import onnxruntime as ort
# Export PyTorch model to ONNX
torch.onnx.export(
model,
dummy_input,
"model.onnx",
opset_version=17,
dynamic_axes={'input': {0: 'batch_size'}}
)
# Optimize with ONNX Runtime
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.optimized_model_filepath = "model_optimized.onnx"
session = ort.InferenceSession(
"model.onnx",
session_options,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
# Run inference
result = session.run(None, {"input": input_data})
Deployment Decision Matrix
Target: Cloud GPU (A100)
→ FP16 inference (minimal optimization needed)
→ Batch processing for throughput
Target: Cloud CPU (standard instances)
→ INT8 quantization + ONNX Runtime
→ Latency: 10-50ms per inference
Target: Edge device (mobile, IoT)
→ Knowledge distillation (small student model)
→ INT8 quantization + TensorFlow Lite / Core ML
→ Latency: 5-20ms, model size < 50MB
Target: Browser (WebAssembly)
→ Tiny model + ONNX.js / TF.js
→ Model size < 10MB
→ Latency: 20-100ms
Anti-Patterns
| Anti-Pattern | Consequence | Fix |
|---|---|---|
| Deploy FP32 in production | 4x memory waste, 2-4x slower | Quantize to INT8 or FP16 |
| Quantize without validation | Silent accuracy degradation | Benchmark on evaluation set before deploying |
| One model for all form factors | Over-powered for simple cases | Right-size model per deployment target |
| Skip profiling | Optimizing wrong bottleneck | Profile before optimizing (GPU util, memory, latency) |
| No A/B testing compressed model | Cannot measure real-world impact | Shadow or canary test compressed model |
Model compression is the bridge between research accuracy and production constraints. The best model is not the most accurate — it is the most accurate model that meets your latency, cost, and resource requirements.