44. AI Hardware Acceleration & On-Device AI
Chapter 44 — AI Hardware Acceleration & On-Device AI
Overview
Optimize performance and cost with accelerators and compilers; enable private, low-latency on-device AI.
Modern AI workloads demand specialized hardware and optimized runtimes to achieve acceptable performance and economics. This chapter covers hardware acceleration strategies from cloud GPUs/TPUs to on-device inference, compiler optimizations, and the architectural tradeoffs that shape deployment decisions.
Topics
- GPU/TPU: memory, kernel fusion, mixed precision, batching.
- Compilers: TensorRT, XLA, ONNX Runtime, quantization and pruning.
- On-device: model size limits, privacy, offline-first design.
Deliverables
- Performance benchmarks, deployment profiles, and cost model.
Why It Matters
The right acceleration strategy can reduce cost per request, improve latency, and enable private on-device experiences—often unlocking use cases that are infeasible on CPU or in the cloud.
Business Impact:
- Cost Optimization: A well-tuned GPU deployment can be 10-100x cheaper per inference than CPU
- Latency: On-device models eliminate network round-trips, achieving <100ms response times
- Privacy: Local inference keeps sensitive data on-device, meeting regulatory requirements
- Scale: Proper batching and quantization enable serving millions of requests with manageable infrastructure
- ROI: Hardware acceleration investments typically pay back in 6-12 months through reduced cloud costs
Hardware Architecture Comparison
GPU vs TPU vs CPU: When to Use What
| Hardware | Best For | Strengths | Limitations | Cost Profile |
|---|---|---|---|---|
| CPU | Low-volume inference, batch jobs | Universal availability, flexible | Slow for large models | Low upfront, high per-inference |
| GPU (NVIDIA A100/H100) | Training, high-throughput serving | Mature ecosystem, versatile | Power consumption, cost | Medium-high |
| TPU (v4/v5) | Large-scale training, batch inference | Training efficiency, custom ops | Vendor lock-in, limited flexibility | Competitive at scale |
| Edge NPU/DSP | On-device mobile/IoT | Power efficient, always available | Model size limits, quantization required | Negligible per-inference |
| AWS Inferentia/Trainium | Cloud inference at scale | Cost-optimized for inference | Limited model support | Low at scale |
Hardware Specification Deep-Dive
graph TB subgraph "Cloud Acceleration Stack" A[Model Input] --> B{Batch Size} B --> C[GPU Memory] C --> D[Tensor Cores] D --> E[Kernel Fusion] E --> F[Mixed Precision FP16/BF16] F --> G[Output] end subgraph "On-Device Stack" H[Model Input] --> I[Quantization INT8/INT4] I --> J[NPU/DSP] J --> K[Model Cache] K --> L[Output] end style D fill:#90EE90 style J fill:#FFB6C1
Hardware & Runtime Deep-Dive
GPU/TPU Optimization Strategies
Memory Bandwidth Optimization
Modern accelerators are often memory-bandwidth bound rather than compute-bound. Understanding data movement is critical:
# Example: Memory-efficient attention with Flash Attention
import torch
from flash_attn import flash_attn_func
def standard_attention(Q, K, V):
# Memory: O(N²) - stores full attention matrix
attention_scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
attention_probs = F.softmax(attention_scores, dim=-1)
return attention_probs @ V
def flash_attention(Q, K, V):
# Memory: O(N) - tiled computation, no materialized attention matrix
# 2-4x faster, enables longer sequences
return flash_attn_func(Q, K, V)
# Performance comparison
# Standard: 512 tokens = 2GB memory, 45ms
# Flash: 512 tokens = 0.5GB memory, 18ms
# Standard: 2048 tokens = OOM
# Flash: 2048 tokens = 2GB memory, 120ms
Tensor Cores and Mixed Precision
NVIDIA Tensor Cores provide 8-16x speedup for specific operations when using FP16/BF16:
import torch
from torch.cuda.amp import autocast, GradScaler
# Automatic Mixed Precision (AMP) training
model = YourModel().cuda()
optimizer = torch.optim.Adam(model.parameters())
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Automatically use FP16 where safe
with autocast():
output = model(batch)
loss = criterion(output, targets)
# Scale loss to prevent underflow
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# Results: 2-3x speedup, same accuracy, 40% memory reduction
Kernel Fusion and Operator Optimization
# Unfused operations - multiple kernel launches
def unfused_gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# Fused kernel - single GPU operation
import triton
import triton.language as tl
@triton.jit
def fused_gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
# Compute GELU in single kernel
output = x * 0.5 * (1.0 + tl.erf(x / 1.41421356237))
tl.store(output_ptr + offsets, output, mask=mask)
# Performance: Unfused=2.1ms, Fused=0.7ms (3x faster)
Batch Sizing Strategy
| Batch Size | Throughput (req/sec) | Latency p50 | Latency p95 | Memory Usage | Cost per 1k requests |
|---|---|---|---|---|---|
| 1 | 45 | 22ms | 28ms | 4GB | $2.40 |
| 8 | 280 | 28ms | 35ms | 8GB | $0.38 |
| 32 | 890 | 36ms | 48ms | 16GB | $0.12 |
| 128 | 2100 | 61ms | 89ms | 32GB | $0.05 |
| 256 | 2400 | 107ms | 152ms | 40GB | $0.04 |
Key Insight: Optimal batch size balances throughput and latency. For user-facing apps, batch=8-32 is often ideal; for offline jobs, maximize batch size.
Quantization: Accuracy vs Performance Tradeoffs
Quantization reduces model precision from FP32/FP16 to INT8/INT4, dramatically improving speed and memory:
Quantization Methods Comparison
| Method | Description | Accuracy Impact | Speed Gain | Memory Reduction |
|---|---|---|---|---|
| Post-Training Dynamic | Quantize weights, compute activations in FP32 | Minimal (<1%) | 2-3x | 4x |
| Post-Training Static | Quantize weights + activations, requires calibration | Low (1-2%) | 3-4x | 4x |
| Quantization-Aware Training | Train with quantization simulation | Negligible | 3-4x | 4x |
| 4-bit Quantization (GPTQ/AWQ) | Extreme compression for LLMs | Low-Medium (2-5%) | 4-6x | 8x |
| 1-bit (BitNet) | Binary weights | High (experimental) | 10x+ | 32x |
Practical Quantization Example
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
# Original model: 13B parameters × 2 bytes (FP16) = 26GB
model_id = "meta-llama/Llama-2-13b-hf"
# Method 1: 8-bit quantization with bitsandbytes
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
model_8bit = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto"
)
# Result: 13GB memory, 2.5x faster, <1% accuracy drop
# Method 2: 4-bit GPTQ quantization
quantize_config = BaseQuantizeConfig(
bits=4,
group_size=128,
desc_act=False,
)
model_4bit = AutoGPTQForCausalLM.from_pretrained(
model_id,
quantize_config=quantize_config
)
# Result: 7GB memory, 4x faster, 2-3% accuracy drop
# Accuracy comparison on MMLU benchmark
# FP16: 54.2% accuracy
# 8-bit: 54.0% accuracy (-0.2%)
# 4-bit GPTQ: 52.8% accuracy (-1.4%)
# 4-bit AWQ: 53.5% accuracy (-0.7%)
Outlier Handling in Quantization
# Challenge: Outlier activations destroy quantization quality
# Solution: Mixed precision + outlier extraction
import torch.nn as nn
class MixedPrecisionLinear(nn.Module):
def __init__(self, in_features, out_features, outlier_threshold=6.0):
super().__init__()
self.weight_int8 = nn.Parameter(torch.zeros(out_features, in_features, dtype=torch.int8))
self.scale = nn.Parameter(torch.ones(out_features))
self.outlier_cols = []
self.outlier_weights = None
def forward(self, x):
# Separate outlier dimensions (keeps FP16)
outlier_output = torch.matmul(x[:, self.outlier_cols], self.outlier_weights)
# INT8 matmul for remaining dimensions
regular_output = torch.matmul(x, self.weight_int8.float()) * self.scale
return regular_output + outlier_output
# Maintains 99%+ accuracy while achieving 3-4x speedup
Compiler Optimization
Modern AI compilers perform graph-level optimizations that are infeasible to implement manually:
Compiler Stack Comparison
graph LR A[Model Definition] --> B{Framework} B -->|PyTorch| C[TorchScript] B -->|TensorFlow| D[XLA] B -->|ONNX| E[ONNX Runtime] C --> F[TensorRT] D --> F E --> F F --> G[Optimized Inference] C --> H[TVM] D --> H E --> H H --> I[Multi-Backend Deployment] style F fill:#90EE90 style H fill:#87CEEB
| Compiler | Best For | Key Optimizations | Limitations |
|---|---|---|---|
| TensorRT | NVIDIA GPU inference | Layer fusion, precision calibration, kernel auto-tuning | NVIDIA-only, complex setup |
| XLA | TPU, GPU training/inference | Operator fusion, memory planning, compilation cache | Limited op coverage, cold-start |
| ONNX Runtime | Cross-platform inference | Execution providers, graph optimization | Variable optimization quality |
| TVM | Diverse hardware targets | Auto-tuning, custom backends | Steep learning curve, immature |
| OpenVINO | Intel CPU/GPU/VPU | Model optimization, quantization | Intel ecosystem only |
TensorRT Optimization Example
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
# Convert PyTorch model to TensorRT
def build_engine(onnx_path, fp16_mode=True, int8_mode=False):
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# Parse ONNX model
with open(onnx_path, 'rb') as model:
parser.parse(model.read())
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
if fp16_mode:
config.set_flag(trt.BuilderFlag.FP16)
if int8_mode:
config.set_flag(trt.BuilderFlag.INT8)
# Requires calibration dataset
config.int8_calibrator = MyCalibrator()
# Build and optimize
engine = builder.build_engine(network, config)
return engine
# Performance comparison on ResNet-50 inference (batch=32)
# PyTorch eager: 42ms
# PyTorch JIT: 35ms
# TensorRT FP32: 18ms
# TensorRT FP16: 8ms
# TensorRT INT8: 5ms
Graph Optimization Patterns
# Before optimization: 5 separate operations
def unoptimized_layer_norm(x, weight, bias, eps=1e-5):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True)
normalized = (x - mean) / torch.sqrt(var + eps)
scaled = normalized * weight
output = scaled + bias
return output
# After compiler fusion: 1 fused kernel
# Automatically performed by XLA/TensorRT
# 3-4x faster, reduced memory bandwidth
Memory Management for Large Models
KV-Cache Management for LLMs
# Standard attention caches key/value tensors
# Memory = batch_size × seq_len × num_layers × hidden_size × 2
# Problem: 70B model, batch=32, seq=2048
# KV cache = 32 × 2048 × 80 × 8192 × 2 × 2 bytes = 160GB!
# Solution 1: Paged Attention (vLLM)
class PagedAttention:
def __init__(self, block_size=16, num_blocks=1000):
self.block_size = block_size
# Pre-allocate block pool
self.kv_blocks = torch.zeros(num_blocks, block_size, hidden_size)
self.block_tables = {} # Maps sequence_id -> [block_ids]
def append_tokens(self, seq_id, new_kv):
# Allocate blocks on demand
blocks = self.block_tables.get(seq_id, [])
# Share blocks across sequences (prefix caching)
# Achieve 3-5x memory reduction
# Solution 2: Multi-Query Attention (MQA)
# Share K/V across attention heads
# Memory reduction: 8x for 32-head models
# Solution 3: Grouped-Query Attention (GQA)
# Hybrid: group heads, share K/V within groups
# Used in Llama 2: 8 K/V heads for 32 query heads = 4x reduction
CPU-GPU Transfer Optimization
# Slow: synchronous transfers
for batch in dataloader:
batch_gpu = batch.cuda() # Transfer
output = model(batch_gpu) # Compute
# GPU idle during transfer
# Fast: asynchronous transfers with pinned memory
dataloader = DataLoader(
dataset,
batch_size=32,
pin_memory=True, # Use pinned memory
num_workers=4
)
# Create CUDA streams
stream = torch.cuda.Stream()
for batch in dataloader:
with torch.cuda.stream(stream):
batch_gpu = batch.cuda(non_blocking=True) # Async transfer
torch.cuda.current_stream().wait_stream(stream)
output = model(batch_gpu)
# Result: Overlap transfer with computation, 20-30% speedup
On-Device AI
On-device inference presents unique constraints: model size, power consumption, thermal management, and intermittent connectivity.
Device Constraints by Platform
| Platform | Typical RAM | Model Size Limit | NPU/DSP | Power Budget | Connectivity |
|---|---|---|---|---|---|
| High-end Phone | 8-12GB | 1-2GB | Yes (50-100 TOPS) | 5-8W burst | Intermittent |
| Mid-range Phone | 4-6GB | 200-500MB | Limited (10-20 TOPS) | 3-5W burst | Intermittent |
| Smartwatch | 1-2GB | 50-100MB | Basic | 0.5-1W | Sporadic |
| IoT Device | 256MB-1GB | 10-50MB | Sometimes | 0.1-0.5W | Edge/offline |
| AR Glasses | 2-4GB | 200-500MB | Yes | 2-3W | Tethered/edge |
On-Device Architecture Pattern
graph TB A[User Input] --> B{Network Available?} B -->|Yes| C[Cloud Model] B -->|No| D[On-Device Model] C --> E{Confidence > Threshold?} E -->|Yes| F[Return Result] E -->|No| G[Hybrid: On-Device + Cloud] D --> H{Result Quality Check} H -->|Good| F H -->|Uncertain| I[Queue for Cloud Sync] subgraph "On-Device Stack" D --> J[Quantized Model<br/>INT8/INT4] J --> K[NPU Inference] K --> L[Local Cache] end subgraph "Cloud Stack" C --> M[Full Precision Model] M --> N[GPU Inference] end style D fill:#FFB6C1 style C fill:#90EE90
Model Compression Techniques
Technique Comparison
| Technique | Size Reduction | Accuracy Impact | Inference Speed | Training Required |
|---|---|---|---|---|
| Pruning (Structured) | 2-4x | Low (1-3%) | 2-3x faster | Fine-tuning |
| Pruning (Unstructured) | 5-10x | Medium (3-8%) | Limited speedup | Fine-tuning |
| Distillation | 3-10x | Low-Medium (2-5%) | 3-10x faster | Full retraining |
| Quantization | 2-8x | Low (1-3%) | 2-4x faster | Optional (QAT) |
| LoRA Adapters | Base model + 1% | Minimal | Same as base | Adapter training |
| Neural Architecture Search | Variable | Optimized | Variable | Full retraining |
Knowledge Distillation Example
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationTrainer:
def __init__(self, teacher_model, student_model, temperature=3.0, alpha=0.7):
self.teacher = teacher_model.eval()
self.student = student_model
self.temperature = temperature
self.alpha = alpha # Weight for distillation loss
def distillation_loss(self, student_logits, teacher_logits, labels):
# Soft targets from teacher
soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
distillation_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean')
distillation_loss *= self.temperature ** 2
# Hard targets (ground truth)
student_loss = F.cross_entropy(student_logits, labels)
# Combined loss
return self.alpha * distillation_loss + (1 - self.alpha) * student_loss
def train_step(self, inputs, labels):
with torch.no_grad():
teacher_logits = self.teacher(inputs)
student_logits = self.student(inputs)
loss = self.distillation_loss(student_logits, teacher_logits, labels)
return loss
# Example: BERT-base (110M) → DistilBERT (66M)
# Size: 40% reduction
# Speed: 60% faster
# Accuracy: 97% of original on GLUE
LoRA for On-Device Personalization
from peft import LoraConfig, get_peft_model
# Base model stays frozen and shared
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# LoRA adds small trainable adapters
lora_config = LoraConfig(
r=8, # Low rank
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
)
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: 4.2M || all params: 6,742M || trainable: 0.06%
# On-device: Ship base model + personalized adapters
# Base model: 7GB (shared across users)
# Per-user adapter: 8MB (unique)
# Total: 7GB + 8MB per user vs. 7GB per user
Privacy-Preserving On-Device AI
# Pattern: On-device feature extraction + Cloud inference
class PrivacyPreservingPipeline:
def __init__(self):
# On-device: extract privacy-safe features
self.feature_extractor = MobileNetV3(pretrained=True)
# Cloud: task-specific head
self.cloud_classifier = None
def process_on_device(self, image):
# Extract embeddings locally
with torch.no_grad():
features = self.feature_extractor(image)
# Features are abstract, don't reveal image content
return features
def process_cloud(self, features):
# Send only features, not raw image
response = requests.post(
'https://api.example.com/classify',
json={'features': features.tolist()}
)
return response.json()
# Privacy benefits:
# - Raw images never leave device
# - Network bandwidth: 2MB image → 2KB features (1000x reduction)
# - Comply with GDPR, CCPA data minimization
Offline-First Design
class OfflineFirstModel:
def __init__(self):
self.online_model_url = "https://api.example.com/model"
self.offline_model_path = "/data/local/model.tflite"
self.model_version = "v2.3"
def predict(self, input_data):
# 1. Try offline model first
try:
result = self.offline_inference(input_data)
confidence = result['confidence']
# 2. Use online model if confidence is low
if confidence < 0.7 and self.is_online():
online_result = self.online_inference(input_data)
if online_result['confidence'] > confidence:
result = online_result
# Update offline model in background
self.check_model_update()
return result
except Exception as e:
# 3. Fallback to rule-based system
return self.rule_based_fallback(input_data)
def check_model_update(self):
# Download new model in background when on WiFi
if self.is_wifi() and not self.is_battery_low():
new_version = self.get_latest_version()
if new_version > self.model_version:
self.download_model(new_version)
# User experience:
# - Always functional, even offline
# - Seamless quality upgrades when online
# - Respects user's data plan and battery
Evaluation
Comprehensive Benchmarking Framework
import time
import psutil
import numpy as np
from dataclasses import dataclass
from typing import List
@dataclass
class BenchmarkResult:
throughput_rps: float
latency_p50_ms: float
latency_p95_ms: float
latency_p99_ms: float
memory_peak_mb: float
gpu_utilization_pct: float
cost_per_1k_requests: float
accuracy_metric: float
class ModelBenchmark:
def __init__(self, model, test_dataset, warmup_steps=100):
self.model = model
self.test_dataset = test_dataset
self.warmup_steps = warmup_steps
def benchmark(self, batch_size=1, num_requests=1000) -> BenchmarkResult:
# Warmup
for _ in range(self.warmup_steps):
self.model(self.test_dataset[0])
latencies = []
memory_usage = []
start_time = time.time()
for i in range(num_requests):
# Measure latency
request_start = time.time()
output = self.model(self.test_dataset[i % len(self.test_dataset)])
latencies.append((time.time() - request_start) * 1000)
# Measure memory
if i % 10 == 0:
memory_usage.append(psutil.Process().memory_info().rss / 1024 / 1024)
total_time = time.time() - start_time
return BenchmarkResult(
throughput_rps=num_requests / total_time,
latency_p50_ms=np.percentile(latencies, 50),
latency_p95_ms=np.percentile(latencies, 95),
latency_p99_ms=np.percentile(latencies, 99),
memory_peak_mb=max(memory_usage),
gpu_utilization_pct=self.get_gpu_utilization(),
cost_per_1k_requests=self.calculate_cost(total_time, num_requests),
accuracy_metric=self.evaluate_accuracy()
)
# Example comparison
configs = [
("FP32 CPU", model_fp32, 'cpu'),
("FP16 GPU", model_fp16, 'cuda'),
("INT8 GPU", model_int8, 'cuda'),
("TensorRT INT8", model_trt, 'cuda'),
]
results = []
for name, model, device in configs:
benchmark = ModelBenchmark(model.to(device), test_dataset)
result = benchmark.benchmark(batch_size=32)
results.append((name, result))
print(f"{name}: {result.latency_p95_ms:.1f}ms p95, ${result.cost_per_1k_requests:.2f}/1k")
Accuracy vs Performance Tradeoff Matrix
| Configuration | Accuracy (MMLU) | Latency p95 | Throughput | Memory | Cost/1M tokens |
|---|---|---|---|---|---|
| Llama-2-70B FP16 | 68.9% | 2800ms | 12 req/s | 140GB | $2.10 |
| Llama-2-70B INT8 | 68.5% | 1200ms | 28 req/s | 70GB | $0.95 |
| Llama-2-70B INT4 | 67.2% | 650ms | 52 req/s | 35GB | $0.52 |
| Llama-2-13B FP16 | 54.8% | 480ms | 68 req/s | 26GB | $0.38 |
| Llama-2-13B INT4 | 53.2% | 180ms | 180 req/s | 7GB | $0.14 |
| Llama-2-7B INT4 | 45.3% | 95ms | 340 req/s | 4GB | $0.07 |
Decision Framework:
- Accuracy-critical (legal, medical): FP16 or INT8 only
- User-facing low-latency: INT4 with continuous monitoring
- Batch processing: Maximize throughput with INT4
- Cost-sensitive: Smallest model that meets accuracy threshold
Case Study: Mobile Document Scanner
Problem Statement
A fintech company needed to extract data from financial documents (bank statements, invoices, receipts) in their mobile app. Requirements:
- Sub-300ms end-to-end latency
- Work offline (poor connectivity in rural areas)
- Privacy: documents must not leave device
- Accuracy: >95% field extraction accuracy
Architecture
graph TB A[Document Photo] --> B[On-Device Pipeline] subgraph "On-Device Processing" B --> C[Layout Detection<br/>MobileNetV3 + FPN<br/>INT8, 15MB] C --> D[Text Detection<br/>CRAFT detector<br/>INT8, 8MB] D --> E[OCR<br/>TrOCR-small quantized<br/>INT8, 35MB] E --> F[Field Extraction<br/>DistilBERT + rules<br/>INT8, 25MB] end F --> G{Confidence Check} G -->|High| H[Return Results] G -->|Low| I{Online?} I -->|Yes| J[Cloud VLM Verification<br/>GPT-4V] I -->|No| K[Flag for Review] J --> H K --> L[Sync When Online] style B fill:#FFB6C1 style J fill:#90EE90
Implementation Details
# On-device pipeline
class DocumentScanner:
def __init__(self):
# Total model size: ~85MB (fits comfortably in memory)
self.layout_detector = load_quantized_model('layout_det_int8.tflite')
self.text_detector = load_quantized_model('craft_int8.tflite')
self.ocr_model = load_quantized_model('trocr_small_int8.tflite')
self.field_extractor = load_quantized_model('distilbert_int8.tflite')
def process_document(self, image):
# Stage 1: Layout detection (40ms)
layout_boxes = self.layout_detector(image)
# Stage 2: Text detection (60ms)
text_regions = []
for box in layout_boxes:
region_img = crop_image(image, box)
text_boxes = self.text_detector(region_img)
text_regions.extend(text_boxes)
# Stage 3: OCR (120ms)
texts = []
for region in text_regions:
region_img = crop_image(image, region)
text = self.ocr_model(region_img)
texts.append((region, text))
# Stage 4: Field extraction (50ms)
fields = self.field_extractor(texts, layout_boxes)
# Total: ~270ms
return {
'fields': fields,
'confidence': self.calculate_confidence(fields),
'raw_text': texts
}
def calculate_confidence(self, fields):
# Confidence based on OCR scores and field validation
scores = [f['ocr_confidence'] for f in fields]
validation = [self.validate_field(f) for f in fields]
return np.mean(scores) * np.mean(validation)
# Cloud fallback (only for low-confidence cases)
class CloudVerification:
def verify_fields(self, image, on_device_result):
# Use GPT-4V for verification
response = openai.ChatCompletion.create(
model="gpt-4-vision-preview",
messages=[{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": image}},
{"type": "text", "text": f"Verify these extracted fields: {on_device_result['fields']}"}
]
}]
)
return response.choices[0].message.content
Performance Results
| Metric | Target | Achieved | Notes |
|---|---|---|---|
| Latency p95 | <300ms | 285ms | On-device only |
| Accuracy | >95% | 96.8% | With cloud fallback: 98.2% |
| Offline Success | >90% | 92% | High confidence extractions |
| Cloud Requests | <10% | 8% | Only low-confidence cases |
| Privacy | 100% local | 100% | Documents never uploaded |
| Cost per Extract | <$0.01 | $0.0008 | 8% × $0.01 cloud cost |
Optimization Techniques Used
- Model Distillation: Trained layout detector from Detectron2 → MobileNetV3 (500MB → 15MB)
- Quantization: INT8 post-training quantization across all models
- Pipeline Optimization: Overlapped stages using async processing
- Smart Cropping: Only process detected regions, not full image
- Caching: Cache layout results for multi-page documents
Business Impact
- User Experience: Instant results without loading spinners
- Privacy: Meets GDPR/CCPA requirements, key selling point
- Reliability: Works in areas with poor connectivity (35% of users)
- Cost: 12x cheaper than cloud-only solution
- Scale: Handles 2M documents/day without additional cloud costs
Implementation Checklist
Phase 1: Hardware Selection & Baseline (Week 1-2)
-
Define SLOs and Constraints
- Target latency (p50, p95, p99)
- Throughput requirements (requests/second)
- Cost budget per request
- Accuracy thresholds
- Privacy/locality requirements
-
Hardware Selection
- Cloud: GPU (A100, H100), TPU (v4, v5), or Inferentia
- Edge: Device profiles (phone, IoT, embedded)
- Hybrid: Define on-device vs cloud split
-
Baseline Measurements
- CPU baseline (throughput, latency, cost)
- GPU/TPU baseline with naive deployment
- Profile memory usage and bottlenecks
- Measure end-to-end request flow
Phase 2: Optimization & Compilation (Week 3-4)
-
Quantization
- Post-training dynamic quantization
- Static quantization with calibration dataset
- Quantization-aware training if needed
- A/B test accuracy on representative data
-
Compiler Optimization
- TensorRT/XLA compilation
- Kernel fusion analysis
- Mixed precision (FP16/BF16)
- Batch size optimization
-
Memory Optimization
- KV-cache strategy for LLMs
- Paged attention implementation
- CPU-GPU transfer optimization
- Memory profiling and leak detection
Phase 3: On-Device Deployment (Week 5-6, if applicable)
-
Model Compression
- Pruning (structured/unstructured)
- Knowledge distillation
- LoRA adapters for personalization
- Neural architecture search
-
Device Integration
- Target NPU/DSP if available
- Offline-first UX design
- Model update mechanism
- Battery and thermal monitoring
-
Privacy Design
- On-device feature extraction
- Minimize data uploads
- Federated learning if needed
- Privacy policy updates
Phase 4: Evaluation & Monitoring (Week 7-8)
-
Comprehensive Benchmarking
- Latency distribution (p50, p95, p99)
- Throughput under load
- Memory usage (peak, steady-state)
- Cost per request calculation
- Accuracy comparison matrix
-
Reliability Testing
- OOM scenarios and recovery
- Thermal throttling behavior
- Network failure handling
- Graceful degradation
-
Production Monitoring
- Latency and throughput dashboards
- Error rates and OOM frequency
- Cost tracking and alerts
- Accuracy monitoring (when ground truth available)
Phase 5: Continuous Improvement (Ongoing)
-
Performance Iteration
- Profile new bottlenecks
- Update compilers and drivers
- New quantization techniques
- Hardware upgrades evaluation
-
Model Updates
- A/B test new model versions
- Gradual rollout strategy
- Rollback procedures
- Version compatibility
-
Cost Optimization
- Right-size instance types
- Spot/preemptible instances
- Auto-scaling policies
- Multi-cloud cost comparison
Common Pitfalls & Solutions
| Pitfall | Impact | Solution |
|---|---|---|
| Ignoring warmup time | First requests 10x slower | Pre-warm models, use keep-alive |
| Wrong batch size | Suboptimal cost/latency | Benchmark multiple batch sizes, use dynamic batching |
| Memory leaks | OOM crashes | Profile with torch.cuda.memory, clear caches |
| CPU-GPU bottleneck | GPU underutilized | Use pinned memory, async transfers, prefetching |
| Over-quantization | Accuracy drops | Start with 8-bit, measure accuracy carefully, use QAT |
| Ignoring outliers | p99 latency spikes | Use separate outlier handling, GC tuning |
| No fallback strategy | Crashes on edge cases | Implement graceful degradation, monitoring |
| Poor thermal design | Device throttling | Duty cycling, reduce model size, background processing |
Best Practices Summary
- Measure First: Always baseline before optimizing
- Incremental Optimization: Apply one technique at a time, measure impact
- Target-Specific: Different deployments need different optimizations
- Monitor Accuracy: Quantization and compression can degrade quality
- Design for Failure: Handle OOM, thermal limits, network issues gracefully
- Privacy by Design: Minimize data movement, prefer local processing
- Cost-Aware: Track cost per request, not just performance
- User-Centric: Optimize for perceived latency and offline reliability
Further Reading
- Hardware: NVIDIA TensorRT documentation, Google TPU performance guide
- Quantization: "LLM.int8()" paper, GPTQ, AWQ papers
- On-Device: TensorFlow Lite guide, Core ML documentation
- Compilers: TVM tutorials, XLA optimization guide
- Memory: vLLM paper on paged attention, FlashAttention paper