X For You algorithm, line by line · Part 19
X For You algorithm, line by line — Part 19: Grox plans + data loaders
Part 19 — Grox's plan layer is a dependency-DAG executor: each plan declares tasks and a dependency map, asyncio futures wire them up, PlanMaster runs all 9 plans in parallel per task. Data loaders cover Kafka (streaming, with prefetch + thread-pool Thrift decode), Strato (on-demand RPC), and a separate-process ASR pipeline with ffmpeg + multimodal LLM.
Building on Session 18, this article covers two more Grox subsystems:
grox/plans/(462 LOC, 11 files) — the dependency-DAG execution layer. APlandeclares a set of tasks (the imperative units of work) and a dependency graph between them; whenPlan.execute(task)is called, it runs the DAG concurrently and produces aTaskResult. ThePlanMasterruns all plans in parallel, each plan filtering itself based on the task's eligibility set.grox/data_loaders/(819 LOC, 4 files) — the message-queue abstraction and concrete loaders.MessageQueueLoaderis the abstract interface;KafkaLoaderand its four subclasses are the concrete consumers;TweetStratoLoader,UserStratoLoader,ReplyRankingScoreStratoLoader, etc. are the synchronous KV-store loaders for on-demand reads;ASRProcessoris a separate-process speech-recognition pipeline.
Total: 1,281 LOC.
There's also grox/data_loaders/data_types.py and grox/data_loaders/media_processor.py imported throughout but not present in the open-source dump. We'll note their interfaces where they appear.
plans/plan.py (106 lines)
The abstract Plan base class.
import time
import asyncio
import logging
import traceback
from abc import ABC
from functools import cache
from grox.lib.utils import camel_to_snake
from grox.tasks.task import Task, TaskResultCategory
from monitor.metrics import Metrics
from grox.schedules.types import TaskResult, TaskContext, TaskPayload, TaskEligibility
logger = logging.getLogger(__name__)
Task is the abstract base class for individual units of work — we'll see it in Session 21. TaskResultCategory is an enum with at least a SKIPPED variant (and presumably SUCCESS/ERROR).
Class structure
class Plan(ABC):
TASKS: dict[str, type[Task]] = {}
TASK_DEPENDENCIES: dict[str, set[str]] = {}
REQUIRED_ELIGIBILITY: TaskEligibility
def __init__(self):
self.deps = set([d for deps in self.TASK_DEPENDENCIES.values() for d in deps])
if any(t not in self.TASKS for t in self.deps) or any(
t not in self.TASKS for t in self.TASK_DEPENDENCIES.keys()
):
raise ValueError("Not every task in TASK_DEPENDENCIES is defined in TASKS")
Three class attributes that subclasses fill:
TASKS— map of task name → Task class. The keys are arbitrary string identifiers, used only within this plan to refer to tasks for dependency wiring.TASK_DEPENDENCIES— map of task name → set of task names this task waits on.REQUIRED_ELIGIBILITY— the single eligibility that gates whether this plan runs at all.
The constructor flattens the dependency graph into self.deps (all task names that appear as a dependency) and validates: every task referenced (either as a key or as a value in the dependency map) must be declared in TASKS. This is the cheap up-front sanity check — it catches typos at process startup, not runtime.
Execute the plan
async def execute(self, task: TaskPayload) -> TaskResult | None:
if not self._eligible(task):
return None
Metrics.counter("plan.execute.count").add(
1, attributes={"plan_name": self.get_name()}
)
logger.debug(f"Creating execution plan for graph: {self.TASK_DEPENDENCIES}")
loop = asyncio.get_running_loop()
dependencies = {task: loop.create_future() for task in self.deps}
start = time.perf_counter()
ctx = TaskContext(task)
If the eligibility doesn't match, return None immediately — the caller (PlanMaster) filters None results.
The crucial line is dependencies = {task: loop.create_future() for task in self.deps}. A future per dependency. Each task that produces a value for a dependency will call future.set_result(...) when done; each task that consumes a dependency will await on it. This is the classic Python pattern for a DAG executor — instead of writing your own topological sort, you let asyncio's scheduler do it for you. Tasks all start concurrently; they immediately await their predecessors' futures; the ones with no predecessors start work right away; the ones with predecessors block until they're set.
Note self.deps is only the dependee side. Looking back at the constructor: self.deps is the union of values in TASK_DEPENDENCIES. So we create futures for tasks that are waited on, not necessarily for every task. Plans that produce a side effect but aren't depended on (e.g., a publish-Kafka task at a leaf) don't get a future.
try:
await asyncio.gather(
*[self._execute_task(t, ctx, dependencies) for t in self.TASKS.keys()]
)
Metrics.counter("plan.execute.success.count").add(
1, attributes={"plan_name": self.get_name()}
)
except Exception as e:
logger.error(f"Error executing plan: {traceback.format_exc()}")
ctx.errors.append(e)
Metrics.counter("plan.execute.failed.count").add(
1, attributes={"plan_name": self.get_name()}
)
asyncio.gather over every task in TASKS. Each task starts as its own coroutine, immediately blocks on its dependencies, then runs when ready. An exception in any task escapes through gather and is caught here — accumulated into ctx.errors (so we can still produce a result).
finally:
duration = time.perf_counter() - start
Metrics.histogram("plan.execute.duration").record(
duration, attributes={"plan_name": self.get_name()}
)
for fut in dependencies.values():
try:
if not fut.done():
fut.cancel()
except Exception:
logger.error(
f"Error canceling dependency future: {traceback.format_exc()}"
)
dependencies.clear()
Always record duration, always cancel undone futures. A future being undone after gather returns implies a task crashed before setting it — those futures' awaiters would hang forever otherwise. dependencies.clear() drops the references for GC.
return TaskResult(
task=task,
content_categories=[c.model_copy() for c in ctx.content_categories],
task_started_at=ctx.start_time,
task_finished_at=time.perf_counter(),
multimodal_post_embedding=ctx.multimodal_post_embedding,
reason=ctx.reason,
success=len(ctx.errors) == 0,
error="\n".join([str(e) for e in ctx.errors]),
)
Build the result from the mutated context. content_categories is model_copy'd — these are Pydantic models and copying defends against later in-place mutation. The other fields are scalars or already-immutable.
success = len(ctx.errors) == 0 — note: a single failed task means the whole plan is marked failed. Errors are joined with newlines for display.
Single task execution with dependency wait
def _eligible(self, ctx: TaskPayload) -> bool:
return self.REQUIRED_ELIGIBILITY in ctx.eligibilities
async def _execute_task(
self, task_name: str, ctx: TaskContext, dependencies: dict[str, asyncio.Future]
):
logger.debug(f"Waiting for task to become ready: {task_name}")
task = self.TASKS[task_name]
deps = self.TASK_DEPENDENCIES.get(task_name, set())
dep_futures = [dependencies[d] for d in deps]
dep_results = await asyncio.gather(*dep_futures)
task_future = dependencies.get(task_name, None)
if any(r == TaskResultCategory.SKIPPED for r in dep_results):
if task_future is not None:
task_future.set_result(TaskResultCategory.SKIPPED)
return
logger.debug(f"Started executing task: {task_name}")
try:
res = await task.exec(ctx)
except Exception as e:
if task_future is not None:
task_future.set_exception(e)
raise e
if task_future is not None:
task_future.set_result(res)
logger.debug(f"Finished executing task: {task_name}")
This is the per-task wrapper:
- Gather dependency results. All dependencies' futures get awaited — note this also serves as "wait for them to be ready", since
await futreturns the valueset_result(value)was called with. - Skip propagation: if any dependency returned
TaskResultCategory.SKIPPED, this task skips too — and propagates the SKIPPED downstream. This is crucial for the filter pattern: a filter task at the start of the plan can decide "this post doesn't qualify" and SKIPPED cascades through the entire DAG. No downstream task runs. - Run the task via
task.exec(ctx). If it throws, propagate the exception viaset_exception(so downstream tasks awaiting this future see the exception), then re-raise soPlan.execute's outer try/except catches it. - Set the future's result if anyone's waiting on this task.
The dependencies.get(task_name, None) is None for leaf tasks (those not waited on by anything). We skip the future bookkeeping for them — they just run and finish.
@cache
def get_name(self) -> str:
return camel_to_snake(self.__class__.__name__)
PlanInitialBanger → plan_initial_banger. Used for metrics attribution. @cache because plans are long-lived singletons and the result never changes.
plans/plan_master.py (62 lines)
Runs all plans in parallel and merges their results.
import asyncio
from grox.plans.plan import Plan
from grox.schedules.types import TaskResult, TaskPayload
from grox.plans.plan_spam_comment import PlanSpamComment
from grox.plans.plan_initial_banger import PlanInitialBanger
from grox.plans.plan_post_embedding_with_summary import PlanPostEmbeddingWithSummary
from grox.plans.plan_post_embedding_v5 import PlanPostEmbeddingV5
from grox.plans.plan_post_embedding_v5_for_reply import PlanPostEmbeddingV5ForReply
from grox.plans.plan_post_embedding_with_summary_for_reply import (
PlanPostEmbeddingWithSummaryForReply,
)
from grox.plans.plan_post_safety import PlanPostSafety
from grox.plans.plan_reply_ranking import PlanReplyRanking
from grox.plans.plan_safety_ptos import PlanSafetyPtos
class PlanMaster:
ALL_PLANS: list[Plan] = [
PlanInitialBanger(),
PlanPostSafety(),
PlanSpamComment(),
PlanPostEmbeddingWithSummary(),
PlanPostEmbeddingWithSummaryForReply(),
PlanPostEmbeddingV5(),
PlanPostEmbeddingV5ForReply(),
PlanReplyRanking(),
PlanSafetyPtos(),
]
Nine plans, each a singleton instance. The list order doesn't matter operationally because plans run via asyncio.gather — but it does fix the merge order for multimodal_post_embedding (we'll see why below).
@classmethod
async def exec(cls, task: TaskPayload) -> TaskResult:
results = await asyncio.gather(*[p.execute(task) for p in cls.ALL_PLANS])
result = cls.merge_results(task, [r for r in results if r is not None])
return result
Run every plan concurrently. Each plan returns None if it's not eligible (the _eligible check at the top of execute). The non-None results get merged.
For a typical task carrying {SPAM_COMMENT, REPLY_RANKING} eligibilities, two plans return TaskResults — the others all return None. So this is gather-N-but-only-keep-the-relevant-ones rather than a costly invocation pattern. The cost of running an execute() with the wrong eligibility is just one branch + one return.
Result merge
@classmethod
def merge_results(cls, task: TaskPayload, results: list[TaskResult]) -> TaskResult:
multimodal_post_embedding = [
r.multimodal_post_embedding
for r in results
if r.multimodal_post_embedding is not None
]
if multimodal_post_embedding:
multimodal_post_embedding = multimodal_post_embedding[0]
else:
multimodal_post_embedding = None
return TaskResult(
task=task,
content_categories=[
c.model_copy() for r in results for c in r.content_categories
],
task_started_at=min(r.task_started_at for r in results),
task_finished_at=max(r.task_finished_at for r in results),
multimodal_post_embedding=multimodal_post_embedding,
reason="\n".join([r.reason for r in results if r.reason]),
success=all(r.success for r in results),
error="\n".join(
[r.error or "unknown error" for r in results if not r.success]
),
)
Merge semantics:
multimodal_post_embedding— take the first non-None one. This is why list order in ALL_PLANS matters. There's only one embedding per task ever (a task is in one of {v4 with summary, v4 for reply, v5, v5 for reply} — but never all). Defensive: if somehow two plans produced one, the first wins.content_categories— concatenate across all plans (a task might produce both banger annotations AND safety annotations).task_started_at/task_finished_at— min start, max finish. This makes the result's duration the wall-clock duration of the full ensemble, not the sum.reason— concat with newlines, only non-empty.success— AND of all sub-plans; ANY failure = task failure.error— concat the failed-plan errors.
The nine plan files (29-38 lines each)
Each plan is a tiny config class. Let me walk through them quickly — they're all the same shape (TASKS map + TASK_DEPENDENCIES map) — and then I'll abstract the common pattern.
plan_initial_banger.py
class PlanInitialBanger(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.BANGER_INITIAL_SCREEN
TASKS = {
"task_initial_banger_filter": TaskInitialBangerFilter,
"task_banger_annotation_rate_limit": TaskRateLimitBangerAnnotationWithPost,
"task_media_hydration": TaskMediaHydrationBanger,
"task_banger_screen_initial": TaskBangerScreen,
"task_grok_upa_action_with_labels": TaskGrokUpaActionWithLabels,
"task_publish_unified_post_annotations_manhattan": TaskPublishUnifiedPostAnnotationsManhattan,
"task_publish_kafka": TaskPublishKafka,
}
TASK_DEPENDENCIES = {
"task_initial_banger_filter": set(),
"task_banger_annotation_rate_limit": {"task_initial_banger_filter"},
"task_media_hydration": {"task_banger_annotation_rate_limit"},
"task_banger_screen_initial": {"task_media_hydration"},
"task_grok_upa_action_with_labels": {"task_banger_screen_initial"},
"task_publish_unified_post_annotations_manhattan": {
"task_banger_screen_initial"
},
"task_publish_kafka": {"task_publish_unified_post_annotations_manhattan"},
}
The "banger" plan: identify high-quality candidates for the initial-screen surface. Drawing the DAG:
filter → rate_limit → media_hydration → banger_screen ┬→ grok_upa_action_with_labels
│
└→ publish_manhattan → publish_kafka
Mostly linear with a fork at the end: after the banger-screen classifier produces its annotation, we do two things in parallel — apply a Grok action with labels (post-process for additional metadata), and publish to Manhattan (X's KV store). Then publish-Kafka depends only on the Manhattan publish, not on the Grok action (so the Kafka publish doesn't wait for the slower Grok call).
plan_spam_comment.py
class PlanSpamComment(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.SPAM_COMMENT
TASKS = {
"task_spam_filter": TaskSpamFilter,
"task_reply_spam_annotation_rate_limit": TaskRateLimitReplySpamAnnotationWithPost,
"task_media_hydration": TaskMediaHydration,
"task_spam_detection": TaskSpamDetection,
"task_publish_reply_spam_mh": TaskWriteReplySpamManhattan,
"task_publish_kafka": TaskPublishKafka,
}
TASK_DEPENDENCIES = {
"task_spam_filter": set(),
"task_reply_spam_annotation_rate_limit": {"task_spam_filter"},
"task_media_hydration": {"task_reply_spam_annotation_rate_limit"},
"task_spam_detection": {"task_media_hydration"},
"task_publish_reply_spam_mh": {"task_spam_detection"},
"task_publish_kafka": {"task_spam_detection"},
}
Spam-comment plan: filter → rate_limit → media_hydration → spam_detection ┬→ publish_mh, └→ publish_kafka.
Same shape — filter → rate-limit → media → ML inference → publish.
plan_post_safety.py
class PlanPostSafety(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.POST_SAFETY
TASKS = {
"task_post_safety_deluxe_filter": TaskPostSafetyDeluxeFilter,
"task_post_safety_annotation_rate_limit": TaskRateLimitPostSafetyAnnotationWithPost,
"task_media_hydration": TaskMediaHydrationBanger,
"task_post_safety_screen_deluxe": TaskPostSafetyScreenDeluxe,
"task_grok_upa_action_with_labels": TaskGrokUpaActionWithLabels,
"task_upsert_tweet_bool_metadata_to_unified_post_annotations_manhattan": TaskUpsertTweetBoolMetadataToUnifiedPostAnnotation,
}
The "post safety deluxe" plan: same pattern (filter → rate_limit → media → safety_screen → fork to grok+manhattan-upsert). No Kafka publish at the end of this one — its outputs go only to Manhattan.
plan_safety_ptos.py
class PlanSafetyPtos(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.SAFETY_PTOS
TASKS = {
"task_safety_ptos_filter": TaskSafetyPtosFilter,
"task_safety_ptos_annotation_rate_limit": TaskRateLimitSafetyPtosAnnotationWithPost,
"task_media_hydration": TaskMediaHydration,
"task_safety_ptos_category_detection": TaskSafetyPtosCategoryDetection,
"task_safety_ptos_policy_detection": TaskSafetyPtosPolicyDetection,
"task_write_safety_post_annotations_result_sink": TaskWriteSafetyPostAnnotationsResultSink,
}
TASK_DEPENDENCIES = {
"task_safety_ptos_filter": {},
"task_safety_ptos_annotation_rate_limit": {"task_safety_ptos_filter"},
"task_media_hydration": {"task_safety_ptos_annotation_rate_limit"},
"task_safety_ptos_category_detection": {"task_media_hydration"},
"task_safety_ptos_policy_detection": {"task_safety_ptos_category_detection"},
"task_write_safety_post_annotations_result_sink": {"task_safety_ptos_policy_detection"},
}
PTOS = "Pillars of Safety". Two-stage classifier: first detect a content category (this post is about Sex / Violence / etc.), then a policy detection (does it violate the policy for that category?). The category result is a dependency for policy. This is the "two-step" pattern: cheap broad classifier first, expensive fine-grained one only on the matched category.
Notice task_safety_ptos_filter: {} — using {} (empty dict) instead of set() for the empty value. That's a bug-of-no-consequence (Python's {} is a dict, not a set, but the iteration semantics are the same when empty). The other plans correctly use set(). This wouldn't fail validation in Plan.__init__ because the empty dict iterates to nothing.
plan_reply_ranking.py
class PlanReplyRanking(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.REPLY_RANKING
TASKS = {
"task_reply_ranking_filter": TaskReplyRankingFilter,
"task_reply_ranking_annotation_rate_limit": TaskRateLimitReplyRankingAnnotationWithPost,
"task_media_hydration": TaskMediaHydration,
"task_rank_replies": TaskRankReplies,
"task_write_reply_ranking_manhattan": TaskWriteReplyRankingManhattan,
}
Reply-ranking plan: filter → rate_limit → media → rank_replies → write_manhattan. This is the system that ranks reply posts under a given parent post.
plan_post_embedding_with_summary.py and ..._for_reply.py
class PlanPostEmbeddingWithSummary(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.POST_EMBEDDING_WITH_SUMMARY
TASKS = {
"task_post_embedding_rate_limit_summary": TaskRateLimitEmbeddingWithPostSummary,
"task_post_embedding_with_summary_filter": TaskPostEmbeddingWithSummaryFilter,
"task_media_hydration": TaskMediaHydration,
"task_post_embedding_summarizer": TaskPostEmbeddingSummarizer,
"task_multimodal_post_embedding_with_summary": TaskMultimodalPostEmbeddingWithSummary,
"task_write_post_embedding_sink_v3": TaskWriteMMEmbeddingSinkV3,
}
Embedding-with-summary plan: rate_limit → filter → media → summarizer → embedding → write-v3-sink.
The summarizer is the LLM call that produces a text summary of the post's content (handling text + image OCR + video transcript); the embedding task takes that summary and produces a vector embedding. Note the summarizer → embedding chain — you can't embed without first having a summary.
The _for_reply variant is identical except for the filter (which has different criteria for reply posts) and the rate-limit (different per-tier QPS budgets).
plan_post_embedding_v5.py and ..._for_reply.py
class PlanPostEmbeddingV5(Plan):
REQUIRED_ELIGIBILITY = TaskEligibility.MM_EMB_V5
TASKS = {
"task_post_embedding_rate_limit": TaskRateLimitEmbeddingV5,
"task_media_hydration": TaskMediaHydration,
"task_asr_transcription": TaskASRTranscription,
"task_multimodal_post_embedding_v5": TaskMultimodalPostEmbeddingV5,
"task_write_post_embedding_sink_v5": TaskWriteMMEmbeddingSinkV5SkipKafkaForReplies,
}
TASK_DEPENDENCIES = {
"task_post_embedding_rate_limit": set(),
"task_media_hydration": {"task_post_embedding_rate_limit"},
"task_asr_transcription": {"task_media_hydration"},
"task_multimodal_post_embedding_v5": {"task_asr_transcription"},
"task_write_post_embedding_sink_v5": {"task_multimodal_post_embedding_v5"},
}
V5 embedding plan: rate_limit → media → ASR transcription → embedding → write-v5-sink.
Differences from v3:
- No filter step — V5 embeds everything that hits the rate-limited budget.
- No summarizer — replaces text summarization with ASR (transcribing audio from videos).
- Different sink (
TaskWriteMMEmbeddingSinkV5SkipKafkaForRepliesvsTaskWriteMMEmbeddingSinkV3).
The v5 design philosophy: instead of compressing into a text summary, let the multimodal model directly attend to the visual frames + transcribed audio. The sink name SkipKafkaForReplies is interesting — it means: for reply posts specifically, write to the KV store but don't emit a Kafka event. Probably because the downstream Kafka consumers of post-embeddings only care about top-level posts.
The _for_reply variant adds a task_post_embedding_filter_for_reply filter step and uses TaskWriteMMEmbeddingSinkV5 (the normal sink, with Kafka). Inverse of the V5 main plan — replies get a filter, top-level posts don't.
The common pattern
All nine plans follow this template:
[filter →] rate_limit → media_hydration → [classifier_or_embedder] → publish (fork as needed)
Filter decides whether this specific post is worth processing — based on language, post type, author tier, etc. Skipped here means everything downstream skips.
Rate-limit is a global throttle — each plan has its own QPS budget so a flood in one category doesn't starve another.
Media hydration downloads the post's images/videos and decodes them — slow I/O, parallelizable.
Classifier/embedder is the LLM/transformer call — the expensive step.
Publish writes the result to KV (Manhattan) and/or emits a Kafka event for downstream consumers.
The dependency graph is mostly linear (each step waits on the previous), with occasional forks at the end for parallel publishes. This is stage-pipelined parallelism: the system can have many tasks in flight at the same time, with task-A's filter running concurrently with task-B's classifier and task-C's publish.
data_loaders/message_queue_loader.py (39 lines)
The abstract interface every message-queue loader implements.
import logging
from abc import ABC, abstractmethod
from grox.data_loaders.data_types import Post, User, UserContext, GroxContentAnalysis
from collections.abc import AsyncGenerator
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class MessageQueuePayload(BaseModel):
mid: str
post: Post | None = None
user: User | None = None
user_context: UserContext | None = None
grox_content_analysis: GroxContentAnalysis | None = None
deadline_ts_secs: int
class MessageQueueLoader(ABC):
def __init__(self):
pass
@abstractmethod
async def start(self):
pass
@abstractmethod
async def stop(self):
pass
@abstractmethod
def poll(self) -> AsyncGenerator[MessageQueuePayload | None, None]:
pass
@abstractmethod
async def ack(self, mid: str, success: bool = True):
pass
Four abstract methods: start, stop, poll, ack. poll() returns an async generator yielding payloads (or None when no message is immediately available). ack(mid, success) is called by the dispatcher after processing — success=False triggers retry/dead-letter on the underlying queue.
MessageQueuePayload is the unified shape — mid (message ID, used for ack routing), the actual content (post / user / context / analysis), and deadline_ts_secs (when to stop trying).
We saw Post, User, UserContext, and GroxContentAnalysis imported from data_types — that file isn't in the open-source dump, but from context we know:
Postis a Pydantic model with fieldsid(post ID),media, text content, author info, safety labels, etc. It has class methods likefrom_thrift_content_understanding_metadata(deserialize from Thrift bytes).Useris a Pydantic model for an author.UserContextwraps a user + some context (e.g., the recent posts they've seen) — used for reply-ranking style tasks.GroxContentAnalysisis a Pydantic model holding the prior analysis result for recovery flows.
data_loaders/kafka_loader.py (232 lines)
Concrete Kafka-backed loader. Uses aiokafka + an internal kafka_cli wrapper.
Setup
import struct
import time
import uuid
import asyncio
import logging
import traceback
from abc import abstractmethod
from typing import override
from collections.abc import AsyncGenerator
from concurrent.futures import ThreadPoolExecutor
from aiokafka import TopicPartition
from kafka_cli.config import KafkaMessage
from grox.config.config import KafkaTopicName, grox_config
from kafka_cli.consumer import KafkaConsumer
from grox.data_loaders.data_types import Post, User, GroxContentAnalysis
from grox.data_loaders.message_queue_loader import (
MessageQueueLoader,
MessageQueuePayload,
)
from monitor.metrics import Metrics
from thrifts.serdes import SerDesError
logger = logging.getLogger(__name__)
MAX_WORKING_THREADS = 12
MAX_WORKING_THREADS = 12 is the size of the thread pool used to deserialize batch messages. We'll see why.
class _Payload(MessageQueuePayload):
tp: TopicPartition
offset: int
Internal subclass that carries Kafka-specific fields. Not exposed to the dispatcher — only used inside the loader.
Constructor
class KafkaLoader(MessageQueueLoader):
def __init__(self, topic_name: KafkaTopicName):
super().__init__()
self._initialized = False
self._shutdown_event = asyncio.Event()
self.topic_name = topic_name
self.loader_config = grox_config.get_kafka_loader_topic(topic_name)
self.consumer_config = grox_config.get_kafka_consumer_topic(topic_name)
self.consumer = KafkaConsumer(self.consumer_config)
self.loaded_messages: dict[str, tuple[TopicPartition, int]] = {}
self.queue: asyncio.Queue[MessageQueuePayload] = asyncio.Queue()
self._prefetcher_task: asyncio.Task | None = None
The two-config split is interesting: loader_config (deadlines, prefetch sizes) and consumer_config (broker addresses, group ID) — both per-topic from the global config. The KafkaConsumer (from kafka_cli) is the actual aiokafka wrapper.
self.queue is the internal prefetch buffer — the loader runs a background prefetcher task that fills this queue, and poll() drains from it. This decouples the consumer's IO loop from the dispatcher's consumption rate.
loaded_messages is declared but as we'll see, not actually populated by _prefetcher — it's probably dead code (or used in a path we don't see).
Start / Stop / Poll / Ack
async def start(self):
logger.info(f"Initializing KafkaLoader, topic: {self.topic_name}")
self._initialized = True
await self.consumer.start()
self._prefetcher_task = asyncio.create_task(self._prefetcher())
self._initialized = True
logger.info(f"KafkaLoader initialized, topic: {self.topic_name}")
Start the consumer, spawn the prefetcher. Note self._initialized = True appears twice (looks like a leftover edit).
async def stop(self):
logger.warning(f"Stopping KafkaLoader, topic: {self.topic_name}")
self._shutdown_event.set()
try:
if self._prefetcher_task:
await asyncio.wait_for(self._prefetcher_task, 5)
except asyncio.TimeoutError:
logger.warning(
f"Waiting prefetcher to stop timed out, topic: {self.topic_name}"
)
await self.consumer.stop()
logger.warning(f"KafkaLoader stopped, topic: {self.topic_name}")
Signal shutdown via event; wait up to 5s for the prefetcher to finish; stop the consumer. The timeout is defensive — if the prefetcher is stuck on an IO call, we don't want to hang shutdown forever.
async def poll(self) -> AsyncGenerator[MessageQueuePayload | None, None]:
while not self._shutdown_event.is_set() or not self.queue.empty():
try:
yield self.queue.get_nowait()
except asyncio.QueueEmpty:
logger.debug(
f"Queue is empty, waiting for prefetcher to fill, topic: {self.topic_name}"
)
yield None
except Exception:
logger.error(
f"Error polling from kafka, topic: {self.topic_name}, error: {traceback.format_exc()}"
)
yield None
Drains the internal queue, yielding None when empty so the caller (StreamTaskGenerator) can rate-limit / shutdown-check / interleave other work. The loop drains the queue even after shutdown is signaled — same pattern as the engine / dispatcher.
async def ack(self, mid: str, success: bool = True):
pass
Ack is a no-op! The KafkaLoader doesn't track or commit offsets per message. Looking at this more carefully: Kafka auto-commit is presumably enabled in the consumer config — offsets are committed periodically based on time, not on per-message acks. This means a process crash could re-process some messages (at-least-once delivery), but it can't lose them. The retry logic in the dispatcher operates on the dispatcher's in-memory state, not by replaying from Kafka.
This is a pragmatic choice: for an enrichment pipeline that already writes to idempotent destinations (Manhattan put-by-id, Kafka publishes that downstream consumers dedupe), at-least-once is fine and the operational simplicity of auto-commit is preferred over per-message offset tracking.
Message deserialization in a thread pool
@abstractmethod
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
pass
def _process_messages(self, messages: list[KafkaMessage]) -> list[_Payload]:
group_size = max(1, len(messages) // MAX_WORKING_THREADS)
message_groups = [
messages[i : i + group_size] for i in range(0, len(messages), group_size)
]
with ThreadPoolExecutor(max_workers=MAX_WORKING_THREADS) as executor:
payloads = []
for result in executor.map(self._messages_to_payloads, message_groups):
payloads.extend(result)
return payloads
Each subclass implements _messages_to_payloads (Thrift bytes → Post object). The base class splits a batch of messages into 12 chunks and processes each chunk on a separate thread.
This is a micro-optimization that matters at scale: Thrift deserialization is CPU-bound and bypasses the GIL when running in pure-Python — but more importantly, when you have C-extension-backed Thrift codecs, they release the GIL during decoding. Splitting a batch across 12 threads gets you near-linear speedup on the deserialization phase.
group_size = max(1, len(messages) // 12) — if you have less than 12 messages, each thread gets a chunk of 1 (so some threads might idle, but no harm done). 12 is hardcoded; you might think it should be os.cpu_count() but apparently 12 is what was tuned for in production.
Note this is a fresh ThreadPoolExecutor per batch — creating threads is cheap relative to the cost of decoding a large Thrift message, and you don't have to manage long-lived thread state.
The prefetcher loop
async def _prefetcher(self) -> None:
logger.info(f"Starting prefetcher, topic: {self.topic_name}")
prefetching_threshold = self.loader_config.prefetching_threshold
prefetching_batch_size = self.loader_config.prefetching_batch_size
while not self._is_shutdown():
if self.queue.qsize() < prefetching_threshold:
logger.debug(
f"Inventory low at {self.queue.qsize()}, prefetching {prefetching_batch_size} messages, topic: {self.topic_name}"
)
try:
messages = await self.consumer.poll(prefetching_batch_size)
try:
payloads = self._process_messages(messages)
except SerDesError:
logger.error(
f"Error processing messages, error: {traceback.format_exc()}"
)
raise
await asyncio.gather(
*[
self.queue.put(
MessageQueuePayload(
mid=payload.mid,
user=payload.user,
post=payload.post,
user_context=payload.user_context,
grox_content_analysis=payload.grox_content_analysis,
deadline_ts_secs=payload.deadline_ts_secs,
)
)
for payload in payloads
]
)
logger.debug(
f"Prefetched {prefetching_batch_size} messages, inventory now at {self.queue.qsize()}, topic: {self.topic_name}"
)
except Exception:
logger.error(
f"Error prefetching messages, error: {traceback.format_exc()}"
)
await asyncio.sleep(0.1)
else:
await asyncio.sleep(0.1)
logger.warning("Prefetcher stopped")
The prefetcher is the inventory keeper: only call Kafka when the internal queue's size drops below prefetching_threshold. Top up by prefetching_batch_size. This is classic prefetch-on-demand:
- If consumers are slow (engine is the bottleneck), the queue stays full and we don't burn Kafka calls.
- If consumers are fast, the queue drops and we fetch a fresh batch.
Note the conversion: payloads come out of _process_messages as _Payload (with tp/offset), but get re-wrapped to plain MessageQueuePayload when pushed onto the queue. The tp/offset info is dropped — we're not using it. Consistent with the no-op ack.
Exception strategy: SerDesError from Thrift is re-raised (after logging) — that crashes the prefetcher and the loader, which crashes the dispatcher and then the whole process. A serialization error indicates a schema mismatch — better to die than to silently drop messages. Other exceptions (Kafka connection issues etc.) are logged and we sleep 0.1s before retrying.
Concrete subclasses
class KafkaPostLoader(KafkaLoader):
def __init__(self, topic_name: KafkaTopicName):
super().__init__(topic_name)
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
post=Post.from_thrift_content_understanding_metadata(message.value),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]
The post loader: deserialize message.value (raw Thrift bytes) into a Post via Post.from_thrift_content_understanding_metadata. The mid is a fresh UUID hex — the loader generates its own message IDs because Kafka offsets are not unique across topics and we want a globally unique routing key.
deadline_ts_secs = now + task_deadline_secs from config — typical values would be a few minutes. Downstream filters check this deadline and skip if exceeded.
class KafkaPostEmbeddingRequestLoader(KafkaLoader):
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
post=Post.from_thrift_post_embedding_request(message.value),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]
Same shape, but a different Thrift schema — embedding-request messages have a slightly different layout (probably with hints about which embedding model to use).
class KafkaGroxContentAnalysisLoader(KafkaLoader):
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
grox_content_analysis=GroxContentAnalysis.from_thrift_content_understanding_metadata(
message.value
),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]
class KafkaTweetEmbeddingLoader(KafkaLoader):
@override
def _messages_to_payloads(self, messages: list[KafkaMessage]) -> list[_Payload]:
return [
_Payload(
mid=uuid.uuid4().hex,
post=Post.from_thrift_tweet_embedding(message.value),
tp=message.tp,
offset=message.offset,
deadline_ts_secs=int(time.time())
+ self.loader_config.task_deadline_secs,
)
for message in messages
]
Two more loaders for different message schemas. The recovery / re-process flows use these.
Note: stream_generator.py imported KafkaAdPostLoader, KafkaPostLoader, KafkaPostEmbeddingRequestLoader — but KafkaAdPostLoader isn't defined in this file. Either it's elsewhere (likely a missing file in the open-source dump) or it was removed but the import wasn't cleaned up.
data_loaders/strato_loader.py (154 lines)
Strato is the X-internal RPC framework / KV-store query system (we saw it referenced earlier in home-mixer as strato_client::tonic_runtime). The strato_loader module wraps a handful of Strato endpoints for on-demand reads/writes.
import asyncio
import logging
from grox.data_loaders.data_types import Post, User
from strato_http.queries.data_types import (
ReplyRankingScore,
ReplyRankingScoreKafka,
)
from strato_http.queries.content_understanding_author_metadata import (
StratoContentUnderstandingAuthorMetadata,
)
from strato_http.queries.content_understanding_post_quote_metadata import (
StratoContentUnderstandingPostQuoteMetadata,
)
from strato_http.queries.content_understanding_metadata_v2 import (
StratoContentUnderstandingMetadataV2,
)
from strato_http.queries.reply_ranking_score import StratoReplyRankingScore
from strato_http.queries.reply_spam_annotation import StratoReplySpamAnnotation
from strato_http.queries.reply_ranking_score_kafka_v2 import (
StratoReplyRankingScoreV2Kafka,
)
from strato_http.queries.safety_label import StratoSafetyLabel
from strato_http.queries.user_recent_posts import StratoUserRecentPosts
from grox.data_loaders.mappers.post_mapper import PostMapper
Eight Strato endpoints; one mapper.
TweetStratoLoader
class TweetStratoLoader:
content_understanding_metadata_strato = StratoContentUnderstandingMetadataV2()
content_understanding_post_quote_metadata_strato = (
StratoContentUnderstandingPostQuoteMetadata()
)
@classmethod
async def load_post(
cls, tweet_id: str, include_ancestors: bool = True
) -> Post | None:
if include_ancestors:
content_understanding_metadata = (
await cls.content_understanding_metadata_strato.fetch(int(tweet_id))
)
if content_understanding_metadata:
post = PostMapper.from_strato_content_understanding_metadata(
content_understanding_metadata
)
return post
else:
post_with_quote_metadata = (
await cls.content_understanding_post_quote_metadata_strato.fetch(
int(tweet_id)
)
)
if post_with_quote_metadata:
post = PostMapper.from_strato_post_with_quote_metadata(
post_with_quote_metadata
)
return post
return None
Load a post by tweet ID. Two query variants:
include_ancestors=True(default): use the V2 endpoint that returns content metadata for this tweet and its ancestors (parent for replies, original for quotes). Most expensive — useful when computing reply features.include_ancestors=False: use the lighter endpoint that returns the tweet + just its quote metadata. Cheaper for tasks that only need this one post.
Both go through PostMapper (in mappers/post_mapper.py — outside our scope but probably a static class with thrift-to-pydantic mappers).
Note these queries are classmethods on class-level singletons — the actual Strato client objects (StratoContentUnderstandingMetadataV2()) are created once when the module loads, and reused for every call.
UserStratoLoader
class UserStratoLoader:
strato = StratoContentUnderstandingAuthorMetadata()
@classmethod
async def load_user(cls, user_id: int) -> User | None:
strato_user = await cls.strato.fetch(user_id)
if not strato_user:
logger.warning(f"failed to hydrate user with {user_id=}, not found")
return None
return PostMapper._from_strato_user_metadata_to_user(strato_user)
One-shot user fetch by ID. Returns None if not found.
ReplyRankingScoreStratoLoader
class ReplyRankingScoreStratoLoader:
strato = StratoReplyRankingScore()
reply_ranking_v2_kafka_strato = StratoReplyRankingScoreV2Kafka()
@classmethod
async def save_reply_ranking_score(
cls, post_id: str, reply_ranking_score: ReplyRankingScore
):
await cls.strato.put(int(post_id), reply_ranking_score)
@classmethod
async def save_reply_ranking_kafka_v2(
cls, post_id: str, reply_ranking_score_kafka: ReplyRankingScoreKafka
):
await cls.reply_ranking_v2_kafka_strato.insert(
int(post_id), reply_ranking_score_kafka
)
Save reply-ranking scores. Two writers because there are two destinations: a Manhattan KV (StratoReplyRankingScore.put) and a Kafka topic (StratoReplyRankingScoreV2Kafka.insert). The _kafka_v2 is the newer rollout — both probably write in parallel during the migration period.
ReplySpamStratoLoader
class ReplySpamStratoLoader:
strato = StratoReplySpamAnnotation()
@classmethod
async def save_spam_reply_annotation(
cls, post_id: str, score: float, positive: bool, reason: str
):
await cls.strato.put(int(post_id), score, positive, reason)
Save a spam-annotation result: a score (continuous), a positive flag (boolean classification), and a reason (text explanation from the classifier). The positive boolean is the binarized version of the score against some threshold; storing both means downstream code can re-threshold without re-classifying.
UserRecentPostsLoader
class UserRecentPostsLoader:
recent_posts_strato = StratoUserRecentPosts()
post_hydrator = StratoContentUnderstandingPostQuoteMetadata()
safety_label = StratoSafetyLabel()
@classmethod
async def load(cls, user_id: int, limit: int = 10) -> list[Post]:
res = await cls.recent_posts_strato.fetch(
user_id, limit=limit, max_per_type=limit
)
if not res or "v" not in res:
logger.warning(f"No recent posts found for {user_id=}")
return []
post_ids: list[int] = []
for _post_type, posts in res["v"]:
for post in posts:
if _post_type == "TypeRetweet":
if "inReactionToPostId" in post:
post_ids.append(post["inReactionToPostId"])
else:
if "postId" in post:
post_ids.append(post["postId"])
Load the K most recent posts by a user. Step 1: fetch the metadata (post IDs grouped by type). Step 2: walk the result, picking the post ID — for retweets, that's inReactionToPostId (the original tweet's ID, not the retweet event), for everything else it's postId.
The result shape is a magic dict with key "v" containing a list of (post_type, posts_list) tuples — typical Strato Thrift-flavored shape.
if not post_ids:
return []
tasks = [cls.post_hydrator.fetch(post_id) for post_id in post_ids]
results = await asyncio.gather(*tasks, return_exceptions=True)
hydrated: list[Post] = []
for post_id, result in zip(post_ids, results):
if isinstance(result, Exception):
logger.warning(
f"Failed to hydrate recent post {post_id} for {user_id=}: {result}"
)
continue
if result is None:
continue
try:
hydrated.append(PostMapper.from_strato_post_with_quote_metadata(result))
except Exception:
logger.warning(
f"Failed to map recent post {post_id} for {user_id=}", exc_info=True
)
for post in hydrated:
post.safety_labels = await cls.safety_label.scan(post.id)
return hydrated
Step 3: hydrate each post ID in parallel via asyncio.gather(..., return_exceptions=True) — the return_exceptions=True is important: a single failed hydration (post deleted, query timed out) shouldn't fail the whole batch.
Step 4: filter results, map to Pydantic Posts, skipping failures.
Step 5: serially scan safety labels for each hydrated post. This is O(N) Strato calls in serial — could be parallelized with another gather, but presumably the labels scan is fast enough that serializing is fine. (Or: the scan is O(scan size) and parallelizing wouldn't help if upstream is the bottleneck.)
Used in tasks that need a user's posting history as context (e.g., spam detection that wants to see if this user habitually posts spam).
data_loaders/asr_processor.py (394 lines)
The biggest single file in this session. ASR = Automatic Speech Recognition. Transcribes video audio so we can include it in the multimodal embedding.
The architecture: separate processes running ffmpeg + an LLM-based transcription endpoint. ffmpeg is a heavy native subprocess; spinning it up per video and running multiple in parallel saturates IO and CPU. Doing this in the engine process would block the asyncio loop, so it's isolated.
Request / Result types
import asyncio
import base64
import logging
import os
import subprocess
import tempfile
import time
import traceback
from multiprocessing import Event, Process, Queue
from multiprocessing.synchronize import Event as MultiprocessingEvent
from queue import Empty
import aiohttp
from cachetools import TTLCache
from pydantic import BaseModel
from grox.config.config import grox_config
from grox.schedules.init import init_proc
from monitor.logging import Logging
from monitor.metrics import Metrics
logger = logging.getLogger(__name__)
class _ASRRequest(BaseModel):
post_id: str
video_url: str
max_audio_duration_s: float
class _ASRResult(BaseModel):
post_id: str
transcript: str | None = None
error: str | None = None
Pydantic models so they pickle cleanly across process boundaries. Errors return as error="..." rather than as exceptions — exceptions don't pickle well, and we want the calling code to see "what went wrong" not "stack trace from inside ffmpeg."
Extract WAV from a remote video URL
def _extract_wav_from_url(
video_url: str, max_duration_s: float | None = None
) -> bytes | None:
with tempfile.TemporaryDirectory() as tmpdir:
wav_path = os.path.join(tmpdir, "audio.wav")
cmd = [
"ffmpeg",
"-y",
"-timeout",
"60000000",
"-rw_timeout",
"60000000",
"-reconnect",
"1",
"-reconnect_streamed",
"1",
"-reconnect_delay_max",
"5",
"-i",
video_url,
"-vn",
"-acodec",
"pcm_s16le",
"-ar",
"16000",
"-ac",
"1",
]
if max_duration_s is not None and max_duration_s > 0:
cmd += ["-t", str(max_duration_s)]
cmd.append(wav_path)
result = subprocess.run(cmd, capture_output=True, timeout=180)
if result.returncode != 0:
if not os.path.exists(wav_path):
return None
raise subprocess.CalledProcessError(
result.returncode, cmd, result.stdout, result.stderr
)
with open(wav_path, "rb") as f:
return f.read()
Build an ffmpeg command:
-yoverwrite output-timeout,-rw_timeout60-second IO timeouts (in microseconds — that's60_000_000)-reconnect 1,-reconnect_streamed 1,-reconnect_delay_max 5— retry HTTP on transient failures-i <url>input (ffmpeg knows HTTP)-vnno video output-acodec pcm_s16le16-bit signed little-endian PCM-ar 1600016kHz sample rate (standard for ASR)-ac 1mono-t <duration>cap audio at max_duration_s
A Python subprocess wrapper around ffmpeg, with a 180-second total timeout. Errors:
- Returncode != 0 but file doesn't exist → return None (no audio stream). This happens for image-only posts.
- Returncode != 0 and file exists → raise
CalledProcessError(partial output, treat as a real failure). - Success → read and return the WAV bytes.
The temp dir auto-cleans on context exit.
Clean ASR output
def _clean_asr(raw: str) -> str:
if "<asr_text>" in raw:
raw = raw.split("<asr_text>", 1)[1]
if "</asr_text>" in raw:
raw = raw.split("</asr_text>", 1)[0]
return raw.strip()
The ASR model is prompt-engineered to return text wrapped in <asr_text>...</asr_text> tags. This function unwraps. Defensive splits — only split if the tag exists, only on first occurrence.
Worker process
class _ASRWorker:
def __init__(
self, task_queue: Queue, resp_queue: Queue, shutdown_event: MultiprocessingEvent
):
self._task_queue: Queue[tuple[_ASRRequest, dict[str, str]]] = task_queue
self._resp_queue: Queue[_ASRResult] = resp_queue
self._shutdown_event: MultiprocessingEvent = shutdown_event
The worker grabs three handles from the parent. The task queue carries (request, log_context) tuples — the log context is passed across the process boundary so we can stamp the worker's logs with the right post/user IDs.
async def _transcribe(self, request: _ASRRequest) -> str | None:
asr_config = grox_config.asr
t_start = time.monotonic()
loop = asyncio.get_event_loop()
wav_bytes = await loop.run_in_executor(
None, _extract_wav_from_url, request.video_url, request.max_audio_duration_s
)
t_extract = time.monotonic() - t_start
Metrics.histogram("asr_proc.extract_duration_s").record(t_extract)
if wav_bytes is None:
return None
Metrics.histogram("asr_proc.audio_bytes").record(len(wav_bytes))
logger.debug(
f"Extracted audio in {t_extract:.2f}s, size={len(wav_bytes)} bytes"
)
loop.run_in_executor(None, ...) runs the blocking ffmpeg call in a default thread pool — keeps the worker's asyncio loop free to do other things (like accept more tasks).
t_start = time.monotonic()
b64_audio = base64.b64encode(wav_bytes).decode()
body = {
"model": "default",
"messages": [
{
"role": "user",
"content": [
{
"type": "audio_url",
"audio_url": {"url": f"data:audio/wav;base64,{b64_audio}"},
},
{"type": "text", "text": "Transcribe this audio."},
],
}
],
"temperature": asr_config.temperature,
"max_tokens": asr_config.max_tokens,
}
async with self._session.post(
f"{asr_config.endpoint}/v1/chat/completions",
json=body,
timeout=aiohttp.ClientTimeout(total=asr_config.timeout),
) as resp:
POST to an OpenAI-compatible chat-completions endpoint with the audio as a data:audio/wav;base64,... URL inside the message content. This is the multimodal-LLM API contract. The model name is "default" — the actual model is configured server-side.
if resp.status == 200:
data = await resp.json()
raw_transcript = data["choices"][0]["message"]["content"].strip()
transcript = _clean_asr(raw_transcript)
t_transcribe = time.monotonic() - t_start
Metrics.histogram("asr_proc.transcribe_duration_s").record(t_transcribe)
Metrics.histogram("asr_proc.transcript_chars").record(len(transcript))
if "usage" in data:
usage = data["usage"]
if "prompt_tokens" in usage:
Metrics.histogram("asr_proc.prompt_tokens").record(
usage["prompt_tokens"]
)
if "completion_tokens" in usage:
Metrics.histogram("asr_proc.completion_tokens").record(
usage["completion_tokens"]
)
if "total_tokens" in usage:
Metrics.histogram("asr_proc.total_tokens").record(
usage["total_tokens"]
)
return transcript
else:
error_text = await resp.text()
raise Exception(
f"ASR request failed with status {resp.status}: {error_text}"
)
Parse the response, clean the transcript, record metrics including token usage from the response's usage field (so we can monitor cost/budget).
Non-200 → raise — caught and handled in _process.
Process a single request with rich error handling
async def _process(self, request: _ASRRequest, ctx: dict[str, str]) -> None:
attributes = {"pid": str(os.getpid())}
with Metrics.tracer("asr_proc").start_as_current_span("asr.process"):
Logging.set_context(**ctx)
start = time.perf_counter()
try:
Metrics.counter("asr_proc.total.count").add(1, attributes=attributes)
transcript = await self._transcribe(request)
if transcript is None:
logger.debug(
f"Video has no audio stream for post {request.post_id}, skipping ASR"
)
Metrics.counter("asr_proc.skip.count").add(
1, attributes={**attributes, "reason": "no_audio_stream"}
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, error="no_audio_stream")
)
else:
Metrics.counter("asr_proc.success.count").add(
1, attributes=attributes
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, transcript=transcript)
)
Run the transcription, dispatch on outcome:
- None = no audio stream (image-only post). Treat as a skip; emit a result with
error="no_audio_stream"so the caller knows. - transcript = success.
except subprocess.TimeoutExpired:
logger.warning(
f"FFmpeg timeout extracting audio for post {request.post_id}"
)
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "ffmpeg_timeout"}
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, error="ffmpeg_timeout")
)
except subprocess.CalledProcessError as e:
error_msg = e.stderr.decode() if e.stderr else str(e)
logger.warning(f"FFmpeg error for post {request.post_id}: {error_msg}")
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "ffmpeg_error"}
)
self._resp_queue.put(
_ASRResult(
post_id=request.post_id, error=f"ffmpeg_error: {error_msg}"
)
)
except asyncio.TimeoutError:
logger.warning(f"ASR request timed out for post {request.post_id}")
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "asr_timeout"}
)
self._resp_queue.put(
_ASRResult(post_id=request.post_id, error="asr_timeout")
)
except Exception as e:
logger.error(
f"ASR processing failed for post {request.post_id}: {traceback.format_exc()}"
)
Metrics.counter("asr_proc.error.count").add(
1, attributes={**attributes, "reason": "unknown"}
)
self._resp_queue.put(_ASRResult(post_id=request.post_id, error=str(e)))
finally:
end = time.perf_counter()
Metrics.histogram("asr_proc.duration").record(end - start)
Four named exception arms (each with its own metric reason label) + a catch-all. Every path produces an _ASRResult — the caller's future never hangs.
The metric attributes include reason: "ffmpeg_timeout" etc., so the SRE team can break down ASR failures by cause without grep-ing logs.
Worker process main loop
async def _init_run(self) -> None:
await init_proc("asr_proc")
self._session = aiohttp.ClientSession()
async def _run(self) -> None:
logger.info("starting ASR worker process loop")
pending: set[asyncio.Task] = set()
while not self._is_shutdown() or not self._task_queue.empty():
try:
request, ctx = self._task_queue.get(block=False)
except Empty:
await asyncio.sleep(0.01)
continue
try:
task = asyncio.create_task(self._process(request, ctx))
pending.add(task)
task.add_done_callback(pending.discard)
except Exception:
logger.error(
f"error processing ASR request {request.post_id}: {traceback.format_exc()}"
)
if pending:
logger.info(f"ASR worker draining {len(pending)} in-flight tasks")
await asyncio.gather(*pending, return_exceptions=True)
logger.warning("ASR worker process loop done")
Same engine-style loop: poll the queue, create_task per request, track pending tasks via a set with add_done_callback(pending.discard) for cleanup. On shutdown, drain pending before exiting.
def run(self) -> None:
async def wrapper():
await self._init_run()
try:
await self._run()
finally:
await self._session.close()
asyncio.run(wrapper())
def _start_loop(self) -> Process:
process = Process(target=self.run)
process.start()
return process
def start(self) -> list[Process]:
return [self._start_loop() for _ in range(grox_config.asr.max_workers)]
start() fan-outs to grox_config.asr.max_workers worker processes — each gets a copy of the same task queue, response queue, and shutdown event. Python's multiprocessing.Queue is process-safe; multiple consumers can race-pull from it.
def _is_shutdown(self) -> bool:
try:
return self._shutdown_event.is_set()
except BrokenPipeError:
logger.error("Broken pipe error, assuming shutdown")
return True
except Exception:
logger.error(
f"Error checking shutdown event, assuming shutdown: {traceback.format_exc()}"
)
return True
Same shutdown pattern as engine/dispatcher.
Parent-side ASRProcessor
class ASRProcessor:
_task_queue: Queue = Queue()
_resp_queue: Queue = Queue()
_shutdown_event = Event()
_inflights: dict[str, asyncio.Future[str | None]] = {}
_workers: list[Process] = []
_initialized = False
_result_task: asyncio.Task | None = None
_cache: TTLCache = TTLCache(maxsize=1_000, ttl=300)
A singleton via classmethods. _cache is a TTL cache (1000 entries, 5 min) — if multiple tasks ask for the same post's transcript within 5 minutes, only one ASR run happens.
@classmethod
async def process(
cls, post_id: str, video_url: str, max_audio_duration_s: float | None = None
) -> str | None:
if not cls._initialized:
raise RuntimeError("ASR processor not initialized")
cached = cls._cache.get(post_id)
if cached is not None:
Metrics.counter("asr_proc.cache_hit.count").add(1)
return cached
if max_audio_duration_s is None:
max_audio_duration_s = grox_config.asr.max_audio_duration_s
future = cls._submit(post_id, video_url, max_audio_duration_s)
transcript = await future
if transcript is not None:
cls._cache[post_id] = transcript
return transcript
The public API. Cache-first; on miss, submit a request and wait on the future. Cache only on success (so failures don't poison the cache for 5 minutes).
@classmethod
def _submit(
cls, post_id: str, video_url: str, max_audio_duration_s: float
) -> asyncio.Future[str | None]:
if post_id in cls._inflights:
return cls._inflights[post_id]
request = _ASRRequest(
post_id=post_id,
video_url=video_url,
max_audio_duration_s=max_audio_duration_s,
)
cls._task_queue.put((request, Logging.get_context()))
future: asyncio.Future[str | None] = asyncio.get_running_loop().create_future()
cls._inflights[post_id] = future
return future
Request deduplication: if the same post_id is already in flight, return the existing future. Two concurrent callers both end up awaiting the same future, and the second caller doesn't trigger a second ASR job. Combined with the TTL cache, this gives you "exactly one ASR per post per 5-minute window."
Logging.get_context() captures the current task/post/user log tags so they can be replayed in the worker process.
@classmethod
async def _result_loop(cls) -> None:
logger.info("ASR processor result loop started")
while not cls._shutdown_event.is_set() or cls._inflights:
try:
result: _ASRResult = cls._resp_queue.get(block=False)
future = cls._inflights.pop(result.post_id, None)
if not future:
logger.warning(f"no future found for post {result.post_id}")
continue
if result.error:
if result.error == "no_audio_stream":
logger.debug(
f"ASR skipped for post {result.post_id}: no audio stream"
)
else:
logger.warning(
f"ASR failed for post {result.post_id}: {result.error}"
)
future.set_result(None)
else:
future.set_result(result.transcript)
except Empty:
await asyncio.sleep(0.01)
except Exception:
logger.error(f"Error processing ASR result: {traceback.format_exc()}")
logger.warning("ASR processor result loop done")
Read from the response queue, find the corresponding future, set it. Error → set_result(None) (treat both errors and no-audio-stream as "no transcript"). The caller gets None and decides whether to proceed without ASR.
@classmethod
def start(cls) -> None:
if cls._initialized:
logger.warning("ASR processor already initialized")
return
logger.info(
f"starting ASR processor with {grox_config.asr.max_workers} workers"
)
cls._workers = _ASRWorker(
cls._task_queue, cls._resp_queue, cls._shutdown_event
).start()
cls._result_task = asyncio.create_task(cls._result_loop())
cls._initialized = True
@classmethod
async def stop(cls, timeout: float = 5) -> None:
logger.warning("stopping ASR processor")
cls._shutdown_event.set()
for worker in cls._workers:
if worker.is_alive():
worker.join(timeout)
if cls._result_task and not cls._result_task.done():
cls._result_task.cancel()
logger.warning("ASR processor stopped")
Start: spawn the worker processes via _ASRWorker.start() (returns a list of Process objects), kick off the result loop.
Stop: signal shutdown, wait for each worker (with a per-worker timeout — total stop wait is N × 5s), cancel the result loop.
Note start() is called from Engine._init_run (Session 18); the engine owns the ASRProcessor. The dispatcher and main don't know about it.
Summary
Plans are the dependency-DAG executor. Each plan has a set of named tasks and a dependency map between them; execute() launches them all in parallel via asyncio.gather, each task awaiting on its dependencies' asyncio futures. The PlanMaster runs all 9 plans in parallel per task; only the ones whose REQUIRED_ELIGIBILITY is in the task's eligibility set actually do work. Results merge by concatenating content categories, taking the first multimodal embedding, and AND'ing the success flags.
The DAG pattern is stage-pipelined: filter → rate_limit → media_hydration → ML inference → publish, with parallel forks at the publish stage. Across plans, common stages (e.g., task_media_hydration) are reused as classes — the same class shows up in multiple plans' TASKS dicts but is instantiated per-plan via Task.exec(ctx) calls (we'll see that in Session 21).
Data loaders split two ways:
- Kafka is the streaming inbound path, with batched prefetch + thread-pool Thrift decode + auto-commit (at-least-once delivery). Different
_messages_to_payloadsoverrides handle different schemas (post / embedding-request / content-analysis / tweet-embedding). - Strato is the on-demand RPC path, used by tasks that need to fetch additional context (parent posts, author info, user history) or write results back (reply-ranking scores, spam annotations).
ASRProcessor is a beautifully orthogonal subsystem: separate worker processes running ffmpeg + a multimodal-LLM transcription endpoint, with cache + dedup at the public API. Worth its own session-sized file (394 LOC) because video transcription has so many failure modes — five named exceptions plus a catch-all, and "no audio stream" is treated as a distinct skip case rather than an error.
Next session
Session 20 — Grox embedder + summarizer + classifiers (~1,450 LOC).
grox/embedder/*— the multimodal embedding pipeline (image + text + audio → vector).grox/summarizer/*— the LLM-based post-content summarizer.grox/classifiers/*— banger / safety / spam / reply-ranking classifiers (all LLM calls with structured-output parsing).
This is where the actual ML model calls live. Plans declare what to do; classifiers implement how to do it.