01 - トランスフォーマーの注意メカニズム: 完全ガイド
2017 年、「必要なのは注意だけ」という Google Brain の論文が状況を永遠に変えました ディープラーニングの分野。著者である Vaswani らは、アーキテクチャを提案しました。 と呼ばれるメカニズムに完全に基づいています 注意、ネットワークを排除する それまで主流だった再帰型 (RNN) および畳み込みシステム。結果は、 アーキテクチャ トランスフォーマー、今日は GPT-4 の基地で、クロード、ラマ 3、 BERT、T5、ビジョン トランスフォーマー、およびほぼすべてのフロンティア モデル。
注意のメカニズムを理解することは学術的な訓練ではありません。注意のメカニズムを理解するための基礎です。 彼らは、LoRA の微調整、量子化、プルーニング、デプロイメントなどの技術を構築します。 エッジ デバイス、このシリーズで取り上げるすべてのトピック。理解が無いままに 注意がどのように機能するかをしっかりと理解していても、その後の最適化はそれぞれブラックボックスのままです。
このシリーズの最初の記事では 高度なディープラーニングとエッジ展開、 最初の直感から数式に至るまで、注意を深く探っていきます。 PyTorch での実装から、Flash Attendant 3 や Grouped-Query などの最新のバリアントまで 注意。
シリーズ概要
| # | アイテム | 集中 |
|---|---|---|
| 1 | あなたはここにいます - トランスフォーマーの注意メカニズム | 自己注意、マルチヘッド、完全なアーキテクチャ |
| 2 | LoRA、QLoRA、アダプターによる微調整 | パラメータ効率の高い微調整 |
| 3 | モデルの量子化 | INT8、INT4、GPTQ、AWQ |
| 4 | 剪定と圧縮 | パラメータの削減、蒸留 |
| 5 | 知識の蒸留 | 教師と生徒、知識の伝達 |
| 6 | オラマとLLMローカル | 局所推論、最適化 |
| 7 | ビジョントランスフォーマー | ViT、DINO、画像分類 |
| 8 | エッジ展開 | ONNX、TensorRT、モバイル デバイス |
| 9 | NAS と AutoML | ニューラル アーキテクチャの検索 |
| 10 | ベンチマークと最適化 | プロファイリング、メトリクス、チューニング |
何を学ぶか
- RNN と LSTM は長いシーケンスには十分ではなかったため
- アテンションメカニズムの背後にある直感: クエリ、キー、値
- スケーリングされたドット積注意の完全な公式
- マルチヘッド アテンションの仕組みと複数のヘッドが必要な理由
- セルフアテンションとクロスアテンションの違い
- 位置エンコーディングが順序の問題を解決する方法
- 完全な Transformer アーキテクチャ: エンコーダーとデコーダー
- PyTorch での 1 行ずつの実践的な実装
- 最新のバリエーション: フラッシュ アテンション 3、GQA、スライディング ウィンドウ アテンション
- 実際のアーキテクチャ: GPT (デコーダのみ)、BERT (エンコーダのみ)、T5 (エンコーダ-デコーダ)
1. 順序の問題: 注意を払う前に
なぜ注目が革命だったかを理解するには、次のモデルから始めなければなりません。 彼らは彼女に先立った。シーケンス (テキスト、オーディオ、時系列) の深層学習は、 2 つのアーキテクチャによって支配されています。 RNN (リカレント ニューラル ネットワーク) そして LSTM (長短期記憶).
1.1 RNN とシーケンシャル ボトルネック
RNN はシーケンスを一度に 1 トークンずつ処理し、隠し状態を渡します あるタイムステップから次のタイムステップへ。各トークンは非表示状態を更新し、その状態として機能します。 これまで見てきたシーケンスの「記憶」。
Input: x1 -----> x2 -----> x3 -----> x4 -----> x5
| | | | |
v v v v v
Hidden: h1 -----> h2 -----> h3 -----> h4 -----> h5
| |
v v
Output: y1 y5
Problema: h5 deve "ricordare" x1 attraverso 4 passaggi.
Con sequenze di 1000+ token, l'informazione di x1 svanisce.
これが問題です 長期にわたる依存関係。 「3年前に保護施設から引き取られて暮らしていた猫は、 家族と一緒に幸せに、 彼は寝ていました 「ソファの上」の場合、RNN は「猫」を接続する必要があります 数十の中間トークンを「眠った」。ベクトルに圧縮された隠れた状態 サイズが固定されているため、必然的に古い情報が失われます。
1.2 LSTM: 解決策ではなく改善
LSTM は、ゲートメカニズム (入力ゲート、忘却ゲート、出力ゲート) を導入しました。 どの情報を保持し、どの情報を破棄するかを制御します。これにより状況は改善され、 しかし解決しませんでした。 LSTM には依然として 2 つの基本的な問題があります。
RNN/LSTM の制限
| 問題 | 説明 | インパクト |
|---|---|---|
| 連続性 | 各トークンは前のトークンに依存するため、並列化できません | 長いシーケンスのトレーニングが非常に遅い |
| ボトルネック | すべての情報は単一の通信事業者を通過します | 100 ~ 200 トークンを超えるシーケンスによる情報損失 |
| グラデーション消失 | バックプロパゲーション中に勾配が指数関数的に縮小する | モデルは遠い関係を学習できません |
各トークンがアクセスできるメカニズムが必要でした。 直接 ある シーケンス内の他のトークン。中間状態を経由する必要はありません。 この仕組みと、注意.
2. 注意とは何か:直観
アテンションは、モデルが次のことを行うことを可能にするメカニズムです。 自分自身に集中する 気をつけて 出力を生成するときに、入力の最も関連性の高い部分に重点を置きます。代わりに シーケンス全体を 1 つのベクトルに圧縮するには、注意がつながりを生み出します。 各出力位置とすべての入力位置の間で直接接続されます。
例え: 図書館での検索
あなたが本屋で「トランスフォーマーの歴史」に関する情報を探していると想像してください。 あなたは心に一つのことを考えています リクエスト (質問)。各本には、 タイトル (Key) はその内容を説明します。タイトルが質問と一致する場合は、 を抽出します コンテンツ その本の(価値)。注意力が働きます まさに次のように:
- クエリ (Q): 「私は何を探しているのですか?」 - 現在のトークンが尋ねる質問
- キー (K): 「この商品には何が入っているの?」 - シーケンス内の各トークンのラベル
- 値 (V): 「情報は次のとおりです」 – 各トークンの実際の内容
このメカニズムは、 互換性スコア クエリと各キーの間。 このスコアは、対応する値にどの程度の注意を払うかを決定します。スコアが来る ソフトマックスを介して正規化され、合計が 1 になる重みが得られ、最終結果は 1 になります。 値の加重平均。
Token corrente: "dormiva"
Query di "dormiva": "Chi sta compiendo questa azione?"
Key Score Peso (softmax)
"Il" -----> 0.1 0.02
"gatto" -----> 4.8 0.65 <-- Alta attenzione!
"che" -----> 0.3 0.03
"era" -----> 0.2 0.02
"stato" -----> 0.1 0.02
"adottato" -----> 1.2 0.08
"..." -----> ... ...
"sul" -----> 2.1 0.12
"divano" -----> 0.8 0.06
Output = 0.02 * V("Il") + 0.65 * V("gatto") + 0.03 * V("che") + ...
Il modello ha imparato che "gatto" e il soggetto di "dormiva",
anche se sono separati da molti token.
3. スケーリングされたドット積の注意: 公式
トランスフォーマーで使用される注意の数学的定式化と スケーリング済み 内積注意。シンプルかつ計算的にエレガントです 行列演算の使用により効率的です。
注意の公式
アテンション(Q, K, V) = ソフトマックス(Q * K^T / sqrt(d_k)) * V
どこ:
- Q (クエリ): サイズ (n x d_k) の行列。n はトークンの数、d_k はクエリ/キーのサイズです。
- K (キー): 次元 (n x d_k) の配列
- V (値): 次元 (n x d_v) の行列。d_v は値の次元です。
- d_k: キーのサイズ。スケーリング係数として使用されます。
- Q*K^T: クエリとキーの間のスカラー積 (n x n スコア行列)
- /sqrt(d_k): 勾配を安定させるためのスケーリング係数
- ソフトマックス: スコアを合計が 1 になる重みに正規化します。
3.1 スケーリングが必要な理由
要因がなければ sqrt(d_k)、Q と K の間の内積により増加する値が生成されます。
寸法 d_k に比例します。 d_k = 512 の場合、スカラー積は次の値に達します。
非常に大きな値。これらの値がソフトマックスに達すると、分布が生成されます。
ほぼワンホット (1 つの重みが 1 に近く、他のすべての重みが 0 に近い)、勾配が非常に小さい。
スケーリングによってこの問題は回避されます。
Senza scaling (d_k = 512):
Score raw: [120.3, 115.8, 2.1, -5.4]
Softmax: [0.989, 0.011, 0.000, 0.000] <-- Quasi one-hot, gradienti ~0
Con scaling (/ sqrt(512) = / 22.6):
Score scaled: [5.32, 5.12, 0.09, -0.24]
Softmax: [0.44, 0.36, 0.10, 0.10] <-- Distribuzione morbida, gradienti sani
3.2 ステップバイステップ: アテンションの計算
3 つのトークンのシーケンスと d_k = 4 の具体的な数値例を見てみましょう。
Sequenza: ["The", "cat", "sat"]
Step 1: Genera Q, K, V tramite proiezioni lineari
Q = X * W_Q K = X * W_K V = X * W_V
Q = [[1.0, 0.5, 0.3, 0.2], (The)
[0.8, 1.2, 0.1, 0.9], (cat)
[0.3, 0.4, 1.1, 0.6]] (sat)
K = [[0.9, 0.6, 0.4, 0.1],
[0.7, 1.1, 0.2, 0.8],
[0.4, 0.3, 1.0, 0.5]]
V = [[0.2, 0.8, 0.1, 0.5],
[0.9, 0.3, 0.7, 0.2],
[0.4, 0.6, 0.5, 0.8]]
Step 2: Calcola Q * K^T (matrice 3x3 di score)
Score[i][j] = dot(Q[i], K[j])
Scores = [[1.19, 1.37, 0.89],
[1.35, 1.77, 1.10],
[0.98, 1.15, 1.42]]
Step 3: Scala per sqrt(d_k) = sqrt(4) = 2
Scaled = [[0.60, 0.69, 0.45],
[0.68, 0.89, 0.55],
[0.49, 0.58, 0.71]]
Step 4: Applica softmax per riga
Weights = [[0.33, 0.36, 0.31], (The guarda The, cat, sat)
[0.32, 0.40, 0.28], (cat guarda The, cat, sat)
[0.29, 0.32, 0.39]] (sat guarda The, cat, sat)
Step 5: Moltiplica pesi per V
Output[0] = 0.33*V[0] + 0.36*V[1] + 0.31*V[2]
= [0.51, 0.56, 0.39, 0.48]
複雑さへの注意
行列 Q * K^T には次元があります n×nここで、n はシーケンスの長さです。 n = 1000 の場合、行列には 1,000,000 個の要素があります。 n = 100,000 の場合、100 億個の要素があります。 この O(n^2) 二次計算の複雑さが、Transformers の主なボトルネックです。 そして、Flash アテンションやスライディング ウィンドウ アテンションのような亜種が開発された理由も説明します。
4. マルチヘッドアテンション: 複数の角度から見る
1 つのアテンション操作で、トークン間の 1 つのタイプの関係が取得されます。しかし、人間関係 シーケンス内には、構文関係 (主語と動詞)、意味関係 (同義語、 コンテキスト)、位置(隣接するトークン)、その他多数。そこには マルチヘッドアテンション は、異なる投影と並行して注意を実行することで、この問題を解決します。
Input X (dimensione: n x d_model, es. n x 512)
|
+---> Head 1: Q1=X*Wq1, K1=X*Wk1, V1=X*Wv1 --> Attention(Q1,K1,V1) --> Z1
| (d_k = d_model/h = 64)
+---> Head 2: Q2=X*Wq2, K2=X*Wk2, V2=X*Wv2 --> Attention(Q2,K2,V2) --> Z2
|
+---> Head 3: Q3=X*Wq3, K3=X*Wk3, V3=X*Wv3 --> Attention(Q3,K3,V3) --> Z3
|
+---> ...
|
+---> Head 8: Q8=X*Wq8, K8=X*Wk8, V8=X*Wv8 --> Attention(Q8,K8,V8) --> Z8
|
v
Concatena: [Z1; Z2; Z3; ... Z8] (dimensione: n x d_model)
|
v
Proiezione finale: Concat * W_O (dimensione: n x d_model)
Con h = 8 頭と d_model = 512、各ヘッドは 1 つで動作します。
次元空間 d_k = d_v = 512 / 8 = 64。総計算コスト
ヘッドが動作するため、フルサイズのシングルアテンションと同様です。
より小さな部分空間で並行して。
すべての頭が学ぶこと
実証研究によると、さまざまな頭がさまざまなパターンに特化していることがわかりました。
- ヘッド 1: 主語と動詞の関係を学ぶことができた
- ヘッド 2: 共参照関係(代名詞とその先行詞)を学ぶことができた
- ヘッド 3: 隣接するトークン (ローカル N グラム) に焦点を当てる可能性があります
- ヘッド 4: 文間の長期的な関係を捉えることができる
- その他のヘッド: 構文パターン、エンティティ、談話構造
マルチヘッド アテンションフォーミュラ
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) * W_O
どこ head_i = アテンション(Q * W_Qi, K * W_Ki, V * W_Vi)
元の論文の一般的なパラメータ: d_model = 512、h = 8、d_k = d_v = 64。 最新のモデル: d_model = 4096 ~ 8192、h = 32 ~ 128。
5. 自己注意: 他のすべてを監視するトークン
La 自意識 クエリ、キー、値の由来となる特定のケース すべて同じシーケンスからのものです。各トークンは独自のクエリ、キー、値を生成し、そのクエリを使用します。 他のすべてのトークン (それ自体を含む) のキーを「クエリ」します。
Frase: "The cat sat on the mat"
Attention Matrix (ogni riga somma a 1.0):
The cat sat on the mat
The [0.15 0.25 0.10 0.05 0.15 0.30]
cat [0.10 0.20 0.35 0.05 0.05 0.25]
sat [0.05 0.40 0.15 0.20 0.05 0.15]
on [0.05 0.10 0.30 0.10 0.15 0.30]
the [0.20 0.15 0.05 0.10 0.10 0.40]
mat [0.10 0.15 0.15 0.25 0.15 0.20]
Osservazioni:
- "sat" presta molta attenzione a "cat" (0.40) --> soggetto-verbo
- "on" presta attenzione a "sat" (0.30) e "mat" (0.30) --> relazione spaziale
- "the" (seconda occorrenza) presta molta attenzione a "mat" (0.40) --> articolo-sostantivo
自己注意はトランスフォーマーの核心です。そしてモデルの構築を可能にするもの 文脈上の表現: 各トークンの表現は埋め込まれます 関連性によって重み付けされたシーケンス全体からの情報。 「銀行」という言葉には、 「川銀行」と「銀行口座」では、周囲のトークンが異なるため表現が異なります。 注意を通じてその表現に影響を与えます。
デコーダにおけるマスクされたセルフアテンション
生成モデル (デコーダ) では、自己注意は 仮面舞踏会: 毎 token は以前のトークンのみを参照でき、将来のトークンは参照できません。これは実装されています ソフトマックスの前に将来のトークンのスコアを -infinity に設定し、 重みはゼロに等しい。これは、 因果関係の注意 GPT、Llamaで使用 そしてすべての自己回帰モデル。
Mask per sequenza di 5 token (0 = visibile, -inf = mascherato):
t1 t2 t3 t4 t5
t1 [ 0 -inf -inf -inf -inf ]
t2 [ 0 0 -inf -inf -inf ]
t3 [ 0 0 0 -inf -inf ]
t4 [ 0 0 0 0 -inf ]
t5 [ 0 0 0 0 0 ]
Dopo la softmax:
t1 vede solo [t1]
t2 vede solo [t1, t2]
t3 vede solo [t1, t2, t3]
...e cosi via
6. クロスアテンション: エンコーダーとデコーダーが通信するとき
La 交差注意 (またはエンコーダとデコーダの注意)とそのメカニズム デコーダがエンコーダ出力を「監視」できるようにします。自分自身への注意とは異なり、 ここで、Q、K、V は同じシーケンスから来ており、クロスアテンションではクエリが来ています。 デコーダからのキー/値とエンコーダからのキー/値。
ENCODER (processa l'input, es. frase in italiano):
"Il gatto dorme" --> Encoder --> Rappresentazioni encoder (K_enc, V_enc)
DECODER (genera l'output, es. traduzione in inglese):
"The cat" --> Self-Attention mascherata --> Q_dec
CROSS-ATTENTION:
Q = Q_dec (dal decoder: "cosa sto cercando per generare il prossimo token?")
K = K_enc (dall'encoder: "cosa contiene ogni token dell'input?")
V = V_enc (dall'encoder: "ecco le informazioni dell'input")
Il decoder può "guardare" tutta la sequenza dell'encoder
per decidere quale token generare dopo.
相互注意は建築の基本です エンコーダ-デコーダ 使用済み 機械翻訳 (T5、mBART)、テキストの要約、条件付き生成用。 たとえば、T5 では、エンコーダーが入力テキストを処理し、デコーダーがテキストを生成します。 出力。クロスアテンションを使用して各生成ステップでエンコーダーを調べます。
トランスフォーマーにおける 3 つの注意点
| タイプ | Qソース | ソースK、V | どこで使用するか |
|---|---|---|---|
| セルフアテンション (エンコーダー) | エンコーダ入力 | エンコーダ入力 | BERTエンコーダ、T5エンコーダ |
| 仮面をかぶった自己注意 | 入力デコーダ | 入力デコーダ | GPT、ラマ、T5 デコーダー |
| クロスアテンション | デコーダ | エンコーダ出力 | T5デコーダ、mBART |
7. 位置エンコーディング: トランスフォーマーが順序を知る方法
トークンを順番に処理する RNN とは異なり、 自意識 順序に関して不変:結果は変わりません 入力トークンを並べ替えた場合。 「猫は魚を食べる」と「猫は魚を食べる」 追加のメカニズムなしで同じ出力が生成されます。の 位置的な エンコーディング の場所に関する情報を追加することで、この問題を解決します。 それぞれのトークン。
7.1 正弦波位置エンコーディング (オリジナル論文)
元の論文では、正弦関数を使用して位置エンコーディングを生成しています。
正弦波位置エンコーディング式
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
どこ 位置 およびシーケンス内のトークンの位置 e i そしてサイズ。 偶数の位置ではサインが使用され、奇数の位置ではコサインが使用されます。それぞれに異なる周波数 サイズによりモデルは相対的な位置関係を学習できます。
Posizione 0: [sin(0), cos(0), sin(0), cos(0), ...] = [0.00, 1.00, 0.00, 1.00, ...]
Posizione 1: [sin(1), cos(1), sin(0.01), cos(0.01)] = [0.84, 0.54, 0.01, 1.00, ...]
Posizione 2: [sin(2), cos(2), sin(0.02), cos(0.02)] = [0.91, -0.42, 0.02, 1.00, ...]
L'embedding finale di ogni token e:
token_embedding = word_embedding + positional_encoding
Le frequenze più basse (dimensioni alte) catturano posizioni globali.
Le frequenze più alte (dimensioni basse) catturano posizioni locali.
7.2 学習された位置エンコーディング
位置正弦波エンコーディングおよび 学習された埋め込み (学んだ): トレーニング可能なパラメータの配列。位置ごとに 1 行。 このアプローチは BERT と GPT-2 で使用されます。利点は、モデルが学習できることです。 特定のタスクに最適な位置パターン。デメリットは長さです シーケンスの最大値であり、トレーニング時に固定されます。
位置エンコーディングの比較
| タイプ | 利点 | 短所 | で使用されます |
|---|---|---|---|
| 正弦波 | 追加のパラメータはなく、より長いシーケンスに一般化されます | 固定パターン、タスクに最適化されていない | オリジナルトランス |
| 学んだ | 特定のタスク向けに最適化 | 固定最大長、複数のパラメータ | バート、GPT-2 |
| RoPE(ロータリー) | 相対位置をキャプチャ、拡張可能 | 実装の複雑さの増大 | ラマ、ミストラル、GPT-NeoX |
| アリバイ | パラメータなし、適切な外挿 | 線形バイアスが制限となる可能性がある | ブルーム議員 |
8. 完全なトランスアーキテクチャ
すべてのパズルのピースを手に入れたら、トランスフォーマーのアーキテクチャを組み立てることができます。 完了しました。オリジナルのトランスフォーマーは次のもので構成されています。 スタックエンコーダ そして スタックデコーダ, each made up of N identical layers (N = 6 in the paper オリジナル)。
INPUT EMBEDDING + POSITIONAL ENCODING
|
+---------v-----------+
| ENCODER STACK | x N (6 nel paper originale)
| |
| +--Multi-Head-------+
| | Self-Attention |
| +------|------------+
| v
| +--Add & Norm-------+ (residual connection + layer norm)
| +------|------------+
| v
| +--Feed-Forward-----+ (2 layer lineari con ReLU/GELU)
| | Network | (d_model -> d_ff -> d_model)
| +------|------------+ (d_ff = 4 * d_model = 2048)
| v
| +--Add & Norm-------+
| +------|------------+
+---------|-----------+
|
| (K, V per cross-attention)
|
OUTPUT EMBEDDING + POSITIONAL ENCODING
|
+---------v-----------+
| DECODER STACK | x N
| |
| +--Masked Multi-----+
| | Head Self-Attn | (causal mask: vede solo il passato)
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
| v
| +--Cross-Attention--+ (Q dal decoder, K/V dall'encoder)
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
| v
| +--Feed-Forward-----+
| +------|------------+
| v
| +--Add & Norm-------+
| +------|------------+
+---------|-----------+
|
v
Linear + Softmax
|
v
Output Probabilities (vocabulario)
8.1 残留接続
各サブレイヤー (アテンションまたはフィードフォワード) には、 残りの接続:
サブレイヤーの出力が入力に追加されます。式 e
output = LayerNorm(x + SubLayer(x))。残留接続が問題を解決する
ディープネットワークにおける勾配消失問題、勾配の流れを可能にする
ショートカット接続を通じて直接。
8.2 フィードフォワードネットワーク
アテンションの後、各トークンは フィードフォワードネットワーク 各位置に独立して適用されます。 2 つの線形変換で構成されます 非線形活性化を伴う (元の論文では ReLU、モデルでは GELU または SwiGLU) 現代):
FFN(x) = W2 * アクティベーション(W1 * x + b1) + b2
内部寸法 (d_ff) は通常、d_model の 4 倍です。 d_model = 512 の場合、 d_ff = 2048。Llama 3 のような最新のモデルでは、d_ff は d_model = 4096 で 14,336 まで上がります。
8.3 層の正規化
La レイヤーの正規化 次元に沿ってアクティベーションを正規化します (バッチではなく) 機能の。トレーニングを安定させ、収束を加速します。 オリジナルのTransformerではPost-LNが使用されています(残留接続後の正規化)。 しかし、最新のモデルのほとんどは LN以前 (正規化 サブレイヤーの前)、トレーニング中により安定します。
9. PyTorch の実装: ゼロからのセルフアテンション
理論からコードに移りましょう。スケーリングされたドット積アテンションを実装し、 事前に構築されたモジュールを使用せずに、PyTorch で最初からマルチヘッド アテンションを作成します。
9.1 スケーリングされた内積注意
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None,
dropout: nn.Dropout = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Scaled Dot-Product Attention.
Args:
query: (batch, heads, seq_len, d_k)
key: (batch, heads, seq_len, d_k)
value: (batch, heads, seq_len, d_v)
mask: (batch, 1, 1, seq_len) o (batch, 1, seq_len, seq_len)
dropout: modulo dropout opzionale
Returns:
output: (batch, heads, seq_len, d_v)
attention_weights: (batch, heads, seq_len, seq_len)
"""
d_k = query.size(-1)
# Step 1: Calcola gli score Q * K^T / sqrt(d_k)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Step 2: Applica la maschera (opzionale)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 3: Softmax per ottenere i pesi di attention
attention_weights = F.softmax(scores, dim=-1)
# Step 4: Dropout opzionale sui pesi
if dropout is not None:
attention_weights = dropout(attention_weights)
# Step 5: Moltiplica pesi per Value
output = torch.matmul(attention_weights, value)
return output, attention_weights
9.2 マルチヘッドアテンション
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention implementata da zero.
Parametri:
d_model: dimensione del modello (es. 512)
num_heads: numero di teste di attention (es. 8)
dropout: tasso di dropout (es. 0.1)
"""
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, \
f"d_model ({d_model}) deve essere divisibile per num_heads ({num_heads})"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # dimensione per testa
# Proiezioni lineari per Q, K, V e output
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.w_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Riorganizza il tensore da (batch, seq_len, d_model)
a (batch, num_heads, seq_len, d_k).
"""
batch_size, seq_len, _ = x.size()
x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, heads, seq_len, d_k)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = None
) -> torch.Tensor:
"""
Forward pass.
Per Self-Attention: query = key = value = X
Per Cross-Attention: query = decoder, key = value = encoder
"""
batch_size = query.size(0)
# 1. Proiezioni lineari
q = self.w_q(query) # (batch, seq_len, d_model)
k = self.w_k(key)
v = self.w_v(value)
# 2. Dividi in teste
q = self.split_heads(q) # (batch, heads, seq_len, d_k)
k = self.split_heads(k)
v = self.split_heads(v)
# 3. Scaled Dot-Product Attention
attn_output, attn_weights = scaled_dot_product_attention(
q, k, v, mask=mask, dropout=self.dropout
)
# 4. Concatena le teste
# (batch, heads, seq_len, d_k) -> (batch, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# 5. Proiezione finale
output = self.w_o(attn_output)
return output
9.3 使用例
# Configurazione
batch_size = 2
seq_len = 10
d_model = 512
num_heads = 8
# Crea il modulo
mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
# Input random (simula una sequenza di token embeddings)
x = torch.randn(batch_size, seq_len, d_model)
# Self-Attention (query = key = value)
output = mha(query=x, key=x, value=x)
print(f"Input shape: {x.shape}") # torch.Size([2, 10, 512])
print(f"Output shape: {output.shape}") # torch.Size([2, 10, 512])
# Causal mask per decoder (triangolare inferiore)
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
# Masked Self-Attention
output_masked = mha(query=x, key=x, value=x, mask=causal_mask)
print(f"Masked output shape: {output_masked.shape}")
# Cross-Attention (query dal decoder, key/value dall'encoder)
encoder_output = torch.randn(batch_size, 20, d_model) # sequenza encoder più lunga
decoder_input = torch.randn(batch_size, seq_len, d_model)
cross_attn_output = mha(
query=decoder_input,
key=encoder_output,
value=encoder_output
)
print(f"Cross-attention shape: {cross_attn_output.shape}") # [2, 10, 512]
10. 注意力の現代的バリエーション
標準的な注意の O(n^2) 二次複雑さが、 多数の最適化されたバリアント。これらのバリエーションは最新モデルの基本です 10万から100万以上のトークンまでのコンテキストを管理します。
10.1 フラッシュ アテンション (v1、v2、v3)
フラッシュアテンションTri Dao と同僚によって開発された、数学を変更しない 注目を集めていますが、その実装はハードウェア レベルで根本的に最適化されています。アイデア キーを使用して、注意スコアの完全な n x n マトリックスを具体化することを避けます。 代わりに 1 つのアプローチを使用する GPU メモリ (HBM) タイル張りの 誰が働いていますか 完全に SRAM (高速オンチップ メモリ) 内にあります。
フラッシュアテンションの進化
| バージョン | Anno | 主要なイノベーション | パフォーマンス |
|---|---|---|---|
| フラッシュアテンション1 | 2022年 | タイリング + 融合カーネル、IO 認識 | 標準と比較して 2 ~ 4 倍のスピードアップ |
| フラッシュアテンション2 | 2023年 | 並列処理が向上し、通信量が減少 | v1 の 2 倍 |
| フラッシュアテンション3 | 2024年 | Hopper GPU、FP8、ワープの特殊化での非同期 | H100 で最大 740 TFLOPS (FP16)、FP8 で 1.2 PFLOPS |
Flash Attendant 3 は、NVIDIA Hopper GPU (H100/H200) の特定の特性を利用しています。 非同期 Tensor Core と TMA (Tensor Memory Accelerator) の間でオーバーレイする 計算とデータ転送、 ワープ特化 インターリーブ用 matmul 演算とソフトマックス演算の最適化、e FP8 ブロック量子化 数値誤差は単純な FP8 実装よりも 2.6 倍低くなります。フラッシュ アテンションは、PyTorch、Hugging Face Transformers、vLLM、および TensorRT-LLM に統合されました。
10.2 マルチクエリアテンション (MQA)
2019 年に Shazeer によって提案された、 マルチクエリアテンション 大幅に減少する 推論中に KV キャッシュに必要なメモリ。別個のセットを用意する代わりに 各ヘッドのキーと値の MQA 共有 シングル すべての中で K と V のセット ヘッド、さまざまなクエリを維持します。
Multi-Head Attention (MHA) - Standard:
Head 1: Q1, K1, V1 | KV Cache per head: d_k * seq_len * 2
Head 2: Q2, K2, V2 | KV Cache totale: h * d_k * seq_len * 2
... | Con h=32, d_k=128, seq=4096:
Head h: Qh, Kh, Vh | = 32 * 128 * 4096 * 2 = 33.5 MB per layer
Multi-Query Attention (MQA):
Head 1: Q1 \
Head 2: Q2 |--- K_shared, V_shared
... | KV Cache totale: d_k * seq_len * 2
Head h: Qh / = 128 * 4096 * 2 = 1.05 MB per layer (32x meno!)
10.3 グループ化されたクエリ アテンション (GQA)
GQA、エインズリーらによって導入されました。 2023 年に、MHA と MQA の間の妥協案。 すべてのヘッド (MQA) 間で 1 つの K/V セットを共有したり、各ヘッドに 1 つの K/V セットを共有したりする代わりに、 ヘッド (MHA)、GQA グループがヘッド gグループ、各グループと 一連の K/V を共有します。 g = 1 の場合は MQA が得られ、g = h の場合は MHA が得られます。
Esempio: 8 query heads, 2 KV groups (g=2)
Gruppo 1: Q1, Q2, Q3, Q4 condividono K1, V1
Gruppo 2: Q5, Q6, Q7, Q8 condividono K2, V2
KV Cache: g * d_k * seq_len * 2 = 2 * 128 * 4096 * 2 = 2.1 MB
(16x meno di MHA, ma solo 2x più di MQA)
Modelli che usano GQA:
- Llama 2 (70B): 8 KV heads, 64 query heads
- Llama 3: GQA con rapporto 8:1
- Mistral 7B: 8 KV heads, 32 query heads
アテンションのバリエーションの比較
| 変異体 | KVヘッド | KVキャッシュメモリ | 品質 | モデル |
|---|---|---|---|---|
| MHA | h (すべて) | 最大 | 改善する | BERT、GPT-2、GPT-3 |
| GQA | g (グループ) | 時間/日の削減 | MHAとほぼ同等 | ラマ 2/3、ミストラル |
| MQA | 1 | 最小限 | 若干の減少 | パルコン、ファルコン |
10.4 スライディングウィンドウの注意
La スライディングウィンドウの注意ミストラルとロングフォーマーで使用される、制限 各位置の w トークンのローカル ウィンドウに注目してください。計算する代わりに シーケンス全体 (O(n^2)) に注目すると、各トークンは前の w トークンのみを参照します。 複雑さを O(n * w) に軽減します。
Sequenza: t1 t2 t3 t4 t5 t6 t7 t8
Attention di t5 (window=3): vede solo [t3, t4, t5]
Attention di t8 (window=3): vede solo [t6, t7, t8]
Attention Matrix (1 = visibile, 0 = mascherato):
t1 t2 t3 t4 t5 t6 t7 t8
t1 [ 1 0 0 0 0 0 0 0 ]
t2 [ 1 1 0 0 0 0 0 0 ]
t3 [ 1 1 1 0 0 0 0 0 ]
t4 [ 0 1 1 1 0 0 0 0 ]
t5 [ 0 0 1 1 1 0 0 0 ]
t6 [ 0 0 0 1 1 1 0 0 ]
t7 [ 0 0 0 0 1 1 1 0 ]
t8 [ 0 0 0 0 0 1 1 1 ]
L'informazione NON si perde: attraverso più layer stacked,
l'informazione di t1 può raggiungere t8 per propagazione.
Con L layer e window w, la reception field effettiva e L * w.
10.5 リングアテンションとページドアテンション
非常に長いコンテキスト (100 万トークン以上) では、さらなる革新が現れています。
- 呼び出し音の注意: アテンションの計算を複数の GPU に分散します リング状に組織されます。各 GPU はシーケンスのセグメントに対するアテンションを計算します。 そして結果を次の GPU に渡します。 RingX (2025) は 94% の効率を達成 最大 4096 個の GPU、100 万個のトークン シーケンス。
- ページ付き注意: 仮想メモリ管理からインスピレーションを得た オペレーティング システムでは、KV キャッシュを不連続なブロック (ページ) に割り当て、 記憶の断片化。これは vLLM の基礎であり、最大のバッチ サイズを許可します。 76倍も高い。
- フレックスアテンション (PyTorch): いくつかの機能をサポートする統合 API 未満のアテンションのバリアント (GQA、コーザル、スライディング ウィンドウ、PagesAttention) 専用実装と比較して 5% のオーバーヘッド。
11. アプリケーション: 変圧器アーキテクチャの実践
Transformer アーキテクチャにより、3 つの主要なモデル ファミリが誕生しました。 注意の使い方が異なります。
11.1 エンコーダのみ: BERT と派生関数
エンコーダ専用モデルでは、 双方向の自己注意: 各トークン シーケンス内の他のすべてのトークン (以前のトークンとそれらのトークンの両方) を確認できます。 後続のもの。そのため、言語理解のタスクに最適です。
BERT (トランスフォーマーからの双方向エンコーダー表現)
- 事前トレーニング: マスクされた言語モデル (MLM) + 次の文の予測
- 注意: 双方向のセルフアテンション (シーケンス全体を見る)
- タスク: 分類、固有表現認識、質問応答
- バリエーション: ロベルタ、アルバート、デベルタ、ディスティルバート
11.2 デコーダのみ: GPT と LLM ファミリ
デコーダ専用モデルでは、 仮面をかぶった自己注意(因果関係): 各トークン 以前のトークンのみが表示されます。これらは自己回帰テキスト生成用に最適化されています。
デコーダ専用モデル
| モデル | パラメータ | 注意のバリエーション | コンテキストウィンドウ |
|---|---|---|---|
| GPT-3 | 175B | 標準MHA | 2K~4Kトークン |
| GPT-4 | ~1.8T (MoE) | GQA(推定) | 128,000トークン |
| ラマ 3 405B | 405B | GQA + RoPE | 128,000トークン |
| ミストラル 7B | 7.3B | GQA + スライディング ウィンドウ | 32,000 トークン |
| クロード (人族) | 非公開 | 非公開 | 200,000トークン |
11.3 エンコーダ/デコーダ: T5 および Seq2Seq モデル
エンコーダ/デコーダ モデルは、次の 3 種類の注意をすべて使用します。 自意識 エンコーダの双方向, デコーダでのマスクされた自己注意 e 交差注意 デコーダとエンコーダの間。次のようなタスクに最適です。 入力を出力 (翻訳、要約、質問への回答) に変換します。
エンコーダ/デコーダ モデル
- T5: 「テキストからテキストへの転送トランスフォーマー」 - 各タスクはテキストインテキストアウトとして定式化されます。
- バート: 生成と理解のためのノイズ除去オートエンコーダー
- mBART: BART多言語翻訳対応
- フラン-T5: T5 は命令チューニングで指示されています
11.4 ビジョントランスフォーマー (ViT)
注意は文章だけにとどまりません。ザ ビジョントランスフォーマー を適用します 画像への自己注意、画像をパッチ (例: 16x16 ピクセル) に分割します。 各パッチを「トークン」として扱います。これは、注意力が あらゆる種類のシーケンシャル データに適用できる一般的なメカニズム。
Immagine 224x224 pixel
|
v
Dividi in patch 16x16: (224/16)^2 = 196 patch
|
v
Ogni patch -> flatten -> proiezione lineare -> patch embedding
|
v
[CLS] + 196 patch embeddings + positional encoding
|
v
Transformer Encoder (self-attention su 197 token)
|
v
[CLS] token -> classificazione dell'immagine
結論と次のステップ
この記事では、問題から注意のメカニズム全体を取り上げました。 RNN の長期依存関係、クエリキー値の直観、公式 スケーリングされたドット積の注意からマルチヘッドの注意、アーキテクチャに至るまで トランスフォーマー完成。 PyTorch e でセルフアテンションを一から実装しました 数百万のトークンを使用したモデルを可能にする最新のバリエーションを調査しました 文脈の。
注意は、最新のディープラーニングすべてが構築される基本的な要素です。 どのように機能するかを理解すると、一部の最適化が機能する理由を理解できるようになります。 特定のモデルが他のモデルよりも高速である理由と、適切なアーキテクチャを選択する方法 あなたのユースケースに合わせて。
覚えておくべき重要な概念
- 注意 ボトルネックを発生させずに、任意のトークンのペア間の直接接続を可能にします。
- スケーリング (sqrt(d_k)) ソフトマックスでの不安定な勾配を防止します
- マルチヘッド 追加コストなしでさまざまな関係を並行してキャプチャします
- 自己注意 文脈に応じた表現を作成します。 クロスアテンション エンコーダとデコーダを接続する
- 位置エンコーディング 次数情報を提供します (正弦波、学習済み、RoPE)
- フラッシュアテンション 数学を変更せずにハードウェア実装を最適化する
- GQA 品質 (MHA) と効率 (MQA) の間の最適な妥協点
Nel 次の記事 シリーズの中で、 微調整 LoRA、QLoRA、アダプターを備えた変圧器の: 事前トレーニングされたモデルを当てはめる方法 パラメータのほんの一部を変更するだけで特定のタスクに適用できるため、 GPU とメモリのコストは大幅に増加します。
追加リソース
- 原紙: 「必要なのは注意だけです」 (Vaswani 他、2017)
- フラッシュ注意 3: 「非同期性と低精度による高速かつ正確な注意」 (Dao et al.、2024)
- GQA 文書: 「GQA: 一般化されたマルチクエリ トランスフォーマー モデルのトレーニング」 (Ainslie et al.、2023)
- 図解されたトランス: ジェイ・アランマーによるビジュアルガイド
- PyTorch ドキュメント: 最適化された実装のための torch.nn.Multiheadtention
- 抱き合う顔: 実際の例を含むトランスフォーマーのドキュメント







