Skip to content

Commit d46baa1

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent 641af6c commit d46baa1

1 file changed

Lines changed: 38 additions & 22 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ translation:
2424
Sequential operations::Numba Version: نسخه Numba
2525
Sequential operations::JAX Version: نسخه JAX
2626
Sequential operations::Summary: خلاصه
27+
Overall recommendations: توصیه‌های کلی
2728
---
2829

2930
(parallel)=
@@ -69,13 +70,17 @@ tags: [hide-output]
6970

7071
```{code-cell} ipython3
7172
import random
73+
from functools import partial
74+
7275
import numpy as np
76+
import numba
7377
import quantecon as qe
7478
import matplotlib.pyplot as plt
7579
from mpl_toolkits.mplot3d.axes3d import Axes3D
7680
from matplotlib import cm
7781
import jax
7882
import jax.numpy as jnp
83+
from jax import lax
7984
```
8085

8186
## عملیات برداری شده
@@ -113,7 +118,7 @@ ax.plot_surface(x,
113118
y,
114119
f(x, y),
115120
rstride=2, cstride=2,
116-
cmap=cm.jet,
121+
cmap=cm.viridis,
117122
alpha=0.7,
118123
linewidth=0.25)
119124
ax.set_zlim(-0.5, 1.0)
@@ -139,7 +144,6 @@ for x in grid:
139144
m = z
140145
```
141146

142-
143147
### برداری‌سازی NumPy
144148

145149
اگر به برداری‌سازی به سبک NumPy تغییر دهیم، می‌توانیم از یک شبکه بسیار بزرگتر استفاده کنیم و کد نسبتاً سریع اجرا می‌شود.
@@ -164,14 +168,11 @@ print(f"NumPy result: {z_max_numpy:.6f}")
164168

165169
(موازی‌سازی نمی‌تواند بسیار کارآمد باشد زیرا فایل باینری قبل از اینکه اندازه آرایه‌های `x` و `y` را ببیند کامپایل می‌شود.)
166170

167-
168171
### مقایسه با Numba
169172

170173
حالا بیایید ببینیم آیا می‌توانیم با استفاده از Numba با یک حلقه ساده به عملکرد بهتری دست یابیم.
171174

172175
```{code-cell} ipython3
173-
import numba
174-
175176
@numba.jit
176177
def compute_max_numba(grid):
177178
m = -np.inf
@@ -185,9 +186,9 @@ def compute_max_numba(grid):
185186
grid = np.linspace(-3, 3, 3_000)
186187
187188
with qe.Timer(precision=8):
188-
z_max_numpy = compute_max_numba(grid)
189+
z_max_numba = compute_max_numba(grid)
189190
190-
print(f"Numba result: {z_max_numpy:.6f}")
191+
print(f"Numba result: {z_max_numba:.6f}")
191192
```
192193

193194
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود.
@@ -203,7 +204,6 @@ with qe.Timer(precision=8):
203204

204205
از طرف دیگر، روال Numba از حافظه بسیار کمتری استفاده می‌کند، زیرا ما فقط با یک شبکه یک‌بعدی کار می‌کنیم.
205206

206-
207207
### Numba موازی شده
208208

209209
حالا بیایید موازی‌سازی با Numba را با استفاده از `prange` امتحان کنیم:
@@ -278,7 +278,6 @@ with qe.Timer(precision=8):
278278

279279
برای دستگاه‌های قدرتمندتر و اندازه‌های شبکه بزرگتر، موازی‌سازی می‌تواند افزایش سرعت قابل توجهی ایجاد کند، حتی روی CPU.
280280

281-
282281
### کد برداری شده با JAX
283282

284283
در ظاهر، کد برداری شده در JAX شبیه به کد NumPy است.
@@ -299,7 +298,7 @@ def f(x, y):
299298

300299
```{code-cell} ipython3
301300
grid = jnp.linspace(-3, 3, 3_000)
302-
x_mesh, y_mesh = np.meshgrid(grid, grid)
301+
x_mesh, y_mesh = jnp.meshgrid(grid, grid)
303302
304303
with qe.Timer(precision=8):
305304
z_max = jnp.max(f(x_mesh, y_mesh))
@@ -316,11 +315,10 @@ with qe.Timer(precision=8):
316315
z_max.block_until_ready()
317316
```
318317

319-
پس از کامپایل، JAX به دلیل شتاب GPU به طور قابل توجهی سریعتر از NumPy است.
318+
پس از کامپایل، JAX به ویژه روی GPU به طور قابل توجهی سریعتر از NumPy است.
320319

321320
سربار کامپایل یک هزینه یک‌بار مصرف است که زمانی که تابع به طور مکرر فراخوانی می‌شود، بازگشت سرمایه دارد.
322321

323-
324322
### JAX به علاوه vmap
325323

326324
یک مشکل با کد NumPy و کد JAX وجود دارد:
@@ -382,7 +380,6 @@ with qe.Timer(precision=8):
382380

383381
ما این ایده‌ها را بیشتر هنگام حل مسائل بزرگتر بررسی خواهیم کرد.
384382

385-
386383
### نسخه 2 vmap
387384

388385
می‌توانیم با استفاده از vmap همچنان کارآمدتر از نظر حافظه باشیم.
@@ -417,7 +414,7 @@ def compute_max_vmap_v2(grid):
417414
with qe.Timer(precision=8):
418415
z_max = compute_max_vmap_v2(grid).block_until_ready()
419416
420-
print(f"JAX vmap v1 result: {z_max:.6f}")
417+
print(f"JAX vmap v2 result: {z_max:.6f}")
421418
```
422419

423420
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود:
@@ -429,7 +426,6 @@ with qe.Timer(precision=8):
429426

430427
اگر این را روی GPU اجرا می‌کنید، همانطور که ما این کار را می‌کنیم، باید افزایش سرعت قابل توجه دیگری را ببینید.
431428

432-
433429
### خلاصه
434430

435431
به نظر ما، JAX برنده برای عملیات برداری شده است.
@@ -444,7 +440,6 @@ with qe.Timer(precision=8):
444440

445441
برای اکثر موارد مواجه شده در اقتصاد، اقتصادسنجی و امور مالی، بسیار بهتر است که برای موازی‌سازی کارآمد به کامپایلر JAX تحویل دهیم تا اینکه سعی کنیم این روال‌ها را خودمان کدنویسی دستی کنیم.
446442

447-
448443
## عملیات ترتیبی
449444

450445
برخی عملیات ذاتاً ترتیبی هستند -- و از این رو برداری کردن آنها دشوار یا غیرممکن است.
@@ -453,7 +448,6 @@ with qe.Timer(precision=8):
453448

454449
برای مقایسه این انتخاب‌ها، مسئله تکرار روی نقشه درجه دوم را که در {doc}`سخنرانی Numba <numba>` خود دیدیم، دوباره بررسی خواهیم کرد.
455450

456-
457451
### نسخه Numba
458452

459453
در اینجا نسخه Numba آمده است.
@@ -497,9 +491,6 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
497491
(ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود.)
498492

499493
```{code-cell} ipython3
500-
from jax import lax
501-
from functools import partial
502-
503494
cpu = jax.devices("cpu")[0]
504495
505496
@partial(jax.jit, static_argnums=(1,), device=cpu)
@@ -542,7 +533,6 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
542533

543534
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند، با این که Numba معمولاً (اما نه همیشه) سرعت‌های کمی بهتری در عملیات کاملاً ترتیبی ارائه می‌دهد.
544535

545-
546536
### خلاصه
547537

548538
در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه می‌دهند، *تفاوت‌های قابل توجهی در خوانایی کد و سهولت استفاده وجود دارد*.
@@ -555,4 +545,30 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
555545

556546
علاوه بر این، آرایه‌های تغییرناپذیر JAX به این معنی است که نمی‌توانیم به سادگی عناصر آرایه را در جا به‌روزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت می‌کند.
557547

558-
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی، و همچنین عملکرد بالا است.
548+
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی، و همچنین عملکرد بالا است.
549+
550+
## توصیه‌های کلی
551+
552+
حال قدمی به عقب بر می‌داریم و مبادلات را خلاصه می‌کنیم.
553+
554+
برای **عملیات برداری‌سازی‌شده**، JAX قوی‌ترین انتخاب است.
555+
556+
به لطف کامپایل JIT و موازی‌سازی کارآمد روی CPU و GPU، در سرعت با NumPy برابری می‌کند یا از آن پیشی می‌گیرد.
557+
558+
تبدیل `vmap` مصرف حافظه را کاهش می‌دهد و اغلب نسبت به برداری‌سازی سنتی مبتنی بر meshgrid، کد روشن‌تری ارائه می‌دهد.
559+
560+
علاوه بر این، توابع JAX به‌صورت خودکار مشتق‌پذیر هستند، همان‌طور که در {doc}`autodiff` بررسی می‌کنیم.
561+
562+
برای **عملیات ترتیبی**، Numba مزایای آشکاری دارد.
563+
564+
کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است.
565+
566+
JAX می‌تواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است و برای کارهای کاملاً ترتیبی، بهره‌وری اضافی ناچیز است.
567+
568+
با این حال، `lax.scan` یک مزیت مهم دارد: از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کند، که Numba قادر به انجام آن نیست.
569+
570+
اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است.
571+
572+
در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.
573+
574+
یک قاعده سرانگشتی مناسب: برای پروژه‌های جدید، به‌ویژه زمانی که شتاب‌دهی سخت‌افزاری یا مشتق‌پذیری ممکن است مفید باشد، به‌طور پیش‌فرض از JAX استفاده کنید، و هنگامی که یک حلقه ترتیبی فشرده نیاز به سرعت و خوانایی دارد، به Numba متوسل شوید.

0 commit comments

Comments
 (0)