Methods
Deep learning method implementations that wrap model architectures with training logic.
This section contains method classes that provide a unified interface for training, validation, and prediction across all deep learning models in TALENT. Each method class inherits from the base Method class and implements model-specific logic while maintaining consistent APIs.
Base Method Class (method/base.py)
All method implementations inherit from the base Method class which provides the core training, validation, and prediction workflow. Understanding this base class is essential for using any method in TALENT.
- class Method
Abstract base class that provides a unified interface for all deep learning methods in TALENT.
Key Features:
Consistent training/validation/prediction workflow across all models
Automatic data preprocessing and formatting with multiple encoding options
Model construction and optimization setup with configurable parameters
Early stopping and checkpoint management for robust training
Comprehensive evaluation metrics for regression and classification tasks
Flexible handling of numerical and categorical features
Core Base Methods
- __init__(args, is_regression)
Initialize the method with configuration and task type.
Parameters:
args (argparse.Namespace) – Configuration arguments containing model, training, and data processing settings
is_regression (bool) – Whether the task is regression (True) or classification (False)
Initialization Process:
Configuration Setup: Store arguments and determine task type
Statistics Reset: Initialize training counters and timers
Logging Setup: Create training log dictionary with best performance tracking
Device Setup: Configure GPU/CPU device for training
- construct_model(model_config=None)
Abstract method to construct the specific model architecture. Must be implemented by each method subclass.
Parameters:
model_config (dict, optional) – Model-specific configuration parameters. If None, uses args.config[‘model’]
Implementation Notes:
Each method class overrides this to create its specific model type
Model is moved to appropriate device (GPU/CPU) and set to correct precision (float/double)
Configuration parameters are model-specific (e.g., hidden dimensions, number of layers, etc.)
- data_format(is_train=True, N=None, C=None, y=None)
Format and preprocess data for training or inference. This is the core data processing pipeline.
Parameters:
is_train (bool) – Whether formatting for training (True) or inference (False)
N (dict, optional) – Numerical features dictionary with train/val/test splits
C (dict, optional) – Categorical features dictionary with train/val/test splits
y (dict, optional) – Target labels dictionary with train/val/test splits
Training Mode Processing Pipeline:
NaN Handling: .. code-block:: python
- self.N, self.C, self.num_new_value, self.imputer, self.cat_new_value =
data_nan_process(self.N, self.C, self.args.num_nan_policy, self.args.cat_nan_policy)
Label Processing: .. code-block:: python
- self.y, self.y_info, self.label_encoder =
data_label_process(self.y, self.is_regression)
Numerical Encoding: Apply binning, quantile transformation, or other numerical policies
Categorical Encoding: Apply ordinal, one-hot, or other categorical encoding strategies
Normalization: Apply standardization, min-max scaling, or quantile normalization
DataLoader Creation: Create PyTorch DataLoaders for training and validation
Inference Mode Processing:
Uses previously fitted transformers (encoders, normalizers) to process test data consistently.
- fit(data, info, train=True, config=None)
Main training method that orchestrates the entire training process.
Parameters:
data (tuple) – (N, C, y) containing numerical features, categorical features, and labels
info (dict) – Dataset information including feature names, types, and metadata
train (bool) – Whether to actually train the model (False for loading checkpoints only)
config (dict, optional) – Override configuration for hyperparameter tuning
Returns:
float – Total training time in seconds
Training Process:
Data Setup: Create Dataset object and extract feature information
Data Processing: Call data_format() to preprocess all data
Model Construction: Call construct_model() to build the neural network
Optimizer Setup: Initialize AdamW optimizer with configured learning rate and weight decay
Training Loop: For each epoch, call train_epoch() and validate()
Early Stopping: Stop training if validation performance doesn’t improve for 20 epochs
Checkpoint Saving: Save best and final model weights
- train_epoch(epoch)
Train the model for one epoch using the training data.
Parameters:
epoch (int) – Current epoch number for logging
Training Steps per Batch:
Feature Extraction: Handle numerical and categorical features appropriately
Forward Pass: Compute model predictions
Loss Computation: Calculate training loss using appropriate criterion
Backward Pass: Compute gradients via backpropagation
Parameter Update: Apply optimizer step
Progress Logging: Display training progress every 50 batches
- validate(epoch)
Validate the model on the validation set and handle early stopping.
Parameters:
epoch (int) – Current epoch number
Validation Process:
Set Evaluation Mode: model.eval() to disable dropout and batch norm updates
Inference Loop: Process validation batches without gradients
Metric Computation: Calculate validation metrics using metric() method
Best Model Tracking: Save model checkpoint if validation performance improved
Early Stopping Logic: Increment counter if no improvement; stop after 20 epochs
Logging: Record validation results and save training log
- predict(data, info, model_name)
Make predictions on test data using a trained model.
Parameters:
data (tuple) – (N, C, y) test data
info (dict) – Dataset information
model_name (str) – Model checkpoint name (‘best-val’ or ‘epoch-last’)
Returns:
tuple – (loss, metrics, metric_names, predictions)
Prediction Process:
Model Loading: Load trained weights from checkpoint
Data Processing: Apply fitted transformers to test data
Inference: Generate predictions in evaluation mode
Metric Calculation: Compute comprehensive evaluation metrics
- metric(predictions, labels, y_info)
Compute comprehensive evaluation metrics based on task type.
Parameters:
predictions (np.ndarray) – Model predictions
labels (np.ndarray) – Ground truth labels
y_info (dict) – Label processing information including classes and normalization details
Returns:
tuple – (metrics_values, metric_names)
Task-Specific Metrics:
Regression Tasks:
MAE: Mean Absolute Error - \(\frac{1}{n}\sum_{i=1}^{n}|y_i - \hat{y}_i|\)
R²: Coefficient of determination - \(1 - \frac{SS_{res}}{SS_{tot}}\)
RMSE: Root Mean Squared Error - \(\sqrt{\frac{1}{n}\sum_{i=1}^{n}(y_i - \hat{y}_i)^2}\)
Binary Classification:
Accuracy: Overall classification accuracy
Balanced Recall: Balanced accuracy score handling class imbalance
Macro Precision: Macro-averaged precision across classes
F1 Score: Binary F1 score - \(2 \cdot \frac{precision \cdot recall}{precision + recall}\)
Log Loss: Cross-entropy loss
AUC: Area under ROC curve for probability predictions
Multi-class Classification:
Accuracy: Overall classification accuracy
Balanced Recall: Balanced accuracy score
Macro Precision: Macro-averaged precision
Macro F1: Macro-averaged F1 score
Log Loss: Cross-entropy loss
Macro AUC: One-vs-Rest AUC score
Utility Functions
- check_softmax(logits)
Ensure logits are properly normalized probabilities for classification tasks.
Parameters:
logits (np.ndarray) – Array of shape (N, C) with raw logits or probabilities
Returns:
np.ndarray – Properly normalized probabilities summing to 1
Mathematical Process:
Probability Check: Verify if values are in [0,1] and sum to 1 per sample
Softmax Application: If not probabilities, apply numerically stable softmax:
\[p_i = \frac{\exp(x_i - \max(x))}{\sum_{j} \exp(x_j - \max(x))}\]Numerical Stability: Subtract max before exp to prevent overflow
Method Implementations
Methods are organized by complexity. Simple methods primarily use base class functionality, while advanced methods implement specialized training procedures.
Simple Methods (Using Base Class Functions)
These methods only override construct_model() and use all base class functionality for training, validation, and prediction.
- class MLPMethod
Multi-Layer Perceptron method - the simplest and most widely applicable method.
Features: Fast training, good baseline performance, suitable for most tabular tasks
Requirements: cat_policy cannot be ‘indices’ (categorical features must be encoded)
Usage: Ideal for beginners or when you need fast training with decent performance
- class ResNetMethod
Residual Network method with skip connections for deeper networks.
Features: Prevents gradient vanishing, supports various activations (ReLU, GELU, ReGLU, GeGLU)
Usage: When you need deeper networks than MLP without gradient problems
- class SNNMethod
Self-Normalizing Network method using SELU activation.
Features: Automatic normalization properties, suitable for deeper networks
Usage: Alternative to ResNet when you want self-normalizing properties
- class NodeMethod
Neural Oblivious Decision Ensembles method implementing neural decision trees.
Features: Tree-like decision making with neural network flexibility
Usage: When you want tree-like interpretability with neural network power
- class GrowNetMethod
Gradient boosting with neural network weak learners.
Features: Combines gradient boosting with neural networks
Usage: For ensemble-based approaches with neural components
- class GrandeMethod
Gradient-boosted neural decision ensembles for tree-mimic behavior.
Features: Neural implementation of decision tree ensembles
Usage: When you want tree ensemble performance with neural network flexibility
Transformer-Based Methods (Using Base Class Functions)
These transformer methods use standard base class training but require specific categorical policies.
- class FTTMethod
Feature Tokenizer Transformer - one of the best performing methods.
Features: Feature tokenization, multi-head attention, state-of-the-art performance
Requirements: cat_policy must be ‘indices’ (uses raw categorical indices)
Usage: First choice for best performance on most datasets
- class SaintMethod
Self-Attention and Intersample Attention Transformer.
Features: Row and column attention, enhanced feature interactions
Usage: For complex datasets where feature interactions are important
- class TabTransformerMethod
Transformer with column-wise attention for categorical features.
Features: Contextual embeddings, strong on categorical-heavy datasets
Usage: When your dataset has many important categorical features
Advanced Methods with Custom Data Processing
These methods override data_format() and implement specialized data handling.
- class TabNetMethod
TabNet with sequential attention and custom data processing pipeline.
Features: * Interpretable sequential feature selection * Custom TabNet-specific data processing * Sparse feature selection with attention visualization
Requirements: cat_policy cannot be ‘indices’ (requires encoded categorical features)
Special Data Processing:
TabNet bypasses the standard data_format() method and implements its own data processing:
- data_format(is_train=True, N=None, C=None, y=None)
Custom data processing optimized for TabNet’s requirements.
Key Differences from Base:
Direct integration with TabNet’s internal categorical handling
Specialized preprocessing for TabNet’s attention mechanism
Custom DataLoader creation for TabNet-compatible data structures
Usage: When you need interpretable predictions with attention visualization
Methods with Custom Training Procedures
These methods implement specialized training workflows by overriding training-related methods.
- class TabRMethod
TabR method implementing KNN-attention hybrid with retrieval-based training.
Features: * Combines KNN retrieval with attention mechanisms * Context-aware predictions using training set as retrieval candidates * Custom fit method with retrieval context management
Requirements: * cat_policy must be ‘tabr_ohe’ * num_policy must be ‘none’
Custom Methods:
- fit(data, info, train=True, config=None)
Enhanced fit method with retrieval-based training setup.
Special Features:
Context Management: Maintains context_size=96 for efficient retrieval
Candidate Selection: Uses full training set as retrieval candidates
Index Tracking: Manages training indices for neighbor search
Retrieval Process:
Setup Retrieval Context: Initialize training indices and context size limits
Model Construction: Build TabR with both attention and KNN components
Candidate Management: Store training data for runtime retrieval
Training Loop: Standard epoch-based training with retrieval context
Usage: Excellent performance on many datasets, especially with clear patterns
- class ExcelFormerMethod
ExcelFormer with semi-permeable attention and mixup training strategies.
Features: * Multiple mixup strategies (feat_mix, hidden_mix, naive_mix) * Mutual information-based feature importance scoring * Custom training process with enhanced data augmentation
Custom Methods:
- fit(data, info, train=True, config=None)
Enhanced fit with mutual information preprocessing.
Preprocessing Steps:
MI Score Computation: Calculate mutual information between features and targets
Feature Ranking: Sort features by importance for mixup weighting
Mixup Configuration: Setup augmentation parameters based on MI scores
- train_epoch(epoch)
Custom training with mixup augmentation strategies.
Mixup Training Process:
Feature Mixup (`mix_type=’feat_mix’`):
\[\begin{split}\lambda = \sum (\text{MI_scores} \odot \text{feat_masks}) \\ \text{loss} = \lambda \cdot \text{criterion}(\text{pred}, y) + (1-\lambda) \cdot \text{criterion}(\text{pred}, y[\text{shuffled}])\end{split}\]Hidden Mixup (`mix_type=’hidden_mix’`):
\[\text{loss} = \text{feat_masks} \cdot \text{criterion}(\text{pred}, y) + (1-\text{feat_masks}) \cdot \text{criterion}(\text{pred}, y[\text{shuffled}])\]
Usage: When you need enhanced generalization through sophisticated data augmentation
- class TromptMethod
Trompt method with prompt-based learning and multiple prediction cycles.
Features: * Prompt-based neural architecture separating intrinsic and sample-specific features * Multiple learning cycles for improved performance * Custom training with repeated targets
Requirements: cat_policy must be ‘indices’
Custom Training:
- train_epoch(epoch)
Custom training with prompt-based multi-cycle learning.
Prompt Learning Process:
Multi-cycle Forward: Uses model.forward_for_training() for multiple prediction cycles
Target Repetition: Repeats targets for each prediction cycle
Cycle Loss: Computes loss across all cycles for robust learning
Mathematical Formulation:
For n_cycles prediction cycles:
\[\begin{split}\text{output} = \text{model.forward_for_training}(X_{num}, X_{cat}) \\ \text{output} = \text{output.view}(-1, d_{out}) \\ y_{repeated} = y.\text{repeat_interleave}(n_{cycles})\end{split}\]
Usage: For complex learning scenarios requiring prompt-based architectures
Methods with Specialized Architectures
- class ModernNCAMethod
Modern Nearest Class Analysis with embedding-based distance learning.
Features: * Embedding space optimization for distance-based classification * Neighborhood sampling for efficient training * Interpretable neighbor relationships
Usage: Excellent for datasets with clear class structure and when interpretability through neighbors is desired
- class ProtoGateMethod
Prototype-based gating method for interpretable feature selection.
Features: * Prototype-based learning with gating mechanisms * Enhanced interpretability through learned prototypes * Suitable for high-dimensional data with sparse relevant features
Usage: When you need prototype-based explanations and adaptive feature selection
Foundation Model Methods
- class TabPFNMethod
Prior-free neural network method using pre-trained foundation models.
Features: * Pre-trained weights for immediate deployment * Zero-shot learning capabilities on new datasets * No gradient-based training required
Usage: For rapid deployment without training when you have limited data or time
- class TabICLMethod
In-context learning method for tabular data using foundation model approaches.
Features: * Meta-learning from diverse tabular datasets * Adaptive to new domains without fine-tuning * Foundation model approach for tabular data
Usage: When you need fast adaptation to new tabular domains
Method Usage Guidelines
For Beginners: Start with MLPMethod or ResNetMethod - they’re simple, fast, and provide good baselines.
For Best Performance: Try FTTMethod, TabNetMethod, or ModernNCAMethod - these often achieve state-of-the-art results.
For Interpretability: Use TabNetMethod (attention visualization), ProtoGateMethod (prototypes), or ModernNCAMethod (neighbors).
For Speed: Choose MLPMethod, SNNMethod, or simple foundation models like TabPFNMethod.
For Complex Features: Consider SaintMethod, TabTransformerMethod, or ExcelFormerMethod with mixup.
Common Usage Pattern:
from TALENT.model.utils import get_method
# Get method class
MethodClass = get_method('ftt') # or 'mlp', 'tabnet', etc.
# Initialize method
method = MethodClass(args, is_regression=True)
# Train the method
time_cost = method.fit(train_data, info)
# Make predictions
loss, metrics, metric_names, predictions = method.predict(test_data, info, 'best-val')