Accommodate changes to search-graph.

This commit is contained in:
2025-01-21 23:12:54 -05:00
parent 844120f0b4
commit d413b757d4
6 changed files with 365 additions and 372 deletions

View File

@@ -10,18 +10,19 @@ use crate::SearchParameters;
use log::trace;
use rand::Rng;
use rand::seq::IteratorRandom;
use search_graph::{EdgeRef, Open, Graph, VertexRef};
/// Provides a method for selecting incoming parent edges to follow during
/// backprop phase of MCTS.
pub trait BackpropSelector<'id>: for<'a> From<&'a SearchParameters> {
type Items: Iterator<Item = search_graph::view::EdgeRef<'id>>;
type Items: Iterator<Item = EdgeRef<Open<'id>>>;
/// Returns the edges to follow when pushing statistics back up through the
/// search graph.
fn select<G: Game, R: Rng>(
&self,
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
node: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
node: VertexRef<Open<'id>>,
payoff: &G::Payoff,
rng: &mut R,
) -> Self::Items;
@@ -44,16 +45,16 @@ impl<'a> From<&'a SearchParameters> for FirstParentSelector {
}
impl<'id> BackpropSelector<'id> for FirstParentSelector {
type Items = std::option::IntoIter<search_graph::view::EdgeRef<'id>>;
type Items = std::option::IntoIter<EdgeRef<Open<'id>>>;
fn select<G: Game, R: Rng>(
&self,
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
node: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
node: VertexRef<Open<'id>>,
_payoff: &G::Payoff,
_rng: &mut R,
) -> Self::Items {
graph.parents(node).next().into_iter()
graph.vertex(node).parents.first().copied().into_iter()
}
}
@@ -72,16 +73,16 @@ impl<'a> From<&'a SearchParameters> for RandomParentSelector {
}
impl<'id> BackpropSelector<'id> for RandomParentSelector {
type Items = std::option::IntoIter<search_graph::view::EdgeRef<'id>>;
type Items = std::option::IntoIter<EdgeRef<Open<'id>>>;
fn select<G: Game, R: Rng>(
&self,
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
node: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
node: VertexRef<Open<'id>>,
_payoff: &G::Payoff,
rng: &mut R,
) -> Self::Items {
graph.parents(node).choose(rng).into_iter()
graph.vertex(node).parents.iter().copied().choose(rng).into_iter()
}
}
@@ -93,12 +94,12 @@ impl<'id> BackpropSelector<'id> for RandomParentSelector {
/// yields may be affected by statistics updates if they are applied during
/// iteration.
pub fn backprop_iter<'a, 'b, 'id, G, S, R>(
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
node: search_graph::view::NodeRef<'id>,
graph: &'b Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
node: VertexRef<Open<'id>>,
payoff: &'b G::Payoff,
selector: &'b S,
rng: &'b mut R,
) -> impl Iterator<Item = search_graph::view::EdgeRef<'id>> + 'b
) -> impl Iterator<Item = EdgeRef<Open<'id>>> + 'b
where
'a: 'b,
G: Game,
@@ -109,34 +110,34 @@ where
}
/// Chains together parent traversals.
struct BackpropIter<'a, 'b, 'id, G, S, R>
struct BackpropIter<'a, 'id, G, S, R>
where
G: Game,
S: BackpropSelector<'id> + 'b,
S: BackpropSelector<'id> + 'a,
R: Rng,
{
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
/// Nodes whose parent edges to traverse.
stack: Vec<search_graph::view::NodeRef<'id>>,
stack: Vec<VertexRef<Open<'id>>>,
/// Edges from most recently examined node.
parent_edges: S::Items,
payoff: &'b G::Payoff,
selector: &'b S,
rng: &'b mut R,
payoff: &'a G::Payoff,
selector: &'a S,
rng: &'a mut R,
}
impl<'a, 'b, 'id, G, S, R> BackpropIter<'a, 'b, 'id, G, S, R>
impl<'a, 'id, G, S, R> BackpropIter<'a, 'id, G, S, R>
where
G: Game,
S: BackpropSelector<'id> + 'b,
S: BackpropSelector<'id> + 'a,
R: Rng,
{
fn new(
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
node: search_graph::view::NodeRef<'id>,
payoff: &'b G::Payoff,
selector: &'b S,
rng: &'b mut R,
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
node: VertexRef<Open<'id>>,
payoff: &'a G::Payoff,
selector: &'a S,
rng: &'a mut R,
) -> Self {
let parent_edges = selector.select(graph, node, payoff, rng);
BackpropIter {
@@ -150,18 +151,18 @@ where
}
}
impl<'a, 'b, 'id, G, S, R> Iterator for BackpropIter<'a, 'b, 'id, G, S, R>
impl<'a, 'id, G, S, R> Iterator for BackpropIter<'a, 'id, G, S, R>
where
G: Game,
S: BackpropSelector<'id> + 'b,
S: BackpropSelector<'id> + 'a,
R: Rng,
{
type Item = search_graph::view::EdgeRef<'id>;
type Item = EdgeRef<Open<'id>>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(parent) = self.parent_edges.next() {
if !self.graph.edge_data(parent).mark_backprop_traversal() {
self.stack.push(self.graph.edge_source(parent));
if !self.graph.edge(parent).data.mark_backprop_traversal() {
self.stack.push(self.graph.edge(parent).source);
return Some(parent);
}
}
@@ -171,8 +172,8 @@ where
.selector
.select(self.graph, node, self.payoff, self.rng);
while let Some(parent) = self.parent_edges.next() {
if !self.graph.edge_data(parent).mark_backprop_traversal() {
self.stack.push(self.graph.edge_source(parent));
if !self.graph.edge(parent).data.mark_backprop_traversal() {
self.stack.push(self.graph.edge(parent).source);
return Some(parent);
}
}
@@ -187,23 +188,23 @@ where
/// Iterable view over parents of a graph node that selects parents for which
/// this node is a best child.
pub struct ParentSelectionIter<'a, 'b, 'id, G, I>
pub struct ParentSelectionIter<'a, 'id, G, I>
where
G: Game,
I: Iterator<Item = search_graph::view::EdgeRef<'id>>,
I: Iterator<Item = EdgeRef<Open<'id>>>,
{
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
parents: I,
explore_bias: f64,
}
impl<'a, 'b, 'id, G, I> ParentSelectionIter<'a, 'b, 'id, G, I>
impl<'a, 'id, G, I> ParentSelectionIter<'a, 'id, G, I>
where
G: Game,
I: Iterator<Item = search_graph::view::EdgeRef<'id>>,
I: Iterator<Item = EdgeRef<Open<'id>>>,
{
pub fn new(
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
parents: I,
explore_bias: f64,
) -> Self {
@@ -215,19 +216,19 @@ where
}
}
impl<'a, 'b, 'id, G, I> Iterator for ParentSelectionIter<'a, 'b, 'id, G, I>
impl<'a, 'id, G, I> Iterator for ParentSelectionIter<'a, 'id, G, I>
where
G: Game,
I: Iterator<Item = search_graph::view::EdgeRef<'id>>,
I: Iterator<Item = EdgeRef<Open<'id>>>,
{
type Item = search_graph::view::EdgeRef<'id>;
type Item = EdgeRef<Open<'id>>;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.parents.next() {
None => return None,
Some(e) => {
if self.graph.edge_data(e).mark_backprop_traversal() {
if self.graph.edge(e).data.mark_backprop_traversal() {
trace!("ParentSelectionIter::next: edge was already visited in backtrace",);
continue;
}
@@ -235,14 +236,14 @@ where
trace!(
"ParentSelectionIter::next: edge {:?} (from node {:?}) is a best child",
e,
self.graph.edge_source(e),
self.graph.edge(e).source,
);
return Some(e);
}
trace!(
"ParentSelectionIter::next: edge {:?} (data {:?}) is not a best child",
e,
self.graph.edge_data(e),
self.graph.edge(e).data,
);
}
}
@@ -257,8 +258,8 @@ where
/// `selector`. Ancestors are selected by recursively applying `selector` to the
/// parent that it selects.
pub fn backprop<'a, 'id, G, S, R>(
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
node: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
node: VertexRef<Open<'id>>,
payoff: &G::Payoff,
selector: &S,
rng: &mut R,
@@ -271,7 +272,7 @@ pub fn backprop<'a, 'id, G, S, R>(
// updating the statistics alters best child status, and backprop_iter returns
// a lazy iterator.
let statistics: Vec<&EdgeData<G>> = backprop_iter(graph, node, payoff, selector, rng)
.map(|edge| graph.edge_data(edge))
.map(|edge| &graph.edge(edge).data)
.collect();
for s in statistics {
s.statistics.increment(payoff);

View File

@@ -22,6 +22,7 @@ use std::result::Result;
use log::trace;
use rand::Rng;
use search_graph::{Closed, Open, Graph, VertexRef};
/// Wraps a decision made by a rollout policy.
///
@@ -63,8 +64,8 @@ pub struct ActionStatistics<G: Game> {
/// Creates a new search graph suitable for Monte Carlo tree search through the
/// state space of the game `G`.
pub fn new_search_graph<G: Game>() -> search_graph::Graph<G::State, VertexData, EdgeData<G>> {
search_graph::Graph::<G::State, VertexData, EdgeData<G>>::new()
pub fn new_search_graph<G: Game>() -> search_graph::Graph<Closed, G::State, VertexData, EdgeData<G>> {
Graph::default()
}
/// Parameters for a round of Monte Carlo tree search.
@@ -80,38 +81,34 @@ pub struct SearchParameters {
/// Recursively traverses the search graph to find a game state from which to
/// perform payoff estimates.
pub struct RolloutPhase<'a, 'id, R: Rng, G: Game> {
pub struct RolloutPhase<'id, R: Rng> {
rng: R,
settings: SearchParameters,
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
root_node: search_graph::view::NodeRef<'id>,
root_node: VertexRef<Open<'id>>,
}
impl<'a, 'id, R: Rng, G: Game> RolloutPhase<'a, 'id, R, G> {
pub fn initialize(
impl<'id, R: Rng> RolloutPhase<'id, R> {
pub fn initialize<G: Game>(
rng: R,
settings: SearchParameters,
root_state: G::State,
mut graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
graph: &mut Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
) -> Self {
trace!("initializing rollout phase to state: {:?}", root_state);
let root_node = match graph.find_node(&root_state) {
Some(n) => n,
None => graph.append_node(root_state.clone(), VertexData::default()),
};
let root_node = graph.find_or_add_vertex(root_state.clone(), VertexData::default);
RolloutPhase {
rng,
settings,
graph,
root_node,
}
}
pub fn rollout<S: RolloutSelector>(
pub fn rollout<G: Game, S: RolloutSelector>(
mut self,
) -> Result<ScoringPhase<'a, 'id, R, G>, rollout::RolloutError<G, S::Error>> {
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
) -> Result<ScoringPhase<'id, R>, rollout::RolloutError<G, S::Error>> {
let result = rollout::rollout(
&self.graph,
graph,
self.root_node,
S::from(&self.settings),
&mut self.rng,
@@ -121,52 +118,50 @@ impl<'a, 'id, R: Rng, G: Game> RolloutPhase<'a, 'id, R, G> {
trace!("rollout result has target node: {:?}", node);
trace!(
"rollout result has state: {:?}",
self.graph.node_state(node)
graph.game_state_at(node)
);
ScoringPhase {
rng: self.rng,
settings: self.settings,
graph: self.graph,
root_node: self.root_node,
rollout_node: node,
}
})
}
pub fn root_node(&self) -> search_graph::view::NodeRef<'id> {
pub fn root_node(&self) -> VertexRef<Open<'id>> {
self.root_node
}
pub fn recover_components(
pub fn recover_rng(
self,
) -> (
R,
search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
) {
(self.rng, self.graph)
) -> R {
self.rng
}
}
/// Computes an estimate of the score for a game state selected during rollout.
pub struct ScoringPhase<'a, 'id, R: Rng, G: Game> {
pub struct ScoringPhase<'id, R: Rng> {
rng: R,
settings: SearchParameters,
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
root_node: search_graph::view::NodeRef<'id>,
rollout_node: search_graph::view::NodeRef<'id>,
root_node: VertexRef<Open<'id>>,
rollout_node: VertexRef<Open<'id>>,
}
impl<'a, 'id, R: Rng, G: Game> ScoringPhase<'a, 'id, R, G> {
pub fn root_node(&self) -> search_graph::view::NodeRef<'id> {
impl<'id, R: Rng> ScoringPhase<'id, R> {
pub fn root_node(&self) -> VertexRef<Open<'id>> {
self.root_node
}
pub fn rollout_node(&self) -> search_graph::view::NodeRef<'id> {
pub fn rollout_node(&self) -> VertexRef<Open<'id>> {
self.rollout_node
}
pub fn score<S: Simulator>(mut self) -> Result<BackpropPhase<'a, 'id, R, G>, S::Error> {
let payoff = match G::payoff_of(self.graph.node_state(self.rollout_node())) {
pub fn score<G: Game, S: Simulator>(
mut self,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
) -> Result<BackpropPhase<'id, R, G>, S::Error> {
let payoff = match G::payoff_of(graph.game_state_at(self.rollout_node())) {
Some(p) => {
trace!("direct payoff found: {:?}", p);
p
@@ -174,14 +169,13 @@ impl<'a, 'id, R: Rng, G: Game> ScoringPhase<'a, 'id, R, G> {
None => {
trace!("simulating to find payoff");
let simulator = S::from(&self.settings);
simulator.simulate::<G, R>(self.graph.node_state(self.rollout_node()), &mut self.rng)?
simulator.simulate::<G, R>(graph.game_state_at(self.rollout_node()), &mut self.rng)?
}
};
trace!("scoring phase finds payoff {:?}", payoff);
Ok(BackpropPhase {
rng: self.rng,
settings: self.settings,
graph: self.graph,
root_node: self.root_node,
rollout_node: self.rollout_node,
payoff,
@@ -193,19 +187,21 @@ impl<'a, 'id, R: Rng, G: Game> ScoringPhase<'a, 'id, R, G> {
///
/// The strategy for finding the game-state statistics to update during backprop
/// is determined by `BackpropSelector`.
pub struct BackpropPhase<'a, 'id, R: Rng, G: Game> {
pub struct BackpropPhase<'id, R: Rng, G: Game> {
rng: R,
settings: SearchParameters,
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
root_node: search_graph::view::NodeRef<'id>,
rollout_node: search_graph::view::NodeRef<'id>,
root_node: VertexRef<Open<'id>>,
rollout_node: VertexRef<Open<'id>>,
payoff: G::Payoff,
}
impl<'a, 'id, R: Rng, G: Game> BackpropPhase<'a, 'id, R, G> {
pub fn backprop<S: BackpropSelector<'id>>(mut self) -> ExpandPhase<'a, 'id, R, G> {
impl<'id, R: Rng, G: Game> BackpropPhase<'id, R, G> {
pub fn backprop<S: BackpropSelector<'id>>(
mut self,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
) -> ExpandPhase<'id, R> {
backprop::backprop(
&self.graph,
graph,
self.rollout_node,
&self.payoff,
&S::from(&self.settings),
@@ -214,7 +210,6 @@ impl<'a, 'id, R: Rng, G: Game> BackpropPhase<'a, 'id, R, G> {
ExpandPhase {
rng: self.rng,
settings: self.settings,
graph: self.graph,
root_node: self.root_node,
rollout_node: self.rollout_node,
}
@@ -227,45 +222,35 @@ impl<'a, 'id, R: Rng, G: Game> BackpropPhase<'a, 'id, R, G> {
/// state, adding them to the graph if they don't already exist, and then
/// creating an edge from the original node to the node for the resulting game
/// state.
pub struct ExpandPhase<'a, 'id, R: Rng, G: Game> {
pub struct ExpandPhase<'id, R: Rng> {
rng: R,
settings: SearchParameters,
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
root_node: search_graph::view::NodeRef<'id>,
rollout_node: search_graph::view::NodeRef<'id>,
root_node: VertexRef<Open<'id>>,
rollout_node: VertexRef<Open<'id>>,
}
impl<'a, 'id, R: Rng, G: Game> ExpandPhase<'a, 'id, R, G> {
pub fn expand(mut self) -> RolloutPhase<'a, 'id, R, G> {
if self.graph.node_data(self.rollout_node).mark_expanded() {
impl<'id, R: Rng> ExpandPhase<'id, R> {
pub fn expand<G: Game>(
self,
graph: &mut Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
) -> RolloutPhase<'id, R> {
if graph.vertex(self.rollout_node).data.mark_expanded() {
trace!("rollout node was already marked as expanded; ExpandPhase does nothing");
} else {
let parent_state = self.graph.node_state(self.rollout_node).clone();
let parent_state = graph.game_state_at(self.rollout_node).clone();
for action in parent_state.actions() {
trace!("ExpandPhase adds edge for action {:?}", action);
let mut child_state = parent_state.clone();
trace!("ExpandState old state: {:?}", child_state);
child_state.do_action(&action);
trace!("ExpandState new state: {:?}", child_state);
let child = match self.graph.find_node(&child_state) {
Some(n) => {
trace!("ExpandState expanded to existing game state");
n
}
None => {
trace!("ExpandState expanded to new game state");
self.graph.append_node(child_state, Default::default())
}
};
self
.graph
.append_edge(self.rollout_node, child, EdgeData::new(action));
let child = graph.find_or_add_vertex(child_state, Default::default);
graph.find_or_add_edge(self.rollout_node, child, || EdgeData::new(action));
}
}
RolloutPhase {
rng: self.rng,
settings: self.settings,
graph: self.graph,
root_node: self.root_node,
}
}
@@ -298,43 +283,38 @@ mod test {
}
}
type Graph = search_graph::Graph<tictactoe::State, VertexData, EdgeData<tictactoe::ScoredGame>>;
type Graph = search_graph::Graph<search_graph::Closed, tictactoe::State, VertexData, EdgeData<tictactoe::ScoredGame>>;
#[test]
fn rollout_init() {
let mut graph = Graph::new();
search_graph::view::of_graph(&mut graph, |view| {
Graph::new().open(|mut graph| {
RolloutPhase::initialize(
default_rng(),
default_settings(),
default_game_state(),
view,
&mut graph,
);
});
let node = graph.find_node(&default_game_state());
assert!(node.is_some());
let node = node.unwrap();
assert!(node.is_leaf());
assert!(node.is_root());
assert_eq!(default_game_state(), *node.get_label());
let node = graph.find_vertex_with_state(&default_game_state()).unwrap();
let v = graph.vertex(node);
assert!(v.children.is_empty());
assert!(v.parents.is_empty());
assert_eq!(default_game_state(), *graph.game_state_at(node));
});
}
#[test]
fn integration_test() {
let mut graph = Graph::new();
search_graph::view::of_graph(&mut graph, |view| {
Graph::new().open(|mut graph| {
// Rollout.
let rollout_phase = RolloutPhase::initialize(
default_rng(),
default_settings(),
default_game_state(),
view,
&mut graph,
);
let rollout_target = crate::rollout::rollout(
&rollout_phase.graph,
&graph,
rollout_phase.root_node,
ucb::Rollout::from(&default_settings()),
&mut default_rng(),
@@ -343,301 +323,303 @@ mod test {
assert_eq!(rollout_phase.root_node(), rollout_target);
assert_eq!(
tictactoe::State::default(),
*rollout_phase.graph.node_state(rollout_target)
*graph.game_state_at(rollout_target)
);
// Scoring.
let scoring_phase: ScoringPhase<'_, '_, _, tictactoe::ScoredGame> =
rollout_phase.rollout::<ucb::Rollout>().unwrap();
let scoring_phase: ScoringPhase<'_, _> =
rollout_phase.rollout::<_, ucb::Rollout>(&graph).unwrap();
assert_eq!(scoring_phase.rollout_node(), rollout_target);
assert_eq!(scoring_phase.root_node(), rollout_target);
// Backprop.
let backprop_phase: BackpropPhase<'_, '_, _, tictactoe::ScoredGame> = scoring_phase
.score::<simulation::RandomSimulator>()
let backprop_phase: BackpropPhase<'_, _, tictactoe::ScoredGame> = scoring_phase
.score::<_, simulation::RandomSimulator>(&graph)
.unwrap();
// Expand.
let expand_phase: ExpandPhase<'_, '_, _, tictactoe::ScoredGame> =
backprop_phase.backprop::<backprop::FirstParentSelector>();
expand_phase.expand();
});
let expand_phase: ExpandPhase<'_, _> =
backprop_phase.backprop::<backprop::FirstParentSelector>(&graph);
expand_phase.expand(&mut graph);
{
// Search graph should consist of root, edges for each possible move, and
// leaves for the result of each move.
let node = graph.find_node(&Default::default()).unwrap();
assert!(node.is_root());
assert!(!node.is_leaf());
assert_eq!(9, node.get_child_list().len());
for child in node.get_child_list().iter() {
assert_eq!(0, child.get_data().statistics.visits());
let node = graph.find_vertex_with_state(&Default::default()).unwrap();
let v = graph.vertex(node);
assert!(v.parents.is_empty());
assert!(!v.children.is_empty());
assert_eq!(9, v.children.len());
for child in v.children.iter().cloned() {
assert_eq!(0, graph.edge(child).data.statistics.visits());
assert_eq!(
0,
child
.get_data()
graph.edge(child)
.data
.statistics
.score(crate::statistics::two_player::Player::One)
);
assert_eq!(
0,
child
.get_data()
graph
.edge(child)
.data
.statistics
.score(crate::statistics::two_player::Player::Two)
);
assert!(child.get_target().is_leaf());
assert!(!child.get_target().is_root());
assert!(graph.vertex(graph.edge(child).target).children.is_empty());
assert!(!graph.vertex(graph.edge(child).target).parents.is_empty());
}
}
search_graph::view::of_graph(&mut graph, |view| {
RolloutPhase::initialize(
default_rng(),
default_settings(),
default_game_state(),
view,
&mut graph,
)
.rollout::<ucb::Rollout>()
.rollout::<_, ucb::Rollout>(&graph)
.unwrap()
.score::<simulation::RandomSimulator>()
.score::<_, simulation::RandomSimulator>(&graph)
.unwrap()
.backprop::<backprop::FirstParentSelector>()
.expand();
});
.backprop::<backprop::FirstParentSelector>(&mut graph)
.expand(&mut graph);
{
// Two levels of should be expanded.
// Two levels should be expanded.
assert_eq!(18, graph.vertex_count());
assert_eq!(17, graph.edge_count());
let node = graph.find_node(&Default::default()).unwrap();
assert!(node.is_root());
assert!(!node.is_leaf());
assert_eq!(9, node.get_child_list().len());
for (i, child) in node.get_child_list().iter().enumerate() {
let node = graph.find_vertex_with_state(&Default::default()).unwrap();
let v = graph.vertex(node);
assert!(v.parents.is_empty());
assert!(!v.children.is_empty());
assert_eq!(9, v.children.len());
for (i, child) in v.children.iter().copied().enumerate() {
if i == 7 {
assert_eq!(1, child.get_data().statistics.visits());
assert_eq!(1, graph.edge(child).data.statistics.visits());
assert_eq!(
0,
child
.get_data()
1,
graph
.edge(child)
.data
.statistics
.score(crate::statistics::two_player::Player::One)
);
assert_eq!(
1,
child
.get_data()
0,
graph
.edge(child)
.data
.statistics
.score(crate::statistics::two_player::Player::Two)
);
} else {
assert_eq!(0, child.get_data().statistics.visits());
assert_eq!(0, graph.edge(child).data.statistics.visits());
assert_eq!(
0,
child
.get_data()
graph
.edge(child)
.data
.statistics
.score(crate::statistics::two_player::Player::One)
);
assert_eq!(
0,
child
.get_data()
graph
.edge(child)
.data
.statistics
.score(crate::statistics::two_player::Player::Two)
);
}
}
}
search_graph::view::of_graph(&mut graph, |view| {
RolloutPhase::initialize(
default_rng(),
default_settings(),
default_game_state(),
view,
&mut graph,
)
.rollout::<ucb::Rollout>()
.rollout::<_, ucb::Rollout>(&graph)
.unwrap()
.score::<simulation::RandomSimulator>()
.score::<_, simulation::RandomSimulator>(&graph)
.unwrap()
.backprop::<backprop::FirstParentSelector>()
.expand();
.backprop::<backprop::FirstParentSelector>(&mut graph)
.expand(&mut graph);
});
}
#[test]
fn integration_test_first_parent_selector() {
let mut graph = Graph::new();
search_graph::view::of_graph(&mut graph, |view| {
Graph::new().open(|mut graph| {
let mut rollout = RolloutPhase::initialize(
default_rng(),
default_settings(),
default_game_state(),
view,
&mut graph,
);
for _ in 0..10000 {
let score = rollout.rollout::<ucb::Rollout>().unwrap();
let backprop = score.score::<simulation::RandomSimulator>().unwrap();
let expand = backprop.backprop::<backprop::FirstParentSelector>();
rollout = expand.expand();
let score = rollout.rollout::<_, ucb::Rollout>(&graph).unwrap();
let backprop = score.score::<_, simulation::RandomSimulator>(&graph).unwrap();
let expand = backprop.backprop::<backprop::FirstParentSelector>(&graph);
rollout = expand.expand(&mut graph);
}
assert_eq!(755, graph.vertex_count());
assert_eq!(947, graph.edge_count());
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
graph.vertex(graph
.find_vertex_with_state(&default_game_state())
.unwrap())
.children
.iter()
.copied()
.map(|c| &graph.edge(c).data.statistics)
.collect();
assert_eq!(9, child_statistics.len());
assert_eq!(7, child_statistics[0].visits());
assert_eq!(2, child_statistics[0].score(crate::statistics::two_player::Player::One));
assert_eq!(2, child_statistics[0].score(crate::statistics::two_player::Player::Two));
assert_eq!(14, child_statistics[1].visits());
assert_eq!(9, child_statistics[1].score(crate::statistics::two_player::Player::One));
assert_eq!(5, child_statistics[1].score(crate::statistics::two_player::Player::Two));
assert_eq!(13, child_statistics[2].visits());
assert_eq!(9, child_statistics[2].score(crate::statistics::two_player::Player::One));
assert_eq!(4, child_statistics[2].score(crate::statistics::two_player::Player::Two));
assert_eq!(15, child_statistics[3].visits());
assert_eq!(7, child_statistics[3].score(crate::statistics::two_player::Player::One));
assert_eq!(7, child_statistics[3].score(crate::statistics::two_player::Player::Two));
assert_eq!(32, child_statistics[4].visits());
assert_eq!(20, child_statistics[4].score(crate::statistics::two_player::Player::One));
assert_eq!(7, child_statistics[4].score(crate::statistics::two_player::Player::Two));
assert_eq!(6, child_statistics[5].visits());
assert_eq!(2, child_statistics[5].score(crate::statistics::two_player::Player::One));
assert_eq!(3, child_statistics[5].score(crate::statistics::two_player::Player::Two));
assert_eq!(3, child_statistics[6].visits());
assert_eq!(0, child_statistics[6].score(crate::statistics::two_player::Player::One));
assert_eq!(2, child_statistics[6].score(crate::statistics::two_player::Player::Two));
assert_eq!(9, child_statistics[7].visits());
assert_eq!(4, child_statistics[7].score(crate::statistics::two_player::Player::One));
assert_eq!(4, child_statistics[7].score(crate::statistics::two_player::Player::Two));
assert_eq!(9900, child_statistics[8].visits());
assert_eq!(43, child_statistics[8].score(crate::statistics::two_player::Player::One));
assert_eq!(29, child_statistics[8].score(crate::statistics::two_player::Player::Two));
});
assert_eq!(3295, graph.vertex_count());
assert_eq!(5361, graph.edge_count());
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
graph
.find_node(&default_game_state())
.unwrap()
.get_child_list()
.iter()
.map(|c| &c.get_data().statistics)
.collect();
assert_eq!(108, child_statistics[0].visits());
assert_eq!(28, child_statistics[0].score(crate::statistics::two_player::Player::One));
assert_eq!(78, child_statistics[0].score(crate::statistics::two_player::Player::Two));
assert_eq!(209, child_statistics[1].visits());
assert_eq!(153, child_statistics[1].score(crate::statistics::two_player::Player::One));
assert_eq!(50, child_statistics[1].score(crate::statistics::two_player::Player::Two));
assert_eq!(754, child_statistics[2].visits());
assert_eq!(557, child_statistics[2].score(crate::statistics::two_player::Player::One));
assert_eq!(178, child_statistics[2].score(crate::statistics::two_player::Player::Two));
assert_eq!(2023, child_statistics[3].visits());
assert_eq!(1548, child_statistics[3].score(crate::statistics::two_player::Player::One));
assert_eq!(459, child_statistics[3].score(crate::statistics::two_player::Player::Two));
assert_eq!(2105, child_statistics[4].visits());
assert_eq!(834, child_statistics[4].score(crate::statistics::two_player::Player::One));
assert_eq!(1181, child_statistics[4].score(crate::statistics::two_player::Player::Two));
assert_eq!(1172, child_statistics[5].visits());
assert_eq!(574, child_statistics[5].score(crate::statistics::two_player::Player::One));
assert_eq!(590, child_statistics[5].score(crate::statistics::two_player::Player::Two));
assert_eq!(953, child_statistics[6].visits());
assert_eq!(483, child_statistics[6].score(crate::statistics::two_player::Player::One));
assert_eq!(282, child_statistics[6].score(crate::statistics::two_player::Player::Two));
assert_eq!(806, child_statistics[7].visits());
assert_eq!(382, child_statistics[7].score(crate::statistics::two_player::Player::One));
assert_eq!(410, child_statistics[7].score(crate::statistics::two_player::Player::Two));
assert_eq!(1869, child_statistics[8].visits());
assert_eq!(678, child_statistics[8].score(crate::statistics::two_player::Player::One));
assert_eq!(913, child_statistics[8].score(crate::statistics::two_player::Player::Two));
}
#[test]
fn integration_test_random_parent_selector() {
let mut graph = Graph::new();
search_graph::view::of_graph(&mut graph, |view| {
Graph::new().open(|mut graph| {
let mut rollout = RolloutPhase::initialize(
default_rng(),
default_settings(),
default_game_state(),
view,
&mut graph,
);
for _ in 0..10000 {
let score = rollout.rollout::<ucb::Rollout>().unwrap();
let backprop = score.score::<simulation::RandomSimulator>().unwrap();
let expand = backprop.backprop::<backprop::RandomParentSelector>();
rollout = expand.expand();
let score = rollout.rollout::<_, ucb::Rollout>(&graph).unwrap();
let backprop = score.score::<_, simulation::RandomSimulator>(&graph).unwrap();
let expand = backprop.backprop::<backprop::RandomParentSelector>(&mut graph);
rollout = expand.expand(&mut graph);
}
});
assert_eq!(5795, graph.vertex_count());
assert_eq!(13874, graph.edge_count());
assert_eq!(4946, graph.vertex_count());
assert_eq!(10077, graph.edge_count());
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
graph
.find_node(&default_game_state())
.unwrap()
.get_child_list()
.iter()
.map(|c| &c.get_data().statistics)
.vertex(graph
.find_vertex_with_state(&default_game_state())
.unwrap())
.children
.iter()
.copied()
.map(|c| &graph.edge(c).data.statistics)
.collect();
assert_eq!(1204, child_statistics[0].visits());
assert_eq!(708, child_statistics[0].score(crate::statistics::two_player::Player::One));
assert_eq!(347, child_statistics[0].score(crate::statistics::two_player::Player::Two));
assert_eq!(790, child_statistics[1].visits());
assert_eq!(396, child_statistics[1].score(crate::statistics::two_player::Player::One));
assert_eq!(266, child_statistics[1].score(crate::statistics::two_player::Player::Two));
assert_eq!(983, child_statistics[2].visits());
assert_eq!(553, child_statistics[2].score(crate::statistics::two_player::Player::One));
assert_eq!(249, child_statistics[2].score(crate::statistics::two_player::Player::Two));
assert_eq!(829, child_statistics[3].visits());
assert_eq!(471, child_statistics[3].score(crate::statistics::two_player::Player::One));
assert_eq!(258, child_statistics[3].score(crate::statistics::two_player::Player::Two));
assert_eq!(2015, child_statistics[4].visits());
assert_eq!(1224, child_statistics[4].score(crate::statistics::two_player::Player::One));
assert_eq!(388, child_statistics[4].score(crate::statistics::two_player::Player::Two));
assert_eq!(795, child_statistics[5].visits());
assert_eq!(449, child_statistics[5].score(crate::statistics::two_player::Player::One));
assert_eq!(242, child_statistics[5].score(crate::statistics::two_player::Player::Two));
assert_eq!(1035, child_statistics[6].visits());
assert_eq!(578, child_statistics[6].score(crate::statistics::two_player::Player::One));
assert_eq!(287, child_statistics[6].score(crate::statistics::two_player::Player::Two));
assert_eq!(1002, child_statistics[7].visits());
assert_eq!(581, child_statistics[7].score(crate::statistics::two_player::Player::One));
assert_eq!(271, child_statistics[7].score(crate::statistics::two_player::Player::Two));
assert_eq!(1346, child_statistics[8].visits());
assert_eq!(798, child_statistics[8].score(crate::statistics::two_player::Player::One));
assert_eq!(399, child_statistics[8].score(crate::statistics::two_player::Player::Two));
assert_eq!(914, child_statistics[0].visits());
assert_eq!(361, child_statistics[0].score(crate::statistics::two_player::Player::One));
assert_eq!(184, child_statistics[0].score(crate::statistics::two_player::Player::Two));
assert_eq!(1143, child_statistics[1].visits());
assert_eq!(426, child_statistics[1].score(crate::statistics::two_player::Player::One));
assert_eq!(216, child_statistics[1].score(crate::statistics::two_player::Player::Two));
assert_eq!(1039, child_statistics[2].visits());
assert_eq!(305, child_statistics[2].score(crate::statistics::two_player::Player::One));
assert_eq!(154, child_statistics[2].score(crate::statistics::two_player::Player::Two));
assert_eq!(770, child_statistics[3].visits());
assert_eq!(258, child_statistics[3].score(crate::statistics::two_player::Player::One));
assert_eq!(125, child_statistics[3].score(crate::statistics::two_player::Player::Two));
assert_eq!(3273, child_statistics[4].visits());
assert_eq!(1154, child_statistics[4].score(crate::statistics::two_player::Player::One));
assert_eq!(384, child_statistics[4].score(crate::statistics::two_player::Player::Two));
assert_eq!(705, child_statistics[5].visits());
assert_eq!(261, child_statistics[5].score(crate::statistics::two_player::Player::One));
assert_eq!(155, child_statistics[5].score(crate::statistics::two_player::Player::Two));
assert_eq!(648, child_statistics[6].visits());
assert_eq!(270, child_statistics[6].score(crate::statistics::two_player::Player::One));
assert_eq!(162, child_statistics[6].score(crate::statistics::two_player::Player::Two));
assert_eq!(795, child_statistics[7].visits());
assert_eq!(225, child_statistics[7].score(crate::statistics::two_player::Player::One));
assert_eq!(128, child_statistics[7].score(crate::statistics::two_player::Player::Two));
assert_eq!(712, child_statistics[8].visits());
assert_eq!(212, child_statistics[8].score(crate::statistics::two_player::Player::One));
assert_eq!(168, child_statistics[8].score(crate::statistics::two_player::Player::Two));
});
}
#[test]
fn integration_test_best_parent_selector() {
let mut graph = Graph::new();
search_graph::view::of_graph(&mut graph, |view| {
Graph::new().open(|mut graph| {
let mut rollout = RolloutPhase::initialize(
default_rng(),
default_settings(),
default_game_state(),
view,
&mut graph,
);
for _ in 0..10000 {
let score = rollout.rollout::<ucb::Rollout>().unwrap();
let backprop = score.score::<simulation::RandomSimulator>().unwrap();
let expand = backprop.backprop::<ucb::BestParentBackprop>();
rollout = expand.expand();
let score = rollout.rollout::<_, ucb::Rollout>(&mut graph).unwrap();
let backprop = score.score::<_, simulation::RandomSimulator>(&graph).unwrap();
let expand = backprop.backprop::<ucb::BestParentBackprop>(&mut graph);
rollout = expand.expand(&mut graph);
}
});
assert_eq!(4471, graph.vertex_count());
assert_eq!(10889, graph.edge_count());
assert_eq!(5601, graph.vertex_count());
assert_eq!(14471, graph.edge_count());
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
graph
.find_node(&default_game_state())
.unwrap()
.get_child_list()
.vertex(graph
.find_vertex_with_state(&default_game_state())
.unwrap())
.children
.iter()
.map(|c| &c.get_data().statistics)
.copied()
.map(|c| &graph.edge(c).data.statistics)
.collect();
assert_eq!(229, child_statistics[0].visits());
assert_eq!(138, child_statistics[0].score(crate::statistics::two_player::Player::One));
assert_eq!(53, child_statistics[0].score(crate::statistics::two_player::Player::Two));
assert_eq!(73, child_statistics[1].visits());
assert_eq!(33, child_statistics[1].score(crate::statistics::two_player::Player::One));
assert_eq!(30, child_statistics[1].score(crate::statistics::two_player::Player::Two));
assert_eq!(322, child_statistics[2].visits());
assert_eq!(203, child_statistics[2].score(crate::statistics::two_player::Player::One));
assert_eq!(83, child_statistics[2].score(crate::statistics::two_player::Player::Two));
assert_eq!(90, child_statistics[3].visits());
assert_eq!(44, child_statistics[3].score(crate::statistics::two_player::Player::One));
assert_eq!(32, child_statistics[3].score(crate::statistics::two_player::Player::Two));
assert_eq!(9033, child_statistics[4].visits());
assert_eq!(7693, child_statistics[4].score(crate::statistics::two_player::Player::One));
assert_eq!(945, child_statistics[4].score(crate::statistics::two_player::Player::Two));
assert_eq!(74, child_statistics[5].visits());
assert_eq!(34, child_statistics[5].score(crate::statistics::two_player::Player::One));
assert_eq!(35, child_statistics[5].score(crate::statistics::two_player::Player::Two));
assert_eq!(118, child_statistics[6].visits());
assert_eq!(62, child_statistics[6].score(crate::statistics::two_player::Player::One));
assert_eq!(43, child_statistics[6].score(crate::statistics::two_player::Player::Two));
assert_eq!(98, child_statistics[7].visits());
assert_eq!(49, child_statistics[7].score(crate::statistics::two_player::Player::One));
assert_eq!(36, child_statistics[7].score(crate::statistics::two_player::Player::Two));
assert_eq!(180, child_statistics[8].visits());
assert_eq!(104, child_statistics[8].score(crate::statistics::two_player::Player::One));
assert_eq!(41, child_statistics[8].score(crate::statistics::two_player::Player::Two));
assert_eq!(888, child_statistics[0].visits());
assert_eq!(325, child_statistics[0].score(crate::statistics::two_player::Player::One));
assert_eq!(225, child_statistics[0].score(crate::statistics::two_player::Player::Two));
assert_eq!(330, child_statistics[1].visits());
assert_eq!(138, child_statistics[1].score(crate::statistics::two_player::Player::One));
assert_eq!(123, child_statistics[1].score(crate::statistics::two_player::Player::Two));
assert_eq!(918, child_statistics[2].visits());
assert_eq!(340, child_statistics[2].score(crate::statistics::two_player::Player::One));
assert_eq!(236, child_statistics[2].score(crate::statistics::two_player::Player::Two));
assert_eq!(792, child_statistics[3].visits());
assert_eq!(323, child_statistics[3].score(crate::statistics::two_player::Player::One));
assert_eq!(239, child_statistics[3].score(crate::statistics::two_player::Player::Two));
assert_eq!(3817, child_statistics[4].visits());
assert_eq!(1054, child_statistics[4].score(crate::statistics::two_player::Player::One));
assert_eq!(423, child_statistics[4].score(crate::statistics::two_player::Player::Two));
assert_eq!(657, child_statistics[5].visits());
assert_eq!(235, child_statistics[5].score(crate::statistics::two_player::Player::One));
assert_eq!(172, child_statistics[5].score(crate::statistics::two_player::Player::Two));
assert_eq!(727, child_statistics[6].visits());
assert_eq!(246, child_statistics[6].score(crate::statistics::two_player::Player::One));
assert_eq!(172, child_statistics[6].score(crate::statistics::two_player::Player::Two));
assert_eq!(758, child_statistics[7].visits());
assert_eq!(270, child_statistics[7].score(crate::statistics::two_player::Player::One));
assert_eq!(192, child_statistics[7].score(crate::statistics::two_player::Player::Two));
assert_eq!(1152, child_statistics[8].visits());
assert_eq!(448, child_statistics[8].score(crate::statistics::two_player::Player::One));
assert_eq!(304, child_statistics[8].score(crate::statistics::two_player::Player::Two));
});
}
}

View File

@@ -10,6 +10,7 @@ use std::fmt;
use std::result::Result;
use rand::Rng;
use search_graph::{EdgeRef, Graph, Open, VertexRef};
/// Error type for MCTS rollout.
pub enum RolloutError<G: Game, E: Error> {
@@ -75,10 +76,10 @@ pub trait RolloutSelector: for<'a> From<&'a SearchParameters> {
/// Returns the element of `children` that should be followed, or an error.
fn select<'a, 'id, G: Game, R: Rng>(
&self,
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
parent: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
parent: VertexRef<Open<'id>>,
rng: &mut R,
) -> Result<search_graph::view::EdgeRef<'id>, Self::Error>;
) -> Result<EdgeRef<Open<'id>>, Self::Error>;
}
/// Traverses the game graph downwards from `node` down to some terminating
@@ -90,28 +91,28 @@ pub trait RolloutSelector: for<'a> From<&'a SearchParameters> {
///
/// Selection will be done minimax-style, i.e., always trying to maximize the
/// score for the currently active player.
pub fn rollout<'a, 'id, G, S, R>(
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
mut node: search_graph::view::NodeRef<'id>,
pub fn rollout<'id, G, S, R>(
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
mut node: VertexRef<Open<'id>>,
selector: S,
rng: &mut R,
) -> Result<search_graph::view::NodeRef<'id>, RolloutError<G, S::Error>>
) -> Result<VertexRef<Open<'id>>, RolloutError<G, S::Error>>
where
G: Game,
S: RolloutSelector,
R: Rng,
{
loop {
if let Some(_) = G::payoff_of(graph.node_state(node)) {
if let Some(_) = G::payoff_of(graph.game_state_at(node)) {
// Hit known payoff.
break;
} else if graph.child_count(node) == 0 {
} else if graph.vertex(node).children.is_empty() {
// Hit leaf in search graph.
break;
} else {
let child = selector.select(graph, node, rng)?;
graph.edge_data(child).mark_rollout_traversal();
node = graph.edge_target(child);
graph.edge(child).data.mark_rollout_traversal();
node = graph.edge(child).target;
}
}
Ok(node)

View File

@@ -158,11 +158,16 @@ impl<M: PlayerMapping> ScoredStatistics<M> {
let visits = cmp::min(old_visits + 1, VISITS_MAX);
let score_one = cmp::min(old_score_one + score_one, SCORE_MAX);
let score_two = cmp::min(old_score_two + score_two, SCORE_MAX);
success = self.packed.compare_and_swap(
old_packed,
pack_scores(visits, score_one, score_two),
atomic::Ordering::SeqCst,
) == old_packed;
let cex = match self.packed.compare_exchange_weak(
old_packed,
pack_scores(visits, score_one, score_two),
atomic::Ordering::SeqCst,
atomic::Ordering::SeqCst,
) {
Ok(n) => n,
Err(_) => continue,
};
success = cex == old_packed;
}
}
}

View File

@@ -121,11 +121,11 @@ impl game::State for State {
&self.active_player
}
fn actions<'s>(&'s self) -> Box<dyn Iterator<Item = Action> + 's> {
Box::new(iterate![for row in 0..3;
for column in 0..3;
if self.board.get(row, column).is_none();
yield Action { row, column, player: self.active_player, }])
fn actions<'s>(&'s self) -> impl Iterator<Item=Self::Action> + 's {
iterate![for row in 0..3;
for column in 0..3;
if self.board.get(row, column).is_none();
yield Action { row, column, player: self.active_player, }]
}
fn do_action(&mut self, action: &Action) {

View File

@@ -6,7 +6,7 @@ use crate::graph::{EdgeData, VertexData};
use crate::rollout::RolloutSelector;
use log::{error, trace};
use rand::Rng;
use search_graph;
use search_graph::{EdgeRef, Graph, Open, VertexRef};
use std::cmp::Ordering;
use std::error::Error;
@@ -18,9 +18,9 @@ use std::result::Result;
pub enum UcbSuccess<'id> {
/// No (finite) value can be computed, but the UCB policy indicates that
/// this child should be selected. E.g., the child has not yet been visited.
Select(search_graph::view::EdgeRef<'id>),
Select(EdgeRef<Open<'id>>),
/// A value is computed.
Value(search_graph::view::EdgeRef<'id>, f64),
Value(EdgeRef<Open<'id>>, f64),
}
/// Represents error conditions when computing the UCB score for a child.
@@ -63,15 +63,15 @@ impl Error for UcbError {
fn edge_ucb_iter<'v, 'id, G>(
log_parent_visits: f64,
explore_bias: f64,
graph: &'v search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
edges: impl Iterator<Item = search_graph::view::EdgeRef<'id>> + 'v,
graph: &'v Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
edges: impl Iterator<Item = EdgeRef<Open<'id>>> + 'v,
)
-> impl Iterator<Item = Result<UcbSuccess<'id>, UcbError>> + 'v
where
G: Game,
{
edges.map(move |e| {
if graph.edge_data(e).statistics.visits() == 0 {
if graph.edge(e).data.statistics.visits() == 0 {
trace!("edge_ucb_iter selects unvisited action");
Ok(UcbSuccess::Select(e))
} else {
@@ -93,16 +93,17 @@ where
pub fn child_score<'a, 'id, G: Game>(
log_parent_visits: f64,
explore_bias: f64,
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
child: search_graph::view::EdgeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
child: EdgeRef<Open<'id>>,
) -> UcbSuccess<'id> {
let statistics = &graph.edge_data(child).statistics;
let statistics = &graph.edge(child).data.statistics;
if statistics.visits() == 0 {
UcbSuccess::Select(child)
} else {
let child_visits = statistics.visits() as f64;
let vertex = graph.edge(child).source;
let child_score =
statistics.score(graph.node_state(graph.edge_source(child)).active_player()) as f64;
statistics.score(graph.game_state_at(vertex).active_player()) as f64;
let ucb =
child_score / child_visits + explore_bias * f64::sqrt(log_parent_visits / child_visits);
UcbSuccess::Value(child, ucb)
@@ -122,23 +123,23 @@ pub fn child_score<'a, 'id, G: Game>(
/// graph (not just a tree), we want to know all of the parent edges which could
/// have rolled out to a given child.
pub fn is_best_child<'a, 'id, G: Game>(
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
e: search_graph::view::EdgeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
e: EdgeRef<Open<'id>>,
explore_bias: f64,
) -> bool {
let statistics = &graph.edge_data(e).statistics;
let statistics = &graph.edge(e).data.statistics;
// trace!("is_best_child: edge {} has {} visits", e.get_id(), stats.visits);
if statistics.visits() == 0 {
// Edge has been visited, but statistics aren't yet updated.
// trace!("is_best_child: edge {} is a best child because stats.visits == 0", e.get_id());
return true;
}
let parent = graph.edge_source(e);
let parent = graph.edge(e).source;
// trace!("is_best_child: edge {} (from node {}) has {} siblings", e.get_id(), parent.get_id(), siblings.len());
let log_parent_visits = {
let mut parent_visits = 0;
for child_edge in graph.children(parent) {
parent_visits += graph.edge_data(child_edge).statistics.visits();
for child_edge in &graph.vertex(parent).children {
parent_visits += graph.edge(*child_edge).data.statistics.visits();
}
f64::ln(parent_visits as f64)
};
@@ -148,7 +149,7 @@ pub fn is_best_child<'a, 'id, G: Game>(
log_parent_visits,
explore_bias,
graph,
graph.children(parent),
graph.vertex(parent).children.iter().copied(),
);
// Scan through siblings to find the maximum UCB score. This is
// short-circuited using a lazy iterator to ameliorate the O(n) running
@@ -204,19 +205,19 @@ pub fn is_best_child<'a, 'id, G: Game>(
///
/// This function will panic if `parent` has no children.
pub fn find_best_child<'a, 'id, G, R>(
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
parent: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
parent: VertexRef<Open<'id>>,
explore_bias: f64,
rng: &mut R,
) -> Result<search_graph::view::EdgeRef<'id>, UcbError>
) -> Result<EdgeRef<Open<'id>>, UcbError>
where
G: Game,
R: Rng,
{
let log_parent_visits = {
let mut parent_visits = 0;
for child in graph.children(parent) {
parent_visits += graph.edge_data(child).statistics.visits();
for child in &graph.vertex(parent).children {
parent_visits += graph.edge(*child).data.statistics.visits();
}
if parent_visits == 0 {
// When we visit a vertex for the first time, it will have zero visits.
@@ -231,9 +232,9 @@ where
log_parent_visits,
explore_bias,
graph,
graph.children(parent),
graph.vertex(parent).children.iter().copied(),
);
let mut best: search_graph::view::EdgeRef<'id>;
let mut best: EdgeRef<Open<'id>>;
let mut best_ucb: f64;
match ucb_iter.next().expect("vertex has no children")? {
UcbSuccess::Select(e) => {
@@ -301,10 +302,10 @@ impl RolloutSelector for Rollout {
fn select<'a, 'id, G: Game, R: Rng>(
&self,
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
parent: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
parent: VertexRef<Open<'id>>,
rng: &mut R,
) -> Result<search_graph::view::EdgeRef<'id>, UcbError> {
) -> Result<EdgeRef<Open<'id>>, UcbError> {
find_best_child(graph, parent, self.explore_bias, rng)
}
}
@@ -330,18 +331,21 @@ impl<'id> BackpropSelector<'id> for BestParentBackprop {
// ATCs/HKTs. We need ATC support because this iterator type will have its
// lifetime constrained by the borrow of `graph` in the select method, but
// that lifetime isn't known statically.
type Items = std::vec::IntoIter<search_graph::view::EdgeRef<'id>>;
type Items = std::vec::IntoIter<EdgeRef<Open<'id>>>;
fn select<G: Game, R: Rng>(
&self,
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
node: search_graph::view::NodeRef<'id>,
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
node: VertexRef<Open<'id>>,
_payoff: &G::Payoff,
_rng: &mut R,
) -> Self::Items {
let result: Vec<search_graph::view::EdgeRef<'id>> = graph
.parents(node)
.filter(|&parent_edge| is_best_child(graph, parent_edge, self.explore_bias))
let result: Vec<EdgeRef<Open<'id>>> = graph
.vertex(node)
.parents
.iter()
.filter(|&parent_edge| is_best_child(graph, *parent_edge, self.explore_bias))
.copied()
.collect();
result.into_iter()
}