Maxton‘s Blog

Back

RL Study Notes: Value Function Approximation

Summary of value function approximation in RL, covering linear/non-linear forms, state distributions, gradient methods, DQN, and experience replay.

Value Function Approximation#

Linear Function Form#

v^(s,w)=as+b=[s,1]ϕT(s)[ab]w=ϕT(s)w\hat{v}(s, w) = as + b = \underbrace{[s, 1]}_{\phi^T(s)} \underbrace{\begin{bmatrix} a \\ b \end{bmatrix}}_{w} = \phi^T(s)w

Where:

  • ww is the parameter vector.
  • ϕ(s)\phi(s) is the feature vector of state ss.
  • v^(s,w)\hat{v}(s, w) is linear with respect to ww.

Non-linear Function Form#

v^(s,w)=as2+bs+c=[s2,s,1]ϕT(s)[abc]w=ϕT(s)w\hat{v}(s, w) = as^2 + bs + c = \underbrace{[s^2, s, 1]}_{\phi^T(s)} \underbrace{\begin{bmatrix} a \\ b \\ c \end{bmatrix}}_{w} = \phi^T(s)w

In this case:

  • The dimensions of ww and ϕ(s)\phi(s) increase, potentially making the numerical fitting more accurate.
  • Although v^(s,w)\hat{v}(s, w) is non-linear with respect to state ss, it remains linear with respect to parameter ww. The non-linear features are encapsulated in the mapping ϕ(s)\phi(s).

State Value Estimation#

Objective Function:

J(w)=E[(vπ(S)v^(S,w))2]J(w) = \mathbb{E}[(v_\pi(S) - \hat{v}(S, w))^2]
  • The core objective is to find the optimal parameters ww to minimize this objective function.
  • SS is a random variable. Its probability distribution mainly considers the following two types:

Uniform Distribution#

J(w)=1SsS(vπ(s)v^(s,w))2J(w) = \frac{1}{|\mathcal{S}|} \sum_{s \in \mathcal{S}} (v_\pi(s) - \hat{v}(s, w))^2
  • The uniform distribution treats all states equally. However, in actual reinforcement learning, some states are visited more frequently and are more critical, so this distribution is often unsuitable.

Stationary Distribution#

The stationary distribution describes the long-run behavior of a Markov process. Here, {dπ(s)}sS\{d_{\pi}(s)\}_{s \in \mathcal{S}} represents the set of state distributions, satisfying dπ(s)0d_{\pi}(s) \geq 0 and sSdπ(s)=1\sum_{s \in \mathcal{S}} d_{\pi}(s) = 1.

J(w)=sSdπ(s)(vπ(s)v^(s,w))2J(w) = \sum_{s \in \mathcal{S}} d_{\pi}(s)(v_{\pi}(s) - \hat{v}(s, w))^2
  • dπ(s)d_{\pi}(s) represents the stationary probability of being in a specific state under policy π\pi. Using the stationary distribution allows for smaller fitting errors on frequently visited states.
  • The stationary distribution satisfies the following formula:
dπT=dπTPπd_{\pi}^T = d_{\pi}^T P_{\pi}

Where PπP_{\pi} is the state transition matrix in the Bellman equation.

Optimization Methods#

Update parameters using gradient descent:

wk+1=wkαkwJ(wk)w_{k+1} = w_k - \alpha_k \nabla_w J(w_k)

The derivation of the true gradient is as follows:

wJ(w)=wE[(vπ(S)v^(S,w))2]=E[w(vπ(S)v^(S,w))2]=2E[(vπ(S)v^(S,w))(wv^(S,w))]=2E[(vπ(S)v^(S,w))wv^(S,w)]\begin{aligned} \nabla_w J(w) &= \nabla_w \mathbb{E}[(v_\pi(S) - \hat{v}(S, w))^2] \\ &= \mathbb{E}[\nabla_w (v_\pi(S) - \hat{v}(S, w))^2] \\ &= 2\mathbb{E}[(v_\pi(S) - \hat{v}(S, w))(-\nabla_w \hat{v}(S, w))] \\ &= -2\mathbb{E}[(v_\pi(S) - \hat{v}(S, w))\nabla_w \hat{v}(S, w)] \end{aligned}

In practice, Stochastic Gradient Descent (SGD) is commonly used:

wt+1=wt+αt(vπ(st)v^(st,wt))wv^(st,wt)w_{t+1} = w_t + \alpha_t (v_\pi(s_t) - \hat{v}(s_t, w_t)) \nabla_w \hat{v}(s_t, w_t)

Where sts_t is a sample of SS. For brevity, the constant 22 is absorbed into the learning rate αt\alpha_t. Since the true vπ(st)v_{\pi}(s_t) is unknown, we need to replace it with an estimate:

  • Monte Carlo (MC) based: Use the discounted return gtg_t in an episode to approximate vπ(st)v_{\pi}(s_t).
wt+1=wt+αt(gtv^(st,wt))wv^(st,wt)w_{t+1} = w_t + \alpha_t (g_t - \hat{v}(s_t, w_t)) \nabla_w \hat{v}(s_t, w_t)
  • Temporal Difference (TD) based: The target value rt+1+γv^(st+1,wt)r_{t+1}+\gamma\hat{v}(s_{t+1},w_t) is treated as an approximation of vπ(st)v_{\pi}(s_t).
wt+1=wt+αt[rt+1+γv^(st+1,wt)v^(st,wt)]wv^(st,wt)w_{t+1} = w_t + \alpha_t [r_{t+1} + \gamma \hat{v}(s_{t+1}, w_t) - \hat{v}(s_t, w_t)] \nabla_w \hat{v}(s_t, w_t)

TD-Linear Algorithm#

In the linear case of v^(s,w)=ϕT(s)w\hat{v}(s, w) = \phi^T(s)w, the gradient is:

wv^(s,w)=ϕ(s)\nabla_w \hat{v}(s, w) = \phi(s)

Substituting the gradient into the TD algorithm:

wt+1=wt+αt[rt+1+γϕT(st+1)wtϕT(st)wt]ϕ(st)w_{t+1} = w_t + \alpha_t [r_{t+1} + \gamma \phi^T(s_{t+1}) w_t - \phi^T(s_t) w_t] \phi(s_t)

This is the TD learning algorithm with linear function approximation, briefly referred to as TD-Linear.

Derivative Analysis of Linear Approximation#

In RL linear approximation, v^(s,w)\hat{v}(s, w) is a scalar (predicted state value), and ww is a vector (weight parameters).

  1. Deconstructing the Linear Expression For column vectors ϕ(s)=[ϕ1,,ϕn]T\phi(s) = [\phi_1, \dots, \phi_n]^T and w=[w1,,wn]Tw = [w_1, \dots, w_n]^T, the inner product is:

    v^(s,w)=i=1nϕiwi\hat{v}(s, w) = \sum_{i=1}^{n} \phi_i w_i
  2. Deriving with Respect to a Vector The essence of wv^(s,w)\nabla_w \hat{v}(s, w) is taking the partial derivative of the scalar function with respect to each component of vector ww:

    wi(ϕ1w1++ϕnwn)=ϕi\frac{\partial}{\partial w_i} (\phi_1 w_1 + \dots + \phi_n w_n) = \phi_i

    Putting it together, we get wv^(s,w)=ϕ(s)\nabla_w \hat{v}(s, w) = \phi(s).

Tabular Representation#

The tabular method is a special case of linear function approximation. Assume the feature vector of state ss is a One-hot vector:

ϕ(s)=esRS\phi(s) = e_s \in \mathbb{R}^{|\mathcal{S}|}

At this time:

v^(s,w)=esTw=w(s)\hat{v}(s, w) = e_s^T w = w(s)

That is, w(s)w(s) extracts the ss-th component of vector ww corresponding to state ss.

Action Value Function Approximation#

Sarsa with Function Approximation#

wt+1=wt+αt[rt+1+γq^(st+1,at+1,wt)q^(st,at,wt)]wq^(st,at,wt)w_{t+1} = w_t + \alpha_t \left[ r_{t+1} + \gamma \hat{q}(s_{t+1}, a_{t+1}, w_t) - \hat{q}(s_t, a_t, w_t) \right] \nabla_w \hat{q}(s_t, a_t, w_t)

Q-learning with Function Approximation#

wt+1=wt+αt[rt+1+γmaxaA(st+1)q^(st+1,a,wt)q^(st,at,wt)]wq^(st,at,wt)w_{t+1} = w_t + \alpha_t \left[ r_{t+1} + \gamma \max_{a \in \mathcal{A}(s_{t+1})} \hat{q}(s_{t+1}, a, w_t) - \hat{q}(s_t, a_t, w_t) \right] \nabla_w \hat{q}(s_t, a_t, w_t)

Deep Q-Network (DQN)#

DQN uses neural networks to approximate the non-linear Q function.

Loss Function:

J(w)=E[(R+γmaxaA(S)q^(S,a,w)q^(S,A,w))2]J(w) = \mathbb{E} \left[ \left( R + \gamma \max_{a \in \mathcal{A}(S')} \hat{q}(S', a, w) - \hat{q}(S, A, w) \right)^2 \right]

This is essentially minimizing the Bellman Optimality Error. Define the target value yy as:

yR+γmaxaA(S)q^(S,a,w)y \doteq R + \gamma \max_{a \in \mathcal{A}(S')} \hat{q}(S', a, w)

To ensure training stability and prevent the target value from constantly shifting with network updates, DQN introduces a dual-network architecture:

  • Main Network: q^(S,A,w)\hat{q}(S, A, w), responsible for current action evaluation and real-time parameter updates.
  • Target Network: q^(S,A,wT)\hat{q}(S', A, w_T), providing a stable target value yy.

With the target network introduced, the loss function becomes:

J(w)=E[(R+γmaxaA(S)q^(S,a,wT)q^(S,A,w))2]J(w) = \mathbb{E} \left[ \left( R + \gamma \max_{a \in \mathcal{A}(S')} \hat{q}(S', a, w_T) - \hat{q}(S, A, w) \right)^2 \right]

During computation, assuming wTw_T is a constant (i.e., not involved in gradient calculation), gradient descent only updates ww:

wJ(w)=2E[(R+γmaxaA(S)q^(S,a,wT)q^(S,A,w))wq^(S,A,w)]\nabla_w J(w) = -2\mathbb{E} \left[ \left( R + \gamma \max_{a \in \mathcal{A}(S')} \hat{q}(S', a, w_T) - \hat{q}(S, A, w) \right) \nabla_w \hat{q}(S, A, w) \right]

Note: The parameters ww of the main network are periodically copied to the target network wTw_T.

Experience Replay#

  • Motivation: Sequential data collected in reinforcement learning has strong correlations. Using it directly for training can easily lead to network instability.
  • Mechanism: Store the interaction data generated by the agent and the environment as tuples (s,a,r,s)(s, a, r, s') into a Replay Buffer B\mathcal{B}.
  • Sampling: During training, extract a batch of random samples (Mini-batch) from the buffer. This extraction process usually follows a uniform distribution, thereby breaking the temporal correlations between data and significantly improving data utilization efficiency.