X For You algorithm, line by line · Part 16
X For You algorithm, line by line — Part 16: Phoenix runners + end-to-end pipeline
Part 16 of the deep dive into xai-org/x-algorithm. The Python runner infrastructure: ModelRunner / RetrievalModelRunner with Haiku transform setup, checkpoint loading from .npz, the unified embedding table layout, three apply functions for retrieval, and run_pipeline.py — the headline release addition that runs retrieval → ranking from exported checkpoints.
Session 15 covered the model architecture. This session covers the runner code that loads checkpoints, builds inference functions, encodes inputs, and drives the model end-to-end. Plus the new run_pipeline.py (added in this release) that ties retrieval + ranking together as a runnable example.
Files covered (~1,470 LOC):
phoenix/
├── runners.py (807) ModelRunner, RecsysInferenceRunner, RetrievalModelRunner, etc.
├── run_pipeline.py (393) end-to-end retrieval → ranking from exported checkpoints
├── run_ranker.py (121) standalone ranker demo
└── run_retrieval.py (149) standalone retrieval demo
run_pipeline.py is the headline change in this release per the README:
A new
phoenix/run_pipeline.pyreplaces the separaterun_ranker.pyandrun_retrieval.pyscripts with a single entry point that runs retrieval → ranking from exported checkpoints, mirroring how the two stages are composed in production.
runners.py (807 lines) — the inference machinery
The big infrastructure file. Splits into five chunks:
- Checkpoint loading + dummy data helpers.
- Action / continuous-action constants.
ModelRunner+RecsysInferenceRunnerfor ranking.RetrievalModelRunner+RecsysRetrievalInferenceRunnerfor retrieval.- Example-data factories for testing.
Imports
import functools
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from grok import TrainingState
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from recsys_retrieval_model import RetrievalOutput as ModelRetrievalOutput
from recsys_model import (
PhoenixModelConfig,
RecsysBatch,
RecsysEmbeddings,
RecsysModelOutput,
)
rank_logger = logging.getLogger("rank")
TrainingState from grok — the standard JAX/Haiku pattern: a (params, …) container. Here only params is used since we're doing inference, not training.
Both models import their config + IO types from the model files (Session 15). RetrievalOutput is renamed to ModelRetrievalOutput because this file has another RetrievalOutput named tuple lower down (the inference-runner wrapper).
Checkpoint loading
def load_model_params(checkpoint_path: str) -> hk.Params:
"""Load model parameters from an exported checkpoint.
Args:
checkpoint_path: Path to model_params.npz file
Returns:
Haiku params dict (nested FrozenDict)
"""
data = np.load(checkpoint_path, allow_pickle=True)
params: dict = {}
for key in data.files:
parts = key.split("/")
module_path = "/".join(parts[:-1])
param_name = parts[-1]
params.setdefault(module_path, {})[param_name] = jnp.array(data[key])
return hk.data_structures.to_haiku_dict(params)
The checkpoint format: an .npz (numpy zipped) file where each key is a slash-separated path like phoenix_model/transformer/layer_0/attention/query/w. The last segment is the param name; everything before is the module path.
Reconstruct the nested dict structure that Haiku expects:
{
"phoenix_model/transformer/layer_0/attention/query": {"w": <array>, "b": <array>}
}
hk.data_structures.to_haiku_dict converts the Python dict into the immutable FrozenDict Haiku uses. The output goes into TrainingState(params=...).
def load_embedding_table(path: str) -> np.ndarray:
"""Load an embedding table from an exported checkpoint.
Args:
path: Path to embedding_tables.npz file
Returns:
Dict with 'user_embeddings', 'item_embeddings', 'author_embeddings' arrays
"""
return dict(np.load(path))
Embedding tables are separate from model params. Why? Because:
- Embedding tables are huge (millions of rows × emb_size).
- In production they sit on parameter servers, not on the inference accelerator.
- For the open-source release, they're shipped as separate
.npzfiles so you can use them without loading the full distributed system.
The return type annotation says np.ndarray but the body returns a dict — minor type hint bug.
Dummy data factories
def create_dummy_batch_from_config(
hash_config: Any,
history_len: int,
num_candidates: int,
num_actions: int,
batch_size: int = 1,
) -> RecsysBatch:
"""Create a dummy batch for initialization."""
return RecsysBatch(
user_hashes=np.zeros((batch_size, hash_config.num_user_hashes), dtype=np.int32),
history_post_hashes=np.zeros(
(batch_size, history_len, hash_config.num_item_hashes), dtype=np.int32
),
# ...
candidate_product_surface=np.zeros((batch_size, num_candidates), dtype=np.int32),
)
All-zeros batch. Used only for shape inference when initializing the Haiku transform. Haiku needs a sample input to figure out parameter shapes; the values don't matter.
def create_dummy_embeddings_from_config(
hash_config: Any,
emb_size: int,
history_len: int,
num_candidates: int,
batch_size: int = 1,
) -> RecsysEmbeddings:
"""Create dummy embeddings for initialization."""
return RecsysEmbeddings(
user_embeddings=np.zeros(...),
# ...
)
Same idea for embeddings: zeros of the right shape. Init pass produces the parameter tree; the values are then overwritten by load_model_params.
BaseModelRunner — abstract base
@dataclass
class BaseModelRunner(ABC):
"""Base class for model runners with shared initialization logic."""
bs_per_device: float = 2.0
rng_seed: int = 42
@property
@abstractmethod
def model(self) -> Any:
"""Return the model config."""
pass
@property
def _model_name(self) -> str:
"""Return model name for logging."""
return "model"
@abstractmethod
def make_forward_fn(self):
"""Create the forward function. Must be implemented by subclasses."""
pass
def initialize(self):
"""Initialize the model runner."""
self.model.initialize()
self.model.fprop_dtype = jnp.bfloat16
num_local_gpus = len(jax.local_devices())
self.batch_size = max(1, int(self.bs_per_device * num_local_gpus))
rank_logger.info(f"Initializing {self._model_name}...")
self.forward = self.make_forward_fn()
The abstract runner. Two abstract methods (model property, make_forward_fn) and a concrete initialize.
bs_per_device: float = 2.0 — fractional batches per device allowed. bs_per_device=0.5 with 4 devices → batch_size=2. The max(1, ...) guarantees at least 1.
num_local_gpus = len(jax.local_devices()) — works on TPU pods too (returns the local devices, which are the slice this process controls).
Note: forces bfloat16 for forward prop. Standard inference choice — half the memory, full f32 range.
BaseInferenceRunner — abstract inference wrapper
@dataclass
class BaseInferenceRunner(ABC):
"""Base class for inference runners with shared dummy data creation."""
name: str
@property
@abstractmethod
def runner(self) -> BaseModelRunner:
"""Return the underlying model runner."""
pass
def _get_num_actions(self) -> int:
"""Get number of actions. Override in subclasses if needed."""
model_config = self.runner.model
if hasattr(model_config, "num_actions"):
return model_config.num_actions
return 19
def create_dummy_batch(self, batch_size: int = 1) -> RecsysBatch:
# ...
def create_dummy_embeddings(self, batch_size: int = 1) -> RecsysEmbeddings:
# ...
Helper layer on top of BaseModelRunner. Encapsulates dummy data creation so subclasses don't repeat it.
_get_num_actions defaults to 19 (the count in ACTIONS below) — the retrieval model doesn't have a num_actions attribute, so it falls back to the constant.
The action list
ACTIONS: List[str] = [
"favorite_score",
"reply_score",
"repost_score",
"photo_expand_score",
"click_score",
"profile_click_score",
"vqv_score",
"share_score",
"share_via_dm_score",
"share_via_copy_link_score",
"dwell_score",
"quote_score",
"quoted_click_score",
"follow_author_score",
"not_interested_score",
"block_author_score",
"mute_author_score",
"report_score",
"dwell_time",
]
CONTINUOUS_ACTIONS: List[str] = [
"reserved",
"dwell_time",
"video_watch_time",
"scroll_depth",
"reserved_3",
"reserved_4",
"reserved_5",
"reserved_6",
]
NEGATIVE_FEEDBACK_INDICES: List[int] = [
14,
15,
16,
17,
]
The action vocabulary — 19 discrete actions matching the Phoenix model's output dimensions.
Compare to the PhoenixScores fields we saw in Sessions 06 and 10: this list is the same set, in order. Index 0 = favorite, 1 = reply, ..., 18 = dwell_time. The Rust side (PhoenixScorer etc.) reads scores by these indices.
CONTINUOUS_ACTIONS (8 items, with several reserved_* slots) is what the continuous head predicts. Notable: index 0 is "reserved" — the continuous head produces 8 outputs but only 3 are currently meaningful (dwell_time, video_watch_time, scroll_depth). The slots are reserved for future features.
NEGATIVE_FEEDBACK_INDICES = [14, 15, 16, 17] — not_interested, block_author, mute_author, report. The negative-feedback signals. These are the ones with negative weights in the ranking formula (Session 10).
RankingOutput
class RankingOutput(NamedTuple):
"""Output from ranking candidates.
Contains both the raw scores array and individual probability fields
for each engagement type.
"""
scores: jax.Array
ranked_indices: jax.Array
p_favorite_score: jax.Array
p_reply_score: jax.Array
p_repost_score: jax.Array
# ... 19 fields total
p_dwell_time: jax.Array
continuous_preds: Optional[jax.Array] = None
A NamedTuple with both the raw scores array [B, C, num_actions] and 19 individual fields p_<action>_score of shape [B, C]. Plus ranked_indices: [B, C] (the candidate order by score) and continuous_preds.
Why expose both forms? Convenience — Python callers can do output.p_favorite_score instead of slicing into the array. Also makes the named-tuple structure mirror the proto wire format on the Rust side.
ModelRunner — ranking specialization
@dataclass
class ModelRunner(BaseModelRunner):
"""Runner for the recommendation ranking model."""
_model: PhoenixModelConfig = None # type: ignore
def __init__(self, model: PhoenixModelConfig, bs_per_device: float = 2.0, rng_seed: int = 42):
self._model = model
self.bs_per_device = bs_per_device
self.rng_seed = rng_seed
@property
def model(self) -> PhoenixModelConfig:
return self._model
@property
def _model_name(self) -> str:
return "ranking model"
def make_forward_fn(self): # type: ignore
def forward(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings):
out = self.model.make()(batch, recsys_embeddings)
return out
return hk.transform(forward)
The concrete ranking runner. hk.transform(forward) is Haiku's way of turning a function that uses hk.Modules into a (init, apply) pair. After this:
forward.init(rng, batch, emb)returns the parameters.forward.apply(params, rng, batch, emb)runs the model.
make_forward_fn is the abstract method from BaseModelRunner — different for each model type.
def init(
self, rng: jax.Array, data: RecsysBatch, embeddings: RecsysEmbeddings
) -> TrainingState:
assert self.forward is not None
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data, embeddings)
return TrainingState(params=params)
def load_or_init(
self,
init_data: RecsysBatch,
init_embeddings: RecsysEmbeddings,
checkpoint_path: Optional[str] = None,
):
if checkpoint_path is not None:
params = load_model_params(checkpoint_path)
return TrainingState(params=params)
rng = jax.random.PRNGKey(self.rng_seed)
state = self.init(rng, init_data, init_embeddings)
return state
Two paths:
- No checkpoint: random init via
hk.transform.init. - With checkpoint: load params from disk.
In production: always load. In demos / test: random init.
RecsysInferenceRunner — ranking wrapper
@dataclass
class RecsysInferenceRunner(BaseInferenceRunner):
"""Inference runner for the recommendation ranking model."""
_runner: ModelRunner
def __init__(self, runner: ModelRunner, name: str):
self.name = name
self._runner = runner
@property
def runner(self) -> ModelRunner:
return self._runner
def initialize(self):
"""Initialize the inference runner."""
runner = self.runner
dummy_batch = self.create_dummy_batch(batch_size=1)
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
runner.initialize()
state = runner.load_or_init(dummy_batch, dummy_embeddings)
self.params = state.params
@functools.lru_cache
def model():
return runner.model.make()
def hk_forward(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> RecsysModelOutput:
return model()(batch, recsys_embeddings)
The actual user-facing runner. initialize does:
- Build dummy batch + embeddings.
- Call
runner.initialize()(the abstract from base). - Call
runner.load_or_init(...)— gets params (random or from checkpoint). - Cache the model instance via
functools.lru_cache.
The @functools.lru_cache is interesting — model() is a Haiku-internal closure. The cache means we get the same Transformer instance each time, which matters for Haiku's parameter-naming consistency.
def hk_rank_candidates(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> RankingOutput:
"""Rank candidates by their predicted engagement scores."""
output = hk_forward(batch, recsys_embeddings)
logits = output.logits
probs = jax.nn.sigmoid(logits)
primary_scores = probs[:, :, 0]
ranked_indices = jnp.argsort(-primary_scores, axis=-1)
return RankingOutput(
scores=probs,
ranked_indices=ranked_indices,
p_favorite_score=probs[:, :, 0],
p_reply_score=probs[:, :, 1],
p_repost_score=probs[:, :, 2],
p_photo_expand_score=probs[:, :, 3],
p_click_score=probs[:, :, 4],
p_profile_click_score=probs[:, :, 5],
p_vqv_score=probs[:, :, 6],
p_share_score=probs[:, :, 7],
p_share_via_dm_score=probs[:, :, 8],
p_share_via_copy_link_score=probs[:, :, 9],
p_dwell_score=probs[:, :, 10],
p_quote_score=probs[:, :, 11],
p_quoted_click_score=probs[:, :, 12],
p_follow_author_score=probs[:, :, 13],
p_not_interested_score=probs[:, :, 14],
p_block_author_score=probs[:, :, 15],
p_mute_author_score=probs[:, :, 16],
p_report_score=probs[:, :, 17],
p_dwell_time=probs[:, :, 18],
continuous_preds=output.continuous_preds,
)
rank_ = hk.without_apply_rng(hk.transform(hk_rank_candidates))
self.rank_candidates = rank_.apply
Build the rank function:
- Forward pass → logits.
- Sigmoid to get per-action probabilities.
- Sort by index 0 (favorite_score) for
ranked_indices. Soargsort(-x)gives descending order — highest favorite_score first. - Slice the probs by action index, populate the named-tuple fields.
hk.without_apply_rng(hk.transform(...)) — the alternate transform that returns an apply function NOT requiring an RNG argument. For deterministic inference (no dropout, etc.), drop the RNG ceremony.
self.rank_candidates = rank_.apply — stash the apply function as an attribute. Will be called as self.rank_candidates(params, batch, embeddings).
def rank(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> RankingOutput:
"""Rank candidates for the given batch.
...
"""
return self.rank_candidates(self.params, batch, recsys_embeddings)
The user-facing entry point.
create_example_batch — synthetic data generator
def create_example_batch(
batch_size: int,
emb_size: int,
history_len: int,
num_candidates: int,
num_actions: int,
num_user_hashes: int = 2,
num_item_hashes: int = 2,
num_author_hashes: int = 2,
product_surface_vocab_size: int = 16,
num_user_embeddings: int = 1_000_000,
num_post_embeddings: int = 1_000_000,
num_author_embeddings: int = 1_000_000,
include_continuous_actions: bool = False,
include_timestamps: bool = False,
num_continuous_actions: int = 8,
) -> Tuple[RecsysBatch, RecsysEmbeddings]:
"""Create an example batch with random data for testing.
This simulates a recommendation scenario where:
- We have a user with some embedding
- The user has interacted with some posts in their history
- We want to rank a set of candidate posts
...
"""
rng = np.random.default_rng(42)
user_hashes = rng.integers(1, num_user_embeddings, size=(batch_size, num_user_hashes)).astype(
np.int32
)
history_post_hashes = rng.integers(
1, num_post_embeddings, size=(batch_size, history_len, num_item_hashes)
).astype(np.int32)
for b in range(batch_size):
valid_len = rng.integers(history_len // 2, history_len + 1)
history_post_hashes[b, valid_len:, :] = 0
Generate fake but realistic-shaped data for tests / demos. Hashes in [1, num_embeddings) — note 1, not 0: hash 0 is reserved for padding (we saw this in Session 15 — user_hashes[:, 0] != 0 is the validity check).
The per-batch loop:
for b in range(batch_size):
valid_len = rng.integers(history_len // 2, history_len + 1)
history_post_hashes[b, valid_len:, :] = 0
Right-pads each user's history randomly. Some users have full 128-item history, some have only 64. Simulates real users with varying activity levels. The valid_len: slice fills the tail with zeros (= padding).
history_author_hashes = rng.integers(
1, num_author_embeddings, size=(batch_size, history_len, num_author_hashes)
).astype(np.int32)
for b in range(batch_size):
valid_len = rng.integers(history_len // 2, history_len + 1)
history_author_hashes[b, valid_len:, :] = 0
history_actions = (rng.random(size=(batch_size, history_len, num_actions)) > 0.7).astype(
np.float32
)
Author hashes follow the same pattern.
Actions: random booleans with 0.3 probability of being 1 (> 0.7 means about 30% positive rate). Roughly matches the engagement rate of real timelines — most posts go unengaged.
history_product_surface = rng.integers(
0, product_surface_vocab_size, size=(batch_size, history_len)
).astype(np.int32)
candidate_post_hashes = rng.integers(
1, num_post_embeddings, size=(batch_size, num_candidates, num_item_hashes)
).astype(np.int32)
candidate_author_hashes = rng.integers(
1, num_author_embeddings, size=(batch_size, num_candidates, num_author_hashes)
).astype(np.int32)
candidate_product_surface = rng.integers(
0, product_surface_vocab_size, size=(batch_size, num_candidates)
).astype(np.int32)
Standard random integers for everything else.
history_continuous_actions = None
if include_continuous_actions:
history_continuous_actions = np.zeros(
(batch_size, history_len, num_continuous_actions), dtype=np.float32
)
history_continuous_actions[:, :, 1] = rng.exponential(
scale=10.0, size=(batch_size, history_len)
).astype(np.float32)
Optional continuous actions — only index 1 (dwell_time) is populated. Generated from an exponential distribution with scale 10 — matches the heavy-tailed shape of real dwell-time data (most dwells are short, a few are long).
candidate_impr_ts = None
candidate_post_creation_ts = None
if include_timestamps:
base_ts = 1700000000
candidate_impr_ts = np.full((batch_size, num_candidates), base_ts, dtype=np.int32)
age_seconds = rng.integers(60, 72 * 3600, size=(batch_size, num_candidates))
candidate_post_creation_ts = (candidate_impr_ts - age_seconds).astype(np.int32)
Optional timestamps. base_ts = 1700000000 ≈ Nov 2023. Candidate ages randomly from 1 minute to 72 hours.
batch = RecsysBatch(
user_hashes=user_hashes,
history_post_hashes=history_post_hashes,
# ...
candidate_post_creation_ts=candidate_post_creation_ts,
)
embeddings = RecsysEmbeddings(
user_embeddings=rng.normal(size=(batch_size, num_user_hashes, emb_size)).astype(np.float32),
history_post_embeddings=rng.normal(
size=(batch_size, history_len, num_item_hashes, emb_size)
).astype(np.float32),
# ...
)
return batch, embeddings
Random normal embeddings. Real embeddings would be the learned ones — but for demos, random Gaussian is fine.
RetrievalOutput (inference wrapper)
class RetrievalOutput(NamedTuple):
"""Output from retrieval inference.
Contains user representations and retrieved candidates.
"""
user_representation: jax.Array
top_k_indices: jax.Array
top_k_scores: jax.Array
This is the public retrieval output (the inference runner returns this). Identical shape to ModelRetrievalOutput from Session 15 — but kept as a separate class so the boundary between model code and runner code is clean.
RetrievalModelRunner
@dataclass
class RetrievalModelRunner(BaseModelRunner):
"""Runner for the Phoenix retrieval model."""
_model: PhoenixRetrievalModelConfig = None # type: ignore
def __init__(
self,
model: PhoenixRetrievalModelConfig,
bs_per_device: float = 2.0,
rng_seed: int = 42,
):
self._model = model
self.bs_per_device = bs_per_device
self.rng_seed = rng_seed
@property
def model(self) -> PhoenixRetrievalModelConfig:
return self._model
@property
def _model_name(self) -> str:
return "retrieval model"
Same pattern as ModelRunner. Different config type, different name.
def make_forward_fn(self): # type: ignore
def forward(
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> ModelRetrievalOutput:
model = self.model.make()
out = model(batch, recsys_embeddings, corpus_embeddings, top_k)
_ = model.build_candidate_representation(batch, recsys_embeddings)
return out
return hk.transform(forward)
The forward function. Note the trailing call to build_candidate_representation with its result discarded (_ = ...).
Why? Because the CandidateTower parameters only get registered with Haiku if the function is called during init. The main model(batch, embeddings, corpus, top_k) only uses the user tower; the candidate tower is used separately (precomputed offline). To make sure the candidate-tower params are part of the initialized parameter tree, call it during forward — even if we throw the result away.
This is a Haiku-ism: parameters are discovered via use. The init pass walks the forward function once and collects every hk.get_parameter it encounters. Anything not called doesn't get a param.
def init(
self,
rng: jax.Array,
data: RecsysBatch,
embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> TrainingState:
assert self.forward is not None
rng, init_rng = jax.random.split(rng)
params = self.forward.init(init_rng, data, embeddings, corpus_embeddings, top_k)
return TrainingState(params=params)
Init signature includes corpus + top_k as required args.
RecsysRetrievalInferenceRunner
@dataclass
class RecsysRetrievalInferenceRunner(BaseInferenceRunner):
"""Inference runner for the Phoenix retrieval model.
This runner provides methods for:
1. Encoding users to get user representations
2. Encoding candidates to get candidate embeddings
3. Retrieving top-k candidates from a corpus
"""
_runner: RetrievalModelRunner = None # type: ignore
corpus_embeddings: jax.Array | None = None
corpus_post_ids: jax.Array | None = None
def __init__(self, runner: RetrievalModelRunner, name: str):
self.name = name
self._runner = runner
self.corpus_embeddings = None
self.corpus_post_ids = None
Holds the corpus as instance state. The corpus is set once via set_corpus and reused across many retrieve calls.
def initialize(self):
"""Initialize the retrieval inference runner."""
runner = self.runner
dummy_batch = self.create_dummy_batch(batch_size=1)
dummy_embeddings = self.create_dummy_embeddings(batch_size=1)
dummy_corpus = jnp.zeros((10, runner.model.emb_size), dtype=jnp.float32)
dummy_top_k = 5
runner.initialize()
state = runner.load_or_init(dummy_batch, dummy_embeddings, dummy_corpus, dummy_top_k)
self.params = state.params
Same dummy-init pattern. Note dummy corpus size 10 and top_k 5 — small enough to init quickly. After init, the corpus gets replaced via set_corpus.
@functools.lru_cache
def model():
return runner.model.make()
def hk_encode_user(batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
"""Encode user to get user representation."""
m = model()
user_rep, _ = m.build_user_representation(batch, recsys_embeddings)
return user_rep
def hk_encode_candidates(
batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> jax.Array:
"""Encode candidates to get candidate representations."""
m = model()
cand_rep, _ = m.build_candidate_representation(batch, recsys_embeddings)
return cand_rep
def hk_retrieve(
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
) -> "RetrievalOutput":
"""Retrieve top-k candidates from corpus."""
m = model()
return m(batch, recsys_embeddings, corpus_embeddings, top_k)
Three separate functions for the three use cases:
- Encode user only — for online retrieval (build user vector, do ANN search externally).
- Encode candidates only — for batch corpus precomputation (offline pipeline that fills the corpus).
- Full retrieval — encode user + dot-product with corpus + top-K.
Each gets its own Haiku transform. Each shares the same params (same trained model) but exposes a different slice of its functionality.
encode_user_ = hk.without_apply_rng(hk.transform(hk_encode_user))
encode_candidates_ = hk.without_apply_rng(hk.transform(hk_encode_candidates))
retrieve_ = hk.without_apply_rng(hk.transform(hk_retrieve))
self.encode_user_fn = encode_user_.apply
self.encode_candidates_fn = encode_candidates_.apply
self.retrieve_fn = retrieve_.apply
Cache the apply functions.
def encode_user(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings) -> jax.Array:
# ...
def encode_candidates(
self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings
) -> jax.Array:
# ...
def set_corpus(
self,
corpus_embeddings: jax.Array,
corpus_post_ids: jax.Array,
):
"""Set the corpus embeddings for retrieval."""
self.corpus_embeddings = corpus_embeddings
self.corpus_post_ids = corpus_post_ids
def retrieve(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
top_k: int = 100,
corpus_embeddings: Optional[jax.Array] = None,
) -> RetrievalOutput:
"""Retrieve top-k candidates for users."""
if corpus_beddings is None:
corpus_embeddings = self.corpus_embeddings
return self.retrieve_fn(self.params, batch, recsys_embeddings, corpus_embeddings, top_k)
Three public methods plus set_corpus. Note the default top_k = 100 for retrieve. And the optional corpus_embeddings override — pass in a different corpus per-call (e.g. for slicing by language or geography).
create_example_corpus
def create_example_corpus(
corpus_size: int,
emb_size: int,
seed: int = 123,
) -> Tuple[jax.Array, jax.Array]:
"""Create example corpus embeddings for testing retrieval.
...
"""
rng = np.random.default_rng(seed)
corpus_embeddings = rng.normal(size=(corpus_size, emb_size)).astype(np.float32)
norms = np.linalg.norm(corpus_embeddings, axis=-1, keepdims=True)
corpus_embeddings = corpus_embeddings / np.maximum(norms, 1e-12)
corpus_post_ids = np.arange(corpus_size, dtype=np.int64)
return jnp.array(corpus_embeddings), jnp.array(corpus_post_ids)
Random Gaussian corpus, L2-normalized. This matches the model's expectation: corpus embeddings are pre-normalized so dot product = cosine similarity.
End of runners.py.
run_pipeline.py (393 lines) — the headline addition
The single-entry-point demo that runs retrieval → ranking from exported checkpoints. The README highlight of this release.
Imports + constants
"""End-to-end retrieval + ranking pipeline from exported artifacts.
Loads exported checkpoints, a pre-computed corpus, and a user action
sequence, then runs:
1. Retrieval: encode user history → dot product with corpus → top-K
2. Ranking: score top-K candidates with per-action engagement model
Usage:
python run_pipeline.py \
--artifacts_dir ./artifacts \
--sequence_file ./artifacts/example_sequence.json \
--corpus_file ./artifacts/sports_corpus.npz \
--top_k_retrieval 200 \
--top_k_display 30
"""
import argparse
import json
import logging
import os
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from grok import TransformerConfig
from recsys_model import (
HashConfig,
PhoenixModelConfig,
RecsysBatch,
RecsysEmbeddings,
)
from recsys_retrieval_model import PhoenixRetrievalModelConfig
from runners import load_embedding_table, load_model_params
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logging.getLogger("jax").setLevel(logging.WARNING)
log = logging.getLogger(__name__)
IDX_FAV = 1 # SERVER_TWEET_FAV
IDX_REPLY = 4 # SERVER_TWEET_REPLY
IDX_QUOTE = 5 # SERVER_TWEET_QUOTE
IDX_RT = 6 # SERVER_TWEET_RETWEET
IDX_DWELL = 11 # CLIENT_TWEET_RECAP_DWELLED
IDX_VQV = 13 # CLIENT_TWEET_VIDEO_QUALITY_VIEW
Notice the index constants — these are different from the ACTIONS list in runners.py. These reference indices into the history actions vector (the user-action sequence input), NOT the model output. The action sequence has its own taxonomy (SERVER_TWEET_FAV, CLIENT_TWEET_RECAP_DWELLED, etc.).
So:
- Input actions (history): 19+ types indexed by these constants.
- Output predictions: 19 actions indexed by
ACTIONS.
Different vocabularies because input includes server-side events (favs, retweets) and client-side events (dwell, video view), while output is what the model predicts the user will do.
Hash functions
def _hash_ids(ids, scales, biases, modulus, num_buckets):
"""Hash IDs using the same linear congruential hash as training.
Uses numpy int64 arithmetic (wrapping overflow) to match training.
"""
ids = np.asarray(ids, dtype=np.int64).ravel()
scales = np.array(scales, dtype=np.int64)
biases = np.array(biases, dtype=np.int64)
n, m = len(ids), len(scales)
out = np.empty((n, m), dtype=np.int32)
for i in range(n):
for j in range(m):
raw = (ids[i] * scales[j] + biases[j]) % np.int64(modulus)
out[i, j] = 0 if ids[i] == 0 else int((int(raw) % (num_buckets - 1)) + 1)
return out
Linear congruential hashing: (id * scale + bias) % modulus. Then map to [1, num_buckets - 1] (skipping 0 which is padding).
Multiple hashes per ID = multiple scales/biases. So with num_user_hashes=2, there are 2 scale-bias pairs and each ID produces 2 hashes.
The double-loop is slow Python — should be vectorized with numpy, but for a small number of IDs (e.g., 200 candidates × 2 hashes), it doesn't matter.
out[i, j] = 0 if ids[i] == 0 else int((int(raw) % (num_buckets - 1)) + 1):
- If the original ID is 0 (padding), keep the hash as 0.
- Otherwise, take
raw mod (num_buckets - 1) + 1to ensure the hash is in[1, num_buckets - 1].
This is the exact same hash function the training pipeline uses, so checkpoint params match the hash buckets they were trained with.
def build_hash_functions(config):
"""Build hash functions from published config."""
hp = config["hash_params"]
pad = 65
uv = config["user_vocab_size"]
iv = config["item_vocab_size"]
av = config["author_vocab_size"]
def hash_user(user_ids):
h = _hash_ids(user_ids, hp["user_hash_scales"], hp["user_biases"], hp["user_modulus"], uv)
return np.where(h == 0, 0, h + pad)
def hash_item(item_ids):
h = _hash_ids(item_ids, hp["item_hash_scales"], hp["item_biases"], hp["item_modulus"], iv)
return np.where(h == 0, 0, h + pad + uv)
def hash_author(author_ids):
h = _hash_ids(
author_ids, hp["author_hash_scales"], hp["author_biases"], hp["author_modulus"], av
)
return np.where(h == 0, 0, h + pad + uv + iv)
return hash_user, hash_item, hash_author
Build the three hash functions with separate scale/bias config per entity type. Plus a vocabulary offset:
- Users: indices
[pad, pad + uv). - Items: indices
[pad + uv, pad + uv + iv). - Authors: indices
[pad + uv + iv, pad + uv + iv + av).
pad = 65 — first 65 slots reserved for special tokens (padding + some other system-internal symbols).
So the single unified embedding table has slots laid out:
[0..pad) — reserved (padding etc.)
[pad..pad+uv) — user embeddings
[pad+uv..pad+uv+iv) — item embeddings
[pad+uv+iv..pad+uv+iv+av) — author embeddings
Each hash function shifts its raw hash by the appropriate offset.
Build unified embedding table
def build_unified_emb_table(emb_dict, config):
"""Reconstruct monolithic embedding table from split tables."""
emb_size = config["emb_size"]
uv = config["user_vocab_size"]
iv = config["item_vocab_size"]
av = config["author_vocab_size"]
pad = 65
table = np.zeros((pad + uv + iv + av, emb_size), dtype=np.float32)
table[pad : pad + uv] = emb_dict["user_embeddings"]
table[pad + uv : pad + uv + iv] = emb_dict["item_embeddings"]
table[pad + uv + iv : pad + uv + iv + av] = emb_dict["author_embeddings"]
return table
Concatenate the three split tables (user / item / author) into one monolithic table indexed by the unified hash space. The first 65 rows stay all-zeros (padding embeddings).
Build model config from JSON
def build_model_config(config, config_class):
"""Build a model config from the published JSON config."""
kwargs = dict(
emb_size=config["emb_size"],
history_seq_len=config["history_seq_len"],
candidate_seq_len=config["candidate_seq_len"],
hash_config=HashConfig(
num_user_hashes=config["num_user_hashes"],
num_item_hashes=config["num_item_hashes"],
num_author_hashes=config["num_author_hashes"],
),
product_surface_vocab_size=config.get("product_surface_vocab_size", 16),
model=TransformerConfig(
emb_size=config["emb_size"],
key_size=config["key_size"],
num_q_heads=config["num_heads"],
num_kv_heads=config["num_heads"],
num_layers=config["num_layers"],
widening_factor=2.0,
attn_output_multiplier=0.125,
),
)
if config_class == PhoenixModelConfig:
kwargs["num_actions"] = config["num_actions"]
kwargs["post_age_granularity_mins"] = config.get("post_age_granularity_mins", 60)
elif config_class == PhoenixRetrievalModelConfig:
kwargs["enable_linear_proj"] = True
mc = config_class(**kwargs)
mc.initialize()
return mc
Reconstruct the Python dataclass config from the JSON file shipped with the checkpoint. Generic across both model types — branches at the end on config_class.
Notable: num_q_heads = num_kv_heads = num_heads from JSON. So this config doesn't use grouped-query attention (where K/V has fewer heads than Q). The Grok-1 model uses GQA but the recsys variant doesn't.
attn_output_multiplier=0.125 — scales the attention output. Specific tuning for this model size.
main
def main():
parser = argparse.ArgumentParser(description="Run retrieval + ranking pipeline")
parser.add_argument("--artifacts_dir", default="./artifacts")
parser.add_argument(
"--sequence_file",
default=None,
help="Path to user action sequence JSON. Default: artifacts_dir/example_sequence.json",
)
parser.add_argument(
"--corpus_file",
default=None,
help="Path to corpus NPZ. Default: artifacts_dir/sports_corpus.npz",
)
parser.add_argument("--top_k_retrieval", type=int, default=200)
parser.add_argument("--top_k_display", type=int, default=30)
args = parser.parse_args()
CLI args. Default top_k_retrieval=200 — retrieve 200 from the corpus. top_k_display=30 — show 30 in the output table.
artifacts = args.artifacts_dir
seq_file = args.sequence_file or os.path.join(artifacts, "example_sequence.json")
corpus_file = args.corpus_file or os.path.join(artifacts, "sports_corpus.npz")
with open(os.path.join(artifacts, "retrieval", "config.json")) as f:
ret_cfg = json.load(f)
with open(os.path.join(artifacts, "ranker", "config.json")) as f:
rank_cfg = json.load(f)
Load both configs (retrieval + ranker). They're separate JSON files because they have different shapes.
emb_size = ret_cfg["emb_size"]
num_actions = ret_cfg["num_actions"]
hist_len = ret_cfg["history_seq_len"]
cand_len = ret_cfg["candidate_seq_len"]
log.info("Loading retrieval model...")
ret_params = load_model_params(os.path.join(artifacts, "retrieval", "model_params.npz"))
ret_emb = build_unified_emb_table(
load_embedding_table(os.path.join(artifacts, "retrieval", "embedding_tables.npz")),
ret_cfg,
)
log.info("Loading ranker model...")
rank_params = load_model_params(os.path.join(artifacts, "ranker", "model_params.npz"))
rank_emb = build_unified_emb_table(
load_embedding_table(os.path.join(artifacts, "ranker", "embedding_tables.npz")),
rank_cfg,
)
Load both models:
- Each has its own params + embedding tables. They are not shared — retrieval and ranking are independently trained.
- The retrieval model has fewer params (no candidate-tower transformer); the ranker has more.
log.info("Loading corpus...")
corpus = np.load(corpus_file, allow_pickle=True)
corpus_post_ids = corpus["post_ids"]
corpus_repr = corpus["candidate_representations"]
corpus_author_ids = corpus["author_ids"]
corpus_topics = corpus.get("topics", np.array([""] * len(corpus_post_ids)))
log.info(" %d posts, repr shape %s", len(corpus_post_ids), corpus_repr.shape)
Load the corpus. Keys: post_ids, candidate_representations, author_ids, topics (optional).
corpus_repr is pre-computed using the retrieval model's candidate tower (offline batch job). Shape [N, D]. Loading a 1M-post corpus at D=256 = 256MB.
log.info("Loading user sequence from %s", seq_file)
with open(seq_file) as f:
seq = json.load(f)
user_id = seq["user_id"]
history = seq["history"]
log.info(" User %d, %d history items", user_id, len(history))
Load the user history. The format is a JSON with user_id + a list of history items each having post_id, author_id, actions: { "<idx>": <value> }.
hash_user, hash_item, hash_author = build_hash_functions(ret_cfg)
rank_hash_user, rank_hash_item, rank_hash_author = build_hash_functions(rank_cfg)
Two sets of hash functions — different per model (different vocab sizes / scales / biases). Important: a tweet ID hashed for retrieval gives a different bucket than the same tweet ID hashed for ranking. Each model has its own learned representation per hash bucket.
history_post_ids = np.zeros(hist_len, dtype=np.uint64)
history_author_ids = np.zeros(hist_len, dtype=np.uint64)
history_actions = np.zeros((hist_len, num_actions), dtype=np.float32)
for i, item in enumerate(history[:hist_len]):
history_post_ids[i] = item["post_id"]
history_author_ids[i] = item["author_id"]
for act_idx_str, act_val in item.get("actions", {}).items():
idx = int(act_idx_str)
if idx < num_actions:
history_actions[i, idx] = float(act_val)
Build the history arrays:
- Initialize to zeros (= padding).
- For up to
hist_lenitems, populate post_id, author_id, actions. - Actions are sparse — only present keys get set. The rest stays 0 (= action didn't happen).
The if idx < num_actions is defensive — protects against the JSON having an action index beyond what the model expects.
user_hashes = hash_user(np.array([user_id], dtype=np.uint64))
hist_post_h = hash_item(history_post_ids).reshape(1, hist_len, -1)
hist_author_h = hash_author(history_author_ids).reshape(1, hist_len, -1)
Hash everything. reshape(1, hist_len, -1) adds the batch dimension (1) and lets the last dim infer from num_hashes.
Retrieval pass
log.info("Running retrieval...")
ret_model_config = build_model_config(ret_cfg, PhoenixRetrievalModelConfig)
N_neg = 64
def ret_forward(batch, embeddings, gn_post, gn_auth):
model = ret_model_config.make()
user_repr, _ = model.build_user_representation(batch, embeddings)
combined_post = jnp.concatenate([embeddings.candidate_post_embeddings, gn_post], axis=1)
combined_auth = jnp.concatenate([embeddings.candidate_author_embeddings, gn_auth], axis=1)
combined_emb = RecsysEmbeddings(
user_embeddings=embeddings.user_embeddings,
history_post_embeddings=embeddings.history_post_embeddings,
candidate_post_embeddings=combined_post,
history_author_embeddings=embeddings.history_author_embeddings,
candidate_author_embeddings=combined_auth,
)
gn_ph = jnp.ones((1, N_neg, 2), dtype=jnp.int32)
gn_ah = jnp.ones((1, N_neg, 2), dtype=jnp.int32)
comb_ph = jnp.concatenate([batch.candidate_post_hashes, gn_ph], axis=1)
comb_ah = jnp.concatenate([batch.candidate_author_hashes, gn_ah], axis=1)
comb_ps = jnp.concatenate(
[batch.candidate_product_surface, jnp.zeros((1, N_neg), dtype=jnp.int32)], axis=1
)
comb_batch = batch._replace(
candidate_post_hashes=comb_ph,
candidate_author_hashes=comb_ah,
candidate_product_surface=comb_ps,
)
model.build_candidate_representation(comb_batch, combined_emb)
_ = hk.get_parameter("log_temperature", [], init=hk.initializers.Constant(0.0))
return user_repr
ret_fn = hk.without_apply_rng(hk.transform(ret_forward))
This is weird for an inference function. Let me unpack:
N_neg = 64— number of "global negatives" used during training (in-batch negative sampling).- The function builds combined embeddings (candidate + N_neg fake negatives) and calls
build_candidate_representationon the combined batch. - It also reads
log_temperature(a learned scalar used in InfoNCE loss).
None of this is needed for inference! The function only returns user_repr. But the candidate-tower call + log_temperature param read are needed for parameter loading:
When the model was trained, hk.transform recorded parameters from:
build_user_representation(transformer + user reduce).build_candidate_representationon a[1, cand_len + N_neg, ...]shaped input.- The
log_temperaturescalar.
If we load params trained on this graph but our inference function only calls build_user_representation, Haiku will reject the params (extra params it doesn't recognize).
So this forward function is a shape-matching reconstruction of the training forward — runs everything for parameter registration but only returns what we need.
Ugly but pragmatic. Real production would have a cleaner saved-model format.
C = cand_len
batch = RecsysBatch(
user_hashes=jnp.asarray(user_hashes),
history_post_hashes=jnp.asarray(hist_post_h),
history_author_hashes=jnp.asarray(hist_author_h),
history_actions=jnp.asarray(history_actions.reshape(1, hist_len, num_actions)),
history_product_surface=jnp.zeros((1, hist_len), dtype=jnp.int32),
candidate_post_hashes=jnp.zeros((1, C, 2), dtype=jnp.int32),
candidate_author_hashes=jnp.zeros((1, C, 2), dtype=jnp.int32),
candidate_product_surface=jnp.zeros((1, C), dtype=jnp.int32),
)
emb_batch = RecsysEmbeddings(
user_embeddings=jnp.asarray(ret_emb[user_hashes]),
history_post_embeddings=jnp.asarray(ret_emb[hist_post_h]),
candidate_post_embeddings=jnp.zeros((1, C, 2, emb_size)),
history_author_embeddings=jnp.asarray(ret_emb[hist_author_h]),
candidate_author_embeddings=jnp.zeros((1, C, 2, emb_size)),
)
dummy_gn = jnp.zeros((1, N_neg, 2, emb_size))
user_repr = ret_fn.apply(ret_params, batch, emb_batch, dummy_gn, dummy_gn)
log.info(" User repr norm=%.4f", float(jnp.linalg.norm(user_repr)))
Look up embeddings from the unified table via numpy fancy indexing: ret_emb[user_hashes] returns shape [B, num_hashes, emb_size].
Candidate stuff is zeros (we don't have candidates yet — that's what we're retrieving). The dummy_gn provides the N_neg negatives.
Run the forward. Result: a [1, D] user vector.
user_repr norm log — should be ~1.0 since the model L2-normalizes. Sanity check.
TOP_K = min(args.top_k_retrieval, len(corpus_post_ids))
scores = corpus_repr @ np.asarray(user_repr[0])
top_idx = np.argpartition(scores, -TOP_K)[-TOP_K:]
top_idx = top_idx[np.argsort(-scores[top_idx])]
ret_post_ids = corpus_post_ids[top_idx]
ret_author_ids = corpus_author_ids[top_idx]
ret_scores = scores[top_idx]
ret_topics = [str(corpus_topics[i]) for i in top_idx]
log.info(" Retrieved %d (score range: %.4f - %.4f)", TOP_K, ret_scores[-1], ret_scores[0])
Manual top-K via numpy: argpartition for unsorted top-K (O(N) instead of O(N log N)), then argsort within those K for proper ordering.
corpus_repr @ user_repr[0] is [N, D] × [D] = [N]. Dot product similarity (cosine, since both are normalized).
Then look up post IDs, author IDs, topics for the top K.
Ranking pass
log.info("Ranking %d candidates...", TOP_K)
rank_model_config = build_model_config(rank_cfg, PhoenixModelConfig)
def rank_forward(b, e):
return rank_model_config.make()(b, e)
rank_fn = hk.without_apply_rng(hk.transform(rank_forward))
Build the ranker config + transform. Simpler than the retrieval forward — just call the model directly.
rank_user_h = rank_hash_user(np.array([user_id], dtype=np.uint64))
rank_hist_post_h = rank_hash_item(history_post_ids).reshape(1, hist_len, -1)
rank_hist_author_h = rank_hash_author(history_author_ids).reshape(1, hist_len, -1)
all_probs = []
for i in range(0, TOP_K, cand_len):
j = min(i + cand_len, TOP_K)
cs = j - i
cph = rank_hash_item(ret_post_ids[i:j]).reshape(1, cs, -1)
cah = rank_hash_author(ret_author_ids[i:j]).reshape(1, cs, -1)
if cs < cand_len:
cph = np.pad(cph, ((0, 0), (0, cand_len - cs), (0, 0)))
cah = np.pad(cah, ((0, 0), (0, cand_len - cs), (0, 0)))
rb = RecsysBatch(
user_hashes=jnp.asarray(rank_user_h),
history_post_hashes=jnp.asarray(rank_hist_post_h),
history_author_hashes=jnp.asarray(rank_hist_author_h),
history_actions=jnp.asarray(history_actions.reshape(1, hist_len, num_actions)),
history_product_surface=jnp.zeros((1, hist_len), dtype=jnp.int32),
candidate_post_hashes=jnp.asarray(cph),
candidate_author_hashes=jnp.asarray(cah),
candidate_product_surface=jnp.zeros((1, cand_len), dtype=jnp.int32),
)
re = RecsysEmbeddings(
user_embeddings=jnp.asarray(rank_emb[rank_user_h]),
history_post_embeddings=jnp.asarray(rank_emb[rank_hist_post_h]),
candidate_post_embeddings=jnp.asarray(rank_emb[cph]),
history_author_embeddings=jnp.asarray(rank_emb[rank_hist_author_h]),
candidate_author_embeddings=jnp.asarray(rank_emb[cah]),
)
out = rank_fn.apply(rank_params, rb, re)
probs = jax.nn.sigmoid(out.logits)
all_probs.append(np.asarray(probs[0, :cs, :]))
Batched ranking — process 200 candidates in chunks of cand_len (e.g., 32). For each chunk:
- Hash post/author IDs (using the ranker's hash functions, not retrieval's).
- Look up embeddings from the ranker's unified table.
- Build batch + embeddings.
- Pad if last chunk is shorter than
cand_len. - Apply the model.
- Sigmoid the logits to get probabilities.
- Slice off the padding (
probs[0, :cs, :]). - Append to list.
The model is fixed-shape (always cand_len candidates), so we pad and slice to handle the actual count.
all_probs = np.concatenate(all_probs)
weighted = (
all_probs[:, IDX_FAV] * 1.0
+ all_probs[:, IDX_REPLY] * 0.5
+ all_probs[:, IDX_RT] * 0.3
+ all_probs[:, IDX_DWELL] * 0.2
)
ranked = np.argsort(-weighted)
Simplified weighted score for the demo. Note this uses IDX_FAV = 1, IDX_REPLY = 4, ... — the input-action indices, NOT the output indices. The output indices are 0-18 from the ACTIONS list.
So all_probs[:, IDX_FAV=1] corresponds to reply_score from ACTIONS. The demo is using mismatched indices! This looks like a bug in the demo.
Looking more carefully: the IDX_* constants reference SERVER_TWEET_FAV, SERVER_TWEET_REPLY, etc. — these are the input action taxonomy. They happen to also correspond to output indices in a coincidental way in some training schemas. For this demo running on the published checkpoint, the user reading the code probably needs to know the published model's actual action mapping (could be either).
Either way: the weighting [1.0, 0.5, 0.3, 0.2] for [fav, reply, retweet, dwell] is a quick demo formula, not production-tuned. The real production weights come from feature switches (Session 10).
DISPLAY = min(args.top_k_display, TOP_K)
print("\n" + "=" * 120)
print(f"PIPELINE RESULTS — User {user_id}")
print(f"History: {len(history)} items | Corpus: {len(corpus_post_ids)} posts")
print(f"Retrieved top {TOP_K} → Ranked by engagement model")
print("=" * 120)
print(
f"{'Rank':<5} {'Score':<8} {'Ret':<7} {'Fav':<7} {'Reply':<7} "
f"{'RT':<7} {'Dwell':<7} {'VQV':<7} {'Topics':<30} Post URL"
)
print("-" * 120)
for rank, idx in enumerate(ranked[:DISPLAY]):
pid = int(ret_post_ids[idx])
print(
f"{rank+1:<5} {float(weighted[idx]):<8.4f} {float(ret_scores[idx]):<7.4f} "
f"{float(all_probs[idx, IDX_FAV]):<7.4f} {float(all_probs[idx, IDX_REPLY]):<7.4f} "
f"{float(all_probs[idx, IDX_RT]):<7.4f} {float(all_probs[idx, IDX_DWELL]):<7.4f} "
f"{float(all_probs[idx, IDX_VQV]):<7.4f} {ret_topics[idx][:28]:<30} "
f"https://x.com/a/status/{pid}"
)
print(
f"\nWeighted score range: "
f"[{float(weighted[ranked[-1]]):.4f}, {float(weighted[ranked[0]]):.4f}]"
)
print("=" * 120)
Print the results in a formatted table: rank, weighted score, retrieval score, individual action probabilities, topics, post URL.
The URL https://x.com/a/status/<pid> lets the demo user click through to see the actual post. Sanity check by eye: does the model's top recommendation make sense given the user's history?
End of run_pipeline.py.
run_ranker.py (121 lines) — standalone ranker demo
The pre-run_pipeline.py standalone ranker demo. Now superseded by the pipeline, but kept for reference.
def main():
# Model configuration
emb_size = 128 # Embedding dimension
num_actions = len(ACTIONS) # Number of explicit engagement actions
history_seq_len = 32 # Max history length
candidate_seq_len = 8 # Max candidates to rank
# Hash configuration
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
recsys_model = PhoenixModelConfig(
emb_size=emb_size,
num_actions=num_actions,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)
# Create inference runner
inference_runner = RecsysInferenceRunner(
runner=ModelRunner(
model=recsys_model,
bs_per_device=0.125,
),
name="recsys_local",
)
Build a tiny model: emb_size=128, 32-item history, 8 candidates, 2 transformer layers, 2 attention heads. Designed to run on a laptop.
bs_per_device=0.125 → batch_size=1 even with 1 GPU. Saves memory.
print("Initializing model...")
inference_runner.initialize()
print("Model initialized!")
# ...
example_batch, example_embeddings = create_example_batch(
batch_size=batch_size,
emb_size=emb_size,
history_len=history_seq_len,
# ...
)
# Rank candidates
ranking_output = inference_runner.rank(example_batch, example_embeddings)
# Display results
scores = np.array(ranking_output.scores[0]) # [num_candidates, num_actions]
ranked_indices = np.array(ranking_output.ranked_indices[0])
# ...
for rank, idx in enumerate(ranked_indices):
# ...
for action_idx, action_name in enumerate(action_names):
prob = float(scores[idx, action_idx])
bar = "█" * int(prob * 20) + "░" * (20 - int(prob * 20))
print(f" {action_name:24s}: {bar} {prob:.3f}")
Generates random fake data + runs the model + displays ASCII bar charts of each action probability per candidate. Educational demo.
The model is randomly initialized, so the probabilities are basically random ~0.5. To get meaningful predictions, load real params via inference_runner.runner.load_or_init(..., checkpoint_path=...).
End of run_ranker.py.
run_retrieval.py (149 lines) — standalone retrieval demo
Mirror of run_ranker.py for retrieval.
def main():
# Model configuration - same architecture as Phoenix ranker
emb_size = 128
num_actions = len(ACTIONS)
history_seq_len = 32
candidate_seq_len = 8
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
# Configure the retrieval model - uses same transformer as Phoenix
retrieval_model_config = PhoenixRetrievalModelConfig(
emb_size=emb_size,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)
inference_runner = RecsysRetrievalInferenceRunner(
runner=RetrievalModelRunner(
model=retrieval_model_config,
bs_per_device=0.125,
),
name="retrieval_local",
)
Same tiny-model config. Different runner type for retrieval.
print("Initializing retrieval model...")
inference_runner.initialize()
# ...
batch_size = 2 # Two users for demo
example_batch, example_embeddings = create_example_batch(
batch_size=batch_size,
# ...
)
# Step 1: Create a corpus of candidate posts
corpus_size = 1000 # Simulated corpus of 1000 posts
corpus_embeddings, corpus_post_ids = create_example_corpus(
corpus_size=corpus_size,
emb_size=emb_size,
seed=456,
)
print(f"Corpus size: {corpus_size} posts")
inference_runner.set_corpus(corpus_embeddings, corpus_post_ids)
# Step 2: Retrieve top-k candidates for each user
top_k = 10
retrieval_output = inference_runner.retrieve(
example_batch,
example_embeddings,
top_k=top_k,
)
The two-step pipeline: corpus precomputation (here just random) → retrieval call. Two users × 10 results each.
for user_idx in range(batch_size):
print(f"\n User {user_idx + 1}:")
print(f" {'Rank':<6} {'Post ID':<12} {'Score':<12}")
print(f" {'-' * 30}")
for rank in range(top_k):
post_id = top_k_indices[user_idx, rank]
score = top_k_scores[user_idx, rank]
bar = "█" * int((score + 1) * 10) + "░" * (20 - int((score + 1) * 10))
print(f" {rank + 1:<6} {post_id:<12} {bar} {score:.4f}")
ASCII bar chart of similarity scores. The (score + 1) * 10 mapping: scores are in [-1, 1] (cosine), mapped to [0, 20] for the bar.
End of run_retrieval.py.
What we've learned
Checkpoint format: .npz files containing parameters keyed by slash-separated paths (module/submodule/param_name). Loaded via load_model_params which reconstructs the nested Haiku dict.
Split storage: model params + embedding tables are separate .npz files. Lets large embedding tables stay outside the param tree (different lifecycle, different storage).
Unified embedding table layout:
[0..65) — padding / system reserved
[65..65+UV) — user embeddings
[65+UV..65+UV+IV) — item embeddings
[65+UV+IV..) — author embeddings
Three hash functions, each with an offset, hash IDs into their respective regions.
Hash function: linear congruential (id * scale + bias) % modulus. Multiple (scale, bias) pairs per entity = multiple hashes (num_*_hashes). The same hash function is used in training and inference — embeddings learned for hash bucket h are the embeddings used at runtime for that hash bucket.
Dummy-init pattern: Haiku needs an example input to infer parameter shapes. All-zeros tensors of the right shape work fine for init; the actual values are then replaced via load_model_params.
The lru_cache of model(): Haiku functions track parameters by name. Caching the model instance ensures consistent naming across the init + apply phases.
Parameter discovery via use: hk.transform only collects parameters from functions it calls. RetrievalModelRunner.make_forward_fn calls build_candidate_representation even though it doesn't need the result — to make sure the candidate-tower params are registered in the parameter tree. Without this, loading trained checkpoints would fail.
Three apply functions for retrieval:
encode_user_fn— user vector only.encode_candidates_fn— candidate vectors only (used to precompute the corpus offline).retrieve_fn— full top-K retrieval.
All share the same trained params, just expose different slices of the model.
Two-hash-function-systems: retrieval and ranking models have different hash functions because they have different vocab sizes / scale-bias pairs. A given tweet ID hashes to different buckets in each. Each model's embedding table is keyed by its own hash space.
Chunked ranking: the ranker has a fixed candidate_seq_len. To rank 200 candidates with cand_len=32, process in 7 chunks (last one padded). The mask + pad pattern is common in inference of fixed-shape models.
run_pipeline.py is the major addition this release — a runnable end-to-end demo. Before this release, retrieval and ranking were separately runnable but no glue tied them together. The new script:
- Loads both models from checkpoints.
- Builds the user representation via retrieval transformer.
- Top-K against a pre-computed corpus.
- Ranks the top-K via the ranker.
- Combines per-action probabilities into a weighted score.
- Prints a formatted table.
This is the public, runnable example of what production does.
Action vocabulary mismatch: there are two taxonomies:
- Input action indices (used in
run_pipeline.py'sIDX_FAV = 1etc.) — encode "what the user did in their history." - Output action indices (used in
runners.py'sACTIONSlist, 0-indexed) — encode "what the model predicts about the candidate."
Some indices coincidentally align; others don't. The mapping is model-specific.
Forward function as parameter shape-matcher: run_pipeline.py's ret_forward does a bunch of work it doesn't need (combines negatives, reads log_temperature) just so the parameter tree matches what was trained. Ugly but unavoidable given the open release's checkpoint structure.
Next session
Session 17 — Phoenix grok + tests (1,342 LOC). The actual transformer implementation + the test suite:
phoenix/grok.py(616) — the Grok-1 transformer adapted for recsys (attention, RoPE, layer norm, candidate isolation mask)phoenix/test_recsys_model.py(309) — ranking model unit testsphoenix/test_recsys_retrieval_model.py(417) — retrieval model unit tests
After Session 17, we leave Phoenix and enter the Grox content-classification pipeline (Sessions 18-22).