00001 #pragma once
00002 #ifndef OPENGM_INFERENCE_HXX
00003 #define OPENGM_INFERENCE_HXX
00004
00005 #include <vector>
00006 #include <string>
00007 #include <list>
00008 #include <limits>
00009 #include <exception>
00010
00011 #include "opengm/opengm.hxx"
00012
00013 #define OPENGM_GM_TYPE_TYPEDEFS \
00014 typedef typename GraphicalModelType::LabelType LabelType; \
00015 typedef typename GraphicalModelType::IndexType IndexType; \
00016 typedef typename GraphicalModelType::ValueType ValueType; \
00017 typedef typename GraphicalModelType::OperatorType OperatorType; \
00018 typedef typename GraphicalModelType::FactorType FactorType; \
00019 typedef typename GraphicalModelType::IndependentFactorType IndependentFactorType; \
00020 typedef typename GraphicalModelType::FunctionIdentifier FunctionIdentifier \
00021
00022 namespace opengm {
00023
00024 enum InferenceTermination {
00025 UNKNOWN=0,
00026 NORMAL=1,
00027 TIMEOUT=2,
00028 CONVERGENCE=3,
00029 INFERENCE_ERROR=4
00030 };
00031
00033 template <class GM, class ACC>
00034 class Inference
00035 {
00036 public:
00037 typedef GM GraphicalModelType;
00038 typedef ACC AccumulationType;
00039 typedef typename GraphicalModelType::LabelType LabelType;
00040 typedef typename GraphicalModelType::IndexType IndexType;
00041 typedef typename GraphicalModelType::ValueType ValueType;
00042 typedef typename GraphicalModelType::OperatorType OperatorType;
00043 typedef typename GraphicalModelType::FactorType FactorType;
00044 typedef typename GraphicalModelType::IndependentFactorType IndependentFactorType;
00045 typedef typename GraphicalModelType::FunctionIdentifier FunctionIdentifier;
00046
00047 virtual ~Inference() {}
00048
00049 virtual std::string name() const = 0;
00050 virtual const GraphicalModelType& graphicalModel() const = 0;
00051 virtual InferenceTermination infer() = 0;
00055
00056
00057 virtual void setStartingPoint(typename std::vector<LabelType>::const_iterator);
00058 virtual InferenceTermination arg(std::vector<LabelType>&, const size_t = 1) const;
00059 virtual InferenceTermination args(std::vector<std::vector<LabelType> >&) const;
00060 virtual InferenceTermination marginal(const size_t, IndependentFactorType&) const;
00061 virtual InferenceTermination factorMarginal(const size_t, IndependentFactorType&) const;
00062 virtual ValueType bound() const;
00063 virtual ValueType value() const;
00064 InferenceTermination constrainedOptimum(std::vector<IndexType>&,std::vector<LabelType>&, std::vector<LabelType>&) const;
00065 InferenceTermination modeFromMarginal(std::vector<LabelType>&) const;
00066 InferenceTermination modeFromFactorMarginal(std::vector<LabelType>&) const;
00067 };
00068
00072 template<class GM, class ACC>
00073 inline InferenceTermination
00074 Inference<GM, ACC>::arg(
00075 std::vector<LabelType>& arg,
00076 const size_t argIndex
00077 ) const
00078 {
00079 return UNKNOWN;
00080 }
00081
00084 template<class GM, class ACC>
00085 inline void
00086 Inference<GM, ACC>::setStartingPoint(
00087 typename std::vector<LabelType>::const_iterator begin
00088 )
00089 {}
00090
00091 template<class GM, class ACC>
00092 inline InferenceTermination
00093 Inference<GM, ACC>::args(
00094 std::vector<std::vector<LabelType> >& out
00095 ) const
00096 {
00097 return UNKNOWN;
00098 }
00099
00103 template<class GM, class ACC>
00104 inline InferenceTermination
00105 Inference<GM, ACC>::marginal(
00106 const size_t variableIndex,
00107 IndependentFactorType& out
00108 ) const
00109 {
00110 return UNKNOWN;
00111 }
00112
00116 template<class GM, class ACC>
00117 inline InferenceTermination
00118 Inference<GM, ACC>::factorMarginal(
00119 const size_t factorIndex,
00120 IndependentFactorType& out
00121 ) const
00122 {
00123 return UNKNOWN;
00124 }
00125
00126 template<class GM, class ACC>
00127 InferenceTermination
00128 Inference<GM, ACC>::constrainedOptimum(
00129 std::vector<IndexType>& variableIndices,
00130 std::vector<LabelType>& givenLabels,
00131 std::vector<LabelType>& conf
00132 ) const
00133 {
00134 const GM& gm = graphicalModel();
00135 std::vector<IndexType> waitingVariables;
00136 size_t variableId = 0;
00137 size_t numberOfVariables = gm.numberOfVariables();
00138 size_t numberOfFixedVariables = 0;
00139 conf.assign(gm.numberOfVariables(),std::numeric_limits<LabelType>::max());
00140 OPENGM_ASSERT(variableIndices.size()>=givenLabels.size());
00141 for(size_t i=0; i<givenLabels.size() ;++i) {
00142 OPENGM_ASSERT( variableIndices[i]<gm.numberOfVariables());
00143 OPENGM_ASSERT( givenLabels[i]<gm.numberOfLabels(variableIndices[i]));
00144 conf[variableIndices[i]] = givenLabels[i];
00145 waitingVariables.push_back(variableIndices[i]);
00146 ++numberOfFixedVariables;
00147 }
00148 while(variableId<gm.numberOfVariables() && numberOfFixedVariables<numberOfVariables) {
00149 while(waitingVariables.size()>0 && numberOfFixedVariables<numberOfVariables) {
00150 size_t var = waitingVariables.back();
00151 waitingVariables.pop_back();
00152
00153
00154 for(size_t i=0; i<gm.numberOfFactors(var); ++i) {
00155 size_t var2=var;
00156 size_t factorId = gm.factorOfVariable(var,i);
00157 for(size_t n=0; n<gm[factorId].numberOfVariables();++n) {
00158 if(conf[gm[factorId].variableIndex(n)] == std::numeric_limits<LabelType>::max()) {
00159 var2=gm[factorId].variableIndex(n);
00160 break;
00161 }
00162 }
00163 if(var2 != var) {
00164
00165 IndependentFactorType t;
00166
00167 for(size_t i=0; i<gm.numberOfFactors(var2); ++i) {
00168 size_t factorId = gm.factorOfVariable(var2,i);
00169 std::vector<IndexType> knownVariables;
00170 std::vector<LabelType> knownStates;
00171 std::vector<IndexType> unknownVariables;
00172 IndependentFactorType out;
00173 InferenceTermination term = factorMarginal(factorId, out);
00174 if(NORMAL != term) {
00175 return term;
00176 }
00177 for(size_t n=0; n<gm[factorId].numberOfVariables();++n) {
00178 if(gm[factorId].variableIndex(n)!=var2) {
00179 if(conf[gm[factorId].variableIndex(n)] < std::numeric_limits<LabelType>::max()) {
00180 knownVariables.push_back(gm[factorId].variableIndex(n));
00181 knownStates.push_back(conf[gm[factorId].variableIndex(n)]);
00182 }else{
00183 unknownVariables.push_back(gm[factorId].variableIndex(n));
00184 }
00185 }
00186 }
00187
00188 out.fixVariables(knownVariables.begin(), knownVariables.end(), knownStates.begin());
00189 if(unknownVariables.size()>0)
00190 out.template accumulate<AccumulationType>(unknownVariables.begin(),unknownVariables.end());
00191 OperatorType::op(out,t);
00192 }
00193 ValueType value;
00194 std::vector<LabelType> state(t.numberOfVariables());
00195 t.template accumulate<AccumulationType>(value,state);
00196 conf[var2] = state[0];
00197 ++numberOfFixedVariables;
00198 waitingVariables.push_back(var2);
00199 }
00200 }
00201 }
00202 if(conf[variableId]==std::numeric_limits<LabelType>::max()) {
00203
00204 IndependentFactorType out;
00205 InferenceTermination term = marginal(variableId, out);
00206 if(NORMAL != term) {
00207 return term;
00208 }
00209 ValueType value;
00210 std::vector<LabelType> state(out.numberOfVariables());
00211 out.template accumulate<AccumulationType>(value,state);
00212 conf[variableId] = state[0];
00213 waitingVariables.push_back(variableId);
00214 }
00215 ++variableId;
00216 }
00217 return NORMAL;
00218 }
00219
00220 template<class GM, class ACC>
00221 InferenceTermination
00222 Inference<GM, ACC>::modeFromMarginal(
00223 std::vector<LabelType>& conf
00224 ) const
00225 {
00226 const GM& gm = graphicalModel();
00227
00228 size_t numberOfNodes = gm.numberOfVariables();
00229 conf.resize(gm.numberOfVariables());
00230 IndependentFactorType out;
00231 for(size_t node=0; node<numberOfNodes; ++node) {
00232 InferenceTermination term = marginal(node, out);
00233 if(NORMAL != term) {
00234 return term;
00235 }
00236 ValueType value = out(0);
00237 size_t state = 0;
00238 for(size_t i=1; i<gm.numberOfLabels(node); ++i) {
00239 if(ACC::bop(out(i), value)) {
00240 value = out(i);
00241 state = i;
00242 }
00243 }
00244 conf[node] = state;
00245 }
00246 return NORMAL;
00247 }
00248
00249 template<class GM, class ACC>
00250 InferenceTermination
00251 Inference<GM, ACC>::modeFromFactorMarginal(
00252 std::vector<LabelType>& conf
00253 ) const
00254 {
00255 const GM& gm = graphicalModel();
00256 std::vector<IndexType> knownVariables;
00257 std::vector<LabelType> knownStates;
00258 IndependentFactorType out;
00259 for(size_t node=0; node<gm.numberOfVariables(); ++node) {
00260 InferenceTermination term = marginal(node, out);
00261 if(NORMAL != term) {
00262 return term;
00263 }
00264 ValueType value = out(0);
00265 size_t state = 0;
00266 bool unique = true;
00267 for(size_t i=1; i<gm.numberOfLabels(node); ++i) {
00268
00269
00270
00271
00272
00273 if(fabs(out(i) - value)<0.00001) {
00274 unique=false;
00275 }
00276 else if(ACC::bop(out(i), value)) {
00277 value = out(i);
00278 state = i;
00279 unique=true;
00280 }
00281 }
00282 if(unique) {
00283 knownVariables.push_back(node);
00284 knownStates.push_back(state);
00285 }
00286 }
00287 return constrainedOptimum( knownVariables, knownStates, conf);
00288 }
00289
00291 template<class GM, class ACC>
00292 typename GM::ValueType
00293 Inference<GM, ACC>::value() const
00294 {
00295 if(ACC::hasbop()){
00296
00297 std::vector<LabelType> s;
00298 const GM& gm = graphicalModel();
00299 if(NORMAL == arg(s)) {
00300 return gm.evaluate(s);
00301 }
00302 else {
00303 return ACC::template neutral<ValueType>();
00304 }
00305 }else{
00306
00307
00308 return std::numeric_limits<ValueType>::quiet_NaN();
00309 }
00310 }
00311
00313 template<class GM, class ACC>
00314 typename GM::ValueType
00315 Inference<GM, ACC>::bound() const {
00316 if(ACC::hasbop()){
00317
00318 return ACC::template ineutral<ValueType>();
00319 }else{
00320
00321
00322 return std::numeric_limits<ValueType>::quiet_NaN();
00323 }
00324 }
00325
00326 }
00327
00328 #endif // #ifndef OPENGM_INFERENCE_HXX