Add probability weighting
This commit is contained in:
parent
8a4963d5e9
commit
36e33d59cb
|
@ -7,7 +7,7 @@ Purrchance is an unofficial Rust implementation of the
|
|||
|
||||
- [ ] Parsing grammars from text format
|
||||
- [x] Basic lists
|
||||
- [ ] Probability weights
|
||||
- [x] Probability weights
|
||||
- [ ] Single-item lists
|
||||
- [ ] Escape sequences
|
||||
- [ ] Shorthand lists
|
||||
|
|
12
src/lib.rs
12
src/lib.rs
|
@ -29,11 +29,11 @@ impl Purrchance for Expr {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct List(Vec<Expr>);
|
||||
pub struct List(Vec<(Expr, f64)>);
|
||||
|
||||
impl Purrchance for List {
|
||||
fn eval(&self, g: &Grammar) -> Option<String> {
|
||||
self.0.choose(&mut thread_rng())?.eval(g)
|
||||
self.0.choose_weighted(&mut thread_rng(), |item| item.1).ok()?.0.eval(g)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ mod tests {
|
|||
let sym1 = Symbol::Terminal("hell".to_string());
|
||||
let sym2 = Symbol::Terminal("o world".to_string());
|
||||
let expr = Expr(vec![sym1, sym2]);
|
||||
let list = List(vec![expr]);
|
||||
let list = List(vec![(expr, 1.0)]);
|
||||
let g = Grammar(HashMap::new());
|
||||
assert_eq!(list.eval(&g), Some("hello world".to_string()));
|
||||
}
|
||||
|
@ -73,7 +73,7 @@ mod tests {
|
|||
fn eval_multiple_terminals_expr_list() {
|
||||
let sym1 = Symbol::Terminal("hello".to_string());
|
||||
let sym2 = Symbol::Terminal("goodbye".to_string());
|
||||
let list = List(vec![Expr(vec![sym1]), Expr(vec![sym2])]);
|
||||
let list = List(vec![(Expr(vec![sym1]), 1.0), (Expr(vec![sym2]), 1.0)]);
|
||||
let g = Grammar(HashMap::new());
|
||||
assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g)));
|
||||
}
|
||||
|
@ -88,7 +88,7 @@ mod tests {
|
|||
#[test]
|
||||
fn eval_valid_nonterminal() {
|
||||
let term = Symbol::Terminal("hello world".to_string());
|
||||
let list = List(vec![Expr(vec![term])]);
|
||||
let list = List(vec![(Expr(vec![term]), 1.0)]);
|
||||
let mut g = Grammar(HashMap::new());
|
||||
g.0.insert("output".to_string(), list);
|
||||
let nt = Symbol::NonTerminal("output".to_string());
|
||||
|
@ -98,7 +98,7 @@ mod tests {
|
|||
#[test]
|
||||
fn eval_missing_nonterminal() {
|
||||
let term = Symbol::Terminal("hello world".to_string());
|
||||
let list = List(vec![Expr(vec![term])]);
|
||||
let list = List(vec![(Expr(vec![term]), 1.0)]);
|
||||
let mut g = Grammar(HashMap::new());
|
||||
g.0.insert("output".to_string(), list);
|
||||
let nt = Symbol::NonTerminal("missing".to_string());
|
||||
|
|
Loading…
Reference in a new issue