JAX
def square(x):
XLA сприяє:
Основні переважні аспекти JAX:
<div style="background:#fff4e5; border-left:6px solid #f39c12; padding:12px; margin:12px 0;">
{{SEO
|title=JAX — Python-бібліотека для високопродуктивних обчислень, automatic differentiation, NumPy API і машинного навчання
|description=JAX — Wiki-стаття про Python-бібліотеку для високопродуктивних числових обчислень, automatic differentiation, JIT-компіляції, NumPy-подібного API, GPU/TPU-прискорення і machine learning. Розглянуто jax.numpy, grad, jit, vmap, pmap, XLA, pure functions, immutable arrays, PRNG, JAX ecosystem, Flax, Optax, Haiku, Equinox, переваги, обмеження, безпеку і відповідальне використання.
|keywords=JAX, jax.numpy, jnp, Google JAX, Python JAX, automatic differentiation, autograd, jit, vmap, pmap, XLA, GPU, TPU, NumPy API, machine learning, deep learning, high-performance computing, differentiable programming, Flax, Optax, Haiku, Equinox, neural networks, functional programming, JAX arrays
|alternativeTo=ручна реалізація automatic differentiation; повільні NumPy-обчислення без GPU/TPU; самописна JIT-компіляція; складне масштабування числових обчислень; ручне векторизування циклів; окремі інструменти для gradient-based optimization; класичні Python-обчислення без accelerator support
}}
* писати NumPy-подібний код;
* автономно обчислювати gradients;
* компілювати функції через jit;
* векторизувати функції через vmap;
* паралелити обчислення через pmap;
* працювати з GPU і TPU;
* будувати neural networks через додаткові бібліотеки;
* створювати differentiable programs;
* оптимізувати числові функції;
* виконувати research-oriented ML-експерименти.; '''Практична цінність:''' якщо наукова модель диференційована, JAX має змогу допомогти оптимізувати її параметри через gradients.; '''Pytrees''' — це вкладені структури Python, які JAX має змогу обробляти як дерева даних.; * Документація Flax.; TensorFlow
'''Просте пояснення:''' JAX Array — це масив для числових обчислень, який має змогу працювати в JAX-світі: з gradients, JIT і прискорювачами.; * Документація Haiku.; Приклади:
== Приклади задач ==
import jax
</div>
!; Optax має змогу використовуватися для:
<div style="background:#fff7ed; border-left:6px solid #fb923c; padding:12px; margin:12px 0;">
== JAX для research ==
Для neural networks зазвичай використовують:
Приклад:
Приклади:
Для налагодження корисно:
* control flow;
* shapes;
* static arguments;
* error messages;
* recompilation;
* debug behavior.;</div>
</div>
* physics simulations;
* optimization;
* differential equations;
* computational biology;
* probabilistic modeling;
* numerical methods;
* inverse problems;
* differentiable rendering;
* scientific machine learning.; JAX часто порівнюють із TensorFlow.;</div>
== Immutable arrays ==
'''jax.grad''' — це трансформація, яка створює функцію для обчислення gradient.; * JAX Quickstart.; Результат: compiled version функції для швидшого виконання.; !; Проблеми можуть виникати, якщо:
import jax
'''Увага:''' JAX не автономно пришвидшує будь-який Python-код.; Навколо нього існує програмний пакет бібліотек.; '''Практична роль:''' Equinox зручний для користувачів, які хочуть поєднати JAX-підхід із простими Python-класами.; Потрібно враховувати:
- писати pure functions;
- передавати state явно;
- використовувати jax.numpy замість numpy у JAX-функціях;
- спочатку перевіряти код без jit;
- використовувати jit для “гарячих” обчислень;
- використовувати vmap замість ручних циклів;
- контролювати shapes і dtypes;
- правильно працювати з PRNG keys;
- зберігати прості й тестовані функції;
- вимірювати продуктивність;
- уникати зайвих device-host transfers;
- документувати numerical assumptions;
- тестувати gradients.;
Рекомендовано:
- batch processing;
- per-example gradients;
- vectorized evaluation;
- заміни Python loops;
- прискорення обчислень;
- cleaner code.; JAX часто застосовується в машинному навчанні, deep learning, наукових обчисленнях, optimization, differentiable programming, research-проєктах і задачах, де потрібне поєднання гнучкого Python-коду з високою продуктивністю.; !; * писати JAX-код як звичайний NumPy без урахування immutability;
- забувати розділяти random keys;
- додавати side effects у jit-функції;
- очікувати, що print працюватиме як у звичайному Python;
- створювати багато recompilations через змінні shapes;
- використовувати Python loops замість vmap або scan;
- переносити інформаційні дані між CPU і GPU занадто часто;
- не тестувати функції до jit;
- не контролювати dtype;
- не зберігати reproducibility.;
'''Просте пояснення:''' JAX спочатку “дивиться” на функцію як на обчислення, яке можна трансформувати, а вже потім виконує оптимізований варіант.; state.append(x)
jit
def loss(w):
== JAX і Scikit-learn ==
</div>
* multi-GPU training;
* multi-TPU computation;
* паралельного виконання batch;
* distributed-style обчислень;
* масштабування ML-експериментів.; JAX
'''Haiku''' — це бібліотека для neural networks на JAX, створювалась як DeepMind.; Інструмент: jax.vmap.; JAX — це інструмент для обчислень і ML, з цієї причини відповідальність за моделі та їхнє використання залишається за розробником.;=== JIT-компіляція ===
Debugging у JAX має змогу бути складнішим, ніж у звичайному Python, особливо всередині `jit`.;== ліцензійний пакет ==
* параметрів моделей;
* gradients;
* optimizer state;
* batch data;
* structured outputs;
* tree transformations.; * JAX automatic differentiation documentation.;== Типові помилки в JAX ==
== переважні аспекти JAX ==
def f(x):
* спочатку запускати без jit;
* перевіряти shapes;
* перевіряти dtypes;
* використовувати менші приклади;
* уникати зайвої складності;
* тестувати функції окремо;
* додавати asserts там, де доречно;
* розуміти tracing;
* обережно працювати з print у compiled code.;== Haiku ==
<div style="background:#fff4e5; border-left:6px solid #f39c12; padding:12px; margin:12px 0;">
<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">
'''Висновок:''' NumPy — базова бібліотека числових обчислень, а JAX додає до NumPy-подібного стилю autodiff, JIT і accelerator support.; Це низькорівнева й гнучка платформа числових обчислень і трансформацій, поверх якої часто використовують додаткові бібліотеки.; '''JAX Array''' — це ключовий тип масиву в JAX.; JAX
<syntaxhighlight lang="python">
import jax.numpy as jnp
Pytree має змогу містити:
* optimization;
* training neural networks;
* loss functions;
* scientific computing;
* differentiable simulations;
* gradient-based methods.; JAX сам по собі не має такого центрального high-level neural network API, як `torch.nn` у PyTorch або Keras у TensorFlow.;== Flax ==
<syntaxhighlight lang="python">
</div>
* list;
* tuple;
* dict;
* dataclass;
* nested structures;
* arrays;
* parameters of neural networks.;
- створювати modules;
- керувати parameters;
- будувати neural networks;
- працювати з JAX transformations;
- організовувати model code.;== Загальний характеристика ==
y = x.at [0].set(10)
JAX і Scikit-learn мають різні ролі.; |-
| ключовий стиль | Functional programming і transformations | Imperative/eager style із dynamic computation graph |
| Autodiff | grad як функціональна трансформація | autograd через tensor operations |
| Neural network API | Зазвичай через Flax, Haiku, Equinox | torch.nn вбудований у PyTorch |
| Research | Сильний у composable transformations і accelerator-oriented code | Дуже популярний у deep learning research |
| Стан моделі | Часто передається явно | Часто зберігається в modules/objects |
key = jax.random.PRNGKey(0)
JAX для neural networks
JAX arrays схожі на NumPy arrays, але мають важливі відмінності:
Equinox
return (w - 5.0) ** 2
Інструменти: JAX + Flax/Haiku/Equinox + Optax.;</syntaxhighlight> x = jnp.array([1.0, 2.0, 3.0]) Головне правило: у JAX shapes і dtypes — це частина дизайну програми, а не другорядна деталь.; b = jax.random.uniform(key2, shape=(3,))
vmap
Flax застосовується для:
Вона надає можливість автономно обчислювати похідні функцій.;До них належать:
істотно: pmap складніший за grad, jit і vmap.; key1, key2 = jax.random.split(key)
- ліцензію JAX;
- ліцензії залежностей;
- ліцензії моделей;
- ліцензії датасетів;
- умови використання accelerator-середовища;
- політики організації;
- вимоги до attribution.; Практична роль: grad надає можливість писати математичну функцію напряму, а похідні для оптимізації отримувати автономно.;
- компілювати array operations;
- оптимізувати граф обчислень;
- виконувати код на CPU, GPU або TPU;
- об’єднувати операції;
- зменшувати overhead;
- пришвидшувати великі обчислення.; Scikit-learn
JAX застосовується не лише для нейронних мереж, а й для наукових обчислень.;== XLA ==
JAX дуже популярний у research-середовищах, з цієї причини що він надає можливість оперативно експериментувати з математичними ідеями.; Критерій
Небажаний підхід:
Pure functions
</syntaxhighlight>
Небезпека: код має змогу виглядати схожим на NumPy, але поводитися інакше через JAX-трансформації, компіляцію і immutable arrays.; {| class="wikitable"
| ; state = []
JAX додатково часто порівнюють із PyTorch.;
return x ** 2 + 3 * x + 1 Equinox має змогу бути корисним для: Приклад: grad_loss = jax.grad(loss)
JAX є собою open-source проєктом.; def impure_function(x):
Критично: швидка модель не означає правильна модель.; Практична ідея: явні random keys роблять випадковість контрольованішою, відтворюванішою і суміснішою з functional programming.; print(grad_loss(2.0)) Суть jit: JAX компілює Python-функцію у швидший обчислювальний код, який має змогу результативно виконуватися на accelerator hardware.; jax.pmap — це трансформація для паралельного виконання обчислень на кількох devices.;</syntaxhighlight>
a = jax.random.normal(key1, shape=(3,))
`vmap` корисний для: JAX ArrayСуть jax.numpy: розробник пише код у стилі NumPy, але отримує можливість використовувати JAX-трансформації: grad, jit, vmap та інші.; Код потрібно писати з урахуванням JIT, vectorization і device execution.; Критерій Просте пояснення: pytree надає можливість JAX працювати не лише з одним масивом, а з цілою вкладеною структурою масивів.;</syntaxhighlight> істотно: у JAX стан моделі й параметри часто передаються явно, що має змогу бути незвично для користувачів PyTorch або Keras.; !;Це має змогу впливати на:
Примітка: Haiku є собою одним із варіантів neural network framework поверх JAX, але не є собою єдиним стандартом.; import jax.numpy as jnp </syntaxhighlight> Суть екосистеми: JAX дає фундаментальні трансформації й обчислення, а додаткові бібліотеки додають neural networks, optimizers, checkpoints, probabilistic programming та інші інструменти.; Optax — це бібліотека optimization algorithms для JAX.;== grad == Pytrees часто використовуються для:
jax.numpy підтримує роботу багато знайомих операцій: істотно: JAX — це не повна high-level ML-платформа на кшталт TensorFlow або PyTorch.; Для ефективного використання потрібно розуміти devices, sharding, data layout і синхронізацію.;def compute(x): JAX використовує explicit random keys.;Добре працюють: Безпека і відповідальне використанняДив.; додатково
Тут `y` — новий масив із оновленим значенням.; x = jax.random.normal(key, shape=(3,)) офіційно затверджений GitHub-репозиторій JAX описує його як систему для composable transformations of Python+NumPy programs, а серед ключових трансформацій виділяє `grad`, `jit` і `vmap`.; Репозиторій JAX поширюється під ліцензією Apache 2.0.; Висновок: JAX більше схожий на гнучку систему числових трансформацій, а TensorFlow — на ширшу end-to-end ML-платформу.; JAX можна розглядати як систему перетворень для числових Python-функцій.; Підказка: JAX варто вивчати через маленькі функції: спочатку jnp, потім grad, потім jit, потім vmap.; Приклад: Головна перевага: JAX надає можливість комбінувати математично чистий Python-код із потужними трансформаціями для gradients, compilation і vectorization.;Pure function — це функція, яка:
JAX застосовується там, де потрібні швидкі числові обчислення і gradients.; return jnp.sin(x) * jnp.cos(x) + x ** 2 Інструмент: jax.grad.;== jax.numpy == істотно: JAX-трансформації краще працюють із функціональним стилем програмування, де стан передається явно, а не змінюється приховано.;
Для кількох випадкових операцій key потрібно розділяти: Типові помилки користувачів
Гірше працюють:
<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
== Automatic differentiation ==
== PRNG у JAX ==
!; * neural networks;
* scientific computing;
* differentiable programming;
* structured models;
* research code;
* функціонального стилю з класами.; У JAX істотно контролювати shape і dtype.; Під час tracing JAX не завжди має звичайні Python-значення, а функціонує з абстрактними представленнями.;<div style="background:#fff7ed; border-left:6px solid #fb923c; padding:12px; margin:12px 0;">
'''Практична порада:''' якщо задача потребує gradients, accelerator execution і кастомної математики, JAX має змогу бути дуже сильним вибором.;== JAX і TensorFlow ==
</div>
* очікування NumPy-style mutation;
* використання side effects у jit-функціях;
* неправильна робота з random keys;
* надмірна recompilation;
* Python control flow там, де потрібен JAX control flow;
* змішування NumPy і jax.numpy без розуміння наслідків;
* передача Python objects у jit без static_argnums;
* часті device-host transfers;
* неправильне використання vmap;
* недостатнє розуміння shapes.; result = compute(jnp.ones((1000,)))
'''Головне правило:''' JAX найкраще функціонує тоді, коли код написаний функціонально, інформаційні дані мають стабільні shapes, а transformations використовуються усвідомлено.;</div>
import jax.numpy as jnp
import jax.numpy as jnp
Приклад:
<div style="background:#fef2f2; border-left:6px solid #ef4444; padding:12px; margin:12px 0;">
<div style="background:#fef2f2; border-left:6px solid #ef4444; padding:12px; margin:12px 0;">
* Flax;
* Haiku;
* Equinox;
* custom JAX code;
* Optax для optimizers.; import jax
<syntaxhighlight lang="python">
JAX має змогу бути дуже швидким, але продуктивність залежить від стилю коду.; return x * 2
Перед використанням у продукті потрібно перевіряти:
!; * якість даних;
* bias;
* correctness of gradients;
* reproducibility;
* numerical stability;
* privacy;
* security of model deployment;
* ліцензії даних;
* вплив ML-рішень на користувачів;
* моніторинг після deployment.;<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
!; Вона сприяє:
== Tracing ==
print(df(2.0))
'''jax.numpy''' або '''jnp''' — це NumPy-подібний API у JAX.; Це означає, що масив не змінюється “на місці” так само, як це часто роблять у NumPy.;<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
== Shape і dtype ==
<syntaxhighlight lang="python">
== Для чого застосовується JAX ==
!; JAX
!;<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">
<div style="background:#e8f8f5; border-left:6px solid #16a085; padding:12px; margin:12px 0;">
<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">
Результат: векторизована функція без ручного Python loop.; Типовий приклад:
</div>
Можливі складнощі:
Поширені помилки:
'''Висновок:''' PyTorch часто зручніший для класичного object-oriented deep learning workflow, а JAX — для функціонального, трансформаційного і research-oriented підходу.;</div>
'''Для research:''' JAX цінують за те, що transformations можна комбінувати: як ілюстрація, grad + jit + vmap.; '''Просте пояснення:''' vmap бере функцію для одного прикладу і автономно робить її функцією для batch.; Окремо варто відзначити JIT-компіляції, векторизації, роботи з NumPy-подібним API і запуску обчислень на CPU, GPU і TPU виступає ключовою рисою високопродуктивних числових обчислень забезпечується через '''JAX'''.; '''jax.jit''' — це трансформація, яка компілює функцію для швидшого виконання.; import jax
'''Практична роль:''' якщо JAX — це обчислювальний фундамент, то Flax часто застосовується як high-level neural network library поверх JAX.; '''Практична роль:''' Optax часто застосовується разом із JAX і Flax для навчання neural networks.; return x ** 2
JAX найкраще функціонує з '''pure functions'''.; * Документація Equinox.; Вона поєднує NumPy-подібний API із потужними функціональними трансформаціями: `grad`, `jit`, `vmap`, `pmap`.; Якщо задача проста й таблична, Scikit-learn або NumPy можуть бути практичнішими.;<div style="background:#e8f8f5; border-left:6px solid #16a085; padding:12px; margin:12px 0;">
JAX має обмеження.; '''Flax''' — це бібліотека для neural networks на JAX.; @jax.jit
<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
== JAX і PyTorch ==
'''Automatic differentiation''' — одна з ключових можливостей JAX.; |-
| ключовий фокус
| Числові обчислення, autodiff, JIT, research ML
| Класичне машинне навчання
|-
| Типові задачі
| Neural networks, optimization, differentiable programming
| Classification, regression, clustering, preprocessing
|-
| API
| Функціональні transformations
| fit/predict/transform
|-
| Для табличного ML
| Можна, але часто потребує більше коду
| Дуже комфортно
|-
| Для gradients
| Сильна сторона
| Не ключовий фокус
|}
Приклад:
== pmap ==
Він надає можливість:
== Хороші практики роботи з JAX ==
<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
</div>
* Flax;
* Optax;
* Haiku;
* Equinox;
* Orbax;
* Chex;
* JAXopt;
* NumPyro;
* Distrax;
* TFP on JAX.; return x * 2
print(batched_square(jnp.array([1, 2, 3, 4])))
* SGD;
* Adam;
* AdamW;
* learning rate schedules;
* gradient transformations;
* gradient clipping;
* optimizer state;
* training loops.; '''Перевага:''' JAX поєднує знайомий стиль NumPy із сучасними можливостями для machine learning і high-performance computing.; PyTorch
</div>
<syntaxhighlight lang="python">
|-
| ключовий стиль
| Функціональні transformations: grad, jit, vmap
| Повна ML-платформа з Keras, TensorFlow Lite, Serving, TFX
|-
| Рівень
| Нижчий і гнучкіший для research
| Ширша production-екосистема
|-
| Neural networks
| Через Flax, Haiku, Equinox та інші бібліотеки
| Через Keras і TensorFlow API
|-
| Компіляція
| XLA через jit
| TensorFlow graph/XLA у відповідних сценаріях
|-
| Типове використання
| Research, differentiable programming, high-performance numeric code
| Production ML, deep learning, mobile/browser deployment
|}
<div style="background:#fdecea; border-left:6px solid #e74c3c; padding:12px; margin:12px 0;">
import jax.numpy as jnp
def pure_function(x):
Замість in-place mutation застосовується функціональний стиль актуалізація.; JAX особливо корисний для research, differentiable programming, optimization, neural networks, scientific computing і задач, де потрібно поєднати математичну гнучкість із продуктивністю.;== Обмеження JAX ==
== Висновок ==
Результат: training loop із gradients, optimizer update і evaluation.;<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">
Він надає можливість писати код, схожий на NumPy:
== Pytrees ==
Задача: навчити neural network.; Критерій
`grad` часто застосовується для:
== Тематичні мітки ==
== JAX ecosystem ==
<syntaxhighlight lang="text">
</div>
=== Neural network training ===
Equinox — це бібліотека для JAX, яка надає можливість описувати neural networks і differentiable programs через Python-класи, сумісні з pytrees.; df = jax.grad(f) це Python-бібліотека; додатково реалізовано автоматичного диференціювання.;== JAX для наукових обчислень ==
'''істотно:''' open-source ліцензійний пакет JAX не скасовує обмежень на інформаційні дані, моделі або сторонні бібліотеки, які використовуються разом із ним.;</div>
</div>
'''Небезпека:''' JAX-код має змогу бути дуже швидким, але неправильна технічна архітектура обчислень має змогу зробити його повільним, нестабільним або важким для налагодження.;
Вона надає можливість застосувати функцію до batch даних без ручного написання циклу.; * залежить лише від своїх аргументів;
</syntaxhighlight> Задача: застосувати функцію до batch прикладів.; Задача: пришвидшити числову функцію, яка викликається багато разів.;Суть immutable arrays: замість зміни масиву на місці JAX створює нове логічне представлення результату, що краще узгоджується з трансформаціями й компіляцією.; batched_square = jax.vmap(square) Automatic differentiationy = jnp.sin(x) + x ** 2
* defining neural networks;
* training models;
* research experiments;
* transformer models;
* model state;
* neural network modules;
* integration with Optax;
* large-scale ML research.; !; Приклади:
ПродуктивністьJIT означає Just-In-Time compilation.; Практична роль: XLA є собою однією з причин, чому JAX має змогу виконувати числові функції оперативно після компіляції.; * вищий поріг входу;
<syntaxhighlight lang="text"> VectorizationJAX дуже схожий на NumPy за стилем API, але має важливі відмінності.; `jit` має змогу пришвидшити обчислення, особливо якщо:
<syntaxhighlight lang="python"> |
|---|