[ad_1]
Popüler açık kaynaklı TensorFlow makine öğrenimi platformunu güçlendiren yenilikler arasında otomatik farklılaştırma (otograd) ve XLA (Hızlandırılmış Doğrusal Cebir) derin öğrenme için derleyiciyi optimize ediyor.
Bu iki teknolojiyi bir araya getiren bir diğer proje olan Google JAX, hız ve performans açısından önemli avantajlar sunuyor. GPU’larda veya TPU’larda çalıştırıldığında JAX, çağrı yapan diğer programların yerini alabilir. Dizi, ancak programları çok daha hızlı çalışır. Ek olarak, sinir ağları için JAX kullanmak, TensorFlow gibi daha büyük bir çerçeveyi genişletmekten çok yeni işlevler eklemeyi çok daha kolay hale getirebilir.
Bu makale, faydalarına ve sınırlamalarına genel bir bakış, kurulum talimatları ve Colab’daki Google JAX hızlı başlangıcına ilk bakış dahil olmak üzere Google JAX’ı tanıtmaktadır.
Otograd nedir?
Autograd, Ryan Adams’ın Harvard Intelligent Probabilistic Systems Group’ta bir araştırma projesi olarak başlayan otomatik bir farklılaşma motorudur. Bu yazı itibariyle, motorun bakımı yapılıyor ancak artık aktif olarak geliştirilmiyor. Bunun yerine geliştiricileri, Autograd’ı XLA JIT derlemesi gibi ek özelliklerle birleştiren Google JAX üzerinde çalışıyor. bu otograd motor yerel Python ve NumPy kodunu otomatik olarak ayırt edebilir. Birincil amaçlanan uygulaması, gradyan tabanlı optimizasyondur.
TensorFlow’lar tf.GradientTape
API, Autograd’a benzer fikirlere dayanmaktadır, ancak uygulaması aynı değildir. Autograd tamamen Python’da yazılmıştır ve gradyanı doğrudan fonksiyondan hesaplarken, TensorFlow’un gradyan teyp işlevselliği ince bir Python sarıcı ile C++ ile yazılmıştır. TensorFlow, kayıptaki farklılıkları hesaplamak, kaybın gradyanını tahmin etmek ve bir sonraki en iyi adımı tahmin etmek için geri yayılımı kullanır.
XLA nedir?
XLA TensorFlow tarafından geliştirilen lineer cebir için alana özgü bir derleyicidir. TensorFlow belgelerine göre, XLA, potansiyel olarak kaynak kodu değişikliği olmadan TensorFlow modellerini hızlandırabilir, bu da hızı ve bellek kullanımını iyileştirir. Bir örnek 2020 Google’dır BERT MLPerf kıyaslama gönderimiburada XLA kullanan 8 Volta V100 GPU, ~7x performans iyileştirmesi ve ~5x parti boyutu iyileştirmesi elde etti.
XLA, bir TensorFlow grafiğini, belirli bir model için özel olarak oluşturulmuş bir dizi hesaplama çekirdeğinde derler. Bu çekirdekler modele özgü olduğundan, optimizasyon için modele özgü bilgilerden yararlanabilirler. TensorFlow içinde XLA, JIT (tam zamanında) derleyicisi olarak da adlandırılır. içindeki bir bayrakla etkinleştirebilirsiniz. @tf.function
Python dekoratörü, şöyle:
@tf.function(jit_compile=True)
XLA’yı TensorFlow’da aşağıdakileri ayarlayarak da etkinleştirebilirsiniz: TF_XLA_FLAGS
ortam değişkeni veya bağımsız çalıştırarak tfcompile
alet.
TensorFlow dışında, XLA programları aşağıdakiler tarafından oluşturulabilir:
Google JAX’ı kullanmaya başlayın
geçtim JAX Hızlı Başlangıç varsayılan olarak bir GPU kullanan Colab’da. İsterseniz bir TPU kullanmayı seçebilirsiniz, ancak aylık ücretsiz TPU kullanımı sınırlıdır. Ayrıca bir çalıştırmanız gerekir özel başlatma Google JAX için bir Colab TPU kullanmak için.
Hızlı başlangıca ulaşmak için Colab’da aç düğmesinin üst kısmındaki JAX’ta Paralel Değerlendirme dokümantasyon sayfası. Bu sizi canlı not defteri ortamına geçirecektir. Ardından, aşağı bırakın Bağlamak Barındırılan bir çalışma zamanına bağlanmak için not defterindeki düğme.
Hızlı başlangıcın bir GPU ile çalıştırılması, JAX’in matris ve doğrusal cebir işlemlerini ne kadar hızlandırabileceğini açıkça ortaya koydu. Daha sonra not defterinde, mikrosaniye cinsinden ölçülen JIT hızlandırılmış süreleri gördüm. Kodu okuduğunuzda, çoğu, derin öğrenmede kullanılan ortak işlevleri ifade ederek hafızanızı çalıştırabilir.
Şekil 1. Google JAX hızlı başlangıcında bir matris matematik örneği.
JAX nasıl kurulur
Bir JAX kurulumu, işletim sisteminize ve CPU, GPU veya TPU sürümü seçimine uygun olmalıdır. CPU’lar için basittir; örneğin, JAX’i dizüstü bilgisayarınızda çalıştırmak istiyorsanız şunu girin:
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
GPU’lar için, sahip olmanız gerekir CUDA ve CuDNN uyumlu bir NVIDIA sürücüsü ile birlikte yüklenir. İhtiyacın olacak oldukça yeni sürümler ikinizde. CUDA ve CuDNN’nin son sürümlerine sahip Linux’ta, önceden oluşturulmuş CUDA uyumlu tekerlekler kurabilirsiniz; aksi takdirde, yapmanız gerekir kaynaktan inşa.
JAX ayrıca önceden oluşturulmuş tekerlekler sağlar Google Bulut TPUs. Cloud TPU’lar, Colab TPU’larından daha yenidir ve geriye dönük uyumlu değildir, ancak Colab ortamları zaten JAX ve doğru TPU desteğini içerir.
JAX API’si
Var JAX API’sine üç katman. En üst düzeyde JAX, NumPy API’sinin bir aynasını uygular, jax.numpy
. Hemen hemen ile yapılabilecek herhangi bir şey numpy
ile yapılabilir jax.numpy
. sınırlaması jax.numpy
NumPy dizilerinden farklı olarak, JAX dizileri değişmezdir, yani bir kez oluşturulduktan sonra içerikleri değiştirilemez.
JAX API’nin orta katmanı jax.lax
NumPy katmanından daha katı ve genellikle daha güçlü olan . içindeki tüm işlemler jax.numpy
sonunda tanımlanan fonksiyonlar cinsinden ifade edilir. jax.lax
. Süre jax.numpy
karışık veri türleri arasında işlemlere izin vermek için argümanları dolaylı olarak teşvik edecek, jax.lax
olmayacak; bunun yerine, açık promosyon işlevleri sağlar.
API’nin en alt katmanı XLA’dır. Herşey jax.lax
işlemler, XLA’daki işlemler için Python sarmalayıcılarıdır. Her JAX işlemi sonunda, JIT derlemesini sağlayan bu temel XLA işlemleri cinsinden ifade edilir.
JAX’in Sınırlamaları
JAX dönüşümleri ve derleme yalnızca işlevsel olarak saf Python işlevlerinde çalışmak üzere tasarlanmıştır. Bir işlevin yan etkisi varsa, print()
deyimi, kod üzerinden birden çok çalıştırmanın farklı yan etkileri olacaktır. A print()
sonraki çalışmalarda farklı şeyler yazdırır veya hiçbir şey yazdırmazdı.
JAX’ın diğer sınırlamaları, yerinde mutasyonlara izin vermemeyi içerir (çünkü diziler değişmezdir). Bu sınırlama, yerinde olmayan dizi güncellemelerine izin verilerek azaltılır:
updated_array = jax_array.at[1, :].set(1.0)
Ayrıca, JAX varsayılan olarak tek duyarlıklı sayılara (float32
), NumPy varsayılan olarak çift duyarlık (float64
). Gerçekten çift hassasiyete ihtiyacınız varsa, JAX’ı şu şekilde ayarlayabilirsiniz: jax_enable_x64
modu. Genel olarak, tek duyarlıklı hesaplamalar daha hızlı çalışır ve daha az GPU belleği gerektirir.
Hızlandırılmış sinir ağı için JAX kullanma
Bu noktada, açıkça belirtilmelidir ki, abilir JAX’ta hızlandırılmış sinir ağlarını uygulayın. Öte yandan, tekerleği neden yeniden icat edelim? Google Araştırma grupları ve DeepMind, JAX tabanlı açık kaynaklı birkaç sinir ağı kitaplığı: Keten örnekler ve nasıl yapılır kılavuzları ile sinir ağı eğitimi için tam özellikli bir kütüphanedir. haiku sinir ağı modülleri içindir, Optaks gradyan işleme ve optimizasyon içindir, RLax RL (pekiştirmeli öğrenme) algoritmaları içindir ve chex güvenilir kod ve test içindir.
JAX hakkında daha fazla bilgi edinin
Buna ek olarak JAX Hızlı BaşlangıçJAX’ın bir öğreticiler dizisi Colab’da çalıştırabileceğinizi (ve çalıştırmanız gerektiğini). İlk öğretici, nasıl kullanılacağını gösterir. jax.numpy
fonksiyonlar, grad
ve value_and_grad
işlevleri ve @jit
dekoratör. Bir sonraki öğretici, JIT derlemesi hakkında daha derine iniyor. Son öğreticide, hem tekli hem de çok ana bilgisayarlı ortamlarda işlevleri nasıl derleyeceğinizi ve otomatik olarak bölümlendireceğinizi öğreniyorsunuz.
Ayrıca JAX referans belgelerini de okuyabilirsiniz (ve yapmalısınız). SSS) ve gelişmiş öğreticileri çalıştırın ( Autodiff Yemek Kitabı) Colab’da. Son olarak, aşağıdakilerden başlayarak API belgelerini okumalısınız: ana JAX paketi.
Telif Hakkı © 2022 IDG Communications, Inc.
[ad_2]
Kaynak : https://www.infoworld.com/article/3666812/what-is-google-jax-numpy-on-accelerators.html#tk.rss_all