1. 悬崖寻路问题介绍
悬崖寻路问题是指在一个4 x 12的网格中,智能体以网格的左下角位置为起点,以网格的下角位置为终点,目标是移动智能体到达终点位置,智能体每次可以在上、下、左、右这4个方向中移动一步,每移动一步会得到-1单位的奖励。

智能体在移动中有以下限制:
(1) 智能体不能移出网格,如果智能体想执行某个动作移出网格,那么这一步智能体不会移动,但是这个操作依然会得到-1单位的奖励
(2) 如果智能体“掉入悬崖” ,会立即回到起点位置,并得到-100单位的奖励
(3) 当智能体移动到终点时,该回合结束,该回合总奖励为各步奖励之和
请设计算法求解悬崖寻路问题的最佳策略。
2. 时序差分强化学习算法
2.1 Q-learning训练算法
算法1 Q学习:一种异策略的时序差分学习算法

解释:
其中:状态是4*12格子的位置决定的;动作是上下左右四个动作决定的,分别用0、1、2、3表示;Q(s,a)表是Q = np.zeros([4, 12, 4])初始化的,记录着分数状态和状态下的得分情况。
位置表示状态,在程序实现过程中,用格子三种状态:'ground'、'terminal'和'trap'分别表示位置所处的状态;
异策略(off-policy):在执行【策略评估】和【策略提升】的时候,使用的不是相同的策略。off表示在策略评估时,我们使用的策略(我们叫做behavior policy)【偏离】了我们的目标策略(target policy)。
同策略(on-policy):在执行【策略评估】和【策略提升】的时候,使用的是相同的策略。
Sarsa是on-policy的更新方式,它的行动策略和评估策略都是ε-greedy策略,并且是先做出动作后更新策略。
Q-Learning是off-policy的更新方式,假设下一步选取最大奖赏的动作,从而更新价值函数。然后再通过ε-greedy策略选择下一步动作。(Q-Learning的思想根据值迭代得到。由于无法遍历样本较大情况下的状态空间和动作空间,因此难以得到各状态期望价值的精准估计,所以只能利用有限的样本数据,通过类似梯度下降的方式一步步去估计Q值,而不是直接赋予,最终使得Q收敛至最优)
算法2 Sarsa:一种同策略的时序差分学习算法

2.2 Q-learning原理
其学习的原理:基于时序差分学习方法,(书籍:神经网络与深度学习,作者:邱锡鹏)具体如下:


3. Q学习训练代码
import numpy as np
import random
# 获取一个格子的状态
def get_state(row, col):
if row!=3:
return 'ground'
if row == 3 and col == 11:
return 'terminal'
if row == 3 and col == 0:
return 'ground'
return 'trap'
# 在某一状态下执行动作,获得对应奖励
def move(row, col, action):
# 状态检查-进入陷阱或结束,则不能执行任何动作,获得0奖励
if get_state(row, col) in ["trap", "terminal"]:
return row, col, 0
# 执行上下左右动作后,对应的位置变化
if action == 0:
row -= 1
if action == 1:
row += 1
if action == 2:
col -= 1
if action == 3:
col += 1
# 最小不能小于零,最大不能大于3
row = max(0, row)
row = min(3, row)
col = max(0, col)
col = min(11, col)
# 掉进trap奖励-100,其余每走一步奖励-1,让agent尽快完成任务
reward = -1
if get_state(row, col) == 'trap':
reward = -100
return row, col, reward
# 初始化Q表格,每个格子采取每个动作的分数,刚开始都是未知的故为零
Q = np.zeros([4, 12, 4])
# 根据当前所处的格子,选取一个动作
def get_action(row, col):
# 以一定的概率探索
if random.random() < 0.1:
return np.random.choice(range(4))
# 返回当前Q表格中分数最高的动作
return Q[row, col].argmax()
# 计算当前格子的更新量(当前格子采取动作后获得的奖励,来到下一个格子及要进行的动作)
def update(row, col, action, reward, next_row, next_col):
"""计算量更新同srasa有差异
Saras: 估计当前贪婪策略的价值函数Q[row, col, action](在线策略)
Q-learning: 直接估计最优Q[row, col](离线策略)
在线策略:行为策略和目标策略是同一个策略
离线策略:---------------不是同一个策略
"""
target = reward + Q[next_row, next_col].max() * 0.95
value = Q[row, col, action]
# 时序查分计算td_error
td_error = 0.1 * (target - value)
# 返回误差值
return td_error
def train():
for epoch in range(4000):
# 每次迭代开始,随机一个起点,尽可能多地与环境交互,同时绑定一个动作
row = np.random.choice(range(4))
col = 0
action = get_action(row, col)#随机选择一个动作,或者选择分数最高的动作
# 计算本轮奖励的总和,越来越大
rewards_sum = 0
# 一直取探索,直到游戏结束或者进入trap(要判断)
while get_state(row, col) not in ["terminal", "trap"]:#根据位置,获取状态
# 当前状态下移动一次,获得新的状态
next_row, next_col, reward = move(row, col, action)#根据当前的位置row,col和动作action,获得了一个新的位置next_row、next_col和得分reward
rewards_sum += reward#得分加上reward
next_action = get_action(next_row, next_col)#在新位置next_row, next_col,随机选择动作或者选择得分最大的动作
# 获取此次移动的更新量
td_error = update(row, col, action, reward, next_row, next_col)
# 更新Q表格
Q[row, col, action] += td_error
# 状态更新
row, col, action = next_row, next_col, next_action
if epoch % 500 == 0:
print(f"epoch:{epoch}, rewards_sum:{rewards_sum}")
# 保存Q表格数据到文件
def save():
npy_file = './q_table.npy'
np.save(npy_file, Q)
print(npy_file + ' saved.')
# 从文件中读取Q值到Q表格中
def restore(npy_file='./q_table.npy'):
Q = np.load(npy_file)
print(npy_file + ' loaded.')
train()
save()
#print(Q)
#test()
训练结果如下:
epoch:0, rewards_sum:-111
epoch:500, rewards_sum:-112
epoch:1000, rewards_sum:-13
epoch:1500, rewards_sum:-13
epoch:2000, rewards_sum:-16
epoch:2500, rewards_sum:-13
epoch:3000, rewards_sum:-14
epoch:3500, rewards_sum:-18
epoch:4000, rewards_sum:-104
epoch:4500, rewards_sum:-14
epoch:5000, rewards_sum:-14
epoch:5500, rewards_sum:-14
epoch:6000, rewards_sum:-13
epoch:6500, rewards_sum:-14
epoch:7000, rewards_sum:-14
epoch:7500, rewards_sum:-113
epoch:8000, rewards_sum:-109
epoch:8500, rewards_sum:-110
epoch:9000, rewards_sum:-108
epoch:9500, rewards_sum:-17
[[[-10.7336123 -10.24650042 -10.73344798 -10.24650042]
[-10.24602278 -9.73315833 -10.73177298 -9.73315833]
[ -9.73114076 -9.19279825 -10.24330152 -9.19279825]
[ -9.18860185 -8.62399815 -9.71934594 -8.62399815]
[ -8.62062915 -8.02526122 -9.16607 -8.02526122]
[ -8.01177885 -7.39501181 -8.61667075 -7.39501181]
[ -7.38284039 -6.73159137 -7.8800845 -6.73159137]
[ -6.65523622 -6.03325408 -7.35559032 -6.03325408]
[ -5.96181829 -5.29816219 -6.55092082 -5.29816219]
[ -5.20557519 -4.52438125 -5.71929578 -4.52438125]
[ -4.2665695 -3.709875 -4.76517345 -3.709875 ]
[ -3.29211171 -2.8525 -3.6500356 -3.30771276]]
[[-10.73311209 -9.73315833 -10.2448655 -9.73315833]
[-10.22383066 -9.19279825 -10.23292195 -9.19279825]
[ -9.71398046 -8.62399815 -9.73136975 -8.62399815]
[ -9.15720592 -8.02526122 -9.15786455 -8.02526122]
[ -8.54598217 -7.39501181 -8.50477269 -7.39501181]
[ -7.97384834 -6.73159137 -7.98814648 -6.73159137]
[ -7.37252716 -6.03325408 -7.35337743 -6.03325408]
[ -6.6883798 -5.29816219 -6.67601993 -5.29816219]
[ -5.96942013 -4.52438125 -5.9456582 -4.52438125]
[ -5.16871589 -3.709875 -5.08949395 -3.709875 ]
[ -4.43322669 -2.8525 -4.20632995 -2.8525 ]
[ -3.60748154 -1.95 -3.49776269 -2.77677297]]
[[-10.24650041 -10.24650037 -9.73315828 -9.19279825]
[ -9.73315833 -99.99999834 -9.73315826 -8.62399815]
[ -9.19279821 -99.99999958 -9.1927982 -8.02526122]
[ -8.62399815 -99.99999879 -8.62399814 -7.39501181]
[ -8.02526118 -99.99999989 -8.02526121 -6.73159137]
[ -7.3950118 -99.99999999 -7.39501181 -6.03325408]
[ -6.73159137 -99.99999993 -6.73159137 -5.29816219]
[ -6.03325407 -99.99999999 -6.03325407 -4.52438125]
[ -5.29816219 -99.99999953 -5.29816219 -3.709875 ]
[ -4.52438125 -99.99999995 -4.52438125 -2.8525 ]
[ -3.709875 -99.99999998 -3.709875 -1.95 ]
[ -2.8525 -1. -2.8525 -1.95 ]]
[[ -9.73315833 -10.24603711 -10.2462804 -99.99734386]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
训练1000的结果如下:
epoch:0, rewards_sum:-194
epoch:500, rewards_sum:-103
[[[ -9.94490996 -9.94360817 -9.97559011 -9.94452946]
[ -9.51585588 -9.44351953 -9.66630798 -9.44520752]
[ -8.94950359 -8.94035402 -8.97403876 -8.91880636]
[ -8.41839808 -8.37622344 -8.65962805 -8.3654473 ]
[ -7.91864297 -7.79685693 -7.97026698 -7.78756449]
[ -7.29418985 -7.19235148 -7.33119326 -7.17601956]
[ -6.58283168 -6.54340907 -6.5981396 -6.53974411]
[ -5.90440248 -5.88558833 -6.00542477 -5.8706739 ]
[ -5.20343271 -5.18787969 -5.47593413 -5.16342653]
[ -4.54982414 -4.43279411 -4.63806879 -4.42671076]
[ -3.72812163 -3.65173715 -3.92759929 -3.65447304]
[ -2.87025922 -2.84112686 -3.05667397 -2.87657142]]
[[ -9.79891126 -9.69117624 -9.88039999 -9.6878703 ]
[ -9.22944782 -9.15519518 -9.23105301 -9.15599234]
[ -8.74901972 -8.59526921 -8.9817166 -8.59481486]
[ -8.11761072 -8.00441628 -8.42395077 -8.0021373 ]
[ -7.47168752 -7.37896549 -7.57759868 -7.37745416]
[ -6.81696759 -6.71901959 -7.05975507 -6.71913662]
[ -6.26863472 -6.02464771 -6.41904386 -6.02488305]
[ -5.42534354 -5.29322521 -5.61223089 -5.2930385 ]
[ -4.71174584 -4.52179183 -4.72612995 -4.52188988]
[ -3.80662536 -3.70890451 -4.20602672 -3.70895616]
[ -3.39842609 -2.85226608 -2.91211438 -2.8522578 ]
[ -1.95923843 -1.94999982 -2.57083671 -2.31178442]]
[[ -9.47646205 -9.8651947 -9.52811595 -9.19279825]
[ -9.18880666 -92.02335569 -9.2185477 -8.62399815]
[ -8.6344983 -86.49148282 -8.63605077 -8.02526122]
[ -8.01909585 -89.05810109 -8.2776437 -7.39501181]
[ -7.61934948 -81.46979811 -7.72127988 -6.73159137]
[ -6.95180844 -89.05810109 -6.94801873 -6.03325408]
[ -6.37553803 -81.46979811 -6.08940406 -5.29816219]
[ -5.4130512 -84.99053647 -5.64342035 -4.52438125]
[ -4.95889574 -81.46979811 -5.12358091 -3.709875 ]
[ -4.28833417 -81.46979811 -4.04539097 -2.8525 ]
[ -3.50983057 -81.46979811 -3.54163826 -1.95 ]
[ -2.60357885 -1. -2.5591377 -1.83702669]]
[[ -9.73315832 -9.89557882 -9.84333736 -86.49148282]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
4. Q学习测试代码
import numpy as np
import random
Q = np.zeros([4, 12, 4])
# 从文件中读取Q值到Q表格中
def restore(npy_file='./q_table.npy'):
global Q
Q = np.load(npy_file)
#print(Q)
print(npy_file + ' loaded.')
restore('./q_table.npy')
#print(Q)
# 获取一个格子的状态
def get_state(row, col):
if row!=3:
return 'ground'
if row == 3 and col == 11:
return 'terminal'
if row == 3 and col == 0:
return 'ground'
return 'trap'
# 在某一状态下执行动作,获得对应奖励
def move(row, col, action):
# 状态检查-进入陷阱或结束,则不能执行任何动作,获得0奖励
if get_state(row, col) in ["trap", "terminal"]:
return row, col, 0
# 执行上下左右动作后,对应的位置变化
if action == 0:
row -= 1
if action == 1:
row += 1
if action == 2:
col -= 1
if action == 3:
col += 1
# 最小不能小于零,最大不能大于3
row = max(0, row)
row = min(3, row)
col = max(0, col)
col = min(11, col)
# 掉进trap奖励-100,其余每走一步奖励-1,让agent尽快完成任务
reward = -1
if get_state(row, col) == 'trap':
reward = -100
return row, col, reward
# 根据当前所处的格子,选取一个动作
def get_action1(row, col):
# 以一定的概率探索
#if random.random() < 0.1:
# return np.random.choice(range(4))
# 返回当前Q表格中分数最高的动作
return Q[row, col].argmax()
def test():
for epoch in range(20):
# 每次迭代开始,随机一个起点,尽可能多地与环境交互,同时绑定一个动作
row = np.random.choice(range(3))
#print(row)
col = np.random.choice(range(12))
print('Initial position:',row,col)
#print(col)
action = get_action1(row, col)
# 计算本轮奖励的总和,越来越大
rewards_sum = 0
# 一直取探索,直到游戏结束或者进入trap(要判断)
while get_state(row, col) not in ["terminal", "trap"]:
# 当前状态下移动一次,获得新的状态
next_row, next_col, reward = move(row, col, action)
#print(next_row,next_col)
next_action = get_action1(next_row, next_col)
print(next_row,next_col,next_action)
rewards_sum += reward
# 获取此次移动的更新量
#td_error = update(row, col, action, reward, next_row, next_col)
# 更新Q表格
#Q[row, col, action] += td_error
# 状态更新
row, col, action = next_row, next_col, next_action
if epoch % 1 == 0:
print(f"epoch:{epoch}, rewards_sum:{rewards_sum}")
#print(Q)
test()
测试运行结果如下:
0
epoch:0, rewards_sum:-14
3
epoch:1, rewards_sum:-13
3
epoch:2, rewards_sum:-13
1
epoch:3, rewards_sum:-13
1
epoch:4, rewards_sum:-13
2
epoch:5, rewards_sum:-12
0
epoch:6, rewards_sum:-14
1
epoch:7, rewards_sum:-13
2
epoch:8, rewards_sum:-12
1
epoch:9, rewards_sum:-13


