Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .translate/state/jax_intro.md.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
source-sha: 450bafecd23db638602150b47f4272b98aad3146
synced-at: "2026-04-14"
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
synced-at: "2026-05-14"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 7
tool-version: 0.14.1
tool-version: 0.15.0
6 changes: 3 additions & 3 deletions .translate/state/numpy_vs_numba_vs_jax.md.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
source-sha: 450bafecd23db638602150b47f4272b98aad3146
synced-at: "2026-04-14"
source-sha: d08a73d48a409509d7d6f6585b99c2c8909c9a28
synced-at: "2026-05-14"
model: claude-sonnet-4-6
mode: UPDATE
section-count: 3
tool-version: 0.14.1
tool-version: 0.15.0
45 changes: 32 additions & 13 deletions lectures/jax_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ translation:
JAX as a NumPy Replacement::Differences::A Workaround: 变通方法
Functional Programming: 函数式编程
Functional Programming::Pure functions: 纯函数
Functional Programming::Examples: 示例
Functional Programming::Examples -- Pure and Impure: 示例——纯函数与非纯函数
Functional Programming::Why Functional Programming?: 为什么要函数式编程?
Random numbers: 随机数
Random numbers::NumPy / MATLAB Approach: NumPy / MATLAB 方法
Expand Down Expand Up @@ -346,19 +346,20 @@ a
* 不会改变全局状态
* 不会修改传递给函数的数据(不可变数据)

### 示例
### 示例——纯函数与非纯函数

以下是一个*非纯*函数的示例:

```{code-cell} ipython3
tax_rate = 0.1
prices = [10.0, 20.0]

def add_tax(prices):
for i, price in enumerate(prices):
prices[i] = price * (1 + tax_rate)
print('Post-tax prices: ', prices)
return prices

prices = [10.0, 20.0]
add_tax(prices)
prices
```

这个函数不是纯函数,因为:
Expand All @@ -369,15 +370,21 @@ def add_tax(prices):
以下是一个*纯*版本:

```{code-cell} ipython3
tax_rate = 0.1
prices = (10.0, 20.0)

def add_tax_pure(prices, tax_rate):
new_prices = [price * (1 + tax_rate) for price in prices]
return new_prices

tax_rate = 0.1
prices = (10.0, 20.0)
after_tax_prices = add_tax_pure(prices, tax_rate)
after_tax_prices
```

这个纯版本通过函数参数使所有依赖关系变得明确,并且不修改任何外部状态。
这是纯函数,因为:

* 所有依赖关系通过函数参数显式传递
* 并且不修改任何外部状态

### 为什么要函数式编程?

Expand Down Expand Up @@ -427,7 +434,7 @@ print(np.random.randn(2))
* 它是非确定性的:相同的输入,不同的输出
* 它有副作用:它修改了全局随机数生成器状态

在并行化下很危险——必须仔细控制每个线程中发生的事情
在并行化下很危险——必须仔细控制每个线程中发生的事情

### JAX

Expand Down Expand Up @@ -544,7 +551,11 @@ plt.show()
下面的函数使用 `split` 生成 `k` 个(准)独立的随机 `n x n` 矩阵。

```{code-cell} ipython3
def gen_random_matrices(key, n=2, k=3):
def gen_random_matrices(
key, # JAX key for random numbers
n=2, # Matrices will be n x n
k=3 # Number of matrices to generate
):
matrices = []
for _ in range(k):
key, subkey = jax.random.split(key)
Expand All @@ -566,7 +577,7 @@ gen_random_matrices(key)

### 好处

JAX 的显式性带来了显著的好处
如上所述,这种显式性是很有价值的

* 可复现性:通过重用密钥轻松重现结果
* 并行化:控制各个线程上发生的事情
Expand Down Expand Up @@ -647,7 +658,14 @@ with qe.Timer():

结果与 `cos` 示例类似——JAX 更快,尤其是在 JIT 编译后的第二次运行中。

但我们仍在使用即时执行——大量内存和读写开销。
这是因为单个数组操作在 GPU 上并行化了。

但我们仍在使用即时执行:

* 由于中间数组导致大量内存占用
* 大量内存读写

此外,GPU 上还会启动许多独立的内核。

### 编译整个函数

Expand Down Expand Up @@ -681,7 +699,8 @@ with qe.Timer():

* 基于整个计算序列的积极优化
* 消除对硬件加速器的多次调用
* 不创建中间数组

内存占用也大大降低——不再创建中间数组。

顺便提一下,当针对 JIT 编译器的函数时,更常见的语法是:

Expand Down
111 changes: 64 additions & 47 deletions lectures/numpy_vs_numba_vs_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ translation:
Vectorized operations: 向量化运算
Vectorized operations::Problem Statement: 问题陈述
Vectorized operations::NumPy vectorization: NumPy 向量化
Vectorized operations::Memory Issues: 内存问题
Vectorized operations::A Comparison with Numba: 与 Numba 的比较
Vectorized operations::Parallelized Numba: 并行化的 Numba
Vectorized operations::Vectorized code with JAX: 使用 JAX 的向量化代码
Expand Down Expand Up @@ -152,16 +153,33 @@ for x in grid:

让我们切换到 NumPy 并使用更大的网格。

```{code-cell} ipython3
grid = np.linspace(-3, 3, 3_000) # Large grid
```

作为向量化的第一步,我们可能会尝试这样的方式

```{code-cell} ipython3
# Large grid
z = np.max(f(grid, grid)) # This is wrong!
```

这里的问题是 `f(grid, grid)` 并不遵循嵌套循环。

从上图来看,它只计算了对角线上的 `f` 值。

要让 NumPy 在每个 `x,y` 对上计算 `f(x,y)`,我们需要使用 `np.meshgrid`。

这里我们使用 `np.meshgrid` 来创建二维输入网格 `x` 和 `y`,使得 `f(x, y)` 能生成乘积网格上的所有计算结果。

```{code-cell} ipython3
# Large grid
grid = np.linspace(-3, 3, 3_000)

x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid

with qe.Timer():
z_max_numpy = np.max(f(x, y))
z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
```

在向量化版本中,所有循环都在编译后的代码中执行。
Expand All @@ -174,9 +192,29 @@ with qe.Timer():
print(f"NumPy result: {z_max_numpy:.6f}")
```

### 内存问题

我们在合理的时间内得到了正确的解——但内存使用量非常大。

虽然扁平数组占用内存较少

```{code-cell} ipython3
grid.nbytes
```

但网格矩阵是二维的,因此内存占用非常大

```{code-cell} ipython3
x_mesh.nbytes + y_mesh.nbytes
```

此外,NumPy 的即时执行会创建许多相同大小的中间数组!

在实际研究计算中,这种内存使用可能是一个大问题。

### 与 Numba 的比较

现在让我们看看能否使用简单循环的 Numba 获得更好的性能。
让我们看看能否使用简单循环的 Numba 获得更好的性能。

```{code-cell} ipython3
@numba.jit
Expand Down Expand Up @@ -207,13 +245,13 @@ with qe.Timer():
compute_max_numba(grid)
```

根据您的机器,Numba 版本可能比 NumPy 稍慢或稍快
注意我们几乎不使用任何内存——我们只需要一维的 `grid`

在大多数情况下,我们发现 Numba 略胜一筹
此外,执行速度也很好

一方面,NumPy 将高效的算术运算与一定程度的多线程结合在一起,这提供了优势
在大多数机器上,Numba 版本会比 NumPy 稍快一些

另一方面,Numba 例程使用的内存少得多,因为我们只处理一个一维网格
原因是高效的机器码加上更少的内存读写

### 并行化的 Numba

Expand Down Expand Up @@ -307,39 +345,25 @@ with qe.Timer():

### JAX 加 vmap

NumPy 代码和上述 JAX 代码都存在一个问题:

虽然扁平数组占用内存较少

```{code-cell} ipython3
grid.nbytes
```

但网格矩阵的内存占用很大
由于我们在上面使用了 `jax.jit`,我们避免了创建许多中间数组。

```{code-cell} ipython3
x_mesh.nbytes + y_mesh.nbytes
```
但我们仍然创建了大数组 `z_max`、`x_mesh` 和 `y_mesh`。

在实际研究计算中,这种额外的内存使用可能是一个大问题。

幸运的是,JAX 提供了一种使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 的不同方法。

`vmap` 的思路是将向量化分阶段进行,将一个对单个值进行操作的函数转化为对数组进行操作的函数。
幸运的是,我们可以通过使用 [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) 来避免这一问题。

以下是我们将其应用于当前问题的方式。

```{code-cell} ipython3
@jax.jit
def compute_max_vmap(grid):
# 构建一个对给定 y,在所有 x 上取最大值的函数
f_vec_x_max = lambda y: jnp.max(f(grid, y))
compute_column_max = lambda y: jnp.max(f(grid, y))
# 向量化该函数,以便我们可以同时对所有 y 调用
f_vec_max = jax.vmap(f_vec_x_max)
# 在每个 y 处计算所有 x 上的最大值
maxes = f_vec_max(grid)
# 计算最大值的最大值并返回
return jnp.max(maxes)
vectorized_compute_column_max = jax.vmap(compute_column_max)
# 在每一行处计算列最大值
column_maxes = vectorized_compute_column_max(grid)
# 计算列最大值的最大值并返回
return jnp.max(column_maxes)
```

注意我们从不创建
Expand All @@ -348,6 +372,8 @@ def compute_max_vmap(grid):
* 二维网格 `y_mesh` 或
* 二维数组 `f(x,y)`

与 Numba 类似,我们只使用扁平数组 `grid`。

并且由于所有内容都在单个 `@jax.jit` 下,编译器可以将所有操作融合为一个优化的内核。

让我们试试。
Expand Down Expand Up @@ -378,13 +404,11 @@ with qe.Timer():

它在速度(通过 JIT 编译和并行化)和内存效率(通过 vmap)两方面都优于 NumPy。

此外,`vmap` 方法有时可以带来更清晰的代码
在 GPU 上运行时,它也优于 Numba

虽然 Numba 令人印象深刻,但 JAX 的优势在于,对于完全向量化的运算,我们可以在配备硬件加速器的机器上运行完全相同的代码,并在无需额外努力的情况下获得所有收益。

此外,JAX 已经知道如何有效地并行化许多常见的数组运算,这是快速执行的关键。

对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
```{note}
Numba 可以通过 `numba.cuda` 支持 GPU 编程,但这样我们需要手动进行并行化。对于经济学、计量经济学和金融学中遇到的大多数情况,将高效并行化的工作交给 JAX 编译器,远比尝试手工编写这些例程要好得多。
```

## 顺序运算

Expand Down Expand Up @@ -536,8 +560,6 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然

虽然 JAX 的 `at[t].set` 语法确实允许逐元素更新,但整体代码仍然比 Numba 等价版本更难阅读。

对于这类顺序运算,在代码清晰度和实现便利性方面,Numba 是明显的赢家。

## 总体建议

让我们退一步,总结一下各方的权衡取舍。
Expand All @@ -550,17 +572,12 @@ Numba 版本简单直观,易于阅读:我们只需分配一个数组,然

此外,JAX 函数支持自动微分,我们将在 {doc}`autodiff` 中进一步探讨。

对于**顺序操作**,Numba 具有明显优势
对于**顺序操作**,Numba 具有更简洁的语法

代码自然且可读——只需一个带有装饰器的 Python 循环——性能也非常出色。

JAX 可以通过 `lax.scan` 处理顺序问题,但语法不够直观。

```{note}
`lax.scan` 的一个重要优势是它支持通过循环进行自动微分,而 Numba 无法做到这一点。
如果您需要对顺序计算进行微分(例如,计算轨迹对模型参数的敏感性),尽管语法不够自然,JAX 仍是更好的选择。
```
JAX 可以通过 `lax.fori_loop` 或 `lax.scan` 处理顺序问题,但语法不够直观。

在实践中,许多问题涉及两种模式的混合
另一方面,JAX 版本支持自动微分

一个实用的经验法则是:对于新项目默认使用 JAX,尤其是当硬件加速或可微分性可能有用时,而当您有一个需要快速且可读的紧凑顺序循环时,则选用 Numba
例如,当我们希望计算轨迹对模型参数的敏感性时,这可能会很有用
Loading