autoprognosis.plugins.prediction.risk_estimation.plugin_deephit module
- class DeepHitRiskEstimationPlugin(model: Optional[Any] = None, num_durations: int = 10, batch_size: int = 100, epochs: int = 5000, lr: float = 0.01, dim_hidden: int = 300, alpha: float = 0.28, sigma: float = 0.38, dropout: float = 0.2, patience: int = 20, batch_norm: bool = False, random_state: int = 0, hyperparam_search_iterations: Optional[int] = None, **kwargs: Any)
Bases:
autoprognosis.plugins.prediction.risk_estimation.base.RiskEstimationPlugin
DeepHit plugin for survival analysis. DeepHit, that uses a deep neural network to learn the distribution of survival times directly.DeepHit makes no assumptions about the underlying stochastic process and allows for the possibility that the relationship between covariates and risk(s) changes over time. Most importantly, DeepHit smoothly handles competing risks; i.e. settings in which there is more than one possible event of interest.
- Parameters
num_durations – int Number of points in the survival function
batch_size – int Batch size
epochs – int Number of iterations
lr – float learning rate
dim_hidden – int Number of neurons in the hidden layers
alpha – float Weighting (0, 1) likelihood and rank loss (L2 in paper). 1 gives only likelihood, and 0 gives only rank loss. (default: {0.2})
sigma – float From eta in rank loss (L2 in paper) (default: {0.1})
dropout – float Dropout value
patience – int Number of epochs without improvement.
batch_norm – bool Enable/Disable batch_norm
random_state – int Random seed
Example
>>> from autoprognosis.plugins.prediction import Predictions >>> from pycox.datasets import metabric >>> >>> df = metabric.read_df() >>> X = df.drop(["duration", "event"], axis=1) >>> Y = df["event"] >>> T = df["duration"] >>> >>> plugin = Predictions(category="risk_estimation").get("deephit") >>> plugin.fit(X, T, Y) >>> >>> eval_time_horizons = [int(T[Y.iloc[:] == 1].quantile(0.50))] >>> plugin.predict(X, eval_time_horizons)
References: [1] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning
approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018. http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit
- change_output(output: str) None
- explain(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
- fit(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) autoprognosis.plugins.prediction.risk_estimation.base.RiskEstimationPlugin
Train the plugin
- Parameters
X – pd.DataFrame
- fit_predict(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
Fit the model and predict the training data. Used by predictors.
- fit_transform(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
Fit the model and transform the training data. Used by imputers and preprocessors.
- classmethod fqdn() str
The fully-qualified name of the plugin: type->subtype->name
- get_args() dict
- static hyperparameter_space(*args: Any, **kwargs: Any) List[autoprognosis.plugins.core.params.Params]
The hyperparameter search domain, used for tuning.
- classmethod hyperparameter_space_fqdn(*args: Any, **kwargs: Any) List[autoprognosis.plugins.core.params.Params]
The hyperparameter domain using they fully-qualified name.
- is_fitted() bool
Check if the model was trained
- classmethod load(buff: bytes) autoprognosis.plugins.prediction.risk_estimation.plugin_deephit.DeepHitRiskEstimationPlugin
Load the plugin from bytes
- static name() str
The name of the plugin, e.g.: xgboost
- predict(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
Run predictions for the input. Used by predictors.
- Parameters
X – pd.DataFrame
- predict_proba(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
- classmethod sample_hyperparameters(trial: optuna.trial.Trial, *args: Any, **kwargs: Any) Dict[str, Any]
Sample hyperparameters for Optuna.
- classmethod sample_hyperparameters_fqdn(trial: optuna.trial.Trial, *args: Any, **kwargs: Any) Dict[str, Any]
Sample hyperparameters using they fully-qualified name.
- classmethod sample_hyperparameters_np(random_state: int = 0, *args: Any, **kwargs: Any) Dict[str, Any]
Sample hyperparameters as a dict.
- save() bytes
Save the plugin to bytes
- score(X: pandas.core.frame.DataFrame, y: pandas.core.frame.DataFrame, metric: str = 'aucroc') float
- static subtype() str
The type of the plugin, e.g.: classifier
- transform(X: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame
Transform the input. Used by imputers and preprocessors.
- Parameters
X – pd.DataFrame
- static type() str
The type of the plugin, e.g.: prediction