autoprognosis.plugins.prediction.risk_estimation.plugin_deephit module

class DeepHitRiskEstimationPlugin(model: Optional[Any] = None, num_durations: int = 10, batch_size: int = 100, epochs: int = 5000, lr: float = 0.01, dim_hidden: int = 300, alpha: float = 0.28, sigma: float = 0.38, dropout: float = 0.2, patience: int = 20, batch_norm: bool = False, random_state: int = 0, hyperparam_search_iterations: Optional[int] = None, **kwargs: Any)

Bases: autoprognosis.plugins.prediction.risk_estimation.base.RiskEstimationPlugin

DeepHit plugin for survival analysis. DeepHit, that uses a deep neural network to learn the distribution of survival times directly.DeepHit makes no assumptions about the underlying stochastic process and allows for the possibility that the relationship between covariates and risk(s) changes over time. Most importantly, DeepHit smoothly handles competing risks; i.e. settings in which there is more than one possible event of interest.

Parameters
  • num_durations – int Number of points in the survival function

  • batch_size – int Batch size

  • epochs – int Number of iterations

  • lr – float learning rate

  • dim_hidden – int Number of neurons in the hidden layers

  • alpha – float Weighting (0, 1) likelihood and rank loss (L2 in paper). 1 gives only likelihood, and 0 gives only rank loss. (default: {0.2})

  • sigma – float From eta in rank loss (L2 in paper) (default: {0.1})

  • dropout – float Dropout value

  • patience – int Number of epochs without improvement.

  • batch_norm – bool Enable/Disable batch_norm

  • random_state – int Random seed

Example

>>> from autoprognosis.plugins.prediction import Predictions
>>> from pycox.datasets import metabric
>>>
>>> df = metabric.read_df()
>>> X = df.drop(["duration", "event"], axis=1)
>>> Y = df["event"]
>>> T = df["duration"]
>>>
>>> plugin = Predictions(category="risk_estimation").get("deephit")
>>> plugin.fit(X, T, Y)
>>>
>>> eval_time_horizons = [int(T[Y.iloc[:] == 1].quantile(0.50))]
>>> plugin.predict(X, eval_time_horizons)

References: [1] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning

approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial Intelligence, 2018. http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit

change_output(output: str) None
explain(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
fit(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) autoprognosis.plugins.prediction.risk_estimation.base.RiskEstimationPlugin

Train the plugin

Parameters

X – pd.DataFrame

fit_predict(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame

Fit the model and predict the training data. Used by predictors.

fit_transform(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame

Fit the model and transform the training data. Used by imputers and preprocessors.

classmethod fqdn() str

The fully-qualified name of the plugin: type->subtype->name

get_args() dict
static hyperparameter_space(*args: Any, **kwargs: Any) List[autoprognosis.plugins.core.params.Params]

The hyperparameter search domain, used for tuning.

classmethod hyperparameter_space_fqdn(*args: Any, **kwargs: Any) List[autoprognosis.plugins.core.params.Params]

The hyperparameter domain using they fully-qualified name.

is_fitted() bool

Check if the model was trained

classmethod load(buff: bytes) autoprognosis.plugins.prediction.risk_estimation.plugin_deephit.DeepHitRiskEstimationPlugin

Load the plugin from bytes

static name() str

The name of the plugin, e.g.: xgboost

predict(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame

Run predictions for the input. Used by predictors.

Parameters

X – pd.DataFrame

predict_proba(X: pandas.core.frame.DataFrame, *args: Any, **kwargs: Any) pandas.core.frame.DataFrame
classmethod sample_hyperparameters(trial: optuna.trial.Trial, *args: Any, **kwargs: Any) Dict[str, Any]

Sample hyperparameters for Optuna.

classmethod sample_hyperparameters_fqdn(trial: optuna.trial.Trial, *args: Any, **kwargs: Any) Dict[str, Any]

Sample hyperparameters using they fully-qualified name.

classmethod sample_hyperparameters_np(random_state: int = 0, *args: Any, **kwargs: Any) Dict[str, Any]

Sample hyperparameters as a dict.

save() bytes

Save the plugin to bytes

score(X: pandas.core.frame.DataFrame, y: pandas.core.frame.DataFrame, metric: str = 'aucroc') float
static subtype() str

The type of the plugin, e.g.: classifier

transform(X: pandas.core.frame.DataFrame) pandas.core.frame.DataFrame

Transform the input. Used by imputers and preprocessors.

Parameters

X – pd.DataFrame

static type() str

The type of the plugin, e.g.: prediction

plugin

alias of autoprognosis.plugins.prediction.risk_estimation.plugin_deephit.DeepHitRiskEstimationPlugin