Zalando's Recommendation Systems based on Graph Neural Networks(GNNs)
Google's Jax Guide for PyTorch developers
Articles
Zalando wrote about how they use Graph Neural Networks (GNNs) to enhance their recommendation system for the Zalando Homepage. By leveraging GNNs to capture complex interaction patterns between users and content to improve relevance for their recommendations.
The main challenge is to predict the Click-Through Rate (CTR) for potential content shown to users on the Zalando Homepage.
The recommendation system is modeled as a bipartite graph with two node types: users and items. The links between these nodes represent user-item interactions such as clicks and views. The task is formulated as a link prediction problem, where the goal is to predict future user-item interactions based on past interactions.
Technical Approach
Rather than directly predicting clicks using a GNN model, Zalando's approach involves:
1. Training a GNN to generate embeddings for users and content based on click prediction tasks.
2. Using these embeddings as additional inputs to the existing production model.
This strategy allows Zalando to leverage GNN capabilities while integrating smoothly with their current infrastructure, avoiding significant operational changes.
Data
The training and evaluation datasets are prepared using the PyTorch Geometric library, which provides functionalities for efficient graph data loading, manipulation, and batching. The datasets are based on user-content activity data at a per-request level, labeled as clicked or not clicked.
Architecture
The GNN architecture is based on GraphSage and operates through message passing and feature aggregation:
1. Initial node features:
- User nodes: Information about recently ordered articles
- Content nodes: Article representations associated with the content
2. Message passing: Nodes send their features to adjacent nodes, potentially transformed by neural network layers.
3. Feature aggregation: Nodes combine incoming features from neighbors using operations like summing or averaging.
4. Embedding generation: As the network depth increases, the GNN captures more distant relationships, effectively generating embeddings for all nodes.
5. Classification: The generated embeddings are passed through a classifier to predict the existence of a "clicked" link between a user and a content node.
Graph Mini-batching
To handle large-scale data, Zalando employs mini-batch training:
- Sampling subgraphs and computing embeddings in parallel
- Sampling links with neighborhoods of both adjacent nodes
- Using disjoint subgraphs for each mini-batch
- Utilizing disjoint sets of links for message passing and supervision signals to prevent information leakage
Integration with Production Model
The GNN-generated embeddings for users and content are used as additional features in the existing production model. This approach allows to:
1. Leverage GNN strengths without major infrastructure changes
2. Retrain the GNN model daily to reflect the latest user-content interactions
3. Address the cold-start problem for new nodes by using initial features and existing 'view' links.
Advantages of GNN Features
GNN-generated embeddings offer several advantages over static features:
1. Task-specific training: Embeddings are trained directly for the click prediction objective.
2. Contextual adaptation: The model encodes both intrinsic properties of content and its evolving relationships with users and other content.
3. Dynamic updates: Features are updated as nodes gain more connections and interactions within the graph.
4. Cold-start problem mitigation: Even new nodes with no clicks can be processed using initial features and 'view' links.
Google wrote a comprehensive guide for PyTorch developers looking to understand and transition to JAX, a high-performance numerical computation library with automatic differentiation support. The post compares JAX to PyTorch, highlighting key differences and similarities while walking through the process of training a simple neural network for the Titanic survival prediction task.
JAX Ecosystem and Modularity
JAX is designed as a flexible, modular ecosystem that focuses on high-performance numerical computation and automatic differentiation. Unlike PyTorch, which provides built-in support for neural networks and optimizers, JAX allows users to bring in their preferred frameworks. The tutorial uses the Flax Neural Network library and the Optax optimization library to demonstrate JAX's capabilities.
Functional Programming in JAX
JAX embraces functional programming, which differs from PyTorch's object-oriented approach. This paradigm focuses on pure functions that don't mutate state or have side effects, always producing the same output for the same input. The benefits of this approach include:
1. Just-In-Time (JIT) compilation for significant speed improvements
2. Easier sharding and parallelization of operations
3. Predictability and reproducibility
Data Loading
Data loading in JAX is similar to PyTorch. Developers can use PyTorch datasets and dataloaders with a simple `collate_fn` to convert data to JAX's Numpy-like arrays.
Model Definition
The post compares model definition in PyTorch and JAX using Flax's NNX API. Both frameworks use similar structures:
1. `__init__` method to define model layers
2. `forward` method in PyTorch, equivalent to `__call__` in NNX
The syntax and structure are nearly identical, making the transition easier for PyTorch users.
Model Initialization and Usage
Model initialization in NNX is similar to PyTorch, with both frameworks eagerly initializing model parameters. The main difference in NNX is the requirement to pass a pseudorandom number generator (PRNG) key when instantiating the model. This approach aligns with JAX's functional nature, ensuring reproducibility and enabling parallelization.
Training Step and Backpropagation
The post details the differences in training loops between PyTorch and Flax NNX:
Setup
Both frameworks allow the creation of optimizers and specification of optimization algorithms. NNX simplifies this process by allowing direct passing of the model to the optimizer.
Forward and Backward Pass
The most significant difference between PyTorch and JAX lies in the forward and backward pass implementation:
1. PyTorch uses `loss.backward()` to trigger AutoGrad for gradient computation.
2. JAX uses `nnx.value_and_grad` or `nnx.grad` to create a function that returns gradients of the loss with respect to model parameters.
Optimizer Step
While PyTorch updates weights in-place using `optimizer.step()`, NNX requires explicitly passing the calculated gradients to the optimizer for weight updates.
Full Training Loop
The post provides full training loop implementations for both PyTorch and JAX/Flax NNX, highlighting their similarities and differences. The main distinctions stem from the object-oriented vs. functional programming approaches.
Performance Benefits
JAX's functional approach enables powerful optimizations:
1. JIT compilation
2. Automatic parallelization
The post demonstrates a significant speedup in training time (from 6.25 minutes to 1.8 minutes on a P100 GPU) by simply adding `@nnx.jit` annotations to functions.
Flax Linen API
While NNX is recommended for new users, the post also introduces the Flax Linen API, which is still widely used in frameworks like MaxText and MaxDiffusion. Linen adheres more closely to pure functional programming principles. The post provides a code example using Linen, highlighting the main differences from NNX.
Libraries
DeepSeek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese. We provide various sizes of the code model, ranging from 1B to 33B versions. Each model is pre-trained on project-level code corpus by employing a window size of 16K and an extra fill-in-the-blank task, to support project-level code completion and infilling. For coding capabilities, DeepSeek Coder achieves state-of-the-art performance among open-source code models on multiple programming languages and various benchmarks.
RecBole is developed based on Python and PyTorch for reproducing and developing recommendation algorithms in a unified, comprehensive and efficient framework for research purpose. This library includes 91 recommendation algorithms, covering four major categories:
General Recommendation
Sequential Recommendation
Context-aware Recommendation
Knowledge-based Recommendation
Transformer-based Hybrid Recommender System aims to investigate the effectiveness of three Transformers (BERT, RoBERTa, XLNet) in handling data sparsity and cold start problems in the recommender system. They present a Transformer-based hybrid recommender system that predicts missing ratings and ex- tracts semantic embeddings from user reviews to mitigate the issues. They conducted two experi- ments on the Amazon Beauty product review dataset, focusing on multi-label text classification under sparse data and recommendation performance in user cold start settings. Our findings demonstrate that XLNet outperforms the other Transformers in both tasks, and our proposed methods show superior performance compared to the traditional ALS method in handling data sparsity and cold start scenarios. This study not only confirms transformers’ effectiveness under cold start conditions in recommender systems but also highlights the need for further study and improvement in the fine-tuning process.
Below the Fold
Apache Answer - Build Q&A platform is a software for teams at any scales. Whether it’s a community forum, help center, or knowledge management platform, you can always count on Answer.
Ubicloud is an open source cloud that can run anywhere. Think of it as an open alternative to cloud providers, like what Linux is to proprietary operating systems.
Ubicloud provides IaaS cloud features on bare metal providers, such as Hetzner, Leaseweb, and AWS Bare Metal. You can set it up yourself on these providers or you can use our managed service.
Bash is great, but when it comes to writing more complex scripts, many people prefer a more convenient programming language. JavaScript is a perfect choice, but the Node.js standard library requires additional hassle before using. The zx package provides useful wrappers around child_process
, escapes arguments and gives sensible defaults.