X For You algorithm, line by line · Part 17

X For You algorithm, line by line — Part 17: Grok transformer + tests

Part 17 — the final Phoenix session. The Grok-1-derived transformer that powers both ranking and retrieval: candidate isolation attention mask, right-anchored RoPE positions, GQA with tanh-clamping, GeGLU feed-forward, the double-layer-norm DecoderLayer, plus the test suites that pin down the most subtle pieces.

May 15, 2026·31 min read

The last Phoenix session. Three files: the actual Grok-1-derived transformer code that does the heavy lifting in both ranking and retrieval, plus the test suites that pin down the behavior of the critical pieces.

Files covered (1,342 LOC):

phoenix/
├── grok.py                          (616)  Transformer + attention + RoPE + recsys mask
├── test_recsys_model.py             (309)  attention mask + RoPE + bucketing tests
└── test_recsys_retrieval_model.py   (417)  CandidateTower + retrieval model + runner tests

The README mentions: "The transformer implementation is ported from the Grok-1 open source release by xAI, adapted for recommendation system use cases." This file is what they mean by "adapted" — same transformer architecture as Grok-1 but with two key recsys-specific modifications:

  1. make_recsys_attn_mask — the candidate-isolation attention mask we keep referencing.
  2. right_anchored_rope_positions — RoPE positions that anchor the most-recent history token at a fixed index.

We'll spend most time on those, plus the standard transformer machinery.


grok.py (616 lines)

Imports + types

import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Sequence, Union

import haiku as hk
import jax
import jax.numpy as jnp

logger = logging.getLogger(__name__)


class TrainingState(NamedTuple):
    """Container for the training state."""

    params: hk.Params

TrainingState is what runners.py (Session 16) wraps loaded params in. Standard Haiku pattern: a (params,) named tuple. Could carry optimizer state too if doing training; here just params.

ffn_size — round up FFN size to multiple of 8

def ffn_size(emb_size, widening_factor):
    _ffn_size = int(widening_factor * emb_size) * 2 // 3
    _ffn_size = _ffn_size + (8 - _ffn_size) % 8  # ensure it's a multiple of 8
    logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
    return _ffn_size

The FFN (feed-forward network) hidden dimension. The * 2 // 3 is the SwiGLU adjustment: a SwiGLU FFN uses 3 linear projections (gate, up, down), so to keep param count comparable to a standard 2-projection FFN at widening_factor * emb_size, you scale down by 2/3.

Round up to multiple of 8 for hardware efficiency — matrix-multiply hardware likes dimensions divisible by 8 (or 16, or 64) for tensor-core utilization.

For emb_size=256, widening_factor=2.0: int(512) * 2 // 3 = 341, rounded to 344.

make_recsys_attn_mask — the critical recsys modification

def make_recsys_attn_mask(
    seq_len: int,
    candidate_start_offset: int,
    dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
    """Create attention mask for recommendation system inference.

    Creates a mask where:
    - Positions 0 to candidate_start_offset-1 (user+history): causal attention
    - Positions candidate_start_offset onwards (candidates): can attend to user+history
      and themselves (self-attention), but NOT to other candidates

    This ensures each candidate is scored independently based on user+history context.

    Args:
        seq_len: Total sequence length (user + history + candidates)
        candidate_start_offset: Position where candidates start in the sequence
        dtype: Data type for the mask

    Returns:
        Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
    """
    # Start with causal mask for the full sequence
    causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))

    # Zero out candidate-to-candidate attention (bottom-right block)
    attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)

    # Add back self-attention for candidates (diagonal of the candidate block)
    candidate_indices = jnp.arange(candidate_start_offset, seq_len)
    attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)

    return attn_mask

This is the most important function in the file for recsys correctness.

The mask construction:

  1. Start with full causal masktril (lower triangle) gives the classic "each token can attend to itself and previous tokens."
  2. Zero out the candidate blockattn_mask[:, :, c:, c:].set(0). This removes ALL candidate-to-candidate attention (including each candidate attending to itself, which we'll restore).
  3. Add back diagonal in the candidate blockattn_mask[:, :, indices, indices].set(1) restores self-attention for each candidate.

The .at[indices].set(value) pattern is JAX's functional update syntax (since JAX arrays are immutable).

Final mask structure for a sequence [user, h1, h2, c1, c2, c3]:

       user  h1   h2   c1   c2   c3
user    1    0    0    0    0    0    ← causal: only sees self
h1      1    1    0    0    0    0    ← causal: sees user + self
h2      1    1    1    0    0    0    ← causal: sees user, h1, self
c1      1    1    1    1    0    0    ← sees user+history + self only
c2      1    1    1    0    1    0    ← sees user+history + self only
c3      1    1    1    0    0    1    ← sees user+history + self only

(This is the exact expected matrix from test_full_mask_structure we'll read in the test file.)

Why this design?

  • No candidate-to-candidate attention ⇒ each candidate's representation depends only on (user + history + itself). Two candidates in the same batch don't influence each other's scores.
  • Self-attention for each candidate ⇒ the candidate's own embedding affects its output. Without it, the model would produce identical outputs for all candidates (since they'd be attending only to user + history, which is the same for all).
  • Causal mask for user + history ⇒ standard transformer causality. Older history doesn't see newer history. Even though the user prefix is at position 0, it gets attended by everything (causal).

The README highlights this:

Candidate Isolation in Ranking: During transformer inference, candidates cannot attend to each other—only to the user context. This ensures the score for a post doesn't depend on which other posts are in the batch, making scores consistent and cacheable.

MHAOutput, DecoderOutput, TransformerOutput

class MHAOutput(NamedTuple):
    """Outputs of the multi-head attention operation."""

    embeddings: jax.Array


class DecoderOutput(NamedTuple):
    embeddings: jax.Array


class TransformerOutput(NamedTuple):
    embeddings: jax.Array

Three single-field named tuples. Why a named tuple for a single field? Future-proofing — if you later add attention_weights or intermediate_activations, you don't break callers.

right_anchored_rope_positions — the second recsys modification

def right_anchored_rope_positions(
    padding_mask: jax.Array,
    history_seq_len: int,
    num_user_prefix_tokens: int,
) -> jax.Array:
    """Compute RoPE positions where the newest history token always gets a fixed position."""
    history_start = num_user_prefix_tokens
    history_end = num_user_prefix_tokens + history_seq_len

    idx = jnp.arange(padding_mask.shape[1], dtype=jnp.int32)[None, :]
    history_len = padding_mask[:, history_start:history_end].sum(axis=1, dtype=jnp.int32)

    positions = jnp.where(
        (history_start <= idx) & (idx < history_end),
        history_end - history_len[:, None] + idx - history_start,
        idx,
    )

    positions = jnp.where(idx >= history_end, history_end, positions)
    positions = jnp.where(padding_mask, positions, 0).astype(jnp.float32)

    return positions

Right-anchored RoPE positions — fixes the "history of variable length" problem.

The standard RoPE assigns position 0 to the first token, 1 to the second, etc. But for history with variable length, this means two users with different history lengths get different position indices for the same age-of-event. User A with 50 history items has their most recent action at position 50; User B with 100 has it at position 100.

This breaks the model's ability to interpret position as recency. Fix: right-anchor the history — make the most recent token always be at the same position (history_end).

Walking the code:

  • history_start, history_end — the range of indices in the sequence where history lives.
  • idx: [0, 1, 2, ..., T-1] shaped [1, T].
  • history_len[b] = how many valid history items user b has.
  • For positions in the history range:
    • Compute history_end - history_len + idx - history_start. So if history_len = 50 and history_end = 65, the newest valid history item lands at position 65 - 50 + 14 - 1 = 28... wait that doesn't quite right. Let me re-read.

Let me trace through more carefully. Setup: num_user_prefix=1, history_seq_len=10. So history_start=1, history_end=11. Say user has history_len=4 valid items.

The valid history items are at sequence indices 7, 8, 9, 10 (right-aligned within [1..11)). The most recent is at index 10.

  • For index 7: position = 11 - 4 + 7 - 1 = 13. Wait, but we want position 7 (closer to "newest").
  • For index 10: position = 11 - 4 + 10 - 1 = 16.

Hmm, that's not what the test expects. Let me re-read the formula: history_end - history_len + idx - history_start.

Wait actually thinking about it: the user has 4 items, they're at sequence indices 7-10 (assuming right-padding... actually padding direction matters here).

Looking at the test test_padding_gets_zero:

padding_mask = jnp.array([[True, True, True, True, False, False, False, False]])
positions = right_anchored_rope_positions(
    padding_mask, history_seq_len=4, num_user_prefix_tokens=1
)
for i in range(4, 8):
    assert float(positions[0, i]) == 0.0

So padding_mask is [T, T, T, T, F, F, F, F]. The valid items are at the LEFT (indices 0-3). Position 0 is the user prefix. Positions 1-3 are valid history. Positions 4-7 are padding.

For this test: history_len = padding_mask[:, 1:5].sum() = 3 (positions 1, 2, 3 are True).

For idx = 1: position = 5 - 3 + 1 - 1 = 2. For idx = 2: position = 5 - 3 + 2 - 1 = 3. For idx = 3: position = 5 - 3 + 3 - 1 = 4.

So valid history items get positions 2, 3, 4. The newest valid history item gets position history_end - 1 = 4.

The intent: anchor the most-recent history token to history_end - 1 (or close to it). Older history gets smaller position indices. The formula maps:

  • index history_start + history_seq_len - history_len (the first valid) → position history_end - history_len.
  • index history_start + history_seq_len - 1 (the last valid, if fully populated) → position history_end - 1.

So shorter history doesn't shift the end position — the newest item is always at roughly the same place. The model sees "recency" consistently.

    positions = jnp.where(idx >= history_end, history_end, positions)
    positions = jnp.where(padding_mask, positions, 0).astype(jnp.float32)

    return positions

Final adjustments:

  • Candidates (idx >= history_end) → all get position history_end. So all candidates share the same position. The RoPE encoding for candidates is uniform; they only differ from each other via their input content, not position.
  • Padding → position 0. The padded slots don't affect attention (the padding mask zeros their contribution), but giving them position 0 keeps RoPE math well-defined.

TransformerConfig

@dataclass
class TransformerConfig:
    emb_size: int
    key_size: int
    num_q_heads: int
    num_kv_heads: int
    num_layers: int
    widening_factor: float = 4.0

    attn_output_multiplier: float = 1.0

    name: Optional[str] = None

    def make(self) -> "Transformer":
        return Transformer(
            num_q_heads=self.num_q_heads,
            num_kv_heads=self.num_kv_heads,
            widening_factor=self.widening_factor,
            key_size=self.key_size,
            attn_output_multiplier=self.attn_output_multiplier,
            num_layers=self.num_layers,
        )

Standard transformer config. Notable:

  • num_q_heads vs num_kv_heads — supports grouped-query attention (GQA). Q has more heads than K/V, all queries within a group share the same K/V. Saves K/V memory.
  • attn_output_multiplier — scales attention output before softmax. Used for tanh-clamping (see below).
  • widening_factor = 4.0 default. Combined with * 2 // 3 in ffn_size, gives FFN dim = ~`2.67 * emb_size`.

hk_rms_norm, Linear, RMSNorm

def hk_rms_norm(
    x: jax.Array,
    fixed_scale=False,
) -> jax.Array:
    """Applies a unique LayerNorm to x with default settings."""
    ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
    return ln(x)


class Linear(hk.Linear):
    def __init__(
        self,
        output_size: int,
        with_bias: bool = True,
        name: Optional[str] = None,
    ):
        super().__init__(
            output_size=output_size,
            with_bias=with_bias,
            name=name,
        )

    def __call__(  # type: ignore
        self,
        inputs: jax.Array,
    ) -> jax.Array:
        """Computes a linear transform of the input."""

        fprop_dtype = inputs.dtype
        if not inputs.shape:
            raise ValueError("Input must not be scalar.")

        input_size = inputs.shape[-1]
        output_size = self.output_size

        w = hk.get_parameter(
            "w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
        )

        out = jnp.dot(inputs, w.astype(fprop_dtype))
        if self.with_bias:
            b = hk.get_parameter(
                "b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
            )
            b = jnp.broadcast_to(b, out.shape)
            out = out + b.astype(fprop_dtype)

        return out

A custom Linear that overrides Haiku's. Two reasons:

  1. fp32 parameters, bf16 forward: params stored as jnp.float32, cast to fprop_dtype on use. Keeps gradient updates precise while inference runs in bf16.
  2. Zero initialization: init=hk.initializers.Constant(0). Standard Haiku uses variance scaling. Why zero? Because this code is for inference — the params will be overwritten by load_model_params. Zero init is just shape-correct placeholder.

If you wanted to train from this code, you'd need to swap the initializer.

class RMSNorm(hk.RMSNorm):
    def __init__(
        self,
        axis: Union[int, Sequence[int], slice],
        eps: float = 1e-5,
        name: Optional[str] = None,
        create_scale: bool = True,
    ):
        super().__init__(axis, eps, create_scale=create_scale, name=name)

    def __call__(self, inputs: jax.Array):
        fprop_dtype = inputs.dtype
        param_shape = (inputs.shape[-1],)
        if self.create_scale:
            scale = hk.get_parameter(
                "scale",
                param_shape,
                dtype=jnp.float32,
                init=hk.initializers.Constant(0),
            )
            scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
        else:
            scale = 1.0
        inputs = inputs.astype(jnp.float32)
        scale = jnp.float32(scale)
        mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
        mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)

        normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)

        outputs = scale * normed_inputs

        return outputs.astype(fprop_dtype)

RMSNorm (Root Mean Square Layer Normalization). The standard modern alternative to LayerNorm — simpler (no mean subtraction), cheaper, and works just as well in transformers.

The math:

  • mean_squared = mean(x²)
  • normed = x / sqrt(mean_squared + eps)
  • output = scale * normed

Same fp32-params + bf16-output pattern as Linear. The scale parameter is per-feature ([D]).

Note init=hk.initializers.Constant(0) — zero init for the scale. Again, placeholder for loaded checkpoints. A trained checkpoint would have scales ~1.0 (or learned values).

rotate_half + RotaryEmbedding

def rotate_half(
    x: jax.Array,
) -> jax.Array:
    """Obtain the rotated counterpart of each feature"""
    x1, x2 = jnp.split(x, 2, axis=-1)
    return jnp.concatenate((-x2, x1), axis=-1)


class RotaryEmbedding(hk.Module):
    """Applies rotary embeddings (RoPE) to the input sequence tensor,
    as described in https://arxiv.org/abs/2104.09864.

    Attributes:
        dim (int): Dimensionality of the feature vectors
        base_exponent (int): Base exponent to compute embeddings from
    """

    def __init__(
        self,
        dim: int,
        name: Optional[str] = None,
        base_exponent: int = 10000,
    ):
        super().__init__(name)
        self.dim = dim
        self.base_exponent = base_exponent
        assert self.dim % 2 == 0

RoPE (Rotary Position Embedding). Standard for modern transformers. Treats each pair of features as a 2D vector and rotates it by an angle proportional to the position.

rotate_half swaps and negates: (x1, x2) → (-x2, x1). This is the 90° rotation in the 2D plane spanned by the two feature halves.

base_exponent = 10000 is the standard RoPE base.

    def __call__(
        self,
        x: jax.Array,
        seq_dim: int,
        offset: jax.Array,
        const_position: Optional[int] = None,
        t: Optional[jax.Array] = None,
    ) -> jax.Array:
        fprop_dtype = x.dtype
        # Compute the per-dimension frequencies
        exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
        inv_freq = jnp.asarray(
            1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
        )

        if jnp.shape(offset) == ():
            # Offset can be a scalar or one offset per batch element.
            offset = jnp.expand_dims(offset, 0)

        # Compute the per element phase (to pass into sin and cos)
        if const_position:
            t = const_position * jnp.ones(
                (
                    1,
                    x.shape[seq_dim],
                ),
                dtype=jnp.float32,
            )
        elif t is None:
            t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
        phase = jnp.einsum("bi,j->bij", t, inv_freq)
        phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]

        x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
        x = x.astype(fprop_dtype)

        return x

The RoPE math:

  1. inv_freq[i] = 1 / 10000^(2i/dim) for i in [0, dim/2). Lower frequencies for higher dimensions — so low-dim features rotate fast (capture short-range patterns), high-dim features rotate slow (capture long-range patterns).
  2. t = positions — either constant, custom (t param), or sequential (arange + offset).
  3. phase = t * inv_freq — outer product. Shape [batch, seq_len, dim/2].
  4. phase = tile(phase, (1, 2)) — duplicate so the phase array has shape [batch, seq_len, dim] (first half and second half match).
  5. x * cos(phase) + rotate_half(x) * sin(phase) — apply the 2D rotation in the feature plane.

The t parameter is what the recsys code uses: pass right-anchored positions, get position-aware encoding that respects the right-anchor.

MultiHeadAttention

The actual attention computation. Long.

class MultiHeadAttention(hk.Module):
    def __init__(
        self,
        num_q_heads: int,
        num_kv_heads: int,
        key_size: int,
        *,
        with_bias: bool = True,
        value_size: Optional[int] = None,
        model_size: Optional[int] = None,
        attn_output_multiplier: float = 1.0,
        name: Optional[str] = None,
    ):
        super().__init__(name=name)
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.key_size = key_size
        self.value_size = value_size or key_size
        self.model_size = model_size or key_size * num_q_heads
        self.attn_output_multiplier = attn_output_multiplier
        self.with_bias = with_bias

model_size defaults to key_size * num_q_heads — the standard transformer convention.

    def __call__(
        self,
        query: jax.Array,
        key: jax.Array,
        value: jax.Array,
        mask: jax.Array,
        positions: Optional[jax.Array] = None,
    ) -> MHAOutput:
        projection = self._linear_projection

        # Check that the keys and values have consistent batch size and sequence length.
        assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"

        if mask is not None:
            assert mask.ndim == 4
            assert mask.shape[0] in {
                1,
                query.shape[0],
            }, f"mask/query shape: {mask.shape}/{query.shape}"
            # ...

Sanity checks. The mask must be 4D [B, H, T_q, T_k] (or broadcastable to that), key/value must agree on batch + sequence length.

        # Compute key/query/values (overload K/Q/V to denote the respective sizes).
        assert self.num_q_heads % self.num_kv_heads == 0
        query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
        key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
        value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")

Project to Q, K, V. Q has more heads (num_q_heads) than K, V (num_kv_heads) — GQA setup. Multiple Q heads share each K/V pair.

        rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
        key_heads = rotate(key_heads, seq_dim=1, offset=0, t=positions)
        query_heads = rotate(query_heads, seq_dim=1, offset=0, t=positions)

Apply RoPE to Q and K (but not V — RoPE only affects the keys' relative positions, which the dot product extracts). Pass the custom positions if available.

        b, t, h, d = query_heads.shape
        _, _, kv_h, _ = key_heads.shape
        assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"

        query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))

Reshape Q for GQA: split the h Q heads into kv_h groups of h // kv_h heads each. Now Q is [B, T, kv_h, h_per_group, d] and K, V are [B, T, kv_h, d].

        # Compute attention weights.
        # Attention softmax is always carried out in fp32.
        attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
            jnp.float32
        )
        attn_logits *= self.attn_output_multiplier
        max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
        attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)

Compute attention logits via einsum. The string "...thHd,...Thd->...hHtT":

  • t = query seq position, T = key seq position.
  • h = kv-group, H = q-head within the group.
  • d = feature dim.
  • Output: [batch, kv_group, q_head, query_pos, key_pos].

Cast to fp32 — softmax precision matters.

attn_output_multiplier is a scaling factor. Then tanh-clamping: max_attn_val * tanh(logits / max_attn_val). Squashes logits to (-max_attn_val, max_attn_val). Prevents huge logits from overflowing softmax in low-precision contexts. Standard stability hack used in some xAI / Anthropic models.

        mask = mask[:, :, None, :, :]

        if mask is not None:
            if mask.ndim != attn_logits.ndim:
                raise ValueError(
                    f"Mask dimensionality {mask.ndim} must match logits dimensionality "
                    f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
                )
            attn_logits = jnp.where(mask, attn_logits, -1e30)
        attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype)  # [H, T', T]

Apply mask: where(mask, logits, -1e30). Where the mask is False, set logits to -1e30 — effectively -∞ after softmax. Then softmax → zero probability for masked positions.

The mask[:, :, None, :, :] reshape inserts the H-axis dim so the mask broadcasts across heads.

Cast back to query dtype (bf16) after softmax.

        # Weight the values by the attention and flatten the head vectors.
        attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
        leading_dims = attn.shape[:2]
        attn = jnp.reshape(attn, (*leading_dims, -1))  # [T', H*V]

        # Apply another projection to get the final embeddings.
        final_projection = Linear(self.model_size, with_bias=False)
        return MHAOutput(final_projection(attn))

Apply attention weights to V, flatten heads, project to model_size.

The output projection uses Linear with with_bias=False — biases on attention outputs are skipped (common modern convention).

MHABlock, DenseBlock, DecoderLayer

@dataclass
class MHABlock(hk.Module):
    """A MHA Block"""

    num_q_heads: int
    num_kv_heads: int
    key_size: int
    attn_output_multiplier: float = 1.0

    @hk.transparent
    def __call__(
        self,
        inputs: jax.Array,  # [B, T, D]
        mask: jax.Array,  # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
        positions: Optional[jax.Array] = None,
    ) -> MHAOutput:
        _, _, model_size = inputs.shape
        assert mask.ndim == 4, f"shape: {mask.shape}"
        assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
        assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
        side_input = inputs

        def attn_block(query, key, value, mask) -> MHAOutput:
            return MultiHeadAttention(
                num_q_heads=self.num_q_heads,
                num_kv_heads=self.num_kv_heads,
                key_size=self.key_size,
                model_size=model_size,
                attn_output_multiplier=self.attn_output_multiplier,
            )(query, key, value, mask, positions=positions)

        attn_output = attn_block(inputs, side_input, side_input, mask)
        h_attn = attn_output.embeddings

        return MHAOutput(embeddings=h_attn)

A thin wrapper that uses inputs for all three of Q/K/V: self-attention. The side_input = inputs makes it explicit.

@hk.transparent — Haiku decorator that makes the module's parameters live in the enclosing module's namespace (instead of creating a sub-namespace). Lets the calling layer have a flatter parameter tree.

@dataclass
class DenseBlock(hk.Module):
    num_q_heads: int
    num_kv_heads: int
    key_size: int
    widening_factor: float = 4.0

    @hk.transparent
    def __call__(
        self,
        inputs: jax.Array,  # [B, T, D]
    ) -> jax.Array:  # [B, T, D]
        _, _, model_size = inputs.shape
        h_v = Linear(
            ffn_size(model_size, self.widening_factor),
            with_bias=False,
            name="linear_v",
        )(inputs)
        h_w1 = jax.nn.gelu(
            Linear(
                ffn_size(model_size, self.widening_factor),
                with_bias=False,
            )(inputs)
        )
        h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)

        return h_dense

The FFN. This is GeGLU (Gated GELU): two parallel projections, one passed through GELU, then element-wise multiplied. Then projected back to model_size.

The math: out = Linear_down(GELU(Linear_up_1(x)) * Linear_up_2(x)). Three linear layers (gate, up, down). Better than vanilla FFN at the same parameter count.

The * 2 // 3 adjustment in ffn_size keeps this iso-parameter with a vanilla widening_factor * D FFN.

@dataclass
class DecoderLayer(hk.Module):
    """A transformer stack."""

    num_q_heads: int
    num_kv_heads: int
    key_size: int
    num_layers: int
    layer_index: Optional[int] = None
    widening_factor: float = 4.0
    name: Optional[str] = None
    attn_output_multiplier: float = 1.0

    def __call__(
        self,
        inputs: jax.Array,  # [B, T, D]
        mask: jax.Array,  # [B, 1, T, T] or [B, 1, 1, T]
        padding_mask: Optional[jax.Array],
        positions: Optional[jax.Array] = None,
    ) -> DecoderOutput:
        """Transforms input embedding sequences to output embedding sequences."""
        del padding_mask  # Unused.

        def layer_norm(x):
            return hk_rms_norm(x)

        h = inputs

        attn_output = MHABlock(
            num_q_heads=self.num_q_heads,
            num_kv_heads=self.num_kv_heads,
            key_size=self.key_size,
            attn_output_multiplier=self.attn_output_multiplier,
        )(layer_norm(h), mask, positions=positions)
        h_attn = attn_output.embeddings

        h_attn = layer_norm(h_attn)
        h += h_attn

        def base_dense_block(h):
            h = DenseBlock(
                num_q_heads=self.num_q_heads,
                num_kv_heads=self.num_kv_heads,
                key_size=self.key_size,
                widening_factor=self.widening_factor,
            )(h)
            return h

        h_dense = base_dense_block(layer_norm(h))

        h_dense = layer_norm(h_dense)
        h += h_dense

        return DecoderOutput(
            embeddings=h,
        )

One transformer layer. The structure:

inputs
↓
RMSNorm
↓
MHA (self-attention)
↓
RMSNorm    ← double layer norm: also after the attention block
↓
+ inputs   ← residual
↓
RMSNorm
↓
GeGLU FFN
↓
RMSNorm    ← double layer norm again
↓
+ previous ← residual
↓
output

Double layer norm is unusual — most transformers do single. This is one of the Grok-1 design choices preserved. Pre-norm AND post-norm style, stacked.

del padding_mask — the param exists in the signature for API compat but isn't used. The mask already incorporates padding.

Transformer — the full stack

@dataclass
class Transformer(hk.Module):
    """A transformer stack."""

    num_q_heads: int
    num_kv_heads: int
    key_size: int
    widening_factor: float
    attn_output_multiplier: float
    num_layers: int
    name: Optional[str] = None

    def __call__(
        self,
        embeddings: jax.Array,  # [B, T, D]
        mask: jax.Array,  # [B, T]
        candidate_start_offset: Optional[int] = None,
        positions: Optional[jax.Array] = None,
    ) -> TransformerOutput:
        # ...

        fprop_dtype = embeddings.dtype
        _, seq_len, _ = embeddings.shape
        padding_mask = mask.copy()
        mask = mask[:, None, None, :]  # [B, H=1, T'=1, T]

        if candidate_start_offset is not None:
            # Use recommendation system attention mask where candidates attend to
            # user+history and themselves, but not to other candidates
            attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
            mask = mask * attn_mask
        else:
            # Standard causal mask for autoregressive sequence modelling
            causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
                fprop_dtype
            )  # [B=1, H=1, T, T]
            mask = mask * causal_mask  # [B, H=1, T, T]

Mask construction — this is where the candidate_start_offset parameter matters.

The input mask is a 2D padding mask [B, T]. Expand to [B, 1, 1, T].

Two branches:

  • candidate_start_offset is not None (ranking): build the recsys mask we read earlier, multiply with the padding mask. The result has both the causal-with-candidate-isolation structure AND padding zeros.
  • candidate_start_offset is None (retrieval): standard causal mask. The retrieval model uses this since there are no candidates in the sequence.

Multiplying the padding mask with the structural mask combines them: a position is attendable iff (padding mask says it's valid) AND (structural mask says it's reachable).

        h = embeddings

        def block(
            h,
            mask,
            padding_mask,
            layer_index: Optional[int] = None,
            widening_factor: Optional[int] = None,
            name: Optional[str] = None,
        ) -> DecoderOutput:
            return DecoderLayer(
                num_q_heads=self.num_q_heads,
                num_kv_heads=self.num_kv_heads,
                key_size=self.key_size,
                widening_factor=widening_factor or self.widening_factor,
                num_layers=self.num_layers,
                attn_output_multiplier=self.attn_output_multiplier,
                name=name,
                layer_index=layer_index,
            )(h, mask, padding_mask, positions=positions)

        for i in range(self.num_layers):
            decoder_output = block(
                h,
                mask,
                padding_mask,
                layer_index=i,
                name=f"decoder_layer_{i}",
            )
            h = decoder_output.embeddings

        return TransformerOutput(
            embeddings=h,
        )

Stack num_layers decoder layers. Each is named decoder_layer_0, decoder_layer_1, ... — the names matter for parameter loading from checkpoints.

Sequential pass — no skipping, no early exit. Output is the final hidden states [B, T, D].

End of grok.py.


test_recsys_model.py (309 lines)

The test file for the ranking model side. Four test classes:

  1. TestMakeRecsysAttnMask — pins down the attention mask behavior.
  2. TestRightAnchoredRopePositions — RoPE positions.
  3. TestComputePostAgeBucket — age bucketing.
  4. TestNormalizeContinuousValue — normalization.

We'll quote the most important ones.

TestMakeRecsysAttnMask — the canary test

class TestMakeRecsysAttnMask:
    """Tests for the make_recsys_attn_mask function."""

    def test_output_shape(self):
        """Test that the output has the correct shape [1, 1, seq_len, seq_len]."""
        seq_len = 10
        candidate_start_offset = 5

        mask = make_recsys_attn_mask(seq_len, candidate_start_offset)

        assert mask.shape == (1, 1, seq_len, seq_len)

Basic shape test. The mask is [1, 1, seq_len, seq_len] — broadcasts across batch and heads.

    def test_user_history_has_causal_attention(self):
        """Test that user+history positions (before candidate_start_offset) have causal attention."""
        seq_len = 8
        candidate_start_offset = 5

        mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
        mask_2d = mask[0, 0]

        for i in range(candidate_start_offset):
            for j in range(candidate_start_offset):
                if j <= i:
                    assert mask_2d[i, j] == 1, f"Position {i} should attend to position {j}"
                else:
                    assert (
                        mask_2d[i, j] == 0
                    ), f"Position {i} should NOT attend to future position {j}"

For positions in [0, candidate_start_offset): standard causal triangle. j <= i attends, j > i doesn't.

    def test_candidates_attend_to_user_history(self):
        """Test that candidates can attend to all user+history positions."""
        # ...
        for candidate_pos in range(candidate_start_offset, seq_len):
            for history_pos in range(candidate_start_offset):
                assert (
                    mask_2d[candidate_pos, history_pos] == 1
                ), f"Candidate at {candidate_pos} should attend to user+history at {history_pos}"

Every candidate attends to every user+history position. Unconditional — no causality constraint on the candidate-to-history edge.

    def test_candidates_attend_to_themselves(self):
        """Test that candidates can attend to themselves (self-attention)."""
        for candidate_pos in range(candidate_start_offset, seq_len):
            assert (
                mask_2d[candidate_pos, candidate_pos] == 1
            ), f"Candidate at {candidate_pos} should attend to itself"

    def test_candidates_do_not_attend_to_other_candidates(self):
        """Test that candidates cannot attend to other candidates."""
        for query_pos in range(candidate_start_offset, seq_len):
            for key_pos in range(candidate_start_offset, seq_len):
                if query_pos != key_pos:
                    assert (
                        mask_2d[query_pos, key_pos] == 0
                    ), f"Candidate at {query_pos} should NOT attend to candidate at {key_pos}"

Self-attention only for candidates. No candidate-to-candidate cross-attention.

    def test_full_mask_structure(self):
        """Test the complete mask structure with a small example."""
        # Sequence: [user, h1, h2, c1, c2, c3]
        # Positions:  0     1   2   3   4   5

        seq_len = 6
        candidate_start_offset = 3

        mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
        mask_2d = mask[0, 0]

        expected = np.array(
            [
                [1, 0, 0, 0, 0, 0],  # user
                [1, 1, 0, 0, 0, 0],  # h1
                [1, 1, 1, 0, 0, 0],  # h2
                [1, 1, 1, 1, 0, 0],  # c1: user+history + self
                [1, 1, 1, 0, 1, 0],  # c2: user+history + self
                [1, 1, 1, 0, 0, 1],  # c3: user+history + self
            ],
            dtype=np.float32,
        )

        np.testing.assert_array_equal(
            np.array(mask_2d),
            expected,
            err_msg="Full mask structure does not match expected pattern",
        )

The canonical example. Matches the diagram from earlier in this article. This is the most readable test of the four — the full expected matrix is laid out explicitly.

    def test_dtype_preserved(self):
        """Test that the specified dtype is used."""
        seq_len = 5
        candidate_start_offset = 3

        mask_f32 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float32)
        mask_f16 = make_recsys_attn_mask(seq_len, candidate_start_offset, dtype=jnp.float16)

        assert mask_f32.dtype == jnp.float32
        assert mask_f16.dtype == jnp.float16

Dtype passthrough.

    def test_single_candidate(self):
        """Test edge case with a single candidate."""
        seq_len = 4
        candidate_start_offset = 3

        mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
        mask_2d = mask[0, 0]

        expected = np.array(
            [
                [1, 0, 0, 0],
                [1, 1, 0, 0],
                [1, 1, 1, 0],
                [1, 1, 1, 1],
            ],
            dtype=np.float32,
        )

        np.testing.assert_array_equal(np.array(mask_2d), expected)

Single candidate edge case: with only one candidate, the mask is just a regular causal triangle. The "no candidate-to-candidate" rule is vacuous; the candidate just sees itself.

    def test_all_candidates(self):
        """Test edge case where all positions except first are candidates."""
        seq_len = 4
        candidate_start_offset = 1

        mask = make_recsys_attn_mask(seq_len, candidate_start_offset)
        mask_2d = mask[0, 0]

        expected = np.array(
            [
                [1, 0, 0, 0],  # user
                [1, 1, 0, 0],  # c1: user + self
                [1, 0, 1, 0],  # c2: user + self
                [1, 0, 0, 1],  # c3: user + self
            ],
            dtype=np.float32,
        )

        np.testing.assert_array_equal(np.array(mask_2d), expected)

No history edge case: user only, then candidates. Each candidate sees just (user, self).

TestRightAnchoredRopePositions

class TestRightAnchoredRopePositions:
    """Tests for the right_anchored_rope_positions function."""

    def test_output_shape(self):
        """Test that the output has the correct shape [B, T]."""
        B, T = 2, 10
        padding_mask = jnp.ones((B, T), dtype=jnp.bool_)
        positions = right_anchored_rope_positions(
            padding_mask, history_seq_len=6, num_user_prefix_tokens=1
        )
        assert positions.shape == (B, T)

    def test_prefix_positions_preserved(self):
        """Test that prefix token positions are 0..num_prefix-1."""
        B, T = 1, 10
        padding_mask = jnp.ones((B, T), dtype=jnp.bool_)
        positions = right_anchored_rope_positions(
            padding_mask, history_seq_len=6, num_user_prefix_tokens=2
        )
        assert float(positions[0, 0]) == 0.0
        assert float(positions[0, 1]) == 1.0

Prefix tokens get positions 0..num_prefix-1. So the user token (if num_user_prefix=1) is always at position 0.

    def test_candidates_share_position(self):
        """Test that all candidate positions are the same (history_end)."""
        B = 1
        num_prefix = 1
        history_len = 4
        num_candidates = 3
        T = num_prefix + history_len + num_candidates

        padding_mask = jnp.ones((B, T), dtype=jnp.bool_)
        positions = right_anchored_rope_positions(
            padding_mask, history_seq_len=history_len, num_user_prefix_tokens=num_prefix
        )

        history_end = num_prefix + history_len
        for c in range(num_candidates):
            assert float(positions[0, history_end + c]) == float(history_end)

All candidates get position history_end. So they share one position — the model's candidate representations differ only via input content, not position. Consistent with the candidate-isolation logic.

    def test_padding_gets_zero(self):
        """Test that padded positions get position 0."""
        B, T = 1, 8
        padding_mask = jnp.array([[True, True, True, True, False, False, False, False]])
        positions = right_anchored_rope_positions(
            padding_mask, history_seq_len=4, num_user_prefix_tokens=1
        )
        for i in range(4, 8):
            assert float(positions[0, i]) == 0.0

Padded slots get position 0. The mask zeros their attention contribution anyway, but the position is well-defined.

TestComputePostAgeBucket

class TestComputePostAgeBucket:
    """Tests for the compute_post_age_bucket function."""

    def test_basic_bucketing(self):
        """Test basic bucketing with 60-minute granularity."""
        # Post that is 30 minutes old -> bucket 1 (0-59 minutes)
        impr_ts = jnp.array([[1000000]])
        post_ts = jnp.array([[1000000 - 30 * 60]])
        bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
        assert int(bucket[0, 0]) == 1

30 minutes old → bucket 1 (i.e., 0-59 minutes, bucket 0 is reserved for missing).

    def test_two_hour_post(self):
        """Test a 2-hour-old post -> bucket 3 (120-179 minutes)."""
        impr_ts = jnp.array([[1000000]])
        post_ts = jnp.array([[1000000 - 120 * 60]])
        bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
        assert int(bucket[0, 0]) == 3

2 hours = 120 minutes → 120 // 60 + 1 = 3. So bucket = (minutes // granularity) + 1. Bucket 1 covers 0-59 minutes, bucket 2 covers 60-119, bucket 3 covers 120-179.

    def test_missing_timestamp_zero(self):
        """Test that missing timestamps (0) map to bucket 0."""
        impr_ts = jnp.array([[0]])
        post_ts = jnp.array([[1000000]])
        bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
        assert int(bucket[0, 0]) == 0
        # ... same for missing post_ts

Sentinel handling: ts of 0 → bucket 0.

    def test_negative_age_maps_to_zero(self):
        """Test that negative age (clock skew) maps to bucket 0."""
        impr_ts = jnp.array([[1000000]])
        post_ts = jnp.array([[1000000 + 60 * 60]])  # post created AFTER impression
        bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
        assert int(bucket[0, 0]) == 0

Clock-skew handling: negative age → bucket 0. Same as missing.

    def test_overflow_bucket(self):
        """Test that very old posts go to the overflow bucket."""
        # POST_AGE_MAX_MINUTES = 4800 (80 hours)
        impr_ts = jnp.array([[1000000]])
        post_ts = jnp.array([[1000000 - 5000 * 60]])  # 5000 minutes old
        bucket = compute_post_age_bucket(impr_ts, post_ts, granularity_mins=60)
        # overflow bucket = 4800 // 60 + 1 = 81
        assert int(bucket[0, 0]) == 81

Posts older than 80 hours → bucket 81. Beyond-this-bucket = "ancient post" treated as a single category.

TestNormalizeContinuousValue

class TestNormalizeContinuousValue:
    """Tests for continuous value normalization."""

    def test_linear_normalization(self):
        """Test linear normalization: x / norm_scale."""
        config = NormConfig(norm_scale=30.0, use_log=False)
        values = jnp.array([0.0, 15.0, 30.0, 60.0])
        result = normalize_continuous_value(values, config)
        np.testing.assert_allclose(np.array(result), [0.0, 0.5, 1.0, 1.0], atol=1e-6)

Linear: x / 30. Note 60.0 → 1.0 (clamped).

    def test_log_normalization(self):
        """Test log normalization: log1p(x) / log1p(norm_scale)."""
        config = NormConfig(norm_scale=30.0, use_log=True)
        values = jnp.array([0.0, 30.0])
        result = normalize_continuous_value(values, config)
        assert float(result[0]) == pytest.approx(0.0, abs=1e-6)
        assert float(result[1]) == pytest.approx(1.0, abs=1e-6)

    def test_clamping(self):
        """Test that values are clamped to [0, norm_scale]."""
        config = NormConfig(norm_scale=10.0, use_log=False)
        values = jnp.array([-5.0, 0.0, 5.0, 15.0])
        result = normalize_continuous_value(values, config)
        np.testing.assert_allclose(np.array(result), [0.0, 0.0, 0.5, 1.0], atol=1e-6)

Standard edge-case tests. log1p for log mode (handles 0 cleanly: log1p(0) = log(1) = 0). Clamping at both ends.


test_recsys_retrieval_model.py (417 lines)

Tests for the retrieval model. Three test classes:

  1. TestCandidateTower — the MLP/mean-pool candidate encoder.
  2. TestPhoenixRetrievalModel — the full retrieval model.
  3. TestRetrievalInferenceRunner — the runner from Session 16.

TestCandidateTower

class TestCandidateTower(unittest.TestCase):
    """Tests for the CandidateTower module."""

    def test_candidate_tower_output_shape(self):
        """Test that candidate tower produces correct output shape."""
        emb_size = 64
        batch_size = 4
        num_candidates = 8
        num_hashes = 4

        def forward(x):
            tower = CandidateTower(emb_size=emb_size, enable_linear_proj=True)
            return tower(x)

        forward_fn = hk.without_apply_rng(hk.transform(forward))

        rng = jax.random.PRNGKey(0)
        x = jax.random.normal(rng, (batch_size, num_candidates, num_hashes, emb_size))

        params = forward_fn.init(rng, x)
        output = forward_fn.apply(params, x)

        self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))

Shape test: [B, C, num_hashes, D] → [B, C, D]. The candidate tower flattens hashes.

    def test_candidate_tower_normalized(self):
        """Test that candidate tower output is L2 normalized."""
        # ...
        params = forward_fn.init(rng, x)
        output = forward_fn.apply(params, x)

        norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
        np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)

Output is L2-normalized to unit vectors. Critical for cosine-similarity retrieval. The model code (Session 15) ensures this in both modes.

    def test_candidate_tower_mean_pooling(self):
        """Test candidate tower with mean pooling (no linear projection)."""
        # ...
        def forward(x):
            tower = CandidateTower(emb_size=emb_size, enable_linear_proj=False)
            return tower(x)

        # ...
        output = forward_fn.apply(params, x)

        self.assertEqual(output.shape, (batch_size, num_candidates, emb_size))

        norms = jnp.sqrt(jnp.sum(output**2, axis=-1))
        np.testing.assert_array_almost_equal(norms, jnp.ones_like(norms), decimal=5)

Mean-pooling mode also produces normalized outputs. Same shape, same norm guarantee.

    def test_mean_pooling_has_no_params(self):
        """Test that mean pooling mode introduces no learned parameters."""
        # ...
        params = forward_fn.init(rng, x)
        # Mean pooling should have no parameters
        total_params = sum(p.size for p in jax.tree.leaves(params))
        self.assertEqual(total_params, 0)

Zero learned parameters in mean-pooling mode. jax.tree.leaves flattens the Haiku params tree; p.size sums element counts. Total = 0 confirms no hk.get_parameter calls fired.

This is the "more parameter-efficient but less expressive" mode referenced in Session 15. Useful when you want to use the post + author hash embeddings directly without a learned projection on top.

TestPhoenixRetrievalModel

class TestPhoenixRetrievalModel(unittest.TestCase):
    """Tests for the full Phoenix Retrieval Model."""

    def setUp(self):
        """Set up test fixtures."""
        self.emb_size = 64
        self.history_seq_len = 16
        self.candidate_seq_len = 8
        self.batch_size = 2
        self.num_actions = 19
        self.corpus_size = 100
        self.top_k = 10

        self.hash_config = HashConfig(
            num_user_hashes=2,
            num_item_hashes=2,
            num_author_hashes=2,
        )

        self.config = PhoenixRetrievalModelConfig(
            emb_size=self.emb_size,
            history_seq_len=self.history_seq_len,
            candidate_seq_len=self.candidate_seq_len,
            hash_config=self.hash_config,
            product_surface_vocab_size=16,
            enable_linear_proj=True,
            model=TransformerConfig(
                emb_size=self.emb_size,
                widening_factor=2,
                key_size=32,
                num_q_heads=2,
                num_kv_heads=2,
                num_layers=1,
                attn_output_multiplier=0.125,
            ),
        )

Tiny test config: emb_size=64, 16-item history, 8 candidates, single transformer layer, 2 heads. Same architecture as production, just shrunk.

    def test_model_forward(self):
        """Test model forward pass produces correct output shapes."""

        def forward(batch, embeddings, corpus_embeddings, top_k):
            model = self.config.make()
            return model(batch, embeddings, corpus_embeddings, top_k)

        forward_fn = hk.without_apply_rng(hk.transform(forward))

        batch, embeddings = self._create_test_batch()
        corpus_embeddings, _ = self._create_test_corpus()

        rng = jax.random.PRNGKey(0)
        params = forward_fn.init(rng, batch, embeddings, corpus_embeddings, self.top_k)
        output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)

        self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
        self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
        self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))

Forward pass + output-shape checks. The full pipeline init runs end-to-end.

    def test_user_representation_normalized(self):
        """Test that user representations are L2 normalized."""
        # ...
        norms = jnp.sqrt(jnp.sum(output.user_representation**2, axis=-1))
        np.testing.assert_array_almost_equal(norms, jnp.ones(self.batch_size), decimal=5)

    def test_candidate_representation_normalized(self):
        """Test that candidate representations from build_candidate_representation are L2 normalized."""
        # ...
        norms = jnp.sqrt(jnp.sum(cand_rep**2, axis=-1))
        np.testing.assert_array_almost_equal(
            norms, jnp.ones((self.batch_size, self.candidate_seq_len)), decimal=5
        )

Both user and candidate representations are L2-normalized. Required for cosine-similarity retrieval.

    def test_retrieve_top_k(self):
        """Test top-k retrieval through __call__."""
        # ...
        self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))
        self.assertEqual(output.top_k_scores.shape, (self.batch_size, self.top_k))

        self.assertTrue(jnp.all(output.top_k_indices >= 0))
        self.assertTrue(jnp.all(output.top_k_indices < self.corpus_size))

        for b in range(self.batch_size):
            scores = np.array(output.top_k_scores[b])
            self.assertTrue(np.all(scores[:-1] >= scores[1:]))

Three guarantees:

  1. Indices are valid ([0, corpus_size)).
  2. Scores are sorted descending (top-K). scores[:-1] >= scores[1:] — each consecutive pair is non-increasing.
    def test_mean_pooling_model_forward(self):
        """Test model forward pass with mean pooling candidate tower."""
        config = PhoenixRetrievalModelConfig(
            # ...
            enable_linear_proj=False,  # Mean pooling
            # ...
        )

        # ...
        output = forward_fn.apply(params, batch, embeddings, corpus_embeddings, self.top_k)

        self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
        self.assertEqual(output.top_k_indices.shape, (self.batch_size, self.top_k))

Confirm the mean-pooling mode also produces a runnable model.

TestRetrievalInferenceRunner

class TestRetrievalInferenceRunner(unittest.TestCase):
    """Tests for the retrieval inference runner."""

    def setUp(self):
        """Set up test fixtures."""
        # ...

    def test_runner_initialization(self):
        """Test that runner initializes correctly."""
        runner = RecsysRetrievalInferenceRunner(
            runner=RetrievalModelRunner(
                model=self.config,
                bs_per_device=0.125,
            ),
            name="test_retrieval",
        )

        runner.initialize()

        self.assertIsNotNone(runner.params)

    def test_runner_encode_user(self):
        """Test user encoding through runner."""
        runner = RecsysRetrievalInferenceRunner(
            runner=RetrievalModelRunner(
                model=self.config,
                bs_per_device=0.125,
            ),
            name="test_retrieval",
        )
        runner.initialize()

        batch, embeddings = create_example_batch(...)

        user_rep = runner.encode_user(batch, embeddings)

        self.assertEqual(user_rep.shape, (self.batch_size, self.emb_size))

Test the inference runner from Session 16. Confirms:

  • Runner initializes (params are not None).
  • encode_user returns the right shape.
  • retrieve (next test) returns shaped outputs.
    def test_runner_retrieve(self):
        """Test retrieval through runner."""
        # ...
        runner.initialize()

        batch, embeddings = create_example_batch(...)

        corpus_size = 100
        corpus_embeddings, corpus_post_ids = create_example_corpus(corpus_size, self.emb_size)
        runner.set_corpus(corpus_embeddings, corpus_post_ids)

        top_k = 10
        output = runner.retrieve(batch, embeddings, top_k=top_k)

        self.assertEqual(output.user_representation.shape, (self.batch_size, self.emb_size))
        self.assertEqual(output.top_k_indices.shape, (self.batch_size, top_k))
        self.assertEqual(output.top_k_scores.shape, (self.batch_size, top_k))

End-to-end runner test: init → set_corpus → retrieve. Pins down the inference contract.


What we've learned

The Grok transformer is mostly a standard modern LLM transformer:

  • RMSNorm instead of LayerNorm.
  • RoPE for position encoding.
  • GeGLU FFN with * 2 // 3 SwiGLU-style scaling.
  • GQA (num_q_heads >= num_kv_heads).
  • Tanh-clamping on attention logits for stability.
  • Double layer norm per layer.
  • Causal attention by default.

Two recsys-specific modifications:

  1. make_recsys_attn_mask — the candidate isolation mask. Critical for score consistency.
  2. right_anchored_rope_positions — anchors the most-recent history token at a fixed position. Lets the model treat position as recency.

The recsys attention mask pattern (memorize this):

  • Causal within user + history.
  • Candidates see all of user + history.
  • Candidates see themselves but no other candidates.

Why candidate isolation matters:

  • Score for candidate X doesn't depend on which other candidates are in the batch.
  • Reproducible across requests (a candidate scored alone vs. with others gets the same score).
  • Cacheable — the model output for a (user, candidate) pair is invariant to other candidates.
  • Required for the partial-scoring caches (Sessions 12-13).

The RoPE customization fixes the "variable-length history" problem:

  • Standard: index = position; longer history = larger position indices for older items.
  • Right-anchored: most-recent history is always near history_end - 1, regardless of how full the history is.
  • The model can learn "position 50" = "this happened ~50 actions ago" consistently across users.

Candidate positions are all the same: all candidates get position history_end. They differ only in content, not position.

fp32 params + bf16 forward: parameters are stored as float32, cast to bfloat16 on forward pass. Softmax always done in float32. Standard mixed-precision setup.

Zero parameter init: hk.initializers.Constant(0) everywhere. Because this code is for inference — parameters get overwritten by load_model_params. For training, swap to VarianceScaling.

The double layer norm per layer: pre-norm AND post-norm. Unusual. Inherited from Grok-1.

Tanh-clamped attention logits: max_attn_val * tanh(logits / max_attn_val). Prevents softmax overflow in bf16. The clamp value is 30.0. Standard stability hack in modern LLMs.

GeGLU FFN with iso-parameter scaling: ffn_size = (widening_factor * emb_size) * 2 / 3, rounded to multiple of 8. The * 2 / 3 factor keeps the 3-projection GeGLU at the same param count as a vanilla 2-projection FFN.

Test-driven specs: the test files act as the spec for make_recsys_attn_mask and right_anchored_rope_positions. If you were re-implementing in another framework, these tests are what you'd target. The expected-matrix test in test_full_mask_structure is the most useful — it shows the exact desired output.

Two model modes for CandidateTower:

  • enable_linear_proj=True (default): 2-layer MLP with SiLU + L2-norm.
  • enable_linear_proj=False: mean-pool + L2-norm. Zero learned parameters — the test confirms this. Useful when you want minimal model size or when the hash embeddings are already expressive enough.

Sorted top-K invariant: the retrieval test confirms scores[:-1] >= scores[1:] — output indices are returned in descending score order. The caller can trust the ordering.

Runner-level tests: TestRetrievalInferenceRunner validates the full init → set_corpus → retrieve pipeline. The tests are basically usage examples — copy them to start using the runner.


Phoenix tour complete

That's the whole Phoenix Python codebase. We've now read:

  • Session 15: recsys_model.py + recsys_retrieval_model.py — model architecture.
  • Session 16: runners.py + run_pipeline.py + run_ranker.py + run_retrieval.py — inference machinery + end-to-end demo.
  • Session 17 (this): grok.py + tests — the transformer + behavior specs.

Series progress:

Sessions Module LOC
01 candidate-pipeline 1,031
02-03 thunder 1,808
04-14 home-mixer ~11,700
15-17 phoenix 3,880
Total ~18,419

5 sessions left, all in Grox (the LLM-driven content classification pipeline).

Next session

Session 18 — Grox core (~900 LOC). The orchestration layer for the LLM-driven content pipeline:

  • grox/__init__.py, dispatcher.py, engine.py, main.py — the runtime
  • grox/schedules/* — schedule + context management
  • grox/generators/* — task and stream generators
  • grox/lib/* — utilities

This is the start of the content-understanding pipeline that produces the safety labels and topic classifications we saw consumed throughout the home-mixer Rust code.