X For You algorithm, line by line · Part 17
X For You algorithm, line by line — Part 17: Grok transformer + tests
Part 17 — the final Phoenix session. The Grok-1-derived transformer that powers both ranking and retrieval: candidate isolation attention mask, right-anchored RoPE positions, GQA with tanh-clamping, GeGLU feed-forward, the double-layer-norm DecoderLayer, plus the test suites that pin down the most subtle pieces.
The last Phoenix session. Three files: the actual Grok-1-derived transformer code that does the heavy lifting in both ranking and retrieval, plus the test suites that pin down the behavior of the critical pieces.
Files covered (1,342 LOC):
phoenix/
├── grok.py (616) Transformer + attention + RoPE + recsys mask
├── test_recsys_model.py (309) attention mask + RoPE + bucketing tests
└── test_recsys_retrieval_model.py (417) CandidateTower + retrieval model + runner tests
The README mentions: "The transformer implementation is ported from the Grok-1 open source release by xAI, adapted for recommendation system use cases." This file is what they mean by "adapted" — same transformer architecture as Grok-1 but with two key recsys-specific modifications:
make_recsys_attn_mask— the candidate-isolation attention mask we keep referencing.right_anchored_rope_positions— RoPE positions that anchor the most-recent history token at a fixed index.
We'll spend most time on those, plus the standard transformer machinery.
grok.py (616 lines)
Imports + types
import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
logger = logging.getLogger(__name__)
class TrainingState(NamedTuple):
"""Container for the training state."""
params: hk.Params
TrainingState is what runners.py (Session 16) wraps loaded params in. Standard Haiku pattern: a (params,) named tuple. Could carry optimizer state too if doing training; here just params.
ffn_size — round up FFN size to multiple of 8
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
return _ffn_size
The FFN (feed-forward network) hidden dimension. The * 2 // 3 is the SwiGLU adjustment: a SwiGLU FFN uses 3 linear projections (gate, up, down), so to keep param count comparable to a standard 2-projection FFN at widening_factor * emb_size, you scale down by 2/3.
Round up to multiple of 8 for hardware efficiency — matrix-multiply hardware likes dimensions divisible by 8 (or 16, or 64) for tensor-core utilization.
For emb_size=256, widening_factor=2.0: int(512) * 2 // 3 = 341, rounded to 344.
make_recsys_attn_mask — the critical recsys modification
def make_recsys_attn_mask(
seq_len: int,
candidate_start_offset: int,
dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""Create attention mask for recommendation system inference.
Creates a mask where:
- Positions 0 to candidate_start_offset-1 (user+history): causal attention
- Positions candidate_start_offset onwards (candidates): can attend to user+history
and themselves (self-attention), but NOT to other candidates
This ensures each candidate is scored independently based on user+history context.
Args:
seq_len: Total sequence length (user + history + candidates)
candidate_start_offset: Position where candidates start in the sequence
dtype: Data type for the mask
Returns:
Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
"""
# Start with causal mask for the full sequence
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))
# Zero out candidate-to-candidate attention (bottom-right block)
attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)
# Add back self-attention for candidates (diagonal of the candidate block)
candidate_indices = jnp.arange(candidate_start_offset, seq_len)
attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)
return attn_mask
This is the most important function in the file for recsys correctness.
The mask construction:
- Start with full causal mask —
tril(lower triangle) gives the classic "each token can attend to itself and previous tokens." - Zero out the candidate block —
attn_mask[:, :, c:, c:].set(0). This removes ALL candidate-to-candidate attention (including each candidate attending to itself, which we'll restore). - Add back diagonal in the candidate block —
attn_mask[:, :, indices, indices].set(1)restores self-attention for each candidate.
The .at[indices].set(value) pattern is JAX's functional update syntax (since JAX arrays are immutable).
Final mask structure for a sequence [user, h1, h2, c1, c2, c3]:
user h1 h2 c1 c2 c3
user 1 0 0 0 0 0 ← causal: only sees self
h1 1 1 0 0 0 0 ← causal: sees user + self
h2 1 1 1 0 0 0 ← causal: sees user, h1, self
c1 1 1 1 1 0 0 ← sees user+history + self only
c2 1 1 1 0 1 0 ← sees user+history + self only
c3 1 1 1 0 0 1 ← sees user+history + self only
(This is the exact expected matrix from test_full_mask_structure we'll read in the test file.)
Why this design?
- No candidate-to-candidate attention ⇒ each candidate's representation depends only on (user + history + itself). Two candidates in the same batch don't influence each other's scores.
- Self-attention for each candidate ⇒ the candidate's own embedding affects its output. Without it, the model would produce identical outputs for all candidates (since they'd be attending only to user + history, which is the same for all).
- Causal mask for user + history ⇒ standard transformer causality. Older history doesn't see newer history. Even though the user prefix is at position 0, it gets attended by everything (causal).
The README highlights this:
Candidate Isolation in Ranking: During transformer inference, candidates cannot attend to each other—only to the user context. This ensures the score for a post doesn't depend on which other posts are in the batch, making scores consistent and cacheable.
MHAOutput, DecoderOutput, TransformerOutput
class MHAOutput(NamedTuple):
"""Outputs of the multi-head attention operation."""
embeddings: jax.Array
class DecoderOutput(NamedTuple):
embeddings: jax.Array
class TransformerOutput(NamedTuple):
embeddings: jax.Array
Three single-field named tuples. Why a named tuple for a single field? Future-proofing — if you later add attention_weights or intermediate_activations, you don't break callers.
right_anchored_rope_positions — the second recsys modification
def right_anchored_rope_positions(
padding_mask: jax.Array,
history_seq_len: int,
num_user_prefix_tokens: int,
) -> jax.Array:
"""Compute RoPE positions where the newest history token always gets a fixed position."""
history_start = num_user_prefix_tokens
history_end = num_user_prefix_tokens + history_seq_len
idx = jnp.arange(padding_mask.shape[1], dtype=jnp.int32)[None, :]
history_len = padding_mask[:, history_start:history_end].sum(axis=1, dtype=jnp.int32)
positions = jnp.where(
(history_start <= idx) & (idx < history_end),
history_end - history_len[:, None] + idx - history_start,
idx,
)
positions = jnp.where(idx >= history_end, history_end, positions)
positions = jnp.where(padding_mask, positions, 0).astype(jnp.float32)
return positions
Right-anchored RoPE positions — fixes the "history of variable length" problem.
The standard RoPE assigns position 0 to the first token, 1 to the second, etc. But for history with variable length, this means two users with different history lengths get different position indices for the same age-of-event. User A with 50 history items has their most recent action at position 50; User B with 100 has it at position 100.
This breaks the model's ability to interpret position as recency. Fix: right-anchor the history — make the most recent token always be at the same position (history_end).
Walking the code:
history_start,history_end— the range of indices in the sequence where history lives.idx:[0, 1, 2, ..., T-1]shaped[1, T].history_len[b]= how many valid history items userbhas.- For positions in the history range:
- Compute
history_end - history_len + idx - history_start. So ifhistory_len = 50andhistory_end = 65, the newest valid history item lands at position65 - 50 + 14 - 1 = 28... wait that doesn't quite right. Let me re-read.
- Compute
Let me trace through more carefully. Setup: num_user_prefix=1, history_seq_len=10. So history_start=1, history_end=11. Say user has history_len=4 valid items.
The valid history items are at sequence indices 7, 8, 9, 10 (right-aligned within [1..11)). The most recent is at index 10.
- For index 7: position =
11 - 4 + 7 - 1 = 13. Wait, but we want position 7 (closer to "newest"). - For index 10: position =
11 - 4 + 10 - 1 = 16.
Hmm, that's not what the test expects. Let me re-read the formula: history_end - history_len + idx - history_start.
Wait actually thinking about it: the user has 4 items, they're at sequence indices 7-10 (assuming right-padding... actually padding direction matters here).
Looking at the test test_padding_gets_zero:
padding_mask = jnp.array([[True, True, True, True, False, False, False, False]])
positions = right_anchored_rope_positions(
padding_mask, history_seq_len=4, num_user_prefix_tokens=1
)
for i in range(4, 8):
assert float(positions[0, i]) == 0.0
So padding_mask is [T, T, T, T, F, F, F, F]. The valid items are at the LEFT (indices 0-3). Position 0 is the user prefix. Positions 1-3 are valid history. Positions 4-7 are padding.
For this test: history_len = padding_mask[:, 1:5].sum() = 3 (positions 1, 2, 3 are True).
For idx = 1: position = 5 - 3 + 1 - 1 = 2.
For idx = 2: position = 5 - 3 + 2 - 1 = 3.
For idx = 3: position = 5 - 3 + 3 - 1 = 4.
So valid history items get positions 2, 3, 4. The newest valid history item gets position history_end - 1 = 4.
The intent: anchor the most-recent history token to history_end - 1 (or close to it). Older history gets smaller position indices. The formula maps:
- index
history_start + history_seq_len - history_len(the first valid) → positionhistory_end - history_len. - index
history_start + history_seq_len - 1(the last valid, if fully populated) → positionhistory_end - 1.
So shorter history doesn't shift the end position — the newest item is always at roughly the same place. The model sees "recency" consistently.
positions = jnp.where(idx >= history_end, history_end, positions)
positions = jnp.where(padding_mask, positions, 0).astype(jnp.float32)
return positions
Final adjustments:
- Candidates (
idx >= history_end) → all get positionhistory_end. So all candidates share the same position. The RoPE encoding for candidates is uniform; they only differ from each other via their input content, not position. - Padding → position 0. The padded slots don't affect attention (the padding mask zeros their contribution), but giving them position 0 keeps RoPE math well-defined.
TransformerConfig
@dataclass
class TransformerConfig:
emb_size: int
key_size: int
num_q_heads: int
num_kv_heads: int
num_layers: int
widening_factor: float = 4.0
attn_output_multiplier: float = 1.0
name: Optional[str] = None
def make(self) -> "Transformer":
return Transformer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
widening_factor=self.widening_factor,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
num_layers=self.num_layers,
)
Standard transformer config. Notable:
num_q_headsvsnum_kv_heads— supports grouped-query attention (GQA). Q has more heads than K/V, all queries within a group share the same K/V. Saves K/V memory.attn_output_multiplier— scales attention output before softmax. Used for tanh-clamping (see below).widening_factor = 4.0default. Combined with* 2 // 3inffn_size, gives FFN dim = ~`2.67 * emb_size`.
hk_rms_norm, Linear, RMSNorm
def hk_rms_norm(
x: jax.Array,
fixed_scale=False,
) -> jax.Array:
"""Applies a unique LayerNorm to x with default settings."""
ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
return ln(x)
class Linear(hk.Linear):
def __init__(
self,
output_size: int,
with_bias: bool = True,
name: Optional[str] = None,
):
super().__init__(
output_size=output_size,
with_bias=with_bias,
name=name,
)
def __call__( # type: ignore
self,
inputs: jax.Array,
) -> jax.Array:
"""Computes a linear transform of the input."""
fprop_dtype = inputs.dtype
if not inputs.shape:
raise ValueError("Input must not be scalar.")
input_size = inputs.shape[-1]
output_size = self.output_size
w = hk.get_parameter(
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
)
out = jnp.dot(inputs, w.astype(fprop_dtype))
if self.with_bias:
b = hk.get_parameter(
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
)
b = jnp.broadcast_to(b, out.shape)
out = out + b.astype(fprop_dtype)
return out
A custom Linear that overrides Haiku's. Two reasons:
- fp32 parameters, bf16 forward: params stored as
jnp.float32, cast tofprop_dtypeon use. Keeps gradient updates precise while inference runs in bf16. - Zero initialization:
init=hk.initializers.Constant(0). Standard Haiku uses variance scaling. Why zero? Because this code is for inference — the params will be overwritten byload_model_params. Zero init is just shape-correct placeholder.
If you wanted to train from this code, you'd need to swap the initializer.
class RMSNorm(hk.RMSNorm):
def __init__(
self,
axis: Union[int, Sequence[int], slice],
eps: float = 1e-5,
name: Optional[str] = None,
create_scale: bool = True,
):
super().__init__(axis, eps, create_scale=create_scale, name=name)
def __call__(self, inputs: jax.Array):
fprop_dtype = inputs.dtype
param_shape = (inputs.shape[-1],)
if self.create_scale:
scale = hk.get_parameter(
"scale",
param_shape,
dtype=jnp.float32,
init=hk.initializers.Constant(0),
)
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
else:
scale = 1.0
inputs = inputs.astype(jnp.float32)
scale = jnp.float32(scale)
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
outputs = scale * normed_inputs
return outputs.astype(fprop_dtype)
RMSNorm (Root Mean Square Layer Normalization). The standard modern alternative to LayerNorm — simpler (no mean subtraction), cheaper, and works just as well in transformers.
The math:
mean_squared = mean(x²)normed = x / sqrt(mean_squared + eps)output = scale * normed
Same fp32-params + bf16-output pattern as Linear. The scale parameter is per-feature ([D]).
Note init=hk.initializers.Constant(0) — zero init for the scale. Again, placeholder for loaded checkpoints. A trained checkpoint would have scales ~1.0 (or learned values).
rotate_half + RotaryEmbedding
def rotate_half(
x: jax.Array,
) -> jax.Array:
"""Obtain the rotated counterpart of each feature"""
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-x2, x1), axis=-1)
class RotaryEmbedding(hk.Module):
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
as described in https://arxiv.org/abs/2104.09864.
Attributes:
dim (int): Dimensionality of the feature vectors
base_exponent (int): Base exponent to compute embeddings from
"""
def __init__(
self,
dim: int,
name: Optional[str] = None,
base_exponent: int = 10000,
):
super().__init__(name)
self.dim = dim
self.base_exponent = base_exponent
assert self.dim % 2 == 0
RoPE (Rotary Position Embedding). Standard for modern transformers. Treats each pair of features as a 2D vector and rotates it by an angle proportional to the position.
rotate_half swaps and negates: (x1, x2) → (-x2, x1). This is the 90° rotation in the 2D plane spanned by the two feature halves.
base_exponent = 10000 is the standard RoPE base.
def __call__(
self,
x: jax.Array,
seq_dim: int,
offset: jax.Array,
const_position: Optional[int] = None,
t: Optional[jax.Array] = None,
) -> jax.Array:
fprop_dtype = x.dtype
# Compute the per-dimension frequencies
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
inv_freq = jnp.asarray(
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
)
if jnp.shape(offset) == ():
# Offset can be a scalar or one offset per batch element.
offset = jnp.expand_dims(offset, 0)
# Compute the per element phase (to pass into sin and cos)
if const_position:
t = const_position * jnp.ones(
(
1,
x.shape[seq_dim],
),
dtype=jnp.float32,
)
elif t is None:
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
phase = jnp.einsum("bi,j->bij", t, inv_freq)
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
x = x.astype(fprop_dtype)
return x
The RoPE math:
inv_freq[i] = 1 / 10000^(2i/dim)fori in [0, dim/2). Lower frequencies for higher dimensions — so low-dim features rotate fast (capture short-range patterns), high-dim features rotate slow (capture long-range patterns).t = positions— either constant, custom (tparam), or sequential (arange + offset).phase = t * inv_freq— outer product. Shape[batch, seq_len, dim/2].phase = tile(phase, (1, 2))— duplicate so the phase array has shape[batch, seq_len, dim](first half and second half match).x * cos(phase) + rotate_half(x) * sin(phase)— apply the 2D rotation in the feature plane.
The t parameter is what the recsys code uses: pass right-anchored positions, get position-aware encoding that respects the right-anchor.
MultiHeadAttention
The actual attention computation. Long.
class MultiHeadAttention(hk.Module):
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
key_size: int,
*,
with_bias: bool = True,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
attn_output_multiplier: float = 1.0,
name: Optional[str] = None,
):
super().__init__(name=name)
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.key_size = key_size
self.value_size = value_size or key_size
self.model_size = model_size or key_size * num_q_heads
self.attn_output_multiplier = attn_output_multiplier
self.with_bias = with_bias
model_size defaults to key_size * num_q_heads — the standard transformer convention.
def __call__(
self,
query: jax.Array,
key: jax.Array,
value: jax.Array,
mask: jax.Array,
positions: Optional[jax.Array] = None,
) -> MHAOutput:
projection = self._linear_projection
# Check that the keys and values have consistent batch size and sequence length.
assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
if mask is not None:
assert mask.ndim == 4
assert mask.shape[0] in {
1,
query.shape[0],
}, f"mask/query shape: {mask.shape}/{query.shape}"
# ...
Sanity checks. The mask must be 4D [B, H, T_q, T_k] (or broadcastable to that), key/value must agree on batch + sequence length.
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
assert self.num_q_heads % self.num_kv_heads == 0
query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")
Project to Q, K, V. Q has more heads (num_q_heads) than K, V (num_kv_heads) — GQA setup. Multiple Q heads share each K/V pair.
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
key_heads = rotate(key_heads, seq_dim=1, offset=0, t=positions)
query_heads = rotate(query_heads, seq_dim=1, offset=0, t=positions)
Apply RoPE to Q and K (but not V — RoPE only affects the keys' relative positions, which the dot product extracts). Pass the custom positions if available.
b, t, h, d = query_heads.shape
_, _, kv_h, _ = key_heads.shape
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
Reshape Q for GQA: split the h Q heads into kv_h groups of h // kv_h heads each. Now Q is [B, T, kv_h, h_per_group, d] and K, V are [B, T, kv_h, d].
# Compute attention weights.
# Attention softmax is always carried out in fp32.
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
jnp.float32
)
attn_logits *= self.attn_output_multiplier
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
Compute attention logits via einsum. The string "...thHd,...Thd->...hHtT":
t= query seq position,T= key seq position.h= kv-group,H= q-head within the group.d= feature dim.- Output:
[batch, kv_group, q_head, query_pos, key_pos].
Cast to fp32 — softmax precision matters.
attn_output_multiplier is a scaling factor. Then tanh-clamping: max_attn_val * tanh(logits / max_attn_val). Squashes logits to (-max_attn_val, max_attn_val). Prevents huge logits from overflowing softmax in low-precision contexts. Standard stability hack used in some xAI / Anthropic models.
mask = mask[:, :, None, :, :]
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
Apply mask: where(mask, logits, -1e30). Where the mask is False, set logits to -1e30 — effectively -∞ after softmax. Then softmax → zero probability for masked positions.
The mask[:, :, None, :, :] reshape inserts the H-axis dim so the mask broadcasts across heads.
Cast back to query dtype (bf16) after softmax.
# Weight the values by the attention and flatten the head vectors.
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
leading_dims = attn.shape[:2]
attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
# Apply another projection to get the final embeddings.
final_projection = Linear(self.model_size, with_bias=False)
return MHAOutput(final_projection(attn))
Apply attention weights to V, flatten heads, project to model_size.
The output projection uses Linear with with_bias=False — biases on attention outputs are skipped (common modern convention).
MHABlock, DenseBlock, DecoderLayer
@dataclass
class MHABlock(hk.Module):
"""A MHA Block"""
num_q_heads: int
num_kv_heads: int
key_size: int
attn_output_multiplier: float = 1.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
positions: Optional[jax.Array] = None,
) -> MHAOutput:
_, _, model_size = inputs.shape
assert mask.ndim == 4, f"shape: {mask.shape}"
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
side_input = inputs
def attn_block(query, key, value, mask) -> MHAOutput:
return MultiHeadAttention(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
model_size=model_size,
attn_output_multiplier=self.attn_output_multiplier,
)(query, key, value, mask, positions=positions)
attn_output = attn_block(inputs, side_input, side_input, mask)
h_attn = attn_output.embeddings
return MHAOutput(embeddings=h_attn)
A thin wrapper that uses inputs for all three of Q/K/V: self-attention. The side_input = inputs makes it explicit.
@hk.transparent — Haiku decorator that makes the module's parameters live in the enclosing module's namespace (instead of creating a sub-namespace). Lets the calling layer have a flatter parameter tree.
@dataclass
class DenseBlock(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float = 4.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
) -> jax.Array: # [B, T, D]
_, _, model_size = inputs.shape
h_v = Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
name="linear_v",
)(inputs)
h_w1 = jax.nn.gelu(
Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
)(inputs)
)
h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)
return h_dense
The FFN. This is GeGLU (Gated GELU): two parallel projections, one passed through GELU, then element-wise multiplied. Then projected back to model_size.
The math: out = Linear_down(GELU(Linear_up_1(x)) * Linear_up_2(x)). Three linear layers (gate, up, down). Better than vanilla FFN at the same parameter count.
The * 2 // 3 adjustment in ffn_size keeps this iso-parameter with a vanilla widening_factor * D FFN.
@dataclass
class DecoderLayer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
num_layers: int
layer_index: Optional[int] = None
widening_factor: float = 4.0
name: Optional[str] = None
attn_output_multiplier: float = 1.0
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T]
padding_mask: Optional[jax.Array],
positions: Optional[jax.Array] = None,
) -> DecoderOutput:
"""Transforms input embedding sequences to output embedding sequences."""
del padding_mask # Unused.
def layer_norm(x):
return hk_rms_norm(x)
h = inputs
attn_output = MHABlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
)(layer_norm(h), mask, positions=positions)
h_attn = attn_output.embeddings
h_attn = layer_norm(h_attn)
h += h_attn
def base_dense_block(h):
h = DenseBlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=self.widening_factor,
)(h)
return h
h_dense = base_dense_block(layer_norm(h))
h_dense = layer_norm(h_dense)
h += h_dense
return DecoderOutput(
embeddings=h,
)
One transformer layer. The structure:
inputs
↓
RMSNorm
↓
MHA (self-attention)
↓
RMSNorm ← double layer norm: also after the attention block
↓
+ inputs ← residual
↓
RMSNorm
↓
GeGLU FFN
↓
RMSNorm ← double layer norm again
↓
+ previous ← residual
↓
output
Double layer norm is unusual — most transformers do single. This is one of the Grok-1 design choices preserved. Pre-norm AND post-norm style, stacked.
del padding_mask — the param exists in the signature for API compat but isn't used. The mask already incorporates padding.
Transformer — the full stack
@dataclass
class Transformer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float
attn_output_multiplier: float
num_layers: int
name: Optional[str] = None
def __call__(
self,
embeddings: jax.Array, # [B, T, D]
mask: jax.Array, # [B, T]
candidate_start_offset: Optional[int] = None,
positions: Optional[jax.Array] = None,
) -> TransformerOutput:
# ...
fprop_dtype = embeddings.dtype
_, seq_len, _ = embeddings.shape
padding_mask = mask.copy()
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
if candidate_start_offset is not None:
# Use recommendation system attention mask where candidates attend to
# user+history and themselves, but not to other candidates
attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
mask = mask * attn_mask
else:
# Standard causal mask for autoregressive sequence modelling
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
fprop_dtype
) # [B=1, H=1, T, T]
mask = mask * causal_mask # [B, H=1, T, T]
Mask construction — this is where the candidate_start_offset parameter matters.
The input mask is a 2D padding mask [B, T]. Expand to [B, 1, 1, T].
Two branches:
candidate_start_offset is not None(ranking): build the recsys mask we read earlier, multiply with the padding mask. The result has both the causal-with-candidate-isolation structure AND padding zeros.candidate_start_offset is None(retrieval): standard causal mask. The retrieval model uses this since there are no candidates in the sequence.
Multiplying the padding mask with the structural mask combines them: a position is attendable iff (padding mask says it's valid) AND (structural mask says it's reachable).
h = embeddings
def block(
h,
mask,
padding_mask,
layer_index: Optional[int] = None,
widening_factor: Optional[int] = None,
name: Optional[str] = None,
) -> DecoderOutput:
return DecoderLayer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=widening_factor or self.widening_factor,
num_layers=self.num_layers,
attn_output_multiplier=self.attn_output_multiplier,
name=name,
layer_index=layer_index,
)(h, mask, padding_mask, positions=positions)
for i in range(self.num_layers):
decoder_output = block(
h,
mask,
padding_mask,
layer_index=i,
name=f"decoder_layer_{i}",
)
h = decoder_output.embeddings
return TransformerOutput(
embeddings=h,
)
Stack num_layers decoder layers. Each is named decoder_layer_0, decoder_layer_1, ... — the names matter for parameter loading from checkpoints.
Sequential pass — no skipping, no early exit. Output is the final hidden states [B, T, D].
End of grok.py.
test_recsys_model.py (309 lines)
The test file for the ranking model side. Four test classes:
TestMakeRecsysAttnMask— pins down the attention mask behavior.TestRightAnchoredRopePositions— RoPE positions.TestComputePostAgeBucket— age bucketing.TestNormalizeContinuousValue— normalization.
We'll quote the most important ones.
TestMakeRecsysAttnMask — the canary test
class TestMakeRecsysAttnMask:
"""Tests for the make_recsys_attn_mask function."""
def test_output_shape(self):
"""Test that the output has the correct shape [1, 1, seq_len, seq_len]."""
seq_len = 10
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
assert mask.shape == (1, 1, seq_len, seq_len)
Basic shape test. The mask is [1, 1, seq_len, seq_len] — broadcasts across batch and heads.
def test_user_history_has_causal_attention(self):
"""Test that user+history positions (before candidate_start_offset) have causal attention."""
seq_len = 8
candidate_start_offset = 5
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
for i in range(candidate_start_offset):
for j in range(candidate_start_offset):
if j <= i:
assert mask_2d[i, j] == 1, f"Position {i} should attend to position {j}"
else:
assert (
mask_2d[i, j] == 0
), f"Position {i} should NOT attend to future position {j}"
For positions in [0, candidate_start_offset): standard causal triangle. j <= i attends, j > i doesn't.
def test_candidates_attend_to_user_history(self):
"""Test that candidates can attend to all user+history positions."""
# ...
for candidate_pos in range(candidate_start_offset, seq_len):
for history_pos in range(candidate_start_offset):
assert (
mask_2d[candidate_pos, history_pos] == 1
), f"Candidate at {candidate_pos} should attend to user+history at {history_pos}"
Every candidate attends to every user+history position. Unconditional — no causality constraint on the candidate-to-history edge.
def test_candidates_attend_to_themselves(self):
"""Test that candidates can attend to themselves (self-attention)."""
for candidate_pos in range(candidate_start_offset, seq_len):
assert (
mask_2d[candidate_pos, candidate_pos] == 1
), f"Candidate at {candidate_pos} should attend to itself"
def test_candidates_do_not_attend_to_other_candidates(self):
"""Test that candidates cannot attend to other candidates."""
for query_pos in range(candidate_start_offset, seq_len):
for key_pos in range(candidate_start_offset, seq_len):
if query_pos != key_pos:
assert (
mask_2d[query_pos, key_pos] == 0
), f"Candidate at {query_pos} should NOT attend to candidate at {key_pos}"
Self-attention only for candidates. No candidate-to-candidate cross-attention.
def test_full_mask_structure(self):
"""Test the complete mask structure with a small example."""
# Sequence: [user, h1, h2, c1, c2, c3]
# Positions: 0 1 2 3 4 5
seq_len = 6
candidate_start_offset = 3
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0, 0, 0], # user
[1, 1, 0, 0, 0, 0], # h1
[1, 1, 1, 0, 0, 0], # h2
[1, 1, 1, 1, 0, 0], # c1: user+history + self
[1, 1, 1, 0, 1, 0], # c2: user+history + self
[1, 1, 1, 0, 0, 1], # c3: user+history + self
],
dtype=np.float32,
)
np.testing.assert_array_equal(
np.array(mask_2d),
expected,
err_msg="Full mask structure does not match expected pattern",
)
The canonical example. Matches the diagram from earlier in this article. This is the most readable test of the four — the full expected matrix is laid out explicitly.
def test_dtype_preserved(self):
"""Test that the specified dtype is used."""
seq_len = 5
candidate_start_offset = 3
mask_f32 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float32)
mask_f16 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float16)
assert mask_f32.dtype == jnp.float32
assert mask_f16.dtype == jnp.float16
Dtype passthrough.
def test_single_candidate(self):
"""Test edge case with a single candidate."""
seq_len = 4
candidate_start_offset = 3
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1],
],
dtype=np.float32,
)
np.testing.assert_array_equal(np.array(mask_2d), expected)
Single candidate edge case: with only one candidate, the mask is just a regular causal triangle. The "no candidate-to-candidate" rule is vacuous; the candidate just sees itself.
def test_all_candidates(self):
"""Test edge case where all positions except first are candidates."""
seq_len = 4
candidate_start_offset = 1
mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
mask_2d = mask[0, 0]
expected = np.array(
[
[1, 0, 0, 0], # user
[1, 1, 0, 0], # c1: user + self
[1, 0, 1, 0], # c2: user + self
[1, 0, 0, 1], # c3: user + self
],
dtype=np.float32,
)
np.testing.assert_array_equal(np.array(mask_2d), expected)
No history edge case: user only, then candidates. Each candidate sees just (user, self).
TestRightAnchoredRopePositions
class TestRightAnchoredRopePositions:
"""Tests for the right_anchored_rope_positions function."""
def test_output_shape(self):
"""Test that the output has the correct shape [B, T]."""
B, T = 2, 10
padding_mask = jnp.ones((B, T), dtype=jnp.bool_)
positions = right_anchored_rope_positions(
padding_mask, history_seq_len=6, num_user_prefix_tokens=1
)
assert positions.shape == (B, T)
def test_prefix_positions_preserved(self):
"""Test that prefix token positions are 0..num_prefix-1."""
B, T = 1, 10
padding_mask = jnp.ones((B, T), dtype=jnp.bool_)
positions = right_anchored_rope_positions(
padding_mask, history_seq_len=6, num_user_prefix_tokens=2
)
assert float(positions[0, 0]) == 0.0
assert float(positions[0, 1]) == 1.0
Prefix tokens get positions 0..num_prefix-1. So the user token (if num_user_prefix=1) is always at position 0.
def test_candidates_share_position(self):
"""Test that all candidate positions are the same (history_end)."""
B = 1
num_prefix = 1
history_len = 4
num_candidates = 3
T = num_prefix + history_len + num_candidates
padding_mask = jnp.ones((B, T), dtype=jnp.bool_)
positions = right_anchored_rope_positions(
padding_mask, history_seq_len=history_len, num_user_prefix_tokens=num_prefix
)
history_end = num_prefix + history_len
for c in range(num_candidates):
assert float(positions[0, history_end + c]) == float(history_end)
All candidates get position history_end. So they share one position — the model's candidate representations differ only via input content, not position. Consistent with the candidate-isolation logic.
def test_padding_gets_zero(self):
"""Test that padded positions get position 0."""
B, T = 1, 8
padding_mask = jnp.array([[True, True, True, True, False, False, False, False]])
positions = right_anchored_rope_positions(
padding_mask, history_seq_len=4, num_user_prefix_tokens=1
)
for i in range(4, 8):
assert float(positions[0, i]) == 0.0
Padded slots get position 0. The mask zeros their attention contribution anyway, but the position is well-defined.
TestComputePostAgeBucket
class TestComputePostAgeBucket:
"""Tests for the compute_post_age_bucket function."""
def test_basic_bucketing(self):
"""Test basic bucketing with 60-minute granularity."""
# Post that is 30 minutes old -> bucket 1 (0-59 minutes)
impr_ts = jnp.array([[1000000]])
post_ts = jnp.array([[1000000 - 30 * 60]])
bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
assert int(bucket[0, 0]) == 1
30 minutes old → bucket 1 (i.e., 0-59 minutes, bucket 0 is reserved for missing).
def test_two_hour_post(self):
"""Test a 2-hour-old post -> bucket 3 (120-179 minutes)."""
impr_ts = jnp.array([[1000000]])
post_ts = jnp.array([[1000000 - 120 * 60]])
bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
assert int(bucket[0, 0]) == 3
2 hours = 120 minutes → 120 // 60 + 1 = 3. So bucket = (minutes // granularity) + 1. Bucket 1 covers 0-59 minutes, bucket 2 covers 60-119, bucket 3 covers 120-179.
def test_missing_timestamp_zero(self):
"""Test that missing timestamps (0) map to bucket 0."""
impr_ts = jnp.array([[0]])
post_ts = jnp.array([[1000000]])
bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
assert int(bucket[0, 0]) == 0
# ... same for missing post_ts
Sentinel handling: ts of 0 → bucket 0.
def test_negative_age_maps_to_zero(self):
"""Test that negative age (clock skew) maps to bucket 0."""
impr_ts = jnp.array([[1000000]])
post_ts = jnp.array([[1000000 + 60 * 60]]) # post created AFTER impression
bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
assert int(bucket[0, 0]) == 0
Clock-skew handling: negative age → bucket 0. Same as missing.
def test_overflow_bucket(self):
"""Test that very old posts go to the overflow bucket."""
# POST_AGE_MAX_MINUTES = 4800 (80 hours)
impr_ts = jnp.array([[1000000]])
post_ts = jnp.array([[1000000 - 5000 * 60]]) # 5000 minutes old
bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
# overflow bucket = 4800 // 60 + 1 = 81
assert int(bucket[0, 0]) == 81
Posts older than 80 hours → bucket 81. Beyond-this-bucket = "ancient post" treated as a single category.
TestNormalizeContinuousValue
class TestNormalizeContinuousValue:
"""Tests for continuous value normalization."""
def test_linear_normalization(self):
"""Test linear normalization: x / norm_scale."""
config = NormConfig(norm_scale=30.0, use_log=False)
values = jnp.array([0.0, 15.0, 30.0, 60.0])
result = normalize_continuous_value(values, config)
np.testing.assert_allclose(np.array(result), [0.0, 0.5, 1.0, 1.0], atol=1e-6)
Linear: x / 30. Note 60.0 → 1.0 (clamped).
def test_log_normalization(self):
"""Test log normalization: log1p(x) / log1p(norm_scale)."""
config = NormConfig(norm_scale=30.0, use_log=True)
values = jnp.array([0.0, 30.0])
result = normalize_continuous_value(values, config)
assert float(result[0]) == pytest.approx(0.0, abs=1e-6)
assert float(result[1]) == pytest.approx(1.0, abs=1e-6)
def test_clamping(self):
"""Test that values are clamped to [0, norm_scale]."""
config = NormConfig(norm_scale=10.0, use_log=False)
values = jnp.array([-5.0, 0.0, 5.0, 15.0])
result = normalize_continuous_value(values, config)
np.testing.assert_allclose(np.array(result), [0.0, 0.0, 0.5, 1.0], atol=1e-6)
Standard edge-case tests. log1p for log mode (handles 0 cleanly: log1p(0) = log(1) = 0). Clamping at both ends.
test_recsys_retrieval_model.py (417 lines)
Tests for the retrieval model. Three test classes:
TestCandidateTower— the MLP/mean-pool candidate encoder.TestPhoenixRetrievalModel— the full retrieval model.TestRetrievalInferenceRunner— the runner from Session 16.
TestCandidateTower
class TestCandidateTower(unittest.TestCase):
"""Tests for the CandidateTower module."""
def test_candidate_tower_output_shape(self):
"""Test that candidate tower produces correct output shape."""
emb_size = 64
batch_size = 4
num_candidates = 8
num_hashes = 4
def forward(x):
tower = CandidateTower(emb_size=emb_size, enable_linear_proj=True)
return tower(x)
forward_fn = hk.without_apply_rng(hk.transform(forward))
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))
params = forward_fn.init(rng, x)
output = forward_fn.apply(params, x)
self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))
Shape test: [B, C, num_hashes, D] → [B, C, D]. The candidate tower flattens hashes.
def test_candidate_tower_normalized(self):
"""Test that candidate tower output is L2 normalized."""
# ...
params = forward_fn.init(rng, x)
output = forward_fn.apply(params, x)
norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)
Output is L2-normalized to unit vectors. Critical for cosine-similarity retrieval. The model code (Session 15) ensures this in both modes.
def test_candidate_tower_mean_pooling(self):
"""Test candidate tower with mean pooling (no linear projection)."""
# ...
def forward(x):
tower = CandidateTower(emb_size=emb_size, enable_linear_proj=False)
return tower(x)
# ...
output = forward_fn.apply(params, x)
self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))
norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)
Mean-pooling mode also produces normalized outputs. Same shape, same norm guarantee.
def test_mean_pooling_has_no_params(self):
"""Test that mean pooling mode introduces no learned parameters."""
# ...
params = forward_fn.init(rng, x)
# Mean pooling should have no parameters
total_params = sum(p.size for p in jax.tree.leaves(params))
self.assertEqual(total_params, 0)
Zero learned parameters in mean-pooling mode. jax.tree.leaves flattens the Haiku params tree; p.size sums element counts. Total = 0 confirms no hk.get_parameter calls fired.
This is the "more parameter-efficient but less expressive" mode referenced in Session 15. Useful when you want to use the post + author hash embeddings directly without a learned projection on top.
TestPhoenixRetrievalModel
class TestPhoenixRetrievalModel(unittest.TestCase):
"""Tests for the full Phoenix Retrieval Model."""
def setUp(self):
"""Set up test fixtures."""
self.emb_size = 64
self.history_seq_len = 16
self.candidate_seq_len = 8
self.batch_size = 2
self.num_actions = 19
self.corpus_size = 100
self.top_k = 10
self.hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
self.config = PhoenixRetrievalModelConfig(
emb_size=self.emb_size,
history_seq_len=self.history_seq_len,
candidate_seq_len=self.candidate_seq_len,
hash_config=self.hash_config,
product_surface_vocab_size=16,
enable_linear_proj=True,
model=TransformerConfig(
emb_size=self.emb_size,
widening_factor=2,
key_size=32,
num_q_heads=2,
num_kv_heads=2,
num_layers=1,
attn_output_multiplier=0.125,
),
)
Tiny test config: emb_size=64, 16-item history, 8 candidates, single transformer layer, 2 heads. Same architecture as production, just shrunk.
def test_model_forward(self):
"""Test model forward pass produces correct output shapes."""
def forward(batch, embeddings, corpus_embeddings, top_k):
model = self.config.make()
return model(batch, embeddings, corpus_embeddings, top_k)
forward_fn = hk.without_apply_rng(hk.transform(forward))
batch, embeddings = self._create_test_batch()
corpus_embeddings, _ = self._create_test_corpus()
rng = jax.random.PRNGKey(0)
params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))
Forward pass + output-shape checks. The full pipeline init runs end-to-end.
def test_user_representation_normalized(self):
"""Test that user representations are L2 normalized."""
# ...
norms = jnp.sqrt(jnp.sum(output.user_representation**2, axis=-1))
np.testing.assert_array_almost_equal(norms, jnp.ones(self.batch_size), decimal=5)
def test_candidate_representation_normalized(self):
"""Test that candidate representations from build_candidate_representation are L2 normalized."""
# ...
norms = jnp.sqrt(jnp.sum(cand_rep**2, axis=-1))
np.testing.assert_array_almost_equal(
norms, jnp.ones((self.batch_size, self.candidate_seq_len)), decimal=5
)
Both user and candidate representations are L2-normalized. Required for cosine-similarity retrieval.
def test_retrieve_top_k(self):
"""Test top-k retrieval through __call__."""
# ...
self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))
self.assertTrue(jnp.all(output.top_k_indices >= 0))
self.assertTrue(jnp.all(output.top_k_indices < self.corpus_size))
for b in range(self.batch_size):
scores = np.array(output.top_k_scores[b])
self.assertTrue(np.all(scores[:-1] >= scores[1:]))
Three guarantees:
- Indices are valid (
[0, corpus_size)). - Scores are sorted descending (top-K).
scores[:-1] >= scores[1:]— each consecutive pair is non-increasing.
def test_mean_pooling_model_forward(self):
"""Test model forward pass with mean pooling candidate tower."""
config = PhoenixRetrievalModelConfig(
# ...
enable_linear_proj=False, # Mean pooling
# ...
)
# ...
output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)
self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
Confirm the mean-pooling mode also produces a runnable model.
TestRetrievalInferenceRunner
class TestRetrievalInferenceRunner(unittest.TestCase):
"""Tests for the retrieval inference runner."""
def setUp(self):
"""Set up test fixtures."""
# ...
def test_runner_initialization(self):
"""Test that runner initializes correctly."""
runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=self.config,
bs_per_device=0.125,
),
name="test_retrieval",
)
runner.initialize()
self.assertIsNotNone(runner.params)
def test_runner_encode_user(self):
"""Test user encoding through runner."""
runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=self.config,
bs_per_device=0.125,
),
name="test_retrieval",
)
runner.initialize()
batch, embeddings = create_example_batch(...)
user_rep = runner.encode_user(batch, embeddings)
self.assertEqual(user_rep.shape, (self.batch_size, self.emb_size))
Test the inference runner from Session 16. Confirms:
- Runner initializes (params are not None).
encode_userreturns the right shape.retrieve(next test) returns shaped outputs.
def test_runner_retrieve(self):
"""Test retrieval through runner."""
# ...
runner.initialize()
batch, embeddings = create_example_batch(...)
corpus_size = 100
corpus_embeddings, corpus_post_ids = create_example_corpus(corpus_size, self.emb_size)
runner.set_corpus(corpus_embeddings, corpus_post_ids)
top_k = 10
output = runner.retrieve(batch, embeddings, top_k=top_k)
self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
self.assertEqual(output.top_k_indices.shape, (self.batch_size, top_k))
self.assertEqual(output.top_k_scores.shape, (self.batch_size, top_k))
End-to-end runner test: init → set_corpus → retrieve. Pins down the inference contract.
What we've learned
The Grok transformer is mostly a standard modern LLM transformer:
- RMSNorm instead of LayerNorm.
- RoPE for position encoding.
- GeGLU FFN with
* 2 // 3SwiGLU-style scaling. - GQA (
num_q_heads >= num_kv_heads). - Tanh-clamping on attention logits for stability.
- Double layer norm per layer.
- Causal attention by default.
Two recsys-specific modifications:
make_recsys_attn_mask— the candidate isolation mask. Critical for score consistency.right_anchored_rope_positions— anchors the most-recent history token at a fixed position. Lets the model treat position as recency.
The recsys attention mask pattern (memorize this):
- Causal within user + history.
- Candidates see all of user + history.
- Candidates see themselves but no other candidates.
Why candidate isolation matters:
- Score for candidate X doesn't depend on which other candidates are in the batch.
- Reproducible across requests (a candidate scored alone vs. with others gets the same score).
- Cacheable — the model output for a (user, candidate) pair is invariant to other candidates.
- Required for the partial-scoring caches (Sessions 12-13).
The RoPE customization fixes the "variable-length history" problem:
- Standard: index = position; longer history = larger position indices for older items.
- Right-anchored: most-recent history is always near
history_end - 1, regardless of how full the history is. - The model can learn "position 50" = "this happened ~50 actions ago" consistently across users.
Candidate positions are all the same: all candidates get position history_end. They differ only in content, not position.
fp32 params + bf16 forward: parameters are stored as float32, cast to bfloat16 on forward pass. Softmax always done in float32. Standard mixed-precision setup.
Zero parameter init: hk.initializers.Constant(0) everywhere. Because this code is for inference — parameters get overwritten by load_model_params. For training, swap to VarianceScaling.
The double layer norm per layer: pre-norm AND post-norm. Unusual. Inherited from Grok-1.
Tanh-clamped attention logits: max_attn_val * tanh(logits / max_attn_val). Prevents softmax overflow in bf16. The clamp value is 30.0. Standard stability hack in modern LLMs.
GeGLU FFN with iso-parameter scaling: ffn_size = (widening_factor * emb_size) * 2 / 3, rounded to multiple of 8. The * 2 / 3 factor keeps the 3-projection GeGLU at the same param count as a vanilla 2-projection FFN.
Test-driven specs: the test files act as the spec for make_recsys_attn_mask and right_anchored_rope_positions. If you were re-implementing in another framework, these tests are what you'd target. The expected-matrix test in test_full_mask_structure is the most useful — it shows the exact desired output.
Two model modes for CandidateTower:
enable_linear_proj=True(default): 2-layer MLP with SiLU + L2-norm.enable_linear_proj=False: mean-pool + L2-norm. Zero learned parameters — the test confirms this. Useful when you want minimal model size or when the hash embeddings are already expressive enough.
Sorted top-K invariant: the retrieval test confirms scores[:-1] >= scores[1:] — output indices are returned in descending score order. The caller can trust the ordering.
Runner-level tests: TestRetrievalInferenceRunner validates the full init → set_corpus → retrieve pipeline. The tests are basically usage examples — copy them to start using the runner.
Phoenix tour complete
That's the whole Phoenix Python codebase. We've now read:
- Session 15:
recsys_model.py+recsys_retrieval_model.py— model architecture. - Session 16:
runners.py+run_pipeline.py+run_ranker.py+run_retrieval.py— inference machinery + end-to-end demo. - Session 17 (this):
grok.py+ tests — the transformer + behavior specs.
Series progress:
| Sessions | Module | LOC |
|---|---|---|
| 01 | candidate-pipeline | 1,031 |
| 02-03 | thunder | 1,808 |
| 04-14 | home-mixer | ~11,700 |
| 15-17 | phoenix | 3,880 |
| Total | ~18,419 |
5 sessions left, all in Grox (the LLM-driven content classification pipeline).
Next session
Session 18 — Grox core (~900 LOC). The orchestration layer for the LLM-driven content pipeline:
grox/__init__.py,dispatcher.py,engine.py,main.py— the runtimegrox/schedules/*— schedule + context managementgrox/generators/*— task and stream generatorsgrox/lib/*— utilities
This is the start of the content-understanding pipeline that produces the safety labels and topic classifications we saw consumed throughout the home-mixer Rust code.