This documentation is part of the "Projects with Books" initiative at zenOSmosis.
The source code for this project is available on GitHub.
Machine Learning Training Pipeline
Loading…
Machine Learning Training Pipeline
Relevant source files
Purpose and Scope
This page documents the machine learning training pipeline for the narrative_stack system, specifically the Stage 1 autoencoder that learns latent representations of US GAAP financial concepts. The training pipeline consumes preprocessed concept/unit/value triplets and their semantic embeddings to train a neural network that can encode financial data into a compressed latent space.
The pipeline uses PyTorch Lightning for training orchestration, implements custom iterable datasets for efficient data streaming from the simd-r-drive WebSocket server, and provides comprehensive experiment tracking through TensorBoard.
Training Pipeline Architecture
The training pipeline operates as a streaming system that continuously fetches preprocessed triplets from the UsGaapStore and feeds them through the autoencoder model. The architecture emphasizes memory efficiency and reproducibility by avoiding full-dataset loads into RAM.
graph TB
subgraph "Data Source Layer"
DataStoreWsClient["DataStoreWsClient\n(simd_r_drive_ws_client)"]
UsGaapStore["UsGaapStore\nlookup_by_index()"]
end
subgraph "Dataset Layer"
IterableDataset["IterableConceptValueDataset\ninternal_batch_size=64\nreturn_scaler=True\nshuffle=True/False"]
CollateFunction["collate_with_scaler()\nBatch construction"]
end
subgraph "PyTorch Lightning Training Loop"
DataLoader["DataLoader\nbatch_size from hparams\nnum_workers=2\npin_memory=True\npersistent_workers=True"]
Model["Stage1Autoencoder\nEncoder → Latent → Decoder"]
Optimizer["Adam Optimizer\n+ CosineAnnealingWarmRestarts\nReduceLROnPlateau"]
Trainer["pl.Trainer\nEarlyStopping\nModelCheckpoint\ngradient_clip_val"]
end
subgraph "Monitoring & Persistence"
TensorBoard["TensorBoardLogger\ntrain_loss\nval_loss_epoch\nlearning_rate"]
Checkpoints["Model Checkpoints\n.ckpt files\nsave_top_k=1\nmonitor='val_loss_epoch'"]
end
DataStoreWsClient --> UsGaapStore
UsGaapStore --> IterableDataset
IterableDataset --> CollateFunction
CollateFunction --> DataLoader
DataLoader --> Model
Model --> Optimizer
Optimizer --> Trainer
Trainer --> TensorBoard
Trainer --> Checkpoints
Checkpoints -.->|Resume training| Model
Natural Language to Code Entity Space: Data Flow
Sources: python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb456-556
Stage1Autoencoder Model
Model Architecture
The Stage1Autoencoder is a fully-connected autoencoder that learns to compress financial concept embeddings combined with their scaled values into a lower-dimensional latent space. The model reconstructs its input, forcing the latent representation to capture the most important features of the US GAAP taxonomy.
graph LR
Input["Input Tensor\n[embedding + scaled_value]\nDimension: embedding_dim + 1"]
Encoder["Encoder Network\nfc1 → dropout → ReLU\nfc2 → dropout → ReLU"]
Latent["Latent Space\nDimension: latent_dim"]
Decoder["Decoder Network\nfc3 → dropout → ReLU\nfc4 → dropout → output"]
Output["Reconstructed Input\nSame dimension as input"]
Loss["MSE Loss\ninput vs output"]
Input --> Encoder
Encoder --> Latent
Latent --> Decoder
Decoder --> Output
Output --> Loss
Input --> Loss
Hyperparameters
The model exposes the following configurable hyperparameters through its hparams attribute:
| Parameter | Description | Typical Value |
|---|---|---|
input_dim | Dimension of input (embedding + 1 for value) | Varies based on embedding size |
latent_dim | Dimension of compressed latent space | 64-128 |
dropout_rate | Dropout probability for regularization | 0.0-0.2 |
lr | Initial learning rate | 1e-5 to 5e-5 |
min_lr | Minimum learning rate for scheduler | 1e-7 to 1e-6 |
batch_size | Training batch size | 32 |
weight_decay | L2 regularization parameter | 1e-8 to 1e-4 |
gradient_clip | Maximum gradient norm | 0.0-1.0 |
Sources: python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb479-490
Loss Function and Optimization
The model uses Mean Squared Error (MSE) loss between the input and reconstructed output. The optimization strategy combines:
- Adam optimizer with configurable learning rate and weight decay. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb479-490
- CosineAnnealingWarmRestarts scheduler for cyclical learning rate annealing.
- ReduceLROnPlateau for adaptive learning rate reduction when validation loss plateaus.
Dataset and Data Loading
IterableConceptValueDataset
The IterableConceptValueDataset is a custom PyTorch IterableDataset that streams training data from the UsGaapStore without loading the entire dataset into memory.
Key characteristics:
graph TB
subgraph "Dataset Initialization"
Config["simd_r_drive_server_config\nhost + port"]
Params["Dataset Parameters\ninternal_batch_size\nreturn_scaler\nshuffle"]
end
subgraph "Data Streaming Process"
Store["UsGaapStore instance\nget_triplet_count()"]
IndexGen["Index Generator\nSequential or shuffled\nbased on shuffle param"]
BatchFetch["Internal Batching\nFetch internal_batch_size items\nvia batch_lookup_by_indices()"]
Unpack["Unpack Triplet Data\nembedding\nscaled_value\nscaler (optional)"]
end
subgraph "Output"
Tensor["PyTorch Tensors\nx: [embedding + scaled_value]\ny: [embedding + scaled_value]\nscaler: RobustScaler object"]
end
Config --> Store
Params --> IndexGen
Store --> IndexGen
IndexGen --> BatchFetch
BatchFetch --> Unpack
Unpack --> Tensor
- Iterable streaming : Data is fetched on-demand during iteration. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb502-522
- Internal batching : Fetches
internal_batch_sizeitems (typically 64) at once from the WebSocket server to reduce network overhead. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb175-176 - Optional shuffling : Randomizes index order for training or maintains sequential order for validation.
DataLoader and Collation
The collate_with_scaler function handles batch construction when the dataset returns triplets (x, y, scaler). It stacks the tensors into batches while preserving the scaler objects in a list. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb506-507
| Parameter | Value | Purpose |
|---|---|---|
batch_size | model.hparams.batch_size | Outer batch size for model training. |
num_workers | 2 | Parallel data loading processes. |
pin_memory | True | Faster host-to-GPU transfers. |
persistent_workers | True | Keeps worker processes alive between epochs. |
Sources: python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb175-176 python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb502-522
Training Configuration
PyTorch Lightning Trainer Setup
The training pipeline uses PyTorch Lightning’s Trainer class to orchestrate the training loop, validation, and callbacks. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb468-548
Callbacks and Persistence
- EarlyStopping : Monitors
val_loss_epochand stops training if no improvement occurs for 20 consecutive epochs. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb528-530 - ModelCheckpoint : Saves the best model weights based on validation loss to the
OUTPUT_PATH. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb532-539 - TensorBoardLogger : Automatically logs
train_loss,val_loss_epoch, andlearning_rate. python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb543
Sources: python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb528-548
Checkpointing and Resuming Training
The pipeline supports resuming training from a .ckpt file. This is handled by passing ckpt_path to trainer.fit(). python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb555
Alternatively, a model can be loaded for fine-tuning with modified hyperparameters:
python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb479-486
Integration with Rust Caches
While the training occurs in Python, the underlying data is often derived from the Rust preprocessor. The Caches struct in the Rust application manages the preprocessor_cache.bin and http_storage_cache.bin src/caches.rs:11-14 which provide the raw data that the Python UsGaapStore eventually consumes. The Caches::open function src/caches.rs:29-51 ensures these data stores are correctly initialized on disk before the training pipeline attempts to access them via the WebSocket bridge.
Sources: src/caches.rs:11-60 python/narrative_stack/notebooks/old.stage1_training_(no_pre_dedupe).ipynb456-556
Dismiss
Refresh this wiki
Enter email to refresh