00001
00002
00003
00004
00005
00006
00007
00008 #ifndef PRIMAL_LPBOUND_HXX_
00009 #define PRIMAL_LPBOUND_HXX_
00010 #include <limits>
00011 #include <algorithm>
00012 #include <opengm/inference/auxiliary/transportationsolver.hxx>
00013 #include <opengm/graphicalmodel/graphicalmodel.hxx>
00014 #include <opengm/inference/trws/utilities2.hxx>
00015
00016 namespace opengm
00017 {
00018
00019 using trws_base::FactorWrapper;
00020 using trws_base::VariableToFactorMapping;
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034 template<class ValueType>
00035 struct PrimalLPBound_Parameter
00036 {
00037 PrimalLPBound_Parameter(ValueType relativePrecision,
00038 size_t maxIterationNumber)
00039 :relativePrecision_(relativePrecision),
00040 maxIterationNumber_(maxIterationNumber){};
00041
00042 ValueType relativePrecision_;
00043 size_t maxIterationNumber_;
00044 };
00045
00060
00061 template <class GM,class ACC>
00062 class PrimalLPBound
00063 {
00064 public:
00065 typedef TransportSolver::TransportationSolver<ACC,FactorWrapper<typename GM::FactorType> > Solver;
00066 typedef typename Solver::floatType ValueType;
00067 typedef std::vector<ValueType> UnaryFactor;
00068 typedef typename GM::IndexType IndexType;
00069 typedef typename GM::LabelType LabelType;
00070
00071 static const IndexType InvalidIndex;
00072 static const ValueType ValueTypeNan;
00073
00074 typedef PrimalLPBound_Parameter<ValueType> Parameter;
00075
00076 PrimalLPBound(const GM& gm,const Parameter& param=Parameter(Solver::floatTypeEps,Solver::defaultMaxIterationNumber));
00077
00078 template<class ValueIterator>
00079 void setVariable(IndexType var, ValueIterator inputBegin);
00080 template<class ValueIterator>
00081 void getVariable(IndexType var, ValueIterator outputBegin);
00082
00083 ValueType getTotalValue();
00084 ValueType getFactorValue(IndexType factorId);
00085 ValueType getVariableValue(IndexType varId);
00086 template<class Matrix>
00087 ValueType getFactorVariable(IndexType factorId, Matrix& matrix);
00088
00089 void ResetBuffer(){_bufferedValues(_gm.numberOfFactors(),ValueTypeNan); _totalValue=ValueTypeNan;}
00090 bool IsValueBuffered(IndexType factorId)const{OPENGM_ASSERT(factorId<_bufferedValues.size()); return (_bufferedValues[factorId] != ValueTypeNan);}
00091 bool IsFactorVariableBuffered(IndexType factorId)const{return _lastActiveSolver==factorId;}
00092 static void CheckDuplicateUnaryFactors(const GM& gm);
00093 private:
00094 void _checkPWFactorID(IndexType factorId,const std::string& message_prefix=std::string());
00095 const GM& _gm;
00096 Solver _solver;
00097 std::vector<UnaryFactor> _unaryFactors;
00098 VariableToFactorMapping<GM> _mapping;
00099
00100 std::vector<ValueType> _bufferedValues;
00101 IndexType _lastActiveSolver;
00102 ValueType _totalValue;
00103 };
00104
00105 template <class GM,class ACC>
00106 void PrimalLPBound<GM,ACC>::CheckDuplicateUnaryFactors(const GM& gm)
00107 {
00108 std::vector<IndexType> numOfunaryFactors(gm.numberOfVariables(),0);
00109 for (IndexType factorId=0;factorId<gm.numberOfFactors();++factorId)
00110 {
00111 if (gm[factorId].numberOfVariables()!=1)
00112 continue;
00113
00114 numOfunaryFactors[gm[factorId].variableIndex(0)]++;
00115 }
00116
00117 IndexType moreCount=std::count_if(numOfunaryFactors.begin(),numOfunaryFactors.end(),std::bind2nd(std::greater<IndexType>(),1));
00118 if (moreCount!=0)
00119 throw std::runtime_error("PrimalLPBound::CheckDuplicateUnaryFactors: all variables must have not more then a single associated unary factor!");
00120 }
00121
00122 template <class GM,class ACC>
00123 const typename PrimalLPBound<GM,ACC>::IndexType PrimalLPBound<GM,ACC>::InvalidIndex=std::numeric_limits<IndexType>::max();
00124
00125 template <class GM,class ACC>
00126 const typename PrimalLPBound<GM,ACC>::ValueType PrimalLPBound<GM,ACC>::ValueTypeNan=std::numeric_limits<ValueType>::max();
00127
00128 template <class GM,class ACC>
00129 PrimalLPBound<GM,ACC>::PrimalLPBound(const GM& gm,const Parameter& param):
00130 _gm(gm),
00131 _solver(
00132 #ifdef TRWS_DEBUG_OUTPUT
00133 std::cerr,
00134 #endif
00135 param.relativePrecision_,param.maxIterationNumber_),
00136 _unaryFactors(gm.numberOfVariables()),
00137 _mapping(gm),
00138 _bufferedValues(gm.numberOfFactors(),ValueTypeNan),
00139 _lastActiveSolver(InvalidIndex),
00140 _totalValue(ValueTypeNan)
00141 {
00142 CheckDuplicateUnaryFactors(gm);
00143
00144 for (size_t i=0;i<_unaryFactors.size();++i)
00145 _unaryFactors[i].assign(_gm.numberOfLabels(i),0);
00146 }
00147
00148 template <class GM,class ACC>
00149 template<class Iterator>
00150 void PrimalLPBound<GM,ACC>::setVariable(IndexType var, Iterator inputBegin)
00151 {
00152 OPENGM_ASSERT(var < _gm.numberOfVariables());
00153 _totalValue=ValueTypeNan;
00154 std::copy(inputBegin,inputBegin+_unaryFactors[var].size(),_unaryFactors[var].begin());
00155
00156
00157 IndexType numOfFactors=_gm.numberOfFactors(var);
00158 for (IndexType i=0;i<numOfFactors;++i)
00159 {
00160 IndexType factorId=_gm.factorOfVariable(var,i);
00161 OPENGM_ASSERT(factorId < _gm.numberOfFactors() );
00162 _bufferedValues[factorId] = ValueTypeNan;
00163 }
00164 }
00165
00166 template <class GM,class ACC>
00167 template<class Iterator>
00168 void PrimalLPBound<GM,ACC>::getVariable(IndexType var, Iterator outputBegin)
00169 {
00170 OPENGM_ASSERT(var < _gm.numberOfVariables());
00171 std::copy(_unaryFactors[var].begin(),_unaryFactors[var].end(),outputBegin);
00172 }
00173
00174 template <class GM,class ACC>
00175 void PrimalLPBound<GM,ACC>::_checkPWFactorID(IndexType factorId, const std::string& message_prefix)
00176 {
00177 OPENGM_ASSERT(factorId < _gm.numberOfFactors());
00178 if (_gm[factorId].numberOfVariables() !=2 )
00179 std::runtime_error(message_prefix + "Function can be applied to second order factors only!");
00180 }
00181
00182 template <class GM,class ACC>
00183 typename PrimalLPBound<GM,ACC>::ValueType PrimalLPBound<GM,ACC>::getFactorValue(IndexType factorId)
00184 {
00185 _checkPWFactorID(factorId,"PrimalLPBound::getFactorValue(): ");
00186
00187 if (_bufferedValues[factorId] == ValueTypeNan)
00188 {
00189 const typename GM::FactorType& factor=_gm[factorId];
00190 IndexType var0=factor.variableIndex(0),
00191 var1=factor.variableIndex(1);
00192 _solver.Init(_unaryFactors[var0].size(),_unaryFactors[var1].size(),FactorWrapper<typename GM::FactorType>(factor));
00193 _bufferedValues[factorId]=_solver.Solve(_unaryFactors[var0].begin(),_unaryFactors[var1].begin());
00194 _lastActiveSolver=factorId;
00195 }
00196
00197 return _bufferedValues[factorId];
00198 }
00199
00200 template <class GM,class ACC>
00201 template<class Matrix>
00202 typename PrimalLPBound<GM,ACC>::ValueType PrimalLPBound<GM,ACC>::getFactorVariable(IndexType factorId, Matrix& matrix)
00203 {
00204 _checkPWFactorID(factorId,"PrimalLPBound::getFactorVariable(): ");
00205
00206 if (_lastActiveSolver!=factorId)
00207 getFactorValue(factorId);
00208
00209 return _solver.GetSolution(&matrix);
00210 }
00211
00212 template <class GM,class ACC>
00213 typename PrimalLPBound<GM,ACC>::ValueType PrimalLPBound<GM,ACC>::getVariableValue(IndexType varId)
00214 {
00215 OPENGM_ASSERT(varId < _gm.numberOfVariables());
00216 OPENGM_ASSERT(varId < _unaryFactors.size());
00217 IndexType factorId=_mapping(varId);
00218 OPENGM_ASSERT(_mapping(varId) < _gm.numberOfFactors());
00219 if (factorId==VariableToFactorMapping<GM>::InvalidIndex)
00220 return (ValueType)0;
00221
00222 if (_bufferedValues[factorId] != ValueTypeNan)
00223 return _bufferedValues[factorId];
00224
00225 ValueType sum=0;
00226 const UnaryFactor& uf=_unaryFactors[varId];
00227 OPENGM_ASSERT(_gm.numberOfLabels(varId)==uf.size());
00228 OPENGM_ASSERT(_gm.numberOfLabels(varId)>0);
00229 const typename GM::FactorType& f=_gm[factorId];
00230 for (LabelType i=0;i<uf.size();++i)
00231 sum+=uf[i]*f(&i);
00232
00233 _bufferedValues[factorId]=sum;
00234 return sum;
00235 }
00236
00237 template <class GM,class ACC>
00238 typename PrimalLPBound<GM,ACC>::ValueType PrimalLPBound<GM,ACC>::getTotalValue()
00239 {
00240 if (_totalValue==ValueTypeNan)
00241 {
00242 _totalValue=0;
00243 for (IndexType factorId=0;factorId<_gm.numberOfFactors();++factorId)
00244 {
00245 const typename GM::FactorType& f=_gm[factorId];
00246 switch (f.numberOfVariables())
00247 {
00248 case 1: _totalValue+=getVariableValue(f.variableIndex(0)); break;
00249 case 2: _totalValue+=getFactorValue(factorId);break;
00250 default: throw std::runtime_error("PrimalLPBound::getTotalValue(): Only factors of order <= 2 are supported!");
00251 }
00252 }
00253 }
00254 return _totalValue;
00255 }
00256
00257 }
00258 #endif