autoprognosis.plugins.prediction.classifiers.plugin_tabnet module
- class TabNetPlugin(n_d: int = 64, n_a: int = 64, lr: float = 0.001, n_steps: int = 3, gamma: float = 1.5, n_independent: int = 2, n_shared: int = 2, lambda_sparse: float = 0.0001, momentum: float = 0.3, clip_value: float = 2.0, max_epochs: int = 1000, patience: int = 20, batch_size: int = 50, random_state: int = 0, **kwargs: Any)
Bases:
autoprognosis.plugins.prediction.classifiers.base.ClassifierPlugin
Classification plugin based on TabNet. TabNet uses sequential attention to choose which features to reason from at each decision step, enabling interpretability and more efficient learning as the learning capacity is used for the most salient features.
- Parameters
n_d – int Width of the decision prediction layer. Bigger values gives more capacity to the model with the risk of overfitting. Values typically range from 8 to 64.
n_a – int Width of the attention embedding for each mask. According to the paper n_d=n_a is usually a good choice. (default=8)
lr – float Learning rate
n_steps – int Number of steps in the architecture (usually between 3 and 10)
gamma – float This is the coefficient for feature reusage in the masks. A value close to 1 will make mask selection least correlated between layers. Values range from 1.0 to 2.0.
n_independent – int Number of independent Gated Linear Units layers at each step. Usual values range from 1 to 5.
n_shared – int Number of shared Gated Linear Units at each step Usual values range from 1 to 5
lambda_sparse – float This is the extra sparsity loss coefficient as proposed in the original paper. The bigger this coefficient is, the sparser your model will be in terms of feature selection. Depending on the difficulty of your problem, reducing this value could help.
momentum – float Momentum for batch normalization, typically ranges from 0.01 to 0.4 (default=0.02)
clip_value – float If a float is given this will clip the gradient at clip_value.
max_epochs – int Maximum number of epochs for trainng.
patience – int Number of consecutive epochs without improvement before performing early stopping.
batch_size – int Batch size
random_state – int Random seed
Example
>>> from autoprognosis.plugins.prediction import Predictions >>> plugin = Predictions(category="classifiers").get("tabnet", max_epochs = 100) >>> from sklearn.datasets import load_iris >>> X, y = load_iris(return_X_y=True) >>> plugin.fit_predict(X, y) # returns the probabilities for each class
Original implementation: https://github.com/dreamquark-ai/tabnet
- change_output(output: str) None
- explain(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
- fit(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) autoprognosis.plugins.core.base_plugin.Plugin
Train the plugin
- Parameters
X – pd.DataFrame
- fit_predict(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
Fit the model and predict the training data. Used by predictors.
- fit_transform(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
Fit the model and transform the training data. Used by imputers and preprocessors.
- classmethod fqdn() str
The fully-qualified name of the plugin: type->subtype->name
- get_args() dict
- static hyperparameter_space(*args: Any, **kwargs: Any) List[autoprognosis.plugins.core.params.Params]
The hyperparameter search domain, used for tuning.
- classmethod hyperparameter_space_fqdn(*args: Any, **kwargs: Any) List[autoprognosis.plugins.core.params.Params]
The hyperparameter domain using they fully-qualified name.
- is_fitted() bool
Check if the model was trained
- classmethod load(buff: bytes) autoprognosis.plugins.prediction.classifiers.plugin_tabnet.TabNetPlugin
Load the plugin from bytes
- static name() str
The name of the plugin, e.g.: xgboost
- predict(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
Run predictions for the input. Used by predictors.
- Parameters
X – pd.DataFrame
- predict_proba(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
- classmethod sample_hyperparameters(trial: optuna.trial.Trial, *args: Any, **kwargs: Any) Dict[str, Any]
Sample hyperparameters for Optuna.
- classmethod sample_hyperparameters_fqdn(trial: optuna.trial.Trial, *args: Any, **kwargs: Any) Dict[str, Any]
Sample hyperparameters using they fully-qualified name.
- classmethod sample_hyperparameters_np(random_state: int = 0, *args: Any, **kwargs: Any) Dict[str, Any]
Sample hyperparameters as a dict.
- save() bytes
Save the plugin to bytes
- score(X: pandas.core.frame.DataFrame, y: pandas.core.frame.DataFrame, metric: str = 'aucroc') float
- static subtype() str
The type of the plugin, e.g.: classifier
- transform(X: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame
Transform the input. Used by imputers and preprocessors.
- Parameters
X – pd.DataFrame
- static type() str
The type of the plugin, e.g.: prediction