Paper Reading/Review

[리뷰] KalmanNet: Neural Network Aided Kalman Filtering for Partially Known Dynamics

  • -
728x90
반응형

이번에는 IEEE Transactions on Signal Processing 2022에 발표된 논문인 KalmanNet: Neural Network Aided Kalman Filtering for Partially Known Dynamics를 읽고, 리뷰해보고자 합니다.

Index
1. Background
    1.1. State Space Model
    1.2. Data-Aided Filtering Problem Formulation
    1.3. Extended Kalman Filter
    1.4. Recurrent Neural Network
    1.5. Back Propagation Through Time
    1.6. Truncated BPTT
    1.7. Gated Recurrent Unit
2. Abstract
3. Introduction
4. Related Work
5. Method
    5.1. High Level Architecture
    5.2. Input Features
    5.3. Neural Network Architecture
    5.4. Training Algorithm
6. Experiment
7. Conclusion 

1. Background

1.1. State Space Model

  • observation 근사치와 숨겨진 state 복구의 영역에서 연구되며, 이산 시간에서 동적 시스템을 고려
  • \( t \in \mathbb{Z} \)에 대해 아래와 같이 표현
    • state evolution model \( x_{t} = f(x_{t-1})+w_{t} \)
      • \( w_{t} \sim \mathcal{N}(0,Q) \), \( x_{t} \in \mathbb{R}^{m} \)
      • system의 역학에 의해 결정
      • object의 위치, 속도, 가속도 등을 결정할 수 있음
    • observation model \( y_{t} = h(x_{t})+v_{t} \)
      • \( v_{t} \sim \mathcal{N}(0,R) \), \( y_{t} \in \mathbb{R}^{n} \) 
      • 센서의 측정값으로, 관찰의 유형과 품질에 의해 결정
      • \( \mathbb{Z} \)는 정수, \( \mathbb{R} \)은 실수, \( x_{t} \)는 \( t \)시점에서의 state vector, \( f(\cdot) \)는 state evolution 함수, \( w_{t} \)는 additive white Gaussian noise, \( Q \)는 covariance, \( y_{t} \)는 \( t \)시점에서의 observation vector, \( h(\cdot) \)는 additive white Gaussian noise \( v_{t} \)와 covariance \( R \)에 의해 손상된 non-linear observation mapping
  • evolution이나 observation이 linear transformation인 경우 아래와 같은 행렬 \( F \), \( H \) 존재
    • \( f(x_{t-1})=F \cdot x_{t-1} \), \( h(x_{t})=H \cdot x_{t} \)
  • model의 매개변수는 알 수 없으며, 실시간으로 추정하기 위한 전용 메커니즘 도입 필요

1.2. Data-Aided Filtering Problem Formulation

  • filtering 문제는 실시간 tracking의 핵심으로, \( y_{t} \)에 기초한 \( x_{t} \)를 실시간으로 제공해야 함
  • filtering에 필요한 모든 데이터를 알지 못함
    • \( f(\cdot) \), \( h(\cdot) \)에 대해 정확하거나 근사치를 알고 있는 상태
    • noise \( w_{t} \), \( v_{t} \), \( Q \), \( R \)은 모르는 상태
  • online으로 학습에 필요한 GT data를 수집하거나, 알고리즘을 통해 GT를 계산할 수 있음

1.3. Extended Kalman Filter

 

Extended Kalman Filter

0. 들어가기에 앞서 본 게시글은 Extended Kalman Filter에 대해 쉽게 이해할 수 있도록 최대한 간략하게 작성한 글입니다. 더욱 자세한 내용을 알고싶다면, 아래의 참고 링크 부분의 링크를 참고 바랍

alstn59v.tistory.com

1.4. ~ 1.6. RNN, BPTT, Truncated BPTT

 

Recurrent Neural Network

1. 개념 순차 데이터를 처리하는데 적합한 신경망 machine translation, DNA analysis, voice recognition, motion recognition, sentiment analysis 등에 이용 hidden layer의 neuron에서 출력된 값이 다시 그 neuron의 입력으로 사

alstn59v.tistory.com

1.7. Gated Recurrent Unit

 

Gated Recurrent Unit

0. 들어가기에 앞서 본 게시글은 다양한 GRU에 대해 쉽게 이해할 수 있도록 최대한 간략하게 작성한 글입니다. 더욱 자세한 내용을 알고싶다면, 아래의 참고 링크 부분의 링크를 참고 바랍니다. 1.

alstn59v.tistory.com

 

2. Abstract

  • linear gaussian state space의 경우 kalman filter가 가장 복잡도가 낮은 방법
    • 그러나 real world는 전체 정보가 아닌 일부 정보와, non-linear motion이 주로 등장
  • 학습을 기반으로 실시간 state를 추정하는 KalmanNet 제시
    • data로부터 복잡한 motion을 학습
    • kalman filter의 data efficiency와 interpretability를 유지

 

3. Introduction

  • kalman filter의 낮은 복잡성 때문에, 다양한 목적으로 상태 추정을 위해 사용됨
  • 현실 세계는 주로 non-linear state space 이어서 성능 저하 발생
    • 이를 handling하기 위해 extended kalman filter, unscented kalman filter가 등장하였으며, monte-carlo sampling을 이용하는 방식인 다양한 particle filter 등장
  • 위의 모든 filter들은 model-based 방식이기 때문에, 성능이 역학이나 domain에 대한 정확한 지식과 model의 가정에 의존
    • 어느 정도의 불확실성을 대처하도록 설계되었지만, 완전한 domain 지식으로도 성능을 달성하기 힘듬
    • 가정과 실제가 얼마나 차이나는지에 대한 어느 정도의 지식이 필요
  • 명시적이고 정확한 지식을 요구하지 않는 DNN의 성공에 영감을 받음
    • 다루기 힘든 환경에서 time-series 작업을 잘 수행
    • 그러나, 연산 과정에 대한 interpretability가 부족하며, 많은 양의 데이터와 매개변수가 필요
  • 따라서 딥러닝과 고전적인 방식의 낮은 복잡성을 모두 활용하는 하이브리드 접근 방식 제안
    • noise에 대한 정보를 모르고, model에 대해 부분적으로 알려지거나 근사된다고 가정
    • kalman gain에 대한 계산이 noise와 domain 정보에 대한 의존성을 가진 주요한 부분으로 식별하여 소형 RNN으로 대체
    • supervised learning 방식 이용

 

4. Related Work

  • 새 관측 정보를 사용하여, 이전의 추정치를 업데이트
    • 간단하고 noise가 gaussian 분포를 따르는 경우 : kalman filter, extended kalman filter, unscented kalman filter, …, etc
    • 복잡하고 noise가 gaussian 분포를 따르지 않는 경우 : particle filter
  • kalman filter와 머신러닝의 조합
    • system의 역학이 복잡한 경우 DNN을 통해 학습하여 사용할 수 있지만, 실시간 filtering이 불가능
    • 누락된 매개변수를 사전에 조정하는데 machine learning을 활용하지만, 기본 역학에 제시된 매개변수로 제한되는 문제가 있음

 

5. Method

5.1. High Level Architecture

  • kalman filter의 공분산 행렬 \( Q \)와 \( R \)은 알려져있지 않기 때문에, 해당 행렬이 사용되는 부분인 kalman gain에 대해 학습
    • 학습된 값을 kalman filter에서 이용
  • extended kalman filter처럼 predict와 update의 두 단계를 가짐
    • predict 단계
      • first-order statistical moments만 예측한다는 점을 제외하고 EKF와 동일
      • noise의 정보에 의존하지 않으며, second-order statistical moments의 명시적 추정치를 유지하지 않음
    • update 단계
      • kalman gain의 계산이 명시적으로 제공되지 않고 RNN을 이용해 학습
      • second-order statistical moments를 암시적으로 추적

5.2. Input Features

  • 매 \( t \)순간 kalman gain을 계산하기 위해 \( z_{t} \)과 \( \hat{x}^{-}_{t-1} \)의 통계 정보를 포함하는 입력이 제공되어야 함
  • 따라서, 알려지지 않은 통계와 관련된 아래 값들을 RNN의 input feature로 사용 가능
    • state evolution process의 정보 캡슐화
      • observation difference : \( \Delta \tilde{z} _{t} = z_{t}-z_{t-1} \)
      • forward evolution difference : \( \Delta \tilde{x} _{t} = \hat{x}^{-}_{t|t}-\hat{x}^{-}_{t-1|t-1} \)
        • 연속적인 두 시점의 사후 상태 추정의 차이값으로, \( t \)시점에서 사용 가능한 feature는 \( \Delta \tilde{x} _{t-1} \)
    • state estimate의 불확실성 캡슐화
      • innovation difference : \( \Delta z_{t} = z_{t}-\hat{z}_{t|t-1} \)
      • forward update difference : \( \Delta \hat{x}_{t} = \hat{x}^{-}_{t|t}-\hat{x}^{-}_{t|t-1} \)
        • 사후 상태 추정과 이전 상태 추정의 차이값으로, \( t \)시점에서 사용 가능한 feature는 \( \Delta \hat{x} _{t-1} \)
    • difference 연산은 예측 가능한 요소를 제거하고, difference의 time-series는 대부분 noise의 통계에 영향을 받음
      • 경험적 관찰에 따르면 좋은 조합은 \( \{ \Delta \tilde{z} _{t} , \Delta z_{t}, \Delta \hat{x}_{t}\} \)와 \( \{ \Delta \tilde{z} _{t} , \Delta \tilde{x} _{t}, \Delta \hat{x}_{t}\} \)

5.3. Neural Network Architecture

  • KalmanNet의 DNN이 input features를 이용하여 kalman gain 계산
    • 재귀적 특성 때문에 RNN을 사용하여 memory의 요소를 포함해야 함
  • kalman gain 계산을 위해 2가지 architecture를 고려
    • RNN의 내부 메모리를 사용하여 kalman gain 계산에 필요한 statistical moments인 \( P^{-}_{k} \)와 \(  HP^{-}_{k}H^{T}+R \)를 암시적으로 추적하기 위해 GRU cell을 이용하는 구조
      • 크기가 \( m^{2}+n^{2} \)인 hidden state를 가지며, input layer로 FC 사용
      • GRU의 state vector \( \bf{h} \)는 \( K_{k} \in \mathbb{R}^{m \times n} \)에 mapping
      • 많은 수의 state variables를 사용하게 되고, overparameterization을 초래
    • kalman gain 계산에 필요한 statistical moments 각각에 대해 별도의 GRU cell을 이용하는 구조
      • 첫 번째 GRU layer는 \( Q \)를 추적하여 \( m^{2} \) 변수를 추적
      • 두 번째와 세 번째 GRU layer는 \( P^{-}_{k} \)와 \(  HP^{-}_{k}H^{T}+R \)를 각각 추적하여 \( m^{2} \)와 \( n^{2} \) 크기의 hidden state 변수를 가짐
      • GRU는 상호 연결되어 kalman gain을 구하는데 관여
      • 이전에 고려한 구조보다 덜 추상적이기 때문에, 학습해야할 매개변수가 \( 5 \times 10^{5} \)개 에서 \( 2.5 \times 10^{4} \)개로 줄어듬

5.4. Training Algorithm

  • supvervised learning 이용
  • \( \hat{x}_{t|t} \)를 직접 생성하는 대신, end-to-end 학습 수행을 통해 kalman gain 계산
    • 즉, RNN의 output이 아닌 \( \hat{x}_{t} \)를 기반으로 연속 집합 \( \mathbb{R}^{m} \)을 취하는 loss \( \mathcal{L} \) 계산
      • \( \mathcal{L}=\Vert x_{t}-\hat{x}_{t|t}\Vert^{2} \)
      • \( \frac{\delta \mathcal{L}}{\delta K_{t}}=\frac{\delta \Vert x_{t}-\hat{x}_{t|t}\Vert^{2}}{\delta K_{t}}=2 \cdot (K_{t}\cdot \Delta z_{t}-\Delta x_{t})\cdot \Delta z_{t}^{T} \text{ where } \Delta x_{t} \triangleq x_{t}-\hat{x}_{t|t-1} \)
    • kalman gain의 계산에 대한 loss를 backpropagate 수행
  • 학습에는 다양한 길이를 가진 \( N \)개의 trajectory가 dataset으로 사용 됨
    • 즉, \( T_{i} \)를 \( i^{th} \) trajectory의 길이로 하는 \( \mathcal{D} = \{ (Z_{i}, X_{i}) \}^{N}_{1} \)
      • \( Z_{i}=[z^{(i)}_{1}, \dots ,z^{(i)}_{T_{i}}] \), \( X_{i}=[x^{(i)}_{0}, x^{(i)}_{1},\dots ,x^{(i)}_{T_{i}}] \)
  • \( \Theta \)가 RNN의 훈련 가능한 매개변수이고, \( \gamma \)를 정규화 상관계수라 할 때, 정규화된 L2 MSE loss를 측정할 수 있음
    • \( \ell_{i}=\frac{1}{T_{i}}\sum^{T_{i}}_{t=1}{\Vert\hat{x}_{t}(z^{(i)}_{t};\Theta)-x^{(i)}_{t}\Vert^{2}+\gamma \cdot \Vert \Theta \Vert^{2}} \)
  • \( \Theta \)를 최적화 하기 위해, 모든 \( k \)로 index된 mini-batch에 대해 \( M<N \)인 trajectory \( i^{k}_{1},\dots,i^{k}_{M} \)를 이용하여 mini-batch loss \( \mathcal{L}_{k} \) 계산
    • \( \mathcal{L}_{k}=\frac{1}{M}\sum^{M}_{j=1}{{\ell_{i}}^{k}_{j}(\Theta)} \)
  • 외부의 재귀와 내부의 RNN을 모두 가진 구조이기 때문에, 학습에 BPTT 알고리즘 이용
    • shared network 매개변수를 사용하여 시간을 기준으로 unfold한 다음, forward와 backward gradient를 계산
  • 학습을 위해 BPTT 알고리즘의 다양한 변형을 고려
    • 방법 1) BPTT를 그대로 적용하여, 각 iteration에 대해 전체 trajectory에 걸쳐 gradient 계산
    • 방법 2) truncated BPTT를 적용하여, long trajectory dataset이 주어질 때 이 trajectory를 tracklet으로 분리한 뒤 shuffle하여 학습에 이용
    • 방법 3) truncated BPTT 대신 각 trajectory를 상대적으로 짧은 tracklet으로 잘라내어 학습에 이용
    • 방법 1은 계산 cost가 높고 불안정할 수 있기 때문에, 방법 2를 통해 어느정도 훈련한 후, 방법 1을 통해 튜닝
    • 방법 3은 linear state space에서 가장 적합하다고 알려짐

 

6. Experiment

  • KalmanNet의 학습 곡선과 MMSE의 달성

  • non-linear state space model에서 성능 비교

  • noisy state에서 성능 비교

  • Lorenz Attractor에서의 prediction 시각화 비교

  • NCLT dataset에서의 성능 비교

 

7. Conclusion 

  • 고전적인 방법의 MB EKF와 딥 러닝의 조합인 KalmanNet 제시
    • state space-model-dependent를 적은 양의 dataset으로 훈련할 수 있는 RNN으로 대체
    • model 불일치와 non-linearity를 극복하는 방법에 대해 학습이 가능함을 보여줌
    • tracklet을 사용하여 훈련하는 동안 임의의 trajectory를 사용할 수 있도록 하는 방법 제시
    • 복잡성을 줄여 계산 성능이 제한된 장비에 적용할 수 있음

 

 

논문 링크

https://arxiv.org/abs/2107.10043

https://github.com/KalmanNet/KalmanNet_TSP

728x90
반응형
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.