Tutorial 1: Optimal Control for Discrete States
Contents
Tutorial 1: Optimal Control for Discrete States¶
Week 3, Day 3: Optimal Control
By Neuromatch Academy
Content creators: Zhengwei Wu, Shreya Saxena, Xaq Pitkow
Content reviewers: Karolina Stosio, Roozbeh Farhoodi, Saeed Salehi, Ella Batty, Spiros Chavlis, Matt Krause, and Michael Waskom
Our 2021 Sponsors, including Presenting Sponsor Facebook Reality Labs
Tutorial Objectives¶
In this tutorial, we will implement a binary control task: a Partially Observable Markov Decision Process (POMDP) that describes fishing. The agent (you) seeks reward from two fishing sites without directly observing where the school of fish is (a group of fish is called a school!). This makes the world a Hidden Markov Model. Based on when and where you catch fish, you keep updating your belief about the fish location, i.e. the posterior of the fish given past observations. You should control your position to get the most fish while minimizing the cost of switching sides.
You’ve already learned about stochastic dynamics, latent states, and measurements. Now we introduce you to the new concepts of control, utility, and policy.
Setup¶
# Imports
import numpy as np
import scipy
from scipy.linalg import inv
from math import isclose
import matplotlib.pyplot as plt
Figure Settings¶
#@title Figure Settings
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import ipywidgets as widgets
from ipywidgets import interact, fixed, HBox, Layout, VBox, interactive, Label
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")
Plotting Functions¶
# @title Plotting Functions
def plot_fish(fish_state, ax=None):
"""
Plot the fish dynamics
"""
T = len(fish_state)
showlen = min(T, 200)
startT = 0
endT = startT + showlen
showT = range(startT, endT)
time_range = np.linspace(0, showlen - 1)
if not ax:
fig, ax = plt.subplots(1, 1, figsize=(12, 2.5))
ax.plot(- fish_state[showT], color='dodgerblue', markersize=10, linewidth=3.0)
ax.set_xlabel('time', fontsize=18)
ax.set_ylabel('Fish state', rotation=360, fontsize=18)
ax.yaxis.set_label_coords(-0.1, 0.25)
ax.set_xticks([0, showlen, showlen])
ax.tick_params(axis='both', which='major', labelsize=18)
ax.set_xlim([0, showlen])
ax.set_ylim([-1.1, 1.1])
ax.set_yticks([-1, 1])
ax.tick_params(axis='both', which='major', labelsize=18)
labels = [item.get_text() for item in ax.get_yticklabels()]
labels[0] = 'Right'
labels[1] = 'Left'
ax.set_yticklabels(labels)
def plot_measurement(measurement, ax=None):
"""
Plot the measurements
"""
T = len(measurement)
showlen = min(T, 200)
startT = 0
endT = startT + showlen
showT = range(startT, endT)
time_range = np.linspace(0, showlen - 1)
if not ax:
fig, ax = plt.subplots(1, 1, figsize=(12, 2.5))
ax.plot(measurement[showT], 'r*', markersize=5)
ax.set_xlabel('time', fontsize=18)
ax.set_ylabel('Measurement', rotation=360, fontsize=18)
ax.yaxis.set_label_coords(-0.2, 0.4)
ax.set_xticks([0, showlen, showlen])
ax.tick_params(axis='both', which='major', labelsize=18)
ax.set_xlim([0, showlen])
ax.set_ylim([-.1, 1.1])
ax.set_yticks([0, 1])
ax.set_yticklabels(['no fish', 'caught fish'])
def plot_act_loc(loc, act, ax_loc=None):
"""
Plot the action and location of 200 time points
"""
T = len(act)
showlen = min(T, 200)
startT = 0
endT = startT + showlen
showT = range(startT, endT)
time_range = np.linspace(0, showlen - 1)
if not ax_loc:
fig, ax_loc = plt.subplots(1, 1, figsize=(12, 2.5))
act_int = (act == "switch").astype(int)
ax_loc.plot(-loc[showT], 'g.-', markersize=8, linewidth=5)
ax_loc.plot((act_int[showT] * 4 - 3) * .5, 'rv', markersize=12,
label='switch')
ax_loc.set_xlabel('time', fontsize=18)
ax_loc.set_ylabel('Your state', rotation=360, fontsize=18)
ax_loc.legend(loc="upper right", fontsize=12)
ax_loc.set_xlim([0, showlen])
ax_loc.set_ylim([-1.1, 1.1])
ax_loc.set_yticks([-1, 1])
ax_loc.set_xticks([0, showlen, showlen])
ax_loc.tick_params(axis='both', which='major', labelsize=18)
labels = [item.get_text() for item in ax_loc.get_yticklabels()]
labels[1] = 'Left'
labels[0] = 'Right'
ax_loc.set_yticklabels(labels)
def plot_belief(belief, ax1=None, choose_policy=None):
"""
Plot the belief dynamics of 200 time points
"""
T = belief.shape[1]
showlen = min(T, 200)
startT = 0
endT = startT + showlen
showT = range(startT, endT)
time_range = np.linspace(0, showlen - 1)
if not ax1:
fig, ax1 = plt.subplots(1, 1, figsize=(12, 2.5))
ax1.plot(belief[0, showT], color='dodgerblue', markersize=10, linewidth=3.0)
ax1.yaxis.set_label_coords(-0.1, 0.25)
ax1.set_xlabel('time', rotation=360, fontsize=18)
ax1.set_ylabel('Belief on \n left', rotation=360, fontsize=18)
ax1.tick_params(axis='both', which='major', labelsize=18)
ax1.set_xlim([0, showlen])
ax1.set_yticks([0, 1])
ax1.set_ylim([0, 1.1])
ax1.set_xticks([0, showlen, showlen])
if choose_policy == "threshold":
ax2 = ax1.twinx()
ax2.plot(time_range, threshold * np.ones(time_range.shape), 'r--')
ax2.plot(time_range, (1 - threshold) * np.ones(time_range.shape), 'c--')
ax2.set_yticks([threshold, 1 - threshold])
ax2.set_ylim([0, 1.1])
ax2.tick_params(axis='both', which='major', labelsize=18)
labels = [item.get_text() for item in ax2.get_yticklabels()]
labels[0] = 'threshold to switch \n from left to right'
labels[-1] = 'threshold to switch \n from right to left'
ax2.set_yticklabels(labels)
def plot_dynamics(belief, loc, act, meas, fish_state, choose_policy):
"""
Plot the dynamics of 200 time points
"""
if choose_policy == 'threshold':
fig, [ax0, ax_loc, ax1, ax_bel] = plt.subplots(4, 1, figsize=(12, 9))
plot_fish(fish_state, ax=ax0)
plot_belief(belief, ax1=ax_bel)
plot_measurement(meas, ax=ax1)
plot_act_loc(loc, act, ax_loc=ax_loc)
else:
fig, [ax0, ax1, ax_bel] = plt.subplots(3, 1, figsize=(12, 7))
plot_fish(fish_state, ax=ax0)
plot_belief(belief, ax1=ax_bel)
plot_measurement(meas, ax=ax1)
plt.tight_layout()
plt.show()
def belief_histogram(belief, bins=100):
"""
Plot the histogram of belief states
"""
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.hist(belief, bins)
ax.set_xlabel('belief', fontsize=18)
ax.set_ylabel('count', fontsize=18)
plt.show()
def plot_value_threshold(cost_sw=0.5, T=10000, p_stay=.95,
high_rew_p=.4, low_rew_p=.1, step=.05):
"""
Helper function to plot the value function and threshold
"""
params = [T, p_stay, high_rew_p, low_rew_p, _]
threshold_array, value_array = value_threshold(params, cost_sw, step)
yrange = np.max(value_array) - np.min(value_array)
fig_, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(threshold_array, value_array, 'b')
ax.set_ylim([np.min(value_array) - yrange * .1, np.max(value_array) + yrange * .1])
ax.set_title(f'threshold vs value with switching cost c = {cost_sw:.2f}',
fontsize=20)
ax.set_xlabel('threshold', fontsize=16)
ax.set_ylabel('value', fontsize=16)
plt.show()
Helper Functions¶
# @title Helper Functions
binomial = np.random.binomial
class ExcerciseError(AssertionError):
pass
def test_policy_threshold():
well_done = True
for loc in [-1, 1]:
threshold = 0.4
belief = np.array([.2, .3])
if policy_threshold(threshold, belief, loc) != "switch":
raise ExcerciseError("'policy_threshold' function is not correctly implemented!")
for loc in [1, -1]:
threshold = 0.6
belief = np.array([.7, .8])
if policy_threshold(threshold, belief, loc) != "stay":
raise ExcerciseError("'policy_threshold' function is not correctly implemented!")
print("Well Done!")
def test_value_function():
measurement = np.array([0, 0, 0, 1, 0, 0, 0, 0, 1, 1])
act = np.array(["switch", "stay", "switch", "stay", "stay",
"stay", "switch", "switch", "stay", "stay"])
cost_sw = .5
if not isclose(value_function(measurement, act, cost_sw), .1):
raise ExcerciseError("'value_function' function is not correctly implemented!")
print("Well Done!")
Section 1: Dynamics of Fishing¶
Video 1: Gone fishing¶
There are two locations for the fish and you (Left and Right). If you’re on the same side as the fish, you’ll catch more, with probabilty \(q_{\rm high}\) per discrete time step. Otherwise you may still catch fish with probability \(q_{\rm low}\). One fish is worth 1 “point”.
The fish location \(s^{\rm fish}\) is latent. The only information you get about the fish location is when you catch one. Secretly at each time step, the fish may switch sides with a certain probability \(p_{\rm sw} = 1 - p_{\rm stay}\).
You are in control of your own location. You may stay on your current side with no cost, or switch to the other side and incur an action cost \(C\) (again, in units of fish).
You select controls or actions by following a policy. This defines what to do in any situation. Here the situation is specified by your location and your belief \(b_t\) about the fish location. For optimal control we assume that this belief is the posterior probability over the current fish location, given all the past measurements. We only need one number for this, since the fish are either on the left or the right. So we write
where \(m\) are the measurements, and \(a\) are the controls or actions (stay or switch).
Ultimately we will parameterize the policy by a simple threshold on beliefs. (This happens to be optimal if you pick the right threshold!) When your belief that fish are on your current side falls below a threshold \(\theta\), you switch to the other side.
Your overall goals in this tutorial are:
Measure when fish are caught, first if the school of fish doesn’t move.
For moving fish, plot their dynamics and your belief about it based on your measurements.
Compute the value for a given control policy.
Find the optimal policy for controlling your position.
Interactive Demo 1: Examining fish dynamics¶
In this demo, we will look at the dynamics of the fish moving from side to side. We will use the helper class binaryHMM
, implemented in the next cell. Please take some time to study this class and its methods. You will then see a demo where you can change the probability of switching states and examing the resulting dynamics of the fish.
class binaryHMM():
def __init__(self, params, fish_initial=-1, loc_initial=-1):
self.params = params
self.fish_initial = fish_initial
self.loc_initial = loc_initial
def fish_state_telegraph(self, fish_past, p_stay):
"""
fish state update according to telegraph process
Args:
fish_past (int): the fish location (-1 for left side, 1 for right side)
p_stay : the probability that the state of a certain site stays the same
Returns:
fish_new (int): updated fish location
"""
# we use logical operation XOR (denoted by ^ in python)
fish_new = (1 - binomial(1, p_stay)) ^ ((fish_past + 1) // 2)
fish_new = fish_new * 2 - 1
return fish_new
def fish_dynamics(self):
"""
fish state dynamics according to telegraph process
Returns:
fish_state (numpy array of int)
"""
T, p_stay, _, _, _ = self.params
fish_state = np.zeros(T, int) # -1: left side ; 1: right side
# initialization
fish_state[0] = self.fish_initial
for t in range(1, T):
fish_state[t] = self.fish_state_telegraph(fish_state[t - 1], p_stay)
return fish_state
def generate_process_lazy(self):
"""
fish dynamics and measurements if you always stay in the intial location
without changing sides
Returns:
fish_state (numpy array of int): locations of the fish
loc (numpy array of int): left or right site, -1 for left, and 1 for right
measurement (numpy array of binary): whether a reward is obtained
"""
T, _, high_rew_p, low_rew_p, _ = self.params
rew_p_vector = np.array([low_rew_p, high_rew_p])
fish_state = self.fish_dynamics()
loc = np.zeros(T, int) # -1: left side, 1: right side
measurement = np.zeros(T, int) # 0: no food, 1: get food
for t in range(0, T):
loc[t] = self.loc_initial
# new measurement
measurement[t] = binomial(1, rew_p_vector[(fish_state[t] == loc[t]) * 1])
return fish_state, loc, measurement
def update_ex_1(p_stay=.95, high_rew_p=.4, low_rew_p=.1, T=200):
"""
p_stay: probability fish stay
high_rew_p: p(catch fish) when you're on their side
low_rew_p : p(catch fish) when you're on other side
"""
params = [T, p_stay, high_rew_p, low_rew_p, _]
#### initial condition for fish [fish_initial] and you [loc_initial] ####
binaryHMM_test = binaryHMM(params, fish_initial=-1, loc_initial=-1)
fish_state = binaryHMM_test.fish_dynamics()
plot_fish(fish_state)
Make sure you execute this cell to enable the widget!
# @markdown Make sure you execute this cell to enable the widget!
widget=interactive(update_ex_1, {'manual': True},
high_rew_p=fixed(.4),
low_rew_p=fixed(.1),
p_stay=(.5, 1., .001),
T=fixed(200))
widget.children[-2].description='Run Simulation'
widget.children[-2].style.button_color='lightgreen'
controls = HBox(widget.children[:-1], layout=Layout(flex_flow='row wrap'))
output = widget.children[-1]
display(VBox([controls, output]))
Section 2: Catching fish¶
Video 2: Catch some fish¶
Interactive Demo 2: Catching fish¶
Now set \(p_{\rm stay} = 1\) so that the state of the two sites are fixed, and we can directly see the chances of catching fish on each side. The variable fish_initial
indicates the initial side of the fish, and loc_initial
indicates your initial location. They each take value \(-1\) for left and \(1\) for right.
Instructions:
set the two locations (
fish_initial
andloc_initial
) to be the same, and measure when you catch fish.set the two locations (
fish_initial
andloc_initial
) to be the different, and measure when you catch fish.visually compare the measurements from 1 and 2.
Finally, you can also play around with
high_rew_p
(high reward probability) andlow_rew_p
(low reward probability) sliders.
def update_ex_2(p_stay=1., high_rew_p=.6, low_rew_p=.05, T=100):
"""
p_stay: probability fish stay
high_rew_p: p(catch fish) when you're on their side
low_rew_p : p(catch fish) when you're on other side
"""
params = [T, p_stay, high_rew_p, low_rew_p, _]
#### initial condition for fish [fish_initial] and you [loc_initial] ####
binaryHMM_test = binaryHMM(params, fish_initial=-1, loc_initial=-1)
fish_state, loc, measurement = binaryHMM_test.generate_process_lazy()
plot_measurement(measurement)
Make sure you execute this cell to enable the widget!
#@markdown Make sure you execute this cell to enable the widget!
widget=interactive(update_ex_2, {'manual': True},
high_rew_p=(.0, 1., .001),
low_rew_p=(.0, 1., .001),
p_stay=fixed(1.),
T=fixed(100))
widget.children[-2].description='Run Simulation'
widget.children[-2].style.button_color='lightgreen'
controls = HBox(widget.children[:-1], layout=Layout(flex_flow='row wrap'))
output = widget.children[-1]
display(VBox([controls, output]))
Section 3: Belief dynamics and belief distributions¶
Video 3: Where are the fish?¶
We have provided a class for the binary task, with the agent always staying at one side following a lazy policy function def policy_lazy(belief, loc)
that we provided. Now in this exercise, you will extend the module to generate the real dynamics, including beliefs and a moving agent. With the generated data, we will see how beliefs change over time, and how often different beliefs happen.
For convenience, your belief at time t is actually a 2-dimensional vector. The first element is the belief that the fish are on the left, and the second element is the belief the fish are on the right. At every time, these elements sum to \(1\).
We will first check the dynamics with lazy policy, and then explore the case with a threshold-based policy.
Evaluate the cells below to setup the lazy policy.
class binaryHMM_belief(binaryHMM):
def __init__(self, params, fish_initial = -1, loc_initial = -1,
choose_policy = 'threshold'):
binaryHMM.__init__(self, params, fish_initial, loc_initial)
self.choose_policy = choose_policy
def generate_process(self):
"""
fish dynamics and measurements based on the choosen policy
Returns:
belief (numpy array of float): belief on the states of the two sites
act (numpy array of string): actions over time
loc (numpy array of int): left or right site
measurement (numpy array of binary): whether a reward is obtained
fish_state (numpy array of int): fish locations
"""
T, p_stay, high_rew_p, low_rew_p, threshold = self.params
fish_state = self.fish_dynamics() # -1: left side; 1: right side
loc = np.zeros(T, int) # -1: left side, 1: right side
measurement = np.zeros(T, int) # 0: no food, 1: get food
act = np.empty(T, dtype='object') # "stay", or "switch"
belief = np.zeros((2, T), float) # the probability that the fish is on the left (1st element)
# or on the right (2nd element),
# the beliefs on the two boxes sum up to be 1
rew_prob = np.array([low_rew_p, high_rew_p])
# initialization
loc[0] = -1
measurement[0] = 0
belief_0 = np.random.random(1)[0]
belief[:, 0] = np.array([belief_0, 1 - belief_0])
act[0] = self.policy(threshold, belief[:, 0], loc[0])
for t in range(1, T):
if act[t - 1] == "stay":
loc[t] = loc[t - 1]
else:
loc[t] = - loc[t - 1]
# new measurement
measurement[t] = binomial(1, rew_prob[(fish_state[t] == loc[t]) * 1])
belief[0, t] = self.belief_update(belief[0, t - 1] , loc[t],
measurement[t], p_stay,
high_rew_p, low_rew_p)
belief[1, t] = 1 - belief[0, t]
act[t] = self.policy(threshold, belief[:, t], loc[t])
return belief, loc, act, measurement, fish_state
def policy(self, threshold, belief, loc):
"""
chooses policy based on whether it is lazy policy
or a threshold-based policy
Args:
threshold (float): the threshold of belief on the current site,
when the belief is lower than the threshold, switch side
belief (numpy array of float): the belief on the two sites
loc (int) : the location of the agent
Returns:
act (string): "stay" or "switch"
"""
if self.choose_policy == "threshold":
act = policy_threshold(threshold, belief, loc)
if self.choose_policy == "lazy":
act = policy_lazy(belief, loc)
return act
def belief_update(self, belief_past, loc, measurement, p_stay,
high_rew_p, low_rew_p):
"""
using PAST belief on the LEFT box, CURRENT location and
and measurement to update belief
"""
rew_prob_matrix = np.array([[1 - high_rew_p, high_rew_p],
[1 - low_rew_p, low_rew_p]])
# update belief posterior, p(s[t] | measurement(0-t), act(0-t-1))
belief_0 = (belief_past * p_stay + (1 - belief_past) * (1 - p_stay)) *\
rew_prob_matrix[(loc + 1) // 2, measurement]
belief_1 = ((1 - belief_past) * p_stay + belief_past * (1 - p_stay)) *\
rew_prob_matrix[1-(loc + 1) // 2, measurement]
belief_0 = belief_0 / (belief_0 + belief_1)
return belief_0
def policy_lazy(belief, loc):
"""
This function is a lazy policy where stay is also taken
"""
act = "stay"
return act
Interactive Demo 3: Task dynamics following a lazy policy¶
The parameter for policy choose_policy
can be either “lazy” or “threshold”. In the following example, use the lazy policy.
Instructions:
With the class defined above, we have created an object of
binaryHMM_belief
given parameters of the dynamics, params, and a parameter for policy.Run the dynamics and explain the time series of the beliefs you see.
def update_ex_3(p_stay=.98, threshold=.2, high_rew_p=.4, low_rew_p=.1, T=200):
"""
p_stay: probability fish stay
high_rew_p: p(catch fish) when you're on their side
low_rew_p : p(catch fish) when you're on other side
threshold: threshold of belief below which switching is taken
"""
params = [T, p_stay, high_rew_p, low_rew_p, threshold]
#### initial condition for fish [fish_initial] and you [loc_initial] ####
binaryHMM_test = binaryHMM_belief(params, choose_policy="lazy",
fish_initial=-1, loc_initial=-1)
belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()
plot_dynamics(belief, loc, act, measurement, fish_state,
binaryHMM_test.choose_policy)
Make sure you execute this cell to enable the widget!
#@markdown Make sure you execute this cell to enable the widget!
widget=interactive(update_ex_3, {'manual': True},
high_rew_p=(.0, 1., .001),
low_rew_p=(.0, 1., .001),
p_stay=(.5, 1., .001),
T=fixed(200),
threshold=fixed(.2))
widget.children[-2].description='Run Simulation'
widget.children[-2].style.button_color='lightgreen'
controls = HBox(widget.children[:-1], layout=Layout(flex_flow='row wrap'))
output = widget.children[-1]
display(VBox([controls, output]))
Section 4: Implementing threshold policy¶
Video 4: How should you act?¶
Coding Exercise 4: dynamics following a threshold-based policy.¶
Instructions:
You need to code a new policy
def policy_threshold(threshold, belief, loc)
. The policy takes three inputs: your belief about the fish state, your location (“Left” or “Right”), and a belief threshold: when your belief that you are on the same side as the fish drops below this threshold, you choose to switch; otherwise you stay.You should return an action for each time t, which takes the value of “stay” or “switch”.
After you complete the code for the policy based on threshold, create an object of
binaryHMM_belief
and set the policy parameter to bechoose_policy = threshold
.We have provided an example of the parameters. You should play with the parameters to see the various dynamics.
def policy_threshold(threshold, belief, loc):
"""
chooses whether to switch side based on whether the belief
on the current site drops below the threshold
Args:
threshold (float): the threshold of belief on the current site,
when the belief is lower than the threshold, switch side
belief (numpy array of float, 2-dimensional): the belief on the
two sites at a certain time
loc (int) : the location of the agent at a certain time
-1 for left side, 1 for right side
Returns:
act (string): "stay" or "switch"
"""
############################################################################
## Insert code to:
## generate actions (Stay or Switch) for current belief and location
##
## Belief is a 2d vector: first element = Prob(fish on Left | measurements)
## second element = Prob(fish on Right | measurements)
## Returns "switch" if Belief that fish are in your current location < threshold
## "stay" otherwise
##
## Hint: use loc value to determine which row of belief you need to use
## see the docstring for more information about loc
##
## complete the function and remove
raise NotImplementedError("Student exercise: Please complete <act>")
############################################################################
# Write the if statement
if ...:
# action below threshold
act = ...
else:
# action above threshold
act = ...
return act
# Test your function
test_policy_threshold()
Interactive Demo 4: Dynamics with different thresholds¶
def update_ex_4(p_stay=.98, threshold=.2, high_rew_p=.4, low_rew_p=.1, T=200):
"""
p_stay: probability fish stay
high_rew_p: p(catch fish) when you're on their side
low_rew_p : p(catch fish) when you're on other side
threshold: threshold of belief below which switching is taken
"""
params = [T, p_stay, high_rew_p, low_rew_p, threshold]
#### initial condition for fish [fish_initial] and you [loc_initial] ####
binaryHMM_test = binaryHMM_belief(params, fish_initial=-1, loc_initial=-1,
choose_policy="threshold")
belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()
plot_dynamics(belief, loc, act, measurement,
fish_state, binaryHMM_test.choose_policy)
Make sure you execute this cell to enable the widget!
#@markdown Make sure you execute this cell to enable the widget!
widget=interactive(update_ex_4, {'manual': True},
high_rew_p=fixed(.4),
low_rew_p=fixed(.1),
p_stay=fixed(.95),
T=fixed(200),
threshold=(.0, 1., .001))
widget.children[-2].description='Run Simulation'
widget.children[-2].style.button_color='lightgreen'
controls = HBox(widget.children[:-1], layout=Layout(flex_flow='row wrap'))
output = widget.children[-1]
display(VBox([controls, output]))
Section 5: Implementing a value function¶
Video 5: Evaluate policy¶
Coding Exercise 5: Implementing a value function¶
Now we have generated behavior for a policy parameterized by a threshold. While it seems clear that this is at least better than being lazy, we want to know how good it is. For that, we will calculate a value function. We will use this value to compare different policies, and maximize the amount of fish we catch while minimizing our effort.
Specifically, here the value is total expected utility per unit time.
where \(U_s(s_t)\) is the instantaneous utility (reward) from the site, and \(U_a(a_t)\) is the utility (negative cost) for the chosen action. Here, the action cost is 0 if you stay, and cost_sw
if you switch.
We could take this average mathematically over the probabilities of rewards and actions. More simply, we get the same answer by simply averaging the actual rewards and costs over a long time, so that’s what you should do.
Instructions
Fill in the function
value_function(measurement, act, cost_sw)
given a sequence of measurements, actions, and the cost of switching.Visually find the threshold that yields the highest total value. We have provided code for plotting value versus threshold. The threshold \(\theta^*\) with the highest value gives the optimal policy for controlling where you should fish.
def value_function(measurement, act, cost_sw):
"""
value function
Args:
act (numpy array of string): length T with each element
taking value "stay" or "switch"
cost_sw (float): the cost of switching side
measurement (numpy array of binary): whether a reward is obtained
Returns:
value (float): expected utility per unit time
"""
act_int = (act == "switch").astype(int)
T = len(measurement)
############################################################################
## Insert your code here to:
## compute the value function = rate of catching fish - costs
##
## complete the function and remove
raise NotImplementedError("Student exercise: Please complete <value>")
############################################################################
# Calculate the value function
value = ...
return value
# Test your function
test_value_function()
# Brute force search for optimal policy: loop over thresholds and compute value for each.
# This function is needed for the second exercise.
def value_threshold(params, cost_sw, step):
threshold_array = np.arange(0, .5 + step, step)
value_array = np.zeros(threshold_array.shape)
T, p_stay, high_rew_p, low_rew_p, _ = params
for i in range(len(threshold_array)):
threshold = threshold_array[i]
params = [T, p_stay, high_rew_p, low_rew_p, threshold]
binaryHMM_test = binaryHMM_belief(params, choose_policy="threshold")
belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()
value_array[i] = value_function(measurement, act, cost_sw)
return threshold_array, value_array
plot_value_threshold(cost_sw=0.5, p_stay=0.95, high_rew_p=0.4, low_rew_p=0.1, T=10000)
Summary¶
In this tutorial, you have used the Hidden Markov Models you learned about yesterday to figure out where the fish are, and then you have acted (moved location) to catch the most fish. In particular, you have found an optimal policy to guide your actions.
Bonus¶
Bonus Section 1: Different task, different optimal policy¶
Video 6: Sensitivity of optimal policy¶
Bonus Interactive Demo 1¶
Instructions: After plotting value versus threshold, adjust various task parameters using the sliders below, and observe how the optimal threshold moves with
switching cost (
cost_sw
)fish dynamics (
p_switch
)probability of catching fish on each side,
low_rew_p
andhigh_rew_p
Can you explain why the optimal threshold changes with these parameters?
EXPLAIN why the optimal threshold changes for:
lower switching cost?
faster fish dynamics?
rarer fish?
Note that it may require long simulations to see subtle changes in values of different policies, so look for coarse trends first.
¶
Make sure you execute this cell to enable the widget!
#@title
#@markdown Make sure you execute this cell to enable the widget!
widget=interactive(plot_value_threshold, {'manual': True},
T=fixed(10000),
p_stay=(0.5, 1., 0.001),
high_rew_p=(0., 1., 0.001),
low_rew_p=(0., 1., 0.001),
cost_sw=(0., 2., .1),
step=fixed(0.1))
widget.children[-2].description='Run Simulation'
widget.children[-2].style.button_color='lightgreen'
controls = HBox(widget.children[:-1], layout=Layout(flex_flow='row wrap'))
output = widget.children[-1]
display(VBox([controls, output]))