00001 #pragma once
00002 #ifndef OPENGM_MESSAGE_PASSING_HXX
00003 #define OPENGM_MESSAGE_PASSING_HXX
00004
00005 #include <vector>
00006 #include <map>
00007 #include <list>
00008 #include <set>
00009
00010 #include "opengm/opengm.hxx"
00011 #include "opengm/inference/inference.hxx"
00012 #include "opengm/inference/messagepassing/messagepassing_trbp.hxx"
00013 #include "opengm/inference/messagepassing/messagepassing_bp.hxx"
00014 #include "opengm/utilities/tribool.hxx"
00015 #include "opengm/utilities/metaprogramming.hxx"
00016 #include "opengm/operations/maximizer.hxx"
00017 #include "opengm/inference/visitors/visitor.hxx"
00018
00019 namespace opengm {
00020
00023 struct MaxDistance {
00027 template<class M>
00028 static typename M::ValueType
00029 op(const M& in1, const M& in2)
00030 {
00031 typedef typename M::ValueType ValueType;
00032 ValueType v1,v2,d1,d2;
00033 Maximizer::neutral(v1);
00034 Maximizer::neutral(v2);
00035 for(size_t n=0; n<in1.size(); ++n) {
00036 d1=in1(n)-in2(n);
00037 d2=-d1;
00038 Maximizer::op(d1,v1);
00039 Maximizer::op(d2,v2);
00040 }
00041 Maximizer::op(v2,v1);
00042 return v1;
00043 }
00044 };
00045
00048 template<class GM, class ACC, class UPDATE_RULES, class DIST=opengm::MaxDistance>
00049 class MessagePassing : public Inference<GM, ACC> {
00050 public:
00051 typedef GM GraphicalModelType;
00052 typedef ACC Accumulation;
00053 typedef ACC AccumulatorType;
00054 OPENGM_GM_TYPE_TYPEDEFS;
00055 typedef DIST Distance;
00056 typedef typename UPDATE_RULES::FactorHullType FactorHullType;
00057 typedef typename UPDATE_RULES::VariableHullType VariableHullType;
00058
00060 typedef VerboseVisitor<MessagePassing<GM, ACC, UPDATE_RULES, DIST> > VerboseVisitorType;
00062 typedef TimingVisitor<MessagePassing<GM, ACC, UPDATE_RULES, DIST> > TimingVisitorType;
00064 typedef EmptyVisitor<MessagePassing<GM, ACC, UPDATE_RULES, DIST> > EmptyVisitorType;
00065
00066 struct Parameter {
00067 typedef typename UPDATE_RULES::SpecialParameterType SpecialParameterType;
00068 Parameter
00069 (
00070 const size_t maximumNumberOfSteps = 100,
00071 const ValueType bound = static_cast<ValueType> (0.000000),
00072 const ValueType damping = static_cast<ValueType> (0),
00073 const SpecialParameterType & specialParameter =SpecialParameterType(),
00074 const opengm::Tribool isAcyclic = opengm::Tribool::Maybe
00075 )
00076 : maximumNumberOfSteps_(maximumNumberOfSteps),
00077 bound_(bound),
00078 damping_(damping),
00079 inferSequential_(false),
00080 useNormalization_(true),
00081 specialParameter_(specialParameter),
00082 isAcyclic_(isAcyclic)
00083 {}
00084
00085 size_t maximumNumberOfSteps_;
00086 ValueType bound_;
00087 ValueType damping_;
00088 bool inferSequential_;
00089 std::vector<size_t> sortedNodeList_;
00090 bool useNormalization_;
00091 SpecialParameterType specialParameter_;
00092 opengm::Tribool isAcyclic_;
00093 };
00094
00096 struct Message {
00097 Message()
00098 : nodeId_(-1),
00099 internalMessageId_(-1)
00100 {}
00101 Message(const size_t nodeId, const size_t & internalMessageId)
00102 : nodeId_(nodeId),
00103 internalMessageId_(internalMessageId)
00104 {}
00105
00106 size_t nodeId_;
00107 size_t internalMessageId_;
00108 };
00110
00111 MessagePassing(const GraphicalModelType&, const Parameter& = Parameter());
00112 std::string name() const;
00113 const GraphicalModelType& graphicalModel() const;
00114 InferenceTermination marginal(const size_t, IndependentFactorType& out) const;
00115 InferenceTermination factorMarginal(const size_t, IndependentFactorType & out) const;
00116 ValueType convergenceXF() const;
00117 ValueType convergenceFX() const;
00118 ValueType convergence() const;
00119 virtual void reset();
00120 InferenceTermination infer();
00121 template<class VisitorType>
00122 InferenceTermination infer(VisitorType&);
00123 void propagate(const ValueType& = 0);
00124 InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
00125
00126
00127
00128 private:
00129 void inferAcyclic();
00130 void inferParallel();
00131 void inferSequential();
00132 template<class VisitorType>
00133 void inferParallel(VisitorType&);
00134 template<class VisitorType>
00135 void inferAcyclic(VisitorType&);
00136 template<class VisitorType>
00137 void inferSequential(VisitorType&);
00138 private:
00139 const GraphicalModelType& gm_;
00140 Parameter parameter_;
00141 std::vector<FactorHullType> factorHulls_;
00142 std::vector<VariableHullType> variableHulls_;
00143 };
00144
00145 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00146 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::MessagePassing
00147 (
00148 const GraphicalModelType& gm,
00149 const typename MessagePassing<GM, ACC, UPDATE_RULES, DIST>::Parameter& parameter
00150 )
00151 : gm_(gm),
00152 parameter_(parameter)
00153 {
00154 if(parameter_.sortedNodeList_.size() == 0) {
00155 parameter_.sortedNodeList_.resize(gm.numberOfVariables());
00156 for (size_t i = 0; i < gm.numberOfVariables(); ++i)
00157 parameter_.sortedNodeList_[i] = i;
00158 }
00159 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm.numberOfVariables());
00160
00161 UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
00162
00163
00164 variableHulls_.resize(gm.numberOfVariables(), VariableHullType ());
00165 for (size_t i = 0; i < gm.numberOfVariables(); ++i) {
00166 variableHulls_[i].assign(gm, i, ¶meter_.specialParameter_);
00167 }
00168 factorHulls_.resize(gm.numberOfFactors(), FactorHullType ());
00169 for (size_t i = 0; i < gm.numberOfFactors(); i++) {
00170 factorHulls_[i].assign(gm, i, variableHulls_, ¶meter_.specialParameter_);
00171 }
00172 }
00173
00174 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00175 void
00176 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::reset()
00177 {
00178 if(parameter_.sortedNodeList_.size() == 0) {
00179 parameter_.sortedNodeList_.resize(gm_.numberOfVariables());
00180 for (size_t i = 0; i < gm_.numberOfVariables(); ++i)
00181 parameter_.sortedNodeList_[i] = i;
00182 }
00183 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
00184 UPDATE_RULES::initializeSpecialParameter(gm_,this->parameter_);
00185
00186
00187 variableHulls_.resize(gm_.numberOfVariables(), VariableHullType ());
00188 for (size_t i = 0; i < gm_.numberOfVariables(); ++i) {
00189 variableHulls_[i].assign(gm_, i, ¶meter_.specialParameter_);
00190 }
00191 factorHulls_.resize(gm_.numberOfFactors(), FactorHullType ());
00192 for (size_t i = 0; i < gm_.numberOfFactors(); i++) {
00193 factorHulls_[i].assign(gm_, i, variableHulls_, ¶meter_.specialParameter_);
00194 }
00195 }
00196
00197 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00198 inline std::string
00199 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::name() const {
00200 return "MP";
00201 }
00202
00203 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00204 inline const typename MessagePassing<GM, ACC, UPDATE_RULES, DIST>::GraphicalModelType&
00205 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::graphicalModel() const {
00206 return gm_;
00207 }
00208
00209 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00210 inline InferenceTermination
00211 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::infer() {
00212 EmptyVisitorType v;
00213 return infer(v);
00214 }
00215
00216 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00217 template<class VisitorType>
00218 inline InferenceTermination
00219 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::infer
00220 (
00221 VisitorType& visitor
00222 ) {
00223 if (parameter_.isAcyclic_ == opengm::Tribool::True) {
00224 parameter_.useNormalization_=false;
00225 inferAcyclic(visitor);
00226 } else if (parameter_.isAcyclic_ == opengm::Tribool::False) {
00227 if (parameter_.inferSequential_) {
00228 inferSequential(visitor);
00229 } else {
00230 inferParallel(visitor);
00231 }
00232 } else {
00233 if (gm_.isAcyclic()) {
00234 parameter_.isAcyclic_ = opengm::Tribool::True;
00235 parameter_.useNormalization_=false;
00236 inferAcyclic(visitor);
00237 } else {
00238 parameter_.isAcyclic_ = opengm::Tribool::False;
00239 if (parameter_.inferSequential_) {
00240 inferSequential(visitor);
00241 } else {
00242 inferParallel(visitor);
00243 }
00244 }
00245 }
00246 return NORMAL;
00247 }
00248
00254 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00255 inline void
00256 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferAcyclic() {
00257 EmptyVisitorType v;
00258 return infer(v);
00259 }
00260
00262
00268 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00269 template<class VisitorType>
00270 void
00271 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferAcyclic
00272 (
00273 VisitorType& visitor
00274 )
00275 {
00276 OPENGM_ASSERT(gm_.isAcyclic());
00277 visitor.begin(*this);
00278 size_t numberOfVariables = gm_.numberOfVariables();
00279 size_t numberOfFactors = gm_.numberOfFactors();
00280
00281
00282 std::vector<std::vector<size_t> > counterVar2FacMessage(numberOfVariables);
00283 std::vector<std::vector<size_t> > counterFac2VarMessage(numberOfFactors);
00284
00285 std::vector<Message> ready2SendVar2FacMessage;
00286 std::vector<Message> ready2SendFac2VarMessage;
00287 ready2SendVar2FacMessage.reserve(100);
00288 ready2SendFac2VarMessage.reserve(100);
00289 for (size_t fac = 0; fac < numberOfFactors; ++fac) {
00290 counterFac2VarMessage[fac].resize(gm_[fac].numberOfVariables(), gm_[fac].numberOfVariables() - 1);
00291 }
00292 for (size_t var = 0; var < numberOfVariables; ++var) {
00293 counterVar2FacMessage[var].resize(gm_.numberOfFactors(var));
00294 for (size_t i = 0; i < gm_.numberOfFactors(var); ++i) {
00295 counterVar2FacMessage[var][i] = gm_.numberOfFactors(var) - 1;
00296 }
00297 }
00298
00299 for (size_t var = 0; var < numberOfVariables; ++var) {
00300 for (size_t i = 0; i < counterVar2FacMessage[var].size(); ++i) {
00301 if (counterVar2FacMessage[var][i] == 0) {
00302 --counterVar2FacMessage[var][i];
00303 ready2SendVar2FacMessage.push_back(Message(var, i));
00304 }
00305 }
00306 }
00307 for (size_t fac = 0; fac < numberOfFactors; ++fac) {
00308 for (size_t i = 0; i < counterFac2VarMessage[fac].size(); ++i) {
00309 if (counterFac2VarMessage[fac][i] == 0) {
00310 --counterFac2VarMessage[fac][i];
00311 ready2SendFac2VarMessage.push_back(Message(fac, i));
00312 }
00313 }
00314 }
00315
00316 while (ready2SendVar2FacMessage.size() > 0 || ready2SendFac2VarMessage.size() > 0) {
00317 while (ready2SendVar2FacMessage.size() > 0) {
00318 Message m = ready2SendVar2FacMessage.back();
00319 size_t nodeId = m.nodeId_;
00320 size_t factorId = gm_.factorOfVariable(nodeId,m.internalMessageId_);
00321
00322 variableHulls_[nodeId].propagate(gm_, m.internalMessageId_, 0, false);
00323 ready2SendVar2FacMessage.pop_back();
00324
00325 for (size_t i = 0; i < gm_[factorId].numberOfVariables(); ++i) {
00326 if (gm_[factorId].variableIndex(i) != nodeId) {
00327 if (--counterFac2VarMessage[factorId][i] == 0) {
00328 ready2SendFac2VarMessage.push_back(Message(factorId, i));
00329 }
00330 }
00331 }
00332 }
00333 while (ready2SendFac2VarMessage.size() > 0) {
00334 Message m = ready2SendFac2VarMessage.back();
00335 size_t factorId = m.nodeId_;
00336 size_t nodeId = gm_[factorId].variableIndex(m.internalMessageId_);
00337
00338 factorHulls_[factorId].propagate(m.internalMessageId_, 0, parameter_.useNormalization_);
00339 ready2SendFac2VarMessage.pop_back();
00340
00341 for (size_t i = 0; i < gm_.numberOfFactors(nodeId); ++i) {
00342 if (gm_.factorOfVariable(nodeId,i) != factorId) {
00343 if (--counterVar2FacMessage[nodeId][i] == 0) {
00344 ready2SendVar2FacMessage.push_back(Message(nodeId, i));
00345 }
00346 }
00347 }
00348 }
00349 visitor(*this);
00350 }
00351 visitor.end(*this);
00352
00353 }
00354
00356 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00357 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::propagate
00358 (
00359 const ValueType& damping
00360 ) {
00361 for (size_t i = 0; i < variableHulls_.size(); ++i) {
00362 variableHulls_[i].propagateAll(damping, false);
00363 }
00364 for (size_t i = 0; i < factorHulls_.size(); ++i) {
00365 factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
00366 }
00367 }
00368
00370 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00371 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferParallel() {
00372 EmptyVisitorType v;
00373 return infer(v);
00374 }
00375
00378 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00379 template<class VisitorType>
00380 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferParallel
00381 (
00382 VisitorType& visitor
00383 )
00384 {
00385 ValueType c = 0;
00386 ValueType damping = parameter_.damping_;
00387 visitor.begin(*this);
00388
00389
00390 for (size_t i = 0; i < factorHulls_.size(); ++i) {
00391 if (factorHulls_[i].numberOfBuffers() < 2) {
00392 factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
00393 factorHulls_[i].propagateAll(0, parameter_.useNormalization_);
00394 }
00395 }
00396 for (unsigned long n = 0; n < parameter_.maximumNumberOfSteps_; ++n) {
00397 for (size_t i = 0; i < variableHulls_.size(); ++i) {
00398 variableHulls_[i].propagateAll(gm_, damping, false);
00399 }
00400 for (size_t i = 0; i < factorHulls_.size(); ++i) {
00401 if (factorHulls_[i].numberOfBuffers() >= 2)
00402 factorHulls_[i].propagateAll(damping, parameter_.useNormalization_);
00403 }
00404 visitor(*this);
00405 c = convergence();
00406 if (c < parameter_.bound_) {
00407 break;
00408 }
00409 }
00410 visitor.end(*this);
00411
00412 }
00413
00422 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00423 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferSequential() {
00424 EmptyVisitorType v;
00425 return infer(v);
00426 }
00427
00438 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00439 template<class VisitorType>
00440 inline void MessagePassing<GM, ACC, UPDATE_RULES, DIST>::inferSequential
00441 (
00442 VisitorType& visitor
00443 ) {
00444 OPENGM_ASSERT(parameter_.sortedNodeList_.size() == gm_.numberOfVariables());
00445 visitor.begin(*this);
00446 ValueType damping = parameter_.damping_;
00447
00448
00449 std::vector<size_t> nodeOrder(gm_.numberOfVariables());
00450 for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
00451 nodeOrder[parameter_.sortedNodeList_[o]] = o;
00452 }
00453
00454
00455 for (size_t f = 0; f < factorHulls_.size(); ++f) {
00456 if (factorHulls_[f].numberOfBuffers() < 2) {
00457 factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
00458 factorHulls_[f].propagateAll(0, parameter_.useNormalization_);
00459 }
00460 }
00461
00462
00463 std::vector<std::vector<size_t> > inversePositions(gm_.numberOfVariables());
00464 for(size_t var=0; var<gm_.numberOfVariables();++var) {
00465 for(size_t i=0; i<gm_.numberOfFactors(var); ++i) {
00466 size_t factorId = gm_.factorOfVariable(var,i);
00467 for(size_t j=0; j<gm_.numberOfVariables(factorId);++j) {
00468 if(gm_.variableOfFactor(factorId,j)==var) {
00469 inversePositions[var].push_back(j);
00470 break;
00471 }
00472 }
00473 }
00474 }
00475
00476
00477
00478 for (unsigned long itteration = 0; itteration < parameter_.maximumNumberOfSteps_; ++itteration) {
00479 if(itteration%2==0) {
00480
00481 for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
00482 size_t variableId = parameter_.sortedNodeList_[o];
00483
00484 for(size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
00485 size_t factorId = gm_.factorOfVariable(variableId,i);
00486 factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
00487 }
00488
00489
00490 variableHulls_[variableId].propagateAll(gm_, damping, false);
00491 }
00492 }
00493 else{
00494
00495 for (size_t o = 0; o < gm_.numberOfVariables(); ++o) {
00496 size_t variableId = parameter_.sortedNodeList_[gm_.numberOfVariables() - 1 - o];
00497
00498 for(size_t i=0; i<gm_.numberOfFactors(variableId); ++i) {
00499 size_t factorId = gm_.factorOfVariable(variableId,i);
00500 factorHulls_[factorId].propagate(inversePositions[variableId][i], damping, parameter_.useNormalization_);
00501 }
00502
00503 variableHulls_[variableId].propagateAll(gm_, damping, false);
00504 }
00505 }
00506 visitor(*this);
00507
00508 }
00509 visitor.end(*this);
00510 }
00511
00512 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00513 inline InferenceTermination
00514 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::marginal
00515 (
00516 const size_t variableIndex,
00517 IndependentFactorType & out
00518 ) const {
00519 OPENGM_ASSERT(variableIndex < variableHulls_.size());
00520 variableHulls_[variableIndex].marginal(gm_, variableIndex, out, parameter_.useNormalization_);
00521 return NORMAL;
00522 }
00523
00524 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00525 inline InferenceTermination
00526 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::factorMarginal
00527 (
00528 const size_t factorIndex,
00529 IndependentFactorType &out
00530 ) const {
00531 typedef typename GM::OperatorType OP;
00532 OPENGM_ASSERT(factorIndex < factorHulls_.size());
00533 out.assign(gm_, gm_[factorIndex].variableIndicesBegin(), gm_[factorIndex].variableIndicesEnd(), OP::template neutral<ValueType>());
00534 factorHulls_[factorIndex].marginal(out, parameter_.useNormalization_);
00535 return NORMAL;
00536 }
00537
00539 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00540 inline typename MessagePassing<GM, ACC, UPDATE_RULES, DIST>::ValueType
00541 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::convergenceXF() const {
00542 ValueType result = 0;
00543 for (size_t j = 0; j < factorHulls_.size(); ++j) {
00544 for (size_t i = 0; i < factorHulls_[j].numberOfBuffers(); ++i) {
00545 ValueType d = factorHulls_[j].template distance<DIST > (i);
00546 if (d > result) {
00547 result = d;
00548 }
00549 }
00550 }
00551 return result;
00552 }
00553
00555 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00556 inline typename MessagePassing<GM, ACC, UPDATE_RULES, DIST>::ValueType
00557 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::convergenceFX() const {
00558 ValueType result = 0;
00559 for (size_t j = 0; j < variableHulls_.size(); ++j) {
00560 for (size_t i = 0; i < variableHulls_[j].numberOfBuffers(); ++i) {
00561 ValueType d = variableHulls_[j].template distance<DIST > (i);
00562 if (d > result) {
00563 result = d;
00564 }
00565 }
00566 }
00567 return result;
00568 }
00569
00571 template<class GM, class ACC, class UPDATE_RULES, class DIST>
00572 inline typename MessagePassing<GM, ACC, UPDATE_RULES, DIST>::ValueType
00573 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::convergence() const {
00574 return convergenceXF();
00575 }
00576
00577 template<class GM, class ACC,class UPDATE_RULES, class DIST >
00578 inline InferenceTermination
00579 MessagePassing<GM, ACC, UPDATE_RULES, DIST>::arg
00580 (
00581 std::vector<LabelType>& conf,
00582 const size_t N
00583 ) const {
00584 if (N != 1) {
00585 throw RuntimeError("This implementation of message passing cannot return the k-th optimal configuration.");
00586 }
00587 else {
00588 if (parameter_.isAcyclic_ == opengm::Tribool::True) {
00589 return this->modeFromFactorMarginal(conf);
00590 }
00591 else {
00592 return this->modeFromFactorMarginal(conf);
00593
00594 }
00595 }
00596 }
00597
00598 }
00599
00600 #endif // #ifndef OPENGM_BELIEFPROPAGATION_HXX