X For You algorithm, line by line · Part 16

X For You algorithm, line by line — Part 16: Phoenix runners + end-to-end pipeline

Part 16 of the deep dive into xai-org/x-algorithm. The Python runner infrastructure: ModelRunner / RetrievalModelRunner with Haiku transform setup, checkpoint loading from .npz, the unified embedding table layout, three apply functions for retrieval, and run_pipeline.py — the headline release addition that runs retrieval → ranking from exported checkpoints.

May 15, 2026·30 min read

Session 15 covered the model architecture. This session covers the runner code that loads checkpoints, builds inference functions, encodes inputs, and drives the model end-to-end. Plus the new run_pipeline.py (added in this release) that ties retrieval + ranking together as a runnable example.

Files covered (~1,470 LOC):

phoenix/
├── runners.py          (807)  ModelRunner, RecsysInferenceRunner, RetrievalModelRunner, etc.
├── run_pipeline.py     (393)  end-to-end retrieval → ranking from exported checkpoints
├── run_ranker.py       (121)  standalone ranker demo
└── run_retrieval.py    (149)  standalone retrieval demo

run_pipeline.py is the headline change in this release per the README:

A new phoenix/run_pipeline.py replaces the separate run_ranker.py and run_retrieval.py scripts with a single entry point that runs retrieval → ranking from exported checkpoints, mirroring how the two stages are composed in production.


runners.py (807 lines) — the inference machinery

The big infrastructure file. Splits into five chunks:

  1. Checkpoint loading + dummy data helpers.
  2. Action / continuous-action constants.
  3. ModelRunner + RecsysInferenceRunner for ranking.
  4. RetrievalModelRunner + RecsysRetrievalInferenceRunner for retrieval.
  5. Example-data factories for testing.

Imports

import functools
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

from grok import TrainingState
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from recsys_retrieval_model import RetrievalOutput as ModelRetrievalOutput

from recsys_model import (
    PhoenixModelConfig,
    RecsysBatch,
    RecsysEmbeddings,
    RecsysModelOutput,
)

rank_logger = logging.getLogger("rank")

TrainingState from grok — the standard JAX/Haiku pattern: a (params, …) container. Here only params is used since we're doing inference, not training.

Both models import their config + IO types from the model files (Session 15). RetrievalOutput is renamed to ModelRetrievalOutput because this file has another RetrievalOutput named tuple lower down (the inference-runner wrapper).

Checkpoint loading

def load_model_params(checkpoint_path: str) -> hk.Params:
    """Load model parameters from an exported checkpoint.

    Args:
        checkpoint_path: Path to model_params.npz file

    Returns:
        Haiku params dict (nested FrozenDict)
    """
    data = np.load(checkpoint_path, allow_pickle=True)
    params: dict = {}
    for key in data.files:
        parts = key.split("/")
        module_path = "/".join(parts[:-1])
        param_name = parts[-1]
        params.setdefault(module_path, {})[param_name] = jnp.array(data[key])
    return hk.data_structures.to_haiku_dict(params)

The checkpoint format: an .npz (numpy zipped) file where each key is a slash-separated path like phoenix_model/transformer/layer_0/attention/query/w. The last segment is the param name; everything before is the module path.

Reconstruct the nested dict structure that Haiku expects:

{
    "phoenix_model/transformer/layer_0/attention/query": {"w": <array>, "b": <array>}
}

hk.data_structures.to_haiku_dict converts the Python dict into the immutable FrozenDict Haiku uses. The output goes into TrainingState(params=...).

def load_embedding_table(path: str) -> np.ndarray:
    """Load an embedding table from an exported checkpoint.

    Args:
        path: Path to embedding_tables.npz file

    Returns:
        Dict with 'user_embeddings', 'item_embeddings', 'author_embeddings' arrays
    """
    return dict(np.load(path))

Embedding tables are separate from model params. Why? Because:

  • Embedding tables are huge (millions of rows × emb_size).
  • In production they sit on parameter servers, not on the inference accelerator.
  • For the open-source release, they're shipped as separate .npz files so you can use them without loading the full distributed system.

The return type annotation says np.ndarray but the body returns a dict — minor type hint bug.

Dummy data factories

def create_dummy_batch_from_config(
    hash_config: Any,
    history_len: int,
    num_candidates: int,
    num_actions: int,
    batch_size: int = 1,
) -> RecsysBatch:
    """Create a dummy batch for initialization."""
    return RecsysBatch(
        user_hashes=np.zeros((batch_size, hash_config.num_user_hashes), dtype=np.int32),
        history_post_hashes=np.zeros(
            (batch_size, history_len, hash_config.num_item_hashes), dtype=np.int32
        ),
        # ...
        candidate_product_surface=np.zeros((batch_size, num_candidates), dtype=np.int32),
    )

All-zeros batch. Used only for shape inference when initializing the Haiku transform. Haiku needs a sample input to figure out parameter shapes; the values don't matter.

def create_dummy_embeddings_from_config(
    hash_config: Any,
    emb_size: int,
    history_len: int,
    num_candidates: int,
    batch_size: int = 1,
) -> RecsysEmbeddings:
    """Create dummy embeddings for initialization."""
    return RecsysEmbeddings(
        user_embeddings=np.zeros(...),
        # ... 
    )

Same idea for embeddings: zeros of the right shape. Init pass produces the parameter tree; the values are then overwritten by load_model_params.

BaseModelRunner — abstract base

@dataclass
class BaseModelRunner(ABC):
    """Base class for model runners with shared initialization logic."""

    bs_per_device: float = 2.0
    rng_seed: int = 42

    @property
    @abstractmethod
    def model(self) -> Any:
        """Return the model config."""
        pass

    @property
    def _model_name(self) -> str:
        """Return model name for logging."""
        return "model"

    @abstractmethod
    def make_forward_fn(self):
        """Create the forward function. Must be implemented by subclasses."""
        pass

    def initialize(self):
        """Initialize the model runner."""
        self.model.initialize()
        self.model.fprop_dtype = jnp.bfloat16
        num_local_gpus = len(jax.local_devices())

        self.batch_size = max(1, int(self.bs_per_device * num_local_gpus))

        rank_logger.info(f"Initializing {self._model_name}...")
        self.forward = self.make_forward_fn()

The abstract runner. Two abstract methods (model property, make_forward_fn) and a concrete initialize.

bs_per_device: float = 2.0fractional batches per device allowed. bs_per_device=0.5 with 4 devices → batch_size=2. The max(1, ...) guarantees at least 1.

num_local_gpus = len(jax.local_devices()) — works on TPU pods too (returns the local devices, which are the slice this process controls).

Note: forces bfloat16 for forward prop. Standard inference choice — half the memory, full f32 range.

BaseInferenceRunner — abstract inference wrapper

@dataclass
class BaseInferenceRunner(ABC):
    """Base class for inference runners with shared dummy data creation."""

    name: str

    @property
    @abstractmethod
    def runner(self) -> BaseModelRunner:
        """Return the underlying model runner."""
        pass

    def _get_num_actions(self) -> int:
        """Get number of actions. Override in subclasses if needed."""
        model_config = self.runner.model
        if hasattr(model_config, "num_actions"):
            return model_config.num_actions
        return 19

    def create_dummy_batch(self, batch_size: int = 1) -> RecsysBatch:
        # ...
    def create_dummy_embeddings(self, batch_size: int = 1) -> RecsysEmbeddings:
        # ...

Helper layer on top of BaseModelRunner. Encapsulates dummy data creation so subclasses don't repeat it.

_get_num_actions defaults to 19 (the count in ACTIONS below) — the retrieval model doesn't have a num_actions attribute, so it falls back to the constant.

The action list

ACTIONS: List[str] = [
    "favorite_score",
    "reply_score",
    "repost_score",
    "photo_expand_score",
    "click_score",
    "profile_click_score",
    "vqv_score",
    "share_score",
    "share_via_dm_score",
    "share_via_copy_link_score",
    "dwell_score",
    "quote_score",
    "quoted_click_score",
    "follow_author_score",
    "not_interested_score",
    "block_author_score",
    "mute_author_score",
    "report_score",
    "dwell_time",
]

CONTINUOUS_ACTIONS: List[str] = [
    "reserved",
    "dwell_time",
    "video_watch_time",
    "scroll_depth",
    "reserved_3",
    "reserved_4",
    "reserved_5",
    "reserved_6",
]

NEGATIVE_FEEDBACK_INDICES: List[int] = [
    14,
    15,
    16,
    17,
]

The action vocabulary — 19 discrete actions matching the Phoenix model's output dimensions.

Compare to the PhoenixScores fields we saw in Sessions 06 and 10: this list is the same set, in order. Index 0 = favorite, 1 = reply, ..., 18 = dwell_time. The Rust side (PhoenixScorer etc.) reads scores by these indices.

CONTINUOUS_ACTIONS (8 items, with several reserved_* slots) is what the continuous head predicts. Notable: index 0 is "reserved" — the continuous head produces 8 outputs but only 3 are currently meaningful (dwell_time, video_watch_time, scroll_depth). The slots are reserved for future features.

NEGATIVE_FEEDBACK_INDICES = [14, 15, 16, 17]not_interested, block_author, mute_author, report. The negative-feedback signals. These are the ones with negative weights in the ranking formula (Session 10).

RankingOutput

class RankingOutput(NamedTuple):
    """Output from ranking candidates.

    Contains both the raw scores array and individual probability fields
    for each engagement type.
    """

    scores: jax.Array

    ranked_indices: jax.Array

    p_favorite_score: jax.Array
    p_reply_score: jax.Array
    p_repost_score: jax.Array
    # ... 19 fields total
    p_dwell_time: jax.Array
    continuous_preds: Optional[jax.Array] = None

A NamedTuple with both the raw scores array [B, C, num_actions] and 19 individual fields p_<action>_score of shape [B, C]. Plus ranked_indices: [B, C] (the candidate order by score) and continuous_preds.

Why expose both forms? Convenience — Python callers can do output.p_favorite_score instead of slicing into the array. Also makes the named-tuple structure mirror the proto wire format on the Rust side.

ModelRunner — ranking specialization

@dataclass
class ModelRunner(BaseModelRunner):
    """Runner for the recommendation ranking model."""

    _model: PhoenixModelConfig = None  # type: ignore

    def __init__(self, model: PhoenixModelConfig, bs_per_device: float = 2.0, rng_seed: int = 42):
        self._model = model
        self.bs_per_device = bs_per_device
        self.rng_seed = rng_seed

    @property
    def model(self) -> PhoenixModelConfig:
        return self._model

    @property
    def _model_name(self) -> str:
        return "ranking model"

    def make_forward_fn(self):  # type: ignore
        def forward(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings):
            out = self.model.make()(batch, recsys_embeddings)
            return out

        return hk.transform(forward)

The concrete ranking runner. hk.transform(forward) is Haiku's way of turning a function that uses hk.Modules into a (init, apply) pair. After this:

  • forward.init(rng, batch, emb) returns the parameters.
  • forward.apply(params, rng, batch, emb) runs the model.

make_forward_fn is the abstract method from BaseModelRunner — different for each model type.

    def init(
        self, rng: jax.Array, data: RecsysBatch, embeddings: RecsysEmbeddings
    ) -> TrainingState:
        assert self.forward is not None
        rng, init_rng = jax.random.split(rng)
        params = self.forward.init(init_rng, data, embeddings)
        return TrainingState(params=params)

    def load_or_init(
        self,
        init_data: RecsysBatch,
        init_embeddings: RecsysEmbeddings,
        checkpoint_path: Optional[str] = None,
    ):
        if checkpoint_path is not None:
            params = load_model_params(checkpoint_path)
            return TrainingState(params=params)
        rng = jax.random.PRNGKey(self.rng_seed)
        state = self.init(rng, init_data, init_embeddings)
        return state

Two paths:

  • No checkpoint: random init via hk.transform.init.
  • With checkpoint: load params from disk.

In production: always load. In demos / test: random init.

RecsysInferenceRunner — ranking wrapper

@dataclass
class RecsysInferenceRunner(BaseInferenceRunner):
    """Inference runner for the recommendation ranking model."""

    _runner: ModelRunner

    def __init__(self, runner: ModelRunner, name: str):
        self.name = name
        self._runner = runner

    @property
    def runner(self) -> ModelRunner:
        return self._runner

    def initialize(self):
        """Initialize the inference runner."""
        runner = self.runner

        dummy_batch = self.create_dummy_batch(batch_size=1)
        dummy_embeddings = self.create_dummy_embeddings(batch_size=1)

        runner.initialize()

        state = runner.load_or_init(dummy_batch, dummy_embeddings)
        self.params = state.params

        @functools.lru_cache
        def model():
            return runner.model.make()

        def hk_forward(
            batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
        ) -> RecsysModelOutput:
            return model()(batch, recsys_embeddings)

The actual user-facing runner. initialize does:

  1. Build dummy batch + embeddings.
  2. Call runner.initialize() (the abstract from base).
  3. Call runner.load_or_init(...) — gets params (random or from checkpoint).
  4. Cache the model instance via functools.lru_cache.

The @functools.lru_cache is interesting — model() is a Haiku-internal closure. The cache means we get the same Transformer instance each time, which matters for Haiku's parameter-naming consistency.

        def hk_rank_candidates(
            batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
        ) -> RankingOutput:
            """Rank candidates by their predicted engagement scores."""
            output = hk_forward(batch, recsys_embeddings)
            logits = output.logits

            probs = jax.nn.sigmoid(logits)

            primary_scores = probs[:, :, 0]

            ranked_indices = jnp.argsort(-primary_scores, axis=-1)

            return RankingOutput(
                scores=probs,
                ranked_indices=ranked_indices,
                p_favorite_score=probs[:, :, 0],
                p_reply_score=probs[:, :, 1],
                p_repost_score=probs[:, :, 2],
                p_photo_expand_score=probs[:, :, 3],
                p_click_score=probs[:, :, 4],
                p_profile_click_score=probs[:, :, 5],
                p_vqv_score=probs[:, :, 6],
                p_share_score=probs[:, :, 7],
                p_share_via_dm_score=probs[:, :, 8],
                p_share_via_copy_link_score=probs[:, :, 9],
                p_dwell_score=probs[:, :, 10],
                p_quote_score=probs[:, :, 11],
                p_quoted_click_score=probs[:, :, 12],
                p_follow_author_score=probs[:, :, 13],
                p_not_interested_score=probs[:, :, 14],
                p_block_author_score=probs[:, :, 15],
                p_mute_author_score=probs[:, :, 16],
                p_report_score=probs[:, :, 17],
                p_dwell_time=probs[:, :, 18],
                continuous_preds=output.continuous_preds,
            )

        rank_ = hk.without_apply_rng(hk.transform(hk_rank_candidates))
        self.rank_candidates = rank_.apply

Build the rank function:

  1. Forward pass → logits.
  2. Sigmoid to get per-action probabilities.
  3. Sort by index 0 (favorite_score) for ranked_indices. So argsort(-x) gives descending order — highest favorite_score first.
  4. Slice the probs by action index, populate the named-tuple fields.

hk.without_apply_rng(hk.transform(...)) — the alternate transform that returns an apply function NOT requiring an RNG argument. For deterministic inference (no dropout, etc.), drop the RNG ceremony.

self.rank_candidates = rank_.apply — stash the apply function as an attribute. Will be called as self.rank_candidates(params, batch, embeddings).

    def rank(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> RankingOutput:
        """Rank candidates for the given batch.
        ...
        """
        return self.rank_candidates(self.params, batch, recsys_embeddings)

The user-facing entry point.

create_example_batch — synthetic data generator

def create_example_batch(
    batch_size: int,
    emb_size: int,
    history_len: int,
    num_candidates: int,
    num_actions: int,
    num_user_hashes: int = 2,
    num_item_hashes: int = 2,
    num_author_hashes: int = 2,
    product_surface_vocab_size: int = 16,
    num_user_embeddings: int = 1_000_000,
    num_post_embeddings: int = 1_000_000,
    num_author_embeddings: int = 1_000_000,
    include_continuous_actions: bool = False,
    include_timestamps: bool = False,
    num_continuous_actions: int = 8,
) -> Tuple[RecsysBatch, RecsysEmbeddings]:
    """Create an example batch with random data for testing.

    This simulates a recommendation scenario where:
    - We have a user with some embedding
    - The user has interacted with some posts in their history
    - We want to rank a set of candidate posts
    ...
    """
    rng = np.random.default_rng(42)

    user_hashes = rng.integers(1, num_user_embeddings, size=(batch_size, num_user_hashes)).astype(
        np.int32
    )

    history_post_hashes = rng.integers(
        1, num_post_embeddings, size=(batch_size, history_len, num_item_hashes)
    ).astype(np.int32)

    for b in range(batch_size):
        valid_len = rng.integers(history_len // 2, history_len + 1)
        history_post_hashes[b, valid_len:, :] = 0

Generate fake but realistic-shaped data for tests / demos. Hashes in [1, num_embeddings) — note 1, not 0: hash 0 is reserved for padding (we saw this in Session 15 — user_hashes[:, 0] != 0 is the validity check).

The per-batch loop:

for b in range(batch_size):
    valid_len = rng.integers(history_len // 2, history_len + 1)
    history_post_hashes[b, valid_len:, :] = 0

Right-pads each user's history randomly. Some users have full 128-item history, some have only 64. Simulates real users with varying activity levels. The valid_len: slice fills the tail with zeros (= padding).

    history_author_hashes = rng.integers(
        1, num_author_embeddings, size=(batch_size, history_len, num_author_hashes)
    ).astype(np.int32)
    for b in range(batch_size):
        valid_len = rng.integers(history_len // 2, history_len + 1)
        history_author_hashes[b, valid_len:, :] = 0

    history_actions = (rng.random(size=(batch_size, history_len, num_actions)) > 0.7).astype(
        np.float32
    )

Author hashes follow the same pattern.

Actions: random booleans with 0.3 probability of being 1 (> 0.7 means about 30% positive rate). Roughly matches the engagement rate of real timelines — most posts go unengaged.

    history_product_surface = rng.integers(
        0, product_surface_vocab_size, size=(batch_size, history_len)
    ).astype(np.int32)

    candidate_post_hashes = rng.integers(
        1, num_post_embeddings, size=(batch_size, num_candidates, num_item_hashes)
    ).astype(np.int32)

    candidate_author_hashes = rng.integers(
        1, num_author_embeddings, size=(batch_size, num_candidates, num_author_hashes)
    ).astype(np.int32)

    candidate_product_surface = rng.integers(
        0, product_surface_vocab_size, size=(batch_size, num_candidates)
    ).astype(np.int32)

Standard random integers for everything else.

    history_continuous_actions = None
    if include_continuous_actions:
        history_continuous_actions = np.zeros(
            (batch_size, history_len, num_continuous_actions), dtype=np.float32
        )
        history_continuous_actions[:, :, 1] = rng.exponential(
            scale=10.0, size=(batch_size, history_len)
        ).astype(np.float32)

Optional continuous actions — only index 1 (dwell_time) is populated. Generated from an exponential distribution with scale 10 — matches the heavy-tailed shape of real dwell-time data (most dwells are short, a few are long).

    candidate_impr_ts = None
    candidate_post_creation_ts = None
    if include_timestamps:
        base_ts = 1700000000
        candidate_impr_ts = np.full((batch_size, num_candidates), base_ts, dtype=np.int32)
        age_seconds = rng.integers(60, 72 * 3600, size=(batch_size, num_candidates))
        candidate_post_creation_ts = (candidate_impr_ts - age_seconds).astype(np.int32)

Optional timestamps. base_ts = 1700000000 ≈ Nov 2023. Candidate ages randomly from 1 minute to 72 hours.

    batch = RecsysBatch(
        user_hashes=user_hashes,
        history_post_hashes=history_post_hashes,
        # ...
        candidate_post_creation_ts=candidate_post_creation_ts,
    )

    embeddings = RecsysEmbeddings(
        user_embeddings=rng.normal(size=(batch_size, num_user_hashes, emb_size)).astype(np.float32),
        history_post_embeddings=rng.normal(
            size=(batch_size, history_len, num_item_hashes, emb_size)
        ).astype(np.float32),
        # ...
    )

    return batch, embeddings

Random normal embeddings. Real embeddings would be the learned ones — but for demos, random Gaussian is fine.

RetrievalOutput (inference wrapper)

class RetrievalOutput(NamedTuple):
    """Output from retrieval inference.

    Contains user representations and retrieved candidates.
    """

    user_representation: jax.Array

    top_k_indices: jax.Array

    top_k_scores: jax.Array

This is the public retrieval output (the inference runner returns this). Identical shape to ModelRetrievalOutput from Session 15 — but kept as a separate class so the boundary between model code and runner code is clean.

RetrievalModelRunner

@dataclass
class RetrievalModelRunner(BaseModelRunner):
    """Runner for the Phoenix retrieval model."""

    _model: PhoenixRetrievalModelConfig = None  # type: ignore

    def __init__(
        self,
        model: PhoenixRetrievalModelConfig,
        bs_per_device: float = 2.0,
        rng_seed: int = 42,
    ):
        self._model = model
        self.bs_per_device = bs_per_device
        self.rng_seed = rng_seed

    @property
    def model(self) -> PhoenixRetrievalModelConfig:
        return self._model

    @property
    def _model_name(self) -> str:
        return "retrieval model"

Same pattern as ModelRunner. Different config type, different name.

    def make_forward_fn(self):  # type: ignore
        def forward(
            batch: RecsysBatch,
            recsys_embeddings: RecsysEmbeddings,
            corpus_embeddings: jax.Array,
            top_k: int,
        ) -> ModelRetrievalOutput:
            model = self.model.make()
            out = model(batch, recsys_embeddings, corpus_embeddings, top_k)

            _ = model.build_candidate_representation(batch, recsys_embeddings)
            return out

        return hk.transform(forward)

The forward function. Note the trailing call to build_candidate_representation with its result discarded (_ = ...).

Why? Because the CandidateTower parameters only get registered with Haiku if the function is called during init. The main model(batch, embeddings, corpus, top_k) only uses the user tower; the candidate tower is used separately (precomputed offline). To make sure the candidate-tower params are part of the initialized parameter tree, call it during forward — even if we throw the result away.

This is a Haiku-ism: parameters are discovered via use. The init pass walks the forward function once and collects every hk.get_parameter it encounters. Anything not called doesn't get a param.

    def init(
        self,
        rng: jax.Array,
        data: RecsysBatch,
        embeddings: RecsysEmbeddings,
        corpus_embeddings: jax.Array,
        top_k: int,
    ) -> TrainingState:
        assert self.forward is not None
        rng, init_rng = jax.random.split(rng)
        params = self.forward.init(init_rng, data, embeddings, corpus_embeddings, top_k)
        return TrainingState(params=params)

Init signature includes corpus + top_k as required args.

RecsysRetrievalInferenceRunner

@dataclass
class RecsysRetrievalInferenceRunner(BaseInferenceRunner):
    """Inference runner for the Phoenix retrieval model.

    This runner provides methods for:
    1. Encoding users to get user representations
    2. Encoding candidates to get candidate embeddings
    3. Retrieving top-k candidates from a corpus
    """

    _runner: RetrievalModelRunner = None  # type: ignore

    corpus_embeddings: jax.Array | None = None
    corpus_post_ids: jax.Array | None = None

    def __init__(self, runner: RetrievalModelRunner, name: str):
        self.name = name
        self._runner = runner
        self.corpus_embeddings = None
        self.corpus_post_ids = None

Holds the corpus as instance state. The corpus is set once via set_corpus and reused across many retrieve calls.

    def initialize(self):
        """Initialize the retrieval inference runner."""
        runner = self.runner

        dummy_batch = self.create_dummy_batch(batch_size=1)
        dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
        dummy_corpus = jnp.zeros((10, runner.model.emb_size), dtype=jnp.float32)
        dummy_top_k = 5

        runner.initialize()

        state = runner.load_or_init(dummy_batch, dummy_embeddings, dummy_corpus, dummy_top_k)
        self.params = state.params

Same dummy-init pattern. Note dummy corpus size 10 and top_k 5 — small enough to init quickly. After init, the corpus gets replaced via set_corpus.

        @functools.lru_cache
        def model():
            return runner.model.make()

        def hk_encode_user(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
            """Encode user to get user representation."""
            m = model()
            user_rep, _ = m.build_user_representation(batch, recsys_embeddings)
            return user_rep

        def hk_encode_candidates(
            batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
        ) -> jax.Array:
            """Encode candidates to get candidate representations."""
            m = model()
            cand_rep, _ = m.build_candidate_representation(batch, recsys_embeddings)
            return cand_rep

        def hk_retrieve(
            batch: RecsysBatch,
            recsys_embeddings: RecsysEmbeddings,
            corpus_embeddings: jax.Array,
            top_k: int,
        ) -> "RetrievalOutput":
            """Retrieve top-k candidates from corpus."""
            m = model()
            return m(batch, recsys_embeddings, corpus_embeddings, top_k)

Three separate functions for the three use cases:

  1. Encode user only — for online retrieval (build user vector, do ANN search externally).
  2. Encode candidates only — for batch corpus precomputation (offline pipeline that fills the corpus).
  3. Full retrieval — encode user + dot-product with corpus + top-K.

Each gets its own Haiku transform. Each shares the same params (same trained model) but exposes a different slice of its functionality.

        encode_user_ = hk.without_apply_rng(hk.transform(hk_encode_user))
        encode_candidates_ = hk.without_apply_rng(hk.transform(hk_encode_candidates))
        retrieve_ = hk.without_apply_rng(hk.transform(hk_retrieve))

        self.encode_user_fn = encode_user_.apply
        self.encode_candidates_fn = encode_candidates_.apply
        self.retrieve_fn = retrieve_.apply

Cache the apply functions.

    def encode_user(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
        # ...

    def encode_candidates(
        self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
    ) -> jax.Array:
        # ...

    def set_corpus(
        self,
        corpus_embeddings: jax.Array,
        corpus_post_ids: jax.Array,
    ):
        """Set the corpus embeddings for retrieval."""
        self.corpus_embeddings = corpus_embeddings
        self.corpus_post_ids = corpus_post_ids

    def retrieve(
        self,
        batch: RecsysBatch,
        recsys_embeddings: RecsysEmbeddings,
        top_k: int = 100,
        corpus_embeddings: Optional[jax.Array] = None,
    ) -> RetrievalOutput:
        """Retrieve top-k candidates for users."""
        if corpus_beddings is None:
            corpus_embeddings = self.corpus_embeddings

        return self.retrieve_fn(self.params, batch, recsys_embeddings, corpus_embeddings, top_k)

Three public methods plus set_corpus. Note the default top_k = 100 for retrieve. And the optional corpus_embeddings override — pass in a different corpus per-call (e.g. for slicing by language or geography).

create_example_corpus

def create_example_corpus(
    corpus_size: int,
    emb_size: int,
    seed: int = 123,
) -> Tuple[jax.Array, jax.Array]:
    """Create example corpus embeddings for testing retrieval.
    ...
    """
    rng = np.random.default_rng(seed)

    corpus_embeddings = rng.normal(size=(corpus_size, emb_size)).astype(np.float32)
    norms = np.linalg.norm(corpus_embeddings, axis=-1, keepdims=True)
    corpus_embeddings = corpus_embeddings / np.maximum(norms, 1e-12)

    corpus_post_ids = np.arange(corpus_size, dtype=np.int64)

    return jnp.array(corpus_embeddings), jnp.array(corpus_post_ids)

Random Gaussian corpus, L2-normalized. This matches the model's expectation: corpus embeddings are pre-normalized so dot product = cosine similarity.

End of runners.py.


run_pipeline.py (393 lines) — the headline addition

The single-entry-point demo that runs retrieval → ranking from exported checkpoints. The README highlight of this release.

Imports + constants

"""End-to-end retrieval + ranking pipeline from exported artifacts.

Loads exported checkpoints, a pre-computed corpus, and a user action
sequence, then runs:
  1. Retrieval: encode user history → dot product with corpus → top-K
  2. Ranking: score top-K candidates with per-action engagement model

Usage:
    python run_pipeline.py \
        --artifacts_dir ./artifacts \
        --sequence_file ./artifacts/example_sequence.json \
        --corpus_file ./artifacts/sports_corpus.npz \
        --top_k_retrieval 200 \
        --top_k_display 30
"""

import argparse
import json
import logging
import os

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

from grok import TransformerConfig
from recsys_model import (
    HashConfig,
    PhoenixModelConfig,
    RecsysBatch,
    RecsysEmbeddings,
)
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from runners import load_embedding_table, load_model_params

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logging.getLogger("jax").setLevel(logging.WARNING)
log = logging.getLogger(__name__)

IDX_FAV = 1  # SERVER_TWEET_FAV
IDX_REPLY = 4  # SERVER_TWEET_REPLY
IDX_QUOTE = 5  # SERVER_TWEET_QUOTE
IDX_RT = 6  # SERVER_TWEET_RETWEET
IDX_DWELL = 11  # CLIENT_TWEET_RECAP_DWELLED
IDX_VQV = 13  # CLIENT_TWEET_VIDEO_QUALITY_VIEW

Notice the index constants — these are different from the ACTIONS list in runners.py. These reference indices into the history actions vector (the user-action sequence input), NOT the model output. The action sequence has its own taxonomy (SERVER_TWEET_FAV, CLIENT_TWEET_RECAP_DWELLED, etc.).

So:

  • Input actions (history): 19+ types indexed by these constants.
  • Output predictions: 19 actions indexed by ACTIONS.

Different vocabularies because input includes server-side events (favs, retweets) and client-side events (dwell, video view), while output is what the model predicts the user will do.

Hash functions

def _hash_ids(ids, scales, biases, modulus, num_buckets):
    """Hash IDs using the same linear congruential hash as training.

    Uses numpy int64 arithmetic (wrapping overflow) to match training.
    """
    ids = np.asarray(ids, dtype=np.int64).ravel()
    scales = np.array(scales, dtype=np.int64)
    biases = np.array(biases, dtype=np.int64)
    n, m = len(ids), len(scales)
    out = np.empty((n, m), dtype=np.int32)
    for i in range(n):
        for j in range(m):
            raw = (ids[i] * scales[j] + biases[j]) % np.int64(modulus)
            out[i, j] = 0 if ids[i] == 0 else int((int(raw) % (num_buckets - 1)) + 1)
    return out

Linear congruential hashing: (id * scale + bias) % modulus. Then map to [1, num_buckets - 1] (skipping 0 which is padding).

Multiple hashes per ID = multiple scales/biases. So with num_user_hashes=2, there are 2 scale-bias pairs and each ID produces 2 hashes.

The double-loop is slow Python — should be vectorized with numpy, but for a small number of IDs (e.g., 200 candidates × 2 hashes), it doesn't matter.

out[i, j] = 0 if ids[i] == 0 else int((int(raw) % (num_buckets - 1)) + 1):

  • If the original ID is 0 (padding), keep the hash as 0.
  • Otherwise, take raw mod (num_buckets - 1) + 1 to ensure the hash is in [1, num_buckets - 1].

This is the exact same hash function the training pipeline uses, so checkpoint params match the hash buckets they were trained with.

def build_hash_functions(config):
    """Build hash functions from published config."""
    hp = config["hash_params"]
    pad = 65
    uv = config["user_vocab_size"]
    iv = config["item_vocab_size"]
    av = config["author_vocab_size"]

    def hash_user(user_ids):
        h = _hash_ids(user_ids, hp["user_hash_scales"], hp["user_biases"], hp["user_modulus"], uv)
        return np.where(h == 0, 0, h + pad)

    def hash_item(item_ids):
        h = _hash_ids(item_ids, hp["item_hash_scales"], hp["item_biases"], hp["item_modulus"], iv)
        return np.where(h == 0, 0, h + pad + uv)

    def hash_author(author_ids):
        h = _hash_ids(
            author_ids, hp["author_hash_scales"], hp["author_biases"], hp["author_modulus"], av
        )
        return np.where(h == 0, 0, h + pad + uv + iv)

    return hash_user, hash_item, hash_author

Build the three hash functions with separate scale/bias config per entity type. Plus a vocabulary offset:

  • Users: indices [pad, pad + uv).
  • Items: indices [pad + uv, pad + uv + iv).
  • Authors: indices [pad + uv + iv, pad + uv + iv + av).

pad = 65 — first 65 slots reserved for special tokens (padding + some other system-internal symbols).

So the single unified embedding table has slots laid out:

[0..pad)        — reserved (padding etc.)
[pad..pad+uv)   — user embeddings
[pad+uv..pad+uv+iv)   — item embeddings
[pad+uv+iv..pad+uv+iv+av)  — author embeddings

Each hash function shifts its raw hash by the appropriate offset.

Build unified embedding table

def build_unified_emb_table(emb_dict, config):
    """Reconstruct monolithic embedding table from split tables."""
    emb_size = config["emb_size"]
    uv = config["user_vocab_size"]
    iv = config["item_vocab_size"]
    av = config["author_vocab_size"]
    pad = 65
    table = np.zeros((pad + uv + iv + av, emb_size), dtype=np.float32)
    table[pad : pad + uv] = emb_dict["user_embeddings"]
    table[pad + uv : pad + uv + iv] = emb_dict["item_embeddings"]
    table[pad + uv + iv : pad + uv + iv + av] = emb_dict["author_embeddings"]
    return table

Concatenate the three split tables (user / item / author) into one monolithic table indexed by the unified hash space. The first 65 rows stay all-zeros (padding embeddings).

Build model config from JSON

def build_model_config(config, config_class):
    """Build a model config from the published JSON config."""
    kwargs = dict(
        emb_size=config["emb_size"],
        history_seq_len=config["history_seq_len"],
        candidate_seq_len=config["candidate_seq_len"],
        hash_config=HashConfig(
            num_user_hashes=config["num_user_hashes"],
            num_item_hashes=config["num_item_hashes"],
            num_author_hashes=config["num_author_hashes"],
        ),
        product_surface_vocab_size=config.get("product_surface_vocab_size", 16),
        model=TransformerConfig(
            emb_size=config["emb_size"],
            key_size=config["key_size"],
            num_q_heads=config["num_heads"],
            num_kv_heads=config["num_heads"],
            num_layers=config["num_layers"],
            widening_factor=2.0,
            attn_output_multiplier=0.125,
        ),
    )
    if config_class == PhoenixModelConfig:
        kwargs["num_actions"] = config["num_actions"]
        kwargs["post_age_granularity_mins"] = config.get("post_age_granularity_mins", 60)
    elif config_class == PhoenixRetrievalModelConfig:
        kwargs["enable_linear_proj"] = True

    mc = config_class(**kwargs)
    mc.initialize()
    return mc

Reconstruct the Python dataclass config from the JSON file shipped with the checkpoint. Generic across both model types — branches at the end on config_class.

Notable: num_q_heads = num_kv_heads = num_heads from JSON. So this config doesn't use grouped-query attention (where K/V has fewer heads than Q). The Grok-1 model uses GQA but the recsys variant doesn't.

attn_output_multiplier=0.125 — scales the attention output. Specific tuning for this model size.

main

def main():
    parser = argparse.ArgumentParser(description="Run retrieval + ranking pipeline")
    parser.add_argument("--artifacts_dir", default="./artifacts")
    parser.add_argument(
        "--sequence_file",
        default=None,
        help="Path to user action sequence JSON. Default: artifacts_dir/example_sequence.json",
    )
    parser.add_argument(
        "--corpus_file",
        default=None,
        help="Path to corpus NPZ. Default: artifacts_dir/sports_corpus.npz",
    )
    parser.add_argument("--top_k_retrieval", type=int, default=200)
    parser.add_argument("--top_k_display", type=int, default=30)
    args = parser.parse_args()

CLI args. Default top_k_retrieval=200 — retrieve 200 from the corpus. top_k_display=30 — show 30 in the output table.

    artifacts = args.artifacts_dir
    seq_file = args.sequence_file or os.path.join(artifacts, "example_sequence.json")
    corpus_file = args.corpus_file or os.path.join(artifacts, "sports_corpus.npz")

    with open(os.path.join(artifacts, "retrieval", "config.json")) as f:
        ret_cfg = json.load(f)
    with open(os.path.join(artifacts, "ranker", "config.json")) as f:
        rank_cfg = json.load(f)

Load both configs (retrieval + ranker). They're separate JSON files because they have different shapes.

    emb_size = ret_cfg["emb_size"]
    num_actions = ret_cfg["num_actions"]
    hist_len = ret_cfg["history_seq_len"]
    cand_len = ret_cfg["candidate_seq_len"]

    log.info("Loading retrieval model...")
    ret_params = load_model_params(os.path.join(artifacts, "retrieval", "model_params.npz"))
    ret_emb = build_unified_emb_table(
        load_embedding_table(os.path.join(artifacts, "retrieval", "embedding_tables.npz")),
        ret_cfg,
    )

    log.info("Loading ranker model...")
    rank_params = load_model_params(os.path.join(artifacts, "ranker", "model_params.npz"))
    rank_emb = build_unified_emb_table(
        load_embedding_table(os.path.join(artifacts, "ranker", "embedding_tables.npz")),
        rank_cfg,
    )

Load both models:

  • Each has its own params + embedding tables. They are not shared — retrieval and ranking are independently trained.
  • The retrieval model has fewer params (no candidate-tower transformer); the ranker has more.
    log.info("Loading corpus...")
    corpus = np.load(corpus_file, allow_pickle=True)
    corpus_post_ids = corpus["post_ids"]
    corpus_repr = corpus["candidate_representations"]
    corpus_author_ids = corpus["author_ids"]
    corpus_topics = corpus.get("topics", np.array([""] * len(corpus_post_ids)))
    log.info("  %d posts, repr shape %s", len(corpus_post_ids), corpus_repr.shape)

Load the corpus. Keys: post_ids, candidate_representations, author_ids, topics (optional).

corpus_repr is pre-computed using the retrieval model's candidate tower (offline batch job). Shape [N, D]. Loading a 1M-post corpus at D=256 = 256MB.

    log.info("Loading user sequence from %s", seq_file)
    with open(seq_file) as f:
        seq = json.load(f)

    user_id = seq["user_id"]
    history = seq["history"]
    log.info("  User %d, %d history items", user_id, len(history))

Load the user history. The format is a JSON with user_id + a list of history items each having post_id, author_id, actions: { "<idx>": <value> }.

    hash_user, hash_item, hash_author = build_hash_functions(ret_cfg)
    rank_hash_user, rank_hash_item, rank_hash_author = build_hash_functions(rank_cfg)

Two sets of hash functions — different per model (different vocab sizes / scales / biases). Important: a tweet ID hashed for retrieval gives a different bucket than the same tweet ID hashed for ranking. Each model has its own learned representation per hash bucket.

    history_post_ids = np.zeros(hist_len, dtype=np.uint64)
    history_author_ids = np.zeros(hist_len, dtype=np.uint64)
    history_actions = np.zeros((hist_len, num_actions), dtype=np.float32)

    for i, item in enumerate(history[:hist_len]):
        history_post_ids[i] = item["post_id"]
        history_author_ids[i] = item["author_id"]
        for act_idx_str, act_val in item.get("actions", {}).items():
            idx = int(act_idx_str)
            if idx < num_actions:
                history_actions[i, idx] = float(act_val)

Build the history arrays:

  1. Initialize to zeros (= padding).
  2. For up to hist_len items, populate post_id, author_id, actions.
  3. Actions are sparse — only present keys get set. The rest stays 0 (= action didn't happen).

The if idx < num_actions is defensive — protects against the JSON having an action index beyond what the model expects.

    user_hashes = hash_user(np.array([user_id], dtype=np.uint64))
    hist_post_h = hash_item(history_post_ids).reshape(1, hist_len, -1)
    hist_author_h = hash_author(history_author_ids).reshape(1, hist_len, -1)

Hash everything. reshape(1, hist_len, -1) adds the batch dimension (1) and lets the last dim infer from num_hashes.

Retrieval pass

    log.info("Running retrieval...")
    ret_model_config = build_model_config(ret_cfg, PhoenixRetrievalModelConfig)

    N_neg = 64

    def ret_forward(batch, embeddings, gn_post, gn_auth):
        model = ret_model_config.make()
        user_repr, _ = model.build_user_representation(batch, embeddings)
        combined_post = jnp.concatenate([embeddings.candidate_post_embeddings, gn_post], axis=1)
        combined_auth = jnp.concatenate([embeddings.candidate_author_embeddings, gn_auth], axis=1)
        combined_emb = RecsysEmbeddings(
            user_embeddings=embeddings.user_embeddings,
            history_post_embeddings=embeddings.history_post_embeddings,
            candidate_post_embeddings=combined_post,
            history_author_embeddings=embeddings.history_author_embeddings,
            candidate_author_embeddings=combined_auth,
        )
        gn_ph = jnp.ones((1, N_neg, 2), dtype=jnp.int32)
        gn_ah = jnp.ones((1, N_neg, 2), dtype=jnp.int32)
        comb_ph = jnp.concatenate([batch.candidate_post_hashes, gn_ph], axis=1)
        comb_ah = jnp.concatenate([batch.candidate_author_hashes, gn_ah], axis=1)
        comb_ps = jnp.concatenate(
            [batch.candidate_product_surface, jnp.zeros((1, N_neg), dtype=jnp.int32)], axis=1
        )
        comb_batch = batch._replace(
            candidate_post_hashes=comb_ph,
            candidate_author_hashes=comb_ah,
            candidate_product_surface=comb_ps,
        )
        model.build_candidate_representation(comb_batch, combined_emb)
        _ = hk.get_parameter("log_temperature", [], init=hk.initializers.Constant(0.0))
        return user_repr

    ret_fn = hk.without_apply_rng(hk.transform(ret_forward))

This is weird for an inference function. Let me unpack:

  • N_neg = 64 — number of "global negatives" used during training (in-batch negative sampling).
  • The function builds combined embeddings (candidate + N_neg fake negatives) and calls build_candidate_representation on the combined batch.
  • It also reads log_temperature (a learned scalar used in InfoNCE loss).

None of this is needed for inference! The function only returns user_repr. But the candidate-tower call + log_temperature param read are needed for parameter loading:

When the model was trained, hk.transform recorded parameters from:

  1. build_user_representation (transformer + user reduce).
  2. build_candidate_representation on a [1, cand_len + N_neg, ...] shaped input.
  3. The log_temperature scalar.

If we load params trained on this graph but our inference function only calls build_user_representation, Haiku will reject the params (extra params it doesn't recognize).

So this forward function is a shape-matching reconstruction of the training forward — runs everything for parameter registration but only returns what we need.

Ugly but pragmatic. Real production would have a cleaner saved-model format.

    C = cand_len
    batch = RecsysBatch(
        user_hashes=jnp.asarray(user_hashes),
        history_post_hashes=jnp.asarray(hist_post_h),
        history_author_hashes=jnp.asarray(hist_author_h),
        history_actions=jnp.asarray(history_actions.reshape(1, hist_len, num_actions)),
        history_product_surface=jnp.zeros((1, hist_len), dtype=jnp.int32),
        candidate_post_hashes=jnp.zeros((1, C, 2), dtype=jnp.int32),
        candidate_author_hashes=jnp.zeros((1, C, 2), dtype=jnp.int32),
        candidate_product_surface=jnp.zeros((1, C), dtype=jnp.int32),
    )
    emb_batch = RecsysEmbeddings(
        user_embeddings=jnp.asarray(ret_emb[user_hashes]),
        history_post_embeddings=jnp.asarray(ret_emb[hist_post_h]),
        candidate_post_embeddings=jnp.zeros((1, C, 2, emb_size)),
        history_author_embeddings=jnp.asarray(ret_emb[hist_author_h]),
        candidate_author_embeddings=jnp.zeros((1, C, 2, emb_size)),
    )
    dummy_gn = jnp.zeros((1, N_neg, 2, emb_size))
    user_repr = ret_fn.apply(ret_params, batch, emb_batch, dummy_gn, dummy_gn)
    log.info("  User repr norm=%.4f", float(jnp.linalg.norm(user_repr)))

Look up embeddings from the unified table via numpy fancy indexing: ret_emb[user_hashes] returns shape [B, num_hashes, emb_size].

Candidate stuff is zeros (we don't have candidates yet — that's what we're retrieving). The dummy_gn provides the N_neg negatives.

Run the forward. Result: a [1, D] user vector.

user_repr norm log — should be ~1.0 since the model L2-normalizes. Sanity check.

    TOP_K = min(args.top_k_retrieval, len(corpus_post_ids))
    scores = corpus_repr @ np.asarray(user_repr[0])
    top_idx = np.argpartition(scores, -TOP_K)[-TOP_K:]
    top_idx = top_idx[np.argsort(-scores[top_idx])]
    ret_post_ids = corpus_post_ids[top_idx]
    ret_author_ids = corpus_author_ids[top_idx]
    ret_scores = scores[top_idx]
    ret_topics = [str(corpus_topics[i]) for i in top_idx]
    log.info("  Retrieved %d (score range: %.4f - %.4f)", TOP_K, ret_scores[-1], ret_scores[0])

Manual top-K via numpy: argpartition for unsorted top-K (O(N) instead of O(N log N)), then argsort within those K for proper ordering.

corpus_repr @ user_repr[0] is [N, D] × [D] = [N]. Dot product similarity (cosine, since both are normalized).

Then look up post IDs, author IDs, topics for the top K.

Ranking pass

    log.info("Ranking %d candidates...", TOP_K)
    rank_model_config = build_model_config(rank_cfg, PhoenixModelConfig)

    def rank_forward(b, e):
        return rank_model_config.make()(b, e)

    rank_fn = hk.without_apply_rng(hk.transform(rank_forward))

Build the ranker config + transform. Simpler than the retrieval forward — just call the model directly.

    rank_user_h = rank_hash_user(np.array([user_id], dtype=np.uint64))
    rank_hist_post_h = rank_hash_item(history_post_ids).reshape(1, hist_len, -1)
    rank_hist_author_h = rank_hash_author(history_author_ids).reshape(1, hist_len, -1)

    all_probs = []
    for i in range(0, TOP_K, cand_len):
        j = min(i + cand_len, TOP_K)
        cs = j - i
        cph = rank_hash_item(ret_post_ids[i:j]).reshape(1, cs, -1)
        cah = rank_hash_author(ret_author_ids[i:j]).reshape(1, cs, -1)
        if cs < cand_len:
            cph = np.pad(cph, ((0, 0), (0, cand_len - cs), (0, 0)))
            cah = np.pad(cah, ((0, 0), (0, cand_len - cs), (0, 0)))
        rb = RecsysBatch(
            user_hashes=jnp.asarray(rank_user_h),
            history_post_hashes=jnp.asarray(rank_hist_post_h),
            history_author_hashes=jnp.asarray(rank_hist_author_h),
            history_actions=jnp.asarray(history_actions.reshape(1, hist_len, num_actions)),
            history_product_surface=jnp.zeros((1, hist_len), dtype=jnp.int32),
            candidate_post_hashes=jnp.asarray(cph),
            candidate_author_hashes=jnp.asarray(cah),
            candidate_product_surface=jnp.zeros((1, cand_len), dtype=jnp.int32),
        )
        re = RecsysEmbeddings(
            user_embeddings=jnp.asarray(rank_emb[rank_user_h]),
            history_post_embeddings=jnp.asarray(rank_emb[rank_hist_post_h]),
            candidate_post_embeddings=jnp.asarray(rank_emb[cph]),
            history_author_embeddings=jnp.asarray(rank_emb[rank_hist_author_h]),
            candidate_author_embeddings=jnp.asarray(rank_emb[cah]),
        )
        out = rank_fn.apply(rank_params, rb, re)
        probs = jax.nn.sigmoid(out.logits)
        all_probs.append(np.asarray(probs[0, :cs, :]))

Batched ranking — process 200 candidates in chunks of cand_len (e.g., 32). For each chunk:

  1. Hash post/author IDs (using the ranker's hash functions, not retrieval's).
  2. Look up embeddings from the ranker's unified table.
  3. Build batch + embeddings.
  4. Pad if last chunk is shorter than cand_len.
  5. Apply the model.
  6. Sigmoid the logits to get probabilities.
  7. Slice off the padding (probs[0, :cs, :]).
  8. Append to list.

The model is fixed-shape (always cand_len candidates), so we pad and slice to handle the actual count.

    all_probs = np.concatenate(all_probs)
    weighted = (
        all_probs[:, IDX_FAV] * 1.0
        + all_probs[:, IDX_REPLY] * 0.5
        + all_probs[:, IDX_RT] * 0.3
        + all_probs[:, IDX_DWELL] * 0.2
    )
    ranked = np.argsort(-weighted)

Simplified weighted score for the demo. Note this uses IDX_FAV = 1, IDX_REPLY = 4, ... — the input-action indices, NOT the output indices. The output indices are 0-18 from the ACTIONS list.

So all_probs[:, IDX_FAV=1] corresponds to reply_score from ACTIONS. The demo is using mismatched indices! This looks like a bug in the demo.

Looking more carefully: the IDX_* constants reference SERVER_TWEET_FAV, SERVER_TWEET_REPLY, etc. — these are the input action taxonomy. They happen to also correspond to output indices in a coincidental way in some training schemas. For this demo running on the published checkpoint, the user reading the code probably needs to know the published model's actual action mapping (could be either).

Either way: the weighting [1.0, 0.5, 0.3, 0.2] for [fav, reply, retweet, dwell] is a quick demo formula, not production-tuned. The real production weights come from feature switches (Session 10).

    DISPLAY = min(args.top_k_display, TOP_K)
    print("\n" + "=" * 120)
    print(f"PIPELINE RESULTS — User {user_id}")
    print(f"History: {len(history)} items | Corpus: {len(corpus_post_ids)} posts")
    print(f"Retrieved top {TOP_K} → Ranked by engagement model")
    print("=" * 120)
    print(
        f"{'Rank':<5} {'Score':<8} {'Ret':<7} {'Fav':<7} {'Reply':<7} "
        f"{'RT':<7} {'Dwell':<7} {'VQV':<7} {'Topics':<30} Post URL"
    )
    print("-" * 120)

    for rank, idx in enumerate(ranked[:DISPLAY]):
        pid = int(ret_post_ids[idx])
        print(
            f"{rank+1:<5} {float(weighted[idx]):<8.4f} {float(ret_scores[idx]):<7.4f} "
            f"{float(all_probs[idx, IDX_FAV]):<7.4f} {float(all_probs[idx, IDX_REPLY]):<7.4f} "
            f"{float(all_probs[idx, IDX_RT]):<7.4f} {float(all_probs[idx, IDX_DWELL]):<7.4f} "
            f"{float(all_probs[idx, IDX_VQV]):<7.4f} {ret_topics[idx][:28]:<30} "
            f"https://x.com/a/status/{pid}"
        )

    print(
        f"\nWeighted score range: "
        f"[{float(weighted[ranked[-1]]):.4f}, {float(weighted[ranked[0]]):.4f}]"
    )
    print("=" * 120)

Print the results in a formatted table: rank, weighted score, retrieval score, individual action probabilities, topics, post URL.

The URL https://x.com/a/status/<pid> lets the demo user click through to see the actual post. Sanity check by eye: does the model's top recommendation make sense given the user's history?

End of run_pipeline.py.


run_ranker.py (121 lines) — standalone ranker demo

The pre-run_pipeline.py standalone ranker demo. Now superseded by the pipeline, but kept for reference.

def main():
    # Model configuration
    emb_size = 128  # Embedding dimension
    num_actions = len(ACTIONS)  # Number of explicit engagement actions
    history_seq_len = 32  # Max history length
    candidate_seq_len = 8  # Max candidates to rank

    # Hash configuration
    hash_config = HashConfig(
        num_user_hashes=2,
        num_item_hashes=2,
        num_author_hashes=2,
    )

    recsys_model = PhoenixModelConfig(
        emb_size=emb_size,
        num_actions=num_actions,
        history_seq_len=history_seq_len,
        candidate_seq_len=candidate_seq_len,
        hash_config=hash_config,
        product_surface_vocab_size=16,
        model=TransformerConfig(
            emb_size=emb_size,
            widening_factor=2,
            key_size=64,
            num_q_heads=2,
            num_kv_heads=2,
            num_layers=2,
            attn_output_multiplier=0.125,
        ),
    )

    # Create inference runner
    inference_runner = RecsysInferenceRunner(
        runner=ModelRunner(
            model=recsys_model,
            bs_per_device=0.125,
        ),
        name="recsys_local",
    )

Build a tiny model: emb_size=128, 32-item history, 8 candidates, 2 transformer layers, 2 attention heads. Designed to run on a laptop.

bs_per_device=0.125 → batch_size=1 even with 1 GPU. Saves memory.

    print("Initializing model...")
    inference_runner.initialize()
    print("Model initialized!")

    # ...

    example_batch, example_embeddings = create_example_batch(
        batch_size=batch_size,
        emb_size=emb_size,
        history_len=history_seq_len,
        # ...
    )

    # Rank candidates
    ranking_output = inference_runner.rank(example_batch, example_embeddings)

    # Display results
    scores = np.array(ranking_output.scores[0])  # [num_candidates, num_actions]
    ranked_indices = np.array(ranking_output.ranked_indices[0])

    # ...
    for rank, idx in enumerate(ranked_indices):
        # ...
        for action_idx, action_name in enumerate(action_names):
            prob = float(scores[idx, action_idx])
            bar = "█" * int(prob * 20) + "░" * (20 - int(prob * 20))
            print(f"    {action_name:24s}: {bar} {prob:.3f}")

Generates random fake data + runs the model + displays ASCII bar charts of each action probability per candidate. Educational demo.

The model is randomly initialized, so the probabilities are basically random ~0.5. To get meaningful predictions, load real params via inference_runner.runner.load_or_init(..., checkpoint_path=...).

End of run_ranker.py.


run_retrieval.py (149 lines) — standalone retrieval demo

Mirror of run_ranker.py for retrieval.

def main():
    # Model configuration - same architecture as Phoenix ranker
    emb_size = 128
    num_actions = len(ACTIONS)
    history_seq_len = 32
    candidate_seq_len = 8

    hash_config = HashConfig(
        num_user_hashes=2,
        num_item_hashes=2,
        num_author_hashes=2,
    )

    # Configure the retrieval model - uses same transformer as Phoenix
    retrieval_model_config = PhoenixRetrievalModelConfig(
        emb_size=emb_size,
        history_seq_len=history_seq_len,
        candidate_seq_len=candidate_seq_len,
        hash_config=hash_config,
        product_surface_vocab_size=16,
        model=TransformerConfig(
            emb_size=emb_size,
            widening_factor=2,
            key_size=64,
            num_q_heads=2,
            num_kv_heads=2,
            num_layers=2,
            attn_output_multiplier=0.125,
        ),
    )

    inference_runner = RecsysRetrievalInferenceRunner(
        runner=RetrievalModelRunner(
            model=retrieval_model_config,
            bs_per_device=0.125,
        ),
        name="retrieval_local",
    )

Same tiny-model config. Different runner type for retrieval.

    print("Initializing retrieval model...")
    inference_runner.initialize()

    # ...
    batch_size = 2  # Two users for demo
    example_batch, example_embeddings = create_example_batch(
        batch_size=batch_size,
        # ...
    )

    # Step 1: Create a corpus of candidate posts
    corpus_size = 1000  # Simulated corpus of 1000 posts
    corpus_embeddings, corpus_post_ids = create_example_corpus(
        corpus_size=corpus_size,
        emb_size=emb_size,
        seed=456,
    )
    print(f"Corpus size: {corpus_size} posts")

    inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)

    # Step 2: Retrieve top-k candidates for each user
    top_k = 10
    retrieval_output = inference_runner.retrieve(
        example_batch,
        example_embeddings,
        top_k=top_k,
    )

The two-step pipeline: corpus precomputation (here just random) → retrieval call. Two users × 10 results each.

    for user_idx in range(batch_size):
        print(f"\n  User {user_idx + 1}:")
        print(f"    {'Rank':<6} {'Post ID':<12} {'Score':<12}")
        print(f"    {'-' * 30}")
        for rank in range(top_k):
            post_id = top_k_indices[user_idx, rank]
            score = top_k_scores[user_idx, rank]
            bar = "█" * int((score + 1) * 10) + "░" * (20 - int((score + 1) * 10))
            print(f"    {rank + 1:<6} {post_id:<12} {bar} {score:.4f}")

ASCII bar chart of similarity scores. The (score + 1) * 10 mapping: scores are in [-1, 1] (cosine), mapped to [0, 20] for the bar.

End of run_retrieval.py.


What we've learned

Checkpoint format: .npz files containing parameters keyed by slash-separated paths (module/submodule/param_name). Loaded via load_model_params which reconstructs the nested Haiku dict.

Split storage: model params + embedding tables are separate .npz files. Lets large embedding tables stay outside the param tree (different lifecycle, different storage).

Unified embedding table layout:

[0..65)            — padding / system reserved
[65..65+UV)        — user embeddings
[65+UV..65+UV+IV)  — item embeddings
[65+UV+IV..)       — author embeddings

Three hash functions, each with an offset, hash IDs into their respective regions.

Hash function: linear congruential (id * scale + bias) % modulus. Multiple (scale, bias) pairs per entity = multiple hashes (num_*_hashes). The same hash function is used in training and inference — embeddings learned for hash bucket h are the embeddings used at runtime for that hash bucket.

Dummy-init pattern: Haiku needs an example input to infer parameter shapes. All-zeros tensors of the right shape work fine for init; the actual values are then replaced via load_model_params.

The lru_cache of model(): Haiku functions track parameters by name. Caching the model instance ensures consistent naming across the init + apply phases.

Parameter discovery via use: hk.transform only collects parameters from functions it calls. RetrievalModelRunner.make_forward_fn calls build_candidate_representation even though it doesn't need the result — to make sure the candidate-tower params are registered in the parameter tree. Without this, loading trained checkpoints would fail.

Three apply functions for retrieval:

  • encode_user_fn — user vector only.
  • encode_candidates_fn — candidate vectors only (used to precompute the corpus offline).
  • retrieve_fn — full top-K retrieval.

All share the same trained params, just expose different slices of the model.

Two-hash-function-systems: retrieval and ranking models have different hash functions because they have different vocab sizes / scale-bias pairs. A given tweet ID hashes to different buckets in each. Each model's embedding table is keyed by its own hash space.

Chunked ranking: the ranker has a fixed candidate_seq_len. To rank 200 candidates with cand_len=32, process in 7 chunks (last one padded). The mask + pad pattern is common in inference of fixed-shape models.

run_pipeline.py is the major addition this release — a runnable end-to-end demo. Before this release, retrieval and ranking were separately runnable but no glue tied them together. The new script:

  1. Loads both models from checkpoints.
  2. Builds the user representation via retrieval transformer.
  3. Top-K against a pre-computed corpus.
  4. Ranks the top-K via the ranker.
  5. Combines per-action probabilities into a weighted score.
  6. Prints a formatted table.

This is the public, runnable example of what production does.

Action vocabulary mismatch: there are two taxonomies:

  • Input action indices (used in run_pipeline.py's IDX_FAV = 1 etc.) — encode "what the user did in their history."
  • Output action indices (used in runners.py's ACTIONS list, 0-indexed) — encode "what the model predicts about the candidate."

Some indices coincidentally align; others don't. The mapping is model-specific.

Forward function as parameter shape-matcher: run_pipeline.py's ret_forward does a bunch of work it doesn't need (combines negatives, reads log_temperature) just so the parameter tree matches what was trained. Ugly but unavoidable given the open release's checkpoint structure.


Next session

Session 17 — Phoenix grok + tests (1,342 LOC). The actual transformer implementation + the test suite:

  • phoenix/grok.py (616) — the Grok-1 transformer adapted for recsys (attention, RoPE, layer norm, candidate isolation mask)
  • phoenix/test_recsys_model.py (309) — ranking model unit tests
  • phoenix/test_recsys_retrieval_model.py (417) — retrieval model unit tests

After Session 17, we leave Phoenix and enter the Grox content-classification pipeline (Sessions 18-22).