PTaRL
A regularization-based framework that enhances prediction by constructing and projecting into a prototype-based space.
Functions
def run_one_epoch(model, data_loader, loss_func, model_type, config, regularize, ot_weight, diversity_weight, r_weight, diversity, optimizer=None)
Runs one training epoch with optional OT (Optimal Transport) and diversity regularization.
Parameters:
model - Neural network model.
data_loader - Data loader for training.
loss_func - Loss function.
model_type (str) - Model type (e.g., ‘ot’ for optimal transport).
config - Configuration dictionary.
regularize (bool) - Whether to apply regularization.
ot_weight (float) - Weight for OT loss.
diversity_weight (float) - Weight for diversity loss.
r_weight (float) - Weight for regularization loss.
diversity (bool) - Whether to apply diversity regularization.
optimizer (Optional) - Optimizer for training.
Returns:
float - Average loss for the epoch.
def run_one_epoch_val(model, data_loader, loss_func, model_type, config, is_regression)
Runs one validation epoch.
Parameters:
model - Neural network model.
data_loader - Data loader for validation.
loss_func - Loss function.
model_type (str) - Model type.
config - Configuration dictionary.
is_regression (bool) - Whether the task is regression.
Returns:
tuple - Predictions and ground truth labels.
def fit_Ptarl(args, model, train_loader, val_loader, loss_func, model_type, config, regularize, is_regression, ot_weight, diversity_weight, r_weight, diversity, seed, save_path)
Fits a PTaRL model with training and validation.
Parameters:
args - Command line arguments.
model - Neural network model.
train_loader - Training data loader.
val_loader - Validation data loader.
loss_func - Loss function.
model_type (str) - Model type.
config - Configuration dictionary.
regularize (bool) - Whether to apply regularization.
is_regression (bool) - Whether the task is regression.
ot_weight (float) - Weight for OT loss.
diversity_weight (float) - Weight for diversity loss.
r_weight (float) - Weight for regularization loss.
diversity (bool) - Whether to apply diversity regularization.
seed (int) - Random seed.
save_path (str) - Path to save model.
def test(model, test_loader, no_ot=False)
Tests a trained model.
Parameters:
model - Trained model.
test_loader - Test data loader.
no_ot (bool, optional, Default is False) - Whether to disable OT.
Returns:
tuple - Predictions and ground truth labels.
def generate_topic(model, train_loader, n_clusters)
Generates topics from trained model.
Parameters:
model - Trained model.
train_loader - Training data loader.
n_clusters (int) - Number of clusters/topics.
Returns:
np.ndarray - Generated topics.
References:
Hangting Ye, Wei Fan, Xiaozhuang Song, Shun Zheng, He Zhao, Dandan Guo, and Yi Chang. PTARL: Prototype-based Tabular Representation Learning via Space Calibration. In Proceedings of the Twelfth International Conference on Learning Representations, 2024. https://openreview.net/pdf?id=G32oY4Vnm8