diff --git a/.translate/state/jax_intro.md.yml b/.translate/state/jax_intro.md.yml index 4f0ca12..6d393a2 100644 --- a/.translate/state/jax_intro.md.yml +++ b/.translate/state/jax_intro.md.yml @@ -1,6 +1,6 @@ -source-sha: 450bafecd23db638602150b47f4272b98aad3146 -synced-at: "2026-04-14" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-05-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 7 -tool-version: 0.14.1 +tool-version: 0.15.0 diff --git a/.translate/state/numpy_vs_numba_vs_jax.md.yml b/.translate/state/numpy_vs_numba_vs_jax.md.yml index e904fbe..66ae445 100644 --- a/.translate/state/numpy_vs_numba_vs_jax.md.yml +++ b/.translate/state/numpy_vs_numba_vs_jax.md.yml @@ -1,6 +1,6 @@ -source-sha: 450bafecd23db638602150b47f4272b98aad3146 -synced-at: "2026-04-14" +source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28 +synced-at: "2026-05-14" model: claude-sonnet-4-6 mode: UPDATE section-count: 3 -tool-version: 0.14.1 +tool-version: 0.15.0 diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 49e6cee..a2af66b 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -24,7 +24,7 @@ translation: JAX as a NumPy Replacement::Differences::A Workaround: 变通方法 Functional Programming: 函数式编程 Functional Programming::Pure functions: 纯函数 - Functional Programming::Examples: 示例 + Functional Programming::Examples -- Pure and Impure: 示例——纯函数与非纯函数 Functional Programming::Why Functional Programming?: 为什么要函数式编程? Random numbers: 随机数 Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法 @@ -346,19 +346,20 @@ a * 不会改变全局状态 * 不会修改传递给函数的数据(不可变数据) -### 示例 +### 示例——纯函数与非纯函数 以下是一个*非纯*函数的示例: ```{code-cell} ipython3 tax_rate = 0.1 -prices = [10.0, 20.0] def add_tax(prices): for i, price in enumerate(prices): prices[i] = price * (1 + tax_rate) - print('Post-tax prices: ', prices) - return prices + +prices = [10.0, 20.0] +add_tax(prices) +prices ``` 这个函数不是纯函数,因为: @@ -369,15 +370,21 @@ def add_tax(prices): 以下是一个*纯*版本: ```{code-cell} ipython3 -tax_rate = 0.1 -prices = (10.0, 20.0) def add_tax_pure(prices, tax_rate): new_prices = [price * (1 + tax_rate) for price in prices] return new_prices + +tax_rate = 0.1 +prices = (10.0, 20.0) +after_tax_prices = add_tax_pure(prices, tax_rate) +after_tax_prices ``` -这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。 +这是纯函数,因为: + +* 所有依赖关系通过函数参数显式传递 +* 并且不修改任何外部状态 ### 为什么要函数式编程? @@ -427,7 +434,7 @@ print(np.random.randn(2)) * 它是非确定性的:相同的输入,不同的输出 * 它有副作用:它修改了全局随机数生成器状态 -在并行化下很危险——必须仔细控制每个线程中发生的事情! +在并行化下很危险——必须仔细控制每个线程中发生的事情。 ### JAX @@ -544,7 +551,11 @@ plt.show() 下面的函数使用 `split` 生成 `k` 个(准)独立的随机 `n x n` 矩阵。 ```{code-cell} ipython3 -def gen_random_matrices(key, n=2, k=3): +def gen_random_matrices( + key, # JAX key for random numbers + n=2, # Matrices will be n x n + k=3 # Number of matrices to generate + ): matrices = [] for _ in range(k): key, subkey = jax.random.split(key) @@ -566,7 +577,7 @@ gen_random_matrices(key) ### 好处 -JAX 的显式性带来了显著的好处: +如上所述,这种显式性是很有价值的: * 可复现性:通过重用密钥轻松重现结果 * 并行化:控制各个线程上发生的事情 @@ -647,7 +658,14 @@ with qe.Timer(): 结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。 -但我们仍在使用即时执行——大量内存和读写开销。 +这是因为单个数组操作在 GPU 上并行化了。 + +但我们仍在使用即时执行: + +* 由于中间数组导致大量内存占用 +* 大量内存读写 + +此外,GPU 上还会启动许多独立的内核。 ### 编译整个函数 @@ -681,7 +699,8 @@ with qe.Timer(): * 基于整个计算序列的积极优化 * 消除对硬件加速器的多次调用 -* 不创建中间数组 + +内存占用也大大降低——不再创建中间数组。 顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是: diff --git a/lectures/numpy_vs_numba_vs_jax.md b/lectures/numpy_vs_numba_vs_jax.md index 5621d66..560b836 100644 --- a/lectures/numpy_vs_numba_vs_jax.md +++ b/lectures/numpy_vs_numba_vs_jax.md @@ -13,6 +13,7 @@ translation: Vectorized operations: 向量化运算 Vectorized operations::Problem Statement: 问题陈述 Vectorized operations::NumPy vectorization: NumPy 向量化 + Vectorized operations::Memory Issues: 内存问题 Vectorized operations::A Comparison with Numba: 与 Numba 的比较 Vectorized operations::Parallelized Numba: 并行化的 Numba Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码 @@ -152,16 +153,33 @@ for x in grid: 让我们切换到 NumPy 并使用更大的网格。 +```{code-cell} ipython3 +grid = np.linspace(-3, 3, 3_000) # Large grid +``` + +作为向量化的第一步,我们可能会尝试这样的方式 + +```{code-cell} ipython3 +# Large grid +z = np.max(f(grid, grid)) # This is wrong! +``` + +这里的问题是 `f(grid, grid)` 并不遵循嵌套循环。 + +从上图来看,它只计算了对角线上的 `f` 值。 + +要让 NumPy 在每个 `x,y` 对上计算 `f(x,y)`,我们需要使用 `np.meshgrid`。 + 这里我们使用 `np.meshgrid` 来创建二维输入网格 `x` 和 `y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。 ```{code-cell} ipython3 # Large grid grid = np.linspace(-3, 3, 3_000) -x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid +x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid with qe.Timer(): - z_max_numpy = np.max(f(x, y)) + z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works ``` 在向量化版本中,所有循环都在编译后的代码中执行。 @@ -174,9 +192,29 @@ with qe.Timer(): print(f"NumPy result: {z_max_numpy:.6f}") ``` +### 内存问题 + +我们在合理的时间内得到了正确的解——但内存使用量非常大。 + +虽然扁平数组占用内存较少 + +```{code-cell} ipython3 +grid.nbytes +``` + +但网格矩阵是二维的,因此内存占用非常大 + +```{code-cell} ipython3 +x_mesh.nbytes + y_mesh.nbytes +``` + +此外,NumPy 的即时执行会创建许多相同大小的中间数组! + +在实际研究计算中,这种内存使用可能是一个大问题。 + ### 与 Numba 的比较 -现在让我们看看能否使用简单循环的 Numba 获得更好的性能。 +让我们看看能否使用简单循环的 Numba 获得更好的性能。 ```{code-cell} ipython3 @numba.jit @@ -207,13 +245,13 @@ with qe.Timer(): compute_max_numba(grid) ``` -根据您的机器,Numba 版本可能比 NumPy 稍慢或稍快。 +注意我们几乎不使用任何内存——我们只需要一维的 `grid`。 -在大多数情况下,我们发现 Numba 略胜一筹。 +此外,执行速度也很好。 -一方面,NumPy 将高效的算术运算与一定程度的多线程结合在一起,这提供了优势。 +在大多数机器上,Numba 版本会比 NumPy 稍快一些。 -另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格。 +原因是高效的机器码加上更少的内存读写。 ### 并行化的 Numba @@ -307,25 +345,11 @@ with qe.Timer(): ### JAX 加 vmap -NumPy 代码和上述 JAX 代码都存在一个问题: - -虽然扁平数组占用内存较少 - -```{code-cell} ipython3 -grid.nbytes -``` - -但网格矩阵的内存占用很大 +由于我们在上面使用了 `jax.jit`,我们避免了创建许多中间数组。 -```{code-cell} ipython3 -x_mesh.nbytes + y_mesh.nbytes -``` +但我们仍然创建了大数组 `z_max`、`x_mesh` 和 `y_mesh`。 -在实际研究计算中,这种额外的内存使用可能是一个大问题。 - -幸运的是,JAX 提供了一种使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 的不同方法。 - -`vmap` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。 +幸运的是,我们可以通过使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 来避免这一问题。 以下是我们将其应用于当前问题的方式。 @@ -333,13 +357,13 @@ x_mesh.nbytes + y_mesh.nbytes @jax.jit def compute_max_vmap(grid): # 构建一个对给定 y,在所有 x 上取最大值的函数 - f_vec_x_max = lambda y: jnp.max(f(grid, y)) + compute_column_max = lambda y: jnp.max(f(grid, y)) # 向量化该函数,以便我们可以同时对所有 y 调用 - f_vec_max = jax.vmap(f_vec_x_max) - # 在每个 y 处计算所有 x 上的最大值 - maxes = f_vec_max(grid) - # 计算最大值的最大值并返回 - return jnp.max(maxes) + vectorized_compute_column_max = jax.vmap(compute_column_max) + # 在每一行处计算列最大值 + column_maxes = vectorized_compute_column_max(grid) + # 计算列最大值的最大值并返回 + return jnp.max(column_maxes) ``` 注意我们从不创建 @@ -348,6 +372,8 @@ def compute_max_vmap(grid): * 二维网格 `y_mesh` 或 * 二维数组 `f(x,y)` +与 Numba 类似,我们只使用扁平数组 `grid`。 + 并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。 让我们试试。 @@ -378,13 +404,11 @@ with qe.Timer(): 它在速度(通过 JIT 编译和并行化)和内存效率(通过 vmap)两方面都优于 NumPy。 -此外,`vmap` 方法有时可以带来更清晰的代码。 +在 GPU 上运行时,它也优于 Numba。 -虽然 Numba 令人印象深刻,但 JAX 的优势在于,对于完全向量化的运算,我们可以在配备硬件加速器的机器上运行完全相同的代码,并在无需额外努力的情况下获得所有收益。 - -此外,JAX 已经知道如何有效地并行化许多常见的数组运算,这是快速执行的关键。 - -对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。 +```{note} +Numba 可以通过 `numba.cuda` 支持 GPU 编程,但这样我们需要手动进行并行化。对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。 +``` ## 顺序运算 @@ -536,8 +560,6 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然 虽然 JAX 的 `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读。 -对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。 - ## 总体建议 让我们退一步,总结一下各方的权衡取舍。 @@ -550,17 +572,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然 此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进一步探讨。 -对于**顺序操作**,Numba 具有明显优势。 +对于**顺序操作**,Numba 具有更简洁的语法。 代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色。 -JAX 可以通过 `lax.scan` 处理顺序问题,但语法不够直观。 - -```{note} -`lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。 -如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。 -``` +JAX 可以通过 `lax.fori_loop` 或 `lax.scan` 处理顺序问题,但语法不够直观。 -在实践中,许多问题涉及两种模式的混合。 +另一方面,JAX 版本支持自动微分。 -一个实用的经验法则是:对于新项目默认使用 JAX,尤其是当硬件加速或可微分性可能有用时,而当您有一个需要快速且可读的紧凑顺序循环时,则选用 Numba。 +例如,当我们希望计算轨迹对模型参数的敏感性时,这可能会很有用。