purrchance/src/lib.rs

108 lines
3.2 KiB
Rust

extern crate rand;
use rand::{seq::SliceRandom, thread_rng};
use std::collections::HashMap;
pub trait Purrchance {
fn eval(&self, g: &Grammar) -> Option<String>;
}
pub enum Symbol {
Terminal(String),
NonTerminal(String),
}
impl Purrchance for Symbol {
fn eval(&self, g: &Grammar) -> Option<String> {
match self {
Symbol::Terminal(s) => Some(s.to_string()),
Symbol::NonTerminal(label) => g.0.get(label)?.eval(g),
}
}
}
pub struct Expr(Vec<Symbol>);
impl Purrchance for Expr {
fn eval(&self, g: &Grammar) -> Option<String> {
Some(self.0.iter().map(|sym| sym.eval(g)).collect::<Option<Vec<String>>>()?.join(""))
}
}
pub struct List(Vec<(Expr, f64)>);
impl Purrchance for List {
fn eval(&self, g: &Grammar) -> Option<String> {
self.0.choose_weighted(&mut thread_rng(), |item| item.1).ok()?.0.eval(g)
}
}
pub struct Grammar(HashMap<String,List>);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn eval_terminal() {
let sym = Symbol::Terminal("hello world".to_string());
let g = Grammar(HashMap::new());
assert_eq!(sym.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_terminals_expr() {
let sym1 = Symbol::Terminal("hell".to_string());
let sym2 = Symbol::Terminal("o world".to_string());
let expr = Expr(vec![sym1, sym2]);
let g = Grammar(HashMap::new());
assert_eq!(expr.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_single_terminals_expr_list() {
let sym1 = Symbol::Terminal("hell".to_string());
let sym2 = Symbol::Terminal("o world".to_string());
let expr = Expr(vec![sym1, sym2]);
let list = List(vec![(expr, 1.0)]);
let g = Grammar(HashMap::new());
assert_eq!(list.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_multiple_terminals_expr_list() {
let sym1 = Symbol::Terminal("hello".to_string());
let sym2 = Symbol::Terminal("goodbye".to_string());
let list = List(vec![(Expr(vec![sym1]), 1.0), (Expr(vec![sym2]), 1.0)]);
let g = Grammar(HashMap::new());
assert!(vec![Some("hello".to_string()), Some("goodbye".to_string())].contains(&list.eval(&g)));
}
#[test]
fn eval_empty_list() {
let list = List(vec![]);
let g = Grammar(HashMap::new());
assert_eq!(list.eval(&g), None);
}
#[test]
fn eval_valid_nonterminal() {
let term = Symbol::Terminal("hello world".to_string());
let list = List(vec![(Expr(vec![term]), 1.0)]);
let mut g = Grammar(HashMap::new());
g.0.insert("output".to_string(), list);
let nt = Symbol::NonTerminal("output".to_string());
assert_eq!(nt.eval(&g), Some("hello world".to_string()));
}
#[test]
fn eval_missing_nonterminal() {
let term = Symbol::Terminal("hello world".to_string());
let list = List(vec![(Expr(vec![term]), 1.0)]);
let mut g = Grammar(HashMap::new());
g.0.insert("output".to_string(), list);
let nt = Symbol::NonTerminal("missing".to_string());
assert_eq!(nt.eval(&g), None);
}
}