Accommodate changes to search-graph.
This commit is contained in:
101
src/backprop.rs
101
src/backprop.rs
@@ -10,18 +10,19 @@ use crate::SearchParameters;
|
||||
use log::trace;
|
||||
use rand::Rng;
|
||||
use rand::seq::IteratorRandom;
|
||||
use search_graph::{EdgeRef, Open, Graph, VertexRef};
|
||||
|
||||
/// Provides a method for selecting incoming parent edges to follow during
|
||||
/// backprop phase of MCTS.
|
||||
pub trait BackpropSelector<'id>: for<'a> From<&'a SearchParameters> {
|
||||
type Items: Iterator<Item = search_graph::view::EdgeRef<'id>>;
|
||||
type Items: Iterator<Item = EdgeRef<Open<'id>>>;
|
||||
|
||||
/// Returns the edges to follow when pushing statistics back up through the
|
||||
/// search graph.
|
||||
fn select<G: Game, R: Rng>(
|
||||
&self,
|
||||
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
node: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
node: VertexRef<Open<'id>>,
|
||||
payoff: &G::Payoff,
|
||||
rng: &mut R,
|
||||
) -> Self::Items;
|
||||
@@ -44,16 +45,16 @@ impl<'a> From<&'a SearchParameters> for FirstParentSelector {
|
||||
}
|
||||
|
||||
impl<'id> BackpropSelector<'id> for FirstParentSelector {
|
||||
type Items = std::option::IntoIter<search_graph::view::EdgeRef<'id>>;
|
||||
type Items = std::option::IntoIter<EdgeRef<Open<'id>>>;
|
||||
|
||||
fn select<G: Game, R: Rng>(
|
||||
&self,
|
||||
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
node: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
node: VertexRef<Open<'id>>,
|
||||
_payoff: &G::Payoff,
|
||||
_rng: &mut R,
|
||||
) -> Self::Items {
|
||||
graph.parents(node).next().into_iter()
|
||||
graph.vertex(node).parents.first().copied().into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,16 +73,16 @@ impl<'a> From<&'a SearchParameters> for RandomParentSelector {
|
||||
}
|
||||
|
||||
impl<'id> BackpropSelector<'id> for RandomParentSelector {
|
||||
type Items = std::option::IntoIter<search_graph::view::EdgeRef<'id>>;
|
||||
type Items = std::option::IntoIter<EdgeRef<Open<'id>>>;
|
||||
|
||||
fn select<G: Game, R: Rng>(
|
||||
&self,
|
||||
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
node: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
node: VertexRef<Open<'id>>,
|
||||
_payoff: &G::Payoff,
|
||||
rng: &mut R,
|
||||
) -> Self::Items {
|
||||
graph.parents(node).choose(rng).into_iter()
|
||||
graph.vertex(node).parents.iter().copied().choose(rng).into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,12 +94,12 @@ impl<'id> BackpropSelector<'id> for RandomParentSelector {
|
||||
/// yields may be affected by statistics updates if they are applied during
|
||||
/// iteration.
|
||||
pub fn backprop_iter<'a, 'b, 'id, G, S, R>(
|
||||
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
node: search_graph::view::NodeRef<'id>,
|
||||
graph: &'b Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
node: VertexRef<Open<'id>>,
|
||||
payoff: &'b G::Payoff,
|
||||
selector: &'b S,
|
||||
rng: &'b mut R,
|
||||
) -> impl Iterator<Item = search_graph::view::EdgeRef<'id>> + 'b
|
||||
) -> impl Iterator<Item = EdgeRef<Open<'id>>> + 'b
|
||||
where
|
||||
'a: 'b,
|
||||
G: Game,
|
||||
@@ -109,34 +110,34 @@ where
|
||||
}
|
||||
|
||||
/// Chains together parent traversals.
|
||||
struct BackpropIter<'a, 'b, 'id, G, S, R>
|
||||
struct BackpropIter<'a, 'id, G, S, R>
|
||||
where
|
||||
G: Game,
|
||||
S: BackpropSelector<'id> + 'b,
|
||||
S: BackpropSelector<'id> + 'a,
|
||||
R: Rng,
|
||||
{
|
||||
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
/// Nodes whose parent edges to traverse.
|
||||
stack: Vec<search_graph::view::NodeRef<'id>>,
|
||||
stack: Vec<VertexRef<Open<'id>>>,
|
||||
/// Edges from most recently examined node.
|
||||
parent_edges: S::Items,
|
||||
payoff: &'b G::Payoff,
|
||||
selector: &'b S,
|
||||
rng: &'b mut R,
|
||||
payoff: &'a G::Payoff,
|
||||
selector: &'a S,
|
||||
rng: &'a mut R,
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'id, G, S, R> BackpropIter<'a, 'b, 'id, G, S, R>
|
||||
impl<'a, 'id, G, S, R> BackpropIter<'a, 'id, G, S, R>
|
||||
where
|
||||
G: Game,
|
||||
S: BackpropSelector<'id> + 'b,
|
||||
S: BackpropSelector<'id> + 'a,
|
||||
R: Rng,
|
||||
{
|
||||
fn new(
|
||||
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
node: search_graph::view::NodeRef<'id>,
|
||||
payoff: &'b G::Payoff,
|
||||
selector: &'b S,
|
||||
rng: &'b mut R,
|
||||
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
node: VertexRef<Open<'id>>,
|
||||
payoff: &'a G::Payoff,
|
||||
selector: &'a S,
|
||||
rng: &'a mut R,
|
||||
) -> Self {
|
||||
let parent_edges = selector.select(graph, node, payoff, rng);
|
||||
BackpropIter {
|
||||
@@ -150,18 +151,18 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'id, G, S, R> Iterator for BackpropIter<'a, 'b, 'id, G, S, R>
|
||||
impl<'a, 'id, G, S, R> Iterator for BackpropIter<'a, 'id, G, S, R>
|
||||
where
|
||||
G: Game,
|
||||
S: BackpropSelector<'id> + 'b,
|
||||
S: BackpropSelector<'id> + 'a,
|
||||
R: Rng,
|
||||
{
|
||||
type Item = search_graph::view::EdgeRef<'id>;
|
||||
type Item = EdgeRef<Open<'id>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(parent) = self.parent_edges.next() {
|
||||
if !self.graph.edge_data(parent).mark_backprop_traversal() {
|
||||
self.stack.push(self.graph.edge_source(parent));
|
||||
if !self.graph.edge(parent).data.mark_backprop_traversal() {
|
||||
self.stack.push(self.graph.edge(parent).source);
|
||||
return Some(parent);
|
||||
}
|
||||
}
|
||||
@@ -171,8 +172,8 @@ where
|
||||
.selector
|
||||
.select(self.graph, node, self.payoff, self.rng);
|
||||
while let Some(parent) = self.parent_edges.next() {
|
||||
if !self.graph.edge_data(parent).mark_backprop_traversal() {
|
||||
self.stack.push(self.graph.edge_source(parent));
|
||||
if !self.graph.edge(parent).data.mark_backprop_traversal() {
|
||||
self.stack.push(self.graph.edge(parent).source);
|
||||
return Some(parent);
|
||||
}
|
||||
}
|
||||
@@ -187,23 +188,23 @@ where
|
||||
|
||||
/// Iterable view over parents of a graph node that selects parents for which
|
||||
/// this node is a best child.
|
||||
pub struct ParentSelectionIter<'a, 'b, 'id, G, I>
|
||||
pub struct ParentSelectionIter<'a, 'id, G, I>
|
||||
where
|
||||
G: Game,
|
||||
I: Iterator<Item = search_graph::view::EdgeRef<'id>>,
|
||||
I: Iterator<Item = EdgeRef<Open<'id>>>,
|
||||
{
|
||||
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
parents: I,
|
||||
explore_bias: f64,
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'id, G, I> ParentSelectionIter<'a, 'b, 'id, G, I>
|
||||
impl<'a, 'id, G, I> ParentSelectionIter<'a, 'id, G, I>
|
||||
where
|
||||
G: Game,
|
||||
I: Iterator<Item = search_graph::view::EdgeRef<'id>>,
|
||||
I: Iterator<Item = EdgeRef<Open<'id>>>,
|
||||
{
|
||||
pub fn new(
|
||||
graph: &'b search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
graph: &'a Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
parents: I,
|
||||
explore_bias: f64,
|
||||
) -> Self {
|
||||
@@ -215,19 +216,19 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'b, 'id, G, I> Iterator for ParentSelectionIter<'a, 'b, 'id, G, I>
|
||||
impl<'a, 'id, G, I> Iterator for ParentSelectionIter<'a, 'id, G, I>
|
||||
where
|
||||
G: Game,
|
||||
I: Iterator<Item = search_graph::view::EdgeRef<'id>>,
|
||||
I: Iterator<Item = EdgeRef<Open<'id>>>,
|
||||
{
|
||||
type Item = search_graph::view::EdgeRef<'id>;
|
||||
type Item = EdgeRef<Open<'id>>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
loop {
|
||||
match self.parents.next() {
|
||||
None => return None,
|
||||
Some(e) => {
|
||||
if self.graph.edge_data(e).mark_backprop_traversal() {
|
||||
if self.graph.edge(e).data.mark_backprop_traversal() {
|
||||
trace!("ParentSelectionIter::next: edge was already visited in backtrace",);
|
||||
continue;
|
||||
}
|
||||
@@ -235,14 +236,14 @@ where
|
||||
trace!(
|
||||
"ParentSelectionIter::next: edge {:?} (from node {:?}) is a best child",
|
||||
e,
|
||||
self.graph.edge_source(e),
|
||||
self.graph.edge(e).source,
|
||||
);
|
||||
return Some(e);
|
||||
}
|
||||
trace!(
|
||||
"ParentSelectionIter::next: edge {:?} (data {:?}) is not a best child",
|
||||
e,
|
||||
self.graph.edge_data(e),
|
||||
self.graph.edge(e).data,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -257,8 +258,8 @@ where
|
||||
/// `selector`. Ancestors are selected by recursively applying `selector` to the
|
||||
/// parent that it selects.
|
||||
pub fn backprop<'a, 'id, G, S, R>(
|
||||
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
node: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
node: VertexRef<Open<'id>>,
|
||||
payoff: &G::Payoff,
|
||||
selector: &S,
|
||||
rng: &mut R,
|
||||
@@ -271,7 +272,7 @@ pub fn backprop<'a, 'id, G, S, R>(
|
||||
// updating the statistics alters best child status, and backprop_iter returns
|
||||
// a lazy iterator.
|
||||
let statistics: Vec<&EdgeData<G>> = backprop_iter(graph, node, payoff, selector, rng)
|
||||
.map(|edge| graph.edge_data(edge))
|
||||
.map(|edge| &graph.edge(edge).data)
|
||||
.collect();
|
||||
for s in statistics {
|
||||
s.statistics.increment(payoff);
|
||||
|
||||
518
src/lib.rs
518
src/lib.rs
@@ -22,6 +22,7 @@ use std::result::Result;
|
||||
|
||||
use log::trace;
|
||||
use rand::Rng;
|
||||
use search_graph::{Closed, Open, Graph, VertexRef};
|
||||
|
||||
/// Wraps a decision made by a rollout policy.
|
||||
///
|
||||
@@ -63,8 +64,8 @@ pub struct ActionStatistics<G: Game> {
|
||||
|
||||
/// Creates a new search graph suitable for Monte Carlo tree search through the
|
||||
/// state space of the game `G`.
|
||||
pub fn new_search_graph<G: Game>() -> search_graph::Graph<G::State, VertexData, EdgeData<G>> {
|
||||
search_graph::Graph::<G::State, VertexData, EdgeData<G>>::new()
|
||||
pub fn new_search_graph<G: Game>() -> search_graph::Graph<Closed, G::State, VertexData, EdgeData<G>> {
|
||||
Graph::default()
|
||||
}
|
||||
|
||||
/// Parameters for a round of Monte Carlo tree search.
|
||||
@@ -80,38 +81,34 @@ pub struct SearchParameters {
|
||||
|
||||
/// Recursively traverses the search graph to find a game state from which to
|
||||
/// perform payoff estimates.
|
||||
pub struct RolloutPhase<'a, 'id, R: Rng, G: Game> {
|
||||
pub struct RolloutPhase<'id, R: Rng> {
|
||||
rng: R,
|
||||
settings: SearchParameters,
|
||||
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
root_node: search_graph::view::NodeRef<'id>,
|
||||
root_node: VertexRef<Open<'id>>,
|
||||
}
|
||||
|
||||
impl<'a, 'id, R: Rng, G: Game> RolloutPhase<'a, 'id, R, G> {
|
||||
pub fn initialize(
|
||||
impl<'id, R: Rng> RolloutPhase<'id, R> {
|
||||
pub fn initialize<G: Game>(
|
||||
rng: R,
|
||||
settings: SearchParameters,
|
||||
root_state: G::State,
|
||||
mut graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
graph: &mut Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
) -> Self {
|
||||
trace!("initializing rollout phase to state: {:?}", root_state);
|
||||
let root_node = match graph.find_node(&root_state) {
|
||||
Some(n) => n,
|
||||
None => graph.append_node(root_state.clone(), VertexData::default()),
|
||||
};
|
||||
let root_node = graph.find_or_add_vertex(root_state.clone(), VertexData::default);
|
||||
RolloutPhase {
|
||||
rng,
|
||||
settings,
|
||||
graph,
|
||||
root_node,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rollout<S: RolloutSelector>(
|
||||
pub fn rollout<G: Game, S: RolloutSelector>(
|
||||
mut self,
|
||||
) -> Result<ScoringPhase<'a, 'id, R, G>, rollout::RolloutError<G, S::Error>> {
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
) -> Result<ScoringPhase<'id, R>, rollout::RolloutError<G, S::Error>> {
|
||||
let result = rollout::rollout(
|
||||
&self.graph,
|
||||
graph,
|
||||
self.root_node,
|
||||
S::from(&self.settings),
|
||||
&mut self.rng,
|
||||
@@ -121,52 +118,50 @@ impl<'a, 'id, R: Rng, G: Game> RolloutPhase<'a, 'id, R, G> {
|
||||
trace!("rollout result has target node: {:?}", node);
|
||||
trace!(
|
||||
"rollout result has state: {:?}",
|
||||
self.graph.node_state(node)
|
||||
graph.game_state_at(node)
|
||||
);
|
||||
ScoringPhase {
|
||||
rng: self.rng,
|
||||
settings: self.settings,
|
||||
graph: self.graph,
|
||||
root_node: self.root_node,
|
||||
rollout_node: node,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn root_node(&self) -> search_graph::view::NodeRef<'id> {
|
||||
pub fn root_node(&self) -> VertexRef<Open<'id>> {
|
||||
self.root_node
|
||||
}
|
||||
|
||||
pub fn recover_components(
|
||||
pub fn recover_rng(
|
||||
self,
|
||||
) -> (
|
||||
R,
|
||||
search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
) {
|
||||
(self.rng, self.graph)
|
||||
) -> R {
|
||||
self.rng
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes an estimate of the score for a game state selected during rollout.
|
||||
pub struct ScoringPhase<'a, 'id, R: Rng, G: Game> {
|
||||
pub struct ScoringPhase<'id, R: Rng> {
|
||||
rng: R,
|
||||
settings: SearchParameters,
|
||||
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
root_node: search_graph::view::NodeRef<'id>,
|
||||
rollout_node: search_graph::view::NodeRef<'id>,
|
||||
root_node: VertexRef<Open<'id>>,
|
||||
rollout_node: VertexRef<Open<'id>>,
|
||||
}
|
||||
|
||||
impl<'a, 'id, R: Rng, G: Game> ScoringPhase<'a, 'id, R, G> {
|
||||
pub fn root_node(&self) -> search_graph::view::NodeRef<'id> {
|
||||
impl<'id, R: Rng> ScoringPhase<'id, R> {
|
||||
pub fn root_node(&self) -> VertexRef<Open<'id>> {
|
||||
self.root_node
|
||||
}
|
||||
|
||||
pub fn rollout_node(&self) -> search_graph::view::NodeRef<'id> {
|
||||
pub fn rollout_node(&self) -> VertexRef<Open<'id>> {
|
||||
self.rollout_node
|
||||
}
|
||||
|
||||
pub fn score<S: Simulator>(mut self) -> Result<BackpropPhase<'a, 'id, R, G>, S::Error> {
|
||||
let payoff = match G::payoff_of(self.graph.node_state(self.rollout_node())) {
|
||||
pub fn score<G: Game, S: Simulator>(
|
||||
mut self,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
) -> Result<BackpropPhase<'id, R, G>, S::Error> {
|
||||
let payoff = match G::payoff_of(graph.game_state_at(self.rollout_node())) {
|
||||
Some(p) => {
|
||||
trace!("direct payoff found: {:?}", p);
|
||||
p
|
||||
@@ -174,14 +169,13 @@ impl<'a, 'id, R: Rng, G: Game> ScoringPhase<'a, 'id, R, G> {
|
||||
None => {
|
||||
trace!("simulating to find payoff");
|
||||
let simulator = S::from(&self.settings);
|
||||
simulator.simulate::<G, R>(self.graph.node_state(self.rollout_node()), &mut self.rng)?
|
||||
simulator.simulate::<G, R>(graph.game_state_at(self.rollout_node()), &mut self.rng)?
|
||||
}
|
||||
};
|
||||
trace!("scoring phase finds payoff {:?}", payoff);
|
||||
Ok(BackpropPhase {
|
||||
rng: self.rng,
|
||||
settings: self.settings,
|
||||
graph: self.graph,
|
||||
root_node: self.root_node,
|
||||
rollout_node: self.rollout_node,
|
||||
payoff,
|
||||
@@ -193,19 +187,21 @@ impl<'a, 'id, R: Rng, G: Game> ScoringPhase<'a, 'id, R, G> {
|
||||
///
|
||||
/// The strategy for finding the game-state statistics to update during backprop
|
||||
/// is determined by `BackpropSelector`.
|
||||
pub struct BackpropPhase<'a, 'id, R: Rng, G: Game> {
|
||||
pub struct BackpropPhase<'id, R: Rng, G: Game> {
|
||||
rng: R,
|
||||
settings: SearchParameters,
|
||||
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
root_node: search_graph::view::NodeRef<'id>,
|
||||
rollout_node: search_graph::view::NodeRef<'id>,
|
||||
root_node: VertexRef<Open<'id>>,
|
||||
rollout_node: VertexRef<Open<'id>>,
|
||||
payoff: G::Payoff,
|
||||
}
|
||||
|
||||
impl<'a, 'id, R: Rng, G: Game> BackpropPhase<'a, 'id, R, G> {
|
||||
pub fn backprop<S: BackpropSelector<'id>>(mut self) -> ExpandPhase<'a, 'id, R, G> {
|
||||
impl<'id, R: Rng, G: Game> BackpropPhase<'id, R, G> {
|
||||
pub fn backprop<S: BackpropSelector<'id>>(
|
||||
mut self,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
) -> ExpandPhase<'id, R> {
|
||||
backprop::backprop(
|
||||
&self.graph,
|
||||
graph,
|
||||
self.rollout_node,
|
||||
&self.payoff,
|
||||
&S::from(&self.settings),
|
||||
@@ -214,7 +210,6 @@ impl<'a, 'id, R: Rng, G: Game> BackpropPhase<'a, 'id, R, G> {
|
||||
ExpandPhase {
|
||||
rng: self.rng,
|
||||
settings: self.settings,
|
||||
graph: self.graph,
|
||||
root_node: self.root_node,
|
||||
rollout_node: self.rollout_node,
|
||||
}
|
||||
@@ -227,45 +222,35 @@ impl<'a, 'id, R: Rng, G: Game> BackpropPhase<'a, 'id, R, G> {
|
||||
/// state, adding them to the graph if they don't already exist, and then
|
||||
/// creating an edge from the original node to the node for the resulting game
|
||||
/// state.
|
||||
pub struct ExpandPhase<'a, 'id, R: Rng, G: Game> {
|
||||
pub struct ExpandPhase<'id, R: Rng> {
|
||||
rng: R,
|
||||
settings: SearchParameters,
|
||||
graph: search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
root_node: search_graph::view::NodeRef<'id>,
|
||||
rollout_node: search_graph::view::NodeRef<'id>,
|
||||
root_node: VertexRef<Open<'id>>,
|
||||
rollout_node: VertexRef<Open<'id>>,
|
||||
}
|
||||
|
||||
impl<'a, 'id, R: Rng, G: Game> ExpandPhase<'a, 'id, R, G> {
|
||||
pub fn expand(mut self) -> RolloutPhase<'a, 'id, R, G> {
|
||||
if self.graph.node_data(self.rollout_node).mark_expanded() {
|
||||
impl<'id, R: Rng> ExpandPhase<'id, R> {
|
||||
pub fn expand<G: Game>(
|
||||
self,
|
||||
graph: &mut Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
) -> RolloutPhase<'id, R> {
|
||||
if graph.vertex(self.rollout_node).data.mark_expanded() {
|
||||
trace!("rollout node was already marked as expanded; ExpandPhase does nothing");
|
||||
} else {
|
||||
let parent_state = self.graph.node_state(self.rollout_node).clone();
|
||||
let parent_state = graph.game_state_at(self.rollout_node).clone();
|
||||
for action in parent_state.actions() {
|
||||
trace!("ExpandPhase adds edge for action {:?}", action);
|
||||
let mut child_state = parent_state.clone();
|
||||
trace!("ExpandState old state: {:?}", child_state);
|
||||
child_state.do_action(&action);
|
||||
trace!("ExpandState new state: {:?}", child_state);
|
||||
let child = match self.graph.find_node(&child_state) {
|
||||
Some(n) => {
|
||||
trace!("ExpandState expanded to existing game state");
|
||||
n
|
||||
}
|
||||
None => {
|
||||
trace!("ExpandState expanded to new game state");
|
||||
self.graph.append_node(child_state, Default::default())
|
||||
}
|
||||
};
|
||||
self
|
||||
.graph
|
||||
.append_edge(self.rollout_node, child, EdgeData::new(action));
|
||||
let child = graph.find_or_add_vertex(child_state, Default::default);
|
||||
graph.find_or_add_edge(self.rollout_node, child, || EdgeData::new(action));
|
||||
}
|
||||
}
|
||||
RolloutPhase {
|
||||
rng: self.rng,
|
||||
settings: self.settings,
|
||||
graph: self.graph,
|
||||
root_node: self.root_node,
|
||||
}
|
||||
}
|
||||
@@ -298,43 +283,38 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
type Graph = search_graph::Graph<tictactoe::State, VertexData, EdgeData<tictactoe::ScoredGame>>;
|
||||
type Graph = search_graph::Graph<search_graph::Closed, tictactoe::State, VertexData, EdgeData<tictactoe::ScoredGame>>;
|
||||
|
||||
#[test]
|
||||
fn rollout_init() {
|
||||
let mut graph = Graph::new();
|
||||
|
||||
search_graph::view::of_graph(&mut graph, |view| {
|
||||
Graph::new().open(|mut graph| {
|
||||
RolloutPhase::initialize(
|
||||
default_rng(),
|
||||
default_settings(),
|
||||
default_game_state(),
|
||||
view,
|
||||
&mut graph,
|
||||
);
|
||||
});
|
||||
|
||||
let node = graph.find_node(&default_game_state());
|
||||
assert!(node.is_some());
|
||||
let node = node.unwrap();
|
||||
assert!(node.is_leaf());
|
||||
assert!(node.is_root());
|
||||
assert_eq!(default_game_state(), *node.get_label());
|
||||
let node = graph.find_vertex_with_state(&default_game_state()).unwrap();
|
||||
let v = graph.vertex(node);
|
||||
assert!(v.children.is_empty());
|
||||
assert!(v.parents.is_empty());
|
||||
assert_eq!(default_game_state(), *graph.game_state_at(node));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_test() {
|
||||
let mut graph = Graph::new();
|
||||
|
||||
search_graph::view::of_graph(&mut graph, |view| {
|
||||
Graph::new().open(|mut graph| {
|
||||
// Rollout.
|
||||
let rollout_phase = RolloutPhase::initialize(
|
||||
default_rng(),
|
||||
default_settings(),
|
||||
default_game_state(),
|
||||
view,
|
||||
&mut graph,
|
||||
);
|
||||
let rollout_target = crate::rollout::rollout(
|
||||
&rollout_phase.graph,
|
||||
&graph,
|
||||
rollout_phase.root_node,
|
||||
ucb::Rollout::from(&default_settings()),
|
||||
&mut default_rng(),
|
||||
@@ -343,301 +323,303 @@ mod test {
|
||||
assert_eq!(rollout_phase.root_node(), rollout_target);
|
||||
assert_eq!(
|
||||
tictactoe::State::default(),
|
||||
*rollout_phase.graph.node_state(rollout_target)
|
||||
*graph.game_state_at(rollout_target)
|
||||
);
|
||||
|
||||
// Scoring.
|
||||
let scoring_phase: ScoringPhase<'_, '_, _, tictactoe::ScoredGame> =
|
||||
rollout_phase.rollout::<ucb::Rollout>().unwrap();
|
||||
let scoring_phase: ScoringPhase<'_, _> =
|
||||
rollout_phase.rollout::<_, ucb::Rollout>(&graph).unwrap();
|
||||
assert_eq!(scoring_phase.rollout_node(), rollout_target);
|
||||
assert_eq!(scoring_phase.root_node(), rollout_target);
|
||||
|
||||
// Backprop.
|
||||
let backprop_phase: BackpropPhase<'_, '_, _, tictactoe::ScoredGame> = scoring_phase
|
||||
.score::<simulation::RandomSimulator>()
|
||||
let backprop_phase: BackpropPhase<'_, _, tictactoe::ScoredGame> = scoring_phase
|
||||
.score::<_, simulation::RandomSimulator>(&graph)
|
||||
.unwrap();
|
||||
|
||||
// Expand.
|
||||
let expand_phase: ExpandPhase<'_, '_, _, tictactoe::ScoredGame> =
|
||||
backprop_phase.backprop::<backprop::FirstParentSelector>();
|
||||
expand_phase.expand();
|
||||
});
|
||||
let expand_phase: ExpandPhase<'_, _> =
|
||||
backprop_phase.backprop::<backprop::FirstParentSelector>(&graph);
|
||||
expand_phase.expand(&mut graph);
|
||||
|
||||
{
|
||||
// Search graph should consist of root, edges for each possible move, and
|
||||
// leaves for the result of each move.
|
||||
let node = graph.find_node(&Default::default()).unwrap();
|
||||
assert!(node.is_root());
|
||||
assert!(!node.is_leaf());
|
||||
assert_eq!(9, node.get_child_list().len());
|
||||
for child in node.get_child_list().iter() {
|
||||
assert_eq!(0, child.get_data().statistics.visits());
|
||||
let node = graph.find_vertex_with_state(&Default::default()).unwrap();
|
||||
let v = graph.vertex(node);
|
||||
assert!(v.parents.is_empty());
|
||||
assert!(!v.children.is_empty());
|
||||
assert_eq!(9, v.children.len());
|
||||
for child in v.children.iter().cloned() {
|
||||
assert_eq!(0, graph.edge(child).data.statistics.visits());
|
||||
assert_eq!(
|
||||
0,
|
||||
child
|
||||
.get_data()
|
||||
graph.edge(child)
|
||||
.data
|
||||
.statistics
|
||||
.score(crate::statistics::two_player::Player::One)
|
||||
);
|
||||
assert_eq!(
|
||||
0,
|
||||
child
|
||||
.get_data()
|
||||
graph
|
||||
.edge(child)
|
||||
.data
|
||||
.statistics
|
||||
.score(crate::statistics::two_player::Player::Two)
|
||||
);
|
||||
assert!(child.get_target().is_leaf());
|
||||
assert!(!child.get_target().is_root());
|
||||
assert!(graph.vertex(graph.edge(child).target).children.is_empty());
|
||||
assert!(!graph.vertex(graph.edge(child).target).parents.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
search_graph::view::of_graph(&mut graph, |view| {
|
||||
RolloutPhase::initialize(
|
||||
default_rng(),
|
||||
default_settings(),
|
||||
default_game_state(),
|
||||
view,
|
||||
&mut graph,
|
||||
)
|
||||
.rollout::<ucb::Rollout>()
|
||||
.rollout::<_, ucb::Rollout>(&graph)
|
||||
.unwrap()
|
||||
.score::<simulation::RandomSimulator>()
|
||||
.score::<_, simulation::RandomSimulator>(&graph)
|
||||
.unwrap()
|
||||
.backprop::<backprop::FirstParentSelector>()
|
||||
.expand();
|
||||
});
|
||||
.backprop::<backprop::FirstParentSelector>(&mut graph)
|
||||
.expand(&mut graph);
|
||||
|
||||
{
|
||||
// Two levels of should be expanded.
|
||||
// Two levels should be expanded.
|
||||
assert_eq!(18, graph.vertex_count());
|
||||
assert_eq!(17, graph.edge_count());
|
||||
let node = graph.find_node(&Default::default()).unwrap();
|
||||
assert!(node.is_root());
|
||||
assert!(!node.is_leaf());
|
||||
assert_eq!(9, node.get_child_list().len());
|
||||
for (i, child) in node.get_child_list().iter().enumerate() {
|
||||
let node = graph.find_vertex_with_state(&Default::default()).unwrap();
|
||||
let v = graph.vertex(node);
|
||||
assert!(v.parents.is_empty());
|
||||
assert!(!v.children.is_empty());
|
||||
assert_eq!(9, v.children.len());
|
||||
for (i, child) in v.children.iter().copied().enumerate() {
|
||||
if i == 7 {
|
||||
assert_eq!(1, child.get_data().statistics.visits());
|
||||
assert_eq!(1, graph.edge(child).data.statistics.visits());
|
||||
assert_eq!(
|
||||
0,
|
||||
child
|
||||
.get_data()
|
||||
1,
|
||||
graph
|
||||
.edge(child)
|
||||
.data
|
||||
.statistics
|
||||
.score(crate::statistics::two_player::Player::One)
|
||||
);
|
||||
assert_eq!(
|
||||
1,
|
||||
child
|
||||
.get_data()
|
||||
0,
|
||||
graph
|
||||
.edge(child)
|
||||
.data
|
||||
.statistics
|
||||
.score(crate::statistics::two_player::Player::Two)
|
||||
);
|
||||
} else {
|
||||
assert_eq!(0, child.get_data().statistics.visits());
|
||||
assert_eq!(0, graph.edge(child).data.statistics.visits());
|
||||
assert_eq!(
|
||||
0,
|
||||
child
|
||||
.get_data()
|
||||
graph
|
||||
.edge(child)
|
||||
.data
|
||||
.statistics
|
||||
.score(crate::statistics::two_player::Player::One)
|
||||
);
|
||||
assert_eq!(
|
||||
0,
|
||||
child
|
||||
.get_data()
|
||||
graph
|
||||
.edge(child)
|
||||
.data
|
||||
.statistics
|
||||
.score(crate::statistics::two_player::Player::Two)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
search_graph::view::of_graph(&mut graph, |view| {
|
||||
RolloutPhase::initialize(
|
||||
default_rng(),
|
||||
default_settings(),
|
||||
default_game_state(),
|
||||
view,
|
||||
&mut graph,
|
||||
)
|
||||
.rollout::<ucb::Rollout>()
|
||||
.rollout::<_, ucb::Rollout>(&graph)
|
||||
.unwrap()
|
||||
.score::<simulation::RandomSimulator>()
|
||||
.score::<_, simulation::RandomSimulator>(&graph)
|
||||
.unwrap()
|
||||
.backprop::<backprop::FirstParentSelector>()
|
||||
.expand();
|
||||
.backprop::<backprop::FirstParentSelector>(&mut graph)
|
||||
.expand(&mut graph);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_test_first_parent_selector() {
|
||||
let mut graph = Graph::new();
|
||||
search_graph::view::of_graph(&mut graph, |view| {
|
||||
Graph::new().open(|mut graph| {
|
||||
let mut rollout = RolloutPhase::initialize(
|
||||
default_rng(),
|
||||
default_settings(),
|
||||
default_game_state(),
|
||||
view,
|
||||
&mut graph,
|
||||
);
|
||||
for _ in 0..10000 {
|
||||
let score = rollout.rollout::<ucb::Rollout>().unwrap();
|
||||
let backprop = score.score::<simulation::RandomSimulator>().unwrap();
|
||||
let expand = backprop.backprop::<backprop::FirstParentSelector>();
|
||||
rollout = expand.expand();
|
||||
let score = rollout.rollout::<_, ucb::Rollout>(&graph).unwrap();
|
||||
let backprop = score.score::<_, simulation::RandomSimulator>(&graph).unwrap();
|
||||
let expand = backprop.backprop::<backprop::FirstParentSelector>(&graph);
|
||||
rollout = expand.expand(&mut graph);
|
||||
}
|
||||
|
||||
assert_eq!(755, graph.vertex_count());
|
||||
assert_eq!(947, graph.edge_count());
|
||||
|
||||
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
|
||||
graph.vertex(graph
|
||||
.find_vertex_with_state(&default_game_state())
|
||||
.unwrap())
|
||||
.children
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|c| &graph.edge(c).data.statistics)
|
||||
.collect();
|
||||
assert_eq!(9, child_statistics.len());
|
||||
assert_eq!(7, child_statistics[0].visits());
|
||||
assert_eq!(2, child_statistics[0].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(2, child_statistics[0].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(14, child_statistics[1].visits());
|
||||
assert_eq!(9, child_statistics[1].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(5, child_statistics[1].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(13, child_statistics[2].visits());
|
||||
assert_eq!(9, child_statistics[2].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(4, child_statistics[2].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(15, child_statistics[3].visits());
|
||||
assert_eq!(7, child_statistics[3].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(7, child_statistics[3].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(32, child_statistics[4].visits());
|
||||
assert_eq!(20, child_statistics[4].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(7, child_statistics[4].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(6, child_statistics[5].visits());
|
||||
assert_eq!(2, child_statistics[5].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(3, child_statistics[5].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(3, child_statistics[6].visits());
|
||||
assert_eq!(0, child_statistics[6].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(2, child_statistics[6].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(9, child_statistics[7].visits());
|
||||
assert_eq!(4, child_statistics[7].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(4, child_statistics[7].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(9900, child_statistics[8].visits());
|
||||
assert_eq!(43, child_statistics[8].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(29, child_statistics[8].score(crate::statistics::two_player::Player::Two));
|
||||
});
|
||||
|
||||
assert_eq!(3295, graph.vertex_count());
|
||||
assert_eq!(5361, graph.edge_count());
|
||||
|
||||
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
|
||||
graph
|
||||
.find_node(&default_game_state())
|
||||
.unwrap()
|
||||
.get_child_list()
|
||||
.iter()
|
||||
.map(|c| &c.get_data().statistics)
|
||||
.collect();
|
||||
assert_eq!(108, child_statistics[0].visits());
|
||||
assert_eq!(28, child_statistics[0].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(78, child_statistics[0].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(209, child_statistics[1].visits());
|
||||
assert_eq!(153, child_statistics[1].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(50, child_statistics[1].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(754, child_statistics[2].visits());
|
||||
assert_eq!(557, child_statistics[2].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(178, child_statistics[2].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(2023, child_statistics[3].visits());
|
||||
assert_eq!(1548, child_statistics[3].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(459, child_statistics[3].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(2105, child_statistics[4].visits());
|
||||
assert_eq!(834, child_statistics[4].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(1181, child_statistics[4].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1172, child_statistics[5].visits());
|
||||
assert_eq!(574, child_statistics[5].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(590, child_statistics[5].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(953, child_statistics[6].visits());
|
||||
assert_eq!(483, child_statistics[6].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(282, child_statistics[6].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(806, child_statistics[7].visits());
|
||||
assert_eq!(382, child_statistics[7].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(410, child_statistics[7].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1869, child_statistics[8].visits());
|
||||
assert_eq!(678, child_statistics[8].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(913, child_statistics[8].score(crate::statistics::two_player::Player::Two));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_test_random_parent_selector() {
|
||||
let mut graph = Graph::new();
|
||||
search_graph::view::of_graph(&mut graph, |view| {
|
||||
Graph::new().open(|mut graph| {
|
||||
let mut rollout = RolloutPhase::initialize(
|
||||
default_rng(),
|
||||
default_settings(),
|
||||
default_game_state(),
|
||||
view,
|
||||
&mut graph,
|
||||
);
|
||||
for _ in 0..10000 {
|
||||
let score = rollout.rollout::<ucb::Rollout>().unwrap();
|
||||
let backprop = score.score::<simulation::RandomSimulator>().unwrap();
|
||||
let expand = backprop.backprop::<backprop::RandomParentSelector>();
|
||||
rollout = expand.expand();
|
||||
let score = rollout.rollout::<_, ucb::Rollout>(&graph).unwrap();
|
||||
let backprop = score.score::<_, simulation::RandomSimulator>(&graph).unwrap();
|
||||
let expand = backprop.backprop::<backprop::RandomParentSelector>(&mut graph);
|
||||
rollout = expand.expand(&mut graph);
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(5795, graph.vertex_count());
|
||||
assert_eq!(13874, graph.edge_count());
|
||||
assert_eq!(4946, graph.vertex_count());
|
||||
assert_eq!(10077, graph.edge_count());
|
||||
|
||||
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
|
||||
graph
|
||||
.find_node(&default_game_state())
|
||||
.unwrap()
|
||||
.get_child_list()
|
||||
.iter()
|
||||
.map(|c| &c.get_data().statistics)
|
||||
.vertex(graph
|
||||
.find_vertex_with_state(&default_game_state())
|
||||
.unwrap())
|
||||
.children
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|c| &graph.edge(c).data.statistics)
|
||||
.collect();
|
||||
assert_eq!(1204, child_statistics[0].visits());
|
||||
assert_eq!(708, child_statistics[0].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(347, child_statistics[0].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(790, child_statistics[1].visits());
|
||||
assert_eq!(396, child_statistics[1].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(266, child_statistics[1].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(983, child_statistics[2].visits());
|
||||
assert_eq!(553, child_statistics[2].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(249, child_statistics[2].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(829, child_statistics[3].visits());
|
||||
assert_eq!(471, child_statistics[3].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(258, child_statistics[3].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(2015, child_statistics[4].visits());
|
||||
assert_eq!(1224, child_statistics[4].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(388, child_statistics[4].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(795, child_statistics[5].visits());
|
||||
assert_eq!(449, child_statistics[5].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(242, child_statistics[5].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1035, child_statistics[6].visits());
|
||||
assert_eq!(578, child_statistics[6].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(287, child_statistics[6].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1002, child_statistics[7].visits());
|
||||
assert_eq!(581, child_statistics[7].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(271, child_statistics[7].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1346, child_statistics[8].visits());
|
||||
assert_eq!(798, child_statistics[8].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(399, child_statistics[8].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(914, child_statistics[0].visits());
|
||||
assert_eq!(361, child_statistics[0].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(184, child_statistics[0].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1143, child_statistics[1].visits());
|
||||
assert_eq!(426, child_statistics[1].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(216, child_statistics[1].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1039, child_statistics[2].visits());
|
||||
assert_eq!(305, child_statistics[2].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(154, child_statistics[2].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(770, child_statistics[3].visits());
|
||||
assert_eq!(258, child_statistics[3].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(125, child_statistics[3].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(3273, child_statistics[4].visits());
|
||||
assert_eq!(1154, child_statistics[4].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(384, child_statistics[4].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(705, child_statistics[5].visits());
|
||||
assert_eq!(261, child_statistics[5].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(155, child_statistics[5].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(648, child_statistics[6].visits());
|
||||
assert_eq!(270, child_statistics[6].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(162, child_statistics[6].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(795, child_statistics[7].visits());
|
||||
assert_eq!(225, child_statistics[7].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(128, child_statistics[7].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(712, child_statistics[8].visits());
|
||||
assert_eq!(212, child_statistics[8].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(168, child_statistics[8].score(crate::statistics::two_player::Player::Two));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_test_best_parent_selector() {
|
||||
let mut graph = Graph::new();
|
||||
search_graph::view::of_graph(&mut graph, |view| {
|
||||
Graph::new().open(|mut graph| {
|
||||
let mut rollout = RolloutPhase::initialize(
|
||||
default_rng(),
|
||||
default_settings(),
|
||||
default_game_state(),
|
||||
view,
|
||||
&mut graph,
|
||||
);
|
||||
for _ in 0..10000 {
|
||||
let score = rollout.rollout::<ucb::Rollout>().unwrap();
|
||||
let backprop = score.score::<simulation::RandomSimulator>().unwrap();
|
||||
let expand = backprop.backprop::<ucb::BestParentBackprop>();
|
||||
rollout = expand.expand();
|
||||
let score = rollout.rollout::<_, ucb::Rollout>(&mut graph).unwrap();
|
||||
let backprop = score.score::<_, simulation::RandomSimulator>(&graph).unwrap();
|
||||
let expand = backprop.backprop::<ucb::BestParentBackprop>(&mut graph);
|
||||
rollout = expand.expand(&mut graph);
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(4471, graph.vertex_count());
|
||||
assert_eq!(10889, graph.edge_count());
|
||||
assert_eq!(5601, graph.vertex_count());
|
||||
assert_eq!(14471, graph.edge_count());
|
||||
|
||||
let child_statistics: Vec<&crate::statistics::two_player::ScoredStatistics<_>> =
|
||||
graph
|
||||
.find_node(&default_game_state())
|
||||
.unwrap()
|
||||
.get_child_list()
|
||||
.vertex(graph
|
||||
.find_vertex_with_state(&default_game_state())
|
||||
.unwrap())
|
||||
.children
|
||||
.iter()
|
||||
.map(|c| &c.get_data().statistics)
|
||||
.copied()
|
||||
.map(|c| &graph.edge(c).data.statistics)
|
||||
.collect();
|
||||
assert_eq!(229, child_statistics[0].visits());
|
||||
assert_eq!(138, child_statistics[0].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(53, child_statistics[0].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(73, child_statistics[1].visits());
|
||||
assert_eq!(33, child_statistics[1].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(30, child_statistics[1].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(322, child_statistics[2].visits());
|
||||
assert_eq!(203, child_statistics[2].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(83, child_statistics[2].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(90, child_statistics[3].visits());
|
||||
assert_eq!(44, child_statistics[3].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(32, child_statistics[3].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(9033, child_statistics[4].visits());
|
||||
assert_eq!(7693, child_statistics[4].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(945, child_statistics[4].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(74, child_statistics[5].visits());
|
||||
assert_eq!(34, child_statistics[5].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(35, child_statistics[5].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(118, child_statistics[6].visits());
|
||||
assert_eq!(62, child_statistics[6].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(43, child_statistics[6].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(98, child_statistics[7].visits());
|
||||
assert_eq!(49, child_statistics[7].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(36, child_statistics[7].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(180, child_statistics[8].visits());
|
||||
assert_eq!(104, child_statistics[8].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(41, child_statistics[8].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(888, child_statistics[0].visits());
|
||||
assert_eq!(325, child_statistics[0].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(225, child_statistics[0].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(330, child_statistics[1].visits());
|
||||
assert_eq!(138, child_statistics[1].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(123, child_statistics[1].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(918, child_statistics[2].visits());
|
||||
assert_eq!(340, child_statistics[2].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(236, child_statistics[2].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(792, child_statistics[3].visits());
|
||||
assert_eq!(323, child_statistics[3].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(239, child_statistics[3].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(3817, child_statistics[4].visits());
|
||||
assert_eq!(1054, child_statistics[4].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(423, child_statistics[4].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(657, child_statistics[5].visits());
|
||||
assert_eq!(235, child_statistics[5].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(172, child_statistics[5].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(727, child_statistics[6].visits());
|
||||
assert_eq!(246, child_statistics[6].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(172, child_statistics[6].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(758, child_statistics[7].visits());
|
||||
assert_eq!(270, child_statistics[7].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(192, child_statistics[7].score(crate::statistics::two_player::Player::Two));
|
||||
assert_eq!(1152, child_statistics[8].visits());
|
||||
assert_eq!(448, child_statistics[8].score(crate::statistics::two_player::Player::One));
|
||||
assert_eq!(304, child_statistics[8].score(crate::statistics::two_player::Player::Two));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::fmt;
|
||||
use std::result::Result;
|
||||
|
||||
use rand::Rng;
|
||||
use search_graph::{EdgeRef, Graph, Open, VertexRef};
|
||||
|
||||
/// Error type for MCTS rollout.
|
||||
pub enum RolloutError<G: Game, E: Error> {
|
||||
@@ -75,10 +76,10 @@ pub trait RolloutSelector: for<'a> From<&'a SearchParameters> {
|
||||
/// Returns the element of `children` that should be followed, or an error.
|
||||
fn select<'a, 'id, G: Game, R: Rng>(
|
||||
&self,
|
||||
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
parent: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
parent: VertexRef<Open<'id>>,
|
||||
rng: &mut R,
|
||||
) -> Result<search_graph::view::EdgeRef<'id>, Self::Error>;
|
||||
) -> Result<EdgeRef<Open<'id>>, Self::Error>;
|
||||
}
|
||||
|
||||
/// Traverses the game graph downwards from `node` down to some terminating
|
||||
@@ -90,28 +91,28 @@ pub trait RolloutSelector: for<'a> From<&'a SearchParameters> {
|
||||
///
|
||||
/// Selection will be done minimax-style, i.e., always trying to maximize the
|
||||
/// score for the currently active player.
|
||||
pub fn rollout<'a, 'id, G, S, R>(
|
||||
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
mut node: search_graph::view::NodeRef<'id>,
|
||||
pub fn rollout<'id, G, S, R>(
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
mut node: VertexRef<Open<'id>>,
|
||||
selector: S,
|
||||
rng: &mut R,
|
||||
) -> Result<search_graph::view::NodeRef<'id>, RolloutError<G, S::Error>>
|
||||
) -> Result<VertexRef<Open<'id>>, RolloutError<G, S::Error>>
|
||||
where
|
||||
G: Game,
|
||||
S: RolloutSelector,
|
||||
R: Rng,
|
||||
{
|
||||
loop {
|
||||
if let Some(_) = G::payoff_of(graph.node_state(node)) {
|
||||
if let Some(_) = G::payoff_of(graph.game_state_at(node)) {
|
||||
// Hit known payoff.
|
||||
break;
|
||||
} else if graph.child_count(node) == 0 {
|
||||
} else if graph.vertex(node).children.is_empty() {
|
||||
// Hit leaf in search graph.
|
||||
break;
|
||||
} else {
|
||||
let child = selector.select(graph, node, rng)?;
|
||||
graph.edge_data(child).mark_rollout_traversal();
|
||||
node = graph.edge_target(child);
|
||||
graph.edge(child).data.mark_rollout_traversal();
|
||||
node = graph.edge(child).target;
|
||||
}
|
||||
}
|
||||
Ok(node)
|
||||
|
||||
@@ -158,11 +158,16 @@ impl<M: PlayerMapping> ScoredStatistics<M> {
|
||||
let visits = cmp::min(old_visits + 1, VISITS_MAX);
|
||||
let score_one = cmp::min(old_score_one + score_one, SCORE_MAX);
|
||||
let score_two = cmp::min(old_score_two + score_two, SCORE_MAX);
|
||||
success = self.packed.compare_and_swap(
|
||||
old_packed,
|
||||
pack_scores(visits, score_one, score_two),
|
||||
atomic::Ordering::SeqCst,
|
||||
) == old_packed;
|
||||
let cex = match self.packed.compare_exchange_weak(
|
||||
old_packed,
|
||||
pack_scores(visits, score_one, score_two),
|
||||
atomic::Ordering::SeqCst,
|
||||
atomic::Ordering::SeqCst,
|
||||
) {
|
||||
Ok(n) => n,
|
||||
Err(_) => continue,
|
||||
};
|
||||
success = cex == old_packed;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,11 +121,11 @@ impl game::State for State {
|
||||
&self.active_player
|
||||
}
|
||||
|
||||
fn actions<'s>(&'s self) -> Box<dyn Iterator<Item = Action> + 's> {
|
||||
Box::new(iterate![for row in 0..3;
|
||||
for column in 0..3;
|
||||
if self.board.get(row, column).is_none();
|
||||
yield Action { row, column, player: self.active_player, }])
|
||||
fn actions<'s>(&'s self) -> impl Iterator<Item=Self::Action> + 's {
|
||||
iterate![for row in 0..3;
|
||||
for column in 0..3;
|
||||
if self.board.get(row, column).is_none();
|
||||
yield Action { row, column, player: self.active_player, }]
|
||||
}
|
||||
|
||||
fn do_action(&mut self, action: &Action) {
|
||||
|
||||
70
src/ucb.rs
70
src/ucb.rs
@@ -6,7 +6,7 @@ use crate::graph::{EdgeData, VertexData};
|
||||
use crate::rollout::RolloutSelector;
|
||||
use log::{error, trace};
|
||||
use rand::Rng;
|
||||
use search_graph;
|
||||
use search_graph::{EdgeRef, Graph, Open, VertexRef};
|
||||
|
||||
use std::cmp::Ordering;
|
||||
use std::error::Error;
|
||||
@@ -18,9 +18,9 @@ use std::result::Result;
|
||||
pub enum UcbSuccess<'id> {
|
||||
/// No (finite) value can be computed, but the UCB policy indicates that
|
||||
/// this child should be selected. E.g., the child has not yet been visited.
|
||||
Select(search_graph::view::EdgeRef<'id>),
|
||||
Select(EdgeRef<Open<'id>>),
|
||||
/// A value is computed.
|
||||
Value(search_graph::view::EdgeRef<'id>, f64),
|
||||
Value(EdgeRef<Open<'id>>, f64),
|
||||
}
|
||||
|
||||
/// Represents error conditions when computing the UCB score for a child.
|
||||
@@ -63,15 +63,15 @@ impl Error for UcbError {
|
||||
fn edge_ucb_iter<'v, 'id, G>(
|
||||
log_parent_visits: f64,
|
||||
explore_bias: f64,
|
||||
graph: &'v search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
edges: impl Iterator<Item = search_graph::view::EdgeRef<'id>> + 'v,
|
||||
graph: &'v Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
edges: impl Iterator<Item = EdgeRef<Open<'id>>> + 'v,
|
||||
)
|
||||
-> impl Iterator<Item = Result<UcbSuccess<'id>, UcbError>> + 'v
|
||||
where
|
||||
G: Game,
|
||||
{
|
||||
edges.map(move |e| {
|
||||
if graph.edge_data(e).statistics.visits() == 0 {
|
||||
if graph.edge(e).data.statistics.visits() == 0 {
|
||||
trace!("edge_ucb_iter selects unvisited action");
|
||||
Ok(UcbSuccess::Select(e))
|
||||
} else {
|
||||
@@ -93,16 +93,17 @@ where
|
||||
pub fn child_score<'a, 'id, G: Game>(
|
||||
log_parent_visits: f64,
|
||||
explore_bias: f64,
|
||||
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
child: search_graph::view::EdgeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
child: EdgeRef<Open<'id>>,
|
||||
) -> UcbSuccess<'id> {
|
||||
let statistics = &graph.edge_data(child).statistics;
|
||||
let statistics = &graph.edge(child).data.statistics;
|
||||
if statistics.visits() == 0 {
|
||||
UcbSuccess::Select(child)
|
||||
} else {
|
||||
let child_visits = statistics.visits() as f64;
|
||||
let vertex = graph.edge(child).source;
|
||||
let child_score =
|
||||
statistics.score(graph.node_state(graph.edge_source(child)).active_player()) as f64;
|
||||
statistics.score(graph.game_state_at(vertex).active_player()) as f64;
|
||||
let ucb =
|
||||
child_score / child_visits + explore_bias * f64::sqrt(log_parent_visits / child_visits);
|
||||
UcbSuccess::Value(child, ucb)
|
||||
@@ -122,23 +123,23 @@ pub fn child_score<'a, 'id, G: Game>(
|
||||
/// graph (not just a tree), we want to know all of the parent edges which could
|
||||
/// have rolled out to a given child.
|
||||
pub fn is_best_child<'a, 'id, G: Game>(
|
||||
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
e: search_graph::view::EdgeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
e: EdgeRef<Open<'id>>,
|
||||
explore_bias: f64,
|
||||
) -> bool {
|
||||
let statistics = &graph.edge_data(e).statistics;
|
||||
let statistics = &graph.edge(e).data.statistics;
|
||||
// trace!("is_best_child: edge {} has {} visits", e.get_id(), stats.visits);
|
||||
if statistics.visits() == 0 {
|
||||
// Edge has been visited, but statistics aren't yet updated.
|
||||
// trace!("is_best_child: edge {} is a best child because stats.visits == 0", e.get_id());
|
||||
return true;
|
||||
}
|
||||
let parent = graph.edge_source(e);
|
||||
let parent = graph.edge(e).source;
|
||||
// trace!("is_best_child: edge {} (from node {}) has {} siblings", e.get_id(), parent.get_id(), siblings.len());
|
||||
let log_parent_visits = {
|
||||
let mut parent_visits = 0;
|
||||
for child_edge in graph.children(parent) {
|
||||
parent_visits += graph.edge_data(child_edge).statistics.visits();
|
||||
for child_edge in &graph.vertex(parent).children {
|
||||
parent_visits += graph.edge(*child_edge).data.statistics.visits();
|
||||
}
|
||||
f64::ln(parent_visits as f64)
|
||||
};
|
||||
@@ -148,7 +149,7 @@ pub fn is_best_child<'a, 'id, G: Game>(
|
||||
log_parent_visits,
|
||||
explore_bias,
|
||||
graph,
|
||||
graph.children(parent),
|
||||
graph.vertex(parent).children.iter().copied(),
|
||||
);
|
||||
// Scan through siblings to find the maximum UCB score. This is
|
||||
// short-circuited using a lazy iterator to ameliorate the O(n) running
|
||||
@@ -204,19 +205,19 @@ pub fn is_best_child<'a, 'id, G: Game>(
|
||||
///
|
||||
/// This function will panic if `parent` has no children.
|
||||
pub fn find_best_child<'a, 'id, G, R>(
|
||||
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
parent: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
parent: VertexRef<Open<'id>>,
|
||||
explore_bias: f64,
|
||||
rng: &mut R,
|
||||
) -> Result<search_graph::view::EdgeRef<'id>, UcbError>
|
||||
) -> Result<EdgeRef<Open<'id>>, UcbError>
|
||||
where
|
||||
G: Game,
|
||||
R: Rng,
|
||||
{
|
||||
let log_parent_visits = {
|
||||
let mut parent_visits = 0;
|
||||
for child in graph.children(parent) {
|
||||
parent_visits += graph.edge_data(child).statistics.visits();
|
||||
for child in &graph.vertex(parent).children {
|
||||
parent_visits += graph.edge(*child).data.statistics.visits();
|
||||
}
|
||||
if parent_visits == 0 {
|
||||
// When we visit a vertex for the first time, it will have zero visits.
|
||||
@@ -231,9 +232,9 @@ where
|
||||
log_parent_visits,
|
||||
explore_bias,
|
||||
graph,
|
||||
graph.children(parent),
|
||||
graph.vertex(parent).children.iter().copied(),
|
||||
);
|
||||
let mut best: search_graph::view::EdgeRef<'id>;
|
||||
let mut best: EdgeRef<Open<'id>>;
|
||||
let mut best_ucb: f64;
|
||||
match ucb_iter.next().expect("vertex has no children")? {
|
||||
UcbSuccess::Select(e) => {
|
||||
@@ -301,10 +302,10 @@ impl RolloutSelector for Rollout {
|
||||
|
||||
fn select<'a, 'id, G: Game, R: Rng>(
|
||||
&self,
|
||||
graph: &search_graph::view::View<'a, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
parent: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
parent: VertexRef<Open<'id>>,
|
||||
rng: &mut R,
|
||||
) -> Result<search_graph::view::EdgeRef<'id>, UcbError> {
|
||||
) -> Result<EdgeRef<Open<'id>>, UcbError> {
|
||||
find_best_child(graph, parent, self.explore_bias, rng)
|
||||
}
|
||||
}
|
||||
@@ -330,18 +331,21 @@ impl<'id> BackpropSelector<'id> for BestParentBackprop {
|
||||
// ATCs/HKTs. We need ATC support because this iterator type will have its
|
||||
// lifetime constrained by the borrow of `graph` in the select method, but
|
||||
// that lifetime isn't known statically.
|
||||
type Items = std::vec::IntoIter<search_graph::view::EdgeRef<'id>>;
|
||||
type Items = std::vec::IntoIter<EdgeRef<Open<'id>>>;
|
||||
|
||||
fn select<G: Game, R: Rng>(
|
||||
&self,
|
||||
graph: &search_graph::view::View<'_, 'id, G::State, VertexData, EdgeData<G>>,
|
||||
node: search_graph::view::NodeRef<'id>,
|
||||
graph: &Graph<Open<'id>, G::State, VertexData, EdgeData<G>>,
|
||||
node: VertexRef<Open<'id>>,
|
||||
_payoff: &G::Payoff,
|
||||
_rng: &mut R,
|
||||
) -> Self::Items {
|
||||
let result: Vec<search_graph::view::EdgeRef<'id>> = graph
|
||||
.parents(node)
|
||||
.filter(|&parent_edge| is_best_child(graph, parent_edge, self.explore_bias))
|
||||
let result: Vec<EdgeRef<Open<'id>>> = graph
|
||||
.vertex(node)
|
||||
.parents
|
||||
.iter()
|
||||
.filter(|&parent_edge| is_best_child(graph, *parent_edge, self.explore_bias))
|
||||
.copied()
|
||||
.collect();
|
||||
result.into_iter()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user