Revised Belief Propagation #26
Conversation
Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling.
…be set from another cache
Also includes some fixes to the way `TensorNetwork` types are constructed based on index structure.
for more information, see https://pre-commit.ci
…instead of trying to operate on existing graphs The reason for this is: - One only cares about the edges of the input graph - A simple graph cannot be used as it "forgets" its edge names resulting in recursion - As shown with `TensorNetwork`, removing edges may not always be defined.
…s from an array.
This was caused by the change to the `cache` being backed by a directed graph.
| function messages(all_messages::AbstractGraph) | ||
| return map(edge -> message(all_messages, edge), edges(all_messages)) | ||
| end |
There was a problem hiding this comment.
Can't this be defined as messages(all_messages::AbstractGraph) = messages(all_messages, edges(all_messages))?
There was a problem hiding this comment.
(Also all_messages is slightly strange to me as an argument name for this.)
| function _message_cache_underlying_graph(graph::AbstractGraph) | ||
| digraph = similar_graph(NamedDiGraph, vertices(graph)) | ||
| for edge in edges(graph) | ||
| add_edge!(digraph, edge) | ||
| if !is_directed(graph) | ||
| add_edge!(digraph, reverse(edge)) | ||
| end | ||
| end | ||
| return digraph | ||
| end |
There was a problem hiding this comment.
Couldn't we just use a conversion, i.e. NamedDiGraph(graph::AbstractGraph), instead of introducing a new function for this?
| return cache | ||
| end | ||
|
|
||
| function beliefpropagation(network::AbstractGraph, messages::Dictionary; kwargs...) |
There was a problem hiding this comment.
| function beliefpropagation(network::AbstractGraph, messages::Dictionary; kwargs...) | |
| function beliefpropagation(network::AbstractGraph, messages; kwargs...) |
I think it would be better to accept any iterator of messages, then it is up to the MessageCache constructor to convert to a Dictionary (but that is just an internal detail).
| return beliefpropagation(network, cache; kwargs...) | ||
| end | ||
|
|
||
| function beliefpropagation(network, cache; kwargs...) |
There was a problem hiding this comment.
| function beliefpropagation(network, cache; kwargs...) | |
| function beliefpropagation(network::AbstractGraph, cache::MessageCache; kwargs...) |
reflecting the comment above.
|
|
||
| struct MessageCache{MT, V, E} <: AbstractDataGraph{V, Nothing, MT} | ||
| messages::Dictionary{E, MT} | ||
| underlying_graph::NamedDiGraph{V} |
There was a problem hiding this comment.
Honestly I was picturing the MessageCache would just be a dictionary, and wouldn't even have a graph stored. Can't the graph be determined from the network that is passed around in the problem? Is a graph really needed here? The edges are defined as the keys of the dictionary anyway.
… the `stopping_criterion` kwarg.
Co-authored-by: Matthew Fishman <mtfishman@users.noreply.github.com>
| maxiter, | ||
| tol, | ||
| stopping_criterion |
There was a problem hiding this comment.
What I was picturing was just having a single keyword argument stopping_criterion. Users can either pass an actual stopping criterion object, or a NamedTuple of keyword arguments, i.e.:
beliefpropagation(network, cache; stopping_criterion = (; maxiter = 10, tol = 1e-4))
beliefpropagation(network, cache; stopping_criterion = StopAfterIteration(10))Then there would be a single processing function that takes a NamedTuple and constructs the proper stopping criterion object, or if a stopping criteria is passed, it returns that object. See the design of how alg and trunc process NamedTuple into algorithm and truncation objects in MatrixAlgebraKit.
| if cache1.underlying_graph != cache2.underlying_graph | ||
| return false | ||
| elseif cache1.messages != cache2.messages | ||
| return false | ||
| end | ||
| return true |
There was a problem hiding this comment.
| if cache1.underlying_graph != cache2.underlying_graph | |
| return false | |
| elseif cache1.messages != cache2.messages | |
| return false | |
| end | |
| return true | |
| return (cache1.underlying_graph == cache2.underlying_graph) && (cache1.messages == cache2.messages) |
This PR express belief propagation in terms of the new interface based on AlgorithmsInterface.jl and the included AlgorithmsInterfaceExtensions.jl library.