On the Factory Floor: ML Engineering for Industrial-Scale Ads Recommendation Models from Google
Korvus, Local-Gemma, STORM, Katex
Article
This paper presents Google's search ads click-through rate (CTR) prediction system, offering insights into the challenges and solutions for large-scale industrial recommender systems. The CTR prediction model described has billions of weights, trains on over 100 billion examples, and performs inference at more than 100,000 requests per second. It of course is large engineering system that tries to:
Balancing accuracy improvements with training and serving costs
Maintaining model simplicity to allow for ongoing R&D
Addressing the unique challenges of online advertising, such as non-stationarity of data and calibration requirements
With all of these engineering challenges, Google breaks down various technologies that they adopt building these large scale recommender systems with regards to model performance(improvement of the model performance) and model efficiency(decrease in the cost without hiring the model performance). This two approaches would balance the model ROI(Return on Investment) in such a way that one does not scale the model without thinking about efficiency.
Model Architecture and Training
The model uses a deep neural network (DNN) architecture with the following components:
Input representation: Ad-query pairs are represented using sparse features, including n-grams and sub-word units, mapped to embedding tables.
Network structure:
Embedding input layer (E)
Multiple fully-connected hidden layers (Hi = σ(Hi-1Wi))
Output layer with sigmoid activation for click probability
Loss function: Logistic loss of observed click label y with respect to prediction ŷ
Training method:
Synchronous minibatch SGD on TPUs
AdaGrad optimizer for both embedding and dense network weights
Online learning with a single sequential pass over logged examples in chronological order
Evaluation: Progressive validation, using predictions on each example before it is trained on
The only thing that is proprietary/custom to Google is the usage of TPUs over GPUs. Other than that, most of the work is well established in the industry.
Efficiency Techniques
1. Matrix Factorization(MF)
To reduce computational costs while maintaining accuracy, the paper talks about MF bottleneck layers between non-linearities:
Reduces compute from m × n to (m × k) + (k × n)
Allows for linear scaling of compute when increasing layer sizes
2. AutoML for Efficiency
The paper describes an automated approach to optimize model architecture for accuracy/cost trade-offs:
Uses neural architecture search based on weight sharing
Components: weight-sharing network, RL controller, and constraints
Explores network configurations (layer widths, embedding dimensions) efficiently
Finds model versions with neutral accuracy and decreased costs
3. Sampling Strategies
To reduce training data without sacrificing accuracy, the authors employ:
Importance sampling based on example difficulty
Stratified sampling to balance rare events
Down-sampling of negative examples
Accuracy Improvement Techniques
1. Distillation
The paper describes two applications of distillation:a. Cross-model distillation:
Train a large, complex "teacher" model
Use teacher's predictions to train a smaller, simpler "student" model
Allows for exploring complex architectures without increasing serving costs
b. Self-distillation:
Use the model's own past predictions as soft targets
Helps with non-stationarity and provides a form of ensembling
2. Second-Order Optimization
The authors report the first known large-scale deployment of a second-order optimizer in a production neural network:
Distributed Shampoo algorithm used for training dense network weights
Improved convergence speed and final model quality
3. Mixture of Experts (MoE)
Adaptation of MoE for online learning:
Sparse gating network to select a subset of experts for each example
Experts are full DNN models
Challenges addressed: expert collapse, input/output normalization, data freshness
4. Multi-Task Learning
The model is trained to predict multiple related targets:
Main task: click prediction
Auxiliary tasks: conversion prediction, user satisfaction signals
Shared embedding layer and task-specific output layers
Reproducibility Improvements
The paper addresses the challenge of irreproducibility in non-convex optimization:
Identified sources of irreproducibility:
Hardware non-determinism
Algorithmic non-determinism
Data ordering and sampling
Solutions implemented:
Deterministic seeding of PRNGs
Consistent data ordering and sampling
Smooth activation functions (e.g., GELU) instead of ReLU
Lowered learning rates
It is interesting that Google has to worry about determinism of the system not due to debugging/monitoring/observability but actually reducing the variance between different runs in the model performance metrics.
UI Generalization
To improve generalization across different UI treatments, the authors introduce:
Model factorization:
Separate "semantic" and "positional" components of the model
Semantic component: processes query and ad content
Positional component: handles UI-specific features
Allows for better transfer learning across UI configurations
Constrained learning:
Add constraints to force the model to learn specific behaviors
Example: Monotonicity constraints for ad position effects
Results: Improved generalization performance across UI treatments and reduced system irreproducibility
Bias Constraints
The paper describes a general-purpose technique for adding bias constraints to the model:
Motivation: Improve generalization and system reproducibility
Implementation:
Define a set of constraint functions g(x, y) that should have zero expectation
Add a penalty term to the loss function: λ * (E[g(x, y)])^2
Use online estimates of E[g(x, y)] during training
Applications:
Calibration constraints
Fairness constraints
Domain-specific constraints (e.g., monotonicity)
Results: Improved generalization and reduced system irreproducibility
Empirical Results and Comparisons
Out of all of these various model accuracy and efficiency techniques, how do they fare with regards to the model ROI you might ask.
The paper provides a comparison of the relative impact of various techniques:
Baseline (naive scaling): 1.000x accuracy, 1.000x training cost
Bottlenecks: 0.999x accuracy, 0.930x training cost
AutoML: 1.001x accuracy, 0.920x training cost
Sampling: 1.000x accuracy, 0.700x training cost
Distillation: 1.005x accuracy, 1.000x training cost
Second-order optimization: 1.005x accuracy, 1.000x training cost
Mixture of Experts: 1.015x accuracy, 1.100x training cost
Multi-task learning: 1.007x accuracy, 1.050x training cost
This is an interest way to look at the results as some of the techniques are improving model performance and improves efficiency, some of them are only for model optimization, efficiency like sampling without much degradation of the model performance. Some of the technologies like AutoML provide little to almost none improvement with a significant improvement in the exploration efficiency.
I like how methodical this paper is with exploration of different technologies with return that they are getting and how much investment is being put in.
Libraries
STORM(Synthesis of Topic Outlines through Retrieval and Multi-perspective Question Asking) is a LLM system that writes Wikipedia-like articles from scratch based on Internet search.
While the system cannot produce publication-ready articles that often require a significant number of edits, experienced Wikipedia editors have found it helpful in their pre-writing stage.
local-gemma provides an easy way to run Gemma-2 locally directly from your CLI (or via a Python library) and fast. It is built on top of the 🤗 Transformers and bitsandbytes libraries.
It can be configured to give fully equivalent results to the original implementation, or reduce memory requirements down to just the largest layer in the model!
KaTeX is a fast, easy-to-use JavaScript library for TeX math rendering on the web.
Fast: KaTeX renders its math synchronously and doesn't need to reflow the page. See how it compares to a competitor in this speed test.
Print quality: KaTeX's layout is based on Donald Knuth's TeX, the gold standard for math typesetting.
Self contained: KaTeX has no dependencies and can easily be bundled with your website resources.
Server side rendering: KaTeX produces the same output regardless of browser or environment, so you can pre-render expressions using Node.js and send them as plain HTML.
Official code for the XLand-100B: A Large-Scale Multi-Task Dataset for In-Context Reinforcement Learning, which presents two large datasets for in-context RL based on XLand-MiniGrid environment: XLand-100B and a smaller version XLand-Trivial-20B. Together, they contain about 3.5B episodes, 130B transitions and 40,000 unique tasks, which is more than in any other dataset currently available in RL. Furthermore, our datasets are unique in that they contain the complete training histories of the base algorithms, rather than just expert transitions or partial replay buffers. With this datasets we aim to democratize research in the rapidly growing field of in-context RL and provide a solid foundation for further scaling.
Korvus is a search SDK that unifies the entire RAG pipeline in a single database query. Built on top of Postgres with bindings for Python, JavaScript and Rust, Korvus delivers high-performance, customizable search capabilities with minimal infrastructure concerns. It is an all-in-one, open-source RAG (Retrieval-Augmented Generation) pipeline built for Postgres. It combines LLMs, vector memory, embedding generation, reranking, summarization and custom models into a single query, maximizing performance and simplifying your search architecture.
Korvus stands out by harnessing the full power of Postgres for RAG operations:
Postgres-Native RAG: Korvus leverages Postgres' robust capabilities, allowing you to perform complex RAG operations directly within your database. This approach eliminates the need for external services and API calls, significantly reducing latency and complexity many times over.
Single Query Efficiency: With Korvus, your entire RAG pipeline - from embedding generation to text generation - is executed in a single SQL query. This "one query to rule them all" approach simplifies your architecture and boosts performance.
Scalability and Performance: By building on Postgres, Korvus inherits its excellent scalability and performance characteristics. As your data grows, Korvus grows with it, maintaining high performance even with large datasets.
4M: Massively Multimodal Masked Modeling is a framework for training any-to-any multimodal foundation models. It is scalable, Open-sourced, and across tens of modalities and tasks.
It is a framework for training "any-to-any" foundation models, using tokenization and masking to scale to many diverse modalities. Models trained using 4M can perform a wide range of vision tasks, transfer well to unseen tasks and modalities, and are flexible and steerable multimodal generative models. It has a good project page as well for further details.
Models
Apple released a new family models called DLCM in HuggingFace. It is a transformer based LLM which has the following specifications:
Architecture: Decoder-only Transformer
Framework: PyTorch with OpenLM
Optimizer: AdamW
Learning Rate: 2e-3 (peak)
Weight Decay: 0.05
Batch Size: 2048 sequences
Sequence Length: 2048 tokens
Total Training Tokens: 2.5T
Hardware: Trained on H100 GPUs