Add probability weighting

This commit is contained in:
xenofem 2020-06-18 21:21:39 -04:00
parent 8a4963d5e9
commit 36e33d59cb
2 changed files with 7 additions and 7 deletions

View file

@ -7,7 +7,7 @@ Purrchance is an unofficial Rust implementation of the
- [ ] Parsing grammars from text format - [ ] Parsing grammars from text format
- [x] Basic lists - [x] Basic lists
- [ ] Probability weights - [x] Probability weights
- [ ] Single-item lists - [ ] Single-item lists
- [ ] Escape sequences - [ ] Escape sequences
- [ ] Shorthand lists - [ ] Shorthand lists

View file

@ -29,11 +29,11 @@ impl Purrchance for Expr {
} }
} }
pub struct List(Vec<Expr>); pub struct List(Vec<(Expr, f64)>);
impl Purrchance for List { impl Purrchance for List {
fn eval(&self, g: &Grammar) -> Option<String> { fn eval(&self, g: &Grammar) -> Option<String> {
self.0.choose(&mut thread_rng())?.eval(g) self.0.choose_weighted(&mut thread_rng(), |item| item.1).ok()?.0.eval(g)
} }
} }
@ -64,7 +64,7 @@ mod tests {
let sym1 = Symbol::Terminal("hell".to_string()); let sym1 = Symbol::Terminal("hell".to_string());
let sym2 = Symbol::Terminal("o world".to_string()); let sym2 = Symbol::Terminal("o world".to_string());
let expr = Expr(vec![sym1, sym2]); let expr = Expr(vec![sym1, sym2]);
let list = List(vec![expr]); let list = List(vec![(expr, 1.0)]);
let g = Grammar(HashMap::new()); let g = Grammar(HashMap::new());
assert_eq!(list.eval(&g), Some("hello world".to_string())); assert_eq!(list.eval(&g), Some("hello world".to_string()));
} }
@ -73,7 +73,7 @@ mod tests {
fn eval_multiple_terminals_expr_list() { fn eval_multiple_terminals_expr_list() {
let sym1 = Symbol::Terminal("hello".to_string()); let sym1 = Symbol::Terminal("hello".to_string());
let sym2 = Symbol::Terminal("goodbye".to_string()); let sym2 = Symbol::Terminal("goodbye".to_string());
let list = List(vec![Expr(vec![sym1]), Expr(vec![sym2])]); let list = List(vec![(Expr(vec![sym1]), 1.0), (Expr(vec![sym2]), 1.0)]);
let g = Grammar(HashMap::new()); let g = Grammar(HashMap::new());
assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g))); assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g)));
} }
@ -88,7 +88,7 @@ mod tests {
#[test] #[test]
fn eval_valid_nonterminal() { fn eval_valid_nonterminal() {
let term = Symbol::Terminal("hello world".to_string()); let term = Symbol::Terminal("hello world".to_string());
let list = List(vec![Expr(vec![term])]); let list = List(vec![(Expr(vec![term]), 1.0)]);
let mut g = Grammar(HashMap::new()); let mut g = Grammar(HashMap::new());
g.0.insert("output".to_string(), list); g.0.insert("output".to_string(), list);
let nt = Symbol::NonTerminal("output".to_string()); let nt = Symbol::NonTerminal("output".to_string());
@ -98,7 +98,7 @@ mod tests {
#[test] #[test]
fn eval_missing_nonterminal() { fn eval_missing_nonterminal() {
let term = Symbol::Terminal("hello world".to_string()); let term = Symbol::Terminal("hello world".to_string());
let list = List(vec![Expr(vec![term])]); let list = List(vec![(Expr(vec![term]), 1.0)]);
let mut g = Grammar(HashMap::new()); let mut g = Grammar(HashMap::new());
g.0.insert("output".to_string(), list); g.0.insert("output".to_string(), list);
let nt = Symbol::NonTerminal("missing".to_string()); let nt = Symbol::NonTerminal("missing".to_string());