Skip to content

4lisyd/HMMRust

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


build tests license rustc deps

                                 ▄    ▄▄▄  ▄    ▄  ▄▄▄   ▄▄▄   ▄▄▄  ▄▄▄
                                █▀▄  █  █ ██▄██  █  █ █  █    █  █
                                █ ▐▌ █  █ █ ▀ █  █▄█  ▀▀▀▄    █  ▀▀▀▄
                                █  █ █▄▄█ █   █  █▄█  █▄▄█    █  █▄▄█
                                     Rust Hidden Markov Models

Baum‑Welch  ·  Viterbi  ·  Forward‑Backward  ·  Pure Rust


No C bindings, no Python wrappers — just Rust, ndarray, and four algorithms that Rabiner wrote down in 1989. Trains with Baum‑Welch, decodes with Viterbi, smooths with forward‑backward. Runs in log‑space where it matters so you don't wake up to a NaN.


📦 Contents


🚀 Quick start

git clone https://github.com/4li/HMMRust
cd HMMRust
cargo run --example casino

You'll see a 2‑state HMM trained on 1000 simulated die rolls — the classic "occasionally dishonest casino" from Rabiner. The terminal output is colour‑coded because apparently I never outgrew syntax highlighting.

cargo run --example stock_regime

Generates 600 days of synthetic returns, trains a 3‑state model, prints a full colour‑coded timeline of bull/bear/sideways regimes with learned parameters.

# Cargo.toml
[dependencies]
hmmrust = "0.1"
ndarray = "0.15"         # you'll need this for the matrix types

🤔 What's an HMM

An HMM is a Markov chain with hidden states. You get a stream of symbols. Some invisible state machine flips between states and each state has its own way of generating symbols. You never see the states — you only see the symbols.

           a_ij                a_jk
     ┌─────────────┐     ┌─────────────┐
     │             │     │             │
     ▼             │     ▼             │
  ┌─────┐       ┌─────┐       ┌─────┐
  │  i  │──────►│  j  │──────►│  k  │      ← states (hidden)
  └──┬──┘       └──┬──┘       └──┬──┘
     │ b_i(o)      │ b_j(o)      │ b_k(o)
     ▼             ▼             ▼
  symbol        symbol        symbol          ← observations (visible)

The questions:

# Problem Algorithm
1 What's P(observations | model)? Forward
2 What's the most likely state path? Viterbi
3 What parameters best explain the data? Baum‑Welch

🧠 How it works

1. Forward — score a sequence

Compute α_t(i), the probability of being in state i at time t having seen the first t observations:

α₁(i) = πᵢ · bᵢ(o₁)
α_{t+1}(j) = [ Σᵢ α_t(i) · a_ij ] · b_j(o_{t+1})

Sum α_T over all states to get P(sequence | model). At each step we rescale by 1 / Σ α_t(i) to keep numbers in floating‑point range. The log‑likelihood drops out of the scaling factors for free.

2. Viterbi — find the best path

Same structure as forward, except swap Σ for max and keep a backpointer matrix. Runs in O(T·N²). Done in log‑space so underflow is a non‑issue.

3. Baum‑Welch — learn the parameters

Expectation‑Maximization for HMMs. Give it one or more observation sequences (unlabelled — you don't know the true states). It alternates:

E‑step — run forward and backward to get posteriors:

γ_t(i) = P(state = i at time t | all observations)
ξ_t(i,j) = P(state = i at t AND state = j at t+1 | all observations)

M‑step — re‑estimate parameters from expected counts:

π̂_i    = γ₁(i)
â_ij   = Σ_t ξ_t(i,j) / Σ_t γ_t(i)
b̂_j(k) = Σ_{t: o_t=k} γ_t(j) / Σ_t γ_t(j)

Repeat until the log‑likelihood stops improving.

Heads‑up. Baum‑Welch is sensitive to initialisation. Hmm::new() samples from a flat Dirichlet, which is maximally uninformative. On small datasets EM can settle into a crap local maximum. Solution: multiple random restarts (the examples do this) or hand‑pick starting parameters if you know the domain. Read the theory doc for details → docs/API.md


🔧 API

All the heavy lifting lives in algorithms. The Hmm struct handles construction and sampling.

use hmmrust::algorithms;           // forward, backward, viterbi, smooth, baum_welch, log_likelihood
use hmmrust::Hmm;                  // the struct + constructors

let mut rng = rand::thread_rng();

// ── random 3‑state, 10‑symbol HMM ────────────────────────
let mut hmm = Hmm::new(3, 10, &mut rng);

// ── or build one by hand ─────────────────────────────────
let pi = Array1::from_vec(vec![0.5, 0.5]);
let a  = Array2::from_shape_vec((2,2), vec![0.7, 0.3, 0.4, 0.6]).unwrap();
let b  = Array2::from_shape_vec((2,3), vec![0.5, 0.3, 0.2, 0.2, 0.4, 0.4]).unwrap();
let hmm = Hmm::from_params(pi, a, b).unwrap();

// ── generate synthetic data ──────────────────────────────
let (hidden_states, observations) = hmm.generate(&mut rng, 1000);

// ── train (can pass multiple sequences) ─────────────────
algorithms::baum_welch(&mut hmm, &[observations.clone()], 100, 1e-5);

// ── decode the most likely state path ────────────────────
let path  = algorithms::viterbi(&hmm, &observations);

// ── posterior state probabilities at each step ───────────
let gamma = algorithms::smooth(&hmm, &observations);

// ── score a sequence ─────────────────────────────────────
let (_, scale) = algorithms::forward(&hmm, &observations);
let ll = algorithms::log_likelihood(&scale);
Function Sig Returns
forward (&Hmm, &[usize]) (Array2<f64>, Vec<f64>) — scaled alpha, c factors
backward (&Hmm, &[usize], &[f64]) Array2<f64> — scaled beta
log_likelihood (&[f64]) f64 — log P(obs | model)
viterbi (&Hmm, &[usize]) Vec<usize> — best state path
smooth (&Hmm, &[usize]) Array2<f64> — gamma posteriors
baum_welch (&mut Hmm, &[Vec<usize>], usize, f64) f64 — final log‑likelihood

Full signatures and examples → docs/API.md


📸 Examples

Casino — cargo run --example casino

A 2‑state HMM (fair die / loaded die). Trains with 3 random restarts, picks the best, decodes the hidden states. Printed with ANSI colours:

  True:    LLLLLLLLLLLLLLLLLLLLLLFFFFFFFFFFFFFFFFFFFFFFFFFF...LLL...FFFFFFFF...LLL...
  Decoded: LLLLLLLLLLLLLLLLLLLLLLFFFFFFFFFFFFFFFFFFFFFFFFFFFF...LLL...FFFFFFFF...LLL...

  Accuracy: 71.7% (717 / 1000 correctly labelled)

  First 20 smoothed posteriors P(loaded | obs):
    t= 0  [████████████████████████████████████████] 1.000
    t= 1  [███████████████████████████████████████]  0.994
    t= 2  [███████████████████████████████████████]  0.991
    ...

  Learned A = [[0.958, 0.042], [0.025, 0.975]]
  (true A    = [[0.950, 0.050], [0.100, 0.900]])

State persistence correctly recovered — the model learned that regimes last many steps, not just one.

Stock regime — cargo run --example stock_regime

3‑state HMM (bear / sideways / bull) on discretized returns. Prints a timeline block:

  True:  ▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄...bear/side/bull
  Model: ▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄...

  Accuracy: 74.5% (447 / 600 correct)

📚 Documentation

  • docs/API.md — every public function, signatures, edge cases
  • docs/GETTING_STARTED.md — install, first model, gotchas
  • Rabiner 1989 — the canonical reference. If you implement an HMM and don't cite this paper, did you even implement an HMM?

🧪 Running tests

cargo test

# 17 tests:
#   ✓ row‑stochastic validation
#   ✓ forward scaling
#   ✓ log‑likelihood is finite & negative
#   ✓ gamma rows sum to 1
#   ✓ viterbi path = observation length
#   ✓ viterbi on deterministic model
#   ✓ baum‑welch increases likelihood
#   ✓ baum‑welch keeps params valid
#   ✓ no underflow on 2000‑step sequence
#   ✓ deterministic parameter recovery
#   ... and 7 more

📦 Dependencies

ndarray, rand, rand_distr. That's it. No BLAS, no LAPACK, no nalgebra. The matrices are tiny — we're doing HMMs, not training GPT. Total dependency tree pulls ~20 crates.


📄 License

MIT. Do whatever you want. PRs welcome — just match the vibe.

About

Hidden markov model implementation in Rust (performant) programming language

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages