TabCaps

A capsule network that encapsulates all feature values of a record into vectorial features.

Functions

class TabCapsModel(BaseEstimator)

TabCaps model for tabular data classification using capsule networks.

Parameters:

  • decode (bool, optional, Default is False) - Whether to use reconstruction.

  • mean (int, optional) - Mean value for normalization.

  • std (int, optional) - Standard deviation for normalization.

  • sub_class (int, optional, Default is 1) - Number of sub-classes.

  • init_dim (int, optional) - Initial dimension.

  • primary_capsule_size (int, optional, Default is 16) - Primary capsule size.

  • digit_capsule_size (int, optional, Default is 16) - Digit capsule size.

  • leaves (int, optional, Default is 32) - Number of leaves.

  • seed (int, optional, Default is 0) - Random seed.

  • verbose (int, optional, Default is 1) - Verbosity level.

  • optimizer_fn (Any, optional) - Optimizer function.

  • optimizer_params (Dict, optional) - Optimizer parameters.

  • scheduler_fn (Any, optional) - Scheduler function.

  • scheduler_params (Dict, optional) - Scheduler parameters.

  • input_dim (int, optional) - Input dimension.

  • output_dim (int, optional) - Output dimension.

  • device_name (str, optional, Default is “auto”) - Device name.

Methods:

  • fit(self, X_train, y_train, eval_set=None, eval_name=None, eval_metric=None, max_epochs=100, patience=10, batch_size=1024, virtual_batch_size=256, callbacks=None, logname=None, resume_dir=None, device_id=None, cfg=None) - Train the model.

  • predict(self, X, y, decode=False) - Make predictions.

  • save_check(self, path, seed) - Save model checkpoint.

  • load_model(self, filepath, input_dim, output_dim) - Load saved model.

class CapsuleClassifier(nn.Module)

Capsule network classifier for tabular data.

Parameters:

  • input_dim (int) - Input dimension.

  • output_dim (int) - Output dimension.

  • out_capsule_num (int) - Number of output capsules.

  • init_dim (int) - Initial dimension.

  • primary_capsule_dim (int) - Primary capsule dimension.

  • digit_capsule_dim (int) - Digit capsule dimension.

  • n_leaves (int) - Number of leaves.

Input:

  • x (Tensor) - Input tensor.

Output:

  • Tensor - Classification output.

class ReconstructCapsNet(nn.Module)

Capsule network with reconstruction capabilities.

Parameters:

  • input_dim (int) - Input dimension.

  • output_dim (int) - Output dimension.

  • out_capsule_num (int) - Number of output capsules.

  • init_dim (int) - Initial dimension.

  • primary_capsule_dim (int) - Primary capsule dimension.

  • digit_capsule_dim (int) - Digit capsule dimension.

  • n_leaves (int) - Number of leaves.

Input:

  • x (Tensor) - Input tensor.

  • y_one_hot (Tensor) - One-hot encoded labels.

Output:

  • tuple - (classification_output, reconstruction_output).

class MarginLoss(nn.Module)

Margin loss for capsule networks.

Parameters:

  • m_plus (float, optional, Default is 0.9) - Positive margin.

  • m_minus (float, optional, Default is 0.1) - Negative margin.

  • lambda_val (float, optional, Default is 0.5) - Lambda value.

Input:

  • y_pred (Tensor) - Predicted outputs.

  • y_true (Tensor) - True labels.

Output:

  • Tensor - Loss value.

class AbstractLayer(nn.Module)

Abstract layer for capsule networks.

Parameters:

  • base_input_dim (int) - Base input dimension.

  • base_output_dim (int) - Base output dimension.

  • k (int) - Number of branches.

  • virtual_batch_size (int) - Virtual batch size.

  • bias (bool, optional, Default is False) - Whether to use bias.

Input:

  • x (Tensor) - Input tensor.

Output:

  • Tensor - Layer output.

References:

Jintai Chen, Kuanlun Liao, Yanwen Fang, Danny Z. Chen, Jian Wu. TABCAPS: A CAPSULE NEURAL NETWORK FOR TABULAR DATA CLASSIFICATION WITH BOW ROUTING. In Proceedings of the 11th International Conference on Learning Representations, 2023. https://openreview.net/pdf?id=G32oY4Vnm8