{ "cells": [ { "cell_type": "markdown", "id": "molecular-moscow", "metadata": {}, "source": [ "# AutoPrognosis survival analysis\n", "\n", "Welcome! This tutorial will walk you through the steps of selecting a model for a survival analysis task using AutoPrognosis." ] }, { "cell_type": "markdown", "id": "auburn-hygiene", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "wanted-point", "metadata": {}, "outputs": [], "source": [ "# stdlib\n", "import json\n", "import warnings\n", "\n", "import pandas as pd\n", "\n", "# third party\n", "from lifelines.datasets import load_rossi\n", "from sklearn.model_selection import train_test_split\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "id": "573e7cfc", "metadata": {}, "source": [ "### Import RiskEstimationStudy\n", "\n", "RiskEstimationStudy is the engine that learns an ensemble of survival analysis pipelines and their hyperparameters automatically." ] }, { "cell_type": "code", "execution_count": null, "id": "c1304c76", "metadata": {}, "outputs": [], "source": [ "# autoprognosis absolute\n", "from autoprognosis.studies.risk_estimation import RiskEstimationStudy" ] }, { "cell_type": "markdown", "id": "devoted-console", "metadata": {}, "source": [ "### Load the target dataset\n", "\n", "AutoPrognosis expects pandas.DataFrames as input.\n", "\n", "For this example, we will use the [Rossi dataset](https://rdrr.io/cran/RcmdrPlugin.survival/man/Rossi.html)." ] }, { "cell_type": "code", "execution_count": null, "id": "coated-innocent", "metadata": {}, "outputs": [], "source": [ "# third party\n", "from lifelines.datasets import load_rossi\n", "\n", "rossi = load_rossi()\n", "\n", "X = rossi.drop([\"week\", \"arrest\"], axis=1)\n", "Y = rossi[\"arrest\"]\n", "T = rossi[\"week\"]\n", "\n", "eval_time_horizons = [\n", " int(T[Y.iloc[:] == 1].quantile(0.25)),\n", " int(T[Y.iloc[:] == 1].quantile(0.50)),\n", " int(T[Y.iloc[:] == 1].quantile(0.75)),\n", "]" ] }, { "cell_type": "markdown", "id": "refined-booth", "metadata": {}, "source": [ "### Create the risk estimation study\n", "\n", "While AutoPrognosis provides default plugins, it allows the user to customize the plugins for the pipelines.\n", "\n", "You can see the supported plugins below:" ] }, { "cell_type": "code", "execution_count": null, "id": "1def3193", "metadata": {}, "outputs": [], "source": [ "# stdlib\n", "# List the available plugins\n", "import json\n", "from pathlib import Path\n", "\n", "# autoprognosis absolute\n", "from autoprognosis.plugins import Plugins\n", "\n", "print(json.dumps(Plugins().list_available(), indent=2))" ] }, { "cell_type": "markdown", "id": "94bfdf48", "metadata": {}, "source": [ "We will set a few custom plugins for the pipelines and create the classifier study." ] }, { "cell_type": "code", "execution_count": null, "id": "incident-familiar", "metadata": {}, "outputs": [], "source": [ "workspace = Path(\"workspace\")\n", "workspace.mkdir(parents=True, exist_ok=True)\n", "\n", "study_name = \"test_risk_estimation_studies\"\n", "\n", "study = RiskEstimationStudy(\n", " study_name=study_name,\n", " dataset=rossi,\n", " target=\"arrest\",\n", " time_to_event=\"week\",\n", " time_horizons=eval_time_horizons,\n", " num_iter=10, # DELETE THIS LINE FOR BETTER RESULTS. number of BO iterations per estimator. Default: 50\n", " num_study_iter=1, # DELETE THIS LINE FOR BETTER RESULTS. number of outer optimization iterations. Default: 5\n", " risk_estimators=[\n", " \"cox_ph\",\n", " \"lognormal_aft\",\n", " \"loglogistic_aft\",\n", " ], # DELETE THIS LINE FOR BETTER RESULTS.\n", " workspace=workspace,\n", " score_threshold=0.4,\n", ")" ] }, { "cell_type": "markdown", "id": "dominican-tulsa", "metadata": {}, "source": [ "### Search for the best ensemble\n" ] }, { "cell_type": "code", "execution_count": null, "id": "529fd74d", "metadata": {}, "outputs": [], "source": [ "study.run()" ] }, { "cell_type": "code", "execution_count": null, "id": "0677190e", "metadata": {}, "outputs": [], "source": [ "# stdlib\n", "import pprint\n", "\n", "# autoprognosis absolute\n", "from autoprognosis.utils.serialization import load_model_from_file\n", "from autoprognosis.utils.tester import evaluate_survival_estimator\n", "\n", "output = workspace / study_name / \"model.p\"\n", "\n", "model = load_model_from_file(output)\n", "\n", "metrics = evaluate_survival_estimator(model, X, T, Y, eval_time_horizons)\n", "\n", "print(f\"Model {model.name()}\")\n", "print(f\"Score: \")\n", "\n", "pprint.pprint(metrics)" ] }, { "cell_type": "markdown", "id": "6652ec88", "metadata": {}, "source": [ "## Serialization" ] }, { "cell_type": "code", "execution_count": null, "id": "216d6efa", "metadata": {}, "outputs": [], "source": [ "# autoprognosis absolute\n", "from autoprognosis.utils.serialization import load_from_file, save_to_file\n", "\n", "out = workspace / \"tmp.bkp\"\n", "\n", "# Fit the model\n", "model.fit(X, T, Y)\n", "\n", "# Save\n", "save_to_file(out, model)\n", "\n", "# Reload\n", "loaded_model = load_from_file(out)\n", "\n", "print(loaded_model.name())\n", "\n", "assert loaded_model.name() == model.name()\n", "\n", "out.unlink()" ] }, { "cell_type": "markdown", "id": "neural-dutch", "metadata": {}, "source": [ "## Congratulations!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!\n", "\n", "### Star AutoPrognosis on GitHub\n", "\n", "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.\n", "\n", "- [Star AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)\n", "- [Star HyperImpute](https://github.com/vanderschaarlab/hyperimpute)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }