コンテンツにスキップ

バックプロパゲーション実装アーキテクチャ概要

[!NOTE] 最新の実装状況は 機能実装ステータス (Remaining Functionality) を参照してください。

対象読者: 勾配学習とSNNの検証を担当するエンジニア。参考になるコード: backpropagation_verification.py, surrogate.py

本ドキュメントは、EvoSpikeNetにおけるバックプロパゲーション基盤の実装を、コンセプト、モジュール構成、データフロー、シーケンスの観点から詳細に説明します。すべての図に読み方と意図を添え、設計を追いやすくしています。

1. コンセプト

  • スパイクの非微分性を回避: スパイク関数の代わりに滑らかな勾配(サロゲート)を逆伝播専用で用い、PyTorchのautogradが勾配を流せるようにする。
  • 勾配の正当性と安定性を検証: 有限差分による数値勾配との比較、サロゲート勾配の近似誤差測定、勾配/重みの数値安定性確認を自動化。
  • 収束挙動とベンチマーク: 収束率、損失推移、勾配ノルムを追跡し、SNNとANNで学習効率を比較する。

2. モジュール構成

  • FastSigmoid autograd関数(surrogate.py): 前向きはステップ関数、後ろ向きは高速シグモイド由来の滑らかな勾配を返す。
  • SurrogateGradientbackpropagation_verification.py): fast sigmoid, triangular, rectangular, exponential, super spike など複数のサロゲート勾配を提供。
  • GradientVerifier(同 backpropagation_verification.py): 有限差分とサロゲート誤差を計測し、最大誤差・平均誤差・相対誤差をレポート。
  • NumericalStabilityTester(同 backpropagation_verification.py): 勾配ノルム・NaN/Inf・条件数を反復測定し、安定性を判定。重み更新を含むパスも検証。
  • ConvergenceAnalyzer(同 backpropagation_verification.py): 反復学習し、損失と勾配ノルムの履歴から収束率を算出、忍耐値で早期終了を判定。
  • ComparativeBenchmark(同 backpropagation_verification.py): SNNとANNを同条件で訓練・評価し、損失/精度/時間を比較。
  • BackpropagationVerificationSuite(同 backpropagation_verification.py): 上記コンポーネントをまとめ、フル検証・レポート生成を行う。

3. アーキテクチャ概要図

graph LR
    subgraph SNN学習基盤
        A[Surrogate<br/>勾配関数群]
        B[GradientVerifier<br/>有限差分/サロゲート]
        C[NumericalStabilityTester<br/>勾配・重み安定性]
        D[ConvergenceAnalyzer<br/>収束解析]
        E[ComparativeBenchmark<br/>SNN vs ANN]
        F[VerificationSuite<br/>統合オーケストレーション]
    end
    A --> B
    A --> F
    B --> F
    C --> F
    D --> F
    E --> F
図の説明: Surrogate勾配(A)が勾配検証(B)と統合スイート(F)で利用される。安定性(C)・収束(D)・比較ベンチ(E)の結果がスイートに集約され、単一の検証フローで確認できる。

4. データフロー(勾配検証パス)

flowchart LR
    In[サンプル入力<br/>サンプル教師] --> M[対象モデル]
    M --> L[損失計算]
    L -->|backward| Gv[勾配取得<br/>autograd]
    L -->|有限差分| Fd[数値勾配]
    Gv --> Err[誤差計測<br/>最大/平均/相対]
    Fd --> Err
    Err --> Rep[検証結果<br/>GradientVerificationResult]
図の説明: 入力と教師ラベルを用いてモデル出力と損失を計算。autogradで得た勾配と有限差分で得た勾配を比較し、誤差指標を算出して結果オブジェクトにまとめる。これがGradientVerifier.verify_finite_differenceの内部処理を可視化している。

5. データフロー(サロゲート検証パス)

flowchart LR
    X[入力サンプル<br/>線形空間] --> SFn[スパイク関数<br/>Heaviside等]
    X --> SG[サロゲート勾配関数]
    SFn --> FD[有限差分<br/>スパイク差分]
    SG --> ErrS[誤差計測<br/>サロゲート vs 有限差分]
    FD --> ErrS
    ErrS --> RepS[検証結果<br/>SURROGATE]
図の説明: スパイク関数の出力差分(有限差分)と、サロゲート勾配関数の出力を比較し、近似誤差を評価する。verify_surrogate_gradientのステップを示す。

6. シーケンス図(フル検証スイート)

sequenceDiagram
    participant User as ユーザー
    participant Suite as VerificationSuite
    participant GV as GradientVerifier
    participant NST as StabilityTester
    participant CA as ConvergenceAnalyzer
    participant CB as ComparativeBenchmark

    User->>Suite: run_full_verification(model, loaders, loss, opt)
    Suite->>GV: verify_finite_difference(...)
    GV-->>Suite: gradient_verification
    Suite->>NST: check_gradient_stability(...)
    NST-->>Suite: gradient_stability
    Suite->>NST: check_weight_stability(...)
    NST-->>Suite: weight_stability
    Suite->>CA: analyze_convergence(...)
    CA-->>Suite: convergence
    Suite->>CB: benchmark_training(...)
    CB-->>Suite: snn/ann metrics
    Suite-->>User: results + report
図の説明: run_full_verification呼び出しから各検証モジュールの実行順序と戻り値を示す。最後に統合結果とレポート文字列がユーザーへ返る。

7. 主要API概要

  • spike_function = FastSigmoid.apply: スパイク活性化の代替(前向きはステップ、後向きは滑らか勾配)。
  • GradientVerifier.verify_finite_difference(model, inputs, targets, loss_fn): autograd勾配と有限差分の最大/平均/相対誤差を算出。
  • GradientVerifier.verify_surrogate_gradient(spike_fn, surrogate_fn, ...): サロゲート勾配と有限差分を比較し、近似品質を評価。
  • NumericalStabilityTester.check_gradient_stability(...): 勾配ノルム分布、NaN/Inf、条件数から安定性判定。
  • NumericalStabilityTester.check_weight_stability(...): 反復学習中の重みノルム推移と条件数を検査。
  • ConvergenceAnalyzer.analyze_convergence(...): 損失と勾配ノルム履歴、収束率、忍耐による早期停止を返す。
  • ComparativeBenchmark.benchmark_training(...): SNN/ANNの損失・精度・訓練時間を並列比較。
  • BackpropagationVerificationSuite.run_full_verification(...): 上記を一括実行して辞書形式の結果を返す。
  • BackpropagationVerificationSuite.generate_report(results): 人可読な検証レポート文字列を生成。

8. トレーサビリティ

  • サロゲート勾配実装: surrogate.py
  • 勾配検証・安定性・収束・ベンチ: backpropagation_verification.py
  • テストスイート: test_backprop_verification.py
  • PoC比較: proof-of-concept/poc5_backprop_verification