aboutsummaryrefslogtreecommitdiff
path: root/Fundamentals_of_Accelerated_Data_Science/3-04_logistic_regression.ipynb
diff options
context:
space:
mode:
authorleshe4ka46 <alex9102naid1@ya.ru>2025-10-18 12:25:53 +0300
committerleshe4ka46 <alex9102naid1@ya.ru>2025-10-18 12:25:53 +0300
commit910a222fa60ce6ea0831f2956470b8a0b9f62670 (patch)
tree1d6bbccafb667731ad127f93390761100fc11b53 /Fundamentals_of_Accelerated_Data_Science/3-04_logistic_regression.ipynb
parent35b9040e4104b0e79bf243a2c9769c589f96e2c4 (diff)
nvidia2
Diffstat (limited to 'Fundamentals_of_Accelerated_Data_Science/3-04_logistic_regression.ipynb')
-rw-r--r--Fundamentals_of_Accelerated_Data_Science/3-04_logistic_regression.ipynb1904
1 files changed, 1904 insertions, 0 deletions
diff --git a/Fundamentals_of_Accelerated_Data_Science/3-04_logistic_regression.ipynb b/Fundamentals_of_Accelerated_Data_Science/3-04_logistic_regression.ipynb
new file mode 100644
index 0000000..3890ea7
--- /dev/null
+++ b/Fundamentals_of_Accelerated_Data_Science/3-04_logistic_regression.ipynb
@@ -0,0 +1,1904 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "<img src=\"./images/DLI_Header.png\" width=400/>"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Fundamentals of Accelerated Data Science # "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 04 - Logistic Regression ##\n",
+ "\n",
+ "**Table of Contents**\n",
+ "<br>\n",
+ "This notebook uses GPU-accelerated logistic regression to predict infection risk based on features of our population members. This notebook covers the below sections: \n",
+ "1. [Environment](#Environment)\n",
+ "2. [Load Data](#Load-Data)\n",
+ "3. [Logistic Regression](#Logistic-Regression)\n",
+ " * [Viewing the Regression](#Viewing-the-Regression)\n",
+ " * [Estimate Probability of Infection](#Estimate-Probability-of-Infection)\n",
+ "4. [Model Explainability](#Model-Explainability)\n",
+ " * [Show Infection Prevalence is Related to Age](#Show-Infection-Prevalence-is-Related-to-Age)\n",
+ " * [Exercise #1 - Show Infection Prevalence is Related to Sex](#Exercise-#1---Show-Infection-Prevalence-is-Related-to-Sex)\n",
+ "5. [Making Predictions with Separate Training and Testing Data](#Making-Predictions-with-Separate-Training-and-Test-Data)\n",
+ " * [Exercise #2 - Fit Logistic Regression Model Using Training Data](#Exercise-#2---Fit-Logistic-Regression-Model-Using-Training-Data)\n",
+ " * [Use Test Data to Validate Model](#Use-Test-Data-to-Validate-Model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Environment ##"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import cudf\n",
+ "import cuml\n",
+ "\n",
+ "import cupy as cp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gdf = cudf.read_csv('./data/clean_uk_pop_full.csv', usecols=['age', 'sex', 'infected'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "age float64\n",
+ "sex float64\n",
+ "infected float64\n",
+ "dtype: object"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gdf.dtypes"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(58479894, 3)"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gdf.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>age</th>\n",
+ " <th>sex</th>\n",
+ " <th>infected</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1</th>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>2</th>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>3</th>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>4</th>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " age sex infected\n",
+ "0 0.0 0.0 0.0\n",
+ "1 0.0 0.0 0.0\n",
+ "2 0.0 0.0 0.0\n",
+ "3 0.0 0.0 0.0\n",
+ "4 0.0 0.0 0.0"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gdf.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Logistic Regression ##\n",
+ "Logistic regression can be used to estimate the probability of an outcome as a function of some (assumed independent) inputs. In our case, we would like to estimate infection risk based on population members' age and sex.\n",
+ "\n",
+ "Below we train a logistic regresion model. We first create a cuML logistic regression instance `logreg`. The `logreg.fit` method takes 2 arguments: the model's independent variables *X*, and the dependent variable *y*. Fit the `logreg` model using the `gdf` columns `age` and `sex` as *X* and the `infected` column as *y*."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<style>#sk-container-id-1 {\n",
+ " /* Definition of color scheme common for light and dark mode */\n",
+ " --sklearn-color-text: black;\n",
+ " --sklearn-color-line: gray;\n",
+ " /* Definition of color scheme for unfitted estimators */\n",
+ " --sklearn-color-unfitted-level-0: #fff5e6;\n",
+ " --sklearn-color-unfitted-level-1: #f6e4d2;\n",
+ " --sklearn-color-unfitted-level-2: #ffe0b3;\n",
+ " --sklearn-color-unfitted-level-3: chocolate;\n",
+ " /* Definition of color scheme for fitted estimators */\n",
+ " --sklearn-color-fitted-level-0: #f0f8ff;\n",
+ " --sklearn-color-fitted-level-1: #d4ebff;\n",
+ " --sklearn-color-fitted-level-2: #b3dbfd;\n",
+ " --sklearn-color-fitted-level-3: cornflowerblue;\n",
+ "\n",
+ " /* Specific color for light theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-icon: #696969;\n",
+ "\n",
+ " @media (prefers-color-scheme: dark) {\n",
+ " /* Redefinition of color scheme for dark theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-icon: #878787;\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 pre {\n",
+ " padding: 0;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 input.sk-hidden--visually {\n",
+ " border: 0;\n",
+ " clip: rect(1px 1px 1px 1px);\n",
+ " clip: rect(1px, 1px, 1px, 1px);\n",
+ " height: 1px;\n",
+ " margin: -1px;\n",
+ " overflow: hidden;\n",
+ " padding: 0;\n",
+ " position: absolute;\n",
+ " width: 1px;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-dashed-wrapped {\n",
+ " border: 1px dashed var(--sklearn-color-line);\n",
+ " margin: 0 0.4em 0.5em 0.4em;\n",
+ " box-sizing: border-box;\n",
+ " padding-bottom: 0.4em;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-container {\n",
+ " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
+ " but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
+ " so we also need the `!important` here to be able to override the\n",
+ " default hidden behavior on the sphinx rendered scikit-learn.org.\n",
+ " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
+ " display: inline-block !important;\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-text-repr-fallback {\n",
+ " display: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-parallel-item,\n",
+ "div.sk-serial,\n",
+ "div.sk-item {\n",
+ " /* draw centered vertical line to link estimators */\n",
+ " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
+ " background-size: 2px 100%;\n",
+ " background-repeat: no-repeat;\n",
+ " background-position: center center;\n",
+ "}\n",
+ "\n",
+ "/* Parallel-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-1 div.sk-parallel-item::after {\n",
+ " content: \"\";\n",
+ " width: 100%;\n",
+ " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
+ " flex-grow: 1;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-parallel {\n",
+ " display: flex;\n",
+ " align-items: stretch;\n",
+ " justify-content: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-parallel-item {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
+ " align-self: flex-end;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
+ " align-self: flex-start;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
+ " width: 0;\n",
+ "}\n",
+ "\n",
+ "/* Serial-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-1 div.sk-serial {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ " align-items: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " padding-right: 1em;\n",
+ " padding-left: 1em;\n",
+ "}\n",
+ "\n",
+ "\n",
+ "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
+ "clickable and can be expanded/collapsed.\n",
+ "- Pipeline and ColumnTransformer use this feature and define the default style\n",
+ "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
+ "*/\n",
+ "\n",
+ "/* Pipeline and ColumnTransformer style (default) */\n",
+ "\n",
+ "#sk-container-id-1 div.sk-toggleable {\n",
+ " /* Default theme specific background. It is overwritten whether we have a\n",
+ " specific estimator or a Pipeline/ColumnTransformer */\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable label */\n",
+ "#sk-container-id-1 label.sk-toggleable__label {\n",
+ " cursor: pointer;\n",
+ " display: block;\n",
+ " width: 100%;\n",
+ " margin-bottom: 0;\n",
+ " padding: 0.5em;\n",
+ " box-sizing: border-box;\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
+ " /* Arrow on the left of the label */\n",
+ " content: \"▸\";\n",
+ " float: left;\n",
+ " margin-right: 0.25em;\n",
+ " color: var(--sklearn-color-icon);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable content - dropdown */\n",
+ "\n",
+ "#sk-container-id-1 div.sk-toggleable__content {\n",
+ " max-height: 0;\n",
+ " max-width: 0;\n",
+ " overflow: hidden;\n",
+ " text-align: left;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-toggleable__content pre {\n",
+ " margin: 0.2em;\n",
+ " border-radius: 0.25em;\n",
+ " color: var(--sklearn-color-text);\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
+ " /* Expand drop-down */\n",
+ " max-height: 200px;\n",
+ " max-width: 100%;\n",
+ " overflow: auto;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
+ " content: \"▾\";\n",
+ "}\n",
+ "\n",
+ "/* Pipeline/ColumnTransformer-specific style */\n",
+ "\n",
+ "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific style */\n",
+ "\n",
+ "/* Colorize estimator box */\n",
+ "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
+ "#sk-container-id-1 div.sk-label label {\n",
+ " /* The background is the default theme color */\n",
+ " color: var(--sklearn-color-text-on-default-background);\n",
+ "}\n",
+ "\n",
+ "/* On hover, darken the color of the background */\n",
+ "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Label box, darken color on hover, fitted */\n",
+ "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator label */\n",
+ "\n",
+ "#sk-container-id-1 div.sk-label label {\n",
+ " font-family: monospace;\n",
+ " font-weight: bold;\n",
+ " display: inline-block;\n",
+ " line-height: 1.2em;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-label-container {\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific */\n",
+ "#sk-container-id-1 div.sk-estimator {\n",
+ " font-family: monospace;\n",
+ " border: 1px dotted var(--sklearn-color-border-box);\n",
+ " border-radius: 0.25em;\n",
+ " box-sizing: border-box;\n",
+ " margin-bottom: 0.5em;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-estimator.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "/* on hover */\n",
+ "#sk-container-id-1 div.sk-estimator:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
+ "\n",
+ "/* Common style for \"i\" and \"?\" */\n",
+ "\n",
+ ".sk-estimator-doc-link,\n",
+ "a:link.sk-estimator-doc-link,\n",
+ "a:visited.sk-estimator-doc-link {\n",
+ " float: right;\n",
+ " font-size: smaller;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1em;\n",
+ " height: 1em;\n",
+ " width: 1em;\n",
+ " text-decoration: none !important;\n",
+ " margin-left: 1ex;\n",
+ " /* unfitted */\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted,\n",
+ "a:link.sk-estimator-doc-link.fitted,\n",
+ "a:visited.sk-estimator-doc-link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "/* Span, style for the box shown on hovering the info icon */\n",
+ ".sk-estimator-doc-link span {\n",
+ " display: none;\n",
+ " z-index: 9999;\n",
+ " position: relative;\n",
+ " font-weight: normal;\n",
+ " right: .2ex;\n",
+ " padding: .5ex;\n",
+ " margin: .5ex;\n",
+ " width: min-content;\n",
+ " min-width: 20ex;\n",
+ " max-width: 50ex;\n",
+ " color: var(--sklearn-color-text);\n",
+ " box-shadow: 2pt 2pt 4pt #999;\n",
+ " /* unfitted */\n",
+ " background: var(--sklearn-color-unfitted-level-0);\n",
+ " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted span {\n",
+ " /* fitted */\n",
+ " background: var(--sklearn-color-fitted-level-0);\n",
+ " border: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link:hover span {\n",
+ " display: block;\n",
+ "}\n",
+ "\n",
+ "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
+ "\n",
+ "#sk-container-id-1 a.estimator_doc_link {\n",
+ " float: right;\n",
+ " font-size: 1rem;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1rem;\n",
+ " height: 1rem;\n",
+ " width: 1rem;\n",
+ " text-decoration: none;\n",
+ " /* unfitted */\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "#sk-container-id-1 a.estimator_doc_link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;LogisticRegression<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>LogisticRegression()</pre></div> </div></div></div></div>"
+ ],
+ "text/plain": [
+ "LogisticRegression()"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "logreg = cuml.LogisticRegression()\n",
+ "logreg.fit(gdf[['age', 'sex']], gdf['infected'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Viewing the Regression ###\n",
+ "After fitting the model, we could use `logreg.predict` to estimate whether someone has more than a 50% chance to be infected, but since the virus has low prevalence in the population (around 1-2%, in this data set), individual probabilities of infection are well below 50% and the model should correctly predict that no one is individually likely to have the infection.\n",
+ "\n",
+ "However, we also have access to the model coefficients at `logreg.coef_` as well as the intercept at `logreg.intercept_`. Both of these values are cuDF Series. \n",
+ "\n",
+ "Below we view these values. Notice that changing sex from 0 to 1 has the same effect via the coefficients as changing the age by ~48 years."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "cudf.core.dataframe.DataFrame"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(logreg.coef_)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "cudf.core.series.Series"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "type(logreg.intercept_)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Coefficients: [age, sex]\n",
+ "[0 0.014861\n",
+ "Name: 0, dtype: float64, 0 0.695666\n",
+ "Name: 1, dtype: float64]\n",
+ "Intercept:\n",
+ "-5.222369426308725\n"
+ ]
+ }
+ ],
+ "source": [
+ "logreg_coef = logreg.coef_\n",
+ "logreg_int = logreg.intercept_\n",
+ "\n",
+ "print(\"Coefficients: [age, sex]\")\n",
+ "print([logreg_coef[0], logreg_coef[1]])\n",
+ "\n",
+ "print(\"Intercept:\")\n",
+ "print(logreg_int[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Estimate Probability of Infection ###\n",
+ "As with all logistic regressions, the coefficients allow us to calculate the logit for each; from that, we can calculate the estimated percentage risk of infection. \n",
+ "\n",
+ "**Note**: Remembering that a 1 indicates 'infected', we assign that class' probability to a new column in the original dataframe. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>0</th>\n",
+ " <th>1</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>0.994634</td>\n",
+ " <td>0.005366</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1</th>\n",
+ " <td>0.994634</td>\n",
+ " <td>0.005366</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>2</th>\n",
+ " <td>0.994634</td>\n",
+ " <td>0.005366</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>3</th>\n",
+ " <td>0.994634</td>\n",
+ " <td>0.005366</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>4</th>\n",
+ " <td>0.994634</td>\n",
+ " <td>0.005366</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>...</th>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>58479889</th>\n",
+ " <td>0.960428</td>\n",
+ " <td>0.039572</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>58479890</th>\n",
+ " <td>0.960428</td>\n",
+ " <td>0.039572</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>58479891</th>\n",
+ " <td>0.960428</td>\n",
+ " <td>0.039572</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>58479892</th>\n",
+ " <td>0.960428</td>\n",
+ " <td>0.039572</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>58479893</th>\n",
+ " <td>0.960428</td>\n",
+ " <td>0.039572</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "<p>58479894 rows × 2 columns</p>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " 0 1\n",
+ "0 0.994634 0.005366\n",
+ "1 0.994634 0.005366\n",
+ "2 0.994634 0.005366\n",
+ "3 0.994634 0.005366\n",
+ "4 0.994634 0.005366\n",
+ "... ... ...\n",
+ "58479889 0.960428 0.039572\n",
+ "58479890 0.960428 0.039572\n",
+ "58479891 0.960428 0.039572\n",
+ "58479892 0.960428 0.039572\n",
+ "58479893 0.960428 0.039572\n",
+ "\n",
+ "[58479894 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "class_probs = logreg.predict_proba(gdf[['age', 'sex']])\n",
+ "class_probs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gdf['risk'] = class_probs[1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Looking at the original records with their new estimated risks, we can see how estimated risk varies across individuals."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>age</th>\n",
+ " <th>sex</th>\n",
+ " <th>infected</th>\n",
+ " <th>risk</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>57475949</th>\n",
+ " <td>84.0</td>\n",
+ " <td>1.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.036319</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>43073547</th>\n",
+ " <td>39.0</td>\n",
+ " <td>1.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.018944</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>46406333</th>\n",
+ " <td>48.0</td>\n",
+ " <td>1.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.021596</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>17837831</th>\n",
+ " <td>48.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.010889</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>16785013</th>\n",
+ " <td>45.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.010419</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " age sex infected risk\n",
+ "57475949 84.0 1.0 0.0 0.036319\n",
+ "43073547 39.0 1.0 0.0 0.018944\n",
+ "46406333 48.0 1.0 0.0 0.021596\n",
+ "17837831 48.0 0.0 0.0 0.010889\n",
+ "16785013 45.0 0.0 0.0 0.010419"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gdf.take(cp.random.choice(gdf.shape[0], size=5, replace=False))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model Explainability ##\n",
+ "Model explainability refers to the ability to understand and explain the decisions and reasoning underlying the predictions from machine learning models. It can be achieved by investigating how the feature variables are related to the target variable. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Show Infection Prevalence is Related to Age ###\n",
+ "The positive coefficient on age suggests that the virus is more prevalent in older people, even when controlling for sex.\n",
+ "\n",
+ "For this exercise, show that infection prevalence has some relationship to age by printing the mean `infected` values for the oldest and youngest members of the population when grouped by age:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " infected\n",
+ "age \n",
+ "66.0 0.020700\n",
+ "71.0 0.021292\n",
+ "64.0 0.020675\n",
+ "77.0 0.022102\n",
+ "82.0 0.022929\n",
+ " infected\n",
+ "age \n",
+ "33.0 0.015707\n",
+ "76.0 0.021928\n",
+ "74.0 0.021807\n",
+ "79.0 0.022518\n",
+ "86.0 0.023417\n"
+ ]
+ }
+ ],
+ "source": [
+ "# %load solutions/risk_by_age\n",
+ "age_groups = gdf[['age', 'infected']].groupby(['age'])\n",
+ "print(age_groups.mean().head())\n",
+ "print(age_groups.mean().tail())\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Exercise #1 - Show Infection Prevalence is Related to Sex ###\n",
+ "Similarly, the positive coefficient on sex suggests that the virus is more prevalent in people with sex = `1` (females), even when controlling for age.\n",
+ "\n",
+ "**Instructions**: <br>\n",
+ "* Modify the `<FIXME>` only and execute the below cell to show that infection prevalence has some relationship to sex by printing the mean `infected` values for the population when grouped by sex. ."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>infected</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>sex</th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0.0</th>\n",
+ " <td>0.010140</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1.0</th>\n",
+ " <td>0.020713</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " infected\n",
+ "sex \n",
+ "0.0 0.010140\n",
+ "1.0 0.020713"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sex_groups = gdf[['sex', 'infected']].groupby(['sex'])\n",
+ "sex_groups.mean()"
+ ]
+ },
+ {
+ "cell_type": "raw",
+ "metadata": {
+ "scrolled": true
+ },
+ "source": [
+ "\n",
+ "sex_groups = gdf[['sex', 'infected']].groupby(['sex'])\n",
+ "sex_groups.mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Click ... for solution. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Making Predictions with Separate Training and Test Data ##\n",
+ "The typical process involves training the model on the training set, then using the test set to evaluate its performance. This provides a more realistic assessment of how well the model will perform on new, unseen data in real-world applications. By testing on a separate dataset, you can detect if your model is **overfitting** to the training data. Overfitting occurs when a model performs well on training data but poorly on new data. In many cases, you don't have access to truly new data, so splitting your existing data simulates this scenario. \n",
+ "\n",
+ "cuML gives us a simple method for producing paired training/testing data:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "X_train, X_test, y_train, y_test = cuml.train_test_split(gdf[['age', 'sex']], gdf['infected'], train_size=0.9)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Exercise #2 - Fit Logistic Regression Model Using Training Data ###\n",
+ "\n",
+ "**Instructions**: <br>\n",
+ "* Execute the below cell to create a new logistic regression model `logreg`\n",
+ "* Modify the `<FIXME>` only and execute the cell below to fit the new model with the *X* and *y* training data just created."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "logreg = cuml.LogisticRegression()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<style>#sk-container-id-3 {\n",
+ " /* Definition of color scheme common for light and dark mode */\n",
+ " --sklearn-color-text: black;\n",
+ " --sklearn-color-line: gray;\n",
+ " /* Definition of color scheme for unfitted estimators */\n",
+ " --sklearn-color-unfitted-level-0: #fff5e6;\n",
+ " --sklearn-color-unfitted-level-1: #f6e4d2;\n",
+ " --sklearn-color-unfitted-level-2: #ffe0b3;\n",
+ " --sklearn-color-unfitted-level-3: chocolate;\n",
+ " /* Definition of color scheme for fitted estimators */\n",
+ " --sklearn-color-fitted-level-0: #f0f8ff;\n",
+ " --sklearn-color-fitted-level-1: #d4ebff;\n",
+ " --sklearn-color-fitted-level-2: #b3dbfd;\n",
+ " --sklearn-color-fitted-level-3: cornflowerblue;\n",
+ "\n",
+ " /* Specific color for light theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
+ " --sklearn-color-icon: #696969;\n",
+ "\n",
+ " @media (prefers-color-scheme: dark) {\n",
+ " /* Redefinition of color scheme for dark theme */\n",
+ " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
+ " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
+ " --sklearn-color-icon: #878787;\n",
+ " }\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 pre {\n",
+ " padding: 0;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 input.sk-hidden--visually {\n",
+ " border: 0;\n",
+ " clip: rect(1px 1px 1px 1px);\n",
+ " clip: rect(1px, 1px, 1px, 1px);\n",
+ " height: 1px;\n",
+ " margin: -1px;\n",
+ " overflow: hidden;\n",
+ " padding: 0;\n",
+ " position: absolute;\n",
+ " width: 1px;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-dashed-wrapped {\n",
+ " border: 1px dashed var(--sklearn-color-line);\n",
+ " margin: 0 0.4em 0.5em 0.4em;\n",
+ " box-sizing: border-box;\n",
+ " padding-bottom: 0.4em;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-container {\n",
+ " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
+ " but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
+ " so we also need the `!important` here to be able to override the\n",
+ " default hidden behavior on the sphinx rendered scikit-learn.org.\n",
+ " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
+ " display: inline-block !important;\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-text-repr-fallback {\n",
+ " display: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-parallel-item,\n",
+ "div.sk-serial,\n",
+ "div.sk-item {\n",
+ " /* draw centered vertical line to link estimators */\n",
+ " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
+ " background-size: 2px 100%;\n",
+ " background-repeat: no-repeat;\n",
+ " background-position: center center;\n",
+ "}\n",
+ "\n",
+ "/* Parallel-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-3 div.sk-parallel-item::after {\n",
+ " content: \"\";\n",
+ " width: 100%;\n",
+ " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
+ " flex-grow: 1;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-parallel {\n",
+ " display: flex;\n",
+ " align-items: stretch;\n",
+ " justify-content: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " position: relative;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-parallel-item {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-parallel-item:first-child::after {\n",
+ " align-self: flex-end;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-parallel-item:last-child::after {\n",
+ " align-self: flex-start;\n",
+ " width: 50%;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-parallel-item:only-child::after {\n",
+ " width: 0;\n",
+ "}\n",
+ "\n",
+ "/* Serial-specific style estimator block */\n",
+ "\n",
+ "#sk-container-id-3 div.sk-serial {\n",
+ " display: flex;\n",
+ " flex-direction: column;\n",
+ " align-items: center;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " padding-right: 1em;\n",
+ " padding-left: 1em;\n",
+ "}\n",
+ "\n",
+ "\n",
+ "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
+ "clickable and can be expanded/collapsed.\n",
+ "- Pipeline and ColumnTransformer use this feature and define the default style\n",
+ "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
+ "*/\n",
+ "\n",
+ "/* Pipeline and ColumnTransformer style (default) */\n",
+ "\n",
+ "#sk-container-id-3 div.sk-toggleable {\n",
+ " /* Default theme specific background. It is overwritten whether we have a\n",
+ " specific estimator or a Pipeline/ColumnTransformer */\n",
+ " background-color: var(--sklearn-color-background);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable label */\n",
+ "#sk-container-id-3 label.sk-toggleable__label {\n",
+ " cursor: pointer;\n",
+ " display: block;\n",
+ " width: 100%;\n",
+ " margin-bottom: 0;\n",
+ " padding: 0.5em;\n",
+ " box-sizing: border-box;\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 label.sk-toggleable__label-arrow:before {\n",
+ " /* Arrow on the left of the label */\n",
+ " content: \"▸\";\n",
+ " float: left;\n",
+ " margin-right: 0.25em;\n",
+ " color: var(--sklearn-color-icon);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 label.sk-toggleable__label-arrow:hover:before {\n",
+ " color: var(--sklearn-color-text);\n",
+ "}\n",
+ "\n",
+ "/* Toggleable content - dropdown */\n",
+ "\n",
+ "#sk-container-id-3 div.sk-toggleable__content {\n",
+ " max-height: 0;\n",
+ " max-width: 0;\n",
+ " overflow: hidden;\n",
+ " text-align: left;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-toggleable__content.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-toggleable__content pre {\n",
+ " margin: 0.2em;\n",
+ " border-radius: 0.25em;\n",
+ " color: var(--sklearn-color-text);\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-toggleable__content.fitted pre {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
+ " /* Expand drop-down */\n",
+ " max-height: 200px;\n",
+ " max-width: 100%;\n",
+ " overflow: auto;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
+ " content: \"▾\";\n",
+ "}\n",
+ "\n",
+ "/* Pipeline/ColumnTransformer-specific style */\n",
+ "\n",
+ "#sk-container-id-3 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific style */\n",
+ "\n",
+ "/* Colorize estimator box */\n",
+ "#sk-container-id-3 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-label label.sk-toggleable__label,\n",
+ "#sk-container-id-3 div.sk-label label {\n",
+ " /* The background is the default theme color */\n",
+ " color: var(--sklearn-color-text-on-default-background);\n",
+ "}\n",
+ "\n",
+ "/* On hover, darken the color of the background */\n",
+ "#sk-container-id-3 div.sk-label:hover label.sk-toggleable__label {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Label box, darken color on hover, fitted */\n",
+ "#sk-container-id-3 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
+ " color: var(--sklearn-color-text);\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Estimator label */\n",
+ "\n",
+ "#sk-container-id-3 div.sk-label label {\n",
+ " font-family: monospace;\n",
+ " font-weight: bold;\n",
+ " display: inline-block;\n",
+ " line-height: 1.2em;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-label-container {\n",
+ " text-align: center;\n",
+ "}\n",
+ "\n",
+ "/* Estimator-specific */\n",
+ "#sk-container-id-3 div.sk-estimator {\n",
+ " font-family: monospace;\n",
+ " border: 1px dotted var(--sklearn-color-border-box);\n",
+ " border-radius: 0.25em;\n",
+ " box-sizing: border-box;\n",
+ " margin-bottom: 0.5em;\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-0);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-estimator.fitted {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-0);\n",
+ "}\n",
+ "\n",
+ "/* on hover */\n",
+ "#sk-container-id-3 div.sk-estimator:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-2);\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 div.sk-estimator.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-2);\n",
+ "}\n",
+ "\n",
+ "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
+ "\n",
+ "/* Common style for \"i\" and \"?\" */\n",
+ "\n",
+ ".sk-estimator-doc-link,\n",
+ "a:link.sk-estimator-doc-link,\n",
+ "a:visited.sk-estimator-doc-link {\n",
+ " float: right;\n",
+ " font-size: smaller;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1em;\n",
+ " height: 1em;\n",
+ " width: 1em;\n",
+ " text-decoration: none !important;\n",
+ " margin-left: 1ex;\n",
+ " /* unfitted */\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted,\n",
+ "a:link.sk-estimator-doc-link.fitted,\n",
+ "a:visited.sk-estimator-doc-link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
+ ".sk-estimator-doc-link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover,\n",
+ "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
+ ".sk-estimator-doc-link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "/* Span, style for the box shown on hovering the info icon */\n",
+ ".sk-estimator-doc-link span {\n",
+ " display: none;\n",
+ " z-index: 9999;\n",
+ " position: relative;\n",
+ " font-weight: normal;\n",
+ " right: .2ex;\n",
+ " padding: .5ex;\n",
+ " margin: .5ex;\n",
+ " width: min-content;\n",
+ " min-width: 20ex;\n",
+ " max-width: 50ex;\n",
+ " color: var(--sklearn-color-text);\n",
+ " box-shadow: 2pt 2pt 4pt #999;\n",
+ " /* unfitted */\n",
+ " background: var(--sklearn-color-unfitted-level-0);\n",
+ " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link.fitted span {\n",
+ " /* fitted */\n",
+ " background: var(--sklearn-color-fitted-level-0);\n",
+ " border: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "\n",
+ ".sk-estimator-doc-link:hover span {\n",
+ " display: block;\n",
+ "}\n",
+ "\n",
+ "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
+ "\n",
+ "#sk-container-id-3 a.estimator_doc_link {\n",
+ " float: right;\n",
+ " font-size: 1rem;\n",
+ " line-height: 1em;\n",
+ " font-family: monospace;\n",
+ " background-color: var(--sklearn-color-background);\n",
+ " border-radius: 1rem;\n",
+ " height: 1rem;\n",
+ " width: 1rem;\n",
+ " text-decoration: none;\n",
+ " /* unfitted */\n",
+ " color: var(--sklearn-color-unfitted-level-1);\n",
+ " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 a.estimator_doc_link.fitted {\n",
+ " /* fitted */\n",
+ " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
+ " color: var(--sklearn-color-fitted-level-1);\n",
+ "}\n",
+ "\n",
+ "/* On hover */\n",
+ "#sk-container-id-3 a.estimator_doc_link:hover {\n",
+ " /* unfitted */\n",
+ " background-color: var(--sklearn-color-unfitted-level-3);\n",
+ " color: var(--sklearn-color-background);\n",
+ " text-decoration: none;\n",
+ "}\n",
+ "\n",
+ "#sk-container-id-3 a.estimator_doc_link.fitted:hover {\n",
+ " /* fitted */\n",
+ " background-color: var(--sklearn-color-fitted-level-3);\n",
+ "}\n",
+ "</style><div id=\"sk-container-id-3\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" checked><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;LogisticRegression<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>LogisticRegression()</pre></div> </div></div></div></div>"
+ ],
+ "text/plain": [
+ "LogisticRegression()"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "logreg.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "raw",
+ "metadata": {},
+ "source": [
+ "\n",
+ "logreg.fit(X_train, y_train)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Click ... for solution. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Use Test Data to Validate Model ###\n",
+ "We can now use the same procedure as above to predict infection risk using the test data:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "16557172 0.010267\n",
+ "21666454 0.012426\n",
+ "25019343 0.014598\n",
+ "32613710 0.012409\n",
+ "5458911 0.006700\n",
+ " ... \n",
+ "29193786 0.010716\n",
+ "10832641 0.008235\n",
+ "50674662 0.024982\n",
+ "15628357 0.009970\n",
+ "44635132 0.020095\n",
+ "Name: 1, Length: 5847990, dtype: float64"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y_test_pred = logreg.predict_proba(X_test, convert_dtype=True)[1]\n",
+ "y_test_pred.index = X_test.index\n",
+ "y_test_pred"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As we saw before, very few people are actually infected in the population, even among the highest-risk groups. As a simple way to check our model, we split the test set into above-average predicted risk and below-average predicted risk, then observe that the prevalence of infections correlates closely to those predicted risks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>age</th>\n",
+ " <th>sex</th>\n",
+ " <th>infected</th>\n",
+ " <th>predicted_risk</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>high_risk</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>False</th>\n",
+ " <td>29.536875</td>\n",
+ " <td>0.252218</td>\n",
+ " <td>0.009992</td>\n",
+ " <td>0.010328</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>True</th>\n",
+ " <td>56.187385</td>\n",
+ " <td>0.889923</td>\n",
+ " <td>0.023676</td>\n",
+ " <td>0.023323</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " age sex infected predicted_risk\n",
+ "high_risk \n",
+ "False 29.536875 0.252218 0.009992 0.010328\n",
+ "True 56.187385 0.889923 0.023676 0.023323"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_results = cudf.DataFrame()\n",
+ "test_results['age'] = X_test['age']\n",
+ "test_results['sex'] = X_test['sex']\n",
+ "test_results['infected'] = y_test\n",
+ "test_results['predicted_risk'] = y_test_pred\n",
+ "\n",
+ "test_results['high_risk'] = test_results['predicted_risk'] > test_results['predicted_risk'].mean()\n",
+ "\n",
+ "risk_groups = test_results.groupby('high_risk')\n",
+ "risk_groups.mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, in a few milliseconds, we can do a two-tier analysis by sex and age:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 310 ms, sys: 56.2 ms, total: 367 ms\n",
+ "Wall time: 366 ms\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th>infected</th>\n",
+ " <th>predicted_risk</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>sex</th>\n",
+ " <th>age</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th rowspan=\"3\" valign=\"top\">0.0</th>\n",
+ " <th>6.0</th>\n",
+ " <td>0.003474</td>\n",
+ " <td>0.005867</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1.0</th>\n",
+ " <td>0.000764</td>\n",
+ " <td>0.005449</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>14.0</th>\n",
+ " <td>0.006123</td>\n",
+ " <td>0.006602</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1.0</th>\n",
+ " <th>5.0</th>\n",
+ " <td>0.006264</td>\n",
+ " <td>0.011532</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th rowspan=\"6\" valign=\"top\">0.0</th>\n",
+ " <th>60.0</th>\n",
+ " <td>0.013377</td>\n",
+ " <td>0.012984</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>...</th>\n",
+ " <td>...</td>\n",
+ " <td>...</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>35.0</th>\n",
+ " <td>0.010341</td>\n",
+ " <td>0.008995</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>13.0</th>\n",
+ " <td>0.005650</td>\n",
+ " <td>0.006505</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>32.0</th>\n",
+ " <td>0.010420</td>\n",
+ " <td>0.008606</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>19.0</th>\n",
+ " <td>0.007208</td>\n",
+ " <td>0.007107</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>1.0</th>\n",
+ " <th>12.0</th>\n",
+ " <td>0.010666</td>\n",
+ " <td>0.012778</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "<p>182 rows × 2 columns</p>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " infected predicted_risk\n",
+ "sex age \n",
+ "0.0 6.0 0.003474 0.005867\n",
+ " 1.0 0.000764 0.005449\n",
+ " 14.0 0.006123 0.006602\n",
+ "1.0 5.0 0.006264 0.011532\n",
+ "0.0 60.0 0.013377 0.012984\n",
+ "... ... ...\n",
+ " 35.0 0.010341 0.008995\n",
+ " 13.0 0.005650 0.006505\n",
+ " 32.0 0.010420 0.008606\n",
+ " 19.0 0.007208 0.007107\n",
+ "1.0 12.0 0.010666 0.012778\n",
+ "\n",
+ "[182 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "s_groups = test_results[['sex', 'age', 'infected', 'predicted_risk']].groupby(['sex', 'age'])\n",
+ "s_groups.mean()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import IPython\n",
+ "app = IPython.Application.instance()\n",
+ "app.kernel.do_shutdown(True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Well Done!** Let's move to the [next notebook](3-05_knn.ipynb). "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "<img src=\"./images/DLI_Header.png\" width=400/>"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.15"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}