Skip to content

Latest commit

 

History

History
337 lines (247 loc) · 9.96 KB

File metadata and controls

337 lines (247 loc) · 9.96 KB

API Reference

Every public function in HMMRust. Each section gives the signature, what it returns, what to watch out for, and a usage snippet.


Hmm struct

pub struct Hmm {
    pub n_states: usize,
    pub n_symbols: usize,
    pub pi: Array1<f64>,    // initial state distribution, shape [N]
    pub a:  Array2<f64>,    // transition matrix, shape [N × N]
    pub b:  Array2<f64>,    // emission matrix, shape [N × M]
}

All fields are public so you can inspect and hand‑tune parameters after training (or before, if you know what you're doing).

Hmm::new(n_states, n_symbols, rng) → Self

Random initialisation. Every row of A and B is sampled from a flat Dirichlet — uniform over the probability simplex. pi is sampled the same way.

let mut rng = rand::thread_rng();
let hmm = Hmm::new(3, 10, &mut rng);

Dirichlet sampling means every parameter is guaranteed to be a valid probability distribution (non‑negative, rows sum to 1). But it also means the model starts out maximally uninformative — it has no prior structure. Baum‑Welch has to discover it from data alone. For hard problems, consider multiple restarts.

Hmm::from_params(pi, a, b) → Result<Self, String>

Construct from explicit parameters. Validates that:

  • pi sums to 1 (± 1e‑10)
  • a has shape [N, N] and every row sums to 1
  • b has shape [N, M] and every row sums to 1
  • M (n_symbols) is inferred from b.ncols()

Returns Err(String) describing which check failed.

let pi = Array1::from_vec(vec![0.5, 0.5]);
let a  = Array2::from_shape_vec((2,2), vec![0.7, 0.3, 0.4, 0.6]).unwrap();
let b  = Array2::from_shape_vec((2,3), vec![0.5, 0.3, 0.2, 0.2, 0.4, 0.4]).unwrap();

let hmm = Hmm::from_params(pi, a, b)
    .expect("parameters should be valid");

If you need different tolerance or custom validation, construct directly:

let hmm = Hmm {
    n_states: 2,
    n_symbols: 3,
    pi: my_pi,
    a:  my_a,
    b:  my_b,
};

Hmm::generate(&self, rng, len) → (Vec<usize>, Vec<usize>)

Sample a hidden state sequence and corresponding observation sequence from the model. Returns (states, observations) — both Vec<usize> of length len.

let (hidden, observed) = hmm.generate(&mut rng, 500);
assert_eq!(hidden.len(), 500);
assert_eq!(observed.len(), 500);

Uses rand::distributions::WeightedIndex internally. Panics if any probability row in A or B has all zeros — which shouldn't happen with valid params.


Algorithms

All algorithm functions live in hmmrust::algorithms. They're free functions, not methods — keeps the Hmm struct clean and lets algorithms live in a separate module.

Import everything:

use hmmrust::algorithms;

Or cherry‑pick:

use hmmrust::algorithms::{forward, viterbi, baum_welch};

forward(&hmm, observations) → (Array2<f64>, Vec<f64>)

Scaled forward algorithm. Returns:

Return Type Shape Meaning
alpha Array2<f64> [T × N] scaled forward probabilities
scale Vec<f64> T scaling factors at each time step

Each row of alpha sums to exactly 1.0 (modulo float noise). The scaling factors c_t are needed by backward() and log_likelihood(). Without scaling, alpha values would shrink exponentially — after ~50 steps you'd hit f64 underflow.

let obs = vec![0, 1, 0, 0, 1];
let (alpha, scale) = algorithms::forward(&hmm, &obs);

// alpha[t][i] = scaled P(state = i at t | observations up to t)
let row_sum: f64 = alpha.row(3).sum();
assert!((row_sum - 1.0).abs() < 1e-10);

Edge case: if an observation is out of range (≥ hmm.n_symbols), the emission probability is looked up normally — which means it panics with an ndarray index‑out‑of‑bounds error. Validate your input.


backward(&hmm, observations, scale) → Array2<f64>

Scaled backward algorithm. Takes the same scale vector returned by forward().

Return Type Shape Meaning
beta Array2<f64> [T × N] scaled backward probabilities

Uses the same scaling factors as the forward pass. This is critical — if you pass different scale values, gamma and xi computations in Baum‑Welch will be wrong.

let (alpha, scale) = algorithms::forward(&hmm, &obs);
let beta = algorithms::backward(&hmm, &obs, &scale);

// beta[T-1][i] = scale[T-1]  (by construction)
for i in 0..hmm.n_states {
    assert!((beta[(obs.len()-1, i)] - scale[obs.len()-1]).abs() < 1e-10);
}

log_likelihood(scale) → f64

Extract log P(observations | model) from forward scaling factors.

let (_, scale) = algorithms::forward(&hmm, &obs);
let ll = algorithms::log_likelihood(&scale);
// ll = -Σ_t ln(c_t)

The log‑likelihood is always ≤ 0. For a perfect model on deterministic data, it approaches 0. For long real‑world sequences, expect large negative numbers (‑1000, ‑5000, etc.).


viterbi(&hmm, observations) → Vec<usize>

Log‑space Viterbi decoding. Returns the single most likely hidden state sequence.

let path = algorithms::viterbi(&hmm, &obs);
assert_eq!(path.len(), obs.len());

for &state in &path {
    assert!(state < hmm.n_states);
}

Computes log(pi), log(A), and log(B) up front. Zeros in any probability are mapped to f64::NEG_INFINITY so they can't win the max. Runs in O(T·N²) time and O(T·N) memory.

Why log‑space? In a 500‑step sequence with 5 states, the raw probability of any single path is roughly 0.2^(2·500) ≈ 10⁻⁷⁰⁰ — way below f64's minimum. Adding logs instead of multiplying keeps everything well‑behaved.


smooth(&hmm, observations) → Array2<f64>

Forward‑backward smoothing. Returns gamma: the posterior probability of each state at each time step given ALL observations.

Return Type Shape Meaning
gamma Array2<f64> [T × N] γ_t(i) = P(state = i at t | all observations)
let gamma = algorithms::smooth(&hmm, &obs);

// Each row sums to 1
for t in 0..obs.len() {
    let sum: f64 = gamma.row(t).sum();
    assert!((sum - 1.0).abs() < 1e-10);
}

// gamma[42][1] = probability the HMM was in state 1 at time step 42
println!("P(state=1 at t=42) = {:.3}", gamma[(42, 1)]);

Internally calls forward() + backward(), then normalises their product. If both alpha and beta for a timestep are all zero (degenerate model), that row stays zeros.


baum_welch(&mut hmm, sequences, max_iter, tol) → f64

Baum‑Welch expectation‑maximization training. Modifies hmm in place.

Param Type Meaning
hmm &mut Hmm starting parameters (overwritten)
sequences &[Vec<usize>] training data — one or more observation sequences
max_iter usize maximum EM iterations
tol f64 stop when log‑likelihood improvement < tol

Returns the final log‑likelihood (summed across all sequences).

let train_data = vec![
    vec![0, 1, 0, 2, 1, 0],
    vec![1, 0, 0, 1, 2, 1],
];
let final_ll = algorithms::baum_welch(&mut hmm, &train_data, 100, 1e-5);
println!("trained model ll: {:.2}", final_ll);

What happens inside:

  1. For each sequence: run forward, run backward, compute gamma, compute xi
  2. Accumulate expected counts across all sequences
  3. Re‑estimate pi, A, B from expected counts
  4. Check convergence (LL delta < tol)
  5. If not converged and max_iter not hit, go to 1

Gotchas:

  • All sequences must use the same observation space (0..n_symbols). Passing a value ≥ n_symbols panics.
  • A single short sequence (e.g. 10 observations with 5 hidden states) won't give useful results. There simply isn't enough signal.
  • The algorithm is NOT guaranteed to find the global optimum. It converges to a local maximum of the likelihood surface. Run multiple restarts.
  • If a row in accum_a or accum_b has zero total expected counts, that row is replaced with a uniform distribution as a fallback. This can happen if the model never assigns significant probability to a particular state.

Convergence behaviour:

With the flat‑Dirichlet init from Hmm::new(), you'll typically see:

  • Iterations 1–10: rapid improvement (LL jumps by 100s of nats)
  • Iterations 10–30: moderate improvement (LL increases by 1–10 nats/iter)
  • Iterations 30–50: slow convergence (LL creeps up by < 1 nat/iter)
  • Iterations 50+: plateau or oscillation

If the model converges to LL ≈ ‑300 on 500 observations of a 6‑symbol problem, that's suspicious — it probably landed in a bad local optimum like the alternating‑state trap. Try again with different rng.


Common patterns

Multiple random restarts

let mut best_ll = f64::NEG_INFINITY;
let mut best_hmm = Hmm::new(n_states, n_symbols, &mut rng);

for _ in 0..5 {
    let mut candidate = Hmm::new(n_states, n_symbols, &mut rng);
    let ll = algorithms::baum_welch(&mut candidate, &sequences, 150, 1e-5);
    if ll > best_ll {
        best_ll = ll;
        best_hmm = candidate;
    }
}

Training with validation

// Split data
let split = observations.len() * 4 / 5;
let train = observations[..split].to_vec();
let valid = observations[split..].to_vec();

algorithms::baum_welch(&mut hmm, &[train], 100, 1e-5);

let (_, scale) = algorithms::forward(&hmm, &valid);
let valid_ll = algorithms::log_likelihood(&scale);
println!("validation ll: {:.2}", valid_ll);

Detecting label swaps

Viterbi labels (state 0, state 1, …) are arbitrary — the model doesn't know which state means "fair die" vs "loaded die." To align labels:

// Compute average emission for each state
let mut means: Vec<(usize, f64)> = (0..hmm.n_states)
    .map(|s| {
        let avg = (0..hmm.n_symbols)
            .map(|k| k as f64 * hmm.b[(s, k)])
            .sum();
        (s, avg)
    })
    .collect();
means.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());

// Now means[0].0 = lowest-mean state, means[2].0 = highest-mean (for 3-state)
println!("bear state: {}", means[0].0);
println!("bull state: {}", means[2].0);