diff --git a/.gitignore b/.gitignore index 7c4d72e..a1a53a2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ *.csv *.geojson jupytext.toml +private/ # C extensions *.so diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2d27e0e..82e5746 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,32 +1,32 @@ repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 - hooks: - - id: check-toml - - id: check-yaml - - id: end-of-file-fixer - types: [python] - - id: trailing-whitespace - - id: requirements-txt-fixer - - id: check-added-large-files - args: ["--maxkb=500"] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + types: [python] + - id: trailing-whitespace + - id: requirements-txt-fixer + - id: check-added-large-files + args: ["--maxkb=500"] - - repo: https://github.com/psf/black - rev: 25.1.0 - hooks: - - id: black-jupyter + - repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black-jupyter - - repo: https://github.com/codespell-project/codespell - rev: v2.4.1 - hooks: - - id: codespell - args: - [ - "--ignore-words-list=aci,acount,acounts,fallow,ges,hart,hist,nd,ned,ois,wqs,watermask,tre", - "--skip=*.csv,*.geojson,*.json,*.yml*.js,*.html,*cff,*.pdf", - ] + - repo: https://github.com/codespell-project/codespell + rev: v2.4.1 + hooks: + - id: codespell + args: + [ + "--ignore-words-list=aci,acount,acounts,fallow,ges,hart,hist,nd,ned,ois,wqs,watermask,tre,mape", + "--skip=*.csv,*.geojson,*.json,*.yml*.js,*.html,*cff,*.pdf", + ] - - repo: https://github.com/kynan/nbstripout - rev: 0.8.1 - hooks: - - id: nbstripout + - repo: https://github.com/kynan/nbstripout + rev: 0.8.1 + hooks: + - id: nbstripout diff --git a/docs/PDFM/pdfm_superresolution.ipynb b/docs/PDFM/pdfm_superresolution.ipynb new file mode 100644 index 0000000..643134e --- /dev/null +++ b/docs/PDFM/pdfm_superresolution.ipynb @@ -0,0 +1,824 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "TVZlb_WyecAl" + }, + "source": [ + "# PDFM Super Resolution with Zillow Housing Data\n", + "\n", + "This notebook is adapted from the PDFM notebook example [here](https://github.com/google-research/population-dynamics/blob/master/notebooks/pdfm_superresolution_and_imputation.ipynb). Credits to the original authors." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YHgjY5eb7TBE" + }, + "source": [ + "## Data Preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VzNXfZyVS5wY" + }, + "source": [ + "### Step 1: Download a csv file of the embeddings using this [link](https://docs.google.com/forms/d/e/1FAIpQLSeZLIqTCIx1-OiBzUnqXZpu_k5M223ZvMmqwQhMZ_0TkaWhEQ/viewform).\n", + "\n", + "The county and ZCTA (zipcode census tabulation area) embeddings are available in different files.\n", + "\n", + "Here we assume that you have obtained the embeddings and uploaded them to a Google Drive directory called `pdfm_embeddings/v0/us`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ky58xlzJfPJ-" + }, + "outputs": [], + "source": [ + "import math\n", + "import pandas as pd\n", + "import numpy as np\n", + "import geopandas as gpd\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.linear_model import Ridge\n", + "from sklearn import metrics as skmetrics\n", + "import lightgbm as lgbm\n", + "from sklearn.pipeline import make_pipeline\n", + "from sklearn import preprocessing\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lC7wR5HD5RVR", + "outputId": "9177e012-12e8-474d-aef3-0e35e3adcab4" + }, + "outputs": [], + "source": [ + "BASE_PATH = \"./\" # Set this to the path where your data files are located\n", + "\n", + "county_embeddings = pd.read_csv(BASE_PATH + \"county_embeddings.csv\").set_index(\"place\")\n", + "zip_embeddings = pd.read_csv(BASE_PATH + \"zcta_embeddings.csv\").set_index(\"place\")\n", + "embeddings = pd.concat([county_embeddings, zip_embeddings])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "V_4Lt8Yl5vDG", + "outputId": "acfb79cc-a2b3-4577-fb90-c66e964fdb98" + }, + "outputs": [], + "source": [ + "embedding_features = [f\"feature{x}\" for x in range(330)]\n", + "embeddings.head(2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "V15mXezNgezz" + }, + "source": [ + "### Step 2: Download and load a few variables from GitHub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "zhvi_county_url = \"https://github.com/opengeos/datasets/releases/download/us/zillow_home_value_index_by_county.csv\"\n", + "zhvi_zipcode_url = \"https://github.com/opengeos/datasets/releases/download/us/zillow_home_value_index_by_zipcode.csv\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "zhvi_county_df = pd.read_csv(\n", + " zhvi_county_url, dtype={\"StateCodeFIPS\": \"string\", \"MunicipalCodeFIPS\": \"string\"}\n", + ")\n", + "zhvi_county_df[\"place\"] = (\n", + " \"geoId/\" + zhvi_county_df[\"StateCodeFIPS\"] + zhvi_county_df[\"MunicipalCodeFIPS\"]\n", + ")\n", + "zhvi_county_df = zhvi_county_df.set_index(\"place\")\n", + "zhvi_county_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "zhvi_zipcode_df = pd.read_csv(zhvi_zipcode_url, dtype={\"RegionName\": \"string\"})\n", + "zhvi_zipcode_df[\"place\"] = zhvi_zipcode_df[\"RegionName\"].apply(lambda x: f\"zip/{x}\")\n", + "zhvi_zipcode_df = zhvi_zipcode_df.set_index(\"place\")\n", + "zhvi_zipcode_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "zhvi_df = pd.concat([zhvi_county_df, zhvi_zipcode_df])\n", + "zhvi_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "zhvi_df[-5:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(zhvi_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = embeddings.join(zhvi_df, how=\"inner\")\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df[-3:]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OOnmR8lhg8VF" + }, + "source": [ + "## Data Visualizations" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kOI8jZyCG33q" + }, + "source": [ + "### Download the county and zcta (Zipcode census tabulation area) level geojson file.\n", + "\n", + "The county and zcta level geojson file are available in the same folder as the embeddings. Download the geojson file and upload to Google Colab." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WvfGKA4kGTCj" + }, + "outputs": [], + "source": [ + "county_geo = gpd.read_file(BASE_PATH + \"county.geojson\").set_index(\"place\")\n", + "zip_geo = gpd.read_file(BASE_PATH + \"zcta.geojson\").set_index(\"place\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aa4yagcBHo-J", + "outputId": "17513a0e-ea90-4103-8caa-4af339fa94cd" + }, + "outputs": [], + "source": [ + "geo = pd.concat([county_geo, zip_geo])\n", + "embeddings = gpd.GeoDataFrame(embeddings, geometry=geo.geometry)\n", + "embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = embeddings.join(zhvi_df).set_geometry(\"geometry\")\n", + "df.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df[\"county_id\"] = df[\"county\"] + df[\"state\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1voPA0zLJ0I" + }, + "source": [ + "### Map out an embedding dimension spatially" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BHctv_WXiNcS" + }, + "outputs": [], + "source": [ + "def get_locale(df, index, states=None, counties=None):\n", + " df = df[df.index.isin(index)]\n", + " if not states and not counties:\n", + " return df\n", + " filter = df.state.isin(states)\n", + " if counties:\n", + " filter &= df.county.isin(counties)\n", + " return df[filter]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yaZ2_936KTQg", + "outputId": "5bb24978-921c-425a-a788-5a0cabbb90f1" + }, + "outputs": [], + "source": [ + "# @title Map out an embedding dimension feature0 spatially across all counties in US\n", + "feature = embedding_features[300]\n", + "ax = get_locale(embeddings, embeddings.index).plot(feature)\n", + "_ = ax.set_title(feature + \" in counties\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7WQHnSLyKxMu", + "outputId": "2651cf5f-b231-4f02-c3f9-473cc58c85f7" + }, + "outputs": [], + "source": [ + "# @title Map out an embedding dimension feature0 spatially across all counties and zipcodes in NY state\n", + "fig, ax = plt.subplots(1, 2, figsize=(8, 4))\n", + "state = \"NY\"\n", + "get_locale(embeddings, county_embeddings.index, states=[state]).plot(feature, ax=ax[0])\n", + "get_locale(embeddings, zip_embeddings.index, states=[state]).plot(feature, ax=ax[1])\n", + "fig.suptitle(f\"{feature} in {state}\")\n", + "ax[0].set(title=\"counties\")\n", + "ax[1].set(title=\"zip codes\")\n", + "plt.setp(ax, xticks=[], yticks=[])\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6jGB-wETlYFX" + }, + "source": [ + "## Applying the embeddings in a prediction task" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "jhVnsHRUnmAc" + }, + "outputs": [], + "source": [ + "def evaluate(df: pd.DataFrame) -> dict:\n", + " \"\"\"Evaluates the model performance on the given dataframe.\n", + "\n", + " Args:\n", + " df: A pandas DataFrame with columns 'y' and 'y_pred'.\n", + "\n", + " Returns:\n", + " A dictionary of performance metrics.\n", + " \"\"\"\n", + " # Ensure necessary columns exist and drop rows with NaN or zero in 'y'\n", + " if not {\"y\", \"y_pred\"}.issubset(df.columns):\n", + " raise ValueError(\"DataFrame must contain 'y' and 'y_pred' columns\")\n", + "\n", + " df = df.dropna(subset=[\"y\", \"y_pred\"])\n", + " df = df[df[\"y\"] != 0]\n", + "\n", + " r2 = skmetrics.r2_score(df[\"y\"], df[\"y_pred\"])\n", + " correlation = float(df[\"y\"].corr(df[\"y_pred\"]))\n", + " rmse = math.sqrt(skmetrics.mean_squared_error(df[\"y\"], df[\"y_pred\"]))\n", + " mae = float(skmetrics.mean_absolute_error(df[\"y\"], df[\"y_pred\"]))\n", + " mape = float(skmetrics.mean_absolute_percentage_error(df[\"y\"], df[\"y_pred\"]))\n", + "\n", + " return {\n", + " \"r2\": r2,\n", + " \"rmse\": rmse,\n", + " \"mae\": mae,\n", + " \"mape\": mape,\n", + " \"correlation\": correlation,\n", + " }\n", + "\n", + "\n", + "def subset_eval(\n", + " label: str,\n", + " county_name: str,\n", + " state: str,\n", + " gpred: gpd.GeoDataFrame,\n", + " visualize: bool = True,\n", + " cmap: str = \"Greys\",\n", + ") -> dict:\n", + " \"\"\"Runs intra-county or intra-state evaluation and visualizes the results.\n", + "\n", + " Args:\n", + " label: The label for the title of the visualization.\n", + " county_name: The specific county name to filter.\n", + " state: The specific state name to filter.\n", + " gpred: GeoDataFrame containing 'y', 'y_pred', 'state', and 'county' columns.\n", + " visualize: Whether to display visualizations.\n", + " cmap: Colormap for visualizations.\n", + "\n", + " Returns:\n", + " A dictionary of performance metrics.\n", + " \"\"\"\n", + " # Apply filters based on state and county name\n", + " subset = gpred.copy()\n", + " if state:\n", + " subset = subset[subset[\"state\"] == state]\n", + " if county_name:\n", + " subset = subset[subset[\"county\"] == county_name]\n", + "\n", + " # Drop rows where 'y' is NaN\n", + " subset = subset.dropna(subset=[\"y\", \"y_pred\"])\n", + " eval_metrics = evaluate(subset)\n", + "\n", + " if visualize:\n", + " _, ax = plt.subplots(1, 3, figsize=(12, 4))\n", + "\n", + " # Scatter plot of predicted vs actual\n", + " subset.plot.scatter(\"y\", \"y_pred\", alpha=0.8, ax=ax[2], color=\"darkgray\")\n", + " x0, x1 = (\n", + " subset[[\"y\", \"y_pred\"]].min().min(),\n", + " subset[[\"y\", \"y_pred\"]].max().max(),\n", + " )\n", + " ax[2].plot([x0, x1], [x0, x1], ls=\"--\", color=\"black\")\n", + " ax[2].set_title(\n", + " f'r={eval_metrics[\"correlation\"]:.2f}, mae={eval_metrics[\"mae\"]:.2f}'\n", + " )\n", + "\n", + " # Maps of actual and predicted values\n", + " subset.plot(\n", + " \"y\",\n", + " legend=True,\n", + " ax=ax[0],\n", + " vmin=x0,\n", + " vmax=x1,\n", + " cmap=cmap,\n", + " legend_kwds={\"fraction\": 0.02, \"pad\": 0.05},\n", + " )\n", + " ax[0].set_title(\"Actual\")\n", + " subset.plot(\"y_pred\", legend=False, ax=ax[1], vmin=x0, vmax=x1, cmap=cmap)\n", + " ax[1].set_title(\"Predicted\")\n", + "\n", + " plt.setp(ax[:2], xticks=[], yticks=[])\n", + " plt.suptitle(f\"{label} - {county_name}, {state}\")\n", + " plt.tight_layout()\n", + "\n", + " return eval_metrics\n", + "\n", + "\n", + "def make_predictions_df(\n", + " predictions: np.ndarray, test_df: gpd.GeoDataFrame, label: str\n", + ") -> gpd.GeoDataFrame:\n", + " \"\"\"Creates a GeoDataFrame with predictions, true labels, and geographic info.\n", + "\n", + " Args:\n", + " predictions: A sequence of predictions.\n", + " test_df: The original test GeoDataFrame that the predictions are based on.\n", + " label: The column name for the true label in `test_df`.\n", + "\n", + " Returns:\n", + " A GeoDataFrame for evaluation and visualizations.\n", + " \"\"\"\n", + " if label not in test_df.columns:\n", + " raise ValueError(\n", + " f\"The specified label '{label}' does not exist in test_df columns.\"\n", + " )\n", + "\n", + " df_predictions = pd.DataFrame(\n", + " {\"y\": test_df[label], \"y_pred\": predictions}, index=test_df.index\n", + " )\n", + " return test_df[[\"geometry\", \"state\", \"county\"]].join(df_predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qQ55ASjJlgNL" + }, + "source": [ + "## Superresolution - Train the model on counties and make predictions for zip code\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# @title Train on counties and predict for zip codes\n", + "label = \"2025-01-31\"\n", + "data = df[df[label].notna()]\n", + "train = data[data.index.isin(county_geo.index)]\n", + "test = data[data.index.isin(zip_geo.index)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(train), len(test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Vey-QWY7lOvA", + "outputId": "e15f9340-d156-458c-e27b-f7b4bf40cb9c" + }, + "outputs": [], + "source": [ + "model = Ridge()\n", + "model.fit(train[embedding_features], train[label])\n", + "predictions = model.predict(test[embedding_features])\n", + "gdf_predictions = make_predictions_df(predictions, test, label)\n", + "evaluate(gdf_predictions)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 713 + }, + "id": "MG8RTdsjpAuL", + "outputId": "5a4654f5-ee14-44c0-d91b-f88268ff091f" + }, + "outputs": [], + "source": [ + "# @title Visualize some test set predictions\n", + "_ = subset_eval(label, \"Harris County\", \"TX\", gdf_predictions, cmap=\"Blues\")\n", + "_ = subset_eval(label, \"Greenville County\", \"SC\", gdf_predictions, cmap=\"Blues\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 365 + }, + "id": "02C2CbCdp1l0", + "outputId": "99aeadb5-a5c8-49b1-9332-48b15b08f42c" + }, + "outputs": [], + "source": [ + "# @title Evaluate over a state by setting the county to an empty string.\n", + "_ = subset_eval(label, \"\", \"NY\", gdf_predictions, cmap=\"Blues\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uohbUjMZyD63" + }, + "source": [ + "## Imputation - zip -> zip\n", + "\n", + "Train on zipcodes in a subset of counties." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "csYlgZKLqFoZ", + "outputId": "5dc400f3-afa4-4da9-ad74-03771a3dae36" + }, + "outputs": [], + "source": [ + "# @title train on zip codes in 20% of the counties, test on the remaining 80%.\n", + "\n", + "\n", + "def get_train_test_split(training_fraction=0.8):\n", + " data = df[df.index.isin(zip_embeddings.index)].copy()\n", + " # Split the zip codes by county into train/test sets.\n", + " train_counties = (\n", + " data.drop_duplicates(\"county_id\").sample(frac=training_fraction).county_id\n", + " )\n", + " train = data[data.county_id.isin(train_counties)]\n", + " test = data[~data.index.isin(train.index)]\n", + " print(\n", + " \"# training counties:\",\n", + " len(train_counties),\n", + " \"\\n# training zip codes:\",\n", + " train.shape[0],\n", + " \"\\n# test zip codes:\",\n", + " test.shape[0],\n", + " )\n", + " return train, test\n", + "\n", + "\n", + "def run_imputation_model(\n", + " train, test, label, min_population=500, model_class=Ridge, model_kwargs={}\n", + "):\n", + " train = train[(train.population >= min_population) & train[label].notna()]\n", + " test = test[(test.population >= min_population) & test[label].notna()]\n", + " model = make_pipeline(preprocessing.MinMaxScaler(), model_class(**model_kwargs))\n", + " model.fit(train[embedding_features], train[label])\n", + " predictions = model.predict(test[embedding_features])\n", + " gdf_predictions = make_predictions_df(predictions, test, label)\n", + " results = evaluate(gdf_predictions)\n", + " return model, results\n", + "\n", + "\n", + "# Increasing this value generally improves performance.\n", + "training_fraction = 0.2\n", + "label = \"2025-01-31\"\n", + "train, test = get_train_test_split(training_fraction)\n", + "model, results = run_imputation_model(train, test, label)\n", + "results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "tx-tx4uz3Etr", + "outputId": "2ce32c01-c4a2-421b-c22f-1b5c109ce25d" + }, + "outputs": [], + "source": [ + "# @title Visualize a few counties from the test set.\n", + "test_counties = test.county_id.unique()\n", + "large_counties = (\n", + " df[df.county_id.isin(test_counties)]\n", + " .sort_values(\"population\", ascending=False)[[\"state\", \"county\", \"population\"]]\n", + " .head(4)\n", + ")\n", + "for _, row in large_counties.iterrows():\n", + " _ = subset_eval(label, row.county, row.state, gdf_predictions, cmap=\"Blues\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 258 + }, + "id": "3bFwhB0E2qzy", + "outputId": "621fe283-b98d-4a68-e664-5b0fd369e658" + }, + "outputs": [], + "source": [ + "# @title Try other labels.\n", + "labels = [\n", + " \"2024-01-31\",\n", + " \"2024-02-29\",\n", + " \"2024-03-31\",\n", + " \"2024-04-30\",\n", + " \"2024-05-31\",\n", + " \"2024-06-30\",\n", + " \"2024-07-31\",\n", + " \"2024-08-31\",\n", + " \"2024-09-30\",\n", + " \"2024-10-31\",\n", + " \"2024-11-30\",\n", + " \"2024-12-31\",\n", + "]\n", + "train, test = get_train_test_split(0.8)\n", + "models_by_label = {}\n", + "metrics_df = pd.DataFrame(\n", + " columns=[\"label\", \"correlation\", \"r2\", \"rmse\", \"mae\", \"mape\", \"model\"]\n", + ")\n", + "for label in labels:\n", + " models_by_label[label], results = run_imputation_model(train, test, label)\n", + " results[\"label\"] = label\n", + " results[\"model\"] = \"linear\"\n", + " metrics_df.loc[len(metrics_df)] = results\n", + "\n", + "metrics_df.round(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 467 + }, + "id": "8NlkBAcl0UZe", + "outputId": "ac55f509-e675-4586-82d2-9aa62d1dfcb6" + }, + "outputs": [], + "source": [ + "# @title Try LightGBM models instead of linear.\n", + "\n", + "# This will take a few minutes to run.\n", + "models_by_label_lgbm = {}\n", + "metrics_df_lgbm = pd.DataFrame(\n", + " columns=[\"label\", \"r2\", \"rmse\", \"mae\", \"mape\", \"correlation\", \"model\"]\n", + ")\n", + "for label in labels:\n", + " models_by_label_lgbm[label], results = run_imputation_model(\n", + " train,\n", + " test,\n", + " label,\n", + " model_class=lgbm.LGBMRegressor,\n", + " model_kwargs={\n", + " \"min_child_samples\": 40,\n", + " \"importance_type\": \"gain\",\n", + " \"n_estimators\": 400,\n", + " \"learning_rate\": 0.04,\n", + " \"force_col_wise\": True,\n", + " },\n", + " )\n", + " results[\"label\"] = label\n", + " results[\"model\"] = \"lgbm\"\n", + " metrics_df_lgbm.loc[len(metrics_df_lgbm)] = results\n", + "\n", + "metrics_df_lgbm.round(3)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NlKp3P_oQBFS" + }, + "source": [ + "The LGBM results are mostly comparable with the linear model. They can be improved with more iterations and lower learning rate. You can also try setting `feature_fraction=0.5`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 448 + }, + "id": "EdnJWqnlzv5O", + "outputId": "c5228bf1-54a8-415c-abb3-5f2a3d00cc79" + }, + "outputs": [], + "source": [ + "# @title LightGBM feature importance\n", + "\n", + "features = {\n", + " \"trends\": (128, embedding_features[:128]),\n", + " \"maps\": (128, embedding_features[128:256]),\n", + " \"weather\": (74, embedding_features[256:]),\n", + "}\n", + "all_importance = []\n", + "for label, model in models_by_label_lgbm.items():\n", + " importance = pd.DataFrame(\n", + " model[1].feature_importances_, index=embedding_features, columns=[\"importance\"]\n", + " )\n", + " importance[\"importance\"] = importance[\"importance\"].abs()\n", + " for feature, dims in features.items():\n", + " importance.loc[dims[1], \"feature\"] = feature\n", + " importance = importance.groupby(\"feature\").importance.sum().reset_index()\n", + " importance[\"importance\"] = importance.importance / importance.importance.sum() * 100\n", + " importance[\"label\"] = label\n", + " all_importance.append(importance)\n", + "all_importance = pd.concat(all_importance)\n", + "_, ax = plt.subplots(figsize=(10, 3))\n", + "sns.barplot(\n", + " data=all_importance,\n", + " x=\"label\",\n", + " y=\"importance\",\n", + " hue=\"feature\",\n", + " hue_order=features.keys(),\n", + " ax=ax,\n", + ")\n", + "_ = plt.xticks(rotation=30)" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "Z1voPA0zLJ0I" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "geo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/mkdocs.yml b/mkdocs.yml index 3790ab7..bbbf4ca 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -59,6 +59,7 @@ nav: - PDFM: - PDFM/zillow_home_value.ipynb - PDFM/map_pdfm_features.ipynb + - PDFM/pdfm_superresolution.ipynb - Workshops: - workshops/AIforGood_2025.ipynb - Opengeos: https://github.com/opengeos