Перейти до вмісту

JAX

Матеріал з K2 ERP Wiki
Версія від 19:11, 8 травня 2026, створена R (обговорення | внесок) (Створена сторінка: {{SEO |title=JAX — Python-бібліотека для високопродуктивних обчислень, automatic differentiation, NumPy API і машинного навчання |description=JAX — Wiki-стаття про Python-бібліотеку для високопродуктивних числових обчислень, automatic differentiation, JIT-компіляції, NumPy-подібного API, GPU/TPU-прискорення і...)
(різн.) ← Попередня версія | Поточна версія (різн.) | Новіша версія → (різн.)

def square(x):

XLA сприяє:

Основні переважні аспекти JAX:

<div style="background:#fff4e5; border-left:6px solid #f39c12; padding:12px; margin:12px 0;">

{{SEO
|title=JAX  Python-бібліотека для високопродуктивних обчислень, automatic differentiation, NumPy API і машинного навчання
|description=JAX  Wiki-стаття про Python-бібліотеку для високопродуктивних числових обчислень, automatic differentiation, JIT-компіляції, NumPy-подібного API, GPU/TPU-прискорення і machine learning. Розглянуто jax.numpy, grad, jit, vmap, pmap, XLA, pure functions, immutable arrays, PRNG, JAX ecosystem, Flax, Optax, Haiku, Equinox, переваги, обмеження, безпеку і відповідальне використання.
|keywords=JAX, jax.numpy, jnp, Google JAX, Python JAX, automatic differentiation, autograd, jit, vmap, pmap, XLA, GPU, TPU, NumPy API, machine learning, deep learning, high-performance computing, differentiable programming, Flax, Optax, Haiku, Equinox, neural networks, functional programming, JAX arrays
|alternativeTo=ручна реалізація automatic differentiation; повільні NumPy-обчислення без GPU/TPU; самописна JIT-компіляція; складне масштабування числових обчислень; ручне векторизування циклів; окремі інструменти для gradient-based optimization; класичні Python-обчислення без accelerator support
}}

* писати NumPy-подібний код;
* автономно обчислювати gradients;
* компілювати функції через jit;
* векторизувати функції через vmap;
* паралелити обчислення через pmap;
* працювати з GPU і TPU;
* будувати neural networks через додаткові бібліотеки;
* створювати differentiable programs;
* оптимізувати числові функції;
* виконувати research-oriented ML-експерименти.; '''Практична цінність:''' якщо наукова модель диференційована, JAX має змогу допомогти оптимізувати її параметри через gradients.; '''Pytrees'''  це вкладені структури Python, які JAX має змогу обробляти як дерева даних.; * Документація Flax.; TensorFlow
'''Просте пояснення:''' JAX Array  це масив для числових обчислень, який має змогу працювати в JAX-світі: з gradients, JIT і прискорювачами.; * Документація Haiku.; Приклади:

== Приклади задач ==

import jax
</div>
!; Optax має змогу використовуватися для:
<div style="background:#fff7ed; border-left:6px solid #fb923c; padding:12px; margin:12px 0;">
== JAX для research ==

Для neural networks зазвичай використовують:

Приклад:

Приклади:

Для налагодження корисно:

* control flow;
* shapes;
* static arguments;
* error messages;
* recompilation;
* debug behavior.;</div>

</div>

* physics simulations;
* optimization;
* differential equations;
* computational biology;
* probabilistic modeling;
* numerical methods;
* inverse problems;
* differentiable rendering;
* scientific machine learning.; JAX часто порівнюють із TensorFlow.;</div>
== Immutable arrays ==
'''jax.grad'''  це трансформація, яка створює функцію для обчислення gradient.; * JAX Quickstart.; Результат: compiled version функції для швидшого виконання.; !; Проблеми можуть виникати, якщо:
import jax
'''Увага:''' JAX не автономно пришвидшує будь-який Python-код.; Навколо нього існує програмний пакет бібліотек.; '''Практична роль:''' Equinox зручний для користувачів, які хочуть поєднати JAX-підхід із простими Python-класами.; Потрібно враховувати:
  • писати pure functions;
  • передавати state явно;
  • використовувати jax.numpy замість numpy у JAX-функціях;
  • спочатку перевіряти код без jit;
  • використовувати jit для “гарячих” обчислень;
  • використовувати vmap замість ручних циклів;
  • контролювати shapes і dtypes;
  • правильно працювати з PRNG keys;
  • зберігати прості й тестовані функції;
  • вимірювати продуктивність;
  • уникати зайвих device-host transfers;
  • документувати numerical assumptions;
  • тестувати gradients.;
Під час роботи з JAX часто виникають типові помилки.; x = jnp.array([1, 2, 3]) Приклад: У JAX робота з випадковістю відрізняється від NumPy.; JAX не намагається бути однією великою бібліотекою для всього.; JAX можна використовувати в різних сценаріях.; Помилка: обирати JAX лише з цієї причини, що він швидкий.;== Джерела == Головна думка: JAX — це не без ускладнень “швидкий NumPy”, а платформа composable transformations для Python-функцій, яка відкриває потужні функціональні можливості для gradients, JIT, vectorization і accelerator-based computing.; * JAX GitHub repository.; Tracing — це механізм, через який JAX аналізує функцію для трансформацій на кшталт `jit`, `grad` або `vmap`.;== Debugging у JAX == </syntaxhighlight>

Рекомендовано:

  • batch processing;
  • per-example gradients;
  • vectorized evaluation;
  • заміни Python loops;
  • прискорення обчислень;
  • cleaner code.; JAX часто застосовується в машинному навчанні, deep learning, наукових обчисленнях, optimization, differentiable programming, research-проєктах і задачах, де потрібне поєднання гнучкого Python-коду з високою продуктивністю.; !; * писати JAX-код як звичайний NumPy без урахування immutability;
  • забувати розділяти random keys;
  • додавати side effects у jit-функції;
  • очікувати, що print працюватиме як у звичайному Python;
  • створювати багато recompilations через змінні shapes;
  • використовувати Python loops замість vmap або scan;
  • переносити інформаційні дані між CPU і GPU занадто часто;
  • не тестувати функції до jit;
  • не контролювати dtype;
  • не зберігати reproducibility.;
    '''Просте пояснення:''' JAX спочатку дивиться на функцію як на обчислення, яке можна трансформувати, а вже потім виконує оптимізований варіант.; state.append(x)
    
Суть automatic differentiation: JAX має змогу сам побудувати функцію, яка обчислює gradient іншої функції.;

jit

def loss(w):

== JAX і Scikit-learn ==

</div>

* multi-GPU training;
* multi-TPU computation;
* паралельного виконання batch;
* distributed-style обчислень;
* масштабування ML-експериментів.; JAX
'''Haiku''' — це бібліотека для neural networks на JAX, створювалась як DeepMind.; Інструмент: jax.vmap.; JAX — це інструмент для обчислень і ML, з цієї причини відповідальність за моделі та їхнє використання залишається за розробником.;=== JIT-компіляція ===

Debugging у JAX має змогу бути складнішим, ніж у звичайному Python, особливо всередині `jit`.;== ліцензійний пакет ==

* параметрів моделей;
* gradients;
* optimizer state;
* batch data;
* structured outputs;
* tree transformations.; * JAX automatic differentiation documentation.;== Типові помилки в JAX ==

== переважні аспекти JAX ==

def f(x):

* спочатку запускати без jit;
* перевіряти shapes;
* перевіряти dtypes;
* використовувати менші приклади;
* уникати зайвої складності;
* тестувати функції окремо;
* додавати asserts там, де доречно;
* розуміти tracing;
* обережно працювати з print у compiled code.;== Haiku ==

<div style="background:#fff4e5; border-left:6px solid #f39c12; padding:12px; margin:12px 0;">

<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">

'''Висновок:''' NumPy — базова бібліотека числових обчислень, а JAX додає до NumPy-подібного стилю autodiff, JIT і accelerator support.; Це низькорівнева й гнучка платформа числових обчислень і трансформацій, поверх якої часто використовують додаткові бібліотеки.; '''JAX Array''' — це ключовий тип масиву в JAX.; JAX

<syntaxhighlight lang="python">
import jax.numpy as jnp

Pytree має змогу містити:

* optimization;
* training neural networks;
* loss functions;
* scientific computing;
* differentiable simulations;
* gradient-based methods.; JAX сам по собі не має такого центрального high-level neural network API, як `torch.nn` у PyTorch або Keras у TensorFlow.;== Flax ==

<syntaxhighlight lang="python">

</div>

* list;
* tuple;
* dict;
* dataclass;
* nested structures;
* arrays;
* parameters of neural networks.;
XLA або Accelerated Linear Algebra — це компілятор, який застосовується JAX для оптимізації числових обчислень.; Висновок: Scikit-learn краще підходить для класичного tabular ML, а JAX — для задач, де потрібні gradients, JIT і custom numerical computation.; Задача: знайти gradient loss-функції.;
  • створювати modules;
  • керувати parameters;
  • будувати neural networks;
  • працювати з JAX transformations;
  • організовувати model code.;== Загальний характеристика ==

y = x.at [0].set(10)

JAX і Scikit-learn мають різні ролі.; |-

ключовий стиль Functional programming і transformations Imperative/eager style із dynamic computation graph
Autodiff grad як функціональна трансформація autograd через tensor operations
Neural network API Зазвичай через Flax, Haiku, Equinox torch.nn вбудований у PyTorch
Research Сильний у composable transformations і accelerator-oriented code Дуже популярний у deep learning research
Стан моделі Часто передається явно Часто зберігається в modules/objects

key = jax.random.PRNGKey(0)

JAX для neural networks

JAX arrays схожі на NumPy arrays, але мають важливі відмінності:

Equinox

return (w - 5.0) ** 2

Інструменти: JAX + Flax/Haiku/Equinox + Optax.;</syntaxhighlight> x = jnp.array([1.0, 2.0, 3.0]) Головне правило: у JAX shapes і dtypes — це частина дизайну програми, а не другорядна деталь.; b = jax.random.uniform(key2, shape=(3,))

vmap

Flax застосовується для:

Вона надає можливість автономно обчислювати похідні функцій.;

До них належать:

істотно: pmap складніший за grad, jit і vmap.; key1, key2 = jax.random.split(key)

  • ліцензію JAX;
  • ліцензії залежностей;
  • ліцензії моделей;
  • ліцензії датасетів;
  • умови використання accelerator-середовища;
  • політики організації;
  • вимоги до attribution.; Практична роль: grad надає можливість писати математичну функцію напряму, а похідні для оптимізації отримувати автономно.;
Інструмент: jax.jit.;
  • компілювати array operations;
  • оптимізувати граф обчислень;
  • виконувати код на CPU, GPU або TPU;
  • об’єднувати операції;
  • зменшувати overhead;
  • пришвидшувати великі обчислення.; Scikit-learn
jax.vmap — це трансформація для автоматичної векторизації функцій.;

JAX застосовується не лише для нейронних мереж, а й для наукових обчислень.;== XLA ==

JAX дуже популярний у research-середовищах, з цієї причини що він надає можливість оперативно експериментувати з математичними ідеями.; Критерій

Небажаний підхід:

Pure functions

</syntaxhighlight>

Небезпека: код має змогу виглядати схожим на NumPy, але поводитися інакше через JAX-трансформації, компіляцію і immutable arrays.; {| class="wikitable"

; state = []
JAX додатково часто порівнюють із PyTorch.;
return x ** 2 + 3 * x + 1

Equinox має змогу бути корисним для: Приклад:

grad_loss = jax.grad(loss)

  • arrays;
  • matrix operations;
  • linear algebra;
  • broadcasting;
  • elementwise functions;
  • reductions;
  • reshaping;
  • indexing;
  • mathematical functions.;

JAX є собою open-source проєктом.; def impure_function(x):

  • shape змінюється між викликами jit-функції;
  • dtype не той, який очікувався;
  • інформаційні дані не на з цієї причини device;
  • модель очікує batch, а отримує один приклад;
  • vmap застосований по неправильній осі;
  • broadcasting функціонує не так, як очікувалося.; pmap має змогу використовуватися для:

Критично: швидка модель не означає правильна модель.; Практична ідея: явні random keys роблять випадковість контрольованішою, відтворюванішою і суміснішою з functional programming.; print(grad_loss(2.0))

Суть jit: JAX компілює Python-функцію у швидший обчислювальний код, який має змогу результативно виконуватися на accelerator hardware.; jax.pmap — це трансформація для паралельного виконання обчислень на кількох devices.;</syntaxhighlight>

  • навчання neural network;
  • custom optimization;
  • differentiable physics simulation;
  • research prototype;
  • reinforcement learning;
  • probabilistic modeling;
  • scientific computing;
  • gradient-based calibration;
  • vectorized numerical experiments;
  • high-performance array computation;
  • TPU-based experiments;
  • custom loss functions.;

a = jax.random.normal(key1, shape=(3,))

JAX-документація зазначає, що autodiff у JAX надає можливість без зайвих зусиль обчислювати похідні вищих порядків, бо функції, які обчислюють derivatives, самі можуть бути диференційованими.;
  • функція викликається багато разів;
  • обчислення великі;
  • застосовується GPU або TPU;
  • є собою багато array operations;
  • код підходить для компіляції.; * Документація Optax.;== Optax ==

`vmap` корисний для:

JAX Array

Суть jax.numpy: розробник пише код у стилі NumPy, але отримує можливість використовувати JAX-трансформації: grad, jit, vmap та інші.; Код потрібно писати з урахуванням JIT, vectorization і device execution.; Критерій

Просте пояснення: pytree надає можливість JAX працювати не лише з одним масивом, а з цілою вкладеною структурою масивів.;

</syntaxhighlight>

істотно: у JAX стан моделі й параметри часто передаються явно, що має змогу бути незвично для користувачів PyTorch або Keras.; !;

Це має змогу впливати на:

  • великі array operations;
  • jit-compiled functions;
  • vectorized code;
  • batch computation;
  • accelerator-friendly logic;
  • pure functions;
  • мінімум Python loops у compiled hot path.; * JAX documentation щодо jit, vmap, pmap і pytrees.;== JAX і NumPy ==

Примітка: Haiku є собою одним із варіантів neural network framework поверх JAX, але не є собою єдиним стандартом.; import jax.numpy as jnp

</syntaxhighlight>

Суть екосистеми: JAX дає фундаментальні трансформації й обчислення, а додаткові бібліотеки додають neural networks, optimizers, checkpoints, probabilistic programming та інші інструменти.; Optax — це бібліотека optimization algorithms для JAX.;== grad ==

Pytrees часто використовуються для:

  • Офіційна документація JAX.;</syntaxhighlight>

jax.numpy підтримує роботу багато знайомих операцій:

істотно: JAX — це не повна high-level ML-платформа на кшталт TensorFlow або PyTorch.; Для ефективного використання потрібно розуміти devices, sharding, data layout і синхронізацію.;

def compute(x):

JAX використовує explicit random keys.;

Добре працюють:

Безпека і відповідальне використання

Див.; додатково

  • custom loss functions;
  • differentiable simulations;
  • optimization algorithms;
  • neural architectures;
  • reinforcement learning;
  • probabilistic programming;
  • scientific ML;
  • large-scale research;
  • vectorized experiments;
  • accelerator-friendly code.; JAX arrays зазвичай розглядаються як immutable.; Результати JAX-обчислень потрібно тестувати, перевіряти і валідувати на реальних сценаріях.; NumPy

Тут `y` — новий масив із оновленим значенням.; x = jax.random.normal(key, shape=(3,)) офіційно затверджений GitHub-репозиторій JAX описує його як систему для composable transformations of Python+NumPy programs, а серед ключових трансформацій виділяє `grad`, `jit` і `vmap`.; Репозиторій JAX поширюється під ліцензією Apache 2.0.; Висновок: JAX більше схожий на гнучку систему числових трансформацій, а TensorFlow — на ширшу end-to-end ML-платформу.; JAX можна розглядати як систему перетворень для числових Python-функцій.; Підказка: JAX варто вивчати через маленькі функції: спочатку jnp, потім grad, потім jit, потім vmap.; Приклад:

Головна перевага: JAX надає можливість комбінувати математично чистий Python-код із потужними трансформаціями для gradients, compilation і vectorization.;

Pure function — це функція, яка:

  • machine learning research;
  • deep learning;
  • neural networks;
  • optimization;
  • automatic differentiation;
  • scientific computing;
  • simulation;
  • probabilistic modeling;
  • differentiable programming;
  • reinforcement learning;
  • large-scale numerical computing;
  • GPU/TPU acceleration.; * model parameters;
  • forward function;
  • loss function;
  • grad;
  • optimizer update;
  • jit;
  • batch processing;
  • evaluation.; Критерій

JAX застосовується там, де потрібні швидкі числові обчислення і gradients.; return jnp.sin(x) * jnp.cos(x) + x ** 2

Інструмент: jax.grad.;== jax.numpy ==

істотно: JAX-трансформації краще працюють із функціональним стилем програмування, де стан передається явно, а не змінюється приховано.;
  • багато дрібних Python-викликів;
  • часті передачі даних між host і device;
  • side effects;
  • динамічні форми масивів;
  • погано структурований код;
  • надмірна recompilation.;

Для кількох випадкових операцій key потрібно розділяти:

Типові помилки користувачів

;== Типові сценарії використання ==

Він корисний для: JAX — це Python-бібліотека для високопродуктивних числових обчислень, automatic differentiation, JIT-компіляції, векторизації і роботи з accelerator hardware.; import jax

Практична порада: перед оптимізацією через jit спочатку варто переконатися, що функція правильно функціонує у звичайному режимі.;

Типовий training loop у JAX складається з:

  • NumPy-подібний API;
  • automatic differentiation;
  • jit compilation;
  • vmap для vectorization;
  • pmap для parallelism;
  • GPU/TPU support;
  • composable transformations;
  • functional programming style;
  • зручність для research;
  • сильний для optimization;
  • підходить для differentiable programming;
  • програмний пакет Flax, Optax, Haiku, Equinox.; import jax.numpy as jnp

Основна ідея: JAX надає можливість писати код у стилі NumPy, але додавати до нього automatic differentiation, JIT-компіляцію, векторизацію і прискорення на GPU/TPU.; Результат: функція, яка повертає похідну або gradients параметрів.; |-

ключовий фокус Прискорені числові обчислення, transformations, autodiff Загальні числові обчислення в Python
GPU/TPU супровід accelerator execution Зазвичай CPU-орієнтований
Automatic differentiation Вбудовано через grad Немає вбудованого autodiff
JIT є собою через jax.jit Немає стандартного JIT у NumPy
Mutability Functional-style updates Часто in-place mutation
Гірше працюють:
<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">

== Automatic differentiation ==

== PRNG у JAX ==

!; * neural networks;
* scientific computing;
* differentiable programming;
* structured models;
* research code;
* функціонального стилю з класами.; У JAX істотно контролювати shape і dtype.; Під час tracing JAX не завжди має звичайні Python-значення, а функціонує з абстрактними представленнями.;<div style="background:#fff7ed; border-left:6px solid #fb923c; padding:12px; margin:12px 0;">
'''Практична порада:''' якщо задача потребує gradients, accelerator execution і кастомної математики, JAX має змогу бути дуже сильним вибором.;== JAX і TensorFlow ==

</div>

* очікування NumPy-style mutation;
* використання side effects у jit-функціях;
* неправильна робота з random keys;
* надмірна recompilation;
* Python control flow там, де потрібен JAX control flow;
* змішування NumPy і jax.numpy без розуміння наслідків;
* передача Python objects у jit без static_argnums;
* часті device-host transfers;
* неправильне використання vmap;
* недостатнє розуміння shapes.; result = compute(jnp.ones((1000,)))
'''Головне правило:''' JAX найкраще функціонує тоді, коли код написаний функціонально, інформаційні дані мають стабільні shapes, а transformations використовуються усвідомлено.;</div>
import jax.numpy as jnp
import jax.numpy as jnp

Приклад:

<div style="background:#fef2f2; border-left:6px solid #ef4444; padding:12px; margin:12px 0;">

<div style="background:#fef2f2; border-left:6px solid #ef4444; padding:12px; margin:12px 0;">

* Flax;
* Haiku;
* Equinox;
* custom JAX code;
* Optax для optimizers.; import jax

<syntaxhighlight lang="python">

JAX має змогу бути дуже швидким, але продуктивність залежить від стилю коду.; return x * 2
Перед використанням у продукті потрібно перевіряти:
!; * якість даних;
* bias;
* correctness of gradients;
* reproducibility;
* numerical stability;
* privacy;
* security of model deployment;
* ліцензії даних;
* вплив ML-рішень на користувачів;
* моніторинг після deployment.;<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">

!; Вона сприяє:
== Tracing ==
print(df(2.0))

'''jax.numpy''' або '''jnp'''  це NumPy-подібний API у JAX.; Це означає, що масив не змінюється на місці так само, як це часто роблять у NumPy.;<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">
== Shape і dtype ==

<syntaxhighlight lang="python">
== Для чого застосовується JAX ==

!; JAX
!;<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">
<div style="background:#e8f8f5; border-left:6px solid #16a085; padding:12px; margin:12px 0;">
<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">

Результат: векторизована функція без ручного Python loop.; Типовий приклад:
</div>
Можливі складнощі:

Поширені помилки:

'''Висновок:''' PyTorch часто зручніший для класичного object-oriented deep learning workflow, а JAX  для функціонального, трансформаційного і research-oriented підходу.;</div>

'''Для research:''' JAX цінують за те, що transformations можна комбінувати: як ілюстрація, grad + jit + vmap.; '''Просте пояснення:''' vmap бере функцію для одного прикладу і автономно робить її функцією для batch.; Окремо варто відзначити JIT-компіляції, векторизації, роботи з NumPy-подібним API і запуску обчислень на CPU, GPU і TPU виступає ключовою рисою високопродуктивних числових обчислень забезпечується через '''JAX'''.; '''jax.jit'''  це трансформація, яка компілює функцію для швидшого виконання.; import jax
'''Практична роль:''' якщо JAX  це обчислювальний фундамент, то Flax часто застосовується як high-level neural network library поверх JAX.; '''Практична роль:''' Optax часто застосовується разом із JAX і Flax для навчання neural networks.; return x ** 2
JAX найкраще функціонує з '''pure functions'''.; * Документація Equinox.; Вона поєднує NumPy-подібний API із потужними функціональними трансформаціями: `grad`, `jit`, `vmap`, `pmap`.; Якщо задача проста й таблична, Scikit-learn або NumPy можуть бути практичнішими.;<div style="background:#e8f8f5; border-left:6px solid #16a085; padding:12px; margin:12px 0;">
JAX має обмеження.; '''Flax'''  це бібліотека для neural networks на JAX.; @jax.jit
<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">

== JAX і PyTorch ==

'''Automatic differentiation'''  одна з ключових можливостей JAX.; |-
| ключовий фокус
| Числові обчислення, autodiff, JIT, research ML
| Класичне машинне навчання
|-
| Типові задачі
| Neural networks, optimization, differentiable programming
| Classification, regression, clustering, preprocessing
|-
| API
| Функціональні transformations
| fit/predict/transform
|-
| Для табличного ML
| Можна, але часто потребує більше коду
| Дуже комфортно
|-
| Для gradients
| Сильна сторона
| Не ключовий фокус
|}

Приклад:
== pmap ==
Він надає можливість:

== Хороші практики роботи з JAX ==

<div style="background:#e7f3ff; border-left:6px solid #2b7cff; padding:12px; margin:12px 0;">

</div>

* Flax;
* Optax;
* Haiku;
* Equinox;
* Orbax;
* Chex;
* JAXopt;
* NumPyro;
* Distrax;
* TFP on JAX.; return x * 2

print(batched_square(jnp.array([1, 2, 3, 4])))

* SGD;
* Adam;
* AdamW;
* learning rate schedules;
* gradient transformations;
* gradient clipping;
* optimizer state;
* training loops.; '''Перевага:''' JAX поєднує знайомий стиль NumPy із сучасними можливостями для machine learning і high-performance computing.; PyTorch
</div>
<syntaxhighlight lang="python">
|-
| ключовий стиль
| Функціональні transformations: grad, jit, vmap
| Повна ML-платформа з Keras, TensorFlow Lite, Serving, TFX
|-
| Рівень
| Нижчий і гнучкіший для research
| Ширша production-екосистема
|-
| Neural networks
| Через Flax, Haiku, Equinox та інші бібліотеки
| Через Keras і TensorFlow API
|-
| Компіляція
| XLA через jit
| TensorFlow graph/XLA у відповідних сценаріях
|-
| Типове використання
| Research, differentiable programming, high-performance numeric code
| Production ML, deep learning, mobile/browser deployment
|}

<div style="background:#fdecea; border-left:6px solid #e74c3c; padding:12px; margin:12px 0;">
import jax.numpy as jnp
def pure_function(x):

Замість in-place mutation застосовується функціональний стиль актуалізація.; JAX особливо корисний для research, differentiable programming, optimization, neural networks, scientific computing і задач, де потрібно поєднати математичну гнучкість із продуктивністю.;== Обмеження JAX ==

== Висновок ==

Результат: training loop із gradients, optimizer update і evaluation.;<div style="background:#ecfdf5; border-left:6px solid #10b981; padding:12px; margin:12px 0;">
Він надає можливість писати код, схожий на NumPy:
== Pytrees ==

Задача: навчити neural network.; Критерій

`grad` часто застосовується для:

== Тематичні мітки ==
== JAX ecosystem ==
<syntaxhighlight lang="text">
</div>
=== Neural network training ===

Equinox — це бібліотека для JAX, яка надає можливість описувати neural networks і differentiable programs через Python-класи, сумісні з pytrees.; df = jax.grad(f)

це Python-бібліотека; додатково реалізовано автоматичного диференціювання.;== JAX для наукових обчислень ==
'''істотно:''' open-source ліцензійний пакет JAX не скасовує обмежень на інформаційні дані, моделі або сторонні бібліотеки, які використовуються разом із ним.;</div>

</div>

'''Небезпека:''' JAX-код має змогу бути дуже швидким, але неправильна технічна архітектура обчислень має змогу зробити його повільним, нестабільним або важким для налагодження.;

Вона надає можливість застосувати функцію до batch даних без ручного написання циклу.; * залежить лише від своїх аргументів;

  • не змінює зовнішній стан;
  • не має прихованих побічних ефектів;
  • для однакових входів повертає однаковий результат.; JAX
  • можуть виконуватися на accelerator hardware;
  • підтримують JAX-трансформації;
  • зазвичай є собою immutable;
  • можуть бути частиною compiled computation;
  • можуть брати участь в automatic differentiation;
  • можуть переноситися між devices.; Типові задачі:

</syntaxhighlight>

Задача: застосувати функцію до batch прикладів.; Задача: пришвидшити числову функцію, яка викликається багато разів.;

Суть immutable arrays: замість зміни масиву на місці JAX створює нове логічне представлення результату, що краще узгоджується з трансформаціями й компіляцією.; batched_square = jax.vmap(square)

Automatic differentiation

y = jnp.sin(x) + x ** 2

* defining neural networks;
* training models;
* research experiments;
* transformer models;
* model state;
* neural network modules;
* integration with Optax;
* large-scale ML research.; !; Приклади:

Продуктивність

JIT означає Just-In-Time compilation.; Практична роль: XLA є собою однією з причин, чому JAX має змогу виконувати числові функції оперативно після компіляції.; * вищий поріг входу;

  • незвичний functional style;
  • immutable arrays;
  • explicit PRNG keys;
  • складніші помилки при jit;
  • потрібно розуміти tracing;
  • не всі NumPy-патерни переносяться напряму;
  • neural network API винесений в окремі бібліотеки;
  • production deployment має змогу потребувати додаткової роботи;
  • складніше debugging у compiled code;
  • можливі проблеми сумісності з версіями CUDA/TPU stack.; Водночас JAX потребує розуміння functional programming, immutable arrays, explicit random keys, tracing, shapes, dtypes і особливостей compiled execution.;

<syntaxhighlight lang="text">

Vectorization

JAX дуже схожий на NumPy за стилем API, але має важливі відмінності.; `jit` має змогу пришвидшити обчислення, особливо якщо:

<syntaxhighlight lang="python">