diff --git a/identify_linear_expression.ipynb b/identify_linear_expression.ipynb new file mode 100644 index 0000000000..ba3aebfe35 --- /dev/null +++ b/identify_linear_expression.ipynb @@ -0,0 +1,388 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BK5As0cbUejz", + "jupyter": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import random\n", + "\n", + "from etuples import etuple\n", + "from unification import unify, var\n", + "\n", + "import pytensor.tensor as pt\n", + "from pytensor.graph import rewrite_graph\n", + "from pytensor.graph.fg import FunctionGraph\n", + "from pytensor.graph.rewriting.basic import MergeOptimizer, PatternNodeRewriter, out2in" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.438328768Z", + "start_time": "2025-08-14T11:29:54.500174Z" + }, + "id": "alNycwOIUzTM" + }, + "outputs": [], + "source": [ + "def find_optimal_P(P, Q, mc):\n", + " pi = (Q * (P - mc)).sum()\n", + " dpi_dP = pt.grad(pi, P)\n", + " # P_star, success = root(dpi_dP, P, method=\"hybr\", optimizer_kwargs=dict(tol=1e-8))\n", + " # return P_star, success\n", + " return dpi_dP" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.440094174Z", + "start_time": "2025-08-14T11:31:54.469010Z" + }, + "id": "wVnYGz8GVKb4" + }, + "outputs": [], + "source": [ + "price_effect = pt.scalar(\"price_effect\")\n", + "price = pt.vector(\"price\")\n", + "trend = pt.vector(\"trend\")\n", + "seasonality = pt.vector(\"seasonality\")\n", + "mc = pt.scalar(\"marginal_cost\")\n", + "\n", + "price_term = price * price_effect\n", + "expected_sales = trend + price_term + seasonality" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.440827348Z", + "start_time": "2025-08-14T11:31:54.681476Z" + }, + "id": "BeitshYMVkQU" + }, + "outputs": [], + "source": [ + "expr = find_optimal_P(price, expected_sales, mc=mc)" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.443902007Z", + "start_time": "2025-08-14T11:31:54.918556Z" + }, + "id": "jugOxL4DcRFN" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add [id A] 5\n", + " ├─ Mul [id B] 4\n", + " │ ├─ Sub [id C] 3\n", + " │ │ ├─ price [id D]\n", + " │ │ └─ ExpandDims{axis=0} [id E] 2\n", + " │ │ └─ marginal_cost [id F]\n", + " │ └─ ExpandDims{axis=0} [id G] 0\n", + " │ └─ price_effect [id H]\n", + " ├─ trend [id I]\n", + " ├─ Mul [id J] 1\n", + " │ ├─ price [id D]\n", + " │ └─ ExpandDims{axis=0} [id G] 0\n", + " │ └─ ···\n", + " └─ seasonality [id K]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use existing rewrites to simplify expression\n", + "fgraph = FunctionGraph(outputs=[expr], clone=False)\n", + "rewrite_graph(fgraph, include=(\"canonicalize\",))\n", + "fgraph.dprint()" + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.445406846Z", + "start_time": "2025-08-14T11:31:55.243098Z" + }, + "id": "86-KeCOFWQZU" + }, + "outputs": [], + "source": [ + "# distribute_mul_over_add = PatternNodeRewriter(\n", + "# (pt.mul, (pt.add, \"x\", \"y\"), \"z\"),\n", + "# (pt.add, (pt.mul, \"x\", \"z\"), (pt.mul, \"y\", \"z\")),\n", + "# )\n", + "\n", + "distribute_mul_over_sub = PatternNodeRewriter(\n", + " (pt.mul, (pt.sub, \"x\", \"y\"), \"z\"),\n", + " (pt.add, (pt.mul, \"x\", \"z\"), (pt.mul, (pt.neg, \"y\"), \"z\")),\n", + ")\n", + "\n", + "combine_addition_terms = PatternNodeRewriter(\n", + " (pt.add, (pt.add, \"x\", \"y\"), \"z\", \"x\", \"w\"),\n", + " (pt.add, (pt.mul, \"x\", 2), (pt.add, \"y\", \"z\", \"w\")),\n", + ")\n", + "\n", + "# distribute_mul_over_add = out2in(distribute_mul_over_add, name=\"distribute_mul_add\")\n", + "distribute_mul_over_sub = out2in(distribute_mul_over_sub, name=\"distribute_mul_sub\")\n", + "combine_addition_terms = out2in(combine_addition_terms, name=\"combine_addition_terms\")\n", + "\n", + "# distribute\n", + "distribute_mul_over_sub.rewrite(fgraph)\n", + "# merge equivalent terms\n", + "MergeOptimizer().rewrite(fgraph)\n", + "# combine equivalent terms\n", + "combine_addition_terms.rewrite(fgraph)\n", + "# extract rewritten expression\n", + "expr = fgraph.outputs[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.446341276Z", + "start_time": "2025-08-14T11:31:56.276558Z" + }, + "id": "4qGBap72Xvvn" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add [id A]\n", + " ├─ Mul [id B]\n", + " │ ├─ Mul [id C]\n", + " │ │ ├─ price [id D]\n", + " │ │ └─ ExpandDims{axis=0} [id E]\n", + " │ │ └─ price_effect [id F]\n", + " │ └─ ExpandDims{axis=0} [id G]\n", + " │ └─ 2 [id H]\n", + " └─ Add [id I]\n", + " ├─ Mul [id J]\n", + " │ ├─ Neg [id K]\n", + " │ │ └─ ExpandDims{axis=0} [id L]\n", + " │ │ └─ marginal_cost [id M]\n", + " │ └─ ExpandDims{axis=0} [id E]\n", + " │ └─ ···\n", + " ├─ trend [id N]\n", + " └─ seasonality [id O]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 101, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "expr.dprint()" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.447033733Z", + "start_time": "2025-08-14T11:31:59.481064Z" + }, + "id": "8Fq10k2LcCY-" + }, + "outputs": [], + "source": [ + "# Create variations of a graph for pattern matching\n", + "rewrites = [\n", + " out2in(\n", + " PatternNodeRewriter((pt.add, \"x\", \"y\"), (pt.add, \"y\", \"x\")),\n", + " name=\"commutative_add\",\n", + " ignore_newtrees=True,\n", + " ),\n", + " out2in(\n", + " PatternNodeRewriter((pt.mul, \"x\", \"y\"), (pt.mul, \"y\", \"x\")),\n", + " name=\"commutative_mul\",\n", + " ignore_newtrees=True,\n", + " ),\n", + " out2in(\n", + " PatternNodeRewriter(\n", + " (pt.mul, (pt.mul, \"x\", \"y\"), \"z\"), (pt.mul, \"x\", (pt.mul, \"y\", \"z\"))\n", + " ),\n", + " name=\"associative_mul\",\n", + " ignore_newtrees=True,\n", + " ),\n", + "]\n", + "\n", + "\n", + "def yield_arithmetic_variants(expr, n):\n", + " fgraph = FunctionGraph(outputs=[expr], clone=False)\n", + " while n > 0:\n", + " rewrite = random.choice(rewrites)\n", + " res = rewrite.apply(fgraph)\n", + " n -= 1\n", + " if res:\n", + " yield fgraph.outputs[0]\n", + " yield fgraph.outputs[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.447578804Z", + "start_time": "2025-08-14T11:31:59.831774Z" + }, + "colab": { + "base_uri": "https://localhost:8080/", + "height": 198 + }, + "id": "h9K70LGxYJ7E", + "outputId": "793e98c6-4570-43bf-a452-eb6d0d745dc7" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{~price: price, ~a: Mul.0, ~b: Add.0}" + ] + }, + "execution_count": 103, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Rewrite graph randomly until we match price * a + b\n", + "a, b, price_ = var(\"a\"), var(\"b\"), var(\"price\")\n", + "pattern = etuple(pt.add, etuple(pt.mul, price_, a), b)\n", + "\n", + "for variant in yield_arithmetic_variants(expr, n=100):\n", + " match_dict = unify(variant, pattern)\n", + " if match_dict and match_dict[price_] is price:\n", + " break\n", + "else:\n", + " raise ValueError(\"No matching variant found\")\n", + "match_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.448905279Z", + "start_time": "2025-08-14T11:32:01.264784Z" + }, + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8M-qjXBKa6Db", + "outputId": "cdce40c4-e1dd-4757-f4d6-f368643bb5c1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True_div [id A]\n", + " ├─ Neg [id B]\n", + " │ └─ Add [id C]\n", + " │ ├─ Mul [id D]\n", + " │ │ ├─ Neg [id E]\n", + " │ │ │ └─ ExpandDims{axis=0} [id F]\n", + " │ │ │ └─ marginal_cost [id G]\n", + " │ │ └─ ExpandDims{axis=0} [id H]\n", + " │ │ └─ price_effect [id I]\n", + " │ ├─ trend [id J]\n", + " │ └─ seasonality [id K]\n", + " └─ Mul [id L]\n", + " ├─ ExpandDims{axis=0} [id H]\n", + " │ └─ ···\n", + " └─ ExpandDims{axis=0} [id M]\n", + " └─ 2 [id N]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "optimal_result = -match_dict[b] / match_dict[a]\n", + "optimal_result.dprint()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T11:32:09.449645675Z", + "start_time": "2025-08-14T11:25:52.269957Z" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index d08759a978..28ca7486ca 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -1315,7 +1315,7 @@ def c_code_cache_version(self): return v -softplus = Softplus(upgrade_to_float, name="scalar_softplus") +softplus = Softplus(upgrade_to_float, name="softplus") class Log1mexp(UnaryScalarOp): diff --git a/pytensor_tutorial_pricing.ipynb b/pytensor_tutorial_pricing.ipynb new file mode 100644 index 0000000000..2c39c43ae9 --- /dev/null +++ b/pytensor_tutorial_pricing.ipynb @@ -0,0 +1,3159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "28e4e5d2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-18T20:46:58.416962Z", + "start_time": "2025-08-18T20:46:58.309266Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'width': 1400, 'height': 768, 'scroll': True}" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from notebook.services.config import ConfigManager\n", + "\n", + "\n", + "cm = ConfigManager()\n", + "cm.update(\n", + " \"livereveal\",\n", + " {\n", + " \"width\": 1400,\n", + " \"height\": 768,\n", + " \"scroll\": True,\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9e2ce3b7", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from etuples import etuple\n", + "from IPython.display import clear_output\n", + "from unification import unify, var\n", + "\n", + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from pytensor.graph import rewrite_graph\n", + "from pytensor.graph.basic import explicit_graph_inputs\n", + "from pytensor.graph.features import History\n", + "from pytensor.graph.fg import FunctionGraph\n", + "from pytensor.graph.replace import graph_replace\n", + "from pytensor.graph.rewriting.basic import MergeOptimizer, PatternNodeRewriter, out2in\n", + "from pytensor.tensor.optimize import root\n", + "\n", + "\n", + "np.seterr(all=\"ignore\");" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a3340e9e", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "SEED = sum(map(ord, \"Pytensor at EuroScipy\"))\n", + "rng = np.random.default_rng(SEED)\n", + "\n", + "plt.rcParams.update(\n", + " {\n", + " \"figure.figsize\": (14, 4),\n", + " \"figure.dpi\": 144,\n", + " \"figure.constrained_layout.use\": True,\n", + " \"axes.spines.top\": False,\n", + " \"axes.spines.bottom\": True,\n", + " \"axes.spines.left\": True,\n", + " \"axes.spines.right\": False,\n", + " \"axes.grid\": True,\n", + " \"grid.linewidth\": 0.5,\n", + " \"grid.linestyle\": \"--\",\n", + " \"xtick.labelsize\": \"x-large\",\n", + " \"ytick.labelsize\": \"x-large\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "69a5a94e", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Recent Developments in Pytensor, the Successor Package to Theano\n", + "\n", + "### Theano is dead, long live Theano" + ] + }, + { + "cell_type": "markdown", + "id": "3bc7a98d", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# About Us\n", + "\n", + "- Jesse Grabowski\n", + " - Github [@jessegrabowski](https://github.com/jessegrabowski), [LinkedIn](https://www.linkedin.com/in/jessegrabowski/)\n", + " - PhD Candidate at Paris 1 - Pantheon Sorbonne\n", + " - 目前居住在上海\n", + " \n", + "- Ricardo Vieira\n", + " - Github [@ricardoV94](https://github.com/ricardoV94)\n", + " - Principal data scientist at PyMC Labs\n", + " - Living 2599 km from Lisbon, 292 km from Krakow, 4779 from Djibouti\n", + "\n", + "- Both: Core developers of PyTensor and PyMC (and more)" + ] + }, + { + "cell_type": "markdown", + "id": "fb75e67f-f1b5-4d44-bda6-d020f7635aab", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# What was Theano?" + ] + }, + { + "attachments": { + "before-img.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "id": "10be13c2", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Popular perspective: One of the first widely used deep learning libraries (dating back to 2007). Superseded by PyTorch/Tensorflow/JAX\n", + "\n", + "\n", + "![before-img.png](attachment:before-img.png)" + ] + }, + { + "attachments": { + "2a37930c-c466-4196-9329-858b70b45cac.png": { + "image/png": "" + } + }, + "cell_type": "markdown", + "id": "6c72d325-445e-4b81-956b-5914347a6dd6", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Technical perspective: A pythonic library for doing meta-computation, with many unique capabilities up to this day.\n", + "\n", + "![Screenshot From 2025-08-19 14-39-36.png](attachment:2a37930c-c466-4196-9329-858b70b45cac.png)" + ] + }, + { + "attachments": { + "after-img-small.jpg": { + "image/jpeg": "" + } + }, + "cell_type": "markdown", + "id": "de0a6098", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "# What is Pytensor?\n", + "\n", + "\n", + "![after-img-small.jpg](attachment:after-img-small.jpg)\n", + "\n", + "A library to define, manipulate, and compile computational graphs\n", + "\n", + "Also: A fork of (a fork of) Theano\n", + "\n", + "https://github.com/pymc-devs/pytensor" + ] + }, + { + "cell_type": "markdown", + "id": "d1a34afc", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Who is it for?\n", + "\n", + "- **Researchers** who work with numerical computation and want performance + programatic exploration of the problem-space.\n", + "\n", + "- **Developers** who want to offer expressive data-science frameworks, without pre-commiting to a specific computation library" + ] + }, + { + "cell_type": "markdown", + "id": "06ed3de9", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### What the user wants to write" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b7c60c8f", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "from scipy.special import expit\n", + "\n", + "\n", + "def cross_entropy_loss(y, p):\n", + " return (y * np.log(p) + (1 - y) * np.log(1 - p)).sum()\n", + "\n", + "\n", + "def compute_logistic_loss(X, y, alpha, beta):\n", + " p = expit(alpha + X @ beta)\n", + " return cross_entropy_loss(y, p)" + ] + }, + { + "cell_type": "markdown", + "id": "15f4b909", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### What the developer wants to read" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0d59bd3d", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "X_pt = pt.tensor(\"X\", shape=(None, 3))\n", + "y_pt = pt.tensor(\"y\", shape=(None,))\n", + "\n", + "alpha_pt = pt.tensor(\"alpha\", shape=())\n", + "beta_pt = pt.tensor(\"beta\", shape=(3,))\n", + "\n", + "p = pt.sigmoid(alpha_pt + X_pt @ beta_pt)\n", + "loss = (y_pt * pt.log(p) + (1 - y_pt) * pt.log(1 - p)).sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6df0677f", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sum{axes=None}\n", + " └─ Add\n", + " ├─ Mul\n", + " │ ├─ y\n", + " │ └─ Log\n", + " │ └─ Sigmoid\n", + " │ └─ Add\n", + " │ ├─ ExpandDims{axis=0}\n", + " │ │ └─ alpha\n", + " │ └─ Squeeze{axis=1}\n", + " │ └─ Matmul\n", + " │ ├─ X\n", + " │ └─ ExpandDims{axis=1}\n", + " │ └─ beta\n", + " └─ Mul\n", + " ├─ Sub\n", + " │ ├─ ExpandDims{axis=0}\n", + " │ │ └─ 1\n", + " │ └─ y\n", + " └─ Log\n", + " └─ Sub\n", + " ├─ ExpandDims{axis=0}\n", + " │ └─ 1\n", + " └─ Sigmoid\n", + " └─ ···\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss.dprint(id_type=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "457080c7-d36c-409a-8a69-8bb62e02db49", + "metadata": {}, + "source": [ + "### What the developer gets" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c8afd178-1dbe-45f4-8668-5a1be319964b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compute_logistic_loss" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "90771703-9dc4-43d2-903f-81f934ca11f8", + "metadata": {}, + "outputs": [], + "source": [ + "??compute_logistic_loss" + ] + }, + { + "cell_type": "markdown", + "id": "d9e460ed-e4d4-4bcd-9d40-d5285a049c39", + "metadata": {}, + "source": [ + "* Text-parsing\n", + "* Byte-code parsing?\n", + "* Tracing?" + ] + }, + { + "cell_type": "markdown", + "id": "5cd9701f", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### The PyTensor solution: start with the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "53180b93", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "import pytensor.tensor as pt\n", + "\n", + "\n", + "X_pt = pt.tensor(\"X\", shape=(None, 3))\n", + "y_pt = pt.tensor(\"y\", shape=(None,))\n", + "alpha_pt = pt.tensor(\"alpha\", shape=())\n", + "beta_pt = pt.tensor(\"beta\", shape=(3,))\n", + "\n", + "p = pt.sigmoid(alpha_pt + X_pt @ beta_pt)\n", + "loss = (y_pt * pt.log(p) + (1 - y_pt) * pt.log(1 - p)).sum()" + ] + }, + { + "cell_type": "markdown", + "id": "5c69a43f-e884-4728-b1d5-db3b6cd080b1", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Every variable is a symbolic place-holder, with a strict type." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "c81325a1-6a43-4b0f-8383-692d08fc13f3", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Sum{axes=None}.0" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2c0a5be5591f936", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(pytensor.tensor.variable.TensorVariable, TensorType(float64, shape=()))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(loss), loss.type" + ] + }, + { + "cell_type": "markdown", + "id": "8df944a9-cb6b-4790-a387-8a70ec61662e", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Variables can either be root inputs, or created by the application of operations, which connect inputs to outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "bc75e500fb9ece2f", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "pytensor.graph.basic.Apply" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(loss.owner)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f1bb3439391f2509", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(pytensor.tensor.math.Sum, 'Sum{axes=None}')" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(loss.owner.op), str(loss.owner.op)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "da1d7026-b462-4f6c-b8e4-fdf1f84cc155", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "([Add.0], [Sum{axes=None}.0])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss.owner.inputs, loss.owner.outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c9a8edd7-f765-4118-bd00-400a6fe40437", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(pytensor.tensor.variable.TensorVariable, TensorType(float64, shape=(None,)))" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(loss.owner.inputs[0]), loss.owner.inputs[0].type" + ] + }, + { + "cell_type": "markdown", + "id": "b57e98dc", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "We'll stop here for now, but we could navigate the whole graph from outputs to inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "316c181a13627039", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add\n", + " ├─ Mul\n", + " │ ├─ y\n", + " │ └─ Log\n", + " │ └─ Sigmoid\n", + " │ └─ Add\n", + " │ ├─ ExpandDims{axis=0}\n", + " │ │ └─ alpha\n", + " │ └─ Squeeze{axis=1}\n", + " │ └─ Matmul\n", + " │ ├─ X\n", + " │ └─ ExpandDims{axis=1}\n", + " │ └─ beta\n", + " └─ Mul\n", + " ├─ Sub\n", + " │ ├─ ExpandDims{axis=0}\n", + " │ │ └─ 1\n", + " │ └─ y\n", + " └─ Log\n", + " └─ Sub\n", + " ├─ ExpandDims{axis=0}\n", + " │ └─ 1\n", + " └─ Sigmoid\n", + " └─ ···\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss.owner.inputs[0].dprint(id_type=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "a00bc280", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## What can you do with a graph?\n", + "\n", + "- Query it\n", + "- Transform it \n", + "- Evaluate it\n", + "- Rinse and repeat\n", + "\n", + "**... all in Python!**" + ] + }, + { + "cell_type": "markdown", + "id": "75dceee0", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Query\n", + "\n", + "Find the inputs to the loss function" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d41f93fd", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[y, alpha, X, beta]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(pytensor.graph.basic.explicit_graph_inputs(loss))" + ] + }, + { + "cell_type": "markdown", + "id": "c2bb5ca2", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "Are there any patterns that we can take advantage of? \n", + "`log(sigmoid(x))` is an expression that shows often in our domain, and can be numerically optimized. Let's try and find it." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5fd93260", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "from pytensor.graph.basic import ancestors\n", + "\n", + "\n", + "def find_log_sigmoid(variable):\n", + " # Walk through the ancestors of a variable\n", + " for var in ancestors([variable]):\n", + " if (\n", + " # Check it is not a root variable\n", + " var.owner is not None\n", + " # and comes out of the application of a log\n", + " and var.owner.op == pt.log\n", + " # whose input is also not a root variable\n", + " and var.owner.inputs[0].owner is not None\n", + " # and comes out of the application of a sigmoid\n", + " and var.owner.inputs[0].owner.op == pt.sigmoid\n", + " ):\n", + " yield var" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "bce490af", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[Log.0]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(find_log_sigmoid(loss))" + ] + }, + { + "cell_type": "markdown", + "id": "916d7981", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Transform\n", + "\n", + "#### Stabilize\n", + "\n", + "Our `log(sigmoid(x))` can be written as `-softplus(-x)`, which turns out to be more stable.\n", + "\n", + "PyTensor allows to easily replace equivalent terms." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "38e081a0", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "stable_loss = loss\n", + "for log_sigmoid_term in find_log_sigmoid(loss):\n", + " x = log_sigmoid_term.owner.inputs[0].owner.inputs[0]\n", + " stable_loss = graph_replace(\n", + " stable_loss, replace={log_sigmoid_term: -pt.softplus(-x)}\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "b3f7e9ab", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "#### Add randomness to the system\n", + "\n", + "We can also mutate a graph to alter its behavior. Say to add randomness dowstream of some variables." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "685575d4", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "noisy_loss = graph_replace(loss, replace={X_pt: X_pt + pt.random.normal(0, 1)})" + ] + }, + { + "cell_type": "markdown", + "id": "6aa053d8", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Ablate part of the system\n", + "\n", + "Or ablate part of the graph, say by removing the contribution from the bias term." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "b9264d11", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "ablated_loss = graph_replace(loss, replace={alpha_pt: alpha_pt * 0})" + ] + }, + { + "cell_type": "markdown", + "id": "205b30e0", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Other common transformations\n", + "\n", + "- Autodiff\n", + "- Vectorization\n", + "- Bit quantization\n", + "- Backend specialization\n", + "- Numerical simplification\n", + "- Dead code elimination" + ] + }, + { + "cell_type": "markdown", + "id": "ecb6b6d8", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "## Evaluate (performantly)\n", + "\n", + "At the end of the day we need to crunch real numbers. \n", + "\n", + "PyTensor allows compiling the same graph to different computational backends." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "56293fd6", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "eval_dict = {\n", + " X_pt: rng.normal(size=(100, 3)),\n", + " y_pt: rng.binomial(n=1, p=0.2, size=(100,)),\n", + " alpha_pt: rng.normal(size=()),\n", + " beta_pt: rng.normal(size=(3,)),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b92143c6", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-86.58472878)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss.eval(eval_dict) # By default it uses a custom C backend" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "5d464526", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(-86.58472878, dtype=float64)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stable_loss.eval(eval_dict, mode=\"JAX\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "8be8abcd", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-96.51436343)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "noisy_loss.eval(eval_dict, mode=\"NUMBA\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "8fbd53c9", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array(-129.37799432)" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "try:\n", + " import torch\n", + "\n", + " mode = \"PYTORCH\"\n", + "except ModuleNotFoundError:\n", + " mode = \"NUMBA\"\n", + "\n", + "ablated_loss.eval(eval_dict, mode=mode)" + ] + }, + { + "cell_type": "markdown", + "id": "34b42a83", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Example sandbox: Sales Forecasting\n", + "\n", + "To motivate the features of Pytensor, I am going to use a specific example centered around timeseries forecasting. \n", + "\n", + "If you are not interested in this topic, I don't blame you. But I hope to use the example to show:\n", + "\n", + "1. How pytensor helps researchers accelerate their workflow\n", + "2. How developers can build on to of Pytensor to make extremely flexible software" + ] + }, + { + "cell_type": "markdown", + "id": "11b0ee70", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "## From the researcher perspective\n", + "\n", + "A researcher is going to have a specific model that she's interested in working with. Although there is just one model, it might need to be transformed into many different forms, to do different tasks. For example:\n", + "\n", + "- Pre-estimation checks (simulation, finding solutions)\n", + "- Estimation (\"taking the model to data\") \n", + "- Post-estimation (using the model to make optimal decisions, forecasting, prediction)\n", + "\n", + "The reseracher does *not* care about any of the details about how this happens. She wants to be able to define the function once, then have a nice API that allows all these things to happen.\n", + "\n", + "In addition, we're going to assume she's not an expert in numerical optimizaiton. She'll give her model in the form that is most natural to her, as a researcher in her domain, **not** in the form that is most computationally snappy or numerically stable." + ] + }, + { + "cell_type": "markdown", + "id": "e99676be", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "## From the developer perspective\n", + "\n", + "On the developer side, we assume that there is a general form for the model or class of models we're interesting in supporting. There might be a suite of tricks, simplifications, and stabilizations known by the literature that will allow for better performance. \n", + "\n", + "The developer wants to support the maximum number of models possible, but he also wants performant code. Pytensor allows him to define the general case, then add machinery to analyze it and look for special cases.\n", + "\n", + "We will see this when we turn out attention to \"post-estimation\". We imagine that our package is not simply for estimating sales, but then also using that model. In this example, we will think about using the model of sales to choose optimal prices." + ] + }, + { + "cell_type": "markdown", + "id": "0212e627", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "## The Prophet Model\n", + "\n", + "The prophet model is a time series decomposition model proposed by facebook in ...\n", + "\n", + "Like many time series models, it seeks to decompose an observed signal into a level, trend, and seasonality, so that:\n", + "\n", + "$$ y_t = \\text{level}_t + \\text{trend}_t + \\text{seasonality}_t $$\n", + "\n", + "Unlike other time series models, though, it does this using linear features. So the whole thing collapses back to good old OLS, but with carefully chosen features. \n", + "\n", + "As a result, it is easy to extend to include additional components. " + ] + }, + { + "cell_type": "markdown", + "id": "02ab0f1d", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Modeling Sales\n", + "\n", + "Suppose we want to model sales of some widgets as a function of price, which we get to set. We observe the sales (with some noise), as well as the prices (noiselessly, since we set them). \n", + "\n", + "Maybe our widgets are in higher or lower demand at different times in the year, so there are annual up and down cycles in the data, regardless of the price. Finally, there are economic forces we can't control: the market for our widgets goes up and down of its own accord.\n", + "\n", + "So we can use the prophet model to describe sales, adding in a regression term for the effect of price on sales:\n", + "\n", + "$$ \n", + "\\text{sales}_t = \\text{level}_t + \\text{trend}_t + \\text{seasonality}_t + \\beta \\cdot \\text{price}_t \n", + "$$\n", + "\n", + "For details about this model, there was a very nice presentation by Matthijs Brouns implementing it in PyMC [here](https://www.youtube.com/watch?v=appLxcMLT9Y). For our purposes here, you just need to understand that the model is a linear regression model with some fancy transformations of the time variable." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "c8e259e0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:06.239675Z", + "start_time": "2025-08-14T14:00:06.234183Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "def create_piecewise_trend(t, t_max, n_changepoints):\n", + " \"\"\"\n", + " Create a piecewise linear trend with n changepoints.\n", + " \"\"\"\n", + " s = pt.linspace(0, t_max, n_changepoints + 2)[1:-1]\n", + " A = (t[:, None] > s) * 1\n", + "\n", + " return A, s\n", + "\n", + "\n", + "def create_fourier_features(t, n, p=365.25):\n", + " \"\"\"\n", + " Create seasonal patterns using n fourier basis functions with period p\n", + " \"\"\"\n", + " x = 2 * np.pi * (pt.arange(n) + 1) * t[:, None] / p\n", + " return pt.concatenate((pt.cos(x), pt.sin(x)), axis=1)\n", + "\n", + "\n", + "def generate_features(t, t_max, n_changepoints=10, n_fourier=6, p=365.25):\n", + " \"\"\"\n", + " Generate peicewise trend matrices A and s, and seasonal pattern matrix X.\n", + " \"\"\"\n", + " A, s = create_piecewise_trend(t, t_max, n_changepoints)\n", + " X = create_fourier_features(t, n_fourier, p)\n", + "\n", + " return A, s, X" + ] + }, + { + "cell_type": "markdown", + "id": "657f33ae", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "### Define Symbolic Inputs\n", + "\n", + "Pytensor straddles the line between a symbolic algebra system like Maple or Sympy, and a array library like numpy. In general, we try to adhere very closely to numpy syntax. But like a symbolic algebra system, a pytensor program starts by declaring root variables.\n", + "\n", + "As symbols, these root variables can be freely manipulated. When the time comes to compile our program, these will have to be given as inputs by the user." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "b8a3cb2d", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "price, time = pt.vectors(\"price\", \"time\")\n", + "\n", + "initial_intercept = pt.scalar(\"initial_intercept\")\n", + "initial_slope = pt.scalar(\"initial_slope\")\n", + "trend_changes = pt.vector(\"trend_changes\")\n", + "seasonal_effect = pt.vector(\"seasonal_effect\")\n", + "price_effect = pt.scalar(\"price_effect\")\n", + "\n", + "# Gather everything together into a lists; this will be handy later\n", + "input_data = [time, price]\n", + "\n", + "params = [\n", + " initial_intercept,\n", + " initial_slope,\n", + " trend_changes,\n", + " seasonal_effect,\n", + " price_effect,\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "b8b58603", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Define the Prophet model\n", + "\n", + "All of this syntax should look just like numpy. If you hadn't already seen that the variables being manipulated are symbolic `vector` and `scalar` objects, you wouldn't know this isn't just ordinary computation." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "8f2db2a2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:06.311319Z", + "start_time": "2025-08-14T14:00:06.286197Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "A, s, X = generate_features(time, time.max(), p=52, n_fourier=1)\n", + "\n", + "intercept = initial_intercept + ((-s * A) * trend_changes[None]).sum(axis=1)\n", + "slope = (initial_slope + (A * trend_changes[None]).sum(axis=1)) * time\n", + "trend = intercept + slope\n", + "price_term = price_effect * price\n", + "\n", + "seasonal_term = X @ seasonal_effect\n", + "\n", + "expected_sales = trend + seasonal_term + price_term" + ] + }, + { + "cell_type": "markdown", + "id": "45777d36", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Compile a function\n", + "\n", + "Once we are done manipulating, we compile a function to actually perform numerical computation. This is done using `pytensor.function`. The `function` function needs to know:\n", + "\n", + "- `inputs`: root variables, to be provided by the user at runtime\n", + "- `outputs`: the variable(s) to be computed" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "f42335af", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "sales_fn = pytensor.function(inputs=[*input_data, *params], outputs=expected_sales)" + ] + }, + { + "cell_type": "markdown", + "id": "e6bcbe27", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "### Create a function to generate random data\n", + "\n", + "Next, we also compile a random function to generate observed data, including iid random noise.\n", + "\n", + "Handling random number generators in Pytensor is beyond the scope of this tutorial, so just accept that this cell does what I claim it does. The long and short of it is that Pytensor is purely function, and does not allow side effects in functions by default. This includes advancing a random state! To handle this, we offer a `RandomStream` object, that functions like `np.default_rng()`.\n", + "\n", + "To learn more, you can read [this tutorial](https://pytensor.readthedocs.io/en/latest/tutorial/prng.html)." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "5913cede", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:07.107408Z", + "start_time": "2025-08-14T14:00:06.336499Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "pt_rng = pt.random.RandomStream(seed=SEED)\n", + "observation_noise = pt_rng.normal(scale=1, size=expected_sales.shape)\n", + "\n", + "observed_sales = expected_sales + observation_noise\n", + "data_fn = pytensor.function([*input_data, *params], observed_sales)" + ] + }, + { + "cell_type": "markdown", + "id": "c2943c54", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "### Generate Data\n", + "\n", + "Generate some data with known parameter values. Our job will be to recover these values from observed data using some curve fitting algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "cabf1861", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:07.115009Z", + "start_time": "2025-08-14T14:00:07.111620Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "true_values = {\n", + " \"initial_intercept\": 10.0,\n", + " \"initial_slope\": 0.01,\n", + " \"trend_changes\": rng.normal(scale=1e-3, size=(10,)),\n", + " \"seasonal_effect\": np.array([2.0, 1.0]),\n", + " \"price_effect\": -0.6,\n", + "}\n", + "T = 52 * 5\n", + "time_value = np.arange(T)\n", + "obs_prices = rng.normal(loc=10, scale=1, size=(T,))\n", + "sales_idx = rng.choice(T, size=(25,), replace=False)\n", + "\n", + "for idx in sales_idx:\n", + " obs_prices[idx : idx + 4] /= 2\n", + "prices_obs = np.pad(\n", + " np.convolve(obs_prices, np.full(5, 0.95 ** (np.arange(5))), mode=\"valid\") / 5,\n", + " (2, 2),\n", + " mode=\"mean\",\n", + ")\n", + "\n", + "sales_obs = data_fn(time_value, obs_prices, **true_values)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "43d6bb58-234c-401d-b369-d398d5bf7d2a", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([6.2953841 , 7.73225729, 7.70092471, 6.71385719, 6.87541428])" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sales_obs[:5]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "18bf581e", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:07.305417Z", + "start_time": "2025-08-14T14:00:07.160318Z" + }, + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(2, 1, sharex=True)\n", + "for axis, data, title in zip(\n", + " fig.axes, [obs_prices, sales_obs], [\"Price\", \"Observed Sales\"]\n", + "):\n", + " axis.plot(data, lw=2)\n", + " axis.set_title(title, fontsize=18)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "214067be", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Loss Function and Minimization\n", + "\n", + "To recover the parameters that generated the data, we can choose the parameters of the model to minimize the mean squared error between the estimation of the model and the data.\n", + "\n", + "There are lots of choices for how to actually do the minimization. We will do gradient descent" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "11c9bec6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:07.317394Z", + "start_time": "2025-08-14T14:00:07.313602Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "learning_rate = pt.scalar(\"learning_rate\")\n", + "observed_sales = pt.vector(\"sales\")\n", + "\n", + "loss = ((expected_sales - observed_sales) ** 2).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "fc5dc4910f8954b9", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:07.363965Z", + "start_time": "2025-08-14T14:00:07.360396Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "def sgd_optimization(update_fn, init_params, n_iter=60_000):\n", + " optim_params = deepcopy(init_params)\n", + "\n", + " history = np.empty(n_iter)\n", + " start_i = None\n", + "\n", + " for i in range(n_iter):\n", + " curr_loss, *curr_grads = update_fn(**optim_params)\n", + " lr = max(1e-3 * 0.999**i, 5e-5)\n", + "\n", + " for key, grad in zip(optim_params.keys(), curr_grads):\n", + " optim_params[key] -= lr * 2 * grad / (np.linalg.norm(grad) + 1e-8)\n", + "\n", + " history[i] = curr_loss\n", + " if curr_loss < 5 and not start_i:\n", + " start_i = i\n", + "\n", + " if start_i and (i % 1000 == 0):\n", + " clear_output(wait=True)\n", + " plt.plot(np.arange(start_i, i), history[start_i:i])\n", + " plt.show(block=False)\n", + "\n", + " return optim_params" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "8521ad39", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "def optimize(loss, params, init_params, n_iter=60_000):\n", + " param_grads = pt.grad(loss, params)\n", + " update_fn = pytensor.function([*params], [loss, *param_grads])\n", + "\n", + " return sgd_optimization(update_fn, init_params, n_iter)" + ] + }, + { + "cell_type": "markdown", + "id": "a2a9b395", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "### Insert data into the graph\n", + "\n", + "We're not actually going to vary the data from call to call, so there's no reason to have symbolic inputs for the data. Instead, we we can just directly insert the data we're going to use.\n", + "\n", + "This is efficient because it gives Pytensor a chance to do constant folding. If we have something like `2 * observed_sales` in the graph, we can just compute that new constant once at compile time, rather than every time we call the funtion." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "a3cf0a03d705ff3d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:07.413545Z", + "start_time": "2025-08-14T14:00:07.407860Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "loss_w_data = graph_replace(\n", + " loss, {time: time_value, price: prices_obs, observed_sales: sales_obs}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca44785cc513b0ae", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "init_params = {\n", + " k: np.random.normal(scale=0.1, size=np.asarray(v).shape)\n", + " for k, v in true_values.items()\n", + "}\n", + "init_params[\"initial_intercept\"] = np.array(sales_obs[0])\n", + "init_params[\"initial_slope\"] = np.array(np.diff(sales_obs)[1])\n", + "\n", + "optim_params = optimize(loss_w_data, params, init_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a3d241", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "optim_params" + ] + }, + { + "cell_type": "markdown", + "id": "34803ec5", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "## Predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44098dc2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:14.047389Z", + "start_time": "2025-08-14T14:00:14.044521Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "predicted_sales = sales_fn(time_value, prices_obs, **optim_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f7f5221-211a-43e0-b2bd-5a4f3dd87d30", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "predicted_sales[:7]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e97d9125", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:14.287952Z", + "start_time": "2025-08-14T14:00:14.100195Z" + }, + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "plt.plot(sales_obs, c=\"k\", label=\"Observed\", ls=\"--\")\n", + "plt.plot(predicted_sales, c=\"firebrick\", lw=2, label=\"Predicted\")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "33f2925b", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Reusing Intermediate Results\n", + "\n", + "A simple place where the symbolic approach shines is when we are interested in intermediate computations.\n", + "\n", + "Specifically, compiling a function has no implications for any of the other symbolic variables we created along the way. So we are free to go back and continue manipulating, or even to compile new functions using differen parts of the graph.\n", + "\n", + "In the context of time series, we are often interesting in **time series decomposition**. That is, we want to know what part of the variance is attributable to the trend or the seasonality, and what is residual variance. \n", + "\n", + "In our case, we can simply ask for a new function that computes each of the intermediate bits that went into `expected_price`." + ] + }, + { + "cell_type": "markdown", + "id": "606bee0c", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "### Time Series Decomposition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63739e3f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:14.620862Z", + "start_time": "2025-08-14T14:00:14.351628Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "residual = observed_sales - expected_sales\n", + "f_decompose = pytensor.function(\n", + " [time, price, observed_sales, *params],\n", + " [intercept, slope, trend, seasonal_term, price_term, residual],\n", + ")\n", + "components = f_decompose(time_value, prices_obs, sales_obs, **optim_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21c276df", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:14.631568Z", + "start_time": "2025-08-14T14:00:14.629357Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "component_names = [\n", + " \"Intercept\",\n", + " \"Slope\",\n", + " \"Trend = Intercept + Slope\",\n", + " \"Seasonal Effect\",\n", + " \"Price Effect\",\n", + " \"Residual\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27503e5d", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:14.988487Z", + "start_time": "2025-08-14T14:00:14.679424Z" + }, + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(6, 1, figsize=(14, 12), dpi=144, sharex=True)\n", + "for axis, data, name in zip(fig.axes, components, component_names):\n", + " axis.plot(data)\n", + " axis.set_title(name, fontdict={\"weight\": \"bold\"})\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3c8ebb53", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Graph Rewriting\n", + "\n", + "Although we introduced pytensor as a \"static graph\" library, the graphs generated by pytensor can be freely manipulated. In particular, we are allowed to replace variables -- or even entire subgraphs! -- with new variables or subgraphs.\n", + "\n", + "In the next example we modify the expected sales function to simulate a counterfactual scenario where the seasonality term is only 10% as strong during a specific period. Perhaps a new pandemic outbreak forced everyone indoors, and usual ciycles are disrupted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fac9905-c26e-4c47-81db-5ed23736a71b", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "dampened_seasonal = seasonal_term[150:200].set(seasonal_term[150:200] * 0.1)\n", + "expected_sales_dampened = graph_replace(\n", + " expected_sales, {seasonal_term: dampened_seasonal}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35605033-83c7-4895-8ffb-a29bc8533ec5", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "sales_fn_dampened = pytensor.function(\n", + " inputs=[*input_data, *params], outputs=expected_sales_dampened\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f12e2a84-6f8f-449e-8473-4b1c4d3beb6a", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "new_predicted_sales = sales_fn_dampened(time_value, prices_obs, **optim_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c9bdd89-78b9-4de7-94b9-879420fe6a9b", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "outputs": [], + "source": [ + "plt.plot(sales_obs, c=\"0.8\", label=\"Observed\", ls=\"-\", lw=3)\n", + "plt.plot(predicted_sales, c=\"firebrick\", lw=3, label=\"Predicted\")\n", + "plt.plot(new_predicted_sales, c=\"dodgerblue\", lw=3, ls=\"--\", label=\"Counterfactual\")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "3e58a787", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "## \"Constrained\" Optimization\n", + "\n", + "Often, we might want to restrict a parameter value to be in a certain range. For example, theory says the price effect should be negative -- if prices go up, demand should go down. To enfore this, we can use constrained optimization.\n", + "\n", + "One way to do this is to use an optimization algorithm that can handle boundaries. This would probably work fine in our case, but it can be fussy and we would have to import `scipy` and all that. \n", + "\n", + "Instead, we could also use a change of variables. Rather than choosing `price_effect`, we can have the optimizer choose `log(price_effect)`, then replace `price_effect` with `-exp(log(price_effect))` in the objective function. This means the optimizer can choose whatever it wants (because log(price_effect) can be any value in $\\mathbb R$), but `-exp(log(price_effect))` will be strictly negative when actually evaluating the loss.\n", + "\n", + "We will have to account for the presence of the `exp` when taking gradients, but autodiff will handle that for us." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dda9b05b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:15.016777Z", + "start_time": "2025-08-14T14:00:15.013174Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "log_price_effect = pt.scalar(\"log_price_effect\")\n", + "constrained_loss = graph_replace(loss_w_data, {price_effect: -pt.exp(log_price_effect)})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afedc0b1", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "new_params = [*params[:-1], log_price_effect]\n", + "\n", + "new_init_params = deepcopy(init_params)\n", + "new_init_params[\"log_price_effect\"] = init_params[\"price_effect\"]\n", + "del new_init_params[\"price_effect\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "750738ab", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "constrained_optim_params = optimize(constrained_loss, new_params, new_init_params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d2abb11", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.020130Z", + "start_time": "2025-08-14T14:00:22.017505Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "constrained_optim_params" + ] + }, + { + "cell_type": "markdown", + "id": "70493621", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "Converting the unconstrained value back to the constrained space, we see that we get an answer similar to what we got when we did things the \"normal\" way.\n", + "\n", + "The point here was how easy it is to insert the unconstrained variable together with its inverse transformation in the loss function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8d20f9f", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.073622Z", + "start_time": "2025-08-14T14:00:22.071148Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "-np.exp(constrained_optim_params[\"log_price_effect\"])" + ] + }, + { + "cell_type": "markdown", + "id": "b66e8495", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Post Estimation: Choosing Optimal Price\n", + "\n", + "We now have a demand function, which tells us how many widgets we will sell given a price. We can now ask how to set the price to maximimze total profits, given this demand function.\n", + "\n", + "Our profit is just what we make by selling $Q_t$ widgets, minus the cost it took us to \"produce\" those widgets. Don't take \"produce\" too seriously -- we might be reselling things, in which case the cost of \"production\" is the wholesale cost. \n", + "\n", + "$$\n", + "\\pi_t = Q_t P_t - Q_t \\text{mc}_t = Q_t(P_t - \\text{mc}_t)\n", + "$$\n", + "\n", + "Where $Q_t$ is the quantity sold, $P_t$ is the unit price, and $\\text{mc}_t$ is the marginal cost of production for a single unit.\n", + "\n", + "We seek $P^\\star_t$ such that profits are maximized:\n", + "\n", + "$$\n", + "\\max_{P^\\star_t} Q_t(P^\\star_t - \\text{mc}_t)\n", + "$$\n", + "\n", + "The general strategy for solving these is that we take derivaties of the objective function with respect to the controls and set it equal to zero. In this case, thats $\\frac{\\partial \\pi_t}{\\partial P_t} = 0$. We then seek $P_t^\\star$ such that this equation is true.\n", + "\n", + "Assume we're making a package for this, and we know nothing about the $Q_t$ or $\\text{mc}_t$ that a user might provide. We can still symbolically solve this equation in the general case using a root finder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0403156", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.124362Z", + "start_time": "2025-08-14T14:00:22.122203Z" + }, + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "def find_optimal_P(P, Q, mc):\n", + " pi = (Q * (P - mc)).sum()\n", + " dpi_dP = pt.grad(pi, P)\n", + " P_star, success = root(dpi_dP, P, method=\"lm\", optimizer_kwargs=dict(tol=1e-8))\n", + " return P_star, success" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99c8f5a5", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.385504Z", + "start_time": "2025-08-14T14:00:22.174429Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "# Assume constant marginal cost\n", + "mc = pt.scalar(\"marginal_cost\")\n", + "\n", + "optimal_P, success = find_optimal_P(price, expected_sales, mc)\n", + "P_star_fn = pytensor.function([time, price, *params, mc], [optimal_P, success])\n", + "profit_fn = pytensor.function(\n", + " [price, observed_sales, mc], observed_sales * (price - mc)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d865efd0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.455209Z", + "start_time": "2025-08-14T14:00:22.395131Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "P_star, success_flag = P_star_fn(\n", + " time_value, np.zeros_like(time_value), **optim_params, marginal_cost=5.0\n", + ")\n", + "success_flag" + ] + }, + { + "cell_type": "markdown", + "id": "2b64e9e5", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Optimal Price\n", + "\n", + "We can see that the optimal price fluctuates with demand. When demand is naturally higher, we should raise prices, and vice-versa. In general, we've been charging too little for the product." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c57d3f4", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.586735Z", + "start_time": "2025-08-14T14:00:22.462889Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "plt.plot(P_star, lw=2, label=\"Optimal price\")\n", + "plt.plot(prices_obs, lw=2, label=\"Observed price\")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "09a04631", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "### Optimal Sales\n", + "\n", + "Plugging $P^\\star$ back into the demand function, we see that we've also been selling too much stuff." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36a1a29cb0b4ef4c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.717649Z", + "start_time": "2025-08-14T14:00:22.599125Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "Q_star = sales_fn(time_value, P_star, **optim_params)\n", + "\n", + "plt.plot(Q_star, lw=2, label=\"Sales under optimal price\")\n", + "plt.plot(sales_obs, lw=2, label=\"Observed sales\")\n", + "\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a29de953", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "### Optimal Profits\n", + "\n", + "Given that we solved a maximization problem, we expect to the blue curve to be equal to or larger than the orange one. We see that that this is indeed the case. In a few periods we happened to be close to to the correct price by coincidence, but in general we weren't making as much as we could have been." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f3d7aaf33cffaf4", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.841860Z", + "start_time": "2025-08-14T14:00:22.727940Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "plt.plot(\n", + " profit_fn(P_star, Q_star, marginal_cost=5.0),\n", + " lw=2,\n", + " label=\"Profits under optimal price\",\n", + ")\n", + "plt.plot(\n", + " profit_fn(\n", + " prices_obs, sales_fn(time_value, prices_obs, **optim_params), marginal_cost=5.0\n", + " ),\n", + " lw=2,\n", + " label=\"Observed profits\",\n", + ")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "02cf7ad4", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Supporting Special Cases: Linear Demand" + ] + }, + { + "cell_type": "markdown", + "id": "8bda8dae", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "If you are familiar with numerical optimizers, you know they can be fussy. We'd really like to avoid them if we can. In this case we can, because the demand function is just linear. Bascially, we have:\n", + "\n", + "$$ Q_t =f(t) + \\beta P_t $$\n", + "\n", + "Where $f(t) = \\text{level}_t + \\text{trend}_t + \\text{seasonality}_t$\n", + "\n", + "Substituting that into the profit function:\n", + "\n", + "$$\n", + "\\max_{P_t} \\pi_t = (f(t) + \\beta P_t) (P_t - \\text{mc}_t)\n", + "$$\n", + "\n", + "\n", + "Expand terms:\n", + "\n", + "$$\n", + "\\max_{P_t} \\pi_t = f(t)P_t - f(t) \\text{mc}_t + \\beta P_t^2 - \\beta \\text{mc}_t P_t \n", + "$$\n", + "\n", + "Solve for a first-order condition:\n", + "\n", + "$$\n", + "\\begin{aligned}\n", + "\\frac{\\partial \\pi_t}{\\partial P_t} &= 0 \\Rightarrow \\\\\n", + "f(t) + 2 \\beta P_t - \\beta \\text{mc}_t &= 0 \\\\\n", + "P_t^\\star &= \\frac{\\beta \\text{mc}_t - f(t)}{2\\beta}\n", + "\\end{aligned}\n", + "$$\n", + "\n", + "This is a well-known result from Economics 101, giving the optimal price of a monopoly firm that faces a linear demand function." + ] + }, + { + "cell_type": "markdown", + "id": "dfc5787a", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "## Graph Analysis\n", + "\n", + "To do this, we need to dig a bit deeper into how Pytensor works.\n", + "\n", + "First, every value returned by a pytensor function is itself a *computational graph*. At any point, we can stop and start manipulating that graph.\n", + "\n", + "In this case, the graph we are interested in working with is the derivative of profit with respect to price. We want to identify cases where this derivative is of the form $a + bP = 0$. Then, we can extract $a$ and $b$, and compute $P^\\star = -\\frac{a}{b}$\n", + "\n", + "This can be complex, though. $a$ and $b$ might themselves be complicated expressions. There are lots of weird little corner cases. What if $b$ is positive, but enters with subtraction, so we have $a - bP$? This is also valid, but we might miss it.\n", + "\n", + "To handle this, Pytensor has a database of transformations called **canonicalization**. This transforms a graph into a \"standard form\", so that further rewrites can reason about things in a standard way." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c057bb00", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.944364Z", + "start_time": "2025-08-14T14:00:22.902639Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "pi = (expected_sales * (price - mc)).sum()\n", + "dpi_dP = pt.grad(pi, price)\n", + "expr = rewrite_graph(dpi_dP, include=(\"canonicalize\",))" + ] + }, + { + "cell_type": "markdown", + "id": "f1409885", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "Here is a string representation of the derivative, after canonicalization. \n", + "\n", + "We can see that it's linear in price, because the whole thing is just an addition (the outer-most node is an `Add` Op with several inputs). `price [id D]` appears twice, in the first term `Mul(Sub(price, marginal_cost), price_effect))` $\\rightarrow \\beta(P - \\text{mc})$, then in the last term, as `Mul(price_effect, price)` $\\rightarrow \\beta P$\n", + "\n", + "So we just need to do some algebra to get this into our $a + bP$ form." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5cc257a302d8f6b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:22.957135Z", + "start_time": "2025-08-14T14:00:22.952165Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "expr.dprint(depth=4)" + ] + }, + { + "cell_type": "markdown", + "id": "3090d8ab", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "Recall that we are starting from the form $Q(P - mc)$, where $Q$ itself is a function of $P$. To get to our $a + bP$ form, we need to distribute the multiplication across everything. This is additional canonicalization that is not done by default.\n", + "\n", + "We can add it using `PatternNodeRewriter`. This works by pattern matching on a graph. Take the `distribution_mul_over_add` rewrite for example:\n", + "\n", + "```py\n", + "distribute_mul_over_add = PatternNodeRewriter(\n", + " (pt.mul, (pt.add, \"x\", \"y\"), \"z\"),\n", + " (pt.add, (pt.mul, \"z\", \"x\"), (pt.mul, \"z\", \"y\")),\n", + ")\n", + "```\n", + "\n", + "It is going to look for any any expression of the form `(x + y) * z`, and replace it with `(xz + yz)`. The rewrite is written as a tuple-encoded abstract syntax tree in prefix notation, where each tuple first specifies an operator, then a list of inputs. \n", + "\n", + "- The first argument is what to look for: a multiplication with two inputs, where the first input is an additon with two inputs\n", + "- The second argumnet is what to to insert: an addition with two inputs, where each input is a multiplication of two inputs.\n", + "\n", + "The letters in the inputs and the outputs are matched and used consistently. So the first input to the input addition ends up as the 2nd input to the first multiplication in the output. \n", + "\n", + "We have to handle distribution over addition and subtraction separately, since these are different `Ops`. We also need a rewrite to collect repeated terms in an addition and replace them with multiplication. `combine_addition_terms` will look for expressions of the form `(x + y) + z + w + q + r + x` and replace it with `2x + y + z + w + q + r`.\n", + "\n", + "**Note!** If all this seems overfit to the specific problem, it is! `PatternNodeRewriter` is powerful, but also limited. In general, we can write rewrite *functions* to inspect and reason about graphs in arbitrary ways, all in pure python. For the purposes of this tutorial, however, we want to keep things simple, so we only use pattern rewrites. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a5a2215", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "outputs": [], + "source": [ + "distribute_mul_over_sub = PatternNodeRewriter(\n", + " (pt.mul, (pt.sub, \"x\", \"y\"), \"z\"),\n", + " (pt.add, (pt.mul, \"z\", \"x\"), (pt.mul, \"z\", (pt.neg, \"y\"))),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "762f27d7", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:23.063446Z", + "start_time": "2025-08-14T14:00:23.053053Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "combine_addition_terms = PatternNodeRewriter(\n", + " (pt.add, (pt.add, \"x\", \"y\"), \"z\", \"w\", \"q\", \"r\", \"x\"),\n", + " (pt.add, (pt.mul, \"x\", 2), (pt.add, \"y\", \"z\", \"w\", \"q\", \"r\")),\n", + ")\n", + "\n", + "distribute_mul_over_sub = out2in(distribute_mul_over_sub, name=\"distribute_mul_sub\")\n", + "combine_addition_terms = out2in(combine_addition_terms, name=\"combine_addition_terms\")\n", + "\n", + "fgraph = FunctionGraph(outputs=[expr], clone=False)\n", + "\n", + "# Distribute multiplication\n", + "distribute_mul_over_sub.rewrite(fgraph)\n", + "\n", + "# Merge equivalent sub-expressions\n", + "MergeOptimizer().rewrite(fgraph)\n", + "\n", + "# Gather repeated additions into multiplication\n", + "combine_addition_terms.rewrite(fgraph)\n", + "\n", + "# Extract the rewritten expression\n", + "expr = fgraph.outputs[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad906033", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:23.110862Z", + "start_time": "2025-08-14T14:00:23.106096Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "expr.dprint(depth=4)" + ] + }, + { + "cell_type": "markdown", + "id": "9215219a", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "Then next step is to handle those pesky corner cases I mentioned. We want to make sure that we recognize $a + bP$, but also $bP + a$, or $Pb + a$, or $a + Pb$. \n", + "\n", + "The next 3 rewrites are just allowing inputs to be permuted. We know what we are looking for, and we can apply these rewrites in different combinations to do an exhaustive search. That way we can be sure we won't miss anything " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "742b67c8b729da37", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:23.211792Z", + "start_time": "2025-08-14T14:00:23.209182Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "# Create variations of a graph for pattern matching\n", + "rewrites = [\n", + " PatternNodeRewriter((pt.add, \"x\", \"y\"), (pt.add, \"y\", \"x\")),\n", + " PatternNodeRewriter((pt.mul, \"x\", \"y\"), (pt.mul, \"y\", \"x\")),\n", + " PatternNodeRewriter(\n", + " (pt.mul, (pt.mul, \"x\", \"y\"), \"z\"), (pt.mul, \"x\", (pt.mul, \"y\", \"z\"))\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "0cb4048c", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "The `yield_rewrite_variants` function is very deep into the Pytensor weeds. \n", + "\n", + "It is going to apply the 3 associative rewrites defined in the `rewrites` dictionary in different combinations.\n", + "\n", + "If it fails, it is able to rewind the rewrite it tried, so we can try a new form.\n", + "\n", + "It will also hash each unique forms of the graph that it has tried, so that we don't end up stuck in an endless loop of meaningless permutations.\n", + "\n", + "For purposes here, it is not necessary that you totally understand the implementation details of what is happening here. Only the high-level concept. We seek $a + bP$, but it might be hidden in some equivalent form, so we need to do an exhaustive search across those equivalent forms.\n", + "\n", + "As a note, Pytensor supports relational programming using Kanren and Egraphs to more elegantly attack problems like this. We just don't want to open up that whole kettle of relational programming fish in this tutorial. For details, you can see a brief tutorial [here](https://pytensor.readthedocs.io/en/latest/extending/graph_rewriting.html#minikanren), or a conversation between Ricardo and one of the Egglog developers [here](https://egglog-python.readthedocs.io/latest/explanation/2023_11_17_pytensor.html) about using eGraphs in Pytensor.\n", + "\n", + "But this is also an area where there is a lot of room for active development. That is to say, PRs welcome :)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31d3de4ff1ec52e5", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:23.163429Z", + "start_time": "2025-08-14T14:00:23.159599Z" + }, + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "def yield_rewrite_variants(expr, rewrites, variants_seen=None):\n", + " yield expr\n", + " if variants_seen is None:\n", + " variants_seen = set()\n", + " # HASH WITH text repr: genius!!!\n", + " variants_seen.add(expr.dprint(file=\"str\"))\n", + "\n", + " history = History()\n", + " fgraph = FunctionGraph(outputs=[expr], clone=False)\n", + " fgraph.attach_feature(history)\n", + " toposort = fgraph.toposort()\n", + "\n", + " for i, node in enumerate(reversed(toposort)):\n", + " if len(node.outputs) > 1:\n", + " # Only work with single output nodes\n", + " continue\n", + "\n", + " replacements = [rewrite.transform(fgraph, node) for rewrite in rewrites]\n", + " for replacement in replacements:\n", + " if not replacement:\n", + " continue\n", + " if isinstance(replacement, dict):\n", + " raise ValueError(\"Dict replacement not supported\")\n", + "\n", + " # Apply one of the replacements at a time, and recurse from there\n", + " checkpoint = fgraph.checkpoint()\n", + " fgraph.replace_all(tuple(zip(node.outputs, replacement, strict=True)))\n", + " expr = fgraph.outputs[0]\n", + " if expr.dprint(file=\"str\") not in variants_seen:\n", + " # Try variants on top of this rewrite recursively\n", + " yield from yield_rewrite_variants(\n", + " expr, rewrites, variants_seen=variants_seen\n", + " )\n", + " fgraph.revert(checkpoint) # Go back and try the next branch" + ] + }, + { + "cell_type": "markdown", + "id": "8d94810c", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "Finally, we're ready to go look for our $a + bP$ form. For this, we use an `etuple` together with `unify`. Details about what this all means can be found [here](https://pytensor.readthedocs.io/en/latest/extending/graph_rewriting.html#unification-and-reification).\n", + "\n", + "The short explaination is that we're using then `unification` package to perform [logical unification](https://en.wikipedia.org/wiki/Unification_(computer_science)). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb9e785fa8d12230", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:23.301205Z", + "start_time": "2025-08-14T14:00:23.259125Z" + }, + "scrolled": true, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "a, b, price_ = var(\"a\"), var(\"b\"), var(\"price\")\n", + "pattern = etuple(pt.add, etuple(pt.mul, price_, b), a)\n", + "\n", + "for variant in yield_rewrite_variants(expr, rewrites):\n", + " match_dict = unify(variant, pattern)\n", + " if match_dict and match_dict[price_] is price:\n", + " break\n", + "else:\n", + " raise ValueError(\"No matching variant found\")\n", + "match_dict" + ] + }, + { + "cell_type": "markdown", + "id": "6569d1d2", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "We found a match, which returns the graphs `a`, `b`, and `P` in our target expression $a + bP$.\n", + "\n", + "As expected, $a$ is a big function of all the non-price terms from the Prophet model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d677b5f", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "match_dict[a].dprint(depth=4)" + ] + }, + { + "cell_type": "markdown", + "id": "5e3b1cf5", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "While $b = 2\\beta$." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "996ac546", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "match_dict[b].dprint()" + ] + }, + { + "cell_type": "markdown", + "id": "fcc50d0c", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "source": [ + "With this in hand, we symbolically compute $\\text{Optimal P} = -\\frac{a}{b}$" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be105626", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-14T14:00:23.337512Z", + "start_time": "2025-08-14T14:00:23.332030Z" + }, + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "symbolic_P_star = -match_dict[a] / match_dict[b]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0b64c65", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "fn_P_star_2 = pytensor.function(\n", + " [time, price, *params, mc], symbolic_P_star, on_unused_input=\"ignore\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e0d50dd", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "analytic_P_star = fn_P_star_2(time_value, prices_obs, **optim_params, marginal_cost=5.0)" + ] + }, + { + "cell_type": "markdown", + "id": "e510ae79", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "As a sanity check, we plot our analytical solution found via rewrites against the numerical solution. As expected, they match!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2372f497", + "metadata": { + "slideshow": { + "slide_type": "-" + } + }, + "outputs": [], + "source": [ + "plt.plot(analytic_P_star, lw=3, label=\"Analytic Solution\")\n", + "plt.plot(P_star, ls=\"--\", lw=3, label=\"Numerical Solution\")\n", + "plt.legend(fontsize=16)" + ] + }, + { + "cell_type": "markdown", + "id": "577e11a9", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "# Repackage it nicely\n", + "\n", + "Of course, if this were a real API, we wouldn't want to expose any of this to the user. She should just give us a demand funciton, then we do the best we can with it.\n", + "\n", + "In this case, we want to first look for a linear solution. If we find one, great! We return that. Otherwise, we drop back to the numerical solver. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04580f82", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "def cannonicalize_sales_expr(expr):\n", + " fgraph = FunctionGraph(outputs=[expr], clone=False)\n", + "\n", + " distribute_mul_over_sub.rewrite(fgraph)\n", + " MergeOptimizer().rewrite(fgraph)\n", + " combine_addition_terms.rewrite(fgraph)\n", + " return fgraph.outputs[0]\n", + "\n", + "\n", + "def find_exact_linear_solution(expr):\n", + " P_star, success = None, None\n", + " expr = cannonicalize_sales_expr(expr)\n", + "\n", + " a, b, price_ = var(\"a\"), var(\"b\"), var(\"price\")\n", + " pattern = etuple(pt.add, etuple(pt.mul, price_, a), b)\n", + "\n", + " for variant in yield_rewrite_variants(expr, rewrites):\n", + " match_dict = unify(variant, pattern)\n", + " if match_dict and match_dict[price_] is price:\n", + " print(\"Found linear price function, using exact solution!\")\n", + " P_star = -match_dict[b] / match_dict[a]\n", + " success = pt.as_tensor(np.array(True))\n", + " break\n", + "\n", + " return P_star, success\n", + "\n", + "\n", + "def find_optimal_P_v2(P, Q, mc):\n", + " pi = (Q * (P - mc)).sum()\n", + " dpi_dP = pt.grad(pi, P)\n", + "\n", + " expr = rewrite_graph(dpi_dP, include=(\"canonicalize\",))\n", + "\n", + " # Try for the exact solution\n", + " P_star, success = find_exact_linear_solution(expr)\n", + "\n", + " # If we fail, fall back to a numerical optimizer\n", + " if P_star is None:\n", + " print(\"No exact solution avaiable, using numerical solver\")\n", + " P_star, success = root(expr, P, method=\"hybr\", optimizer_kwargs=dict(tol=1e-8))\n", + "\n", + " return P_star, success" + ] + }, + { + "cell_type": "markdown", + "id": "1fff5780", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "Here is what the user might actually face in reality. She called `find_optimal_P`, passing the price variable, her model of expected sales, and the marginal cost. We then do the best we can, returning either the exact linear solution, or the symbolic root finder, then compile a funtion she can use to do analysis.\n", + "\n", + "In a real package, we'd probably want to hide the compilation too!\n", + "\n", + "Note that we have to use the `on_unused_input` argument to `pytensor.function`. By default, if you provide a root variable that Pytensor does not require to compute the outputs, it will raise an error. Here, we need the prices as initial values of the root finder, but we don't require them if we find a linear solution. Since we don't know what will happen, we always require the prices to be passed in. In the linear case, they are just ignored." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c294e153", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "optimal_P, success = find_optimal_P_v2(price, expected_sales, mc)\n", + "P_star_fn = pytensor.function(\n", + " [time, price, *params, mc], [optimal_P, success], on_unused_input=\"ignore\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d685c25", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "P_star, success_flag = P_star_fn(\n", + " time_value, np.zeros_like(time_value), **optim_params, marginal_cost=5.0\n", + ")\n", + "Q_star = sales_fn(time_value, P_star, **optim_params)\n", + "\n", + "plt.plot(\n", + " profit_fn(P_star, Q_star, marginal_cost=5.0),\n", + " lw=2,\n", + " label=\"Profits under optimal price\",\n", + ")\n", + "plt.plot(\n", + " profit_fn(\n", + " prices_obs, sales_fn(time_value, prices_obs, **optim_params), marginal_cost=5.0\n", + " ),\n", + " lw=2,\n", + " label=\"Observed profits\",\n", + ")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5d3eb601", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "# Handling Complex Scenarios\n", + "\n", + "Again, the power of pytensor is that it can keep handling more and more complex situations without much additional effort.\n", + "\n", + "A common situation in these types of price optimization problems is that decision makers have already committed to a certain price at certain times. So you're not actually able to freely optimize in every period. Instead, you have to do the optimization of certain periods, subject to fixed prices in other periods.\n", + "\n", + "In the final example, we suppose that we are free to optimize prices for 150 periods, between $t=100$ and $t=250$. Otherwise, we have to take a fixed price.\n", + "\n", + "To handle this, we extend the `find_optimal_P` function one more time, to check if the incoming `P` is itself a function of some underlying input. For simplicity, the only function we allow is `SetSubtensor`, which does what it adverties. If we find that `P` is the result of a `SetSubtensor` Op, we reach into that Op and grab the input, then do all the optimization with respect to *that* input, rather than P itself.\n", + "\n", + "This is a case where it would be very difficult to handle without access to the underlying computational graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f29c0ce-6b78-401a-b62c-26893d5ab57e", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "def find_optimal_P_v3(P, Q, mc):\n", + " pi = (Q * (P - mc)).sum()\n", + " # Check which root inputs of P are actually free variables\n", + " [choice_variable] = explicit_graph_inputs([P])\n", + "\n", + " # Whatever we found, optimize it\n", + " dpi_dP = pt.grad(pi, choice_variable)\n", + " expr = rewrite_graph(dpi_dP, include=(\"canonicalize\",))\n", + " P_star, success = find_exact_linear_solution(expr)\n", + "\n", + " if P_star is None:\n", + " print(\"No exact solution avaiable, using numerical solver\")\n", + " P_star, success = root(\n", + " expr, choice_variable, method=\"lm\", optimizer_kwargs=dict(tol=1e-8)\n", + " )\n", + "\n", + " # If we optimized with respect to an input to P, sustitute the optimal value back into\n", + " # the original P\n", + " P_star = graph_replace(P, {choice_variable: P_star})\n", + "\n", + " return P_star, success" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec8ec6cf-effe-46cc-990c-c5b33562017d", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "free_prices = pt.vector(\"free_prices\")\n", + "price_partially_fixed = pt.as_tensor(prices_obs)[100:250].set(free_prices)\n", + "\n", + "sales_partially_fixed = graph_replace(expected_sales, {price: price_partially_fixed})\n", + "\n", + "optimal_P_constrained, success = find_optimal_P_v3(\n", + " price_partially_fixed, sales_partially_fixed, mc\n", + ")\n", + "\n", + "P_star_constrained_fn = pytensor.function(\n", + " [time, price, free_prices, *params, mc],\n", + " [optimal_P_constrained, success],\n", + " on_unused_input=\"ignore\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b2ebd77-0a60-4959-88a7-893ac43c4ef5", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "P_star_constrained, _ = P_star_constrained_fn(\n", + " time_value, prices_obs, np.zeros(150), **optim_params, marginal_cost=5.0\n", + ")\n", + "Q_star = sales_fn(time_value, P_star_constrained, **optim_params)" + ] + }, + { + "cell_type": "markdown", + "id": "de1e6f8e", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "Because we've inserted this `SetSubtensor` Op between the price and the model, our logic for finding $a + bP$ breaks. That fine though, we just drop back to the numerical solver." + ] + }, + { + "cell_type": "markdown", + "id": "bd7aff7a", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "source": [ + "As promised, we now only optimize the period between $t=100$ and $t=250$. Otherwise, we're constrained to accept the observed price." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58db2a42", + "metadata": { + "slideshow": { + "slide_type": "skip" + } + }, + "outputs": [], + "source": [ + "plt.plot(\n", + " profit_fn(P_star_constrained, Q_star, marginal_cost=5.0),\n", + " lw=3,\n", + " label=\"Profit under optimized price\",\n", + ")\n", + "plt.plot(\n", + " profit_fn(\n", + " prices_obs, sales_fn(time_value, prices_obs, **optim_params), marginal_cost=5.0\n", + " ),\n", + " lw=3,\n", + " ls=\"--\",\n", + " label=\"Observed profit\",\n", + ")\n", + "plt.legend(fontsize=16)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "d65e6c91", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Summary and Conclusion\n", + "\n", + "- Theano was a unique package, a symbolic math library that got pidgin holed as a deep learning library.\n", + "- The ability to symbolically manipulate programs is extremely powerful, enabling optimization, user-focused APIs, and transpilation.\n", + "- Pytensor continues the legacy of Theano, continuing to add more rewrites and optimizations, improve documentation, extend to new compiled backends, and make entirely new features" + ] + }, + { + "cell_type": "markdown", + "id": "04947ea8", + "metadata": { + "slideshow": { + "slide_type": "subslide" + } + }, + "source": [ + "One exciting example is an xarray inspired named dimensions API for more human-readable tensor manipulations:\n", + "\n", + "```py\n", + "\n", + "import pytensor.xtensor as ptx\n", + "\n", + "logits = ptx.xtensor(dims=[\"user\", \"choice\"])\n", + "probs = ptx.softmax(logits, dim=\"choice\")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "f5815385", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Existing Pytensor Ecosystem\n", + "\n", + "\n", + "- [PyMC](https://github.com/pymc-devs/pymc)\n", + "- [pymc-marketing](https://github.com/pymc-labs/pymc-marketing)\n", + "- [CausalPy](https://github.com/pymc-labs/causalpy)\n", + "- [gEconpy](https://github.com/jessegrabowski/gEconpy)\n", + "- [pyhs3](https://github.com/scipp-atlas/pyhs3)\n", + "- [HSSM](https://github.com/lnccbrown/HSSM)\n", + "- [Celmech](https://github.com/shadden/celmech)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28e55d35", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "celltoolbar": "Slideshow", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}