Showing preview only (321K chars total). Download the full file or copy to clipboard to get everything.
Repository: xai-org/x-algorithm
Branch: main
Commit: aaa167b3de8a
Files: 78
Total size: 300.0 KB
Directory structure:
gitextract_bsr8dmh8/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── candidate-pipeline/
│ ├── candidate_pipeline.rs
│ ├── filter.rs
│ ├── hydrator.rs
│ ├── lib.rs
│ ├── query_hydrator.rs
│ ├── scorer.rs
│ ├── selector.rs
│ ├── side_effect.rs
│ └── source.rs
├── home-mixer/
│ ├── candidate_hydrators/
│ │ ├── core_data_candidate_hydrator.rs
│ │ ├── gizmoduck_hydrator.rs
│ │ ├── in_network_candidate_hydrator.rs
│ │ ├── mod.rs
│ │ ├── subscription_hydrator.rs
│ │ ├── vf_candidate_hydrator.rs
│ │ └── video_duration_candidate_hydrator.rs
│ ├── candidate_pipeline/
│ │ ├── candidate.rs
│ │ ├── candidate_features.rs
│ │ ├── mod.rs
│ │ ├── phoenix_candidate_pipeline.rs
│ │ ├── query.rs
│ │ └── query_features.rs
│ ├── filters/
│ │ ├── age_filter.rs
│ │ ├── author_socialgraph_filter.rs
│ │ ├── core_data_hydration_filter.rs
│ │ ├── dedup_conversation_filter.rs
│ │ ├── drop_duplicates_filter.rs
│ │ ├── ineligible_subscription_filter.rs
│ │ ├── mod.rs
│ │ ├── muted_keyword_filter.rs
│ │ ├── previously_seen_posts_filter.rs
│ │ ├── previously_served_posts_filter.rs
│ │ ├── retweet_deduplication_filter.rs
│ │ ├── self_tweet_filter.rs
│ │ └── vf_filter.rs
│ ├── lib.rs
│ ├── main.rs
│ ├── query_hydrators/
│ │ ├── mod.rs
│ │ ├── user_action_seq_query_hydrator.rs
│ │ └── user_features_query_hydrator.rs
│ ├── scorers/
│ │ ├── author_diversity_scorer.rs
│ │ ├── mod.rs
│ │ ├── oon_scorer.rs
│ │ ├── phoenix_scorer.rs
│ │ └── weighted_scorer.rs
│ ├── selectors/
│ │ ├── mod.rs
│ │ └── top_k_score_selector.rs
│ ├── server.rs
│ ├── side_effects/
│ │ ├── cache_request_info_side_effect.rs
│ │ └── mod.rs
│ └── sources/
│ ├── mod.rs
│ ├── phoenix_source.rs
│ └── thunder_source.rs
├── phoenix/
│ ├── README.md
│ ├── grok.py
│ ├── pyproject.toml
│ ├── recsys_model.py
│ ├── recsys_retrieval_model.py
│ ├── run_ranker.py
│ ├── run_retrieval.py
│ ├── runners.py
│ ├── test_recsys_model.py
│ └── test_recsys_retrieval_model.py
└── thunder/
├── deserializer.rs
├── kafka/
│ ├── mod.rs
│ ├── tweet_events_listener.rs
│ ├── tweet_events_listener_v2.rs
│ └── utils.rs
├── kafka_utils.rs
├── lib.rs
├── main.rs
├── posts/
│ ├── mod.rs
│ └── post_store.rs
└── thunder_service.rs
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
================================================
FILE: CODE_OF_CONDUCT.md
================================================
Be excellent to each other.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# X For You Feed Algorithm
This repository contains the core recommendation system powering the "For You" feed on X. It combines in-network content (from accounts you follow) with out-of-network content (discovered through ML-based retrieval) and ranks everything using a Grok-based transformer model.
> **Note:** The transformer implementation is ported from the [Grok-1 open source release](https://github.com/xai-org/grok-1) by xAI, adapted for recommendation system use cases.
## Table of Contents
- [Overview](#overview)
- [System Architecture](#system-architecture)
- [Components](#components)
- [Home Mixer](#home-mixer)
- [Thunder](#thunder)
- [Phoenix](#phoenix)
- [Candidate Pipeline](#candidate-pipeline)
- [How It Works](#how-it-works)
- [Pipeline Stages](#pipeline-stages)
- [Scoring and Ranking](#scoring-and-ranking)
- [Filtering](#filtering)
- [Key Design Decisions](#key-design-decisions)
- [License](#license)
---
## Overview
The For You feed algorithm retrieves, ranks, and filters posts from two sources:
1. **In-Network (Thunder)**: Posts from accounts you follow
2. **Out-of-Network (Phoenix Retrieval)**: Posts discovered from a global corpus
Both sources are combined and ranked together using **Phoenix**, a Grok-based transformer model that predicts engagement probabilities for each post. The final score is a weighted combination of these predicted engagements.
We have eliminated every single hand-engineered feature and most heuristics from the system. The Grok-based transformer does all the heavy lifting by understanding your engagement history (what you liked, replied to, shared, etc.) and using that to determine what content is relevant to you.
---
## System Architecture
```
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
│ FOR YOU FEED REQUEST │
└─────────────────────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
│ HOME MIXER │
│ (Orchestration Layer) │
├─────────────────────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ QUERY HYDRATION │ │
│ │ ┌──────────────────────────┐ ┌──────────────────────────────────────────────┐ │ │
│ │ │ User Action Sequence │ │ User Features │ │ │
│ │ │ (engagement history) │ │ (following list, preferences, etc.) │ │ │
│ │ └──────────────────────────┘ └──────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ CANDIDATE SOURCES │ │
│ │ ┌─────────────────────────────┐ ┌────────────────────────────────┐ │ │
│ │ │ THUNDER │ │ PHOENIX RETRIEVAL │ │ │
│ │ │ (In-Network Posts) │ │ (Out-of-Network Posts) │ │ │
│ │ │ │ │ │ │ │
│ │ │ Posts from accounts │ │ ML-based similarity search │ │ │
│ │ │ you follow │ │ across global corpus │ │ │
│ │ └─────────────────────────────┘ └────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ HYDRATION │ │
│ │ Fetch additional data: core post metadata, author info, media entities, etc. │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ FILTERING │ │
│ │ Remove: duplicates, old posts, self-posts, blocked authors, muted keywords, etc. │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ SCORING │ │
│ │ ┌──────────────────────────┐ │ │
│ │ │ Phoenix Scorer │ Grok-based Transformer predicts: │ │
│ │ │ (ML Predictions) │ P(like), P(reply), P(repost), P(click)... │ │
│ │ └──────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────────────────────┐ │ │
│ │ │ Weighted Scorer │ Weighted Score = Σ (weight × P(action)) │ │
│ │ │ (Combine predictions) │ │ │
│ │ └──────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────────────────────┐ │ │
│ │ │ Author Diversity │ Attenuate repeated author scores │ │
│ │ │ Scorer │ to ensure feed diversity │ │
│ │ └──────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ SELECTION │ │
│ │ Sort by final score, select top K candidates │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────────────────┐ │
│ │ FILTERING (Post-Selection) │ │
│ │ Visibility filtering (deleted/spam/violence/gore etc) │ │
│ └─────────────────────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────────────────────────────────┐
│ RANKED FEED RESPONSE │
└─────────────────────────────────────────────────────────────────────────────────────────────┘
```
---
## Components
### Home Mixer
**Location:** [`home-mixer/`](home-mixer/)
The orchestration layer that assembles the For You feed. It leverages the `CandidatePipeline` framework with the following stages:
| Stage | Description |
|-------|-------------|
| Query Hydrators | Fetch user context (engagement history, following list) |
| Sources | Retrieve candidates from Thunder and Phoenix |
| Hydrators | Enrich candidates with additional data |
| Filters | Remove ineligible candidates |
| Scorers | Predict engagement and compute final scores |
| Selector | Sort by score and select top K |
| Post-Selection Filters | Final visibility and dedup checks |
| Side Effects | Cache request info for future use |
The server exposes a gRPC endpoint (`ScoredPostsService`) that returns ranked posts for a given user.
---
### Thunder
**Location:** [`thunder/`](thunder/)
An in-memory post store and realtime ingestion pipeline that tracks recent posts from all users. It:
- Consumes post create/delete events from Kafka
- Maintains per-user stores for original posts, replies/reposts, and video posts
- Serves "in-network" post candidates from accounts the requesting user follows
- Automatically trims posts older than the retention period
Thunder enables sub-millisecond lookups for in-network content without hitting an external database.
---
### Phoenix
**Location:** [`phoenix/`](phoenix/)
The ML component with two main functions:
#### 1. Retrieval (Two-Tower Model)
Finds relevant out-of-network posts:
- **User Tower**: Encodes user features and engagement history into an embedding
- **Candidate Tower**: Encodes all posts into embeddings
- **Similarity Search**: Retrieves top-K posts via dot product similarity
#### 2. Ranking (Transformer with Candidate Isolation)
Predicts engagement probabilities for each candidate:
- Takes user context (engagement history) and candidate posts as input
- Uses special attention masking so candidates cannot attend to each other
- Outputs probabilities for each action type (like, reply, repost, click, etc.)
See [`phoenix/README.md`](phoenix/README.md) for detailed architecture documentation.
---
### Candidate Pipeline
**Location:** [`candidate-pipeline/`](candidate-pipeline/)
A reusable framework for building recommendation pipelines. Defines traits for:
| Trait | Purpose |
|-------|---------|
| `Source` | Fetch candidates from a data source |
| `Hydrator` | Enrich candidates with additional features |
| `Filter` | Remove candidates that shouldn't be shown |
| `Scorer` | Compute scores for ranking |
| `Selector` | Sort and select top candidates |
| `SideEffect` | Run async side effects (caching, logging) |
The framework runs sources and hydrators in parallel where possible, with configurable error handling and logging.
---
## How It Works
### Pipeline Stages
1. **Query Hydration**: Fetch the user's recent engagements history and metadata (eg. following list)
2. **Candidate Sourcing**: Retrieve candidates from:
- **Thunder**: Recent posts from followed accounts (in-network)
- **Phoenix Retrieval**: ML-discovered posts from the global corpus (out-of-network)
3. **Candidate Hydration**: Enrich candidates with:
- Core post data (text, media, etc.)
- Author information (username, verification status)
- Video duration (for video posts)
- Subscription status
4. **Pre-Scoring Filters**: Remove posts that are:
- Duplicates
- Too old
- From the viewer themselves
- From blocked/muted accounts
- Containing muted keywords
- Previously seen or recently served
- Ineligible subscription content
5. **Scoring**: Apply multiple scorers sequentially:
- **Phoenix Scorer**: Get ML predictions from the Phoenix transformer model
- **Weighted Scorer**: Combine predictions into a final relevance score
- **Author Diversity Scorer**: Attenuate repeated author scores for diversity
- **OON Scorer**: Adjust scores for out-of-network content
6. **Selection**: Sort by score and select the top K candidates
7. **Post-Selection Processing**: Final validation of post candidates to be served
---
### Scoring and Ranking
The Phoenix Grok-based transformer model predicts probabilities for multiple engagement types:
```
Predictions:
├── P(favorite)
├── P(reply)
├── P(repost)
├── P(quote)
├── P(click)
├── P(profile_click)
├── P(video_view)
├── P(photo_expand)
├── P(share)
├── P(dwell)
├── P(follow_author)
├── P(not_interested)
├── P(block_author)
├── P(mute_author)
└── P(report)
```
The **Weighted Scorer** combines these into a final score:
```
Final Score = Σ (weight_i × P(action_i))
```
Positive actions (like, repost, share) have positive weights. Negative actions (block, mute, report) have negative weights, pushing down content the user would likely dislike.
---
### Filtering
Filters run at two stages:
**Pre-Scoring Filters:**
| Filter | Purpose |
|--------|---------|
| `DropDuplicatesFilter` | Remove duplicate post IDs |
| `CoreDataHydrationFilter` | Remove posts that failed to hydrate core metadata |
| `AgeFilter` | Remove posts older than threshold |
| `SelfpostFilter` | Remove user's own posts |
| `RepostDeduplicationFilter` | Dedupe reposts of same content |
| `IneligibleSubscriptionFilter` | Remove paywalled content user can't access |
| `PreviouslySeenPostsFilter` | Remove posts user has already seen |
| `PreviouslyServedPostsFilter` | Remove posts already served in session |
| `MutedKeywordFilter` | Remove posts with user's muted keywords |
| `AuthorSocialgraphFilter` | Remove posts from blocked/muted authors |
**Post-Selection Filters:**
| Filter | Purpose |
|--------|---------|
| `VFFilter` | Remove posts that are deleted/spam/violence/gore etc. |
| `DedupConversationFilter` | Deduplicate multiple branches of the same conversation thread |
---
## Key Design Decisions
### 1. No Hand-Engineered Features
The system relies entirely on the Grok-based transformer to learn relevance from user engagement sequences. No manual feature engineering for content relevance. This significantly reduces the complexity in our data pipelines and serving infrastructure.
### 2. Candidate Isolation in Ranking
During transformer inference, candidates cannot attend to each other—only to the user context. This ensures the score for a post doesn't depend on which other posts are in the batch, making scores consistent and cacheable.
### 3. Hash-Based Embeddings
Both retrieval and ranking use multiple hash functions for embedding lookup
### 4. Multi-Action Prediction
Rather than predicting a single "relevance" score, the model predicts probabilities for many actions.
### 5. Composable Pipeline Architecture
The `candidate-pipeline` crate provides a flexible framework for building recommendation pipelines with:
- Separation of pipeline execution and monitoring from business logic
- Parallel execution of independent stages and graceful error handling
- Easy addition of new sources, hydrations, filters, and scorers
---
## License
This project is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details.
================================================
FILE: candidate-pipeline/candidate_pipeline.rs
================================================
use crate::filter::Filter;
use crate::hydrator::Hydrator;
use crate::query_hydrator::QueryHydrator;
use crate::scorer::Scorer;
use crate::selector::Selector;
use crate::side_effect::{SideEffect, SideEffectInput};
use crate::source::Source;
use futures::future::join_all;
use log::{error, info, warn};
use std::sync::Arc;
use tonic::async_trait;
#[derive(Copy, Clone, Debug)]
pub enum PipelineStage {
QueryHydrator,
Source,
Hydrator,
PostSelectionHydrator,
Filter,
PostSelectionFilter,
Scorer,
}
pub struct PipelineResult<Q, C> {
pub retrieved_candidates: Vec<C>,
pub filtered_candidates: Vec<C>,
pub selected_candidates: Vec<C>,
pub query: Arc<Q>,
}
/// Provides a stable request identifier for logging/tracing.
pub trait HasRequestId {
fn request_id(&self) -> &str;
}
#[async_trait]
pub trait CandidatePipeline<Q, C>: Send + Sync
where
Q: HasRequestId + Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>];
fn sources(&self) -> &[Box<dyn Source<Q, C>>];
fn hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
fn filters(&self) -> &[Box<dyn Filter<Q, C>>];
fn scorers(&self) -> &[Box<dyn Scorer<Q, C>>];
fn selector(&self) -> &dyn Selector<Q, C>;
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
fn post_selection_filters(&self) -> &[Box<dyn Filter<Q, C>>];
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<Q, C>>>>;
fn result_size(&self) -> usize;
async fn execute(&self, query: Q) -> PipelineResult<Q, C> {
let hydrated_query = self.hydrate_query(query).await;
let candidates = self.fetch_candidates(&hydrated_query).await;
let hydrated_candidates = self.hydrate(&hydrated_query, candidates).await;
let (kept_candidates, mut filtered_candidates) = self
.filter(&hydrated_query, hydrated_candidates.clone())
.await;
let scored_candidates = self.score(&hydrated_query, kept_candidates).await;
let selected_candidates = self.select(&hydrated_query, scored_candidates);
let post_selection_hydrated_candidates = self
.hydrate_post_selection(&hydrated_query, selected_candidates)
.await;
let (mut final_candidates, post_selection_filtered_candidates) = self
.filter_post_selection(&hydrated_query, post_selection_hydrated_candidates)
.await;
filtered_candidates.extend(post_selection_filtered_candidates);
final_candidates.truncate(self.result_size());
let arc_hydrated_query = Arc::new(hydrated_query);
let input = Arc::new(SideEffectInput {
query: arc_hydrated_query.clone(),
selected_candidates: final_candidates.clone(),
});
self.run_side_effects(input);
PipelineResult {
retrieved_candidates: hydrated_candidates,
filtered_candidates,
selected_candidates: final_candidates,
query: arc_hydrated_query,
}
}
/// Run all query hydrators in parallel and merge results into the query.
async fn hydrate_query(&self, query: Q) -> Q {
let request_id = query.request_id().to_string();
let hydrators: Vec<_> = self
.query_hydrators()
.iter()
.filter(|h| h.enable(&query))
.collect();
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(&query));
let results = join_all(hydrate_futures).await;
let mut hydrated_query = query;
for (hydrator, result) in hydrators.iter().zip(results) {
match result {
Ok(hydrated) => {
hydrator.update(&mut hydrated_query, hydrated);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::QueryHydrator,
hydrator.name(),
err
);
}
}
}
hydrated_query
}
/// Run all candidate sources in parallel and collect results.
async fn fetch_candidates(&self, query: &Q) -> Vec<C> {
let request_id = query.request_id().to_string();
let sources: Vec<_> = self.sources().iter().filter(|s| s.enable(query)).collect();
let source_futures = sources.iter().map(|s| s.get_candidates(query));
let results = join_all(source_futures).await;
let mut collected = Vec::new();
for (source, result) in sources.iter().zip(results) {
match result {
Ok(mut candidates) => {
info!(
"request_id={} stage={:?} component={} fetched {} candidates",
request_id,
PipelineStage::Source,
source.name(),
candidates.len()
);
collected.append(&mut candidates);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::Source,
source.name(),
err
);
}
}
}
collected
}
/// Run all candidate hydrators in parallel and merge results into candidates.
async fn hydrate(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
self.run_hydrators(query, candidates, self.hydrators(), PipelineStage::Hydrator)
.await
}
/// Run post-selection candidate hydrators in parallel and merge results into candidates.
async fn hydrate_post_selection(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
self.run_hydrators(
query,
candidates,
self.post_selection_hydrators(),
PipelineStage::PostSelectionHydrator,
)
.await
}
/// Shared helper to hydrate with a provided hydrator list.
async fn run_hydrators(
&self,
query: &Q,
mut candidates: Vec<C>,
hydrators: &[Box<dyn Hydrator<Q, C>>],
stage: PipelineStage,
) -> Vec<C> {
let request_id = query.request_id().to_string();
let hydrators: Vec<_> = hydrators.iter().filter(|h| h.enable(query)).collect();
let expected_len = candidates.len();
let hydrate_futures = hydrators.iter().map(|h| h.hydrate(query, &candidates));
let results = join_all(hydrate_futures).await;
for (hydrator, result) in hydrators.iter().zip(results) {
match result {
Ok(hydrated) => {
if hydrated.len() == expected_len {
hydrator.update_all(&mut candidates, hydrated);
} else {
warn!(
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
request_id,
stage,
hydrator.name(),
expected_len,
hydrated.len()
);
}
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
stage,
hydrator.name(),
err
);
}
}
}
candidates
}
/// Run all filters sequentially. Each filter partitions candidates into kept and removed.
async fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
self.run_filters(query, candidates, self.filters(), PipelineStage::Filter)
.await
}
/// Run post-scoring filters sequentially on already-scored candidates.
async fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<C>) {
self.run_filters(
query,
candidates,
self.post_selection_filters(),
PipelineStage::PostSelectionFilter,
)
.await
}
// Shared helper to run filters sequentially from a provided filter list.
async fn run_filters(
&self,
query: &Q,
mut candidates: Vec<C>,
filters: &[Box<dyn Filter<Q, C>>],
stage: PipelineStage,
) -> (Vec<C>, Vec<C>) {
let request_id = query.request_id().to_string();
let mut all_removed = Vec::new();
for filter in filters.iter().filter(|f| f.enable(query)) {
let backup = candidates.clone();
match filter.filter(query, candidates).await {
Ok(result) => {
candidates = result.kept;
all_removed.extend(result.removed);
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
stage,
filter.name(),
err
);
candidates = backup;
}
}
}
info!(
"request_id={} stage={:?} kept {}, removed {}",
request_id,
stage,
candidates.len(),
all_removed.len()
);
(candidates, all_removed)
}
/// Run all scorers sequentially and apply their results to candidates.
async fn score(&self, query: &Q, mut candidates: Vec<C>) -> Vec<C> {
let request_id = query.request_id().to_string();
let expected_len = candidates.len();
for scorer in self.scorers().iter().filter(|s| s.enable(query)) {
match scorer.score(query, &candidates).await {
Ok(scored) => {
if scored.len() == expected_len {
scorer.update_all(&mut candidates, scored);
} else {
warn!(
"request_id={} stage={:?} component={} skipped: length_mismatch expected={} got={}",
request_id,
PipelineStage::Scorer,
scorer.name(),
expected_len,
scored.len()
);
}
}
Err(err) => {
error!(
"request_id={} stage={:?} component={} failed: {}",
request_id,
PipelineStage::Scorer,
scorer.name(),
err
);
}
}
}
candidates
}
/// Select (sort/truncate) candidates using the configured selector
fn select(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
if self.selector().enable(query) {
self.selector().select(query, candidates)
} else {
candidates
}
}
// Run all side effects in parallel
fn run_side_effects(&self, input: Arc<SideEffectInput<Q, C>>) {
let side_effects = self.side_effects();
tokio::spawn(async move {
let futures = side_effects
.iter()
.filter(|se| se.enable(input.query.clone()))
.map(|se| se.run(input.clone()));
let _ = join_all(futures).await;
});
}
}
================================================
FILE: candidate-pipeline/filter.rs
================================================
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::util;
pub struct FilterResult<C> {
pub kept: Vec<C>,
pub removed: Vec<C>,
}
/// Filters run sequentially and partition candidates into kept and removed sets
#[async_trait]
pub trait Filter<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this filter should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Filter candidates by evaluating each against some criteria.
/// Returns a FilterResult containing kept candidates (which continue to the next stage)
/// and removed candidates (which are excluded from further processing).
async fn filter(&self, query: &Q, candidates: Vec<C>) -> Result<FilterResult<C>, String>;
/// Returns a stable name for logging/metrics.
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}
================================================
FILE: candidate-pipeline/hydrator.rs
================================================
use crate::util;
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
// Hydrators run in parallel and update candidate fields
#[async_trait]
pub trait Hydrator<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this hydrator should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Hydrate candidates by performing async operations.
/// Returns candidates with this hydrator's fields populated.
///
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
/// Dropping candidates in a hydrator is not allowed - use a filter stage instead.
async fn hydrate(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
/// Update a single candidate with the hydrated fields.
/// Only the fields this hydrator is responsible for should be copied.
fn update(&self, candidate: &mut C, hydrated: C);
/// Update all candidates with the hydrated fields from `hydrated`.
/// Default implementation iterates and calls `update` for each pair.
fn update_all(&self, candidates: &mut [C], hydrated: Vec<C>) {
for (c, h) in candidates.iter_mut().zip(hydrated) {
self.update(c, h);
}
}
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}
================================================
FILE: candidate-pipeline/lib.rs
================================================
pub mod candidate_pipeline;
pub mod filter;
pub mod hydrator;
pub mod query_hydrator;
pub mod scorer;
pub mod selector;
pub mod side_effect;
pub mod source;
pub mod util;
================================================
FILE: candidate-pipeline/query_hydrator.rs
================================================
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::util;
#[async_trait]
pub trait QueryHydrator<Q>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
{
/// Decide if this query hydrator should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Hydrate the query by performing async operations.
/// Returns a new query with this hydrator's fields populated.
async fn hydrate(&self, query: &Q) -> Result<Q, String>;
/// Update the query with the hydrated fields.
/// Only the fields this hydrator is responsible for should be copied.
fn update(&self, query: &mut Q, hydrated: Q);
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}
================================================
FILE: candidate-pipeline/scorer.rs
================================================
use crate::util;
use std::any::type_name_of_val;
use tonic::async_trait;
/// Scorers update candidate fields (like a score field) and run sequentially
#[async_trait]
pub trait Scorer<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this scorer should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Score candidates by performing async operations.
/// Returns candidates with this scorer's fields populated.
///
/// IMPORTANT: The returned vector must have the same candidates in the same order as the input.
/// Dropping candidates in a scorer is not allowed - use a filter stage instead.
async fn score(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, String>;
/// Update a single candidate with the scored fields.
/// Only the fields this scorer is responsible for should be copied.
fn update(&self, candidate: &mut C, scored: C);
/// Update all candidates with the scored fields from `scored`.
/// Default implementation iterates and calls `update` for each pair.
fn update_all(&self, candidates: &mut [C], scored: Vec<C>) {
for (c, s) in candidates.iter_mut().zip(scored) {
self.update(c, s);
}
}
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}
================================================
FILE: candidate-pipeline/selector.rs
================================================
use crate::util;
use std::any::type_name_of_val;
pub trait Selector<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Default selection: sort and truncate based on provided configs
fn select(&self, _query: &Q, candidates: Vec<C>) -> Vec<C> {
let mut sorted = self.sort(candidates);
if let Some(limit) = self.size() {
sorted.truncate(limit);
}
sorted
}
/// Decide if this selector should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
/// Extract the score from a candidate to use for sorting.
fn score(&self, candidate: &C) -> f64;
/// Sort candidates by their scores in descending order.
fn sort(&self, candidates: Vec<C>) -> Vec<C> {
let mut sorted = candidates;
sorted.sort_by(|a, b| {
self.score(b)
.partial_cmp(&self.score(a))
.unwrap_or(std::cmp::Ordering::Equal)
});
sorted
}
/// Optionally provide a size to select. Defaults to no truncation if not overridden.
fn size(&self) -> Option<usize> {
None
}
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}
================================================
FILE: candidate-pipeline/side_effect.rs
================================================
use crate::util;
use std::any::type_name_of_val;
use std::sync::Arc;
use tonic::async_trait;
// A side-effect is an action run that doesn't affect the pipeline result from being returned
#[derive(Clone)]
pub struct SideEffectInput<Q, C> {
pub query: Arc<Q>,
pub selected_candidates: Vec<C>,
}
#[async_trait]
pub trait SideEffect<Q, C>: Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this side-effect should be run
fn enable(&self, _query: Arc<Q>) -> bool {
true
}
async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), String>;
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}
================================================
FILE: candidate-pipeline/source.rs
================================================
use std::any::{Any, type_name_of_val};
use tonic::async_trait;
use crate::util;
#[async_trait]
pub trait Source<Q, C>: Any + Send + Sync
where
Q: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
{
/// Decide if this source should run for the given query
fn enable(&self, _query: &Q) -> bool {
true
}
async fn get_candidates(&self, query: &Q) -> Result<Vec<C>, String>;
fn name(&self) -> &'static str {
util::short_type_name(type_name_of_val(self))
}
}
================================================
FILE: home-mixer/candidate_hydrators/core_data_candidate_hydrator.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct CoreDataCandidateHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
}
impl CoreDataCandidateHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for CoreDataCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let post_features = client.get_tweet_core_datas(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let core_data = post_features.and_then(|x| x.as_ref());
let text = core_data.map(|x| x.text.clone());
let hydrated = PostCandidate {
author_id: core_data.map(|x| x.author_id).unwrap_or_default(),
retweeted_user_id: core_data.and_then(|x| x.source_user_id),
retweeted_tweet_id: core_data.and_then(|x| x.source_tweet_id),
in_reply_to_tweet_id: core_data.and_then(|x| x.in_reply_to_tweet_id),
tweet_text: text.unwrap_or_default(),
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.retweeted_user_id = hydrated.retweeted_user_id;
candidate.retweeted_tweet_id = hydrated.retweeted_tweet_id;
candidate.in_reply_to_tweet_id = hydrated.in_reply_to_tweet_id;
candidate.tweet_text = hydrated.tweet_text;
}
}
================================================
FILE: home-mixer/candidate_hydrators/gizmoduck_hydrator.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::gizmoduck_client::GizmoduckClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct GizmoduckCandidateHydrator {
pub gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
}
impl GizmoduckCandidateHydrator {
pub async fn new(gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>) -> Self {
Self { gizmoduck_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for GizmoduckCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.gizmoduck_client;
let author_ids: Vec<_> = candidates.iter().map(|c| c.author_id).collect();
let author_ids: Vec<_> = author_ids.iter().map(|&x| x as i64).collect();
let retweet_user_ids: Vec<_> = candidates.iter().map(|c| c.retweeted_user_id).collect();
let retweet_user_ids: Vec<_> = retweet_user_ids.iter().flatten().collect();
let retweet_user_ids: Vec<_> = retweet_user_ids.iter().map(|&&x| x as i64).collect();
let mut user_ids_to_fetch = Vec::with_capacity(author_ids.len() + retweet_user_ids.len());
user_ids_to_fetch.extend(author_ids);
user_ids_to_fetch.extend(retweet_user_ids);
user_ids_to_fetch.dedup();
let users = client.get_users(user_ids_to_fetch).await;
let users = users.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
let user = users
.get(&(candidate.author_id as i64))
.and_then(|user| user.as_ref());
let user_counts = user.and_then(|user| user.user.as_ref().map(|u| &u.counts));
let user_profile = user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let author_followers_count: Option<i32> =
user_counts.map(|x| x.followers_count).map(|x| x as i32);
let author_screen_name: Option<String> = user_profile.map(|x| x.screen_name.clone());
let retweet_user = candidate
.retweeted_user_id
.and_then(|retweeted_user_id| users.get(&(retweeted_user_id as i64)))
.and_then(|user| user.as_ref());
let retweet_profile =
retweet_user.and_then(|user| user.user.as_ref().map(|u| &u.profile));
let retweeted_screen_name: Option<String> =
retweet_profile.map(|x| x.screen_name.clone());
let hydrated = PostCandidate {
author_followers_count,
author_screen_name,
retweeted_screen_name,
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.author_followers_count = hydrated.author_followers_count;
candidate.author_screen_name = hydrated.author_screen_name;
candidate.retweeted_screen_name = hydrated.retweeted_screen_name;
}
}
================================================
FILE: home-mixer/candidate_hydrators/in_network_candidate_hydrator.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct InNetworkCandidateHydrator;
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for InNetworkCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let viewer_id = query.user_id as u64;
let followed_ids: HashSet<u64> = query
.user_features
.followed_user_ids
.iter()
.copied()
.map(|id| id as u64)
.collect();
let hydrated_candidates = candidates
.iter()
.map(|candidate| {
let is_self = candidate.author_id == viewer_id;
let is_in_network = is_self || followed_ids.contains(&candidate.author_id);
PostCandidate {
in_network: Some(is_in_network),
..Default::default()
}
})
.collect();
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.in_network = hydrated.in_network;
}
}
================================================
FILE: home-mixer/candidate_hydrators/mod.rs
================================================
pub mod core_data_candidate_hydrator;
pub mod gizmoduck_hydrator;
pub mod in_network_candidate_hydrator;
pub mod subscription_hydrator;
pub mod vf_candidate_hydrator;
pub mod video_duration_candidate_hydrator;
================================================
FILE: home-mixer/candidate_hydrators/subscription_hydrator.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct SubscriptionHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
}
impl SubscriptionHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for SubscriptionHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let post_features = client.get_subscription_author_ids(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let subscription_author_id = post_features.and_then(|x| *x);
let hydrated = PostCandidate {
subscription_author_id,
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.subscription_author_id = hydrated.subscription_author_id;
}
}
================================================
FILE: home-mixer/candidate_hydrators/vf_candidate_hydrator.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use futures::future::join;
use std::collections::HashMap;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_twittercontext_proto::GetTwitterContextViewer;
use xai_twittercontext_proto::TwitterContextViewer;
use xai_visibility_filtering::models::FilteredReason;
use xai_visibility_filtering::vf_client::SafetyLevel;
use xai_visibility_filtering::vf_client::SafetyLevel::{TimelineHome, TimelineHomeRecommendations};
use xai_visibility_filtering::vf_client::VisibilityFilteringClient;
pub struct VFCandidateHydrator {
pub vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
}
impl VFCandidateHydrator {
pub async fn new(vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>) -> Self {
Self { vf_client }
}
async fn fetch_vf_results(
client: &Arc<dyn VisibilityFilteringClient + Send + Sync>,
tweet_ids: Vec<i64>,
safety_level: SafetyLevel,
for_user_id: i64,
context: Option<TwitterContextViewer>,
) -> Result<HashMap<i64, Option<FilteredReason>>, String> {
if tweet_ids.is_empty() {
return Ok(HashMap::new());
}
client
.get_result(tweet_ids, safety_level, for_user_id, context)
.await
.map_err(|e| e.to_string())
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for VFCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let context = query.get_viewer();
let user_id = query.user_id;
let client = &self.vf_client;
let mut in_network_ids = Vec::new();
let mut oon_ids = Vec::new();
for candidate in candidates.iter() {
if candidate.in_network.unwrap_or(false) {
in_network_ids.push(candidate.tweet_id);
} else {
oon_ids.push(candidate.tweet_id);
}
}
let in_network_future = Self::fetch_vf_results(
client,
in_network_ids,
TimelineHome,
user_id,
context.clone(),
);
let oon_future = Self::fetch_vf_results(
client,
oon_ids,
TimelineHomeRecommendations,
user_id,
context,
);
let (in_network_result, oon_result) = join(in_network_future, oon_future).await;
let mut result: HashMap<i64, Option<FilteredReason>> = HashMap::new();
result.extend(in_network_result?);
result.extend(oon_result?);
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for candidate in candidates {
let visibility_reason = result.get(&candidate.tweet_id);
let visibility_reason = visibility_reason.unwrap_or(&None);
let hydrated = PostCandidate {
visibility_reason: visibility_reason.clone(),
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.visibility_reason = hydrated.visibility_reason;
}
}
================================================
FILE: home-mixer/candidate_hydrators/video_duration_candidate_hydrator.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::candidate_features::MediaInfo;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::tweet_entity_service_client::TESClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::hydrator::Hydrator;
pub struct VideoDurationCandidateHydrator {
pub tes_client: Arc<dyn TESClient + Send + Sync>,
}
impl VideoDurationCandidateHydrator {
pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
Self { tes_client }
}
}
#[async_trait]
impl Hydrator<ScoredPostsQuery, PostCandidate> for VideoDurationCandidateHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let client = &self.tes_client;
let tweet_ids = candidates.iter().map(|c| c.tweet_id).collect::<Vec<_>>();
let post_features = client.get_tweet_media_entities(tweet_ids.clone()).await;
let post_features = post_features.map_err(|e| e.to_string())?;
let mut hydrated_candidates = Vec::with_capacity(candidates.len());
for tweet_id in tweet_ids {
let post_features = post_features.get(&tweet_id);
let media_entities = post_features.and_then(|x| x.as_ref());
let video_duration_ms = media_entities.and_then(|entities| {
entities.iter().find_map(|entity| {
if let Some(MediaInfo::VideoInfo(video_info)) = &entity.media_info {
Some(video_info.duration_millis)
} else {
None
}
})
});
let hydrated = PostCandidate {
video_duration_ms,
..Default::default()
};
hydrated_candidates.push(hydrated);
}
Ok(hydrated_candidates)
}
fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidate) {
candidate.video_duration_ms = hydrated.video_duration_ms;
}
}
================================================
FILE: home-mixer/candidate_pipeline/candidate.rs
================================================
use std::collections::HashMap;
use xai_home_mixer_proto as pb;
use xai_visibility_filtering::models as vf;
#[derive(Clone, Debug, Default)]
pub struct PostCandidate {
pub tweet_id: i64,
pub author_id: u64,
pub tweet_text: String,
pub in_reply_to_tweet_id: Option<u64>,
pub retweeted_tweet_id: Option<u64>,
pub retweeted_user_id: Option<u64>,
pub phoenix_scores: PhoenixScores,
pub prediction_request_id: Option<u64>,
pub last_scored_at_ms: Option<u64>,
pub weighted_score: Option<f64>,
pub score: Option<f64>,
pub served_type: Option<pb::ServedType>,
pub in_network: Option<bool>,
pub ancestors: Vec<u64>,
pub video_duration_ms: Option<i32>,
pub author_followers_count: Option<i32>,
pub author_screen_name: Option<String>,
pub retweeted_screen_name: Option<String>,
pub visibility_reason: Option<vf::FilteredReason>,
pub subscription_author_id: Option<u64>,
}
#[derive(Clone, Debug, Default)]
pub struct PhoenixScores {
pub favorite_score: Option<f64>,
pub reply_score: Option<f64>,
pub retweet_score: Option<f64>,
pub photo_expand_score: Option<f64>,
pub click_score: Option<f64>,
pub profile_click_score: Option<f64>,
pub vqv_score: Option<f64>,
pub share_score: Option<f64>,
pub share_via_dm_score: Option<f64>,
pub share_via_copy_link_score: Option<f64>,
pub dwell_score: Option<f64>,
pub quote_score: Option<f64>,
pub quoted_click_score: Option<f64>,
pub follow_author_score: Option<f64>,
pub not_interested_score: Option<f64>,
pub block_author_score: Option<f64>,
pub mute_author_score: Option<f64>,
pub report_score: Option<f64>,
// Continuous actions
pub dwell_time: Option<f64>,
}
pub trait CandidateHelpers {
fn get_screen_names(&self) -> HashMap<u64, String>;
}
impl CandidateHelpers for PostCandidate {
fn get_screen_names(&self) -> HashMap<u64, String> {
let mut screen_names = HashMap::<u64, String>::new();
if let Some(author_screen_name) = self.author_screen_name.clone() {
screen_names.insert(self.author_id, author_screen_name);
}
if let (Some(retweeted_screen_name), Some(retweeted_user_id)) =
(self.retweeted_screen_name.clone(), self.retweeted_user_id)
{
screen_names.insert(retweeted_user_id, retweeted_screen_name);
}
screen_names
}
}
================================================
FILE: home-mixer/candidate_pipeline/candidate_features.rs
================================================
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct PureCoreData {
pub author_id: u64,
pub text: String,
pub source_tweet_id: Option<u64>,
pub source_user_id: Option<u64>,
pub in_reply_to_tweet_id: Option<u64>,
pub in_reply_to_user_id: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct ExclusiveTweetControl {
pub conversation_author_id: i64,
}
pub type MediaEntities = Vec<MediaEntity>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct MediaEntity {
pub media_info: Option<MediaInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub enum MediaInfo {
VideoInfo(VideoInfo),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct VideoInfo {
pub duration_millis: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct Share {
pub source_tweet_id: u64,
pub source_user_id: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct Reply {
pub in_reply_to_tweet_id: Option<u64>,
pub in_reply_to_user_id: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUserCounts {
pub followers_count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUserProfile {
pub screen_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUser {
pub user_id: u64,
pub profile: GizmoduckUserProfile,
pub counts: GizmoduckUserCounts,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct GizmoduckUserResult {
pub user: Option<GizmoduckUser>,
}
================================================
FILE: home-mixer/candidate_pipeline/mod.rs
================================================
pub mod candidate;
pub mod candidate_features;
pub mod phoenix_candidate_pipeline;
pub mod query;
pub mod query_features;
================================================
FILE: home-mixer/candidate_pipeline/phoenix_candidate_pipeline.rs
================================================
use crate::candidate_hydrators::core_data_candidate_hydrator::CoreDataCandidateHydrator;
use crate::candidate_hydrators::gizmoduck_hydrator::GizmoduckCandidateHydrator;
use crate::candidate_hydrators::in_network_candidate_hydrator::InNetworkCandidateHydrator;
use crate::candidate_hydrators::subscription_hydrator::SubscriptionHydrator;
use crate::candidate_hydrators::vf_candidate_hydrator::VFCandidateHydrator;
use crate::candidate_hydrators::video_duration_candidate_hydrator::VideoDurationCandidateHydrator;
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::gizmoduck_client::{GizmoduckClient, ProdGizmoduckClient};
use crate::clients::phoenix_prediction_client::{
PhoenixPredictionClient, ProdPhoenixPredictionClient,
};
use crate::clients::phoenix_retrieval_client::{
PhoenixRetrievalClient, ProdPhoenixRetrievalClient,
};
use crate::clients::s2s::{S2S_CHAIN_PATH, S2S_CRT_PATH, S2S_KEY_PATH};
use crate::clients::socialgraph_client::SocialGraphClient;
use crate::clients::strato_client::{ProdStratoClient, StratoClient};
use crate::clients::thunder_client::ThunderClient;
use crate::clients::tweet_entity_service_client::{ProdTESClient, TESClient};
use crate::clients::uas_fetcher::UserActionSequenceFetcher;
use crate::filters::age_filter::AgeFilter;
use crate::filters::author_socialgraph_filter::AuthorSocialgraphFilter;
use crate::filters::core_data_hydration_filter::CoreDataHydrationFilter;
use crate::filters::dedup_conversation_filter::DedupConversationFilter;
use crate::filters::drop_duplicates_filter::DropDuplicatesFilter;
use crate::filters::ineligible_subscription_filter::IneligibleSubscriptionFilter;
use crate::filters::muted_keyword_filter::MutedKeywordFilter;
use crate::filters::previously_seen_posts_filter::PreviouslySeenPostsFilter;
use crate::filters::previously_served_posts_filter::PreviouslyServedPostsFilter;
use crate::filters::retweet_deduplication_filter::RetweetDeduplicationFilter;
use crate::filters::self_tweet_filter::SelfTweetFilter;
use crate::filters::vf_filter::VFFilter;
use crate::params;
use crate::query_hydrators::user_action_seq_query_hydrator::UserActionSeqQueryHydrator;
use crate::query_hydrators::user_features_query_hydrator::UserFeaturesQueryHydrator;
use crate::scorers::author_diversity_scorer::AuthorDiversityScorer;
use crate::scorers::oon_scorer::OONScorer;
use crate::scorers::phoenix_scorer::PhoenixScorer;
use crate::scorers::weighted_scorer::WeightedScorer;
use crate::selectors::TopKScoreSelector;
use crate::side_effects::cache_request_info_side_effect::CacheRequestInfoSideEffect;
use crate::sources::phoenix_source::PhoenixSource;
use crate::sources::thunder_source::ThunderSource;
use std::sync::Arc;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
use xai_candidate_pipeline::filter::Filter;
use xai_candidate_pipeline::hydrator::Hydrator;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
use xai_candidate_pipeline::scorer::Scorer;
use xai_candidate_pipeline::selector::Selector;
use xai_candidate_pipeline::side_effect::SideEffect;
use xai_candidate_pipeline::source::Source;
use xai_visibility_filtering::vf_client::{
ProdVisibilityFilteringClient, VisibilityFilteringClient,
};
pub struct PhoenixCandidatePipeline {
query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>>,
sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>>,
hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>>,
filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>>,
scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>>,
selector: TopKScoreSelector,
post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>>,
post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>>,
side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>>,
}
impl PhoenixCandidatePipeline {
async fn build_with_clients(
uas_fetcher: Arc<UserActionSequenceFetcher>,
phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
thunder_client: Arc<ThunderClient>,
strato_client: Arc<dyn StratoClient + Send + Sync>,
tes_client: Arc<dyn TESClient + Send + Sync>,
gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sync>,
vf_client: Arc<dyn VisibilityFilteringClient + Send + Sync>,
) -> PhoenixCandidatePipeline {
// Query Hydrators
let query_hydrators: Vec<Box<dyn QueryHydrator<ScoredPostsQuery>>> = vec![
Box::new(UserActionSeqQueryHydrator::new(uas_fetcher)),
Box::new(UserFeaturesQueryHydrator {
strato_client: strato_client.clone(),
}),
];
// Sources
let phoenix_source = Box::new(PhoenixSource {
phoenix_retrieval_client,
});
let thunder_source = Box::new(ThunderSource { thunder_client });
let sources: Vec<Box<dyn Source<ScoredPostsQuery, PostCandidate>>> =
vec![phoenix_source, thunder_source];
// Hydrators
let hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(InNetworkCandidateHydrator),
Box::new(CoreDataCandidateHydrator::new(tes_client.clone()).await),
Box::new(VideoDurationCandidateHydrator::new(tes_client.clone()).await),
Box::new(SubscriptionHydrator::new(tes_client.clone()).await),
Box::new(GizmoduckCandidateHydrator::new(gizmoduck_client).await),
];
// Filters
let filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> = vec![
Box::new(DropDuplicatesFilter),
Box::new(CoreDataHydrationFilter),
Box::new(AgeFilter::new(Duration::from_secs(params::MAX_POST_AGE))),
Box::new(SelfTweetFilter),
Box::new(RetweetDeduplicationFilter),
Box::new(IneligibleSubscriptionFilter),
Box::new(PreviouslySeenPostsFilter),
Box::new(PreviouslyServedPostsFilter),
Box::new(MutedKeywordFilter::new()),
Box::new(AuthorSocialgraphFilter),
];
// Scorers
let phoenix_scorer = Box::new(PhoenixScorer { phoenix_client });
let weighted_scorer = Box::new(WeightedScorer);
let author_diversity_scorer = Box::new(AuthorDiversityScorer::default());
let oon_scorer = Box::new(OONScorer);
let scorers: Vec<Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>> = vec![
phoenix_scorer,
weighted_scorer,
author_diversity_scorer,
oon_scorer,
];
// Selector
let selector = TopKScoreSelector;
// Post-selection hydrators
let post_selection_hydrators: Vec<Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>> =
vec![Box::new(VFCandidateHydrator::new(vf_client.clone()).await)];
// Post-selection filters
let post_selection_filters: Vec<Box<dyn Filter<ScoredPostsQuery, PostCandidate>>> =
vec![Box::new(VFFilter), Box::new(DedupConversationFilter)];
// Side Effects
let side_effects: Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> =
Arc::new(vec![Box::new(CacheRequestInfoSideEffect { strato_client })]);
PhoenixCandidatePipeline {
query_hydrators,
hydrators,
filters,
sources,
scorers,
selector,
post_selection_hydrators,
post_selection_filters,
side_effects,
}
}
pub async fn prod() -> PhoenixCandidatePipeline {
let uas_fetcher =
Arc::new(UserActionSequenceFetcher::new().expect("Failed to create UAS fetcher"));
let _sgs_client = Arc::new(SocialGraphClient::new());
let phoenix_client = Arc::new(
ProdPhoenixPredictionClient::new()
.await
.expect("Failed to create Phoenix prediction client"),
);
let phoenix_retrieval_client = Arc::new(
ProdPhoenixRetrievalClient::new()
.await
.expect("Failed to create Phoenix retrieval client"),
);
let thunder_client = Arc::new(ThunderClient::new().await);
let strato_client = Arc::new(
ProdStratoClient::new()
.await
.expect("Failed to create Strato client"),
);
let tes_client = Arc::new(
ProdTESClient::new()
.await
.expect("Failed to create TES client"),
);
let gizmoduck_client = Arc::new(
ProdGizmoduckClient::new()
.await
.expect("Failed to create Gizmoduck client"),
);
let vf_client = Arc::new(
ProdVisibilityFilteringClient::new(
S2S_CHAIN_PATH.clone(),
S2S_CRT_PATH.clone(),
S2S_KEY_PATH.clone()
)
.await
.expect("Failed to create VF client"),
);
PhoenixCandidatePipeline::build_with_clients(
uas_fetcher,
phoenix_client,
phoenix_retrieval_client,
thunder_client,
strato_client,
tes_client,
gizmoduck_client,
vf_client,
)
.await
}
}
#[async_trait]
impl CandidatePipeline<ScoredPostsQuery, PostCandidate> for PhoenixCandidatePipeline {
fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<ScoredPostsQuery>>] {
&self.query_hydrators
}
fn sources(&self) -> &[Box<dyn Source<ScoredPostsQuery, PostCandidate>>] {
&self.sources
}
fn hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>] {
&self.hydrators
}
fn filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, PostCandidate>>] {
&self.filters
}
fn scorers(&self) -> &[Box<dyn Scorer<ScoredPostsQuery, PostCandidate>>] {
&self.scorers
}
fn selector(&self) -> &dyn Selector<ScoredPostsQuery, PostCandidate> {
&self.selector
}
fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, PostCandidate>>] {
&self.post_selection_hydrators
}
fn post_selection_filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, PostCandidate>>] {
&self.post_selection_filters
}
fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery, PostCandidate>>>> {
Arc::clone(&self.side_effects)
}
fn result_size(&self) -> usize {
params::RESULT_SIZE
}
}
================================================
FILE: home-mixer/candidate_pipeline/query.rs
================================================
use crate::candidate_pipeline::query_features::UserFeatures;
use crate::util::request_util::generate_request_id;
use xai_candidate_pipeline::candidate_pipeline::HasRequestId;
use xai_home_mixer_proto::ImpressionBloomFilterEntry;
use xai_twittercontext_proto::{GetTwitterContextViewer, TwitterContextViewer};
#[derive(Clone, Default, Debug)]
pub struct ScoredPostsQuery {
pub user_id: i64,
pub client_app_id: i32,
pub country_code: String,
pub language_code: String,
pub seen_ids: Vec<i64>,
pub served_ids: Vec<i64>,
pub in_network_only: bool,
pub is_bottom_request: bool,
pub bloom_filter_entries: Vec<ImpressionBloomFilterEntry>,
pub user_action_sequence: Option<xai_recsys_proto::UserActionSequence>,
pub user_features: UserFeatures,
pub request_id: String,
}
impl ScoredPostsQuery {
pub fn new(
user_id: i64,
client_app_id: i32,
country_code: String,
language_code: String,
seen_ids: Vec<i64>,
served_ids: Vec<i64>,
in_network_only: bool,
is_bottom_request: bool,
bloom_filter_entries: Vec<ImpressionBloomFilterEntry>,
) -> Self {
let request_id = format!("{}-{}", generate_request_id(), user_id);
Self {
user_id,
client_app_id,
country_code,
language_code,
seen_ids,
served_ids,
in_network_only,
is_bottom_request,
bloom_filter_entries,
user_action_sequence: None,
user_features: UserFeatures::default(),
request_id,
}
}
}
impl GetTwitterContextViewer for ScoredPostsQuery {
fn get_viewer(&self) -> Option<TwitterContextViewer> {
Some(TwitterContextViewer {
user_id: self.user_id,
client_application_id: self.client_app_id as i64,
request_country_code: self.country_code.clone(),
request_language_code: self.language_code.clone(),
..Default::default()
})
}
}
impl HasRequestId for ScoredPostsQuery {
fn request_id(&self) -> &str {
&self.request_id
}
}
================================================
FILE: home-mixer/candidate_pipeline/query_features.rs
================================================
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct UserFeatures {
pub muted_keywords: Vec<String>,
pub blocked_user_ids: Vec<i64>,
pub muted_user_ids: Vec<i64>,
pub followed_user_ids: Vec<i64>,
pub subscribed_user_ids: Vec<i64>,
}
================================================
FILE: home-mixer/filters/age_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::util::snowflake;
use std::time::Duration;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filter that removes tweets older than a specified duration.
pub struct AgeFilter {
pub max_age: Duration,
}
impl AgeFilter {
pub fn new(max_age: Duration) -> Self {
Self { max_age }
}
fn is_within_age(&self, tweet_id: i64) -> bool {
snowflake::duration_since_creation_opt(tweet_id)
.map(|age| age <= self.max_age)
.unwrap_or(false)
}
}
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for AgeFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (kept, removed): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| self.is_within_age(c.tweet_id));
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/author_socialgraph_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
// Remove candidates that are blocked or muted by the viewer
pub struct AuthorSocialgraphFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for AuthorSocialgraphFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let viewer_blocked_user_ids = query.user_features.blocked_user_ids.clone();
let viewer_muted_user_ids = query.user_features.muted_user_ids.clone();
if viewer_blocked_user_ids.is_empty() && viewer_muted_user_ids.is_empty() {
return Ok(FilterResult {
kept: candidates,
removed: Vec::new(),
});
}
let mut kept: Vec<PostCandidate> = Vec::new();
let mut removed: Vec<PostCandidate> = Vec::new();
for candidate in candidates {
let author_id = candidate.author_id as i64;
let muted = viewer_muted_user_ids.contains(&author_id);
let blocked = viewer_blocked_user_ids.contains(&author_id);
if muted || blocked {
removed.push(candidate);
} else {
kept.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/core_data_hydration_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct CoreDataHydrationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for CoreDataHydrationFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (kept, removed) = candidates
.into_iter()
.partition(|c| c.author_id != 0 && !c.tweet_text.trim().is_empty());
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/dedup_conversation_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashMap;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Keeps only the highest-scored candidate per branch of a conversation tree
pub struct DedupConversationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for DedupConversationFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let mut kept: Vec<PostCandidate> = Vec::new();
let mut removed: Vec<PostCandidate> = Vec::new();
let mut best_per_convo: HashMap<u64, (usize, f64)> = HashMap::new();
for candidate in candidates {
let conversation_id = get_conversation_id(&candidate);
let score = candidate.score.unwrap_or(0.0);
if let Some((kept_idx, best_score)) = best_per_convo.get_mut(&conversation_id) {
if score > *best_score {
let previous = std::mem::replace(&mut kept[*kept_idx], candidate);
removed.push(previous);
*best_score = score;
} else {
removed.push(candidate);
}
} else {
let idx = kept.len();
best_per_convo.insert(conversation_id, (idx, score));
kept.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}
fn get_conversation_id(candidate: &PostCandidate) -> u64 {
candidate
.ancestors
.iter()
.copied()
.min()
.unwrap_or(candidate.tweet_id as u64)
}
================================================
FILE: home-mixer/filters/drop_duplicates_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct DropDuplicatesFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for DropDuplicatesFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let mut seen_ids = HashSet::new();
let mut kept = Vec::new();
let mut removed = Vec::new();
for candidate in candidates {
if seen_ids.insert(candidate.tweet_id) {
kept.push(candidate);
} else {
removed.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/ineligible_subscription_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filters out subscription-only posts from authors the viewer is not subscribed to.
pub struct IneligibleSubscriptionFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for IneligibleSubscriptionFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let subscribed_user_ids: HashSet<u64> = query
.user_features
.subscribed_user_ids
.iter()
.map(|id| *id as u64)
.collect();
let (kept, removed): (Vec<_>, Vec<_>) =
candidates
.into_iter()
.partition(|candidate| match candidate.subscription_author_id {
Some(author_id) => subscribed_user_ids.contains(&author_id),
None => true,
});
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/mod.rs
================================================
pub mod age_filter;
pub mod author_socialgraph_filter;
pub mod core_data_hydration_filter;
pub mod dedup_conversation_filter;
pub mod drop_duplicates_filter;
pub mod ineligible_subscription_filter;
pub mod muted_keyword_filter;
pub mod previously_seen_posts_filter;
pub mod previously_served_posts_filter;
pub mod retweet_deduplication_filter;
pub mod self_tweet_filter;
pub mod vf_filter;
================================================
FILE: home-mixer/filters/muted_keyword_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
use xai_post_text::{MatchTweetGroup, TokenSequence, TweetTokenizer, UserMutes};
pub struct MutedKeywordFilter {
pub tokenizer: Arc<TweetTokenizer>,
}
impl MutedKeywordFilter {
pub fn new() -> Self {
let tokenizer = TweetTokenizer::new();
Self {
tokenizer: Arc::new(tokenizer),
}
}
}
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for MutedKeywordFilter {
#[xai_stats_macro::receive_stats]
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let muted_keywords = query.user_features.muted_keywords.clone();
if muted_keywords.is_empty() {
return Ok(FilterResult {
kept: candidates,
removed: vec![],
});
}
let tokenized = muted_keywords.iter().map(|k| self.tokenizer.tokenize(k));
let token_sequences: Vec<TokenSequence> = tokenized.collect::<Vec<_>>();
let user_mutes = UserMutes::new(token_sequences);
let matcher = MatchTweetGroup::new(user_mutes);
let mut kept = Vec::new();
let mut removed = Vec::new();
for candidate in candidates {
let tweet_text_token_sequence = self.tokenizer.tokenize(&candidate.tweet_text);
if matcher.matches(&tweet_text_token_sequence) {
// Matches muted keywords - should be removed/filtered out
removed.push(candidate);
} else {
// Does not match muted keywords - keep it
kept.push(candidate);
}
}
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/previously_seen_posts_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::util::bloom_filter::BloomFilter;
use crate::util::candidates_util::get_related_post_ids;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filter out previously seen posts using a Bloom Filter and
/// the seen IDs sent in the request directly from the client
pub struct PreviouslySeenPostsFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for PreviouslySeenPostsFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let bloom_filters = query
.bloom_filter_entries
.iter()
.map(BloomFilter::from_entry)
.collect::<Vec<_>>();
let (removed, kept): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
get_related_post_ids(c).iter().any(|&post_id| {
query.seen_ids.contains(&post_id)
|| bloom_filters
.iter()
.any(|filter| filter.may_contain(post_id))
})
});
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/previously_served_posts_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::util::candidates_util::get_related_post_ids;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
pub struct PreviouslyServedPostsFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for PreviouslyServedPostsFilter {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
query.is_bottom_request
}
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (removed, kept): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
get_related_post_ids(c)
.iter()
.any(|id| query.served_ids.contains(id))
});
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/retweet_deduplication_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use std::collections::HashSet;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Deduplicates retweets, keeping only the first occurrence of a tweet
/// (whether as an original or as a retweet).
pub struct RetweetDeduplicationFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for RetweetDeduplicationFilter {
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let mut seen_tweet_ids: HashSet<u64> = HashSet::new();
let mut kept = Vec::new();
let mut removed = Vec::new();
for candidate in candidates {
match candidate.retweeted_tweet_id {
Some(retweeted_id) => {
// Remove if we've already seen this tweet (as original or retweet)
if seen_tweet_ids.insert(retweeted_id) {
kept.push(candidate);
} else {
removed.push(candidate);
}
}
None => {
// Mark this original tweet ID as seen so retweets of it get filtered
seen_tweet_ids.insert(candidate.tweet_id as u64);
kept.push(candidate);
}
}
}
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/self_tweet_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
/// Filter that removes tweets where the author is the viewer.
pub struct SelfTweetFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for SelfTweetFilter {
async fn filter(
&self,
query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let viewer_id = query.user_id as u64;
let (kept, removed): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| c.author_id != viewer_id);
Ok(FilterResult { kept, removed })
}
}
================================================
FILE: home-mixer/filters/vf_filter.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use tonic::async_trait;
use xai_candidate_pipeline::filter::{Filter, FilterResult};
use xai_visibility_filtering::models::{Action, FilteredReason};
pub struct VFFilter;
#[async_trait]
impl Filter<ScoredPostsQuery, PostCandidate> for VFFilter {
#[xai_stats_macro::receive_stats]
async fn filter(
&self,
_query: &ScoredPostsQuery,
candidates: Vec<PostCandidate>,
) -> Result<FilterResult<PostCandidate>, String> {
let (removed, kept): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| should_drop(&c.visibility_reason));
Ok(FilterResult { kept, removed })
}
}
fn should_drop(reason: &Option<FilteredReason>) -> bool {
match reason {
Some(FilteredReason::SafetyResult(safety_result)) => {
matches!(safety_result.action, Action::Drop(_))
}
Some(_) => true,
None => false,
}
}
================================================
FILE: home-mixer/lib.rs
================================================
mod candidate_hydrators;
mod candidate_pipeline;
pub mod clients; // Excluded from open source release for security reasons
mod filters;
pub mod params; // Excluded from open source release for security reasons
mod query_hydrators;
pub mod scorers;
mod selectors;
mod server;
mod side_effects;
mod sources;
pub mod util; // Excluded from open source release for security reasons
pub use server::HomeMixerServer;
================================================
FILE: home-mixer/main.rs
================================================
use clap::Parser;
use log::info;
use std::time::Duration;
use tonic::codec::CompressionEncoding;
use tonic::service::RoutesBuilder;
use tonic_reflection::server::Builder;
use xai_home_mixer_proto as pb;
use xai_http_server::{CancellationToken, GrpcConfig, HttpServer};
use xai_home_mixer::HomeMixerServer;
use xai_home_mixer::params;
#[derive(Parser, Debug)]
#[command(about = "HomeMixer gRPC Server")]
struct Args {
#[arg(long)]
grpc_port: u16,
#[arg(long)]
metrics_port: u16,
#[arg(long)]
reload_interval_minutes: u64,
#[arg(long)]
chunk_size: usize,
}
#[xai_stats_macro::main(name = "home-mixer")]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
xai_init_utils::init().log();
xai_init_utils::init().rustls();
info!(
"Starting server with gRPC port: {}, metrics port: {}, reload interval: {} minutes, chunk size: {}",
args.grpc_port, args.metrics_port, args.reload_interval_minutes, args.chunk_size,
);
// Create the service implementation
let service = HomeMixerServer::new().await;
// Keep a reference to stats_receiver before service is moved
let reflection_service = Builder::configure()
.register_encoded_file_descriptor_set(pb::FILE_DESCRIPTOR_SET)
.build_v1()?;
let mut grpc_routes = RoutesBuilder::default();
grpc_routes.add_service(
pb::scored_posts_service_server::ScoredPostsServiceServer::new(service)
.max_decoding_message_size(params::MAX_GRPC_MESSAGE_SIZE)
.max_encoding_message_size(params::MAX_GRPC_MESSAGE_SIZE)
.accept_compressed(CompressionEncoding::Gzip)
.accept_compressed(CompressionEncoding::Zstd)
.send_compressed(CompressionEncoding::Gzip)
.send_compressed(CompressionEncoding::Zstd),
);
grpc_routes.add_service(reflection_service);
let grpc_config = GrpcConfig::new(args.grpc_port, grpc_routes.routes());
let http_router = axum::Router::default();
let mut server = HttpServer::new(
args.metrics_port,
http_router,
Some(grpc_config),
CancellationToken::new(),
Duration::from_secs(20),
)
.await?;
server.set_readiness(true);
info!("Server ready");
server.wait_for_termination().await;
info!("Server shutdown complete");
Ok(())
}
================================================
FILE: home-mixer/query_hydrators/mod.rs
================================================
pub mod user_action_seq_query_hydrator;
pub mod user_features_query_hydrator;
================================================
FILE: home-mixer/query_hydrators/user_action_seq_query_hydrator.rs
================================================
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::uas_fetcher::{UserActionSequenceFetcher, UserActionSequenceOps};
use crate::params as p;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tonic::async_trait;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
use xai_recsys_aggregation::aggregation::{DefaultAggregator, UserActionAggregator};
use xai_recsys_aggregation::filters::{
AggregatedActionFilter, DenseAggregatedActionFilter, KeepOriginalUserActionFilter,
UserActionFilter,
};
use xai_recsys_proto::{
AggregatedUserActionList, Mask, MaskType, UserActionSequence, UserActionSequenceDataContainer,
UserActionSequenceMeta, user_action_sequence_data_container::Data as ProtoDataContainer,
};
use xai_uas_thrift::convert::thrift_to_proto_aggregated_user_action;
use xai_uas_thrift::user_action_sequence::{
AggregatedUserAction as ThriftAggregatedUserAction,
UserActionSequence as ThriftUserActionSequence,
UserActionSequenceMeta as ThriftUserActionSequenceMeta,
};
/// Hydrate a sequence that captures the user's recent actions
pub struct UserActionSeqQueryHydrator {
pub uas_fetcher: Arc<UserActionSequenceFetcher>,
global_filter: Arc<dyn UserActionFilter>,
aggregator: Arc<dyn UserActionAggregator>,
post_filters: Vec<Arc<dyn AggregatedActionFilter>>,
}
impl UserActionSeqQueryHydrator {
pub fn new(uas_fetcher: Arc<UserActionSequenceFetcher>) -> Self {
Self {
uas_fetcher,
global_filter: Arc::new(KeepOriginalUserActionFilter::new()),
aggregator: Arc::new(DefaultAggregator),
post_filters: vec![Arc::new(DenseAggregatedActionFilter::new())],
}
}
}
#[async_trait]
impl QueryHydrator<ScoredPostsQuery> for UserActionSeqQueryHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPostsQuery, String> {
let uas_thrift = self
.uas_fetcher
.get_by_user_id(query.user_id)
.await
.map_err(|e| format!("Failed to fetch user action sequence: {}", e))?;
let aggregated_uas_proto =
self.aggregate_user_action_sequence(query.user_id, uas_thrift)?;
Ok(ScoredPostsQuery {
user_action_sequence: Some(aggregated_uas_proto),
..Default::default()
})
}
fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQuery) {
query.user_action_sequence = hydrated.user_action_sequence;
}
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
impl UserActionSeqQueryHydrator {
fn aggregate_user_action_sequence(
&self,
user_id: i64,
uas_thrift: ThriftUserActionSequence,
) -> Result<UserActionSequence, String> {
// Extract user_actions from thrift sequence
let thrift_user_actions = uas_thrift.user_actions.clone().unwrap_or_default();
if thrift_user_actions.is_empty() {
return Err(format!("No user actions found for user {}", user_id));
}
// Pre-aggregation filter
let filtered_actions = self.global_filter.run(thrift_user_actions);
if filtered_actions.is_empty() {
return Err(format!(
"No user actions remaining after filtering for user {}",
user_id
));
}
// Aggregate
let mut aggregated_actions =
self.aggregator
.run(&filtered_actions, p::UAS_WINDOW_TIME_MS, 0);
// Post-aggregation filters
for filter in &self.post_filters {
aggregated_actions = filter.run(aggregated_actions);
}
// Truncate to max sequence length (keep last N items)
if aggregated_actions.len() > p::UAS_MAX_SEQUENCE_LENGTH {
let drain_count = aggregated_actions.len() - p::UAS_MAX_SEQUENCE_LENGTH;
aggregated_actions.drain(0..drain_count);
}
// Convert to proto format
let original_metadata = uas_thrift.metadata.clone().unwrap_or_default();
convert_to_proto_sequence(
user_id,
original_metadata,
aggregated_actions,
self.aggregator.name(),
)
}
}
fn convert_to_proto_sequence(
user_id: i64,
original_metadata: ThriftUserActionSequenceMeta,
aggregated_actions: Vec<ThriftAggregatedUserAction>,
aggregator_name: &str,
) -> Result<UserActionSequence, String> {
if aggregated_actions.is_empty() {
return Err("Cannot create sequence from empty aggregated actions".to_string());
}
let first_sequence_time = aggregated_actions
.first()
.and_then(|a| a.impressed_time_ms)
.unwrap_or(0) as u64;
let last_sequence_time = aggregated_actions
.last()
.and_then(|a| a.impressed_time_ms)
.unwrap_or(0) as u64;
// Preserve lastModifiedEpochMs and lastKafkaPublishEpochMs from original metadata
let last_modified_epoch_ms = original_metadata.last_modified_epoch_ms.unwrap_or(0) as u64;
let previous_kafka_publish_epoch_ms =
original_metadata.last_kafka_publish_epoch_ms.unwrap_or(0) as u64;
let proto_metadata = UserActionSequenceMeta {
length: aggregated_actions.len() as u64,
first_sequence_time,
last_sequence_time,
last_modified_epoch_ms,
previous_kafka_publish_epoch_ms,
};
// Convert thrift aggregated actions to proto
let mut proto_agg_actions = Vec::with_capacity(aggregated_actions.len());
for action in aggregated_actions {
proto_agg_actions.push(
thrift_to_proto_aggregated_user_action(action)
.map_err(|e| format!("Failed to convert aggregated action: {}", e))?,
);
}
let aggregation_time_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let agg_list = AggregatedUserActionList {
aggregated_user_actions: proto_agg_actions,
aggregation_provider: aggregator_name.to_string(),
aggregation_time_ms,
};
let mask = Mask {
mask_type: MaskType::NewEvent as i32,
mask: vec![false; agg_list.aggregated_user_actions.len()],
};
// Build the final UserActionSequence
Ok(UserActionSequence {
user_id: user_id as u64,
metadata: Some(proto_metadata),
user_actions_data: Some(UserActionSequenceDataContainer {
data: Some(ProtoDataContainer::OrderedAggregatedUserActionsList(
agg_list,
)),
}),
masks: vec![mask],
..Default::default()
})
}
================================================
FILE: home-mixer/query_hydrators/user_features_query_hydrator.rs
================================================
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::candidate_pipeline::query_features::UserFeatures;
use crate::clients::strato_client::StratoClient;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::query_hydrator::QueryHydrator;
use xai_strato::{StratoResult, StratoValue, decode};
pub struct UserFeaturesQueryHydrator {
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
}
#[async_trait]
impl QueryHydrator<ScoredPostsQuery> for UserFeaturesQueryHydrator {
#[xai_stats_macro::receive_stats]
async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPostsQuery, String> {
let user_id = query.user_id;
let client = &self.strato_client;
let result = client.get_user_features(user_id);
let result = result.await.map_err(|e| e.to_string())?;
let decoded: StratoResult<StratoValue<UserFeatures>> = decode(&result);
match decoded {
StratoResult::Ok(v) => {
let user_features = v.v.unwrap_or_default();
Ok(ScoredPostsQuery {
user_features,
..Default::default()
})
}
StratoResult::Err(_) => Err("Error received from strato".to_string()),
}
}
fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQuery) {
query.user_features = hydrated.user_features;
}
fn name(&self) -> &'static str {
std::any::type_name::<Self>()
}
}
================================================
FILE: home-mixer/scorers/author_diversity_scorer.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params as p;
use std::cmp::Ordering;
use std::collections::HashMap;
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
/// Diversify authors served within a single feed response
pub struct AuthorDiversityScorer {
decay_factor: f64,
floor: f64,
}
impl Default for AuthorDiversityScorer {
fn default() -> Self {
Self::new(p::AUTHOR_DIVERSITY_DECAY, p::AUTHOR_DIVERSITY_FLOOR)
}
}
impl AuthorDiversityScorer {
pub fn new(decay_factor: f64, floor: f64) -> Self {
Self {
decay_factor,
floor,
}
}
fn multiplier(&self, position: usize) -> f64 {
(1.0 - self.floor) * self.decay_factor.powf(position as f64) + self.floor
}
}
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for AuthorDiversityScorer {
#[xai_stats_macro::receive_stats]
async fn score(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let mut author_counts: HashMap<u64, usize> = HashMap::new();
let mut scored = vec![PostCandidate::default(); candidates.len()];
let mut ordered: Vec<(usize, &PostCandidate)> = candidates.iter().enumerate().collect();
ordered.sort_by(|(_, a), (_, b)| {
let a_score = a.weighted_score.unwrap_or(f64::NEG_INFINITY);
let b_score = b.weighted_score.unwrap_or(f64::NEG_INFINITY);
b_score.partial_cmp(&a_score).unwrap_or(Ordering::Equal)
});
for (original_idx, candidate) in ordered {
let entry = author_counts.entry(candidate.author_id).or_insert(0);
let position = *entry;
*entry += 1;
let multiplier = self.multiplier(position);
let adjusted_score = candidate.weighted_score.map(|score| score * multiplier);
let updated = PostCandidate {
score: adjusted_score,
..Default::default()
};
scored[original_idx] = updated;
}
Ok(scored)
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.score = scored.score;
}
}
================================================
FILE: home-mixer/scorers/mod.rs
================================================
pub mod author_diversity_scorer;
pub mod oon_scorer;
pub mod phoenix_scorer;
pub mod weighted_scorer;
================================================
FILE: home-mixer/scorers/oon_scorer.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params as p;
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
// Prioritize in-network candidates over out-of-network candidates
pub struct OONScorer;
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for OONScorer {
async fn score(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let updated_score = c.score.map(|base_score| match c.in_network {
Some(false) => base_score * p::OON_WEIGHT_FACTOR,
_ => base_score,
});
PostCandidate {
score: updated_score,
..Default::default()
}
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.score = scored.score;
}
}
================================================
FILE: home-mixer/scorers/phoenix_scorer.rs
================================================
use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::phoenix_prediction_client::PhoenixPredictionClient;
use crate::util::request_util;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
use xai_recsys_proto::{ActionName, ContinuousActionName};
pub struct PhoenixScorer {
pub phoenix_client: Arc<dyn PhoenixPredictionClient + Send + Sync>,
}
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for PhoenixScorer {
#[xai_stats_macro::receive_stats]
async fn score(
&self,
query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let user_id = query.user_id as u64;
let prediction_request_id = request_util::generate_request_id();
let last_scored_at_ms = Self::current_timestamp_millis();
if let Some(sequence) = &query.user_action_sequence {
let tweet_infos: Vec<xai_recsys_proto::TweetInfo> = candidates
.iter()
.map(|c| {
let tweet_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id as u64);
let author_id = c.retweeted_user_id.unwrap_or(c.author_id);
xai_recsys_proto::TweetInfo {
tweet_id,
author_id,
..Default::default()
}
})
.collect();
let result = self
.phoenix_client
.predict(user_id, sequence.clone(), tweet_infos)
.await;
if let Ok(response) = result {
let predictions_map = self.build_predictions_map(&response);
let scored_candidates = candidates
.iter()
.map(|c| {
// For retweets, look up predictions using the original tweet id
let lookup_tweet_id = c.retweeted_tweet_id.unwrap_or(c.tweet_id as u64);
let phoenix_scores = predictions_map
.get(&lookup_tweet_id)
.map(|preds| self.extract_phoenix_scores(preds))
.unwrap_or_default();
PostCandidate {
phoenix_scores,
prediction_request_id: Some(prediction_request_id),
last_scored_at_ms,
..Default::default()
}
})
.collect();
return Ok(scored_candidates);
}
}
// Return candidates unchanged if no scoring could be done
Ok(candidates.to_vec())
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.phoenix_scores = scored.phoenix_scores;
candidate.prediction_request_id = scored.prediction_request_id;
candidate.last_scored_at_ms = scored.last_scored_at_ms;
}
}
impl PhoenixScorer {
/// Builds Map[tweet_id -> ActionPredictions]
fn build_predictions_map(
&self,
response: &xai_recsys_proto::PredictNextActionsResponse,
) -> HashMap<u64, ActionPredictions> {
let mut predictions_map = HashMap::new();
let Some(distribution_set) = response.distribution_sets.first() else {
return predictions_map;
};
for distribution in &distribution_set.candidate_distributions {
let Some(candidate) = &distribution.candidate else {
continue;
};
let tweet_id = candidate.tweet_id;
let action_probs: HashMap<usize, f64> = distribution
.top_log_probs
.iter()
.enumerate()
.map(|(idx, log_prob)| (idx, (*log_prob as f64).exp()))
.collect();
let continuous_values: HashMap<usize, f64> = distribution
.continuous_actions_values
.iter()
.enumerate()
.map(|(idx, value)| (idx, *value as f64))
.collect();
predictions_map.insert(
tweet_id,
ActionPredictions {
action_probs,
continuous_values,
},
);
}
predictions_map
}
fn extract_phoenix_scores(&self, p: &ActionPredictions) -> PhoenixScores {
PhoenixScores {
favorite_score: p.get(ActionName::ServerTweetFav),
reply_score: p.get(ActionName::ServerTweetReply),
retweet_score: p.get(ActionName::ServerTweetRetweet),
photo_expand_score: p.get(ActionName::ClientTweetPhotoExpand),
click_score: p.get(ActionName::ClientTweetClick),
profile_click_score: p.get(ActionName::ClientTweetClickProfile),
vqv_score: p.get(ActionName::ClientTweetVideoQualityView),
share_score: p.get(ActionName::ClientTweetShare),
share_via_dm_score: p.get(ActionName::ClientTweetClickSendViaDirectMessage),
share_via_copy_link_score: p.get(ActionName::ClientTweetShareViaCopyLink),
dwell_score: p.get(ActionName::ClientTweetRecapDwelled),
quote_score: p.get(ActionName::ServerTweetQuote),
quoted_click_score: p.get(ActionName::ClientQuotedTweetClick),
follow_author_score: p.get(ActionName::ClientTweetFollowAuthor),
not_interested_score: p.get(ActionName::ClientTweetNotInterestedIn),
block_author_score: p.get(ActionName::ClientTweetBlockAuthor),
mute_author_score: p.get(ActionName::ClientTweetMuteAuthor),
report_score: p.get(ActionName::ClientTweetReport),
dwell_time: p.get_continuous(ContinuousActionName::DwellTime),
}
}
fn current_timestamp_millis() -> Option<u64> {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.ok()
.map(|duration| duration.as_millis() as u64)
}
}
struct ActionPredictions {
/// Map of action index -> probability (exp of log prob)
action_probs: HashMap<usize, f64>,
/// Map of continuous action index -> value
continuous_values: HashMap<usize, f64>,
}
impl ActionPredictions {
fn get(&self, action: ActionName) -> Option<f64> {
self.action_probs.get(&(action as usize)).copied()
}
fn get_continuous(&self, action: ContinuousActionName) -> Option<f64> {
self.continuous_values.get(&(action as usize)).copied()
}
}
================================================
FILE: home-mixer/scorers/weighted_scorer.rs
================================================
use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params as p;
use crate::util::score_normalizer::normalize_score;
use tonic::async_trait;
use xai_candidate_pipeline::scorer::Scorer;
pub struct WeightedScorer;
#[async_trait]
impl Scorer<ScoredPostsQuery, PostCandidate> for WeightedScorer {
#[xai_stats_macro::receive_stats]
async fn score(
&self,
_query: &ScoredPostsQuery,
candidates: &[PostCandidate],
) -> Result<Vec<PostCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let weighted_score = Self::compute_weighted_score(c);
let normalized_weighted_score = normalize_score(c, weighted_score);
PostCandidate {
weighted_score: Some(normalized_weighted_score),
..Default::default()
}
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
candidate.weighted_score = scored.weighted_score;
}
}
impl WeightedScorer {
fn apply(score: Option<f64>, weight: f64) -> f64 {
score.unwrap_or(0.0) * weight
}
fn compute_weighted_score(candidate: &PostCandidate) -> f64 {
let s: &PhoenixScores = &candidate.phoenix_scores;
let vqv_weight = Self::vqv_weight_eligibility(candidate);
let combined_score = Self::apply(s.favorite_score, p::FAVORITE_WEIGHT)
+ Self::apply(s.reply_score, p::REPLY_WEIGHT)
+ Self::apply(s.retweet_score, p::RETWEET_WEIGHT)
+ Self::apply(s.photo_expand_score, p::PHOTO_EXPAND_WEIGHT)
+ Self::apply(s.click_score, p::CLICK_WEIGHT)
+ Self::apply(s.profile_click_score, p::PROFILE_CLICK_WEIGHT)
+ Self::apply(s.vqv_score, vqv_weight)
+ Self::apply(s.share_score, p::SHARE_WEIGHT)
+ Self::apply(s.share_via_dm_score, p::SHARE_VIA_DM_WEIGHT)
+ Self::apply(s.share_via_copy_link_score, p::SHARE_VIA_COPY_LINK_WEIGHT)
+ Self::apply(s.dwell_score, p::DWELL_WEIGHT)
+ Self::apply(s.quote_score, p::QUOTE_WEIGHT)
+ Self::apply(s.quoted_click_score, p::QUOTED_CLICK_WEIGHT)
+ Self::apply(s.dwell_time, p::CONT_DWELL_TIME_WEIGHT)
+ Self::apply(s.follow_author_score, p::FOLLOW_AUTHOR_WEIGHT)
+ Self::apply(s.not_interested_score, p::NOT_INTERESTED_WEIGHT)
+ Self::apply(s.block_author_score, p::BLOCK_AUTHOR_WEIGHT)
+ Self::apply(s.mute_author_score, p::MUTE_AUTHOR_WEIGHT)
+ Self::apply(s.report_score, p::REPORT_WEIGHT);
Self::offset_score(combined_score)
}
fn vqv_weight_eligibility(candidate: &PostCandidate) -> f64 {
if candidate
.video_duration_ms
.is_some_and(|ms| ms > p::MIN_VIDEO_DURATION_MS)
{
p::VQV_WEIGHT
} else {
0.0
}
}
fn offset_score(combined_score: f64) -> f64 {
if p::WEIGHTS_SUM == 0.0 {
combined_score.max(0.0)
} else if combined_score < 0.0 {
(combined_score + p::NEGATIVE_WEIGHTS_SUM) / p::WEIGHTS_SUM * p::NEGATIVE_SCORES_OFFSET
} else {
combined_score + p::NEGATIVE_SCORES_OFFSET
}
}
}
================================================
FILE: home-mixer/selectors/mod.rs
================================================
mod top_k_score_selector;
pub use top_k_score_selector::TopKScoreSelector;
================================================
FILE: home-mixer/selectors/top_k_score_selector.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::params;
use xai_candidate_pipeline::selector::Selector;
pub struct TopKScoreSelector;
impl Selector<ScoredPostsQuery, PostCandidate> for TopKScoreSelector {
fn score(&self, candidate: &PostCandidate) -> f64 {
candidate.score.unwrap_or(f64::NEG_INFINITY)
}
fn size(&self) -> Option<usize> {
Some(params::TOP_K_CANDIDATES_TO_SELECT)
}
}
================================================
FILE: home-mixer/server.rs
================================================
use crate::candidate_pipeline::candidate::CandidateHelpers;
use crate::candidate_pipeline::phoenix_candidate_pipeline::PhoenixCandidatePipeline;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use log::info;
use std::sync::Arc;
use std::time::Instant;
use tonic::{Request, Response, Status};
use xai_candidate_pipeline::candidate_pipeline::CandidatePipeline;
use xai_home_mixer_proto as pb;
use xai_home_mixer_proto::{ScoredPost, ScoredPostsResponse};
pub struct HomeMixerServer {
phx_candidate_pipeline: Arc<PhoenixCandidatePipeline>,
}
impl HomeMixerServer {
pub async fn new() -> Self {
HomeMixerServer {
phx_candidate_pipeline: Arc::new(PhoenixCandidatePipeline::prod().await),
}
}
}
#[tonic::async_trait]
impl pb::scored_posts_service_server::ScoredPostsService for HomeMixerServer {
#[xai_stats_macro::receive_stats]
async fn get_scored_posts(
&self,
request: Request<pb::ScoredPostsQuery>,
) -> Result<Response<ScoredPostsResponse>, Status> {
let proto_query = request.into_inner();
if proto_query.viewer_id == 0 {
return Err(Status::invalid_argument("viewer_id must be specified"));
}
let start = Instant::now();
let query = ScoredPostsQuery::new(
proto_query.viewer_id,
proto_query.client_app_id,
proto_query.country_code,
proto_query.language_code,
proto_query.seen_ids,
proto_query.served_ids,
proto_query.in_network_only,
proto_query.is_bottom_request,
proto_query.bloom_filter_entries,
);
info!("Scored Posts request - request_id {}", query.request_id);
let pipeline_result = self.phx_candidate_pipeline.execute(query).await;
let scored_posts: Vec<ScoredPost> = pipeline_result
.selected_candidates
.into_iter()
.map(|candidate| {
let screen_names = candidate.get_screen_names();
ScoredPost {
tweet_id: candidate.tweet_id as u64,
author_id: candidate.author_id,
retweeted_tweet_id: candidate.retweeted_tweet_id.unwrap_or(0),
retweeted_user_id: candidate.retweeted_user_id.unwrap_or(0),
in_reply_to_tweet_id: candidate.in_reply_to_tweet_id.unwrap_or(0),
score: candidate.score.unwrap_or(0.0) as f32,
in_network: candidate.in_network.unwrap_or(false),
served_type: candidate.served_type.map(|t| t as i32).unwrap_or_default(),
last_scored_timestamp_ms: candidate.last_scored_at_ms.unwrap_or(0),
prediction_request_id: candidate.prediction_request_id.unwrap_or(0),
ancestors: candidate.ancestors,
screen_names,
visibility_reason: candidate.visibility_reason.map(|r| r.into()),
}
})
.collect();
info!(
"Scored Posts response - request_id {} - {} posts ({} ms)",
pipeline_result.query.request_id,
scored_posts.len(),
start.elapsed().as_millis()
);
Ok(Response::new(ScoredPostsResponse { scored_posts }))
}
}
================================================
FILE: home-mixer/side_effects/cache_request_info_side_effect.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::strato_client::StratoClient;
use std::env;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::side_effect::{SideEffect, SideEffectInput};
use xai_strato::{StratoResult, StratoValue, decode};
pub struct CacheRequestInfoSideEffect {
pub strato_client: Arc<dyn StratoClient + Send + Sync>,
}
#[async_trait]
impl SideEffect<ScoredPostsQuery, PostCandidate> for CacheRequestInfoSideEffect {
fn enable(&self, query: Arc<ScoredPostsQuery>) -> bool {
env::var("APP_ENV").unwrap_or_default() == "prod" && !query.in_network_only
}
async fn run(
&self,
input: Arc<SideEffectInput<ScoredPostsQuery, PostCandidate>>,
) -> Result<(), String> {
let user_id: i64 = input.query.user_id;
let post_ids: Vec<i64> = input
.selected_candidates
.iter()
.map(|c| c.tweet_id)
.collect();
let client = &self.strato_client;
let res = client
.store_request_info(user_id, post_ids)
.await
.map_err(|e| e.to_string())?;
let decoded: StratoResult<StratoValue<()>> = decode(&res);
match decoded {
StratoResult::Ok(_) => Ok(()),
StratoResult::Err(_) => Err("error received from strato".to_string()),
}
}
}
================================================
FILE: home-mixer/side_effects/mod.rs
================================================
pub mod cache_request_info_side_effect;
================================================
FILE: home-mixer/sources/mod.rs
================================================
pub mod phoenix_source;
pub mod thunder_source;
================================================
FILE: home-mixer/sources/phoenix_source.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::phoenix_retrieval_client::PhoenixRetrievalClient;
use crate::params as p;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::source::Source;
use xai_home_mixer_proto as pb;
pub struct PhoenixSource {
pub phoenix_retrieval_client: Arc<dyn PhoenixRetrievalClient + Send + Sync>,
}
#[async_trait]
impl Source<ScoredPostsQuery, PostCandidate> for PhoenixSource {
fn enable(&self, query: &ScoredPostsQuery) -> bool {
!query.in_network_only
}
#[xai_stats_macro::receive_stats]
async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>, String> {
let user_id = query.user_id as u64;
let sequence = query
.user_action_sequence
.as_ref()
.ok_or_else(|| "PhoenixSource: missing user_action_sequence".to_string())?;
let response = self
.phoenix_retrieval_client
.retrieve(user_id, sequence.clone(), p::PHOENIX_MAX_RESULTS)
.await
.map_err(|e| format!("PhoenixSource: {}", e))?;
let candidates: Vec<PostCandidate> = response
.top_k_candidates
.into_iter()
.flat_map(|scored_candidates| scored_candidates.candidates)
.filter_map(|scored_candidate| scored_candidate.candidate)
.map(|tweet_info| PostCandidate {
tweet_id: tweet_info.tweet_id as i64,
author_id: tweet_info.author_id,
in_reply_to_tweet_id: Some(tweet_info.in_reply_to_tweet_id),
served_type: Some(pb::ServedType::ForYouPhoenixRetrieval),
..Default::default()
})
.collect();
Ok(candidates)
}
}
================================================
FILE: home-mixer/sources/thunder_source.rs
================================================
use crate::candidate_pipeline::candidate::PostCandidate;
use crate::candidate_pipeline::query::ScoredPostsQuery;
use crate::clients::thunder_client::{ThunderClient, ThunderCluster};
use crate::params as p;
use std::sync::Arc;
use tonic::async_trait;
use xai_candidate_pipeline::source::Source;
use xai_home_mixer_proto as pb;
use xai_thunder_proto::GetInNetworkPostsRequest;
use xai_thunder_proto::in_network_posts_service_client::InNetworkPostsServiceClient;
pub struct ThunderSource {
pub thunder_client: Arc<ThunderClient>,
}
#[async_trait]
impl Source<ScoredPostsQuery, PostCandidate> for ThunderSource {
#[xai_stats_macro::receive_stats]
async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec<PostCandidate>, String> {
let cluster = ThunderCluster::Amp;
let channel = self
.thunder_client
.get_random_channel(cluster)
.ok_or_else(|| "ThunderSource: no available channel".to_string())?;
let mut client = InNetworkPostsServiceClient::new(channel.clone());
let following_list = &query.user_features.followed_user_ids;
let request = GetInNetworkPostsRequest {
user_id: query.user_id as u64,
following_user_ids: following_list.iter().map(|&id| id as u64).collect(),
max_results: p::THUNDER_MAX_RESULTS,
exclude_tweet_ids: vec![],
algorithm: "default".to_string(),
debug: false,
is_video_request: false,
};
let response = client
.get_in_network_posts(request)
.await
.map_err(|e| format!("ThunderSource: {}", e))?;
let candidates: Vec<PostCandidate> = response
.into_inner()
.posts
.into_iter()
.map(|post| {
let in_reply_to_tweet_id = post
.in_reply_to_post_id
.and_then(|id| u64::try_from(id).ok());
let conversation_id = post.conversation_id.and_then(|id| u64::try_from(id).ok());
let mut ancestors = Vec::new();
if let Some(reply_to) = in_reply_to_tweet_id {
ancestors.push(reply_to);
if let Some(root) = conversation_id.filter(|&root| root != reply_to) {
ancestors.push(root);
}
}
PostCandidate {
tweet_id: post.post_id,
author_id: post.author_id as u64,
in_reply_to_tweet_id,
ancestors,
served_type: Some(pb::ServedType::ForYouInNetwork),
..Default::default()
}
})
.collect();
Ok(candidates)
}
}
================================================
FILE: phoenix/README.md
================================================
# Phoenix: Recommendation System
This repository contains JAX example code for the Phoenix recommendation system, which powers content ranking and retrieval. Phoenix uses transformer-based architectures for both **retrieval** (finding relevant candidates from millions of items) and **ranking** (ordering a smaller set of candidates by predicted engagement).
> **Note:** The sample transformer implementation in this repository is ported from the [Grok-1 open source release](https://github.com/xai-org/grok-1) by xAI. The core transformer architecture comes from Grok-1, adapted here for recommendation system use cases with custom input embeddings and attention masking for candidate isolation. This code is representative of the model used internally with the exception of specific scaling optimizations.
## Table of Contents
- [Overview](#overview)
- [Architecture](#architecture)
- [Two-Stage Recommendation Pipeline](#two-stage-recommendation-pipeline)
- [Retrieval: Two-Tower Model](#retrieval-two-tower-model)
- [Ranking: Transformer with Candidate Isolation](#ranking-transformer-with-candidate-isolation)
- [Key Design Decisions](#key-design-decisions)
- [Running the Code](#running-the-code)
- [License](#license)
---
## Overview
Phoenix is a recommendation system that predicts user engagement (likes, reposts, replies, etc.) for content. It operates in two stages:
1. **Retrieval**: Efficiently narrow down millions of candidates to hundreds using approximate nearest neighbor (ANN) search
2. **Ranking**: Score and order the retrieved candidates using a more expressive transformer model
---
## Architecture
### Two-Stage Recommendation Pipeline
```
┌─────────────────────────────────────────────────────────────────────────────────┐
│ RECOMMENDATION PIPELINE │
├─────────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ │ │ │ │ │ │
│ │ User │────▶│ STAGE 1: │────▶│ STAGE 2: │────▶ Feed│
│ │ Request │ │ RETRIEVAL │ │ RANKING │ │
│ │ │ │ (Two-Tower) │ │ (Transformer) │ │
│ └──────────┘ │ │ │ │ │
│ │ Millions → 1000s │ │ 1000s → Ranked │ │
│ └─────────────────────┘ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────────┘
```
---
### Retrieval: Two-Tower Model
The retrieval stage uses a **two-tower architecture** that enables efficient similarity search at scale.
#### How Retrieval Works
1. **User Tower**: Encodes user features and engagement history through a transformer to produce a normalized user embedding `[B, D]`
2. **Candidate Tower**: Computes normalized embeddings for all items in the corpus `[N, D]`
3. **Similarity Search**: Retrieves top-K candidates using dot product similarity
---
### Ranking: Transformer with Candidate Isolation
The ranking model uses a transformer architecture where **candidates cannot attend to each other** during inference. This is a critical design choice that ensures the score for a candidate doesn't depend on which other candidates are in the batch
#### Ranking Model Architecture
```
PHOENIX RANKING MODEL
┌────────────────────────────────────────────────────────────────────────────┐
│ │
│ OUTPUT LOGITS │
│ [B, num_candidates, num_actions] │
│ │ │
│ │ Unembedding │
│ │ Projection │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ │ │ │
│ │ Extract Candidate Outputs │ │
│ │ (positions after history) │ │
│ │ │ │
│ └───────────────┬───────────────┘ │
│ │ │
│ ┌───────────────┴───────────────┐ │
│ │ │ │
│ │ Transformer │ │
│ │ (with special masking) │ │
│ │ │ │
│ │ Candidates CANNOT attend │ │
│ │ to each other │ │
│ │ │ │
│ └───────────────┬───────────────┘ │
│ │ │
│ ┌───────────────────────────────┼───────────────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────┐ ┌─────────────────┐ ┌────────────┐ │
│ │ User │ │ History │ │ Candidates │ │
│ │Embedding │ │ Embeddings │ │ Embeddings │ │
│ │ [B, 1] │ │ [B, S, D] │ │ [B, C, D] │ │
│ │ │ │ │ │ │ │
│ │ User │ │ Posts + Authors │ │ Posts + │ │
│ │ Hashes │ │ + Actions + │ │ Authors + │ │
│ │ │ │ Product Surface │ │ Product │ │
│ └──────────┘ └─────────────────┘ │ Surface │ │
│ └────────────┘ │
│ │
└────────────────────────────────────────────────────────────────────────────┘
```
#### Attention Mask: Candidate Isolation
A key detail is the **attention mask** that prevents candidates from attending to each other while still allowing them to attend to the user and history:
```
ATTENTION MASK VISUALIZATION
Keys (what we attend TO)
─────────────────────────────────────────────▶
│ User │ History (S positions) │ Candidates (C positions) │
┌────┼──────┼─────────────────────────────┼───────────────────────────────┤
│ │ │ │ │
│ U │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
│ │ │ │ │
├────┼──────┼─────────────────────────────┼───────────────────────────────┤
Q │ │ │ │ │
u │ H │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
e │ i │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
r │ s │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
i │ t │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✗ │
e │ │ │ │ │
s ├────┼──────┼─────────────────────────────┼───────────────────────────────┤
│ │ │ │ DIAGONAL ONLY (self-attend) │
│ │ C │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✓ ✗ ✗ ✗ ✗ ✗ ✗ │
│ │ a │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✓ ✗ ✗ ✗ ✗ ✗ │
│ │ n │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✓ ✗ ✗ ✗ ✗ │
│ │ d │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✓ ✗ ✗ ✗ │
│ │ i │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✓ ✗ ✗ │
│ │ d │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✓ ✗ │
▼ │ s │ ✓ │ ✓ ✓ ✓ ✓ ✓ ✓ ✓ │ ✗ ✗ ✗ ✗ ✗ ✗ ✓ │
│ │ │ │ │
└────┴──────┴─────────────────────────────┴───────────────────────────────┘
✓ = Can attend (1) ✗ = Cannot attend (0)
Legend:
├─ User + History: Full bidirectional attention among themselves
├─ Candidates → User/History: Candidates CAN attend to user and history
└─ Candidates → Candidates: Candidates CANNOT attend to each other (only self)
```
---
## Key Design Decisions
### 1. Hash-Based Embeddings
Both models use multiple hash functions for embedding lookup
### 2. Shared Architecture
The retrieval user tower uses the same transformer architecture as the ranking model
### 3. Multi-Action Prediction
The ranking model predicts multiple engagement types simultaneously:
```
Output: [B, num_candidates, num_actions]
│
▼
┌─────────────────────────────────────┐
│ Like │ Repost │ Reply │ Click │ ... │
└─────────────────────────────────────┘
```
---
## Running the Code
### Installation
Install [uv](https://docs.astral.sh/uv/getting-started/installation/)
### Running the Ranker
```shell
uv run run_ranker.py
```
### Running Retrieval
```shell
uv run run_retrieval.py
```
### Running Tests
```shell
uv run pytest test_recsys_model.py test_recsys_retrieval_model.py
```
================================================
FILE: phoenix/grok.py
================================================
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Sequence, Union
import haiku as hk
import jax
import jax.numpy as jnp
logger = logging.getLogger(__name__)
class TrainingState(NamedTuple):
"""Container for the training state."""
params: hk.Params
def ffn_size(emb_size, widening_factor):
_ffn_size = int(widening_factor * emb_size) * 2 // 3
_ffn_size = _ffn_size + (8 - _ffn_size) % 8 # ensure it's a multiple of 8
logger.debug(f"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}")
return _ffn_size
def make_recsys_attn_mask(
seq_len: int,
candidate_start_offset: int,
dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
"""Create attention mask for recommendation system inference.
Creates a mask where:
- Positions 0 to candidate_start_offset-1 (user+history): causal attention
- Positions candidate_start_offset onwards (candidates): can attend to user+history
and themselves (self-attention), but NOT to other candidates
This ensures each candidate is scored independently based on user+history context.
Args:
seq_len: Total sequence length (user + history + candidates)
candidate_start_offset: Position where candidates start in the sequence
dtype: Data type for the mask
Returns:
Attention mask of shape [1, 1, seq_len, seq_len] where 1 means "can attend"
"""
# Start with causal mask for the full sequence
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len), dtype=dtype))
# Zero out candidate-to-candidate attention (bottom-right block)
attn_mask = causal_mask.at[:, :, candidate_start_offset:, candidate_start_offset:].set(0)
# Add back self-attention for candidates (diagonal of the candidate block)
candidate_indices = jnp.arange(candidate_start_offset, seq_len)
attn_mask = attn_mask.at[:, :, candidate_indices, candidate_indices].set(1)
return attn_mask
class MHAOutput(NamedTuple):
"""Outputs of the multi-head attention operation."""
embeddings: jax.Array
class DecoderOutput(NamedTuple):
embeddings: jax.Array
class TransformerOutput(NamedTuple):
embeddings: jax.Array
@dataclass
class TransformerConfig:
emb_size: int
key_size: int
num_q_heads: int
num_kv_heads: int
num_layers: int
widening_factor: float = 4.0
attn_output_multiplier: float = 1.0
name: Optional[str] = None
def make(self) -> "Transformer":
return Transformer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
widening_factor=self.widening_factor,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
num_layers=self.num_layers,
)
def hk_rms_norm(
x: jax.Array,
fixed_scale=False,
) -> jax.Array:
"""Applies a unique LayerNorm to x with default settings."""
ln = RMSNorm(axis=-1, create_scale=not fixed_scale)
return ln(x)
class Linear(hk.Linear):
def __init__(
self,
output_size: int,
with_bias: bool = True,
name: Optional[str] = None,
):
super().__init__(
output_size=output_size,
with_bias=with_bias,
name=name,
)
def __call__( # type: ignore
self,
inputs: jax.Array,
) -> jax.Array:
"""Computes a linear transform of the input."""
fprop_dtype = inputs.dtype
if not inputs.shape:
raise ValueError("Input must not be scalar.")
input_size = inputs.shape[-1]
output_size = self.output_size
w = hk.get_parameter(
"w", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)
)
out = jnp.dot(inputs, w.astype(fprop_dtype))
if self.with_bias:
b = hk.get_parameter(
"b", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)
)
b = jnp.broadcast_to(b, out.shape)
out = out + b.astype(fprop_dtype)
return out
class RMSNorm(hk.RMSNorm):
def __init__(
self,
axis: Union[int, Sequence[int], slice],
eps: float = 1e-5,
name: Optional[str] = None,
create_scale: bool = True,
):
super().__init__(axis, eps, create_scale=create_scale, name=name)
def __call__(self, inputs: jax.Array):
fprop_dtype = inputs.dtype
param_shape = (inputs.shape[-1],)
if self.create_scale:
scale = hk.get_parameter(
"scale",
param_shape,
dtype=jnp.float32,
init=hk.initializers.Constant(0),
)
scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)
else:
scale = 1.0
inputs = inputs.astype(jnp.float32)
scale = jnp.float32(scale)
mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)
mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)
normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)
outputs = scale * normed_inputs
return outputs.astype(fprop_dtype)
def rotate_half(
x: jax.Array,
) -> jax.Array:
"""Obtain the rotated counterpart of each feature"""
x1, x2 = jnp.split(x, 2, axis=-1)
return jnp.concatenate((-x2, x1), axis=-1)
class RotaryEmbedding(hk.Module):
"""Applies rotary embeddings (RoPE) to the input sequence tensor,
as described in https://arxiv.org/abs/2104.09864.
Attributes:
dim (int): Dimensionality of the feature vectors
base_exponent (int): Base exponent to compute embeddings from
"""
def __init__(
self,
dim: int,
name: Optional[str] = None,
base_exponent: int = 10000,
):
super().__init__(name)
self.dim = dim
self.base_exponent = base_exponent
assert self.dim % 2 == 0
def __call__(
self,
x: jax.Array,
seq_dim: int,
offset: jax.Array,
const_position: Optional[int] = None,
t: Optional[jax.Array] = None,
) -> jax.Array:
fprop_dtype = x.dtype
# Compute the per-dimension frequencies
exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)
inv_freq = jnp.asarray(
1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32
)
if jnp.shape(offset) == ():
# Offset can be a scalar or one offset per batch element.
offset = jnp.expand_dims(offset, 0)
# Compute the per element phase (to pass into sin and cos)
if const_position:
t = const_position * jnp.ones(
(
1,
x.shape[seq_dim],
),
dtype=jnp.float32,
)
elif t is None:
t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)
phase = jnp.einsum("bi,j->bij", t, inv_freq)
phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]
x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)
x = x.astype(fprop_dtype)
return x
class MultiHeadAttention(hk.Module):
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
key_size: int,
*,
with_bias: bool = True,
value_size: Optional[int] = None,
model_size: Optional[int] = None,
attn_output_multiplier: float = 1.0,
name: Optional[str] = None,
):
super().__init__(name=name)
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.key_size = key_size
self.value_size = value_size or key_size
self.model_size = model_size or key_size * num_q_heads
self.attn_output_multiplier = attn_output_multiplier
self.with_bias = with_bias
def __call__(
self,
query: jax.Array,
key: jax.Array,
value: jax.Array,
mask: jax.Array,
) -> MHAOutput:
# In shape hints below, we suppress the leading dims [...] for brevity.
# Hence e.g. [A, B] should be read in every case as [..., A, B].
projection = self._linear_projection
# Check that the keys and values have consistent batch size and sequence length.
assert key.shape[:2] == value.shape[:2], f"key/value shape: {key.shape}/{value.shape}"
if mask is not None:
assert mask.ndim == 4
assert mask.shape[0] in {
1,
query.shape[0],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert key.shape[0] in {
1,
query.shape[0],
}, f"key/query shape: {key.shape}/{query.shape}"
assert mask.shape[1] == 1
assert mask.shape[2] in {
1,
query.shape[1],
}, f"mask/query shape: {mask.shape}/{query.shape}"
assert mask.shape[3] in {
1,
key.shape[1],
}, f"mask/query shape: {mask.shape}/{key.shape}"
# Compute key/query/values (overload K/Q/V to denote the respective sizes).
assert self.num_q_heads % self.num_kv_heads == 0
query_heads = projection(query, self.key_size, self.num_q_heads, name="query")
key_heads = projection(key, self.key_size, self.num_kv_heads, name="key")
value_heads = projection(value, self.value_size, self.num_kv_heads, name="value")
rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))
key_heads = rotate(key_heads, seq_dim=1, offset=0)
query_heads = rotate(query_heads, seq_dim=1, offset=0)
b, t, h, d = query_heads.shape
_, _, kv_h, _ = key_heads.shape
assert h % kv_h == 0, f"query_heads {h} must be a multiple of kv_heads {kv_h}"
query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))
# Compute attention weights.
# Attention softmax is always carried out in fp32.
attn_logits = jnp.einsum("...thHd,...Thd->...hHtT", query_heads, key_heads).astype(
jnp.float32
)
attn_logits *= self.attn_output_multiplier
max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)
attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)
mask = mask[:, :, None, :, :]
if mask is not None:
if mask.ndim != attn_logits.ndim:
raise ValueError(
f"Mask dimensionality {mask.ndim} must match logits dimensionality "
f"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}."
)
attn_logits = jnp.where(mask, attn_logits, -1e30)
attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype) # [H, T', T]
# Weight the values by the attention and flatten the head vectors.
attn = jnp.einsum("...hHtT,...Thd->...thHd", attn_weights, value_heads)
leading_dims = attn.shape[:2]
attn = jnp.reshape(attn, (*leading_dims, -1)) # [T', H*V]
# Apply another projection to get the final embeddings.
final_projection = Linear(self.model_size, with_bias=False)
return MHAOutput(final_projection(attn))
@hk.transparent
def _linear_projection(
self,
x: jax.Array,
head_size: int,
num_heads: int,
name: Optional[str] = None,
) -> jax.Array:
y = Linear(num_heads * head_size, with_bias=False, name=name)(x)
*leading_dims, _ = x.shape
return y.reshape((*leading_dims, num_heads, head_size))
@dataclass
class MHABlock(hk.Module):
"""A MHA Block"""
num_q_heads: int
num_kv_heads: int
key_size: int
attn_output_multiplier: float = 1.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]
) -> MHAOutput:
_, _, model_size = inputs.shape
assert mask.ndim == 4, f"shape: {mask.shape}"
assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)
assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)
side_input = inputs
def attn_block(query, key, value, mask) -> MHAOutput:
return MultiHeadAttention(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
model_size=model_size,
attn_output_multiplier=self.attn_output_multiplier,
)(query, key, value, mask)
attn_output = attn_block(inputs, side_input, side_input, mask)
h_attn = attn_output.embeddings
return MHAOutput(embeddings=h_attn)
@dataclass
class DenseBlock(hk.Module):
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float = 4.0
@hk.transparent
def __call__(
self,
inputs: jax.Array, # [B, T, D]
) -> jax.Array: # [B, T, D]
_, _, model_size = inputs.shape
h_v = Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
name="linear_v",
)(inputs)
h_w1 = jax.nn.gelu(
Linear(
ffn_size(model_size, self.widening_factor),
with_bias=False,
)(inputs)
)
h_dense = Linear(model_size, with_bias=False)(h_w1 * h_v)
return h_dense
@dataclass
class DecoderLayer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
num_layers: int
layer_index: Optional[int] = None
widening_factor: float = 4.0
name: Optional[str] = None
attn_output_multiplier: float = 1.0
def __call__(
self,
inputs: jax.Array, # [B, T, D]
mask: jax.Array, # [B, 1, T, T] or [B, 1, 1, T]
padding_mask: Optional[jax.Array],
) -> DecoderOutput:
"""Transforms input embedding sequences to output embedding sequences."""
del padding_mask # Unused.
def layer_norm(x):
return hk_rms_norm(x)
h = inputs
attn_output = MHABlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
attn_output_multiplier=self.attn_output_multiplier,
)(layer_norm(h), mask)
h_attn = attn_output.embeddings
h_attn = layer_norm(h_attn)
h += h_attn
def base_dense_block(h):
h = DenseBlock(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=self.widening_factor,
)(h)
return h
h_dense = base_dense_block(layer_norm(h))
h_dense = layer_norm(h_dense)
h += h_dense
return DecoderOutput(
embeddings=h,
)
def layer_norm(x):
return hk_rms_norm(x)
@dataclass
class Transformer(hk.Module):
"""A transformer stack."""
num_q_heads: int
num_kv_heads: int
key_size: int
widening_factor: float
attn_output_multiplier: float
num_layers: int
name: Optional[str] = None
def __call__(
self,
embeddings: jax.Array, # [B, T, D]
mask: jax.Array, # [B, T]
candidate_start_offset: Optional[int] = None,
) -> TransformerOutput:
"""Transforms input embedding sequences to output embedding sequences.
Args:
embeddings: Input embeddings of shape [B, T, D]
mask: Padding mask of shape [B, T], True for valid positions
candidate_start_offset: If provided, positions >= this offset are treated as
candidates that can only attend to positions before the offset (user+history)
and themselves (self-attention), but not to other candidates.
Used for recommendation system inference.
Returns:
TransformerOutput containing the output embeddings.
"""
fprop_dtype = embeddings.dtype
_, seq_len, _ = embeddings.shape
padding_mask = mask.copy()
mask = mask[:, None, None, :] # [B, H=1, T'=1, T]
if candidate_start_offset is not None:
# Use recommendation system attention mask where candidates attend to
# user+history and themselves, but not to other candidates
attn_mask = make_recsys_attn_mask(seq_len, candidate_start_offset, fprop_dtype)
mask = mask * attn_mask
else:
# Standard causal mask for autoregressive sequence modelling
causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(
fprop_dtype
) # [B=1, H=1, T, T]
mask = mask * causal_mask # [B, H=1, T, T]
h = embeddings
def block(
h,
mask,
padding_mask,
layer_index: Optional[int] = None,
widening_factor: Optional[int] = None,
name: Optional[str] = None,
) -> DecoderOutput:
return DecoderLayer(
num_q_heads=self.num_q_heads,
num_kv_heads=self.num_kv_heads,
key_size=self.key_size,
widening_factor=widening_factor or self.widening_factor,
num_layers=self.num_layers,
attn_output_multiplier=self.attn_output_multiplier,
name=name,
layer_index=layer_index,
)(h, mask, padding_mask)
for i in range(self.num_layers):
decoder_output = block(
h,
mask,
padding_mask,
layer_index=i,
name=f"decoder_layer_{i}",
)
h = decoder_output.embeddings
return TransformerOutput(
embeddings=h,
)
================================================
FILE: phoenix/pyproject.toml
================================================
[project]
name = "grok-1"
version = "0.1.0"
description = "Grok-1 model"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"dm-haiku>=0.0.13",
"jax==0.8.1",
"numpy>=1.26.4",
"pyright>=1.1.408",
]
[tool.uv]
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
]
[tool.ruff]
indent-width = 4
line-length = 100
[tool.ruff.lint]
ignore = [
"E722",
"E731",
"E741",
"F405",
"E402",
"F403",
]
select = ["ISC001"]
[dependency-groups]
dev = [
"pytest",
]
================================================
FILE: phoenix/recsys_model.py
================================================
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
from grok import (
TransformerConfig,
Transformer,
layer_norm,
)
logger = logging.getLogger(__name__)
@dataclass
class HashConfig:
"""Configuration for hash-based embeddings."""
num_user_hashes: int = 2
num_item_hashes: int = 2
num_author_hashes: int = 2
@dataclass
class RecsysEmbeddings:
"""Container for pre-looked-up embeddings from the embedding tables.
These embeddings are looked up from hash tables before being passed to the model.
The block_*_reduce functions will combine multiple hash embeddings into single representations.
"""
user_embeddings: jax.typing.ArrayLike
history_post_embeddings: jax.typing.ArrayLike
candidate_post_embeddings: jax.typing.ArrayLike
history_author_embeddings: jax.typing.ArrayLike
candidate_author_embeddings: jax.typing.ArrayLike
class RecsysModelOutput(NamedTuple):
"""Output of the recommendation model."""
logits: jax.Array
class RecsysBatch(NamedTuple):
"""Input batch for the recommendation model.
Contains the feature data (hashes, actions, product surfaces) but NOT the embeddings.
Embeddings are passed separately via RecsysEmbeddings.
"""
user_hashes: jax.typing.ArrayLike
history_post_hashes: jax.typing.ArrayLike
history_author_hashes: jax.typing.ArrayLike
history_actions: jax.typing.ArrayLike
history_product_surface: jax.typing.ArrayLike
candidate_post_hashes: jax.typing.ArrayLike
candidate_author_hashes: jax.typing.ArrayLike
candidate_product_surface: jax.typing.ArrayLike
def block_user_reduce(
user_hashes: jnp.ndarray,
user_embeddings: jnp.ndarray,
num_user_hashes: int,
emb_size: int,
embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
"""Combine multiple user hash embeddings into a single user representation.
Args:
user_hashes: [B, num_user_hashes] - hash values (0 = invalid/padding)
user_embeddings: [B, num_user_hashes, D] - looked-up embeddings
num_user_hashes: number of hash functions used
emb_size: embedding dimension D
embed_init_scale: initialization scale for projection
Returns:
user_embedding: [B, 1, D] - combined user embedding
user_padding_mask: [B, 1] - True where user is valid
"""
B = user_embeddings.shape[0]
D = emb_size
user_embedding = user_embeddings.reshape((B, 1, num_user_hashes * D))
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_1 = hk.get_parameter(
"proj_mat_1",
[num_user_hashes * D, D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
user_embedding = jnp.dot(user_embedding.astype(proj_mat_1.dtype), proj_mat_1).astype(
user_embeddings.dtype
)
# hash 0 is reserved for padding)
user_padding_mask = (user_hashes[:, 0] != 0).reshape(B, 1).astype(jnp.bool_)
return user_embedding, user_padding_mask
def block_history_reduce(
history_post_hashes: jnp.ndarray,
history_post_embeddings: jnp.ndarray,
history_author_embeddings: jnp.ndarray,
history_product_surface_embeddings: jnp.ndarray,
history_actions_embeddings: jnp.ndarray,
num_item_hashes: int,
num_author_hashes: int,
embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
"""Combine history embeddings (post, author, actions, product_surface) into sequence.
Args:
history_post_hashes: [B, S, num_item_hashes]
history_post_embeddings: [B, S, num_item_hashes, D]
history_author_embeddings: [B, S, num_author_hashes, D]
history_product_surface_embeddings: [B, S, D]
history_actions_embeddings: [B, S, D]
num_item_hashes: number of hash functions for items
num_author_hashes: number of hash functions for authors
emb_size: embedding dimension D
embed_init_scale: initialization scale
Returns:
history_embeddings: [B, S, D]
history_padding_mask: [B, S]
"""
B, S, _, D = history_post_embeddings.shape
history_post_embeddings_reshaped = history_post_embeddings.reshape((B, S, num_item_hashes * D))
history_author_embeddings_reshaped = history_author_embeddings.reshape(
(B, S, num_author_hashes * D)
)
post_author_embedding = jnp.concatenate(
[
history_post_embeddings_reshaped,
history_author_embeddings_reshaped,
history_actions_embeddings,
history_product_surface_embeddings,
],
axis=-1,
)
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_3 = hk.get_parameter(
"proj_mat_3",
[post_author_embedding.shape[-1], D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
history_embedding = jnp.dot(post_author_embedding.astype(proj_mat_3.dtype), proj_mat_3).astype(
post_author_embedding.dtype
)
history_embedding = history_embedding.reshape(B, S, D)
history_padding_mask = (history_post_hashes[:, :, 0] != 0).reshape(B, S)
return history_embedding, history_padding_mask
def block_candidate_reduce(
candidate_post_hashes: jnp.ndarray,
candidate_post_embeddings: jnp.ndarray,
candidate_author_embeddings: jnp.ndarray,
candidate_product_surface_embeddings: jnp.ndarray,
num_item_hashes: int,
num_author_hashes: int,
embed_init_scale: float = 1.0,
) -> Tuple[jax.Array, jax.Array]:
"""Combine candidate embeddings (post, author, product_surface) into sequence.
Args:
candidate_post_hashes: [B, C, num_item_hashes]
candidate_post_embeddings: [B, C, num_item_hashes, D]
candidate_author_embeddings: [B, C, num_author_hashes, D]
candidate_product_surface_embeddings: [B, C, D]
num_item_hashes: number of hash functions for items
num_author_hashes: number of hash functions for authors
emb_size: embedding dimension D
embed_init_scale: initialization scale
Returns:
candidate_embeddings: [B, C, D]
candidate_padding_mask: [B, C]
"""
B, C, _, D = candidate_post_embeddings.shape
candidate_post_embeddings_reshaped = candidate_post_embeddings.reshape(
(B, C, num_item_hashes * D)
)
candidate_author_embeddings_reshaped = candidate_author_embeddings.reshape(
(B, C, num_author_hashes * D)
)
post_author_embedding = jnp.concatenate(
[
candidate_post_embeddings_reshaped,
candidate_author_embeddings_reshaped,
candidate_product_surface_embeddings,
],
axis=-1,
)
embed_init = hk.initializers.VarianceScaling(embed_init_scale, mode="fan_out")
proj_mat_2 = hk.get_parameter(
"proj_mat_2",
[post_author_embedding.shape[-1], D],
dtype=jnp.float32,
init=lambda shape, dtype: embed_init(list(reversed(shape)), dtype).T,
)
candidate_embedding = jnp.dot(
post_author_embedding.astype(proj_mat_2.dtype), proj_mat_2
).astype(post_author_embedding.dtype)
candidate_padding_mask = (candidate_post_hashes[:, :, 0] != 0).reshape(B, C).astype(jnp.bool_)
return candidate_embedding, candidate_padding_mask
@dataclass
class PhoenixModelConfig:
"""Configuration for the recommendation system model."""
model: TransformerConfig
emb_size: int
num_actions: int
history_seq_len: int = 128
candidate_seq_len: int = 32
name: Optional[str] = None
fprop_dtype: Any = jnp.bfloat16
hash_config: HashConfig = None # type: ignore
product_surface_vocab_size: int = 16
_initialized = False
def __post_init__(self):
if self.hash_config is None:
self.hash_config = HashConfig()
def initialize(self):
self._initialized = True
return self
def make(self):
if not self._initialized:
logger.warning(f"PhoenixModel {self.name} is not initialized. Initializing.")
self.initialize()
return PhoenixModel(
model=self.model.make(),
config=self,
fprop_dtype=self.fprop_dtype,
)
@dataclass
class PhoenixModel(hk.Module):
"""A transformer-based recommendation model for ranking candidates."""
model: Transformer
config: PhoenixModelConfig
fprop_dtype: Any = jnp.bfloat16
name: Optional[str] = None
def _get_action_embeddings(
self,
actions: jax.Array,
) -> jax.Array:
"""Convert multi-hot action vectors to embeddings.
Uses a learned projection matrix to map the signed action vector
to the embedding dimension. This works for any number of actions.
"""
config = self.config
_, _, num_actions = actions.shape
D = config.emb_size
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
action_projection = hk.get_parameter(
"action_projection",
[num_actions, D],
dtype=jnp.float32,
init=embed_init,
)
actions_signed = (2 * actions - 1).astype(jnp.float32)
action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)
valid_mask = jnp.any(actions, axis=-1, keepdims=True)
action_emb = action_emb * valid_mask
return action_emb.astype(self.fprop_dtype)
def _single_hot_to_embeddings(
self,
input: jax.Array,
vocab_size: int,
emb_size: int,
name: str,
) -> jax.Array:
"""Convert single-hot indices to embeddings via lookup table.
Args:
input: [B, S] tensor of categorical indices
vocab_size: size of the vocabulary
emb_size: embedding dimension
name: name for the embedding table parameter
Returns:
embeddings: [B, S, emb_size]
"""
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
embedding_table = hk.get_parameter(
name,
[vocab_size, emb_size],
dtype=jnp.float32,
init=embed_init,
)
input_one_hot = jax.nn.one_hot(input, vocab_size)
output = jnp.dot(input_one_hot, embedding_table)
return output.astype(self.fprop_dtype)
def _get_unembedding(self) -> jax.Array:
"""Get the unembedding matrix for decoding to logits."""
config = self.config
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
unembed_mat = hk.get_parameter(
"unembeddings",
[config.emb_size, config.num_actions],
dtype=jnp.float32,
init=embed_init,
)
return unembed_mat
def build_inputs(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array, int]:
"""Build input embeddings from batch and pre-looked-up embeddings.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
embeddings: [B, 1 + history_len + num_candidates, D]
padding_mask: [B, 1 + history_len + num_candidates]
candidate_start_offset: int - position where candidates start
"""
config = self.config
hash_config = config.hash_config
history_product_surface_embeddings = self._single_hot_to_embeddings(
batch.history_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
candidate_product_surface_embeddings = self._single_hot_to_embeddings(
batch.candidate_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
user_embeddings, user_padding_mask = block_user_reduce(
batch.user_hashes, # type: ignore
recsys_embeddings.user_embeddings, # type: ignore
hash_config.num_user_hashes,
config.emb_size,
1.0,
)
history_embeddings, history_padding_mask = block_history_reduce(
batch.history_post_hashes, # type: ignore
recsys_embeddings.history_post_embeddings, # type: ignore
recsys_embeddings.history_author_embeddings, # type: ignore
history_product_surface_embeddings,
history_actions_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
)
candidate_embeddings, candidate_padding_mask = block_candidate_reduce(
batch.candidate_post_hashes, # type: ignore
recsys_embeddings.candidate_post_embeddings, # type: ignore
recsys_embeddings.candidate_author_embeddings, # type: ignore
candidate_product_surface_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
)
embeddings = jnp.concatenate(
[user_embeddings, history_embeddings, candidate_embeddings], axis=1
)
padding_mask = jnp.concatenate(
[user_padding_mask, history_padding_mask, candidate_padding_mask], axis=1
)
candidate_start_offset = user_padding_mask.shape[1] + history_padding_mask.shape[1]
return embeddings.astype(self.fprop_dtype), padding_mask, candidate_start_offset
def __call__(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> RecsysModelOutput:
"""Forward pass for ranking candidates.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
RecsysModelOutput containing logits for each candidate. Shape = [B, num_candidates, num_actions]
"""
embeddings, padding_mask, candidate_start_offset = self.build_inputs(
batch, recsys_embeddings
)
# transformer
model_output = self.model(
embeddings,
padding_mask,
candidate_start_offset=candidate_start_offset,
)
out_embeddings = model_output.embeddings
out_embeddings = layer_norm(out_embeddings)
candidate_embeddings = out_embeddings[:, candidate_start_offset:, :]
unembeddings = self._get_unembedding()
logits = jnp.dot(candidate_embeddings.astype(unembeddings.dtype), unembeddings)
logits = logits.astype(self.fprop_dtype)
return RecsysModelOutput(logits=logits)
================================================
FILE: phoenix/recsys_retrieval_model.py
================================================
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional, Tuple
import haiku as hk
import jax
import jax.numpy as jnp
from grok import TransformerConfig, Transformer
from recsys_model import (
HashConfig,
RecsysBatch,
RecsysEmbeddings,
block_history_reduce,
block_user_reduce,
)
logger = logging.getLogger(__name__)
EPS = 1e-12
INF = 1e12
class RetrievalOutput(NamedTuple):
"""Output of the retrieval model."""
user_representation: jax.Array
top_k_indices: jax.Array
top_k_scores: jax.Array
@dataclass
class CandidateTower(hk.Module):
"""Candidate tower that projects post+author embeddings to a shared embedding space.
This tower takes the concatenated embeddings of a post and its author,
and projects them to a normalized representation suitable for similarity search.
"""
emb_size: int
name: Optional[str] = None
def __call__(self, post_author_embedding: jax.Array) -> jax.Array:
"""Project post+author embeddings to normalized representation.
Args:
post_author_embedding: Concatenated post and author embeddings
Shape: [B, C, num_hashes, D] or [B, num_hashes, D]
Returns:
Normalized candidate representation
Shape: [B, C, D] or [B, D]
"""
if len(post_author_embedding.shape) == 4:
B, C, _, _ = post_author_embedding.shape
post_author_embedding = jnp.reshape(post_author_embedding, (B, C, -1))
else:
B, _, _ = post_author_embedding.shape
post_author_embedding = jnp.reshape(post_author_embedding, (B, -1))
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
proj_1 = hk.get_parameter(
"candidate_tower_projection_1",
[post_author_embedding.shape[-1], self.emb_size * 2],
dtype=jnp.float32,
init=embed_init,
)
proj_2 = hk.get_parameter(
"candidate_tower_projection_2",
[self.emb_size * 2, self.emb_size],
dtype=jnp.float32,
init=embed_init,
)
hidden = jnp.dot(post_author_embedding.astype(proj_1.dtype), proj_1)
hidden = jax.nn.silu(hidden)
candidate_embeddings = jnp.dot(hidden.astype(proj_2.dtype), proj_2)
candidate_norm_sq = jnp.sum(candidate_embeddings**2, axis=-1, keepdims=True)
candidate_norm = jnp.sqrt(jnp.maximum(candidate_norm_sq, EPS))
candidate_representation = candidate_embeddings / candidate_norm
return candidate_representation.astype(post_author_embedding.dtype)
@dataclass
class PhoenixRetrievalModelConfig:
"""Configuration for the Phoenix Retrieval Model.
This model uses the same transformer architecture as the Phoenix ranker
for encoding user representations.
"""
model: TransformerConfig
emb_size: int
history_seq_len: int = 128
candidate_seq_len: int = 32
name: Optional[str] = None
fprop_dtype: Any = jnp.bfloat16
hash_config: HashConfig = None # type: ignore
product_surface_vocab_size: int = 16
_initialized: bool = False
def __post_init__(self):
if self.hash_config is None:
self.hash_config = HashConfig()
def initialize(self):
self._initialized = True
return self
def make(self):
if not self._initialized:
logger.warning(f"PhoenixRetrievalModel {self.name} is not initialized. Initializing.")
self.initialize()
return PhoenixRetrievalModel(
model=self.model.make(),
config=self,
fprop_dtype=self.fprop_dtype,
)
@dataclass
class PhoenixRetrievalModel(hk.Module):
"""A two-tower retrieval model using the Phoenix transformer for user encoding.
This model implements the two-tower architecture for efficient retrieval:
- User Tower: Encodes user features + history using the Phoenix transformer
- Candidate Tower: Projects candidate embeddings to a shared space
The user and candidate representations are L2-normalized, enabling efficient
approximate nearest neighbor (ANN) search using dot product similarity.
"""
model: Transformer
config: PhoenixRetrievalModelConfig
fprop_dtype: Any = jnp.bfloat16
name: Optional[str] = None
def _get_action_embeddings(
self,
actions: jax.Array,
) -> jax.Array:
"""Convert multi-hot action vectors to embeddings."""
config = self.config
_, _, num_actions = actions.shape
D = config.emb_size
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
action_projection = hk.get_parameter(
"action_projection",
[num_actions, D],
dtype=jnp.float32,
init=embed_init,
)
actions_signed = (2 * actions - 1).astype(jnp.float32)
action_emb = jnp.dot(actions_signed.astype(action_projection.dtype), action_projection)
valid_mask = jnp.any(actions, axis=-1, keepdims=True)
action_emb = action_emb * valid_mask
return action_emb.astype(self.fprop_dtype)
def _single_hot_to_embeddings(
self,
input: jax.Array,
vocab_size: int,
emb_size: int,
name: str,
) -> jax.Array:
"""Convert single-hot indices to embeddings via lookup table."""
embed_init = hk.initializers.VarianceScaling(1.0, mode="fan_out")
embedding_table = hk.get_parameter(
name,
[vocab_size, emb_size],
dtype=jnp.float32,
init=embed_init,
)
input_one_hot = jax.nn.one_hot(input, vocab_size)
output = jnp.dot(input_one_hot, embedding_table)
return output.astype(self.fprop_dtype)
def build_user_representation(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]:
"""Build user representation from user features and history.
Uses the Phoenix transformer to encode user + history embeddings
into a single user representation vector.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
user_representation: L2-normalized user embedding [B, D]
user_norm: Pre-normalization L2 norm [B, 1]
"""
config = self.config
hash_config = config.hash_config
history_product_surface_embeddings = self._single_hot_to_embeddings(
batch.history_product_surface, # type: ignore
config.product_surface_vocab_size,
config.emb_size,
"product_surface_embedding_table",
)
history_actions_embeddings = self._get_action_embeddings(batch.history_actions) # type: ignore
user_embeddings, user_padding_mask = block_user_reduce(
batch.user_hashes, # type: ignore
recsys_embeddings.user_embeddings, # type: ignore
hash_config.num_user_hashes,
config.emb_size,
1.0,
)
history_embeddings, history_padding_mask = block_history_reduce(
batch.history_post_hashes, # type: ignore
recsys_embeddings.history_post_embeddings, # type: ignore
recsys_embeddings.history_author_embeddings, # type: ignore
history_product_surface_embeddings,
history_actions_embeddings,
hash_config.num_item_hashes,
hash_config.num_author_hashes,
1.0,
)
embeddings = jnp.concatenate([user_embeddings, history_embeddings], axis=1)
padding_mask = jnp.concatenate([user_padding_mask, history_padding_mask], axis=1)
model_output = self.model(
embeddings.astype(self.fprop_dtype),
padding_mask,
candidate_start_offset=None,
)
user_outputs = model_output.embeddings
mask_float = padding_mask.astype(jnp.float32)[:, :, None] # [B, T, 1]
user_embeddings_masked = user_outputs * mask_float
user_embedding_sum = jnp.sum(user_embeddings_masked, axis=1) # [B, D]
mask_sum = jnp.sum(mask_float, axis=1) # [B, 1]
user_representation = user_embedding_sum / jnp.maximum(mask_sum, 1.0)
user_norm_sq = jnp.sum(user_representation**2, axis=-1, keepdims=True)
user_norm = jnp.sqrt(jnp.maximum(user_norm_sq, EPS))
user_representation = user_representation / user_norm
return user_representation, user_norm
def build_candidate_representation(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
) -> Tuple[jax.Array, jax.Array]:
"""Build candidate (item) representations.
Projects post + author embeddings to a shared embedding space
using the candidate tower MLP.
Args:
batch: RecsysBatch containing candidate hashes
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
Returns:
candidate_representation: L2-normalized candidate embeddings [B, C, D]
candidate_padding_mask: Valid candidate mask [B, C]
"""
config = self.config
candidate_post_embeddings = recsys_embeddings.candidate_post_embeddings
candidate_author_embeddings = recsys_embeddings.candidate_author_embeddings
post_author_embedding = jnp.concatenate(
[candidate_post_embeddings, candidate_author_embeddings], axis=2
)
candidate_tower = CandidateTower(
emb_size=config.emb_size,
)
candidate_representation = candidate_tower(post_author_embedding)
candidate_padding_mask = (batch.candidate_post_hashes[:, :, 0] != 0).astype(jnp.bool_) # type: ignore
return candidate_representation, candidate_padding_mask
def __call__(
self,
batch: RecsysBatch,
recsys_embeddings: RecsysEmbeddings,
corpus_embeddings: jax.Array,
top_k: int,
corpus_mask: Optional[jax.Array] = None,
) -> RetrievalOutput:
"""Retrieve top-k candidates from corpus for each user.
Args:
batch: RecsysBatch containing hashes, actions, product surfaces
recsys_embeddings: RecsysEmbeddings containing pre-looked-up embeddings
corpus_embeddings: [N, D] normalized corpus candidate embeddings
top_k: Number of candidates to retrieve
corpus_mask: [N] optional mask for valid corpus entries
Returns:
RetrievalOutput containing user representation and top-k results
"""
user_representation, _ = self.build_user_representation(batch, recsys_embeddings)
top_k_indices, top_k_scores = self._retrieve_top_k(
user_representation, corpus_embeddings, top_k, corpus_mask
)
return RetrievalOutput(
user_representation=user_representation,
top_k_indices=top_k_indices,
top_k_scores=top_k_scores,
)
def _retrieve_top_k(
self,
user_representation: jax.Array,
corpus_embeddings: jax.Array,
top_k: int,
corpus_mask: Optional[jax.Array] = None,
) -> Tuple[jax.Array, jax.Array]:
"""Retrieve top-k candidates from a corpus for each user.
Args:
user_representation: [B, D] normalized user embeddings
corpus_embeddings: [N, D] normalized corpus candidate embeddings
top_k: Number of candidates to retrieve
corpus_mask: [N] optional mask for valid corpus entries
Returns:
top_k_indices: [B, K] indices of top-k candidates
top_k_scores: [B, K] similarity scores of top-k candidates
"""
scores = jnp.matmul(user_representation, corpus_embeddings.T)
if corpus_mask is not None:
scores = jnp.where(corpus_mask[None, :], scores, -INF)
top_k_scores, top_k_indices = jax.lax.top_k(scores, top_k)
return top_k_indices, top_k_scores
================================================
FILE: phoenix/run_ranker.py
================================================
# Copyright 2026 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from grok import TransformerConfig
from recsys_model import PhoenixModelConfig, HashConfig
from runners import RecsysInferenceRunner, ModelRunner, create_example_batch, ACTIONS
def main():
# Model configuration
emb_size = 128 # Embedding dimension
num_actions = len(ACTIONS) # Number of explicit engagement actions
history_seq_len = 32 # Max history length
candidate_seq_len = 8 # Max candidates to rank
# Hash configuration
hash_config = HashConfig(
num_user_hashes=2,
num_item_hashes=2,
num_author_hashes=2,
)
recsys_model = PhoenixModelConfig(
emb_size=emb_size,
num_actions=num_actions,
history_seq_len=history_seq_len,
candidate_seq_len=candidate_seq_len,
hash_config=hash_config,
product_surface_vocab_size=16,
model=TransformerConfig(
emb_size=emb_size,
widening_factor=2,
key_size=64,
num_q_heads=2,
num_kv_heads=2,
num_layers=2,
attn_output_multiplier=0.125,
),
)
# Create inference runner
inference_runner = RecsysInferenceRunner(
runner=ModelRunner(
model=recsys_model,
bs_per_device=0.125,
),
name="recsys_local",
)
print("Initializing model...")
inference_runner.initialize()
print("Model initialized!")
# Create example batch with simulated posts
print("\n" + "=" * 70)
print("RECOMMENDATION SYSTEM DEMO")
print("=" * 70)
batch_size = 1
example_batch, example_embeddings = create_example_batch(
batch_size=batch_size,
emb_size=emb_size,
history_len=history_seq_len,
num_candidates=candidate_seq_len,
num_actions=num_actions,
num_user_hashes=hash_config.num_user_hashes,
num_item_hashes=hash_config.num_item_hashes,
num_author_hashes=hash_config.num_author_hashes,
product_surface_vocab_size=recsys_model.product_surface_vocab_size,
)
action_names = [action.replace("_", " ").title() for action in ACTIONS]
# Count valid history items (where first post hash is non-zero)
valid_history_count = int((ex
gitextract_bsr8dmh8/
├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── candidate-pipeline/
│ ├── candidate_pipeline.rs
│ ├── filter.rs
│ ├── hydrator.rs
│ ├── lib.rs
│ ├── query_hydrator.rs
│ ├── scorer.rs
│ ├── selector.rs
│ ├── side_effect.rs
│ └── source.rs
├── home-mixer/
│ ├── candidate_hydrators/
│ │ ├── core_data_candidate_hydrator.rs
│ │ ├── gizmoduck_hydrator.rs
│ │ ├── in_network_candidate_hydrator.rs
│ │ ├── mod.rs
│ │ ├── subscription_hydrator.rs
│ │ ├── vf_candidate_hydrator.rs
│ │ └── video_duration_candidate_hydrator.rs
│ ├── candidate_pipeline/
│ │ ├── candidate.rs
│ │ ├── candidate_features.rs
│ │ ├── mod.rs
│ │ ├── phoenix_candidate_pipeline.rs
│ │ ├── query.rs
│ │ └── query_features.rs
│ ├── filters/
│ │ ├── age_filter.rs
│ │ ├── author_socialgraph_filter.rs
│ │ ├── core_data_hydration_filter.rs
│ │ ├── dedup_conversation_filter.rs
│ │ ├── drop_duplicates_filter.rs
│ │ ├── ineligible_subscription_filter.rs
│ │ ├── mod.rs
│ │ ├── muted_keyword_filter.rs
│ │ ├── previously_seen_posts_filter.rs
│ │ ├── previously_served_posts_filter.rs
│ │ ├── retweet_deduplication_filter.rs
│ │ ├── self_tweet_filter.rs
│ │ └── vf_filter.rs
│ ├── lib.rs
│ ├── main.rs
│ ├── query_hydrators/
│ │ ├── mod.rs
│ │ ├── user_action_seq_query_hydrator.rs
│ │ └── user_features_query_hydrator.rs
│ ├── scorers/
│ │ ├── author_diversity_scorer.rs
│ │ ├── mod.rs
│ │ ├── oon_scorer.rs
│ │ ├── phoenix_scorer.rs
│ │ └── weighted_scorer.rs
│ ├── selectors/
│ │ ├── mod.rs
│ │ └── top_k_score_selector.rs
│ ├── server.rs
│ ├── side_effects/
│ │ ├── cache_request_info_side_effect.rs
│ │ └── mod.rs
│ └── sources/
│ ├── mod.rs
│ ├── phoenix_source.rs
│ └── thunder_source.rs
├── phoenix/
│ ├── README.md
│ ├── grok.py
│ ├── pyproject.toml
│ ├── recsys_model.py
│ ├── recsys_retrieval_model.py
│ ├── run_ranker.py
│ ├── run_retrieval.py
│ ├── runners.py
│ ├── test_recsys_model.py
│ └── test_recsys_retrieval_model.py
└── thunder/
├── deserializer.rs
├── kafka/
│ ├── mod.rs
│ ├── tweet_events_listener.rs
│ ├── tweet_events_listener_v2.rs
│ └── utils.rs
├── kafka_utils.rs
├── lib.rs
├── main.rs
├── posts/
│ ├── mod.rs
│ └── post_store.rs
└── thunder_service.rs
SYMBOL INDEX (387 symbols across 59 files)
FILE: candidate-pipeline/candidate_pipeline.rs
type PipelineStage (line 14) | pub enum PipelineStage {
type PipelineResult (line 24) | pub struct PipelineResult<Q, C> {
type HasRequestId (line 32) | pub trait HasRequestId {
method request_id (line 33) | fn request_id(&self) -> &str;
type CandidatePipeline (line 37) | pub trait CandidatePipeline<Q, C>: Send + Sync
method query_hydrators (line 42) | fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<Q>>];
method sources (line 43) | fn sources(&self) -> &[Box<dyn Source<Q, C>>];
method hydrators (line 44) | fn hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
method filters (line 45) | fn filters(&self) -> &[Box<dyn Filter<Q, C>>];
method scorers (line 46) | fn scorers(&self) -> &[Box<dyn Scorer<Q, C>>];
method selector (line 47) | fn selector(&self) -> &dyn Selector<Q, C>;
method post_selection_hydrators (line 48) | fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<Q, C>>];
method post_selection_filters (line 49) | fn post_selection_filters(&self) -> &[Box<dyn Filter<Q, C>>];
method side_effects (line 50) | fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<Q, C>>>>;
method result_size (line 51) | fn result_size(&self) -> usize;
method execute (line 53) | async fn execute(&self, query: Q) -> PipelineResult<Q, C> {
method hydrate_query (line 95) | async fn hydrate_query(&self, query: Q) -> Q {
method fetch_candidates (line 126) | async fn fetch_candidates(&self, query: &Q) -> Vec<C> {
method hydrate (line 160) | async fn hydrate(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
method hydrate_post_selection (line 166) | async fn hydrate_post_selection(&self, query: &Q, candidates: Vec<C>) ...
method run_hydrators (line 177) | async fn run_hydrators(
method filter (line 220) | async fn filter(&self, query: &Q, candidates: Vec<C>) -> (Vec<C>, Vec<...
method filter_post_selection (line 226) | async fn filter_post_selection(&self, query: &Q, candidates: Vec<C>) -...
method run_filters (line 237) | async fn run_filters(
method score (line 276) | async fn score(&self, query: &Q, mut candidates: Vec<C>) -> Vec<C> {
method select (line 310) | fn select(&self, query: &Q, candidates: Vec<C>) -> Vec<C> {
method run_side_effects (line 319) | fn run_side_effects(&self, input: Arc<SideEffectInput<Q, C>>) {
FILE: candidate-pipeline/filter.rs
type FilterResult (line 6) | pub struct FilterResult<C> {
type Filter (line 13) | pub trait Filter<Q, C>: Any + Send + Sync
method enable (line 19) | fn enable(&self, _query: &Q) -> bool {
method filter (line 26) | async fn filter(&self, query: &Q, candidates: Vec<C>) -> Result<Filter...
method name (line 29) | fn name(&self) -> &'static str {
FILE: candidate-pipeline/hydrator.rs
type Hydrator (line 7) | pub trait Hydrator<Q, C>: Any + Send + Sync
method enable (line 13) | fn enable(&self, _query: &Q) -> bool {
method hydrate (line 22) | async fn hydrate(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>,...
method update (line 26) | fn update(&self, candidate: &mut C, hydrated: C);
method update_all (line 30) | fn update_all(&self, candidates: &mut [C], hydrated: Vec<C>) {
method name (line 36) | fn name(&self) -> &'static str {
FILE: candidate-pipeline/query_hydrator.rs
type QueryHydrator (line 7) | pub trait QueryHydrator<Q>: Any + Send + Sync
method enable (line 12) | fn enable(&self, _query: &Q) -> bool {
method hydrate (line 18) | async fn hydrate(&self, query: &Q) -> Result<Q, String>;
method update (line 22) | fn update(&self, query: &mut Q, hydrated: Q);
method name (line 24) | fn name(&self) -> &'static str {
FILE: candidate-pipeline/scorer.rs
type Scorer (line 7) | pub trait Scorer<Q, C>: Send + Sync
method enable (line 13) | fn enable(&self, _query: &Q) -> bool {
method score (line 22) | async fn score(&self, query: &Q, candidates: &[C]) -> Result<Vec<C>, S...
method update (line 26) | fn update(&self, candidate: &mut C, scored: C);
method update_all (line 30) | fn update_all(&self, candidates: &mut [C], scored: Vec<C>) {
method name (line 36) | fn name(&self) -> &'static str {
FILE: candidate-pipeline/selector.rs
type Selector (line 4) | pub trait Selector<Q, C>: Send + Sync
method select (line 10) | fn select(&self, _query: &Q, candidates: Vec<C>) -> Vec<C> {
method enable (line 19) | fn enable(&self, _query: &Q) -> bool {
method score (line 24) | fn score(&self, candidate: &C) -> f64;
method sort (line 27) | fn sort(&self, candidates: Vec<C>) -> Vec<C> {
method size (line 38) | fn size(&self) -> Option<usize> {
method name (line 42) | fn name(&self) -> &'static str {
FILE: candidate-pipeline/side_effect.rs
type SideEffectInput (line 8) | pub struct SideEffectInput<Q, C> {
type SideEffect (line 14) | pub trait SideEffect<Q, C>: Send + Sync
method enable (line 20) | fn enable(&self, _query: Arc<Q>) -> bool {
method run (line 24) | async fn run(&self, input: Arc<SideEffectInput<Q, C>>) -> Result<(), S...
method name (line 26) | fn name(&self) -> &'static str {
FILE: candidate-pipeline/source.rs
type Source (line 7) | pub trait Source<Q, C>: Any + Send + Sync
method enable (line 13) | fn enable(&self, _query: &Q) -> bool {
method get_candidates (line 17) | async fn get_candidates(&self, query: &Q) -> Result<Vec<C>, String>;
method name (line 19) | fn name(&self) -> &'static str {
FILE: home-mixer/candidate_hydrators/core_data_candidate_hydrator.rs
type CoreDataCandidateHydrator (line 8) | pub struct CoreDataCandidateHydrator {
method new (line 13) | pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
method hydrate (line 21) | async fn hydrate(
method update (line 52) | fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidat...
FILE: home-mixer/candidate_hydrators/gizmoduck_hydrator.rs
type GizmoduckCandidateHydrator (line 8) | pub struct GizmoduckCandidateHydrator {
method new (line 13) | pub async fn new(gizmoduck_client: Arc<dyn GizmoduckClient + Send + Sy...
method hydrate (line 21) | async fn hydrate(
method update (line 76) | fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidat...
FILE: home-mixer/candidate_hydrators/in_network_candidate_hydrator.rs
type InNetworkCandidateHydrator (line 7) | pub struct InNetworkCandidateHydrator;
method hydrate (line 12) | async fn hydrate(
method update (line 41) | fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidat...
FILE: home-mixer/candidate_hydrators/subscription_hydrator.rs
type SubscriptionHydrator (line 8) | pub struct SubscriptionHydrator {
method new (line 13) | pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
method hydrate (line 21) | async fn hydrate(
method update (line 47) | fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidat...
FILE: home-mixer/candidate_hydrators/vf_candidate_hydrator.rs
type VFCandidateHydrator (line 15) | pub struct VFCandidateHydrator {
method new (line 20) | pub async fn new(vf_client: Arc<dyn VisibilityFilteringClient + Send +...
method fetch_vf_results (line 24) | async fn fetch_vf_results(
method hydrate (line 45) | async fn hydrate(
method update (line 98) | fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidat...
FILE: home-mixer/candidate_hydrators/video_duration_candidate_hydrator.rs
type VideoDurationCandidateHydrator (line 9) | pub struct VideoDurationCandidateHydrator {
method new (line 14) | pub async fn new(tes_client: Arc<dyn TESClient + Send + Sync>) -> Self {
method hydrate (line 22) | async fn hydrate(
method update (line 59) | fn update(&self, candidate: &mut PostCandidate, hydrated: PostCandidat...
FILE: home-mixer/candidate_pipeline/candidate.rs
type PostCandidate (line 6) | pub struct PostCandidate {
type PhoenixScores (line 30) | pub struct PhoenixScores {
type CandidateHelpers (line 53) | pub trait CandidateHelpers {
method get_screen_names (line 54) | fn get_screen_names(&self) -> HashMap<u64, String>;
method get_screen_names (line 58) | fn get_screen_names(&self) -> HashMap<u64, String> {
FILE: home-mixer/candidate_pipeline/candidate_features.rs
type PureCoreData (line 5) | pub struct PureCoreData {
type ExclusiveTweetControl (line 16) | pub struct ExclusiveTweetControl {
type MediaEntities (line 20) | pub type MediaEntities = Vec<MediaEntity>;
type MediaEntity (line 24) | pub struct MediaEntity {
type MediaInfo (line 30) | pub enum MediaInfo {
type VideoInfo (line 36) | pub struct VideoInfo {
type Share (line 42) | pub struct Share {
type Reply (line 49) | pub struct Reply {
type GizmoduckUserCounts (line 56) | pub struct GizmoduckUserCounts {
type GizmoduckUserProfile (line 62) | pub struct GizmoduckUserProfile {
type GizmoduckUser (line 68) | pub struct GizmoduckUser {
type GizmoduckUserResult (line 76) | pub struct GizmoduckUserResult {
FILE: home-mixer/candidate_pipeline/phoenix_candidate_pipeline.rs
type PhoenixCandidatePipeline (line 60) | pub struct PhoenixCandidatePipeline {
method build_with_clients (line 73) | async fn build_with_clients(
method prod (line 162) | pub async fn prod() -> PhoenixCandidatePipeline {
method query_hydrators (line 217) | fn query_hydrators(&self) -> &[Box<dyn QueryHydrator<ScoredPostsQuery>...
method sources (line 221) | fn sources(&self) -> &[Box<dyn Source<ScoredPostsQuery, PostCandidate>...
method hydrators (line 224) | fn hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQuery, PostCandid...
method filters (line 228) | fn filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery, PostCandidate>...
method scorers (line 232) | fn scorers(&self) -> &[Box<dyn Scorer<ScoredPostsQuery, PostCandidate>...
method selector (line 236) | fn selector(&self) -> &dyn Selector<ScoredPostsQuery, PostCandidate> {
method post_selection_hydrators (line 240) | fn post_selection_hydrators(&self) -> &[Box<dyn Hydrator<ScoredPostsQu...
method post_selection_filters (line 244) | fn post_selection_filters(&self) -> &[Box<dyn Filter<ScoredPostsQuery,...
method side_effects (line 248) | fn side_effects(&self) -> Arc<Vec<Box<dyn SideEffect<ScoredPostsQuery,...
method result_size (line 252) | fn result_size(&self) -> usize {
FILE: home-mixer/candidate_pipeline/query.rs
type ScoredPostsQuery (line 8) | pub struct ScoredPostsQuery {
method new (line 24) | pub fn new(
method get_viewer (line 54) | fn get_viewer(&self) -> Option<TwitterContextViewer> {
method request_id (line 66) | fn request_id(&self) -> &str {
FILE: home-mixer/candidate_pipeline/query_features.rs
type UserFeatures (line 5) | pub struct UserFeatures {
FILE: home-mixer/filters/age_filter.rs
type AgeFilter (line 9) | pub struct AgeFilter {
method new (line 14) | pub fn new(max_age: Duration) -> Self {
method is_within_age (line 18) | fn is_within_age(&self, tweet_id: i64) -> bool {
method filter (line 27) | async fn filter(
FILE: home-mixer/filters/author_socialgraph_filter.rs
type AuthorSocialgraphFilter (line 7) | pub struct AuthorSocialgraphFilter;
method filter (line 11) | async fn filter(
FILE: home-mixer/filters/core_data_hydration_filter.rs
type CoreDataHydrationFilter (line 6) | pub struct CoreDataHydrationFilter;
method filter (line 10) | async fn filter(
FILE: home-mixer/filters/dedup_conversation_filter.rs
type DedupConversationFilter (line 8) | pub struct DedupConversationFilter;
method filter (line 12) | async fn filter(
function get_conversation_id (line 44) | fn get_conversation_id(candidate: &PostCandidate) -> u64 {
FILE: home-mixer/filters/drop_duplicates_filter.rs
type DropDuplicatesFilter (line 7) | pub struct DropDuplicatesFilter;
method filter (line 11) | async fn filter(
FILE: home-mixer/filters/ineligible_subscription_filter.rs
type IneligibleSubscriptionFilter (line 8) | pub struct IneligibleSubscriptionFilter;
method filter (line 12) | async fn filter(
FILE: home-mixer/filters/muted_keyword_filter.rs
type MutedKeywordFilter (line 8) | pub struct MutedKeywordFilter {
method new (line 13) | pub fn new() -> Self {
method filter (line 24) | async fn filter(
FILE: home-mixer/filters/previously_seen_posts_filter.rs
type PreviouslySeenPostsFilter (line 10) | pub struct PreviouslySeenPostsFilter;
method filter (line 14) | async fn filter(
FILE: home-mixer/filters/previously_served_posts_filter.rs
type PreviouslyServedPostsFilter (line 7) | pub struct PreviouslyServedPostsFilter;
method enable (line 11) | fn enable(&self, query: &ScoredPostsQuery) -> bool {
method filter (line 15) | async fn filter(
FILE: home-mixer/filters/retweet_deduplication_filter.rs
type RetweetDeduplicationFilter (line 9) | pub struct RetweetDeduplicationFilter;
method filter (line 13) | async fn filter(
FILE: home-mixer/filters/self_tweet_filter.rs
type SelfTweetFilter (line 7) | pub struct SelfTweetFilter;
method filter (line 11) | async fn filter(
FILE: home-mixer/filters/vf_filter.rs
type VFFilter (line 7) | pub struct VFFilter;
method filter (line 12) | async fn filter(
function should_drop (line 25) | fn should_drop(reason: &Option<FilteredReason>) -> bool {
FILE: home-mixer/main.rs
type Args (line 17) | struct Args {
function main (line 30) | async fn main() -> anyhow::Result<()> {
FILE: home-mixer/query_hydrators/user_action_seq_query_hydrator.rs
type UserActionSeqQueryHydrator (line 25) | pub struct UserActionSeqQueryHydrator {
method new (line 33) | pub fn new(uas_fetcher: Arc<UserActionSequenceFetcher>) -> Self {
method hydrate (line 46) | async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPost...
method update (line 62) | fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQu...
method name (line 66) | fn name(&self) -> &'static str {
method aggregate_user_action_sequence (line 72) | fn aggregate_user_action_sequence(
function convert_to_proto_sequence (line 119) | fn convert_to_proto_sequence(
FILE: home-mixer/query_hydrators/user_features_query_hydrator.rs
type UserFeaturesQueryHydrator (line 9) | pub struct UserFeaturesQueryHydrator {
method hydrate (line 16) | async fn hydrate(&self, query: &ScoredPostsQuery) -> Result<ScoredPost...
method update (line 34) | fn update(&self, query: &mut ScoredPostsQuery, hydrated: ScoredPostsQu...
method name (line 38) | fn name(&self) -> &'static str {
FILE: home-mixer/scorers/author_diversity_scorer.rs
type AuthorDiversityScorer (line 10) | pub struct AuthorDiversityScorer {
method new (line 22) | pub fn new(decay_factor: f64, floor: f64) -> Self {
method multiplier (line 29) | fn multiplier(&self, position: usize) -> f64 {
method score (line 37) | async fn score(
method update (line 70) | fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
method default (line 16) | fn default() -> Self {
FILE: home-mixer/scorers/oon_scorer.rs
type OONScorer (line 8) | pub struct OONScorer;
method score (line 12) | async fn score(
method update (line 35) | fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
FILE: home-mixer/scorers/phoenix_scorer.rs
type PhoenixScorer (line 12) | pub struct PhoenixScorer {
method score (line 19) | async fn score(
method update (line 78) | fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
method build_predictions_map (line 87) | fn build_predictions_map(
method extract_phoenix_scores (line 129) | fn extract_phoenix_scores(&self, p: &ActionPredictions) -> PhoenixScor...
method current_timestamp_millis (line 153) | fn current_timestamp_millis() -> Option<u64> {
type ActionPredictions (line 161) | struct ActionPredictions {
method get (line 169) | fn get(&self, action: ActionName) -> Option<f64> {
method get_continuous (line 173) | fn get_continuous(&self, action: ContinuousActionName) -> Option<f64> {
FILE: home-mixer/scorers/weighted_scorer.rs
type WeightedScorer (line 8) | pub struct WeightedScorer;
method score (line 13) | async fn score(
method update (line 34) | fn update(&self, candidate: &mut PostCandidate, scored: PostCandidate) {
method apply (line 40) | fn apply(score: Option<f64>, weight: f64) -> f64 {
method compute_weighted_score (line 44) | fn compute_weighted_score(candidate: &PostCandidate) -> f64 {
method vqv_weight_eligibility (line 72) | fn vqv_weight_eligibility(candidate: &PostCandidate) -> f64 {
method offset_score (line 83) | fn offset_score(combined_score: f64) -> f64 {
FILE: home-mixer/selectors/top_k_score_selector.rs
type TopKScoreSelector (line 6) | pub struct TopKScoreSelector;
method score (line 9) | fn score(&self, candidate: &PostCandidate) -> f64 {
method size (line 12) | fn size(&self) -> Option<usize> {
FILE: home-mixer/server.rs
type HomeMixerServer (line 12) | pub struct HomeMixerServer {
method new (line 17) | pub async fn new() -> Self {
method get_scored_posts (line 27) | async fn get_scored_posts(
FILE: home-mixer/side_effects/cache_request_info_side_effect.rs
type CacheRequestInfoSideEffect (line 10) | pub struct CacheRequestInfoSideEffect {
method enable (line 16) | fn enable(&self, query: Arc<ScoredPostsQuery>) -> bool {
method run (line 20) | async fn run(
FILE: home-mixer/sources/phoenix_source.rs
type PhoenixSource (line 10) | pub struct PhoenixSource {
method enable (line 16) | fn enable(&self, query: &ScoredPostsQuery) -> bool {
method get_candidates (line 21) | async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec...
FILE: home-mixer/sources/thunder_source.rs
type ThunderSource (line 12) | pub struct ThunderSource {
method get_candidates (line 19) | async fn get_candidates(&self, query: &ScoredPostsQuery) -> Result<Vec...
FILE: phoenix/grok.py
class TrainingState (line 26) | class TrainingState(NamedTuple):
function ffn_size (line 32) | def ffn_size(emb_size, widening_factor):
function make_recsys_attn_mask (line 39) | def make_recsys_attn_mask(
class MHAOutput (line 74) | class MHAOutput(NamedTuple):
class DecoderOutput (line 80) | class DecoderOutput(NamedTuple):
class TransformerOutput (line 84) | class TransformerOutput(NamedTuple):
class TransformerConfig (line 89) | class TransformerConfig:
method make (line 101) | def make(self) -> "Transformer":
function hk_rms_norm (line 112) | def hk_rms_norm(
class Linear (line 121) | class Linear(hk.Linear):
method __init__ (line 122) | def __init__(
method __call__ (line 134) | def __call__( # type: ignore
class RMSNorm (line 162) | class RMSNorm(hk.RMSNorm):
method __init__ (line 163) | def __init__(
method __call__ (line 172) | def __call__(self, inputs: jax.Array):
function rotate_half (line 197) | def rotate_half(
class RotaryEmbedding (line 205) | class RotaryEmbedding(hk.Module):
method __init__ (line 214) | def __init__(
method __call__ (line 225) | def __call__(
class MultiHeadAttention (line 264) | class MultiHeadAttention(hk.Module):
method __init__ (line 265) | def __init__(
method __call__ (line 286) | def __call__(
method _linear_projection (line 366) | def _linear_projection(
class MHABlock (line 379) | class MHABlock(hk.Module):
method __call__ (line 388) | def __call__(
class DenseBlock (line 415) | class DenseBlock(hk.Module):
method __call__ (line 422) | def __call__(
class DecoderLayer (line 444) | class DecoderLayer(hk.Module):
method __call__ (line 456) | def __call__(
function layer_norm (line 500) | def layer_norm(x):
class Transformer (line 505) | class Transformer(hk.Module):
method __call__ (line 516) | def __call__(
FILE: phoenix/recsys_model.py
class HashConfig (line 33) | class HashConfig:
class RecsysEmbeddings (line 42) | class RecsysEmbeddings:
class RecsysModelOutput (line 56) | class RecsysModelOutput(NamedTuple):
class RecsysBatch (line 62) | class RecsysBatch(NamedTuple):
function block_user_reduce (line 79) | def block_user_reduce(
function block_history_reduce (line 122) | def block_history_reduce(
function block_candidate_reduce (line 185) | def block_candidate_reduce(
class PhoenixModelConfig (line 246) | class PhoenixModelConfig:
method __post_init__ (line 264) | def __post_init__(self):
method initialize (line 268) | def initialize(self):
method make (line 272) | def make(self):
class PhoenixModel (line 285) | class PhoenixModel(hk.Module):
method _get_action_embeddings (line 293) | def _get_action_embeddings(
method _single_hot_to_embeddings (line 323) | def _single_hot_to_embeddings(
method _get_unembedding (line 353) | def _get_unembedding(self) -> jax.Array:
method build_inputs (line 365) | def build_inputs(
method __call__ (line 439) | def __call__(
FILE: phoenix/recsys_retrieval_model.py
class RetrievalOutput (line 38) | class RetrievalOutput(NamedTuple):
class CandidateTower (line 47) | class CandidateTower(hk.Module):
method __call__ (line 57) | def __call__(self, post_author_embedding: jax.Array) -> jax.Array:
class PhoenixRetrievalModelConfig (line 103) | class PhoenixRetrievalModelConfig:
method __post_init__ (line 124) | def __post_init__(self):
method initialize (line 128) | def initialize(self):
method make (line 132) | def make(self):
class PhoenixRetrievalModel (line 145) | class PhoenixRetrievalModel(hk.Module):
method _get_action_embeddings (line 161) | def _get_action_embeddings(
method _single_hot_to_embeddings (line 186) | def _single_hot_to_embeddings(
method build_user_representation (line 206) | def build_user_representation(
method build_candidate_representation (line 278) | def build_candidate_representation(
method __call__ (line 314) | def __call__(
method _retrieve_top_k (line 346) | def _retrieve_top_k(
FILE: phoenix/run_ranker.py
function main (line 24) | def main():
FILE: phoenix/run_retrieval.py
function main (line 31) | def main():
FILE: phoenix/runners.py
function create_dummy_batch_from_config (line 41) | def create_dummy_batch_from_config(
function create_dummy_embeddings_from_config (line 80) | def create_dummy_embeddings_from_config(
class BaseModelRunner (line 121) | class BaseModelRunner(ABC):
method model (line 129) | def model(self) -> Any:
method _model_name (line 134) | def _model_name(self) -> str:
method make_forward_fn (line 139) | def make_forward_fn(self):
method initialize (line 143) | def initialize(self):
class BaseInferenceRunner (line 156) | class BaseInferenceRunner(ABC):
method runner (line 163) | def runner(self) -> BaseModelRunner:
method _get_num_actions (line 167) | def _get_num_actions(self) -> int:
method create_dummy_batch (line 174) | def create_dummy_batch(self, batch_size: int = 1) -> RecsysBatch:
method create_dummy_embeddings (line 185) | def create_dummy_embeddings(self, batch_size: int = 1) -> RecsysEmbedd...
method initialize (line 197) | def initialize(self):
class RankingOutput (line 225) | class RankingOutput(NamedTuple):
class ModelRunner (line 258) | class ModelRunner(BaseModelRunner):
method __init__ (line 263) | def __init__(self, model: PhoenixModelConfig, bs_per_device: float = 2...
method model (line 269) | def model(self) -> PhoenixModelConfig:
method _model_name (line 273) | def _model_name(self) -> str:
method make_forward_fn (line 276) | def make_forward_fn(self): # type: ignore
method init (line 283) | def init(
method load_or_init (line 291) | def load_or_init(
class RecsysInferenceRunner (line 302) | class RecsysInferenceRunner(BaseInferenceRunner):
method __init__ (line 307) | def __init__(self, runner: ModelRunner, name: str):
method runner (line 312) | def runner(self) -> ModelRunner:
method initialize (line 315) | def initialize(self):
method rank (line 376) | def rank(self, batch: RecsysBatch, recsys_embeddings: RecsysEmbeddings...
function create_example_batch (line 389) | def create_example_batch(
class RetrievalOutput (line 490) | class RetrievalOutput(NamedTuple):
class RetrievalModelRunner (line 504) | class RetrievalModelRunner(BaseModelRunner):
method __init__ (line 509) | def __init__(
method model (line 520) | def model(self) -> PhoenixRetrievalModelConfig:
method _model_name (line 524) | def _model_name(self) -> str:
method make_forward_fn (line 527) | def make_forward_fn(self): # type: ignore
method init (line 542) | def init(
method load_or_init (line 555) | def load_or_init(
class RecsysRetrievalInferenceRunner (line 568) | class RecsysRetrievalInferenceRunner(BaseInferenceRunner):
method __init__ (line 582) | def __init__(self, runner: RetrievalModelRunner, name: str):
method runner (line 589) | def runner(self) -> RetrievalModelRunner:
method initialize (line 592) | def initialize(self):
method encode_user (line 642) | def encode_user(self, batch: RecsysBatch, recsys_embeddings: RecsysEmb...
method encode_candidates (line 654) | def encode_candidates(
method set_corpus (line 668) | def set_corpus(
method retrieve (line 682) | def retrieve(
function create_example_corpus (line 706) | def create_example_corpus(
FILE: phoenix/test_recsys_model.py
class TestMakeRecsysAttnMask (line 22) | class TestMakeRecsysAttnMask:
method test_output_shape (line 25) | def test_output_shape(self):
method test_user_history_has_causal_attention (line 34) | def test_user_history_has_causal_attention(self):
method test_candidates_attend_to_user_history (line 51) | def test_candidates_attend_to_user_history(self):
method test_candidates_attend_to_themselves (line 65) | def test_candidates_attend_to_themselves(self):
method test_candidates_do_not_attend_to_other_candidates (line 78) | def test_candidates_do_not_attend_to_other_candidates(self):
method test_full_mask_structure (line 93) | def test_full_mask_structure(self):
method test_dtype_preserved (line 134) | def test_dtype_preserved(self):
method test_single_candidate (line 145) | def test_single_candidate(self):
method test_all_candidates (line 165) | def test_all_candidates(self):
FILE: phoenix/test_recsys_retrieval_model.py
class TestCandidateTower (line 38) | class TestCandidateTower(unittest.TestCase):
method test_candidate_tower_output_shape (line 41) | def test_candidate_tower_output_shape(self):
method test_candidate_tower_normalized (line 62) | def test_candidate_tower_normalized(self):
method test_candidate_tower_mean_pooling (line 84) | def test_candidate_tower_mean_pooling(self):
class TestPhoenixRetrievalModel (line 109) | class TestPhoenixRetrievalModel(unittest.TestCase):
method setUp (line 112) | def setUp(self):
method _create_test_batch (line 145) | def _create_test_batch(self) -> tuple:
method _create_test_corpus (line 159) | def _create_test_corpus(self):
method test_model_forward (line 163) | def test_model_forward(self):
method test_user_representation_normalized (line 183) | def test_user_representation_normalized(self):
method test_candidate_representation_normalized (line 202) | def test_candidate_representation_normalized(self):
method test_retrieve_top_k (line 223) | def test_retrieve_top_k(self):
class TestRetrievalInferenceRunner (line 250) | class TestRetrievalInferenceRunner(unittest.TestCase):
method setUp (line 253) | def setUp(self):
method test_runner_initialization (line 284) | def test_runner_initialization(self):
method test_runner_encode_user (line 298) | def test_runner_encode_user(self):
method test_runner_retrieve (line 324) | def test_runner_retrieve(self):
FILE: thunder/deserializer.rs
function deserialize_tweet_event (line 8) | pub fn deserialize_tweet_event(payload: &[u8]) -> Result<TweetEvent> {
function deserialize_event (line 16) | pub fn deserialize_event(payload: &[u8]) -> Result<Event> {
function deserialize_tweet_event_v2 (line 24) | pub fn deserialize_tweet_event_v2(payload: &[u8]) -> Result<InNetworkEve...
FILE: thunder/kafka/tweet_events_listener.rs
function monitor_partition_lag (line 27) | async fn monitor_partition_lag(
function is_eligible_video (line 55) | fn is_eligible_video(tweet: &Tweet) -> bool {
function start_partition_lag_monitor (line 77) | pub fn start_partition_lag_monitor(
function start_tweet_event_processing (line 92) | pub async fn start_tweet_event_processing(
function spawn_processing_threads (line 125) | fn spawn_processing_threads(
function process_message_batch (line 193) | async fn process_message_batch(
function process_tweet_events (line 349) | async fn process_tweet_events(
FILE: thunder/kafka/tweet_events_listener_v2.rs
function start_tweet_event_processing_v2 (line 23) | pub async fn start_tweet_event_processing_v2(
function spawn_processing_threads_v2 (line 45) | fn spawn_processing_threads_v2(
function deserialize_batch (line 119) | fn deserialize_batch(
function process_tweet_events_v2 (line 170) | async fn process_tweet_events_v2(
FILE: thunder/kafka/utils.rs
function create_kafka_consumer (line 9) | pub async fn create_kafka_consumer(
function deserialize_kafka_messages (line 22) | pub fn deserialize_kafka_messages<T, F>(
FILE: thunder/kafka_utils.rs
constant TWEET_EVENT_TOPIC (line 15) | const TWEET_EVENT_TOPIC: &str = "";
constant TWEET_EVENT_DEST (line 16) | const TWEET_EVENT_DEST: &str = "";
constant IN_NETWORK_EVENTS_DEST (line 18) | const IN_NETWORK_EVENTS_DEST: &str = "";
constant IN_NETWORK_EVENTS_TOPIC (line 19) | const IN_NETWORK_EVENTS_TOPIC: &str = "";
function start_kafka (line 21) | pub async fn start_kafka(
FILE: thunder/main.rs
function main (line 16) | async fn main() -> Result<()> {
FILE: thunder/posts/post_store.rs
type TinyPost (line 21) | pub struct TinyPost {
method new (line 28) | pub fn new(post_id: i64, created_at: i64) -> Self {
type PostStore (line 39) | pub struct PostStore {
method new (line 57) | pub fn new(retention_seconds: u64, request_timeout_ms: u64) -> Self {
method mark_as_deleted (line 69) | pub fn mark_as_deleted(&self, posts: Vec<TweetDeleteEvent>) {
method insert_posts (line 86) | pub fn insert_posts(&self, mut posts: Vec<LightPost>) {
method finalize_init (line 103) | pub async fn finalize_init(&self) -> Result<()> {
method insert_posts_internal (line 115) | fn insert_posts_internal(&self, posts: Vec<LightPost>) {
method get_videos_by_users (line 171) | pub fn get_videos_by_users(
method get_all_posts_by_users (line 193) | pub fn get_all_posts_by_users(
method get_posts_from_map (line 228) | pub fn get_posts_from_map(
method start_stats_logger (line 331) | pub fn start_stats_logger(self: Arc<Self>) {
method start_auto_trim (line 393) | pub fn start_auto_trim(self: Arc<Self>, interval_minutes: u64) {
method trim_old_posts (line 409) | pub async fn trim_old_posts(&self) -> usize {
method sort_all_user_posts (line 479) | pub async fn sort_all_user_posts(&self) {
method clear (line 512) | pub fn clear(&self) {
method default (line 522) | fn default() -> Self {
FILE: thunder/thunder_service.rs
type ThunderServiceImpl (line 29) | pub struct ThunderServiceImpl {
method new (line 39) | pub fn new(
method server (line 56) | pub fn server(self) -> InNetworkPostsServiceServer<Self> {
method analyze_and_report_post_statistics (line 64) | fn analyze_and_report_post_statistics(posts: &[LightPost], stage: &str) {
method get_in_network_posts (line 154) | async fn get_in_network_posts(
function score_recent (line 334) | fn score_recent(mut light_posts: Vec<LightPost>, max_results: usize) -> ...
Condensed preview — 78 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (322K chars).
[
{
"path": ".gitignore",
"chars": 13,
"preview": "__pycache__/\n"
},
{
"path": "CODE_OF_CONDUCT.md",
"chars": 28,
"preview": "Be excellent to each other.\n"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 16439,
"preview": "# X For You Feed Algorithm\n\nThis repository contains the core recommendation system powering the \"For You\" feed on X. It"
},
{
"path": "candidate-pipeline/candidate_pipeline.rs",
"chars": 11850,
"preview": "use crate::filter::Filter;\nuse crate::hydrator::Hydrator;\nuse crate::query_hydrator::QueryHydrator;\nuse crate::scorer::S"
},
{
"path": "candidate-pipeline/filter.rs",
"chars": 990,
"preview": "use std::any::{Any, type_name_of_val};\nuse tonic::async_trait;\n\nuse crate::util;\n\npub struct FilterResult<C> {\n pub k"
},
{
"path": "candidate-pipeline/hydrator.rs",
"chars": 1423,
"preview": "use crate::util;\nuse std::any::{Any, type_name_of_val};\nuse tonic::async_trait;\n\n// Hydrators run in parallel and update"
},
{
"path": "candidate-pipeline/lib.rs",
"chars": 171,
"preview": "pub mod candidate_pipeline;\npub mod filter;\npub mod hydrator;\npub mod query_hydrator;\npub mod scorer;\npub mod selector;\n"
},
{
"path": "candidate-pipeline/query_hydrator.rs",
"chars": 784,
"preview": "use std::any::{Any, type_name_of_val};\nuse tonic::async_trait;\n\nuse crate::util;\n\n#[async_trait]\npub trait QueryHydrator"
},
{
"path": "candidate-pipeline/scorer.rs",
"chars": 1405,
"preview": "use crate::util;\nuse std::any::type_name_of_val;\nuse tonic::async_trait;\n\n/// Scorers update candidate fields (like a sc"
},
{
"path": "candidate-pipeline/selector.rs",
"chars": 1287,
"preview": "use crate::util;\nuse std::any::type_name_of_val;\n\npub trait Selector<Q, C>: Send + Sync\nwhere\n Q: Clone + Send + Sync"
},
{
"path": "candidate-pipeline/side_effect.rs",
"chars": 740,
"preview": "use crate::util;\nuse std::any::type_name_of_val;\nuse std::sync::Arc;\nuse tonic::async_trait;\n\n// A side-effect is an act"
},
{
"path": "candidate-pipeline/source.rs",
"chars": 520,
"preview": "use std::any::{Any, type_name_of_val};\nuse tonic::async_trait;\n\nuse crate::util;\n\n#[async_trait]\npub trait Source<Q, C>:"
},
{
"path": "home-mixer/candidate_hydrators/core_data_candidate_hydrator.rs",
"chars": 2288,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/candidate_hydrators/gizmoduck_hydrator.rs",
"chars": 3385,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/candidate_hydrators/in_network_candidate_hydrator.rs",
"chars": 1410,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse std"
},
{
"path": "home-mixer/candidate_hydrators/mod.rs",
"chars": 210,
"preview": "pub mod core_data_candidate_hydrator;\npub mod gizmoduck_hydrator;\npub mod in_network_candidate_hydrator;\npub mod subscri"
},
{
"path": "home-mixer/candidate_hydrators/subscription_hydrator.rs",
"chars": 1710,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/candidate_hydrators/vf_candidate_hydrator.rs",
"chars": 3463,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse fut"
},
{
"path": "home-mixer/candidate_hydrators/video_duration_candidate_hydrator.rs",
"chars": 2172,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::candidate_features::MediaInfo;\nu"
},
{
"path": "home-mixer/candidate_pipeline/candidate.rs",
"chars": 2426,
"preview": "use std::collections::HashMap;\nuse xai_home_mixer_proto as pb;\nuse xai_visibility_filtering::models as vf;\n\n#[derive(Clo"
},
{
"path": "home-mixer/candidate_pipeline/candidate_features.rs",
"chars": 2201,
"preview": "use serde::{Deserialize, Serialize};\n\n#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]\n#[serde(rename"
},
{
"path": "home-mixer/candidate_pipeline/mod.rs",
"chars": 122,
"preview": "pub mod candidate;\npub mod candidate_features;\npub mod phoenix_candidate_pipeline;\npub mod query;\npub mod query_features"
},
{
"path": "home-mixer/candidate_pipeline/phoenix_candidate_pipeline.rs",
"chars": 10852,
"preview": "use crate::candidate_hydrators::core_data_candidate_hydrator::CoreDataCandidateHydrator;\nuse crate::candidate_hydrators:"
},
{
"path": "home-mixer/candidate_pipeline/query.rs",
"chars": 2160,
"preview": "use crate::candidate_pipeline::query_features::UserFeatures;\nuse crate::util::request_util::generate_request_id;\nuse xai"
},
{
"path": "home-mixer/candidate_pipeline/query_features.rs",
"chars": 351,
"preview": "use serde::{Deserialize, Serialize};\n\n#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]\n#[serde(rename"
},
{
"path": "home-mixer/filters/age_filter.rs",
"chars": 1092,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/filters/author_socialgraph_filter.rs",
"chars": 1493,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse ton"
},
{
"path": "home-mixer/filters/core_data_hydration_filter.rs",
"chars": 689,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse ton"
},
{
"path": "home-mixer/filters/dedup_conversation_filter.rs",
"chars": 1780,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse std"
},
{
"path": "home-mixer/filters/drop_duplicates_filter.rs",
"chars": 900,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse std"
},
{
"path": "home-mixer/filters/ineligible_subscription_filter.rs",
"chars": 1177,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse std"
},
{
"path": "home-mixer/filters/mod.rs",
"chars": 391,
"preview": "pub mod age_filter;\npub mod author_socialgraph_filter;\npub mod core_data_hydration_filter;\npub mod dedup_conversation_fi"
},
{
"path": "home-mixer/filters/muted_keyword_filter.rs",
"chars": 1937,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse std"
},
{
"path": "home-mixer/filters/previously_seen_posts_filter.rs",
"chars": 1295,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/filters/previously_served_posts_filter.rs",
"chars": 918,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/filters/retweet_deduplication_filter.rs",
"chars": 1545,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse std"
},
{
"path": "home-mixer/filters/self_tweet_filter.rs",
"chars": 773,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse ton"
},
{
"path": "home-mixer/filters/vf_filter.rs",
"chars": 1029,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse ton"
},
{
"path": "home-mixer/lib.rs",
"chars": 413,
"preview": "mod candidate_hydrators;\nmod candidate_pipeline;\npub mod clients; // Excluded from open source release for security reas"
},
{
"path": "home-mixer/main.rs",
"chars": 2378,
"preview": "use clap::Parser;\nuse log::info;\nuse std::time::Duration;\n\nuse tonic::codec::CompressionEncoding;\nuse tonic::service::Ro"
},
{
"path": "home-mixer/query_hydrators/mod.rs",
"chars": 78,
"preview": "pub mod user_action_seq_query_hydrator;\npub mod user_features_query_hydrator;\n"
},
{
"path": "home-mixer/query_hydrators/user_action_seq_query_hydrator.rs",
"chars": 6709,
"preview": "use crate::candidate_pipeline::query::ScoredPostsQuery;\nuse crate::clients::uas_fetcher::{UserActionSequenceFetcher, Use"
},
{
"path": "home-mixer/query_hydrators/user_features_query_hydrator.rs",
"chars": 1511,
"preview": "use crate::candidate_pipeline::query::ScoredPostsQuery;\nuse crate::candidate_pipeline::query_features::UserFeatures;\nuse"
},
{
"path": "home-mixer/scorers/author_diversity_scorer.rs",
"chars": 2320,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/scorers/mod.rs",
"chars": 102,
"preview": "pub mod author_diversity_scorer;\npub mod oon_scorer;\npub mod phoenix_scorer;\npub mod weighted_scorer;\n"
},
{
"path": "home-mixer/scorers/oon_scorer.rs",
"chars": 1135,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/scorers/phoenix_scorer.rs",
"chars": 6786,
"preview": "use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};\nuse crate::candidate_pipeline::query::ScoredPo"
},
{
"path": "home-mixer/scorers/weighted_scorer.rs",
"chars": 3437,
"preview": "use crate::candidate_pipeline::candidate::{PhoenixScores, PostCandidate};\nuse crate::candidate_pipeline::query::ScoredPo"
},
{
"path": "home-mixer/selectors/mod.rs",
"chars": 76,
"preview": "mod top_k_score_selector;\n\npub use top_k_score_selector::TopKScoreSelector;\n"
},
{
"path": "home-mixer/selectors/top_k_score_selector.rs",
"chars": 493,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/server.rs",
"chars": 3334,
"preview": "use crate::candidate_pipeline::candidate::CandidateHelpers;\nuse crate::candidate_pipeline::phoenix_candidate_pipeline::P"
},
{
"path": "home-mixer/side_effects/cache_request_info_side_effect.rs",
"chars": 1443,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/side_effects/mod.rs",
"chars": 40,
"preview": "pub mod cache_request_info_side_effect;\n"
},
{
"path": "home-mixer/sources/mod.rs",
"chars": 48,
"preview": "pub mod phoenix_source;\npub mod thunder_source;\n"
},
{
"path": "home-mixer/sources/phoenix_source.rs",
"chars": 1852,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "home-mixer/sources/thunder_source.rs",
"chars": 2797,
"preview": "use crate::candidate_pipeline::candidate::PostCandidate;\nuse crate::candidate_pipeline::query::ScoredPostsQuery;\nuse cra"
},
{
"path": "phoenix/README.md",
"chars": 10474,
"preview": "# Phoenix: Recommendation System\n\nThis repository contains JAX example code for the Phoenix recommendation system, which"
},
{
"path": "phoenix/grok.py",
"chars": 18731,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "phoenix/pyproject.toml",
"chars": 536,
"preview": "[project]\nname = \"grok-1\"\nversion = \"0.1.0\"\ndescription = \"Grok-1 model\"\nreadme = \"README.md\"\nrequires-python = \">=3.11\""
},
{
"path": "phoenix/recsys_model.py",
"chars": 15808,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "phoenix/recsys_retrieval_model.py",
"chars": 12838,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "phoenix/run_ranker.py",
"chars": 4030,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "phoenix/run_retrieval.py",
"chars": 4819,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "phoenix/runners.py",
"chars": 23996,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "phoenix/test_recsys_model.py",
"chars": 6708,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "phoenix/test_recsys_retrieval_model.py",
"chars": 12672,
"preview": "# Copyright 2026 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this f"
},
{
"path": "thunder/deserializer.rs",
"chars": 1109,
"preview": "use crate::schema::{events::Event, tweet_events::TweetEvent};\nuse anyhow::{Context, Result};\nuse prost::Message;\nuse thr"
},
{
"path": "thunder/kafka/mod.rs",
"chars": 80,
"preview": "pub mod tweet_events_listener;\npub mod tweet_events_listener_v2;\npub mod utils;\n"
},
{
"path": "thunder/kafka/tweet_events_listener.rs",
"chars": 14217,
"preview": "use anyhow::{Context, Result};\nuse log::{error, info, warn};\nuse prost::Message;\nuse std::sync::Arc;\nuse std::sync::atom"
},
{
"path": "thunder/kafka/tweet_events_listener_v2.rs",
"chars": 9785,
"preview": "use anyhow::Result;\nuse log::{info, warn};\nuse std::sync::Arc;\nuse std::sync::atomic::{AtomicUsize, Ordering};\nuse std::"
},
{
"path": "thunder/kafka/utils.rs",
"chars": 1420,
"preview": "use anyhow::{Context, Result};\nuse std::sync::Arc;\nuse tokio::sync::RwLock;\nuse xai_kafka::{KafkaMessage, config::KafkaC"
},
{
"path": "thunder/kafka_utils.rs",
"chars": 4312,
"preview": "use anyhow::{Context, Result};\nuse std::sync::Arc;\nuse xai_kafka::KafkaProducerConfig;\nuse xai_kafka::config::{KafkaConf"
},
{
"path": "thunder/lib.rs",
"chars": 196,
"preview": "pub mod args;\npub mod config;\npub mod deserializer;\npub mod kafka;\npub mod kafka_utils;\npub mod metrics;\npub mod o2;\npub"
},
{
"path": "thunder/main.rs",
"chars": 3119,
"preview": "use anyhow::{Context, Result};\nuse axum::Router;\nuse clap::Parser;\nuse log::info;\nuse std::sync::Arc;\nuse std::time::{Du"
},
{
"path": "thunder/posts/mod.rs",
"chars": 20,
"preview": "pub mod post_store;\n"
},
{
"path": "thunder/posts/post_store.rs",
"chars": 20066,
"preview": "use anyhow::Result;\nuse dashmap::DashMap;\nuse log::info;\nuse std::collections::{HashSet, VecDeque};\nuse std::sync::Arc;\n"
},
{
"path": "thunder/thunder_service.rs",
"chars": 12487,
"preview": "use lazy_static::lazy_static;\nuse log::{debug, info, warn};\nuse std::cmp::Reverse;\nuse std::collections::HashSet;\nuse st"
}
]
About this extraction
This page contains the full source code of the xai-org/x-algorithm GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 78 files (300.0 KB), approximately 70.1k tokens, and a symbol index with 387 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.