バックプロパゲーション実装アーキテクチャ概要
[!NOTE] 最新の実装状況は 機能実装ステータス (Remaining Functionality) を参照してください。
対象読者: 勾配学習とSNNの検証を担当するエンジニア。参考になるコード:
backpropagation_verification.py,surrogate.py。
本ドキュメントは、EvoSpikeNetにおけるバックプロパゲーション基盤の実装を、コンセプト、モジュール構成、データフロー、シーケンスの観点から詳細に説明します。すべての図に読み方と意図を添え、設計を追いやすくしています。
1. コンセプト
- スパイクの非微分性を回避: スパイク関数の代わりに滑らかな勾配(サロゲート)を逆伝播専用で用い、PyTorchのautogradが勾配を流せるようにする。
- 勾配の正当性と安定性を検証: 有限差分による数値勾配との比較、サロゲート勾配の近似誤差測定、勾配/重みの数値安定性確認を自動化。
- 収束挙動とベンチマーク: 収束率、損失推移、勾配ノルムを追跡し、SNNとANNで学習効率を比較する。
2. モジュール構成
FastSigmoidautograd関数(surrogate.py): 前向きはステップ関数、後ろ向きは高速シグモイド由来の滑らかな勾配を返す。SurrogateGradient(backpropagation_verification.py): fast sigmoid, triangular, rectangular, exponential, super spike など複数のサロゲート勾配を提供。GradientVerifier(同backpropagation_verification.py): 有限差分とサロゲート誤差を計測し、最大誤差・平均誤差・相対誤差をレポート。NumericalStabilityTester(同backpropagation_verification.py): 勾配ノルム・NaN/Inf・条件数を反復測定し、安定性を判定。重み更新を含むパスも検証。ConvergenceAnalyzer(同backpropagation_verification.py): 反復学習し、損失と勾配ノルムの履歴から収束率を算出、忍耐値で早期終了を判定。ComparativeBenchmark(同backpropagation_verification.py): SNNとANNを同条件で訓練・評価し、損失/精度/時間を比較。BackpropagationVerificationSuite(同backpropagation_verification.py): 上記コンポーネントをまとめ、フル検証・レポート生成を行う。
3. アーキテクチャ概要図
graph LR
subgraph SNN学習基盤
A[Surrogate<br/>勾配関数群]
B[GradientVerifier<br/>有限差分/サロゲート]
C[NumericalStabilityTester<br/>勾配・重み安定性]
D[ConvergenceAnalyzer<br/>収束解析]
E[ComparativeBenchmark<br/>SNN vs ANN]
F[VerificationSuite<br/>統合オーケストレーション]
end
A --> B
A --> F
B --> F
C --> F
D --> F
E --> F
図の説明: Surrogate勾配(A)が勾配検証(B)と統合スイート(F)で利用される。安定性(C)・収束(D)・比較ベンチ(E)の結果がスイートに集約され、単一の検証フローで確認できる。
4. データフロー(勾配検証パス)
flowchart LR
In[サンプル入力<br/>サンプル教師] --> M[対象モデル]
M --> L[損失計算]
L -->|backward| Gv[勾配取得<br/>autograd]
L -->|有限差分| Fd[数値勾配]
Gv --> Err[誤差計測<br/>最大/平均/相対]
Fd --> Err
Err --> Rep[検証結果<br/>GradientVerificationResult]
図の説明: 入力と教師ラベルを用いてモデル出力と損失を計算。autogradで得た勾配と有限差分で得た勾配を比較し、誤差指標を算出して結果オブジェクトにまとめる。これがGradientVerifier.verify_finite_differenceの内部処理を可視化している。
5. データフロー(サロゲート検証パス)
flowchart LR
X[入力サンプル<br/>線形空間] --> SFn[スパイク関数<br/>Heaviside等]
X --> SG[サロゲート勾配関数]
SFn --> FD[有限差分<br/>スパイク差分]
SG --> ErrS[誤差計測<br/>サロゲート vs 有限差分]
FD --> ErrS
ErrS --> RepS[検証結果<br/>SURROGATE]
図の説明: スパイク関数の出力差分(有限差分)と、サロゲート勾配関数の出力を比較し、近似誤差を評価する。verify_surrogate_gradientのステップを示す。
6. シーケンス図(フル検証スイート)
sequenceDiagram
participant User as ユーザー
participant Suite as VerificationSuite
participant GV as GradientVerifier
participant NST as StabilityTester
participant CA as ConvergenceAnalyzer
participant CB as ComparativeBenchmark
User->>Suite: run_full_verification(model, loaders, loss, opt)
Suite->>GV: verify_finite_difference(...)
GV-->>Suite: gradient_verification
Suite->>NST: check_gradient_stability(...)
NST-->>Suite: gradient_stability
Suite->>NST: check_weight_stability(...)
NST-->>Suite: weight_stability
Suite->>CA: analyze_convergence(...)
CA-->>Suite: convergence
Suite->>CB: benchmark_training(...)
CB-->>Suite: snn/ann metrics
Suite-->>User: results + report
図の説明: run_full_verification呼び出しから各検証モジュールの実行順序と戻り値を示す。最後に統合結果とレポート文字列がユーザーへ返る。
7. 主要API概要
spike_function = FastSigmoid.apply: スパイク活性化の代替(前向きはステップ、後向きは滑らか勾配)。GradientVerifier.verify_finite_difference(model, inputs, targets, loss_fn): autograd勾配と有限差分の最大/平均/相対誤差を算出。GradientVerifier.verify_surrogate_gradient(spike_fn, surrogate_fn, ...): サロゲート勾配と有限差分を比較し、近似品質を評価。NumericalStabilityTester.check_gradient_stability(...): 勾配ノルム分布、NaN/Inf、条件数から安定性判定。NumericalStabilityTester.check_weight_stability(...): 反復学習中の重みノルム推移と条件数を検査。ConvergenceAnalyzer.analyze_convergence(...): 損失と勾配ノルム履歴、収束率、忍耐による早期停止を返す。ComparativeBenchmark.benchmark_training(...): SNN/ANNの損失・精度・訓練時間を並列比較。BackpropagationVerificationSuite.run_full_verification(...): 上記を一括実行して辞書形式の結果を返す。BackpropagationVerificationSuite.generate_report(results): 人可読な検証レポート文字列を生成。
8. トレーサビリティ
- サロゲート勾配実装:
surrogate.py - 勾配検証・安定性・収束・ベンチ:
backpropagation_verification.py - テストスイート:
test_backprop_verification.py - PoC比較: proof-of-concept/poc5_backprop_verification