Google JAX - 개요
이번 포스팅에서는 Gopgle JAX에 대해 알아보겠습니다.
JAX
Google JAX는 딥러닝 연구를 돕기 위해 개발된 NumPy Framwork 입니다.
특히 JAX는 NumPy의 연산들은 GPU를 가능하게 해 기존의 연산 속도를 개선한 기술입니다.
AutoGrad 기술과 XLA를 결합한 것이 가장 큰 특징입니다.
- AutoGrad
- 자동 미분 엔진으로 Python 과 NumPy 코드를 자동으로 미분하는 기술입니다.
- tf. GradientTape API와 비슷하지만 구현 방식이 다릅니다.
- AutoGrad는 함수에서 바로 경사도를 계산하지만, Tensorflow에서는 역전파를 사용해 손실 차이를 계산하고, 손실의 경사도를 추정해 다음 단계의 최선 값 예측을 수행합니다.
- 또한, AutoGrad는 Python으로 작성되었지만, Tensorflow는 C++로 작성되고 Python Rapper을 사용해 구현되어 있습니다.
-
XLA(Accelarted Linear Algebra)
-
최적화를 통해 계산 속도를 높일 수 있는 선행 대수 용 도메인별 그래프 기반의 컴파일러입니다.
-
JIT 컴파일러라고도 불립니다.
-
소스코드의 변경 없이 모델 학습을 가속화할 수 있으며, 특히 속도와 메모리 사용을 개선했다는 특징이 있습니다.
-
tf.Graph를 주어진 모델에 맞추어 생성되는 계산 커널 시퀀스로 컴파일 합니다.
계산 커널은 그 모델에 고유한 커널입니다. 따라서 모델에 특화된 정보 최적화로 활용됩니다.
-
특히 해당 기술들의 결합으로 Tensorflow와 Pytorch같은 프레임워크에 의존 없이 NumPy 만으로 Neural Network를 작성해 GPU로 학습이 가능하게 됩니다.
해당 기술은 2018년도 말에 제안되었으며, Git에 처음 릴리즈 된 것은 2020년 5월 입니다.
또한 Apache License 2.0을 사용하고 있습니다.
특징을 정리해보면 다음과 같습니다.
- 신경망에서 기능을 추가 할 때, Tensorflow등과 같은 프레임워크보다 쉽게 확장이 가능합니다.
- 컴파일러 최적화 기술 중 속도와 성능 측면에서 큰 이점을 갖고 있습니다.
- 제한적인 경우에 대해서만 적용이 가능하다는 한계를 갖고 있습니다.
주의사항
1. 설치 주의 사항
- JAX는 Windows는 지원하고 있지 않습니다.
- CUDA 적용 시, CUDA v.11.1 이상에서만 지원하고 있습니다.
2. JAX 코드 작성 시 주의사항
-
작성한 python 함수에 JAX를 적용하고자 하면, Pure Python Function 형태인 경우에만 적용이 가능합니다
- Pure Python Function 이란, 입력이 같으면 출력이 항상 같음이 보장된 함수를 의미합니다.
- 즉, 전역변수가 필요한 연산의 경우 그 결과가 원하는 바와 다를 수 있어 적용하는 경우 그 결과를 장담할 수 없습니다.
< Example >
g = 0. def impure_uses_globals(x): return x + g print ("First call: ", jit(impure_uses_globals)(4.)) g = 10. # Update the global print ("Second call: ", jit(impure_uses_globals)(5.)) print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))
First call: 4.0 Second call: 5.0 Third call, different type: [14.]
Second call에서 작성자가 원했던 결과는 15.0 입니다. 하지만, g 값을 캐시로 저장해 5+0 연산을 수행했고,
Third call에서 입력인자의 형식이 float 에서 array로 변경되면서 JAX가 함수에 대한 재 컴파일을 수행하게되 변경된 g값을 사용해 14.0을 출력합니다.
즉, 사용자가 원하는 바와는 다른 전역변수의 사용이 이루어져 결과가 예상과 달라지게 된 것입니다.
-
JAX함수에서 Python의 Iterator을 사용하면 잘못된 결과가 발생할 수 있습니다.
- JAX의 range함수와 Python의 Iterator의 사용을 구분해야 합니다.
< Example >
# lax.fori_loop # expected result 45 array = jnp.arange(10) print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # unexpected result 0 iterator = iter(range(10)) print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0))
45 0
JAX의 arange 함수를 사용해 array를 만든 경우의 합은 예상대로 45를 출력하지만,
Python iterator을 사용해 array를 만든 경우 jax의 fori_loop을 사용하면 합을 정상적으로 구하지 못합니다.
-
JAX Device array에 대한 In-Place 업데이트는 사용할 수 없습니다.
-
In-Place Update는 array의 위치에 바로 =을 사용해 값을 넣는 방법을 말합니다.
-
이 경우, Exception ‘<class ‘jaxlib.xla_extension.DeviceArray’>’ object does not support item assignment을 발생시킵니다.
< Example >
# JAX Device array 선언 jax_array = jnp.zeros((3,3), dtype=jnp.float32) try: jax_array[1, :] = 1.0 # In-Place Update는 오류를 발생시킴 # jax_array.at[1, :].set(1.0) 을 사용해야함 except Exception as e: print("Exception {}".format(e))
Exception '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable.
-
-
Python의 list 혹은 tuple은 JAX 함수의 입력으로 넣을 경우 Error을 반환합니다.
-
JAX에서 list와 tuple 은 각각의 요소들을 각각 다른 JAX 변수로 처리합니다.
-
효과적으로 사용하기 위해서는 JAX array 혹은 NumPy로 변환하여 전달해 주어야 합니다.
-
이 경우, TypeError를 발생시킵니다.
< Example >
try: jnp.sum([1, 2, 3]) # -> jnp.sum(jnp.array([1,2,3])) except TypeError as e: print(f"TypeError: {e}")
TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
-
이번 포스팅에서는 Google JAX에 대해 간략한 설명 및 특징과 주의사항에 대해 알아보았습니다.
다음 시간에는 간단한 기능의 샘플코드에 대해 알아보겠습니다.
< 참고 자료 >
[1] https://mjshin.tistory.com/13
[2] https://jax.readthedocs.io/en/latest/jax-101/index.html
[3] https://jjeamin.github.io/posts/jax/