Methods

Deep learning method implementations that wrap model architectures with training logic.

This section contains method classes that provide a unified interface for training, validation, and prediction across all deep learning models in TALENT. Each method class inherits from the base Method class and implements model-specific logic while maintaining consistent APIs.

Base Method Class (method/base.py)

All method implementations inherit from the base Method class which provides the core training, validation, and prediction workflow. Understanding this base class is essential for using any method in TALENT.

class Method

Abstract base class that provides a unified interface for all deep learning methods in TALENT.

Key Features:

  • Consistent training/validation/prediction workflow across all models

  • Automatic data preprocessing and formatting with multiple encoding options

  • Model construction and optimization setup with configurable parameters

  • Early stopping and checkpoint management for robust training

  • Comprehensive evaluation metrics for regression and classification tasks

  • Flexible handling of numerical and categorical features

Core Base Methods

__init__(args, is_regression)

Initialize the method with configuration and task type.

Parameters:

  • args (argparse.Namespace) – Configuration arguments containing model, training, and data processing settings

  • is_regression (bool) – Whether the task is regression (True) or classification (False)

Initialization Process:

  1. Configuration Setup: Store arguments and determine task type

  2. Statistics Reset: Initialize training counters and timers

  3. Logging Setup: Create training log dictionary with best performance tracking

  4. Device Setup: Configure GPU/CPU device for training

construct_model(model_config=None)

Abstract method to construct the specific model architecture. Must be implemented by each method subclass.

Parameters:

  • model_config (dict, optional) – Model-specific configuration parameters. If None, uses args.config[‘model’]

Implementation Notes:

  • Each method class overrides this to create its specific model type

  • Model is moved to appropriate device (GPU/CPU) and set to correct precision (float/double)

  • Configuration parameters are model-specific (e.g., hidden dimensions, number of layers, etc.)

data_format(is_train=True, N=None, C=None, y=None)

Format and preprocess data for training or inference. This is the core data processing pipeline.

Parameters:

  • is_train (bool) – Whether formatting for training (True) or inference (False)

  • N (dict, optional) – Numerical features dictionary with train/val/test splits

  • C (dict, optional) – Categorical features dictionary with train/val/test splits

  • y (dict, optional) – Target labels dictionary with train/val/test splits

Training Mode Processing Pipeline:

  1. NaN Handling: .. code-block:: python

    self.N, self.C, self.num_new_value, self.imputer, self.cat_new_value =

    data_nan_process(self.N, self.C, self.args.num_nan_policy, self.args.cat_nan_policy)

  2. Label Processing: .. code-block:: python

    self.y, self.y_info, self.label_encoder =

    data_label_process(self.y, self.is_regression)

  3. Numerical Encoding: Apply binning, quantile transformation, or other numerical policies

  4. Categorical Encoding: Apply ordinal, one-hot, or other categorical encoding strategies

  5. Normalization: Apply standardization, min-max scaling, or quantile normalization

  6. DataLoader Creation: Create PyTorch DataLoaders for training and validation

Inference Mode Processing:

Uses previously fitted transformers (encoders, normalizers) to process test data consistently.

fit(data, info, train=True, config=None)

Main training method that orchestrates the entire training process.

Parameters:

  • data (tuple) – (N, C, y) containing numerical features, categorical features, and labels

  • info (dict) – Dataset information including feature names, types, and metadata

  • train (bool) – Whether to actually train the model (False for loading checkpoints only)

  • config (dict, optional) – Override configuration for hyperparameter tuning

Returns:

  • float – Total training time in seconds

Training Process:

  1. Data Setup: Create Dataset object and extract feature information

  2. Data Processing: Call data_format() to preprocess all data

  3. Model Construction: Call construct_model() to build the neural network

  4. Optimizer Setup: Initialize AdamW optimizer with configured learning rate and weight decay

  5. Training Loop: For each epoch, call train_epoch() and validate()

  6. Early Stopping: Stop training if validation performance doesn’t improve for 20 epochs

  7. Checkpoint Saving: Save best and final model weights

train_epoch(epoch)

Train the model for one epoch using the training data.

Parameters:

  • epoch (int) – Current epoch number for logging

Training Steps per Batch:

  1. Feature Extraction: Handle numerical and categorical features appropriately

  2. Forward Pass: Compute model predictions

  3. Loss Computation: Calculate training loss using appropriate criterion

  4. Backward Pass: Compute gradients via backpropagation

  5. Parameter Update: Apply optimizer step

  6. Progress Logging: Display training progress every 50 batches

validate(epoch)

Validate the model on the validation set and handle early stopping.

Parameters:

  • epoch (int) – Current epoch number

Validation Process:

  1. Set Evaluation Mode: model.eval() to disable dropout and batch norm updates

  2. Inference Loop: Process validation batches without gradients

  3. Metric Computation: Calculate validation metrics using metric() method

  4. Best Model Tracking: Save model checkpoint if validation performance improved

  5. Early Stopping Logic: Increment counter if no improvement; stop after 20 epochs

  6. Logging: Record validation results and save training log

predict(data, info, model_name)

Make predictions on test data using a trained model.

Parameters:

  • data (tuple) – (N, C, y) test data

  • info (dict) – Dataset information

  • model_name (str) – Model checkpoint name (‘best-val’ or ‘epoch-last’)

Returns:

  • tuple – (loss, metrics, metric_names, predictions)

Prediction Process:

  1. Model Loading: Load trained weights from checkpoint

  2. Data Processing: Apply fitted transformers to test data

  3. Inference: Generate predictions in evaluation mode

  4. Metric Calculation: Compute comprehensive evaluation metrics

metric(predictions, labels, y_info)

Compute comprehensive evaluation metrics based on task type.

Parameters:

  • predictions (np.ndarray) – Model predictions

  • labels (np.ndarray) – Ground truth labels

  • y_info (dict) – Label processing information including classes and normalization details

Returns:

  • tuple – (metrics_values, metric_names)

Task-Specific Metrics:

Regression Tasks:

  • MAE: Mean Absolute Error - \(\frac{1}{n}\sum_{i=1}^{n}|y_i - \hat{y}_i|\)

  • R²: Coefficient of determination - \(1 - \frac{SS_{res}}{SS_{tot}}\)

  • RMSE: Root Mean Squared Error - \(\sqrt{\frac{1}{n}\sum_{i=1}^{n}(y_i - \hat{y}_i)^2}\)

Binary Classification:

  • Accuracy: Overall classification accuracy

  • Balanced Recall: Balanced accuracy score handling class imbalance

  • Macro Precision: Macro-averaged precision across classes

  • F1 Score: Binary F1 score - \(2 \cdot \frac{precision \cdot recall}{precision + recall}\)

  • Log Loss: Cross-entropy loss

  • AUC: Area under ROC curve for probability predictions

Multi-class Classification:

  • Accuracy: Overall classification accuracy

  • Balanced Recall: Balanced accuracy score

  • Macro Precision: Macro-averaged precision

  • Macro F1: Macro-averaged F1 score

  • Log Loss: Cross-entropy loss

  • Macro AUC: One-vs-Rest AUC score

Utility Functions

check_softmax(logits)

Ensure logits are properly normalized probabilities for classification tasks.

Parameters:

  • logits (np.ndarray) – Array of shape (N, C) with raw logits or probabilities

Returns:

  • np.ndarray – Properly normalized probabilities summing to 1

Mathematical Process:

  1. Probability Check: Verify if values are in [0,1] and sum to 1 per sample

  2. Softmax Application: If not probabilities, apply numerically stable softmax:

    \[p_i = \frac{\exp(x_i - \max(x))}{\sum_{j} \exp(x_j - \max(x))}\]
  3. Numerical Stability: Subtract max before exp to prevent overflow

Method Implementations

Methods are organized by complexity. Simple methods primarily use base class functionality, while advanced methods implement specialized training procedures.

Simple Methods (Using Base Class Functions)

These methods only override construct_model() and use all base class functionality for training, validation, and prediction.

class MLPMethod

Multi-Layer Perceptron method - the simplest and most widely applicable method.

Features: Fast training, good baseline performance, suitable for most tabular tasks

Requirements: cat_policy cannot be ‘indices’ (categorical features must be encoded)

Usage: Ideal for beginners or when you need fast training with decent performance

class ResNetMethod

Residual Network method with skip connections for deeper networks.

Features: Prevents gradient vanishing, supports various activations (ReLU, GELU, ReGLU, GeGLU)

Usage: When you need deeper networks than MLP without gradient problems

class SNNMethod

Self-Normalizing Network method using SELU activation.

Features: Automatic normalization properties, suitable for deeper networks

Usage: Alternative to ResNet when you want self-normalizing properties

class NodeMethod

Neural Oblivious Decision Ensembles method implementing neural decision trees.

Features: Tree-like decision making with neural network flexibility

Usage: When you want tree-like interpretability with neural network power

class GrowNetMethod

Gradient boosting with neural network weak learners.

Features: Combines gradient boosting with neural networks

Usage: For ensemble-based approaches with neural components

class GrandeMethod

Gradient-boosted neural decision ensembles for tree-mimic behavior.

Features: Neural implementation of decision tree ensembles

Usage: When you want tree ensemble performance with neural network flexibility

Transformer-Based Methods (Using Base Class Functions)

These transformer methods use standard base class training but require specific categorical policies.

class FTTMethod

Feature Tokenizer Transformer - one of the best performing methods.

Features: Feature tokenization, multi-head attention, state-of-the-art performance

Requirements: cat_policy must be ‘indices’ (uses raw categorical indices)

Usage: First choice for best performance on most datasets

class SaintMethod

Self-Attention and Intersample Attention Transformer.

Features: Row and column attention, enhanced feature interactions

Usage: For complex datasets where feature interactions are important

class TabTransformerMethod

Transformer with column-wise attention for categorical features.

Features: Contextual embeddings, strong on categorical-heavy datasets

Usage: When your dataset has many important categorical features

Advanced Methods with Custom Data Processing

These methods override data_format() and implement specialized data handling.

class TabNetMethod

TabNet with sequential attention and custom data processing pipeline.

Features: * Interpretable sequential feature selection * Custom TabNet-specific data processing * Sparse feature selection with attention visualization

Requirements: cat_policy cannot be ‘indices’ (requires encoded categorical features)

Special Data Processing:

TabNet bypasses the standard data_format() method and implements its own data processing:

data_format(is_train=True, N=None, C=None, y=None)

Custom data processing optimized for TabNet’s requirements.

Key Differences from Base:

  • Direct integration with TabNet’s internal categorical handling

  • Specialized preprocessing for TabNet’s attention mechanism

  • Custom DataLoader creation for TabNet-compatible data structures

Usage: When you need interpretable predictions with attention visualization

Methods with Custom Training Procedures

These methods implement specialized training workflows by overriding training-related methods.

class TabRMethod

TabR method implementing KNN-attention hybrid with retrieval-based training.

Features: * Combines KNN retrieval with attention mechanisms * Context-aware predictions using training set as retrieval candidates * Custom fit method with retrieval context management

Requirements: * cat_policy must be ‘tabr_ohe’ * num_policy must be ‘none’

Custom Methods:

fit(data, info, train=True, config=None)

Enhanced fit method with retrieval-based training setup.

Special Features:

  • Context Management: Maintains context_size=96 for efficient retrieval

  • Candidate Selection: Uses full training set as retrieval candidates

  • Index Tracking: Manages training indices for neighbor search

Retrieval Process:

  1. Setup Retrieval Context: Initialize training indices and context size limits

  2. Model Construction: Build TabR with both attention and KNN components

  3. Candidate Management: Store training data for runtime retrieval

  4. Training Loop: Standard epoch-based training with retrieval context

Usage: Excellent performance on many datasets, especially with clear patterns

class ExcelFormerMethod

ExcelFormer with semi-permeable attention and mixup training strategies.

Features: * Multiple mixup strategies (feat_mix, hidden_mix, naive_mix) * Mutual information-based feature importance scoring * Custom training process with enhanced data augmentation

Custom Methods:

fit(data, info, train=True, config=None)

Enhanced fit with mutual information preprocessing.

Preprocessing Steps:

  1. MI Score Computation: Calculate mutual information between features and targets

  2. Feature Ranking: Sort features by importance for mixup weighting

  3. Mixup Configuration: Setup augmentation parameters based on MI scores

train_epoch(epoch)

Custom training with mixup augmentation strategies.

Mixup Training Process:

Feature Mixup (`mix_type=’feat_mix’`):

\[\begin{split}\lambda = \sum (\text{MI_scores} \odot \text{feat_masks}) \\ \text{loss} = \lambda \cdot \text{criterion}(\text{pred}, y) + (1-\lambda) \cdot \text{criterion}(\text{pred}, y[\text{shuffled}])\end{split}\]

Hidden Mixup (`mix_type=’hidden_mix’`):

\[\text{loss} = \text{feat_masks} \cdot \text{criterion}(\text{pred}, y) + (1-\text{feat_masks}) \cdot \text{criterion}(\text{pred}, y[\text{shuffled}])\]

Usage: When you need enhanced generalization through sophisticated data augmentation

class TromptMethod

Trompt method with prompt-based learning and multiple prediction cycles.

Features: * Prompt-based neural architecture separating intrinsic and sample-specific features * Multiple learning cycles for improved performance * Custom training with repeated targets

Requirements: cat_policy must be ‘indices’

Custom Training:

train_epoch(epoch)

Custom training with prompt-based multi-cycle learning.

Prompt Learning Process:

  1. Multi-cycle Forward: Uses model.forward_for_training() for multiple prediction cycles

  2. Target Repetition: Repeats targets for each prediction cycle

  3. Cycle Loss: Computes loss across all cycles for robust learning

Mathematical Formulation:

For n_cycles prediction cycles:

\[\begin{split}\text{output} = \text{model.forward_for_training}(X_{num}, X_{cat}) \\ \text{output} = \text{output.view}(-1, d_{out}) \\ y_{repeated} = y.\text{repeat_interleave}(n_{cycles})\end{split}\]

Usage: For complex learning scenarios requiring prompt-based architectures

Methods with Specialized Architectures

class ModernNCAMethod

Modern Nearest Class Analysis with embedding-based distance learning.

Features: * Embedding space optimization for distance-based classification * Neighborhood sampling for efficient training * Interpretable neighbor relationships

Usage: Excellent for datasets with clear class structure and when interpretability through neighbors is desired

class ProtoGateMethod

Prototype-based gating method for interpretable feature selection.

Features: * Prototype-based learning with gating mechanisms * Enhanced interpretability through learned prototypes * Suitable for high-dimensional data with sparse relevant features

Usage: When you need prototype-based explanations and adaptive feature selection

Foundation Model Methods

class TabPFNMethod

Prior-free neural network method using pre-trained foundation models.

Features: * Pre-trained weights for immediate deployment * Zero-shot learning capabilities on new datasets * No gradient-based training required

Usage: For rapid deployment without training when you have limited data or time

class TabICLMethod

In-context learning method for tabular data using foundation model approaches.

Features: * Meta-learning from diverse tabular datasets * Adaptive to new domains without fine-tuning * Foundation model approach for tabular data

Usage: When you need fast adaptation to new tabular domains

Method Usage Guidelines

For Beginners: Start with MLPMethod or ResNetMethod - they’re simple, fast, and provide good baselines.

For Best Performance: Try FTTMethod, TabNetMethod, or ModernNCAMethod - these often achieve state-of-the-art results.

For Interpretability: Use TabNetMethod (attention visualization), ProtoGateMethod (prototypes), or ModernNCAMethod (neighbors).

For Speed: Choose MLPMethod, SNNMethod, or simple foundation models like TabPFNMethod.

For Complex Features: Consider SaintMethod, TabTransformerMethod, or ExcelFormerMethod with mixup.

Common Usage Pattern:

from TALENT.model.utils import get_method

# Get method class
MethodClass = get_method('ftt')  # or 'mlp', 'tabnet', etc.

# Initialize method
method = MethodClass(args, is_regression=True)

# Train the method
time_cost = method.fit(train_data, info)

# Make predictions
loss, metrics, metric_names, predictions = method.predict(test_data, info, 'best-val')