2022-08-11,   김수민

이번 포스팅에서는 Gopgle JAX에 대해 알아보겠습니다.


JAX

Google JAX는 딥러닝 연구를 돕기 위해 개발된 NumPy Framwork 입니다.

특히 JAX는 NumPy의 연산들은 GPU를 가능하게 해 기존의 연산 속도를 개선한 기술입니다.

AutoGrad 기술과 XLA를 결합한 것이 가장 큰 특징입니다.

  • AutoGrad
    • 자동 미분 엔진으로 Python 과 NumPy 코드를 자동으로 미분하는 기술입니다.
    • tf. GradientTape API와 비슷하지만 구현 방식이 다릅니다.
      • AutoGrad는 함수에서 바로 경사도를 계산하지만, Tensorflow에서는 역전파를 사용해 손실 차이를 계산하고, 손실의 경사도를 추정해 다음 단계의 최선 값 예측을 수행합니다.
      • 또한, AutoGrad는 Python으로 작성되었지만, Tensorflow는 C++로 작성되고 Python Rapper을 사용해 구현되어 있습니다.
  • XLA(Accelarted Linear Algebra)

    • 최적화를 통해 계산 속도를 높일 수 있는 선행 대수 용 도메인별 그래프 기반의 컴파일러입니다.

    • JIT 컴파일러라고도 불립니다.

    • 소스코드의 변경 없이 모델 학습을 가속화할 수 있으며, 특히 속도와 메모리 사용을 개선했다는 특징이 있습니다.

    • tf.Graph를 주어진 모델에 맞추어 생성되는 계산 커널 시퀀스로 컴파일 합니다.

      계산 커널은 그 모델에 고유한 커널입니다. 따라서 모델에 특화된 정보 최적화로 활용됩니다.

특히 해당 기술들의 결합으로 Tensorflow와 Pytorch같은 프레임워크에 의존 없이 NumPy 만으로 Neural Network를 작성해 GPU로 학습이 가능하게 됩니다.

해당 기술은 2018년도 말에 제안되었으며, Git에 처음 릴리즈 된 것은 2020년 5월 입니다.

또한 Apache License 2.0을 사용하고 있습니다.


특징을 정리해보면 다음과 같습니다.

  • 신경망에서 기능을 추가 할 때, Tensorflow등과 같은 프레임워크보다 쉽게 확장이 가능합니다.
  • 컴파일러 최적화 기술 중 속도와 성능 측면에서 큰 이점을 갖고 있습니다.
  • 제한적인 경우에 대해서만 적용이 가능하다는 한계를 갖고 있습니다.

주의사항

1. 설치 주의 사항

  • JAX는 Windows는 지원하고 있지 않습니다.
  • CUDA 적용 시, CUDA v.11.1 이상에서만 지원하고 있습니다.

2. JAX 코드 작성 시 주의사항

  • 작성한 python 함수에 JAX를 적용하고자 하면, Pure Python Function 형태인 경우에만 적용이 가능합니다

    • Pure Python Function 이란, 입력이 같으면 출력이 항상 같음이 보장된 함수를 의미합니다.
    • 즉, 전역변수가 필요한 연산의 경우 그 결과가 원하는 바와 다를 수 있어 적용하는 경우 그 결과를 장담할 수 없습니다.

    < Example >

    g = 0. 
    def impure_uses_globals(x):
     return x + g 
    print ("First call: ", jit(impure_uses_globals)(4.)) 
    g = 10. # Update the global 
    print ("Second call: ", jit(impure_uses_globals)(5.)) 
    print ("Third call, different type: ", 
                     jit(impure_uses_globals)(jnp.array([4.])))
    
    First call: 4.0 
    Second call: 5.0 
    Third call, different type: [14.]
    

    Second call에서 작성자가 원했던 결과는 15.0 입니다. 하지만, g 값을 캐시로 저장해 5+0 연산을 수행했고,

    Third call에서 입력인자의 형식이 float 에서 array로 변경되면서 JAX가 함수에 대한 재 컴파일을 수행하게되 변경된 g값을 사용해 14.0을 출력합니다.

    즉, 사용자가 원하는 바와는 다른 전역변수의 사용이 이루어져 결과가 예상과 달라지게 된 것입니다.

  • JAX함수에서 Python의 Iterator을 사용하면 잘못된 결과가 발생할 수 있습니다.

    • JAX의 range함수와 Python의 Iterator의 사용을 구분해야 합니다.

    < Example >

    # lax.fori_loop 
    # expected result 45 
    array = jnp.arange(10) 
    print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) 
      
    # unexpected result 0
    iterator = iter(range(10)) 
    print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0))
    
    45 
    0
    

    JAX의 arange 함수를 사용해 array를 만든 경우의 합은 예상대로 45를 출력하지만,

    Python iterator을 사용해 array를 만든 경우 jax의 fori_loop을 사용하면 합을 정상적으로 구하지 못합니다.

  • JAX Device array에 대한 In-Place 업데이트는 사용할 수 없습니다.

    • In-Place Update는 array의 위치에 바로 =을 사용해 값을 넣는 방법을 말합니다.

    • 이 경우, Exception ‘<class ‘jaxlib.xla_extension.DeviceArray’>’ object does not support item assignment을 발생시킵니다.

    < Example >

    # JAX Device array 선언
    jax_array = jnp.zeros((3,3), dtype=jnp.float32)
      
    try:
     jax_array[1, :] = 1.0 # In-Place Update는 오류를 발생시킴
                           # jax_array.at[1, :].set(1.0) 을 사용해야함 
    except Exception as e:
     print("Exception {}".format(e))
    
    Exception '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. 
    
  • Python의 list 혹은 tuple은 JAX 함수의 입력으로 넣을 경우 Error을 반환합니다.

    • JAX에서 list와 tuple 은 각각의 요소들을 각각 다른 JAX 변수로 처리합니다.

    • 효과적으로 사용하기 위해서는 JAX array 혹은 NumPy로 변환하여 전달해 주어야 합니다.

    • 이 경우, TypeError를 발생시킵니다.

    < Example >

    try:
     jnp.sum([1, 2, 3])
     # -> jnp.sum(jnp.array([1,2,3]))
       
    except TypeError as e:
     print(f"TypeError: {e}")
    
    TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
    

이번 포스팅에서는 Google JAX에 대해 간략한 설명 및 특징과 주의사항에 대해 알아보았습니다.

다음 시간에는 간단한 기능의 샘플코드에 대해 알아보겠습니다.


< 참고 자료 >

[1] https://mjshin.tistory.com/13

[2] https://jax.readthedocs.io/en/latest/jax-101/index.html

[3] https://jjeamin.github.io/posts/jax/

업데이트: