Trompt

A prompt-based deep neural network for tabular data that separates learning into intrinsic column features and sample-specific feature importance.

Functions

class LinearEmbeddings(nn.Module)

Linear embeddings for continuous features.

Parameters:

  • n_features (int) - Number of continuous features.

  • d_embedding (int) - Embedding dimension.

Input Shape:

(*, n_features)

Output Shape:

(*, n_features, d_embedding)

class CategoricalEmbeddings1d(nn.Module)

Embeddings for categorical features with support for unknown categories.

Parameters:

  • cardinalities (list[int]) - List of category counts for each feature.

  • d_embedding (int) - Embedding dimension.

Input Shape:

(*, n_cat_features)

Output Shape:

(*, n_cat_features, d_embedding)

class ImportanceGetter(nn.Module)

Computes feature importance scores using prompts and input features (Figure 3 Part 1).

Parameters:

  • P (int) - Number of prompts.

  • C (int) - Total number of features (numerical + categorical).

  • d (int) - Embedding dimension.

Input:

  • O (Tensor) - Previous output tensor.

Output:

  • Tensor - Feature importance matrix.

class TromptEmbedding(nn.Module)

Combines numerical and categorical embeddings (Figure 3 Part 2).

Parameters:

  • n_num_features (int) - Number of numerical features.

  • cat_cardinalities (list[int]) - List of category counts for categorical features.

  • d (int) - Embedding dimension.

Inputs:

  • x_num (Tensor) - Numerical feature tensor.

  • x_cat (Tensor) - Categorical feature tensor.

Output:

  • Tensor - Combined embeddings.

class Expander(nn.Module)

Expands input features using a linear layer and group normalization (Figure 3 Part 3).

Parameters:

  • P (int) - Number of prompts.

Input:

  • x (Tensor) - Input tensor.

Output:

  • Tensor - Expanded tensor.

class TromptCell(nn.Module)

Complete Trompt cell that combines embedding, importance calculation, and expansion.

Parameters:

  • n_num_features (int) - Number of numerical features.

  • cat_cardinalities (list[int]) - List of category counts for categorical features.

  • P (int) - Number of prompts.

  • d (int) - Embedding dimension.

Inputs:

  • x_num (Tensor) - Numerical feature tensor.

  • x_cat (Tensor) - Categorical feature tensor.

  • O (Tensor) - Previous output tensor.

Output:

  • Tensor - Processed tensor.

class TromptDecoder(nn.Module)

Decodes the output of the Trompt cells into final predictions.

Parameters:

  • d (int) - Input dimension.

  • d_out (int) - Output dimension.

Input:

  • o (Tensor) - Input tensor from Trompt cells.

Output:

  • Tensor - Decoded predictions.

Reference

Kuan-Yu Chen, Ping-Han Chiang, Hsin-Rung Chou, Ting-Wei Chen, and Tien-Hao Chang. Trompt: Towards a Better Deep Neural Network for Tabular Data. arXiv:2305.18446 [cs.LG], 2023. https://arxiv.org/abs/2305.18446