00001 #pragma once
00002 #ifndef OPENGM_ALPHABEATSWAP_HXX
00003 #define OPENGM_ALPHABETASWAP_HXX
00004
00005 #include <vector>
00006
00007 #include "opengm/inference/inference.hxx"
00008 #include "opengm/inference/visitors/visitor.hxx"
00009
00010 namespace opengm {
00011
00014 template<class GM, class INF>
00015 class AlphaBetaSwap : public Inference<GM, typename INF::AccumulationType> {
00016 public:
00017 typedef GM GraphicalModelType;
00018 typedef INF InferenceType;
00019 typedef typename INF::AccumulationType AccumulationType;
00020 OPENGM_GM_TYPE_TYPEDEFS;
00021 typedef VerboseVisitor<AlphaBetaSwap<GM,INF> > VerboseVisitorType;
00022 typedef TimingVisitor<AlphaBetaSwap<GM,INF> > TimingVisitorType;
00023 typedef EmptyVisitor<AlphaBetaSwap<GM,INF> > EmptyVisitorType;
00024
00025 struct Parameter {
00026 Parameter() {
00027 maxNumberOfIterations_ = 1000;
00028 }
00029
00030 typename InferenceType::Parameter parameter_;
00031 size_t maxNumberOfIterations_;
00032 };
00033
00034 AlphaBetaSwap(const GraphicalModelType&, Parameter = Parameter());
00035 std::string name() const;
00036 const GraphicalModelType& graphicalModel() const;
00037 InferenceTermination infer();
00038 template<class VISITOR>
00039 InferenceTermination infer(VISITOR & );
00040 void reset();
00041 void setStartingPoint(typename std::vector<LabelType>::const_iterator);
00042 InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
00043
00044 private:
00045 const GraphicalModelType& gm_;
00046 Parameter parameter_;
00047 std::vector<LabelType> label_;
00048 size_t alpha_;
00049 size_t beta_;
00050 size_t maxState_;
00051 void increment();
00052 void addUnary(INF&, const size_t var, const ValueType v0, const ValueType v1);
00053 void addPairwise(INF&, const size_t var1, const size_t var2, const ValueType v0, const ValueType v1, const ValueType v2, const ValueType v3);
00054 };
00055
00056
00057 template<class GM, class INF>
00058 inline void
00059 AlphaBetaSwap<GM, INF>::reset() {
00060 alpha_ = 0;
00061 beta_ = 0;
00062 std::fill(label_.begin(),label_.end(),0);
00063 }
00064
00065 template<class GM, class INF>
00066 inline void
00067 AlphaBetaSwap<GM, INF>::increment() {
00068 if (++beta_ >= maxState_) {
00069 if (++alpha_ >= maxState_ - 1) {
00070 alpha_ = 0;
00071 }
00072 beta_ = alpha_ + 1;
00073 }
00074 OPENGM_ASSERT(alpha_ < maxState_);
00075 OPENGM_ASSERT(beta_ < maxState_);
00076 OPENGM_ASSERT(alpha_ < beta_);
00077 }
00078
00079 template<class GM, class INF>
00080 inline std::string
00081 AlphaBetaSwap<GM, INF>::name() const {
00082 return "Alpha-Beta-Swap";
00083 }
00084
00085 template<class GM, class INF>
00086 inline const typename AlphaBetaSwap<GM, INF>::GraphicalModelType&
00087 AlphaBetaSwap<GM, INF>::graphicalModel() const {
00088 return gm_;
00089 }
00090
00091 template<class GM, class INF>
00092 inline AlphaBetaSwap<GM, INF>::AlphaBetaSwap
00093 (
00094 const GraphicalModelType& gm,
00095 Parameter para
00096 )
00097 : gm_(gm)
00098 {
00099 parameter_ = para;
00100 label_.resize(gm_.numberOfVariables(), 0);
00101 alpha_ = 0;
00102 beta_ = 0;
00103 for (size_t j = 0; j < gm_.numberOfFactors(); ++j) {
00104 if (gm_[j].numberOfVariables() > 2) {
00105 throw RuntimeError("This implementation of Alpha-Beta-Swap supports only factors of order <= 2.");
00106 }
00107 }
00108 maxState_ = 0;
00109 for (size_t i = 0; i < gm_.numberOfVariables(); ++i) {
00110 size_t numSt = gm_.numberOfLabels(i);
00111 if (numSt > maxState_)
00112 maxState_ = numSt;
00113 }
00114 }
00115
00116 template<class GM, class INF>
00117 inline void
00118 AlphaBetaSwap<GM,INF>::setStartingPoint
00119 (
00120 typename std::vector<typename AlphaBetaSwap<GM,INF>::LabelType>::const_iterator begin
00121 ) {
00122 try{
00123 label_.assign(begin, begin+gm_.numberOfVariables());
00124 }
00125 catch(...) {
00126 throw RuntimeError("unsuitable starting point");
00127 }
00128 }
00129
00130 template<class GM, class INF>
00131 inline void
00132 AlphaBetaSwap<GM, INF>::addUnary
00133 (
00134 INF& inf,
00135 const size_t var1,
00136 const ValueType v0,
00137 const ValueType v1
00138 ) {
00139 const size_t shape[] = {2};
00140 const size_t vars[] = {var1};
00141 opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 1, shape, shape + 1);
00142 fac(0) = v0;
00143 fac(1) = v1;
00144 inf.addFactor(fac);
00145 }
00146
00147 template<class GM, class INF>
00148 inline void
00149 AlphaBetaSwap<GM, INF>::addPairwise
00150 (
00151 INF& inf,
00152 const size_t var1,
00153 const size_t var2,
00154 const ValueType v0,
00155 const ValueType v1,
00156 const ValueType v2,
00157 const ValueType v3
00158 ) {
00159 const size_t shape[] = {2, 2};
00160 const size_t vars[] = {var1, var2};
00161 opengm::IndependentFactor<ValueType,IndexType,LabelType> fac(vars, vars + 2, shape, shape + 2);
00162 fac(0, 0) = v0;
00163 fac(0, 1) = v1;
00164 fac(1, 0) = v2;
00165 fac(1, 1) = v3;
00166 OPENGM_ASSERT(v1 + v2 - v0 - v3 >= 0);
00167 inf.addFactor(fac);
00168 }
00169 template<class GM, class INF>
00170 InferenceTermination
00171 AlphaBetaSwap<GM, INF>::infer() {
00172 EmptyVisitorType v;
00173 return infer(v);
00174 }
00175
00176 template<class GM, class INF>
00177 template<class VISITOR>
00178 InferenceTermination
00179 AlphaBetaSwap<GM, INF>::infer
00180 (
00181 VISITOR & visitor
00182 ) {
00183 visitor.begin(*this,0,0);
00184
00185 size_t it = 0;
00186 size_t countUnchanged = 0;
00187 size_t numberOfVariables = gm_.numberOfVariables();
00188 std::vector<size_t> variable2Node(numberOfVariables, 0);
00189 ValueType energy = gm_.evaluate(label_);
00190 size_t vecA[1];
00191 size_t vecB[1];
00192 size_t vecAA[2];
00193 size_t vecAB[2];
00194 size_t vecBA[2];
00195 size_t vecBB[2];
00196 size_t vecAX[2];
00197 size_t vecBX[2];
00198 size_t vecXA[2];
00199 size_t vecXB[2];
00200 size_t numberOfLabelPairs = maxState_*(maxState_ - 1)/2;
00201 while (it++ < parameter_.maxNumberOfIterations_ && countUnchanged < numberOfLabelPairs) {
00202 increment();
00203 size_t counter = 0;
00204 std::vector<size_t> numFacDim(4, 0);
00205 for (size_t i = 0; i < numberOfVariables; ++i) {
00206 if (label_[i] == alpha_ || label_[i] == beta_) {
00207 variable2Node[i] = counter++;
00208 }
00209 }
00210 if (counter == 0) {
00211 continue;
00212 }
00213 INF inf(counter, numFacDim);
00214 vecA[0] = alpha_;
00215 vecB[0] = beta_;
00216 vecAA[0] = alpha_;
00217 vecAA[1] = alpha_;
00218 vecBB[0] = beta_;
00219 vecBB[1] = beta_;
00220 vecBA[0] = beta_;
00221 vecBA[1] = alpha_;
00222 vecAB[0] = alpha_;
00223 vecAB[1] = beta_;
00224 vecAX[0] = alpha_;
00225 vecBX[0] = beta_;
00226 vecXA[1] = alpha_;
00227 vecXB[1] = beta_;
00228 for (size_t k = 0; k < gm_.numberOfFactors(); ++k) {
00229 const FactorType& factor = gm_[k];
00230 if (factor.numberOfVariables() == 1) {
00231 size_t var = factor.variableIndex(0);
00232 size_t node = variable2Node[var];
00233 if (label_[var] == alpha_ || label_[var] == beta_) {
00234 OPENGM_ASSERT(alpha_ < gm_.numberOfLabels(var));
00235 OPENGM_ASSERT(beta_ < gm_.numberOfLabels(var));
00236 addUnary(inf, node, factor(vecA), factor(vecB));
00237
00238 }
00239 } else if (factor.numberOfVariables() == 2) {
00240 size_t var1 = factor.variableIndex(0);
00241 size_t var2 = factor.variableIndex(1);
00242 size_t node1 = variable2Node[var1];
00243 size_t node2 = variable2Node[var2];
00244
00245 if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] == alpha_ || label_[var2] == beta_)) {
00246 addPairwise(inf, node1, node2, factor(vecAA), factor(vecAB), factor(vecBA), factor(vecBB));
00247
00248 } else if ((label_[var1] == alpha_ || label_[var1] == beta_) && (label_[var2] != alpha_ && label_[var2] != beta_)) {
00249 vecAX[1] = vecBX[1] = label_[var2];
00250 addUnary(inf, node1, factor(vecAX), factor(vecBX));
00251
00252 } else if ((label_[var2] == alpha_ || label_[var2] == beta_) && (label_[var1] != alpha_ && label_[var1] != beta_)) {
00253 vecXA[0] = vecXB[0] = label_[var1];
00254 addUnary(inf, node2, factor(vecXA), factor(vecXB));
00255
00256 }
00257 }
00258 }
00259 std::vector<LabelType> state;
00260 inf.infer();
00261 inf.arg(state);
00262 OPENGM_ASSERT(state.size() == counter);
00263 for (size_t var = 0; var < numberOfVariables; ++var) {
00264 if (label_[var] == alpha_ || label_[var] == beta_) {
00265 if (state[variable2Node[var]] == 0)
00266 label_[var] = alpha_;
00267 else
00268 label_[var] = beta_;
00269 } else {
00270
00271 }
00272 }
00273 ValueType energy2 = gm_.evaluate(label_);
00274 visitor(*this,energy2,energy);
00275 OPENGM_ASSERT(!AccumulationType::ibop(energy2, energy));
00276 if (AccumulationType::bop(energy2, energy)) {
00277 energy = energy2;
00278 } else {
00279 ++countUnchanged;
00280 }
00281 }
00282 visitor.end(*this,energy,energy);
00283 return NORMAL;
00284 }
00285
00286 template<class GM, class INF>
00287 inline InferenceTermination
00288 AlphaBetaSwap<GM, INF>::arg(std::vector<LabelType>& arg, const size_t n) const {
00289 if (n > 1) {
00290 return UNKNOWN;
00291 } else {
00292 OPENGM_ASSERT(label_.size() == gm_.numberOfVariables());
00293 arg.resize(label_.size());
00294 for (size_t i = 0; i < label_.size(); ++i)
00295 arg[i] = label_[i];
00296 return NORMAL;
00297 }
00298 }
00299
00300 }
00301
00302 #endif // #ifndef OPENGM_ALPHABEATSWAP_HXX