Every public function in HMMRust. Each section gives the signature, what it returns, what to watch out for, and a usage snippet.
pub struct Hmm {
pub n_states: usize,
pub n_symbols: usize,
pub pi: Array1<f64>, // initial state distribution, shape [N]
pub a: Array2<f64>, // transition matrix, shape [N × N]
pub b: Array2<f64>, // emission matrix, shape [N × M]
}All fields are public so you can inspect and hand‑tune parameters after training (or before, if you know what you're doing).
Random initialisation. Every row of A and B is sampled from a flat
Dirichlet — uniform over the probability simplex. pi is sampled the same way.
let mut rng = rand::thread_rng();
let hmm = Hmm::new(3, 10, &mut rng);Dirichlet sampling means every parameter is guaranteed to be a valid probability distribution (non‑negative, rows sum to 1). But it also means the model starts out maximally uninformative — it has no prior structure. Baum‑Welch has to discover it from data alone. For hard problems, consider multiple restarts.
Construct from explicit parameters. Validates that:
pisums to 1 (± 1e‑10)ahas shape[N, N]and every row sums to 1bhas shape[N, M]and every row sums to 1M(n_symbols) is inferred fromb.ncols()
Returns Err(String) describing which check failed.
let pi = Array1::from_vec(vec![0.5, 0.5]);
let a = Array2::from_shape_vec((2,2), vec![0.7, 0.3, 0.4, 0.6]).unwrap();
let b = Array2::from_shape_vec((2,3), vec![0.5, 0.3, 0.2, 0.2, 0.4, 0.4]).unwrap();
let hmm = Hmm::from_params(pi, a, b)
.expect("parameters should be valid");If you need different tolerance or custom validation, construct directly:
let hmm = Hmm {
n_states: 2,
n_symbols: 3,
pi: my_pi,
a: my_a,
b: my_b,
};Sample a hidden state sequence and corresponding observation sequence from the
model. Returns (states, observations) — both Vec<usize> of length len.
let (hidden, observed) = hmm.generate(&mut rng, 500);
assert_eq!(hidden.len(), 500);
assert_eq!(observed.len(), 500);Uses rand::distributions::WeightedIndex internally. Panics if any probability
row in A or B has all zeros — which shouldn't happen with valid params.
All algorithm functions live in hmmrust::algorithms. They're free functions,
not methods — keeps the Hmm struct clean and lets algorithms live in a
separate module.
Import everything:
use hmmrust::algorithms;Or cherry‑pick:
use hmmrust::algorithms::{forward, viterbi, baum_welch};Scaled forward algorithm. Returns:
| Return | Type | Shape | Meaning |
|---|---|---|---|
alpha |
Array2<f64> |
[T × N] |
scaled forward probabilities |
scale |
Vec<f64> |
T |
scaling factors at each time step |
Each row of alpha sums to exactly 1.0 (modulo float noise). The scaling
factors c_t are needed by backward() and log_likelihood(). Without
scaling, alpha values would shrink exponentially — after ~50 steps you'd hit
f64 underflow.
let obs = vec![0, 1, 0, 0, 1];
let (alpha, scale) = algorithms::forward(&hmm, &obs);
// alpha[t][i] = scaled P(state = i at t | observations up to t)
let row_sum: f64 = alpha.row(3).sum();
assert!((row_sum - 1.0).abs() < 1e-10);Edge case: if an observation is out of range (≥ hmm.n_symbols), the emission
probability is looked up normally — which means it panics with an ndarray
index‑out‑of‑bounds error. Validate your input.
Scaled backward algorithm. Takes the same scale vector returned by forward().
| Return | Type | Shape | Meaning |
|---|---|---|---|
beta |
Array2<f64> |
[T × N] |
scaled backward probabilities |
Uses the same scaling factors as the forward pass. This is critical — if you pass different scale values, gamma and xi computations in Baum‑Welch will be wrong.
let (alpha, scale) = algorithms::forward(&hmm, &obs);
let beta = algorithms::backward(&hmm, &obs, &scale);
// beta[T-1][i] = scale[T-1] (by construction)
for i in 0..hmm.n_states {
assert!((beta[(obs.len()-1, i)] - scale[obs.len()-1]).abs() < 1e-10);
}Extract log P(observations | model) from forward scaling factors.
let (_, scale) = algorithms::forward(&hmm, &obs);
let ll = algorithms::log_likelihood(&scale);
// ll = -Σ_t ln(c_t)The log‑likelihood is always ≤ 0. For a perfect model on deterministic data, it approaches 0. For long real‑world sequences, expect large negative numbers (‑1000, ‑5000, etc.).
Log‑space Viterbi decoding. Returns the single most likely hidden state sequence.
let path = algorithms::viterbi(&hmm, &obs);
assert_eq!(path.len(), obs.len());
for &state in &path {
assert!(state < hmm.n_states);
}Computes log(pi), log(A), and log(B) up front. Zeros in any probability
are mapped to f64::NEG_INFINITY so they can't win the max. Runs in O(T·N²)
time and O(T·N) memory.
Why log‑space? In a 500‑step sequence with 5 states, the raw probability of any single path is roughly 0.2^(2·500) ≈ 10⁻⁷⁰⁰ — way below f64's minimum. Adding logs instead of multiplying keeps everything well‑behaved.
Forward‑backward smoothing. Returns gamma: the posterior probability of each state at each time step given ALL observations.
| Return | Type | Shape | Meaning |
|---|---|---|---|
gamma |
Array2<f64> |
[T × N] |
γ_t(i) = P(state = i at t | all observations) |
let gamma = algorithms::smooth(&hmm, &obs);
// Each row sums to 1
for t in 0..obs.len() {
let sum: f64 = gamma.row(t).sum();
assert!((sum - 1.0).abs() < 1e-10);
}
// gamma[42][1] = probability the HMM was in state 1 at time step 42
println!("P(state=1 at t=42) = {:.3}", gamma[(42, 1)]);Internally calls forward() + backward(), then normalises their product.
If both alpha and beta for a timestep are all zero (degenerate model), that
row stays zeros.
Baum‑Welch expectation‑maximization training. Modifies hmm in place.
| Param | Type | Meaning |
|---|---|---|
hmm |
&mut Hmm |
starting parameters (overwritten) |
sequences |
&[Vec<usize>] |
training data — one or more observation sequences |
max_iter |
usize |
maximum EM iterations |
tol |
f64 |
stop when log‑likelihood improvement < tol |
Returns the final log‑likelihood (summed across all sequences).
let train_data = vec![
vec![0, 1, 0, 2, 1, 0],
vec![1, 0, 0, 1, 2, 1],
];
let final_ll = algorithms::baum_welch(&mut hmm, &train_data, 100, 1e-5);
println!("trained model ll: {:.2}", final_ll);What happens inside:
- For each sequence: run forward, run backward, compute gamma, compute xi
- Accumulate expected counts across all sequences
- Re‑estimate pi, A, B from expected counts
- Check convergence (LL delta < tol)
- If not converged and
max_iternot hit, go to 1
Gotchas:
- All sequences must use the same observation space (0..n_symbols). Passing a value ≥ n_symbols panics.
- A single short sequence (e.g. 10 observations with 5 hidden states) won't give useful results. There simply isn't enough signal.
- The algorithm is NOT guaranteed to find the global optimum. It converges to a local maximum of the likelihood surface. Run multiple restarts.
- If a row in
accum_aoraccum_bhas zero total expected counts, that row is replaced with a uniform distribution as a fallback. This can happen if the model never assigns significant probability to a particular state.
Convergence behaviour:
With the flat‑Dirichlet init from Hmm::new(), you'll typically see:
- Iterations 1–10: rapid improvement (LL jumps by 100s of nats)
- Iterations 10–30: moderate improvement (LL increases by 1–10 nats/iter)
- Iterations 30–50: slow convergence (LL creeps up by < 1 nat/iter)
- Iterations 50+: plateau or oscillation
If the model converges to LL ≈ ‑300 on 500 observations of a 6‑symbol problem,
that's suspicious — it probably landed in a bad local optimum like the
alternating‑state trap. Try again with different rng.
let mut best_ll = f64::NEG_INFINITY;
let mut best_hmm = Hmm::new(n_states, n_symbols, &mut rng);
for _ in 0..5 {
let mut candidate = Hmm::new(n_states, n_symbols, &mut rng);
let ll = algorithms::baum_welch(&mut candidate, &sequences, 150, 1e-5);
if ll > best_ll {
best_ll = ll;
best_hmm = candidate;
}
}// Split data
let split = observations.len() * 4 / 5;
let train = observations[..split].to_vec();
let valid = observations[split..].to_vec();
algorithms::baum_welch(&mut hmm, &[train], 100, 1e-5);
let (_, scale) = algorithms::forward(&hmm, &valid);
let valid_ll = algorithms::log_likelihood(&scale);
println!("validation ll: {:.2}", valid_ll);Viterbi labels (state 0, state 1, …) are arbitrary — the model doesn't know which state means "fair die" vs "loaded die." To align labels:
// Compute average emission for each state
let mut means: Vec<(usize, f64)> = (0..hmm.n_states)
.map(|s| {
let avg = (0..hmm.n_symbols)
.map(|k| k as f64 * hmm.b[(s, k)])
.sum();
(s, avg)
})
.collect();
means.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// Now means[0].0 = lowest-mean state, means[2].0 = highest-mean (for 3-state)
println!("bear state: {}", means[0].0);
println!("bull state: {}", means[2].0);