Skip to main content
AI

Building Offline-Friendly AI Features: When to Ship a Small Model On-Device

Ravinder··9 min read
AIOn-DeviceLLMMobilePrivacy
Share:
Building Offline-Friendly AI Features: When to Ship a Small Model On-Device

The feature request came in as "make the AI work offline." Simple enough on the surface. Six months later, after shipping a 1.2B parameter model to 4 million mobile devices, I can tell you: the hard problems have nothing to do with the model. They're about what happens when the model's answer diverges from what the cloud would have said, which device segments you actually need to support, and what you do when a user's battery is at 4%.

This is what I wish I'd known before starting.

When On-Device Makes Sense

Not every AI feature should run on-device. The business case only holds under specific conditions:

On-device is the right call when:

  • You have latency-sensitive features that cloud round-trips make awkward (real-time suggestions, keyboard autocomplete, live transcription)
  • Your users have significant offline use (field workers, pilots, remote healthcare, developing markets with spotty connectivity)
  • Privacy is a first-class requirement (medical notes, legal documents, therapy journaling — data that should never leave the device)
  • Your cloud inference costs are high and the task can be solved by a smaller model

On-device is the wrong call when:

  • The task requires current information (stock prices, weather, live data)
  • Model updates are critical (safety-relevant tasks where a model error is high-stakes)
  • Your device footprint (storage, RAM, battery) is already tight and users will notice

The correct mental model isn't "replace cloud AI with on-device AI." It's "which tasks can I handle locally, and what's my fallback for the ones I can't?"

The Model Landscape

graph TD subgraph "Device Classes" HC[High-end Mobile
iPhone 15 Pro, Pixel 8 Pro
8-12 GB RAM, Neural Engine] MC[Mid-range Mobile
iPhone 13, Pixel 7a
4-6 GB RAM] LC[Low-end / Older
2-4 GB RAM
No dedicated NPU] end subgraph "Model Size Fit" HC --> M3B[1B-3B params
Phi-3-mini, Gemma-2B] MC --> M1B[500M-1.5B params
Phi-3.5-mini quantized
MobileLLM-1B] LC --> M500[100M-500M params
distilled task-specific
or cloud-only] end

The models worth knowing for mobile in 2025:

Phi-3-mini (3.8B, Microsoft) — The best general-purpose model in the 3-4B range. Quantized to INT4 via GGUF, it runs at ~12 tokens/sec on an iPhone 15 Pro via llama.cpp. Instruction-following is surprisingly strong for its size.

Gemma-2B (Google) — Better than Phi-3 on certain reasoning tasks in the 2B range, but larger memory footprint at the same quantization level. Use it if you need stronger math/code.

MobileLLM (Meta, 1B/125M) — Purpose-built for mobile. Architectural choices (grouped-query attention, deep-thin design) specifically target mobile inference efficiency. 125M variant can run on low-end devices.

Phi-3.5-mini (3.8B) — Follow-up to Phi-3-mini with better multilingual support. Important if your user base is non-English-primary.

Distillation for Task-Specific Models

General-purpose 3B models are impressive but often overkill — or underkill — for your specific task. Task-specific distillation gives you a model that's smaller, faster, and more reliable for your use case.

The process:

flowchart LR A[Teacher Model
GPT-4 / Sonnet] --> B[Generate outputs
on task dataset] B --> C[Student init
Phi-3-mini or Gemma-2B] C --> D[Knowledge Distillation
KL divergence on logits] D --> E[Task-Specific
Student Model] E --> F[Quantize to INT4
GGUF / CoreML] F --> G[Deploy to device]
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
import torch
import torch.nn.functional as F
 
def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.5):
    """
    Combined loss: cross-entropy on labels + KL divergence from teacher.
    alpha: weight on distillation loss (vs standard CE loss)
    temperature: higher T = softer distributions = more transfer
    """
    # Standard supervised loss
    ce_loss = F.cross_entropy(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1),
        ignore_index=-100,
    )
    
    # Distillation loss
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
    kl_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (temperature ** 2)
    
    return alpha * kl_loss + (1 - alpha) * ce_loss
 
class DistillationTrainer(SFTTrainer):
    def __init__(self, teacher_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher_model = teacher_model.eval()
    
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
        
        loss = distillation_loss(
            student_logits=outputs.logits,
            teacher_logits=teacher_outputs.logits,
            labels=inputs["labels"],
        )
        
        return (loss, outputs) if return_outputs else loss

For a focused task (e.g., "extract action items from meeting notes"), you can distill a 3B general model into a 500M task-specific model that performs equivalently on that task at 40% of the inference cost and memory footprint.

Converting for Mobile Deployment

iOS — CoreML / Metal

import coremltools as ct
from transformers import AutoModelForCausalLM
import torch
 
# Load and trace the model
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
model.eval()
 
# Export to ONNX first, then convert to CoreML
# (full export pipeline — abbreviated for space)
example_input = torch.randint(0, 32000, (1, 128))
 
mlmodel = ct.convert(
    traced_model,
    inputs=[ct.TensorType(shape=(1, ct.RangeDim(1, 512)), dtype=np.int32)],
    compute_precision=ct.precision.FLOAT16,  # ANE uses FP16
    compute_units=ct.ComputeUnit.ALL,        # use ANE + GPU + CPU
)
 
# Apply compression for on-device storage
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
    mode="linear_symmetric",
    dtype="int4",
    granularity="per_block",
    block_size=32,
)
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
mlmodel_quantized = ct.optimize.coreml.linear_quantize_weights(mlmodel, config=config)
mlmodel_quantized.save("phi3-mini-int4.mlpackage")

Android — MediaPipe / ONNX Runtime Mobile

// Android inference with ONNX Runtime Mobile
class OnDeviceLLM(context: Context) {
    private val session: OrtSession
    private val tokenizer: Tokenizer
    
    init {
        val sessionOptions = OrtSession.SessionOptions().apply {
            addNnapi()          // use Android Neural Networks API
            setInterOpNumThreads(2)
            setIntraOpNumThreads(4)
        }
        val modelBytes = context.assets.open("phi3-mini-q4.onnx").readBytes()
        session = OrtEnvironment.getEnvironment().createSession(modelBytes, sessionOptions)
        tokenizer = Tokenizer.fromAsset(context, "tokenizer.json")
    }
    
    fun generate(prompt: String, maxTokens: Int = 256): Flow<String> = flow {
        val inputIds = tokenizer.encode(prompt)
        var currentIds = inputIds.toLongArray()
        
        repeat(maxTokens) {
            val inputs = mapOf(
                "input_ids" to OnnxTensor.createTensor(
                    OrtEnvironment.getEnvironment(),
                    arrayOf(currentIds)
                )
            )
            val output = session.run(inputs)
            val logits = output[0].value as Array<Array<FloatArray>>
            val nextToken = logits[0].last().argmax()
            
            if (nextToken == tokenizer.eosTokenId) return@flow
            
            emit(tokenizer.decode(intArrayOf(nextToken)))
            currentIds = currentIds + nextToken
        }
    }
}

Battery and Thermal Management

This is where most teams underestimate the work. A 3B INT4 model generating text continuously will drain a modern iPhone at roughly 8-12% per minute under load. Users will notice.

// iOS: Thermal and battery-aware inference throttling
import Foundation
 
class ThermalAwareInferenceManager {
    private var thermalObserver: NSObjectProtocol?
    
    enum InferenceMode {
        case full          // normal generation
        case throttled     // reduce batch, increase token delay
        case cloudFallback // offload to cloud, preserve battery
    }
    
    var currentMode: InferenceMode = .full
    
    init() {
        thermalObserver = NotificationCenter.default.addObserver(
            forName: ProcessInfo.thermalStateDidChangeNotification,
            object: nil,
            queue: .main
        ) { [weak self] _ in
            self?.updateMode()
        }
    }
    
    func updateMode() {
        let batteryLevel = UIDevice.current.batteryLevel  // -1 if unknown
        let thermalState = ProcessInfo.processInfo.thermalState
        
        switch (thermalState, batteryLevel) {
        case (.critical, _), (.serious, let b) where b < 0.20:
            currentMode = .cloudFallback
        case (.fair, let b) where b < 0.30, (.serious, _):
            currentMode = .throttled
        default:
            currentMode = .full
        }
        
        NotificationCenter.default.post(
            name: .inferenceModeChanged,
            object: currentMode
        )
    }
    
    func shouldUseOnDevice() -> Bool {
        currentMode != .cloudFallback
    }
}

Practical battery rules:

  • Don't run inference on background threads without user awareness — it drains battery without visible feedback
  • Implement token delay (add 10-20ms between tokens when throttled) — it reduces GPU utilization without breaking the streaming UX
  • Respect low-power mode: ProcessInfo.processInfo.isLowPowerModeEnabled → fall back to cloud automatically

The Sync-Online Fallback Pattern

On-device models will produce different answers than cloud models on the same query. This is not a bug you can fully eliminate — it's a property of using different models. Design around it.

sequenceDiagram participant User participant App participant OnDevice participant Cloud User->>App: Request App->>App: Check: online + non-throttled? alt Online and capable App->>Cloud: Send request Cloud-->>App: High-quality response else Offline or throttled App->>OnDevice: Send request OnDevice-->>App: Local response App->>App: Store: (request, local_response, timestamp) end App-->>User: Show response (with offline indicator if applicable) Note over App,Cloud: Later, when online App->>Cloud: Sync: re-run stored requests Cloud-->>App: Cloud responses App->>App: Diff local vs cloud responses alt Significant divergence App-->>User: "Updated answer available" end
import hashlib
import json
from datetime import datetime
from dataclasses import dataclass, asdict
 
@dataclass
class OfflineResponse:
    request_id: str
    query: str
    local_response: str
    timestamp: str
    model_version: str
    synced: bool = False
    cloud_response: str | None = None
 
class OfflineSyncManager:
    def __init__(self, local_store, cloud_client, divergence_threshold: float = 0.3):
        self.store = local_store
        self.cloud = cloud_client
        self.threshold = divergence_threshold
    
    def record_offline_response(self, query: str, response: str, model_ver: str) -> str:
        request_id = hashlib.sha256(f"{query}{datetime.utcnow().isoformat()}".encode()).hexdigest()[:16]
        record = OfflineResponse(
            request_id=request_id,
            query=query,
            local_response=response,
            timestamp=datetime.utcnow().isoformat(),
            model_version=model_ver,
        )
        self.store.save(request_id, asdict(record))
        return request_id
    
    async def sync_pending(self) -> list[dict]:
        unsynced = self.store.get_unsynced()
        diverged = []
        
        for record in unsynced:
            cloud_resp = await self.cloud.generate(record["query"])
            divergence = self._measure_divergence(record["local_response"], cloud_resp)
            
            record["cloud_response"] = cloud_resp
            record["synced"] = True
            self.store.update(record["request_id"], record)
            
            if divergence > self.threshold:
                diverged.append({
                    "request_id": record["request_id"],
                    "divergence": divergence,
                    "query_snippet": record["query"][:80],
                })
        
        return diverged  # caller decides whether to notify user
    
    def _measure_divergence(self, local: str, cloud: str) -> float:
        # Simple word-overlap similarity — replace with embedding similarity in production
        local_words = set(local.lower().split())
        cloud_words = set(cloud.lower().split())
        if not local_words or not cloud_words:
            return 1.0
        intersection = local_words & cloud_words
        union = local_words | cloud_words
        return 1.0 - (len(intersection) / len(union))

The key UX decision: when do you proactively notify the user that the offline answer diverged from cloud? Our answer: only when the divergence score exceeds 0.4 and the topic is factual (not creative). Creative tasks have legitimate variation; factual tasks need surfacing.

Privacy Architecture

On-device processing is only a privacy guarantee if you actually process on-device. Common mistakes that undermine it:

  1. Logging queries for "improvement" on first launch — if the query goes to your servers for any reason, the privacy claim is void
  2. Sending embeddings to cloud for retrieval — embeddings are nearly invertible for short texts; this leaks information
  3. Crash reporting that includes model I/O — Sentry and Firebase Crashlytics can inadvertently capture model input/output in crash traces
// iOS: Ensure on-device processing stays on-device
class PrivacyGuardedInference {
    private let onDeviceModel: OnDeviceLLM
    private let analyticsEnabled: Bool
    
    func infer(query: String) async -> String {
        // Never log the query itself — only aggregate metrics
        let startTime = Date()
        let result = await onDeviceModel.generate(query)
        let latency = Date().timeIntervalSince(startTime)
        
        if analyticsEnabled {
            // Only log performance metrics, never content
            Analytics.log("inference_completed", properties: [
                "latency_ms": latency * 1000,
                "token_count": result.split(separator: " ").count,
                "model_version": onDeviceModel.version,
                // NO: "query": query  ← never do this
                // NO: "response": result ← never do this
            ])
        }
        
        return result
    }
}

Key Takeaways

  • On-device AI is justified by latency sensitivity, offline requirements, or privacy constraints — not by novelty; evaluate the actual need before building.
  • Phi-3-mini INT4 via GGUF or CoreML is the best general-purpose starting point for high-end mobile; MobileLLM-1B fits mid-range devices with adequate task-specific tuning.
  • Task-specific distillation from a teacher model (GPT-4 / Sonnet) into a student (Gemma-2B / Phi-3-mini) dramatically outperforms general-purpose models at 40-60% of the footprint.
  • Battery and thermal management are non-negotiable: monitor thermalState and batteryLevel, implement cloud fallback when critical, and never run inference continuously in the background.
  • Design the offline/online divergence explicitly — store offline responses, sync and diff against cloud answers, and notify users when factual responses diverge significantly.
  • Privacy on-device only holds if you audit your entire data pipeline: crash reporters, analytics SDKs, and sync mechanisms can all inadvertently exfiltrate model I/O.