X For You algorithm, line by line · Part 15
X For You algorithm, line by line — Part 15: Phoenix models (the ML core)
Part 15 of the deep dive into xai-org/x-algorithm. The actual neural networks: PhoenixModel (ranking transformer with user+history+candidates in one sequence, candidate isolation, multi-action heads) and PhoenixRetrievalModel (two-tower with transformer user encoder + MLP candidate tower, L2-normalized for ANN search). Hash embeddings, multi-hot action projection, continuous MLPs, post-age bucketing.
We leave Rust behind. The next six sessions cover the Python side: Phoenix (the ML models) and Grox (the LLM-driven content classification). This session reads the actual neural-network code that powers the For You feed's ranking and retrieval.
Two files, both implementing models built on top of the Grok transformer (xAI's open-source LLM architecture, adapted here for recsys). The framework: JAX + Haiku (Google's functional ML framework).
Files covered (1,068 LOC):
phoenix/
├── recsys_model.py (680) PhoenixModel — the ranking transformer
└── recsys_retrieval_model.py (388) PhoenixRetrievalModel — two-tower retrieval
The ranking model (recsys_model.py) is what PhoenixScorer (Session 10) calls per request to get per-candidate action probabilities. It concatenates user + history + candidates into one transformer input and uses candidate isolation (candidates can't attend to each other) so scores are independent.
The retrieval model (recsys_retrieval_model.py) is what PhoenixSource (Session 11) calls to get out-of-network candidates. Two-tower architecture: user tower (transformer-based) + candidate tower (MLP) project both to a shared embedding space; the system does nearest-neighbor search over the corpus.
recsys_model.py (680 lines) — the ranking transformer
License + imports
# Copyright 2026 X.AI Corp.
# ...
import logging
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
from grok import (
TransformerConfig,
Transformer,
layer_norm,
right_anchored_rope_positions,
)
JAX + Haiku stack. haiku as hk for parameterized modules + parameter initialization. The README mentions: "The transformer implementation is ported from the Grok-1 open source release by xAI, adapted for recommendation system use cases." So the actual Transformer class is from grok (covered in Session 17), and this file builds the recsys head on top.
right_anchored_rope_positions — rotary position embeddings, with right-anchoring (recent history is index 0, oldest at the back). Important for how the model encodes recency.
Post age bucketization
POST_AGE_MAX_MINUTES = 4800
def compute_post_age_bucket(
impr_ts_sec: jax.Array,
post_creation_ts_sec: jax.Array,
granularity_mins: int = 60,
) -> jax.Array:
"""Compute post age buckets from impression and creation timestamps."""
num_normal_buckets = POST_AGE_MAX_MINUTES // granularity_mins
overflow_bucket = num_normal_buckets + 1
post_age_minutes = (impr_ts_sec - post_creation_ts_sec) // 60
bucket = (post_age_minutes // granularity_mins) + 1
bucket = jnp.clip(bucket, 0, overflow_bucket)
bucket = jnp.where(
(post_age_minutes < 0) | (impr_ts_sec == 0) | (post_creation_ts_sec == 0),
0,
bucket,
)
return bucket.astype(jnp.int32)
Convert (impression_ts, post_creation_ts) into a discrete bucket index.
POST_AGE_MAX_MINUTES = 4800= 80 hours. Posts older than this go in the "overflow" bucket.granularity_mins = 60default → 80 buckets of 1 hour each + 1 overflow + 1 invalid.bucket = (age_minutes // 60) + 1— index 0 is reserved for missing/invalid.jnp.wherehandles three invalid cases: negative age (clock skew), missing impression timestamp, missing post creation timestamp. All → bucket 0.
Used as a dense feature: each candidate gets an age bucket index, looked up into an embedding table. The model learns that "post age = 3 hours" has different relevance than "post age = 48 hours."
Continuous-value normalization
@dataclass
class NormConfig:
"""Configuration for continuous value normalization."""
norm_scale: float = 30.0
use_log: bool = False
@dataclass
class ContinuousActionConfig:
"""Configuration for a single continuous action loss (e.g., dwell time)."""
loss_weight: float = 0.0
loss_type: str = "mae"
tweedie_power: float = 1.5
norm_config: NormConfig = None # type: ignore
def __post_init__(self):
if self.norm_config is None:
self.norm_config = NormConfig()
def normalize_continuous_value(
values: jnp.ndarray,
config: NormConfig,
) -> jnp.ndarray:
"""Normalize continuous values to 0-1 range."""
values_clamped = jnp.clip(values, 0.0, config.norm_scale)
if config.use_log:
return jnp.log1p(values_clamped) / jnp.log1p(config.norm_scale)
else:
return values_clamped / config.norm_scale
Continuous actions (like dwell time — how long the user looked at the post) need normalization to a [0, 1] range before feeding into the model.
Two modes:
- Linear:
clip(x, 0, scale) / scale— simple ratio. - Log:
log1p(clip) / log1p(scale)— compresses the distribution. Useful when values span orders of magnitude (3s vs 60s vs 600s for dwell time).
tweedie_power=1.5 in ContinuousActionConfig — for the Tweedie loss, which is good for skewed positive data with many zeros (perfect for dwell time: most posts get 0s dwell, a few get many).
loss_weight=0.0 default — continuous loss is opt-in.
Hash configuration
@dataclass
class HashConfig:
"""Configuration for hash-based embeddings."""
num_user_hashes: int = 2
num_item_hashes: int = 2
num_author_hashes: int = 2
num_ip_hashes: int = 0
Multi-hash embeddings: instead of one embedding per user, use multiple hash functions (e.g., 2) to map the user to multiple slots. The final embedding is the combined representation. Reduces collision impact — if two users hash to the same slot on hash 1, they probably hash to different slots on hash 2.
Standard technique from recsys at scale: lets the embedding table be smaller than the user/item count.
num_ip_hashes = 0 default — IP-based embeddings are opt-in (the IpQueryHydrator from Session 09 hydrates it; the model adds those embeddings if available).
The named tuples — input/output shapes
@dataclass
class RecsysEmbeddings:
"""Container for pre-looked-up embeddings from the embedding tables.
These embeddings are looked up from hash tables before being passed to the model.
The block_*_reduce functions will combine multiple hash embeddings into single representations.
"""
user_embeddings: jax.typing.ArrayLike
history_post_embeddings: jax.typing.ArrayLike
candidate_post_embeddings: jax.typing.ArrayLike
history_author_embeddings: jax.typing.ArrayLike
candidate_author_embeddings: jax.typing.ArrayLike
user_ip_embeddings: Optional[jax.typing.ArrayLike] = None
The pre-looked-up embeddings. The embedding tables themselves live elsewhere (in production: distributed parameter servers; we'll see in Session 16's runners). The model receives the embeddings already fetched.
Shapes:
user_embeddings:[B, num_user_hashes, D]history_post_embeddings:[B, history_seq_len, num_item_hashes, D]candidate_post_embeddings:[B, num_candidates, num_item_hashes, D]- Author equivalents: same shape as post.
- IP:
[B, num_ip_hashes, D].
class RecsysModelOutput(NamedTuple):
"""Output of the recommendation model."""
logits: jax.Array
continuous_preds: Optional[jax.Array] = None
class RecsysBatch(NamedTuple):
"""Input batch for the recommendation model.
Contains the feature data (hashes, actions, product surfaces) but NOT the embeddings.
Embeddings are passed separately via RecsysEmbeddings.
"""
user_hashes: jax.typing.ArrayLike
history_post_hashes: jax.typing.ArrayLike
history_author_hashes: jax.typing.ArrayLike
history_actions: jax.typing.ArrayLike
history_product_surface: jax.typing.ArrayLike
candidate_post_hashes: jax.typing.ArrayLike
candidate_author_hashes: jax.typing.ArrayLike
candidate_product_surface: jax.typing.ArrayLike
history_continuous_actions: Optional[jax.typing.ArrayLike] = None
candidate_impr_ts: Optional[jax.typing.ArrayLike] = None
candidate_post_creation_ts: Optional[jax.typing.ArrayLike] = None
user_ip_hashes: Optional[jax.typing.ArrayLike] = None
The batch input. Hashes, not raw IDs — by the time the data reaches the model, IDs have been hashed (multiple times — once per hash function). Plus action vectors, product surfaces, and (optional) continuous actions and timestamps.
Why separate Batch (data) and Embeddings (looked-up values)?
- The hashes drive the lookup but the embeddings themselves come from a separate distributed lookup step.
- Batch lives on CPU, embeddings live on accelerators.
- Lets the lookup-vs-compute boundary be explicit.
RecsysModelOutput:
logits: [B, num_candidates, num_actions]— raw scores for each discrete action.continuous_preds: [B, num_candidates, num_continuous_actions]— predicted continuous values (dwell time, etc.).
block_user_reduce — combine user hash embeddings
def block_user_reduce(
user_hashes: jnp.ndarray,
user_embeddings: jnp.ndarray,
num_user_hashes: int,
emb_size: int,
embed_init_scale: float = 1.0,
*,
user_ip_embeddings: Optional[jnp.ndarray] = None,
num_ip_hashes: int = 0,
) -> Tuple[jax.Array, jax.Array]:
"""Combine multiple user hash embeddings into a single user representation.
...
"""
B = user_embeddings.shape[0]
D = emb_size
user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * D))
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_1 = hk.get_parameter(
"proj_mat_1",
[num_user_hashes * D, D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
user_embedding = jnp.dot(user_embedding.astype(proj_mat_1.dtype), proj_mat_1).astype(
user_embeddings.dtype
)
# hash 0 is reserved for padding)
if user_ip_embeddings is not None and num_ip_hashes > 0:
ip_emb = user_ip_embeddings.reshape((B, num_ip_hashes, D))
ip_emb = jnp.sum(ip_emb, axis=1, keepdims=True) # [B, 1, D]
user_embedding = user_embedding + ip_emb
user_padding_mask = (user_hashes[:, 0] != 0).reshape(B, 1).astype(jnp.bool_)
return user_embedding, user_padding_mask
Combine the num_user_hashes embeddings into a single D-dimensional user vector:
- Reshape
[B, num_hashes, D]→[B, 1, num_hashes * D]— concatenate hash embeddings along the last axis. - Project via a learned
[num_hashes * D, D]matrix back to D dimensions. - Add IP embeddings (if present): sum across IP hashes, add to user embedding.
- Build padding mask:
(hash 0 != 0)— hash 0 is the padding sentinel.
The projection allows the model to learn how to weight the multi-hash representation. Could just average them, but a learned projection gives more flexibility.
hk.get_parameter registers a trainable parameter with Haiku. The init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T is a quirk — VarianceScaling with mode="fan_out" expects shape in (fan_in, fan_out) convention but Haiku's parameter shape is [fan_out, fan_in]. The reversed-then-transposed init makes the variance scaling work in the intended direction.
block_history_reduce — combine history embeddings
def block_history_reduce(
history_post_hashes: jnp.ndarray,
history_post_embeddings: jnp.ndarray,
history_author_embeddings: jnp.ndarray,
history_product_surface_embeddings: jnp.ndarray,
history_actions_embeddings: jnp.ndarray,
num_item_hashes: int,
num_author_hashes: int,
embed_init_scale: float = 1.0,
*,
history_continuous_embeddings: Optional[jnp.ndarray] = None,
history_post_age_embeddings: Optional[jnp.ndarray] = None,
) -> Tuple[jax.Array, jax.Array]:
"""Combine history embeddings (post, author, actions, product_surface, ...) into sequence.
"""
B, S, _, D = history_post_embeddings.shape
history_post_embeddings_reshaped = history_post_embeddings.reshape((B, S, num_item_hashes * D))
history_author_embeddings_reshaped = history_author_embeddings.reshape(
(B, S, num_author_hashes * D)
)
parts = [
history_post_embeddings_reshaped,
history_author_embeddings_reshaped,
history_actions_embeddings,
history_product_surface_embeddings,
]
if history_continuous_embeddings is not None:
parts.append(history_continuous_embeddings)
if history_post_age_embeddings is not None:
parts.append(history_post_age_embeddings)
post_author_embedding = jnp.concatenate(parts, axis=-1)
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_3 = hk.get_parameter(
"proj_mat_3",
[post_author_embedding.shape[-1], D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
history_embedding = jnp.dot(post_author_embedding.astype(proj_mat_3.dtype), proj_mat_3).astype(
post_author_embedding.dtype
)
history_embedding = history_embedding.reshape(B, S, D)
history_padding_mask = (history_post_hashes[:, :, 0] != 0).reshape(B, S)
return history_embedding, history_padding_mask
Same idea but for history: each position in the user's history is a (post, author, action, product surface, [continuous], [age]) tuple. Concatenate all those embeddings along the feature axis, project to D dimensions, output the sequence.
Each history position becomes a single D-dim vector. The transformer then processes the sequence of these vectors.
Optional inputs (history_continuous_embeddings, history_post_age_embeddings) allow extending the history feature set without breaking the API.
history_padding_mask: which positions in the history are valid. Hash 0 marks padding (history is right-padded to history_seq_len).
block_candidate_reduce — combine candidate embeddings
def block_candidate_reduce(
candidate_post_hashes: jnp.ndarray,
candidate_post_embeddings: jnp.ndarray,
candidate_author_embeddings: jnp.ndarray,
candidate_product_surface_embeddings: jnp.ndarray,
num_item_hashes: int,
num_author_hashes: int,
embed_init_scale: float = 1.0,
*,
candidate_post_age_embeddings: Optional[jnp.ndarray] = None,
) -> Tuple[jax.Array, jax.Array]:
"""Combine candidate embeddings (post, author, product_surface, ...) into sequence.
"""
B, C, _, D = candidate_post_embeddings.shape
candidate_post_embeddings_reshaped = candidate_post_embeddings.reshape(
(B, C, num_item_hashes * D)
)
candidate_author_embeddings_reshaped = candidate_author_embeddings.reshape(
(B, C, num_author_hashes * D)
)
parts = [
candidate_post_embeddings_reshaped,
candidate_author_embeddings_reshaped,
candidate_product_surface_embeddings,
]
if candidate_post_age_embeddings is not None:
parts.append(candidate_post_age_embeddings)
post_author_embedding = jnp.concatenate(parts, axis=-1)
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_2 = hk.get_parameter(
"proj_mat_2",
[post_author_embedding.shape[-1], D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
candidate_embedding = jnp.dot(
post_author_embedding.astype(proj_mat_2.dtype), proj_mat_2
).astype(post_author_embedding.dtype)
candidate_padding_mask = (candidate_post_hashes[:, :, 0] != 0).reshape(B, C).astype(jnp.bool_)
return candidate_embedding, candidate_padding_mask
Like history but no actions (candidates haven't been acted on yet — they're what we're predicting actions for!) and no continuous embeddings. Otherwise same shape.
Note the three separate projection matrices:
proj_mat_1for users.proj_mat_2for candidates.proj_mat_3for history.
So each entity type has its own learned reducer. They're not shared — the model learns different combinations for each.
PhoenixModelConfig
@dataclass
class PhoenixModelConfig:
"""Configuration for the recommendation system model."""
model: TransformerConfig
emb_size: int
num_actions: int
history_seq_len: int = 128
candidate_seq_len: int = 32
name: Optional[str] = None
fprop_dtype: Any = jnp.bfloat16
hash_config: HashConfig = None # type: ignore
product_surface_vocab_size: int = 16
post_age_granularity_mins: int = 60
num_continuous_actions: int = 8
continuous_action_hidden_dim: int = 64
continuous_action_config: ContinuousActionConfig = None # type: ignore
use_ip_address: bool = False
right_anchored_rope: bool = False
mask_neg_feedback_on_negatives: bool = True
_initialized = False
Big config dataclass. The dimensional defaults:
history_seq_len = 128— process the most recent 128 user actions.candidate_seq_len = 32— max 32 candidates per request.product_surface_vocab_size = 16— 16 different product surfaces (Home, Notifications, etc.). Each gets its own embedding.post_age_granularity_mins = 60— 1-hour age buckets.num_continuous_actions = 8— 8 different continuous-value action heads (different durations / engagement signals).continuous_action_hidden_dim = 64— small MLP hidden size for continuous embedding.
fprop_dtype = jnp.bfloat16 — bf16 for forward-prop. Half the memory + 2× the throughput on accelerators, with full f32 dynamic range. Standard for inference.
mask_neg_feedback_on_negatives: bool = True — handle "negative feedback" labels (block, mute, report) carefully during training. The flag's behavior lives in the loss computation (not in this file).
def __post_init__(self):
if self.hash_config is None:
self.hash_config = HashConfig()
if self.continuous_action_config is None:
self.continuous_action_config = ContinuousActionConfig()
@property
def post_age_vocab_size(self) -> int:
"""Derived vocab size for post age buckets: num_normal + overflow + missing."""
return (POST_AGE_MAX_MINUTES // self.post_age_granularity_mins) + 2
def initialize(self):
self._initialized = True
return self
def make(self):
if not self._initialized:
logger.warning(f"PhoenixModel {self.name} is not initialized. Initializing.")
self.initialize()
return PhoenixModel(
model=self.model.make(),
config=self,
fprop_dtype=self.fprop_dtype,
)
post_age_vocab_size = (4800 // 60) + 2 = 82 for default granularity. 80 normal buckets + overflow + missing.
initialize() / make() — two-stage construction pattern. initialize() is for any one-time setup (here a no-op, just sets the flag). make() builds the actual model.
PhoenixModel — the ranking model proper
@dataclass
class PhoenixModel(hk.Module):
"""A transformer-based recommendation model for ranking candidates."""
model: Transformer
config: PhoenixModelConfig
fprop_dtype: Any = jnp.bfloat16
name: Optional[str] = None
A Haiku module holding the inner Transformer + config. model is the actual transformer instance (built from TransformerConfig.make() in __init__).
Action embeddings
def _get_action_embeddings(
self,
actions: jax.Array,
) -> jax.Array:
"""Convert multi-hot action vectors to embeddings.
Uses a learned projection matrix to map the signed action vector
to the embedding dimension. This works for any number of actions.
"""
config = self.config
_, _, num_actions = actions.shape
D = config.emb_size
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
action_projection = hk.get_parameter(
"action_projection",
[num_actions, D],
dtype=jnp.float32,
init=embed_init,
)
actions_signed = (2 * actions - 1).astype(jnp.float32)
action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)
valid_mask = jnp.any(actions, axis=-1, keepdims=True)
action_emb = action_emb * valid_mask
return action_emb.astype(self.fprop_dtype)
Multi-hot → embedding via a learned [num_actions, D] projection.
The key trick: actions_signed = 2 * actions - 1. This maps {0, 1} → {-1, +1}. Why? Because in a multi-hot vector, 0 would normally mean "no signal" — but in the recsys context, a 0 means "this action did NOT happen" which is itself a signal.
By using {-1, +1}, both happened and not-happened actions contribute (with opposite signs) to the embedding. The model learns directional features: "this position had favorite but not retweet" produces a specific embedding direction.
valid_mask = jnp.any(actions, axis=-1, keepdims=True) — if all actions are 0 (padding row), zero out the embedding. Padding has no signal.
Single-hot to embedding (lookup table)
def _single_hot_to_embeddings(
self,
input: jax.Array,
vocab_size: int,
emb_size: int,
name: str,
) -> jax.Array:
"""Convert single-hot indices to embeddings via lookup table.
...
"""
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
embedding_table = hk.get_parameter(
name,
[vocab_size, emb_size],
dtype=jnp.float32,
init=embed_init,
)
input_one_hot = jax.nn.one_hot(input, vocab_size)
output = jnp.dot(input_one_hot, embedding_table)
return output.astype(self.fprop_dtype)
Standard embedding lookup. Done via one-hot + matrix multiply instead of take/gather. On TPU/GPU, the one-hot-and-matmul approach is often faster than indexed gather. Tradeoff: memory (one_hot materializes a [..., vocab_size] tensor).
Used for product surfaces and post age buckets.
Unembedding and continuous head
def _get_unembedding(self) -> jax.Array:
"""Get the unembedding matrix for decoding to discrete action logits."""
config = self.config
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
unembed_mat = hk.get_parameter(
"unembeddings",
[config.emb_size, config.num_actions],
dtype=jnp.float32,
init=embed_init,
)
return unembed_mat
def _get_continuous_head(self) -> jax.Array:
config = self.config
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
continuous_mat = hk.get_parameter(
"continuous_unembeddings",
[config.emb_size, config.num_continuous_actions],
dtype=jnp.float32,
init=embed_init,
)
return continuous_mat
Two output heads:
- Unembedding
[D, num_actions]: project transformer output → discrete action logits. - Continuous head
[D, num_continuous_actions]: project → continuous predictions (sigmoid'd to[0, 1]later).
These are not tied with the input embeddings — independent learned parameters. (Tied embeddings save parameters but constrain expressiveness; with so few actions, the constraint isn't worth it.)
Continuous-value embedding (MLP)
def _project_continuous_value_to_embedding(
self,
values: jnp.ndarray,
D: int,
param_name: str,
norm_config: NormConfig,
hidden_dim: int = 64,
) -> jax.Array:
values_normalized = normalize_continuous_value(values, norm_config)
values_expanded = values_normalized[..., None] # [B, seq_len, 1]
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
proj1 = hk.get_parameter(
f"{param_name}_proj1",
[1, hidden_dim],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
hidden = jnp.dot(values_expanded.astype(proj1.dtype), proj1)
hidden = jax.nn.gelu(hidden)
proj2 = hk.get_parameter(
f"{param_name}_proj2",
[hidden_dim, D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
embedding = jnp.dot(hidden, proj2)
return embedding.astype(self.fprop_dtype)
Continuous scalar → D-dim embedding via a 2-layer MLP with GELU activation. Pipeline:
- Normalize to
[0, 1]. - Expand to
[..., 1](add a feature axis). - Project to hidden_dim with weight
[1, hidden_dim]. - GELU activation.
- Project to D with weight
[hidden_dim, D].
So a single dwell-time value becomes a D-dim embedding via a small MLP. Used for history dwell time specifically (we'll see in build_inputs below).
GELU (Gaussian Error Linear Unit) — smoother than ReLU, common in transformers.
build_inputs — assemble the transformer input
This is the glue function that builds the full input sequence for the transformer.
def build_inputs(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array, int]:
"""Build input embeddings from batch and pre-looked-up embeddings.
...
Returns:
embeddings: [B, 1 + history_len + num_candidates, D]
padding_mask: [B, 1 + history_len + num_candidates]
candidate_start_offset: int - position where candidates start
"""
config = self.config
hash_config = config.hash_config
history_product_surface_embeddings = self._single_hot_to_embeddings(
batch.history_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
candidate_product_surface_embeddings = self._single_hot_to_embeddings(
batch.candidate_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
Lookup product surface embeddings for history and candidates. Same embedding table for both (name="product_surface_embedding_table" is identical in both calls). This is parameter sharing — a "Home" surface has the same embedding whether it appeared in history or candidates.
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
B_size = batch.history_product_surface.shape[0] # type: ignore
S_size = batch.history_product_surface.shape[1] # type: ignore
if batch.history_continuous_actions is not None:
dwell_values = batch.history_continuous_actions[:, :, 1] # index 1 = dwell_time
else:
dwell_values = jnp.zeros((B_size, S_size), dtype=jnp.float32)
history_continuous_embeddings = self._project_continuous_value_to_embedding(
dwell_values,
config.emb_size,
"history_dwell_time",
config.continuous_action_config.norm_config,
config.continuous_action_hidden_dim,
)
Build action and dwell-time embeddings for history. Index 1 = dwell_time in the continuous actions tensor (the comment confirms this). Only this one continuous feature gets a learned embedding (the others are predicted but not fed back as features).
If continuous actions weren't provided (during retrieval or some inference paths), use zeros — the MLP will produce a constant zero-ish embedding.
user_embeddings, user_padding_mask = block_user_reduce(
batch.user_hashes, # type: ignore
recsys_embeddings.user_embeddings, # type: ignore
hash_config.num_user_hashes,
config.emb_size,
1.0,
user_ip_embeddings=recsys_embeddings.user_ip_embeddings,
num_ip_hashes=hash_config.num_ip_hashes,
)
history_embeddings, history_padding_mask = block_history_reduce(
batch.history_post_hashes, # type: ignore
recsys_embeddings.history_post_embeddings, # type: ignore
recsys_embeddings.history_author_embeddings, # type: ignore
history_product_surface_embeddings,
history_actions_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
history_continuous_embeddings=history_continuous_embeddings,
)
Call the three reduce blocks (user + history). Each returns (embedding, padding_mask).
C_size = batch.candidate_product_surface.shape[1] # type: ignore
if batch.candidate_impr_ts is not None and batch.candidate_post_creation_ts is not None:
post_age_buckets = compute_post_age_bucket(
batch.candidate_impr_ts,
batch.candidate_post_creation_ts,
config.post_age_granularity_mins,
)
else:
post_age_buckets = jnp.zeros((B_size, C_size), dtype=jnp.int32)
candidate_post_age_embeddings = self._single_hot_to_embeddings(
post_age_buckets,
config.post_age_vocab_size,
config.emb_size,
"post_age_embedding_table",
)
candidate_embeddings, candidate_padding_mask = block_candidate_reduce(
batch.candidate_post_hashes, # type: ignore
recsys_embeddings.candidate_post_embeddings, # type: ignore
recsys_embeddings.candidate_author_embeddings, # type: ignore
candidate_product_surface_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
candidate_post_age_embeddings=candidate_post_age_embeddings,
)
Compute candidate post age buckets and embed them. Then reduce candidates (with the age embeddings).
Note: post age is only computed for candidates, not history. History positions already have a "natural" recency order via the sequence position.
embeddings = jnp.concatenate(
[user_embeddings, history_embeddings, candidate_embeddings], axis=1
)
padding_mask = jnp.concatenate(
[user_padding_mask, history_padding_mask, candidate_padding_mask], axis=1
)
candidate_start_offset = user_padding_mask.shape[1] + history_padding_mask.shape[1]
return embeddings.astype(self.fprop_dtype), padding_mask, candidate_start_offset
The big concatenation. Shape: [B, 1 + S + C, D] where:
- 1 = user token
- S = history sequence length
- C = candidate count
candidate_start_offset tells the transformer where candidates begin — used for candidate isolation below.
This is the core architecture: a single sequence containing user + history + candidates, processed by one transformer.
__call__ — forward pass
def __call__(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> RecsysModelOutput:
"""Forward pass for ranking candidates.
...
"""
embeddings, padding_mask, candidate_start_offset = self.build_inputs(
batch, recsys_embeddings
)
positions = None
if self.config.right_anchored_rope:
positions = right_anchored_rope_positions(
padding_mask,
history_seq_len=self.config.history_seq_len,
num_user_prefix_tokens=1,
)
# transformer
model_output = self.model(
embeddings,
padding_mask,
candidate_start_offset=candidate_start_offset,
positions=positions,
)
out_embeddings = model_output.embeddings
out_embeddings = layer_norm(out_embeddings)
candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]
unembeddings = self._get_unembedding()
logits = jnp.dot(candidate_embeddings.astype(unembeddings.dtype), unembeddings)
logits = logits.astype(self.fprop_dtype)
continuous_mat = self._get_continuous_head()
continuous_logits = jnp.dot(
candidate_embeddings.astype(continuous_mat.dtype), continuous_mat
)
continuous_preds = jax.nn.sigmoid(continuous_logits).astype(self.fprop_dtype)
return RecsysModelOutput(logits=logits, continuous_preds=continuous_preds)
The full forward pass:
- Build inputs (we just walked this).
- Compute positions (if right-anchored RoPE is enabled). The position 0 anchors at the most-recent history item; older history positions get progressively higher indices. Right-anchored = the model sees recent history with low position indices regardless of padding.
- Transformer call — the magic happens here. The Transformer receives
candidate_start_offsetand is expected to apply candidate isolation: candidates can attend to user + history + each other? No, the README says they CANNOT attend to each other. So the attention mask is built such that candidates only attend to user + history. We'll see exactly how when we readgrok.pyin Session 17. - LayerNorm the output.
- Slice to keep only candidate positions (the first
candidate_start_offsetare user + history outputs which we discard). - Unembed: project to action logits
[B, C, num_actions]. - Continuous predictions: project to continuous logits, then sigmoid to
[0, 1].
The output: per-candidate per-action probabilities. This is what PhoenixScorer (Session 10) stuffs into candidate.phoenix_scores.
End of recsys_model.py.
recsys_retrieval_model.py (388 lines) — two-tower retrieval
The retrieval model. Different goal: instead of scoring a small set of candidates, find the top K from a large corpus via nearest-neighbor search.
Imports
import logging
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
from grok import TransformerConfig, Transformer
from recsys_model import (
HashConfig,
RecsysBatch,
RecsysEmbeddings,
block_history_reduce,
block_user_reduce,
)
logger = logging.getLogger(__name__)
EPS = 1e-12
INF = 1e12
Reuses the structural blocks from recsys_model.py: HashConfig, RecsysBatch, RecsysEmbeddings, plus the user and history reduce functions.
Two numerical constants: EPS for safe division, INF for masking (set masked-out scores to -INF).
RetrievalOutput
class RetrievalOutput(NamedTuple):
"""Output of the retrieval model."""
user_representation: jax.Array
top_k_indices: jax.Array
top_k_scores: jax.Array
The output is the user vector + top-K indices into the corpus + their scores. Indices, not the actual embeddings — the caller looks them up.
CandidateTower
@dataclass
class CandidateTower(hk.Module):
"""Candidate tower that projects post+author embeddings to a shared embedding space.
This tower takes the concatenated embeddings of a post and its author,
and projects them to a normalized representation suitable for similarity search.
Supports two modes:
- enable_linear_proj=True: Two-layer MLP (SiLU) projection followed by L2 normalization.
- enable_linear_proj=False: Simple mean pooling across hash embeddings followed by
L2 normalization. More parameter-efficient but less expressive.
"""
emb_size: int
enable_linear_proj: bool = True
name: Optional[str] = None
def __call__(self, post_author_embedding: jax.Array) -> jax.Array:
"""Project post+author embeddings to normalized representation.
...
"""
if not self.enable_linear_proj:
candidate_representation = jnp.mean(post_author_embedding, axis=-2)
candidate_norm_sq = jnp.sum(candidate_representation**2, axis=-1, keepdims=True)
candidate_norm = jnp.sqrt(jnp.maximum(candidate_norm_sq, EPS))
candidate_representation = candidate_representation / candidate_norm
return candidate_representation.astype(post_author_embedding.dtype)
The simpler mode (enable_linear_proj=False): just mean-pool across the hash dimension, L2-normalize. Cheap, no learned parameters specific to this tower.
if len(post_author_embedding.shape) == 4:
B, C, _, _ = post_author_embedding.shape
post_author_embedding = jnp.reshape(post_author_embedding, (B, C, -1))
else:
B, _, _ = post_author_embedding.shape
post_author_embedding = jnp.reshape(post_author_embedding, (B, -1))
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
proj_1 = hk.get_parameter(
"candidate_tower_projection_1",
[post_author_embedding.shape[-1], self.emb_size * 2],
dtype=jnp.float32,
init=embed_init,
)
proj_2 = hk.get_parameter(
"candidate_tower_projection_2",
[self.emb_size * 2, self.emb_size],
dtype=jnp.float32,
init=embed_init,
)
hidden = jnp.dot(post_author_embedding.astype(proj_1.dtype), proj_1)
hidden = jax.nn.silu(hidden)
candidate_embeddings = jnp.dot(hidden.astype(proj_2.dtype), proj_2)
candidate_norm_sq = jnp.sum(candidate_embeddings**2, axis=-1, keepdims=True)
candidate_norm = jnp.sqrt(jnp.maximum(candidate_norm_sq, EPS))
candidate_representation = candidate_embeddings / candidate_norm
return candidate_representation.astype(post_author_embedding.dtype)
The MLP mode (enable_linear_proj=True, default):
- Flatten the hash dimension into the feature axis.
- Project up to
2 * emb_size(expansion). - SiLU activation (sigmoid-weighted linear unit, smooth like GELU).
- Project down to
emb_size. - L2-normalize.
Shape-handling: works for both [B, C, num_hashes, D] (per-batch candidates) and [B, num_hashes, D] (single candidate per batch, used in corpus precomputation).
The L2 normalization is critical: with both user and candidate vectors normalized to unit length, dot product = cosine similarity. Lets the retrieval use approximate nearest-neighbor indices like FAISS that work on cosine distance.
PhoenixRetrievalModelConfig
@dataclass
class PhoenixRetrievalModelConfig:
"""Configuration for the Phoenix Retrieval Model.
This model uses the same transformer architecture as the Phoenix ranker
for encoding user representations.
"""
model: TransformerConfig
emb_size: int
history_seq_len: int = 128
candidate_seq_len: int = 32
name: Optional[str] = None
fprop_dtype: Any = jnp.bfloat16
hash_config: HashConfig = None # type: ignore
product_surface_vocab_size: int = 16
enable_linear_proj: bool = True
_initialized: bool = False
def __post_init__(self):
if self.hash_config is None:
self.hash_config = HashConfig()
Smaller config than the ranker — no num_actions (retrieval doesn't predict actions), no continuous actions, no post age buckets. Just the basics + the candidate-tower mode toggle.
PhoenixRetrievalModel
@dataclass
class PhoenixRetrievalModel(hk.Module):
"""A two-tower retrieval model using the Phoenix transformer for user encoding.
This model implements the two-tower architecture for efficient retrieval:
- User Tower: Encodes user features + history using the Phoenix transformer
- Candidate Tower: Projects candidate embeddings to a shared space
The user and candidate representations are L2-normalized, enabling efficient
approximate nearest neighbor (ANN) search using dot product similarity.
"""
model: Transformer
config: PhoenixRetrievalModelConfig
fprop_dtype: Any = jnp.bfloat16
name: Optional[str] = None
The model. Notice the same Transformer class as the ranker — but used differently (only encodes user, not candidates).
build_user_representation
def build_user_representation(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]:
"""Build user representation from user features and history.
Uses the Phoenix transformer to encode user + history embeddings
into a single user representation vector.
...
Returns:
user_representation: L2-normalized user embedding [B, D]
user_norm: Pre-normalization L2 norm [B, 1]
"""
config = self.config
hash_config = config.hash_config
history_product_surface_embeddings = self._single_hot_to_embeddings(
batch.history_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
user_embeddings, user_padding_mask = block_user_reduce(
batch.user_hashes, # type: ignore
recsys_embeddings.user_embeddings, # type: ignore
hash_config.num_user_hashes,
config.emb_size,
1.0,
)
history_embeddings, history_padding_mask = block_history_reduce(
batch.history_post_hashes, # type: ignore
recsys_embeddings.history_post_embeddings, # type: ignore
recsys_embeddings.history_author_embeddings, # type: ignore
history_product_surface_embeddings,
history_actions_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
)
embeddings = jnp.concatenate([user_embeddings, history_embeddings], axis=1)
padding_mask = jnp.concatenate([user_padding_mask, history_padding_mask], axis=1)
model_output = self.model(
embeddings.astype(self.fprop_dtype),
padding_mask,
candidate_start_offset=None,
)
Build the same user + history input but no candidates. Pass through the transformer with candidate_start_offset=None. The transformer treats this as plain causal attention (no isolation needed).
user_outputs = model_output.embeddings
mask_float = padding_mask.astype(jnp.float32)[:, :, None] # [B, T, 1]
user_embeddings_masked = user_outputs * mask_float
user_embedding_sum = jnp.sum(user_embeddings_masked, axis=1) # [B, D]
mask_sum = jnp.sum(mask_float, axis=1) # [B, 1]
user_representation = user_embedding_sum / jnp.maximum(mask_sum, 1.0)
user_norm_sq = jnp.sum(user_representation**2, axis=-1, keepdims=True)
user_norm = jnp.sqrt(jnp.maximum(user_norm_sq, EPS))
user_representation = user_representation / user_norm
return user_representation, user_norm
Mean-pool across the sequence (masked):
- Multiply by padding mask (zero out padded positions).
- Sum along the sequence axis.
- Divide by the number of valid positions (
max(mask_sum, 1.0)to avoid divide-by-zero). - L2-normalize.
So the user is represented as the mean of (user token + all history tokens) after the transformer. Different from the ranker, which takes specific candidate-position outputs.
max(mask_sum, 1.0) — for a totally-padded sequence (all-zeros mask), we'd divide by 0. Clip the denominator to 1.0. The numerator is also 0 (all values zeroed), so the result is 0. The L2 normalization would then produce 0 / sqrt(EPS) = 0 — a zero vector. Bad case but won't crash.
build_candidate_representation
def build_candidate_representation(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]:
"""Build candidate (item) representations.
Projects post + author embeddings to a shared embedding space
using the candidate tower.
...
"""
config = self.config
candidate_post_embeddings = recsys_embeddings.candidate_post_embeddings
candidate_author_embeddings = recsys_embeddings.candidate_author_embeddings
post_author_embedding = jnp.concatenate(
[candidate_post_embeddings, candidate_author_embeddings], axis=2
)
candidate_tower = CandidateTower(
emb_size=config.emb_size,
enable_linear_proj=config.enable_linear_proj,
)
candidate_representation = candidate_tower(post_author_embedding)
candidate_padding_mask = (batch.candidate_post_hashes[:, :, 0] != 0).astype(jnp.bool_) # type: ignore
return candidate_representation, candidate_padding_mask
No transformer for candidates — just concatenate post + author embeddings (along the hash axis, position 2) and pass through the CandidateTower MLP. Cheaper than a transformer because the candidate corpus is huge (millions of posts) — we can't afford a transformer per candidate.
This is the classic two-tower trade-off: rich user encoder (transformer), cheap candidate encoder (MLP). The candidate embeddings can be precomputed offline and stored in an ANN index.
__call__ — top-K retrieval
def __call__(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
corpus_mask: Optional[jax.Array] = None,
) -> RetrievalOutput:
"""Retrieve top-k candidates from corpus for each user.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
corpus_embeddings: [N, D] normalized corpus candidate embeddings
top_k: Number of candidates to retrieve
corpus_mask: [N] optional mask for valid corpus entries
Returns:
RetrievalOutput containing user representation and top-k results
"""
user_representation, _ = self.build_user_representation(batch, recsys_embeddings)
top_k_indices, top_k_scores = self._retrieve_top_k(
user_representation, corpus_embeddings, top_k, corpus_mask
)
return RetrievalOutput(
user_representation=user_representation,
top_k_indices=top_k_indices,
top_k_scores=top_k_scores,
)
The retrieval entry point. Takes the user batch + a precomputed corpus (shape [N, D], already L2-normalized — typically stored on the device or streamed in chunks).
Computes user representation, then top-K from corpus. Returns indices + scores.
corpus_mask: Optional[jax.Array] — [N] boolean mask. Used to filter out items (e.g., user's own posts, blocked accounts) before the top-K. Could also be used to limit the corpus to a subset (e.g., user-language posts).
_retrieve_top_k
def _retrieve_top_k(
self,
user_representation: jax.Array,
corpus_embeddings: jax.Array,
top_k: int,
corpus_mask: Optional[jax.Array] = None,
) -> Tuple[jax.Array, jax.Array]:
"""Retrieve top-k candidates from a corpus for each user.
...
"""
scores = jnp.matmul(user_representation, corpus_embeddings.T)
if corpus_mask is not None:
scores = jnp.where(corpus_mask[None, :], scores, -INF)
top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k)
return top_k_indices, top_k_scores
Brute-force top-K:
matmul(user, corpus.T)— shape[B, D] × [D, N] = [B, N]. One similarity score per (user, corpus item).- Mask out invalid items (replace their scores with
-INFso they can't win top-K). jax.lax.top_k— vectorized top-K, returns sorted scores + indices.
corpus_mask[None, :] broadcasts [N] → [1, N] for compatibility with the [B, N] scores.
This is dense matmul, not approximate. For a corpus of N=10M items with D=256 and B=256 users in a batch, that's 10M × 256 × 256 = ~650 GFLOPs per batch — feasible on a single TPU.
For production though, the corpus is much larger; the actual production setup likely uses partitioned corpus (each retrieval host serves a slice) + chunked matmul (process corpus in chunks of 100K at a time, merge top-K). The model code here is the exact version — clean for training; production wraps it with sharding.
End of recsys_retrieval_model.py.
What we've learned
Two models, one transformer:
- Ranking (
PhoenixModel): user + history + candidates → per-candidate per-action probabilities. - Retrieval (
PhoenixRetrievalModel): two-tower — transformer encodes user, MLP encodes candidates → top-K similarity search.
Both share the same Transformer class (from grok, Session 17), HashConfig, RecsysBatch, RecsysEmbeddings, and the reduce blocks. Modular reuse.
The unified sequence input for ranking:
[user_token, history_0, history_1, ..., history_127, candidate_0, candidate_1, ..., candidate_31]
Single transformer pass over 1 + 128 + 32 = 161 tokens. Output for the last 32 positions = per-candidate scores.
Candidate isolation (from the README): candidates can't attend to each other. This makes the score for a candidate independent of other candidates in the batch — so the same post scored in two different requests gets the same score. Critical for cache validity and reproducibility.
Hash-based embeddings: num_user_hashes = num_item_hashes = num_author_hashes = 2. Each entity contributes 2 embeddings, combined via learned projection. Reduces parameter count + collision impact.
The reduce blocks:
block_user_reduce: combines user hashes + (optional) IP embeddings.block_history_reduce: combines history post + author + actions + product_surface + (optional) continuous + (optional) age.block_candidate_reduce: combines candidate post + author + product_surface + (optional) age.
Each has its own learned projection (3 different proj_mats — not shared across entity types).
Three distinct embedding strategies:
- Hash lookup (
block_*_reduce): users, posts, authors via multiple hashes → projected. - Single-hot lookup (
_single_hot_to_embeddings): product_surface, post_age_bucket — small vocabularies, direct table. - Continuous MLP (
_project_continuous_value_to_embedding): dwell time → MLP-projected to D.
{-1, +1} for multi-hot actions: maps 0 → -1, 1 → +1. Lets both "did happen" and "did not happen" contribute to the embedding via the same learned projection. Standard trick from sparse-feature embedding literature.
Two output heads:
- Discrete: per-candidate per-action logits (sigmoid'd later by the caller for probabilities).
- Continuous: per-candidate per-continuous-action, sigmoid'd in-model to
[0, 1].
Right-anchored RoPE: position 0 is the most-recent history token. As you go back in history, position increases. Means the model can extrapolate to longer histories (recent is always position 0).
Two-tower decoupling:
- User tower: heavy (transformer + reduce blocks).
- Candidate tower: light (MLP).
- Both end in L2 normalization → cosine similarity via dot product.
- Candidate embeddings can be precomputed → ANN index → cheap retrieval at request time.
Mean-pool for user representation in retrieval: rather than taking a "user token" output (like a [CLS] token in BERT), pool across all valid positions (user + history). Equal weighting of every history item contributes to the final vector.
The output of ranking matches what Session 10's PhoenixScorer consumes: per-candidate phoenix_scores containing favorite_score, reply_score, etc. The action heads in this model output are precisely those.
bf16 forward pass: half the bytes per number, full f32 range. Saves memory bandwidth; numerical accuracy stays acceptable for inference.
Next session
Session 16 — Phoenix runners (1,470 LOC). The orchestration code that takes the models from this session and runs them in production:
phoenix/runners.py(807) — the production runner with checkpoint loading, embedding lookups, batched inferencephoenix/run_pipeline.py(393) — the end-to-end inference entry point (new in this release)phoenix/run_ranker.py(121) — standalone ranker runnerphoenix/run_retrieval.py(149) — standalone retrieval runner