00001 #ifndef TRWS_SUBPROBLEMSOLVER_HXX_
00002 #define TRWS_SUBPROBLEMSOLVER_HXX_
00003 #include <iostream>
00004 #include <list>
00005 #include <algorithm>
00006 #include <utility>
00007 #include <functional>
00008 #include <valarray>
00009
00010 #include <opengm/inference/trws/utilities2.hxx>
00011 #include <opengm/functions/view_fix_variables_function.hxx>
00012
00013 #ifdef TRWS_DEBUG_OUTPUT
00014 #include <opengm/inference/trws/output_debug_utils.hxx>
00015 #endif
00016
00017
00018 namespace trws_base{
00019
00020 #ifdef TRWS_DEBUG_OUTPUT
00021 using OUT::operator <<;
00022 #endif
00023
00024 template<class GM>
00025 class SequenceStorage
00026 {
00027 public:
00028
00029 typedef typename GM::ValueType ValueType;
00030 typedef typename GM::IndexType IndexType;
00031 typedef typename GM::LabelType LabelType;
00032 typedef std::vector<IndexType> IndexList;
00033 typedef std::vector<ValueType> UnaryFactor;
00034 typedef enum{Direct,Reverse} MoveDirection;
00035 typedef VariableToFactorMapping<GM> VariableToFactorMap;
00036
00037 SequenceStorage(const GM& masterModel,const VariableToFactorMap& var2FactorMap,const IndexList& variableList, const IndexList& pwFactorList,const IndexList& numberOfTreesPerFactor);
00038
00039 ~SequenceStorage(){};
00040 IndexType size()const{return (IndexType)_directIndex.size();};
00041
00042 #ifdef TRWS_DEBUG_OUTPUT
00043 void PrintTestData(std::ostream& fout)const;
00044 #endif
00045
00046 static MoveDirection ReverseDirection(MoveDirection dir);
00047
00048
00049
00050
00051
00052
00053
00054 void AllocateUnaryFactors(std::vector<UnaryFactor>* pfactors);
00055 MoveDirection pwDirection(IndexType pwInd)const{assert(pwInd<_pwDirection.size()); return _pwDirection[pwInd];};
00056 IndexType pwForwardFactor(IndexType var)const{assert(var<_pwForwardIndex.size()); return _pwForwardIndex[var];}
00057 const GM& masterModel()const{return _masterModel;}
00058
00059
00060
00061 const UnaryFactor& unaryFactors(IndexType indx)const{assert(indx<_unaryFactors.size()); return _unaryFactors[indx];}
00062 typename UnaryFactor::iterator ufBegin(IndexType indx){assert(indx<_unaryFactors.size()); return _unaryFactors[indx].begin();}
00063 typename UnaryFactor::iterator ufEnd (IndexType indx){assert(indx<_unaryFactors.size()); return _unaryFactors[indx].end() ;}
00064 IndexType varIndex(IndexType var)const{assert(var<_directIndex.size()); return _directIndex[var];};
00065
00066 template<class ITERATOR>
00067 ValueType evaluate(ITERATOR labeling);
00068 private:
00069 void _ConsistencyCheck();
00070 void _Reset(const IndexList& numOfSequencesPerFactor);
00071 void _Reset(IndexType var,IndexType numOfSequences);
00072
00073 const GM& _masterModel;
00075 IndexList _directIndex;
00076 IndexList _pwForwardIndex;
00077 std::vector<UnaryFactor> _unaryFactors;
00078 std::vector<MoveDirection> _pwDirection;
00079 const VariableToFactorMap& _var2FactorMap;
00080 };
00081
00082
00083
00084 template<class GM>
00085 class FunctionParameters
00086 {
00087 public:
00088 typedef enum {GENERAL,POTTS} FunctionType;
00089 typedef typename GM::ValueType ValueType;
00090 typedef std::valarray<ValueType> ParameterStorageType;
00091 typedef typename GM::IndexType IndexType;
00092 typedef typename GM::LabelType LabelType;
00093
00094 FunctionParameters(const GM& gm);
00095 FunctionType getFunctionType(IndexType factorId)const{return _factorTypes[factorId];};
00096 const ParameterStorageType& getFunctionParameters(IndexType factorId)const
00097 {
00098
00099 return _parameters[factorId];
00100 }
00101 #ifdef TRWS_DEBUG_OUTPUT
00102 void PrintStatusData(std::ostream& fout);
00103 #endif
00104 private:
00105 void _checkConsistency() const;
00106 void _getPottsParameters(const typename GM::FactorType& factor,ParameterStorageType* pstorage)const;
00107 const GM& _gm;
00108 std::vector<ParameterStorageType> _parameters;
00109 std::vector<FunctionType> _factorTypes;
00110 };
00111
00112 template<class GM>
00113 FunctionParameters<GM>::FunctionParameters(const GM& gm)
00114 : _gm(gm),_parameters(_gm.numberOfFactors()),_factorTypes(_gm.numberOfFactors())
00115 {
00116 for (IndexType i=0;i<_gm.numberOfFactors();++i)
00117 {
00118 const typename GM::FactorType& f=_gm[i];
00119
00120 if ((f.numberOfVariables()==2) && f.isPotts())
00121 {
00122 _factorTypes[i]=POTTS;
00123 _getPottsParameters(f,&_parameters[i]);
00124 }else _factorTypes[i]=GENERAL;
00125 }
00126
00127
00128 }
00129
00130 template<class GM>
00131 void FunctionParameters<GM>::_checkConsistency()const
00132 {
00133 OPENGM_ASSERT(_parameters.size()==_gm.numberOfFactors());
00134 OPENGM_ASSERT(_factorTypes.size()==_gm.numberOfFactors());
00135 for (size_t i=0;i<_parameters.size();++i)
00136 if (_factorTypes[i]==POTTS)
00137 {
00138 OPENGM_ASSERT(_parameters[i].size()==2);
00139 }
00140 }
00141
00142 template<class GM>
00143 void FunctionParameters<GM>::_getPottsParameters(const typename GM::FactorType& f,ParameterStorageType* pstorage)const
00144 {
00145 pstorage->resize(2,0.0);
00146 LabelType v00[]={0,0};
00147 LabelType v01[]={0,1};
00148 LabelType v10[]={1,0};
00149 if ((f.numberOfLabels(0)>0) && (f.numberOfLabels(1)>0))
00150 (*pstorage)[1]=f(&v00[0]);
00151 if (f.numberOfLabels(0)>1)
00152 (*pstorage)[0]=f(&v10[0])-f(&v00[0]);
00153 else if (f.numberOfLabels(0)>1)
00154 (*pstorage)[0]=f(&v01[0])-f(&v00[0]);
00155 }
00156
00157 #ifdef TRWS_DEBUG_OUTPUT
00158 template<class GM>
00159 void FunctionParameters<GM>:: PrintStatusData(std::ostream& fout)
00160 {
00161 size_t numPotts=0;
00162 for (size_t i=0;i<_parameters.size();++i) numPotts+= (_factorTypes[i]==POTTS ? 1 : 0) ;
00163 fout << "Total number of factors:" <<_factorTypes.size()<<std::endl;
00164 fout << "Number of POTTS p/w factors:" << numPotts <<std::endl;
00165 }
00166 #endif
00167
00168
00169
00170 template<class GM,class ACC,class InputIterator>
00171 class DynamicProgramming
00172 {
00173 public:
00174 typedef GM GMType;
00175 typedef ACC ACCType;
00176 typedef typename GM::ValueType ValueType;
00177 typedef typename GM::IndexType IndexType;
00178 typedef typename GM::LabelType LabelType;
00179
00180 typedef InputIterator InputIteratorType;
00181 typedef SequenceStorage<GM> Storage;
00182 typedef typename Storage::IndexList IndexList;
00183 typedef typename Storage::UnaryFactor UnaryFactor;
00184 typedef typename Storage::MoveDirection MoveDirection;
00185 typedef std::vector<IndexList> IndexTable;
00186 typedef FunctionParameters<GM> FactorProperties;
00187 typedef typename UnaryFactor::const_iterator ConstIterator;
00188 typedef typename GM::FactorType Factor;
00189 typedef std::pair<typename UnaryFactor::const_iterator,typename UnaryFactor::const_iterator> const_iterators_pair;
00190
00191 public:
00192 static const IndexType NaN;
00193
00194 DynamicProgramming(Storage& storage,const FactorProperties& factorProperties,bool fastComputations=true);
00195
00196 virtual ~DynamicProgramming(){};
00197
00210 void InitMove(){_InitMove(1.0,Storage::Direct);};
00211 void InitMove(MoveDirection movedirection){_InitMove(1.0,movedirection);};
00212 virtual void InitReverseMove(){_InitMove(_rho,Storage::ReverseDirection(_moveDirection));};
00213 virtual void Move();
00214 virtual void PushBack();
00215 virtual void MoveBack();
00219 const_iterators_pair GetMarginals()const{return std::make_pair(_marginals[_currentUnaryIndex].begin(),_marginals[_currentUnaryIndex].end());};
00220 const_iterators_pair GetMarginals(IndexType indx)const{assert(indx<_marginals.size()); return std::make_pair(_marginals[indx].begin(),_marginals[indx].end());};
00221
00222 ValueType GetObjectiveValue()const{return _objectiveValue;};
00223
00224
00225
00226 virtual ValueType ComputeObjectiveValue()=0;
00227
00228
00229
00230
00231
00232 virtual void IncreaseUnaryWeights(InputIteratorType begin,InputIteratorType end);
00233
00234
00235
00236 virtual void FinalizeMove();
00240 LabelType numOfLabels()const{const_iterators_pair p=GetMarginals(); return p.second-p.first;}
00241 virtual void UpdateMarginals();
00242
00243 virtual IndexType getNextPWId()const;
00244 virtual IndexType getPrevPWId()const;
00245
00246 MoveDirection getMoveDirection()const{return _moveDirection;}
00247 IndexType size()const{return (IndexType)_storage.size();}
00248 template<class ITERATOR>
00249 ValueType evaluate(ITERATOR labeling){return _storage.evaluate(labeling);}
00253 #ifdef TRWS_DEBUG_OUTPUT
00254 virtual void PrintTestData(std::ostream& fout)const;
00255 #endif
00256
00257 void SetFastComputation(bool fc){_fastComputation=fc;}
00258
00259 protected:
00260
00261 void _PottsUnaryTransform(LabelType newSize,const typename FactorProperties::ParameterStorageType& params);
00262
00263 void _InitReverseMoveBack(){_core_InitMoves(_rho,Storage::ReverseDirection(_moveDirection));};
00264 void _InitMove(ValueType rho,MoveDirection movedirection);
00265 virtual void _Push();
00266 void _core_InitMoves(ValueType rho,MoveDirection movedirection);
00267 void _PushMessagesToFactor();
00268 void _ClearMessages(UnaryFactor* pbuffer=0);
00269 virtual void _makeLocalCopyOfPWFactor(LabelType trgsize);
00270 void _SumUpBufferToMarginals();
00271 virtual void _BackUpForwardMarginals(){};
00272 virtual void _InitCurrentUnaryBuffer(IndexType index);
00273
00274 IndexType _core_next(IndexType begin,MoveDirection dir)const;
00275 IndexType _next(IndexType begin)const;
00276 IndexType _previous(IndexType begin)const;
00277 IndexType _nextPWIndex()const;
00278
00279 bool _fastComputation;
00280 Storage& _storage;
00281 const FactorProperties& _factorProperties;
00282
00283 std::vector<UnaryFactor> _marginals;
00284
00285 ValueType _objectiveValue;
00286 ValueType _rho;
00287 MoveDirection _moveDirection;
00288 bool _bInitializationNeeded;
00289
00290
00291 UnaryFactor _currentPWFactor;
00292 UnaryFactor _currentUnaryFactor;
00293 IndexType _currentUnaryIndex;
00294
00295 mutable UnaryFactor _unaryTemp;
00296 mutable Pseudo2DArray<ValueType> _spst;
00297 };
00298
00299
00300 template<class GM,class ACC,class InputIterator>
00301 class MaxSumSolver : public DynamicProgramming<GM,ACC,InputIterator>
00302 {
00303 public:
00304 typedef DynamicProgramming<GM,ACC,InputIterator> parent;
00305 typedef typename parent::ValueType ValueType;
00306 typedef typename parent::IndexType IndexType;
00307 typedef typename parent::LabelType LabelType;
00308 typedef typename parent::InputIteratorType InputIteratorType;
00309 typedef std::vector<LabelType> LabelingType;
00310 typedef typename parent::UnaryFactor UnaryFactor;
00311 typedef typename parent::Factor Factor;
00312 typedef typename parent::FactorProperties FactorProperties;
00313
00314
00315 MaxSumSolver(typename parent::Storage& storage,const FactorProperties& factorProperties,bool fastComputations=true)
00316 :parent(storage,factorProperties,fastComputations),
00317 _labeling(parent::size(),parent::NaN)
00318
00319 {};
00320
00321 #ifdef TRWS_DEBUG_OUTPUT
00322 void PrintTestData(std::ostream& fout)const
00323 {
00324 parent::PrintTestData(fout);
00325 fout << "_labeling: "<<_labeling<<std::endl;
00326 }
00327 #endif
00328
00329 ValueType ComputeObjectiveValue();
00330 const LabelingType& arg(){return _labeling;}
00331
00332 void FinalizeMove();
00333
00334 protected:
00335 void _Push();
00336 void _SumUpBackwardEdges(UnaryFactor* u, LabelType fixedLabel)const;
00337 void _EstimateOptimalLabeling();
00338 LabelingType _labeling;
00339 mutable UnaryFactor _marginalsTemp;
00340
00341 };
00342
00343 template<class GM,class ACC,class InputIterator>
00344 void MaxSumSolver<GM,ACC,InputIterator>::_EstimateOptimalLabeling()
00345 {
00346 OPENGM_ASSERT((parent::_currentUnaryIndex==0)||(parent::_currentUnaryIndex==parent::size()-1));
00347 OPENGM_ASSERT(_labeling[parent::_currentUnaryIndex]<parent::_marginals[parent::_currentUnaryIndex].size());
00348
00349 IndexType bk_currentUnaryIndex=parent::_currentUnaryIndex;
00350
00351 typename parent::MoveDirection bk_moveDirection=parent::_moveDirection;
00352 parent::_moveDirection=parent::Storage::ReverseDirection(parent::_moveDirection);
00353
00354
00355 LabelType optLabel=_labeling[parent::_currentUnaryIndex];
00356
00357 for (IndexType i=1;i<parent::size();++i)
00358 {
00359 parent::_currentUnaryIndex=parent::_next(parent::_currentUnaryIndex);
00360 _marginalsTemp=parent::_marginals[parent::_currentUnaryIndex];
00361 _SumUpBackwardEdges(&_marginalsTemp,optLabel);
00362
00363 _labeling[parent::_currentUnaryIndex]=optLabel=std::max_element(_marginalsTemp.begin(),_marginalsTemp.end(),
00364 ACC::template ibop<ValueType>)-_marginalsTemp.begin();
00365 }
00366
00367
00368 parent::_moveDirection=bk_moveDirection;
00369 parent::_currentUnaryIndex=bk_currentUnaryIndex;
00370 }
00371
00372 template<class GM,class ACC,class InputIterator>
00373 typename MaxSumSolver<GM,ACC,InputIterator>::ValueType
00374 MaxSumSolver<GM,ACC,InputIterator>::ComputeObjectiveValue()
00375 {
00376 _labeling[parent::_currentUnaryIndex]=std::max_element(parent::_marginals[parent::_currentUnaryIndex].begin(),
00377 parent::_marginals[parent::_currentUnaryIndex].end(),ACC::template ibop<ValueType>)
00378 -parent::_marginals[parent::_currentUnaryIndex].begin();
00379 return parent::_marginals[parent::_currentUnaryIndex][_labeling[parent::_currentUnaryIndex]];
00380 }
00381
00382 template<class GM,class ACC,class InputIterator>
00383 void MaxSumSolver<GM,ACC,InputIterator>::FinalizeMove()
00384 {
00385 parent::FinalizeMove();
00386 _EstimateOptimalLabeling();
00387 };
00388
00389 template <class T,class ACC> struct compToValue : std::unary_function <T,T> {
00390 compToValue(T val):_val(val){};
00391 T operator() (T x) const
00392 {return (ACC::template bop<T>(x,_val) ? x : _val);}
00393 private:
00394 T _val;
00395 };
00396
00397 template<class GM,class ACC,class InputIterator>
00398 void DynamicProgramming<GM,ACC,InputIterator>::_PottsUnaryTransform(LabelType newSize,const typename FactorProperties::ParameterStorageType& params)
00399 {
00400 OPENGM_ASSERT(params.size()==2);
00401 UnaryFactor* puf=&(_currentUnaryFactor);
00402 if (newSize< puf->size())
00403 puf->resize(newSize);
00404
00405 typename UnaryFactor::iterator bestValIt=std::max_element(puf->begin(),puf->end(),ACC::template ibop<ValueType>);
00406 ValueType bestVal=*bestValIt;
00407 ValueType secondBestVal=bestVal;
00408 if (ACC::bop(params[0],static_cast<ValueType>(0.0)))
00409 {
00410 *bestValIt=ACC::template neutral<ValueType>();
00411 secondBestVal=*std::max_element(puf->begin(),puf->end(),ACC::template ibop<ValueType>);
00412 *bestValIt=bestVal;
00413 }
00414
00415 transform_inplace(puf->begin(),puf->end(),compToValue<ValueType,ACC>(bestVal+params[0]));
00416
00417 if (ACC::bop(params[0],static_cast<ValueType>(0.0)))
00418 ACC::op(secondBestVal+params[0],bestVal,*bestValIt);
00419
00420 if (params[1]!=0.0)
00421 transform_inplace(puf->begin(),puf->end(),std::bind1st(std::plus<ValueType>(),params[1]));
00422 if (newSize> puf->size())
00423 puf->resize(newSize,params[0]+params[1]+bestVal);
00424
00425 }
00426
00427 template<class GM,class ACC,class InputIterator>
00428 void MaxSumSolver<GM,ACC,InputIterator>::_Push()
00429 {
00430 IndexType factorId=parent::_storage.pwForwardFactor(parent::_nextPWIndex());
00431 if ((parent::_factorProperties.getFunctionType(factorId)==FunctionParameters<GM>::POTTS) && parent::_fastComputation)
00432 {
00433 parent::_currentUnaryIndex=parent::_next(parent::_currentUnaryIndex);
00434 LabelType newSize=parent::_storage.unaryFactors(parent::_currentUnaryIndex).size();
00435
00436
00437 parent::_PottsUnaryTransform(newSize,parent::_factorProperties.getFunctionParameters(factorId));
00438 std::transform(parent::_currentUnaryFactor.begin(),parent::_currentUnaryFactor.end(),
00439 parent::_storage.unaryFactors(parent::_currentUnaryIndex).begin(),
00440 parent::_currentUnaryFactor.begin(),plus2ndMul<ValueType>(1.0/parent::_rho));
00441 }else
00442 parent::_Push();
00443 }
00444
00445
00446
00447
00448
00449 template<class GM,class ACC,class InputIterator>
00450 class SumProdSolver : public DynamicProgramming<GM,ACC,InputIterator>
00451 {
00452 public:
00453 typedef DynamicProgramming<GM,ACC,InputIterator> parent;
00454 typedef typename parent::ValueType ValueType;
00455 typedef typename parent::IndexType IndexType;
00456 typedef typename parent::LabelType LabelType;
00457 typedef typename parent::InputIteratorType InputIteratorType;
00458 typedef typename parent::const_iterators_pair const_iterators_pair;
00459 typedef typename parent::Storage Storage;
00460 typedef typename parent::MoveDirection MoveDirection;
00461 typedef typename parent::UnaryFactor UnaryFactor;
00462 typedef typename parent::FactorProperties FactorProperties;
00463
00464
00465 SumProdSolver(Storage& storage,const FactorProperties& factorProperties,bool fastComputations=true)
00466 :parent(storage,factorProperties,fastComputations),_averagingFlag(false){ACC::op(1.0,-1.0,_mul);};
00467 void InitMove(ValueType rho){parent::_InitMove(rho,Storage::Direct);};
00468 void InitMove(ValueType rho,MoveDirection movedirection){parent::_InitMove(rho,movedirection);};
00469
00470 ValueType ComputeObjectiveValue();
00471 ValueType MoveBackGetDerivative();
00472 ValueType getDerivative()const{return _derivativeValue;}
00473 protected:
00474 void _Push();
00475 void _ExponentiatePWFactor();
00476 void _PushMessagesToVariable();
00477 void _PushAndAverage();
00478 void _UpdatePWAverage();
00479 ValueType _getMarginalsLogNormalizer()const{return parent::GetObjectiveValue()/parent::_rho;}
00480 ValueType _GetAveragedUnaryFactors();
00481 void _makeLocalCopyOfPWFactor(LabelType trgsize);
00482 void _InitCurrentUnaryBuffer(IndexType index);
00483
00484 ValueType _mul;
00485 bool _averagingFlag;
00486
00487
00488
00489
00490 UnaryFactor _unaryBuffer;
00491 UnaryFactor _copyPWfactor;
00492 ValueType _derivativeValue;
00493 };
00494
00495
00496
00497 #ifdef TRWS_DEBUG_OUTPUT
00498 template<class GM>
00499 void SequenceStorage<GM>::PrintTestData(std::ostream& fout)const
00500 {
00501 fout << "_directIndex:" <<_directIndex;
00502 fout << "_pwForwardIndex:" <<_pwForwardIndex;
00503 fout << "_unaryFactors:" <<std::endl<<_unaryFactors;
00504 fout << "_pwDirection:" << _pwDirection;
00505 };
00506 #endif
00507
00508 template<class GM>
00509 SequenceStorage<GM>::SequenceStorage(const GM& masterModel,const VariableToFactorMap& var2FactorMap,
00510 const IndexList& variableList,
00511 const IndexList& pwFactorList,
00512 const IndexList& numOfSequencesPerFactor)
00513 :_masterModel(masterModel),
00514 _directIndex(variableList),
00515 _pwForwardIndex(pwFactorList),
00516 _pwDirection(pwFactorList.size())
00517 ,_var2FactorMap(var2FactorMap)
00518 {
00519 _ConsistencyCheck();
00520 AllocateUnaryFactors(&_unaryFactors);
00521 _Reset(numOfSequencesPerFactor);
00522 }
00523
00524 template<class GM>
00525 void SequenceStorage<GM>::_ConsistencyCheck()
00526 {
00527 exception_check((_directIndex.size()-1)==_pwForwardIndex.size(),"DynamicProgramming::_ConsistencyCheck(): (_directIndex.size()-1)!=_pwForwardIndex.size()");
00528
00529 LabelType v[2];
00530 for (IndexType i=0;i<size()-1;++i)
00531 {
00532 exception_check(_masterModel[pwForwardFactor(i)].numberOfVariables()==2,"DynamicProgramming::_ConsistencyCheck():factor.numberOfVariables()!=2");
00533 _masterModel[pwForwardFactor(i)].variableIndices(&v[0]);
00534
00535 if (v[0]==varIndex(i))
00536 {
00537 exception_check(v[1]==varIndex(i+1),"DynamicProgramming::_ConsistencyCheck(): v[1]!=varIndex(i+1)");
00538 _pwDirection[i]=Direct;
00539 }
00540 else if (v[0]==varIndex(i+1))
00541 {
00542 exception_check(v[1]==varIndex(i),"DynamicProgramming::_ConsistencyCheck(): v[1]!=varIndex(i)");
00543 _pwDirection[i]=Reverse;
00544 }
00545 else
00546 throw std::runtime_error("DynamicProgramming::_ConsistencyCheck(): pairwise factor does not correspond to unaries!");
00547 }
00548 }
00549
00550 template<class GM>
00551 void SequenceStorage<GM>::_Reset(const IndexList& numOfSequencesPerFactor)
00552 {
00553 for (IndexType var=0;var<size();++var)
00554 _Reset(var,numOfSequencesPerFactor[var]);
00555 };
00556
00557 template<class GM>
00558 void SequenceStorage<GM>::_Reset(IndexType var,IndexType numOfSequences)
00559 {
00560 assert(var<size());
00561 UnaryFactor& uf=_unaryFactors[var];
00562 _masterModel[_var2FactorMap(varIndex(var))].copyValues(uf.begin());
00563 transform_inplace(uf.begin(),uf.end(),std::bind2nd(std::multiplies<ValueType>(),1.0/numOfSequences));
00564
00565 };
00566
00567 template<class GM>
00568 void SequenceStorage<GM>::AllocateUnaryFactors(std::vector<UnaryFactor>* pfactors)
00569 {
00570 pfactors->resize(size());
00571 for (size_t i=0;i<pfactors->size();++i)
00572 (*pfactors)[i].assign(_masterModel[_var2FactorMap(varIndex(i))].size(),0.0);
00573 };
00574
00575 template<class GM>
00576 typename SequenceStorage<GM>::MoveDirection SequenceStorage<GM>::ReverseDirection(MoveDirection dir)
00577 {
00578 if (dir==Direct)
00579 return Reverse;
00580 else
00581 return Direct;
00582 }
00583
00584 template<class GM>
00585 template<class ITERATOR>
00586 typename SequenceStorage<GM>::ValueType
00587 SequenceStorage<GM>::evaluate(ITERATOR labeling)
00588 {
00589 ValueType value=0.0;
00590 for (size_t i=0;i<size();++i)
00591 {
00592 value+=_unaryFactors[i][*labeling];
00593 if (i<size()-1)
00594 {
00595 if (pwDirection(i)==Direct)
00596 value+=_masterModel[_pwForwardIndex[i]](labeling);
00597 else
00598 {
00599 std::valarray<LabelType> ind(2);
00600 ind[0]=*(labeling+1); ind[1]=*labeling;
00601 value+= _masterModel[_pwForwardIndex[i]](labeling);
00602 }
00603 }
00604 ++labeling;
00605 }
00606 return value;
00607 }
00608
00609
00610 template<class GM,class ACC,class InputIterator>
00611 const typename DynamicProgramming<GM,ACC,InputIterator>::IndexType DynamicProgramming<GM,ACC,InputIterator>::NaN=std::numeric_limits<IndexType>::max();
00612
00613 template<class GM,class ACC,class InputIterator>
00614 DynamicProgramming<GM,ACC,InputIterator>::DynamicProgramming(Storage& storage,const FactorProperties& factorProperties,bool fastComputation)
00615 :_fastComputation(fastComputation),
00616 _storage(storage),
00617 _factorProperties(factorProperties),
00618 _objectiveValue(0.0),
00619 _rho(1.0),
00620 _moveDirection(Storage::Direct),
00621 _bInitializationNeeded(true),
00622 _currentPWFactor(0),
00623 _currentUnaryFactor(0),
00624
00625 _currentUnaryIndex(NaN)
00626 {
00627 _storage.AllocateUnaryFactors(&_marginals);
00628 };
00629
00630 #ifdef TRWS_DEBUG_OUTPUT
00631 template<class GM,class ACC,class InputIterator>
00632 void DynamicProgramming<GM,ACC,InputIterator>::PrintTestData(std::ostream& fout)const
00633 {
00634 fout << "_marginals:" <<std::endl<<_marginals;
00635 fout << "_objectiveValue="<<_objectiveValue<<std::endl;
00636 fout << "_rho="<<_rho<<std::endl;
00637 fout << "_moveDirection="<< _moveDirection<<std::endl;
00638 fout << "_currentPWFactor="<<_currentPWFactor;
00639 fout << "_currentUnaryFactor="<<_currentUnaryFactor;
00640 fout << "_currentUnaryIndex=" <<_currentUnaryIndex<<std::endl;
00641 };
00642 #endif
00643
00644 template<class GM,class ACC,class InputIterator>
00645 typename DynamicProgramming<GM,ACC,InputIterator>::IndexType
00646 DynamicProgramming<GM,ACC,InputIterator>::_core_next(IndexType begin,MoveDirection dir)const
00647 {
00648 if (dir==Storage::Direct)
00649 {
00650 assert(begin<_storage.size()-1);
00651 return ++begin;
00652 }
00653 else
00654 {
00655 assert((begin>0) && (begin<_storage.size()));
00656 return --begin;
00657 }
00658 }
00659
00660 template<class GM,class ACC,class InputIterator>
00661 typename DynamicProgramming<GM,ACC,InputIterator>::IndexType
00662 DynamicProgramming<GM,ACC,InputIterator>::_next(IndexType begin)const
00663 {
00664 return _core_next(begin,_moveDirection);
00665 }
00666
00667 template<class GM,class ACC,class InputIterator>
00668 typename DynamicProgramming<GM,ACC,InputIterator>::IndexType
00669 DynamicProgramming<GM,ACC,InputIterator>::_previous(IndexType begin)const
00670 {
00671 if (_moveDirection==Storage::Direct)
00672 return _core_next(begin,Storage::Reverse);
00673 else
00674 return _core_next(begin,Storage::Direct);
00675 }
00676
00677 template<class GM,class ACC,class InputIterator>
00678 typename DynamicProgramming<GM,ACC,InputIterator>::IndexType
00679 DynamicProgramming<GM,ACC,InputIterator>::_nextPWIndex()const
00680 {
00681 if (_moveDirection==Storage::Direct)
00682 return _currentUnaryIndex;
00683 else
00684 return _currentUnaryIndex-1;
00685 }
00686
00687
00688 template<class GM,class ACC,class InputIterator>
00689 void DynamicProgramming<GM,ACC,InputIterator>::_makeLocalCopyOfPWFactor(LabelType trgsize)
00690 {
00691 const Factor& f=_storage.masterModel()[_storage.pwForwardFactor(_nextPWIndex())];
00692 _currentPWFactor.resize(f.size());
00693 if ( ((_moveDirection==Storage::Direct) && (_storage.pwDirection(_nextPWIndex())==Storage::Direct)) ||
00694 ((_moveDirection==Storage::Reverse) && (_storage.pwDirection(_nextPWIndex())==Storage::Reverse)) )
00695 f.copyValues(_currentPWFactor.begin());
00696 else
00697 f.copyValuesSwitchedOrder(_currentPWFactor.begin());
00698 }
00699
00700
00701 template<class GM,class ACC,class InputIterator>
00702 void DynamicProgramming<GM,ACC,InputIterator>::_PushMessagesToFactor()
00703 {
00704 LabelType trgsize=_storage.unaryFactors(_next(_currentUnaryIndex)).size();
00705
00706
00707 _makeLocalCopyOfPWFactor(trgsize);
00708 assert(_currentPWFactor.size()==(_currentUnaryFactor.size()*trgsize));
00709
00710 if (_rho!=1.0) std::transform(_currentPWFactor.begin(),_currentPWFactor.end(),_currentPWFactor.begin(),std::bind2nd(std::multiplies<ValueType>(),1.0/_rho));
00711
00712 _spst.resize(_currentUnaryFactor.size(),trgsize);
00713
00714
00715 for (LabelType i=0;i<_currentUnaryFactor.size();++i)
00716 transform_inplace(_spst.beginSrcNC(&_currentPWFactor[0],i),_spst.endSrcNC(&_currentPWFactor[0],i),std::bind2nd(std::plus<ValueType>(),_currentUnaryFactor[i]));
00717 }
00718
00719 template<class GM,class ACC,class InputIterator>
00720 void DynamicProgramming<GM,ACC,InputIterator>::_InitCurrentUnaryBuffer(IndexType index)
00721 {
00722 assert(index < _storage.size());
00723 _currentUnaryIndex=index;
00724 _currentUnaryFactor.resize(_storage.unaryFactors(_currentUnaryIndex).size());
00725 std::copy(_storage.unaryFactors(_currentUnaryIndex).begin(),_storage.unaryFactors(_currentUnaryIndex).end(),_currentUnaryFactor.begin());
00726 }
00727
00728 template<class T,class Iterator,class Comp>
00729 T _MaxNormalize_inplace(Iterator begin, Iterator end, T init,Comp comp)
00730 {
00731 T max=*std::max_element(begin,end,comp);
00732 transform_inplace(begin,end,std::bind2nd(std::minus<T>(),max));
00733 return init+max;
00734 }
00735
00736
00737 template<class GM,class ACC,class InputIterator>
00738 void DynamicProgramming<GM,ACC,InputIterator>::_ClearMessages(UnaryFactor* pbuffer)
00739 {
00740 LabelType srcsize=_storage.unaryFactors(_previous(_currentUnaryIndex)).size();
00741
00742 _spst.resize(srcsize,_currentUnaryFactor.size());
00743
00744 if (pbuffer==0)
00745 {
00746 for (LabelType i=0;i<_currentUnaryFactor.size();++i)
00747 _currentUnaryFactor[i]+=_MaxNormalize_inplace(_spst.beginTrgNC(&_currentPWFactor[0],i),_spst.endTrgNC(&_currentPWFactor[0],i),(ValueType)0.0,ACC::template ibop<ValueType>);
00748 }
00749 else
00750 {
00751 pbuffer->resize(_currentUnaryFactor.size());
00752 for (LabelType i=0;i<_currentUnaryFactor.size();++i)
00753 _currentUnaryFactor[i]+=(*pbuffer)[i]=_MaxNormalize_inplace(_spst.beginTrgNC(&_currentPWFactor[0],i),_spst.endTrgNC(&_currentPWFactor[0],i),(ValueType)0.0,ACC::template ibop<ValueType>);
00754 }
00755 }
00756
00757 template<class GM,class ACC,class InputIterator>
00758 void DynamicProgramming<GM,ACC,InputIterator>::_Push()
00759 {
00760
00761 _PushMessagesToFactor();
00762 _InitCurrentUnaryBuffer(_next(_currentUnaryIndex));
00763
00764 _ClearMessages();
00765 _BackUpForwardMarginals();
00766 }
00767
00768 template<class GM,class ACC,class InputIterator>
00769 void DynamicProgramming<GM,ACC,InputIterator>::UpdateMarginals()
00770 {
00771 std::copy(_currentUnaryFactor.begin(),_currentUnaryFactor.end(),_marginals[_currentUnaryIndex].begin());
00772 }
00773
00774 template<class GM,class ACC,class InputIterator>
00775 void DynamicProgramming<GM,ACC,InputIterator>::_SumUpBufferToMarginals()
00776 {
00777 UnaryFactor& marginals=_marginals[_currentUnaryIndex];
00778 std::transform(_currentUnaryFactor.begin(),_currentUnaryFactor.end(),marginals.begin(),marginals.begin(),std::plus<ValueType>());
00779 std::transform(marginals.begin(),marginals.end(),_storage.unaryFactors(_currentUnaryIndex).begin(),marginals.begin(),plus2ndMul<ValueType>(-1.0/_rho));
00780 }
00781
00782
00783 template<class GM,class ACC,class InputIterator>
00784 void DynamicProgramming<GM,ACC,InputIterator>::Move()
00785 {
00786 if (_bInitializationNeeded)
00787 {
00788 InitReverseMove();
00789 _bInitializationNeeded=false;
00790 }
00791
00792 for (IndexType i=0;i<_storage.size()-1;++i)
00793 {
00794 _Push();
00795 UpdateMarginals();
00796 }
00797
00798
00799 FinalizeMove();
00800 }
00801
00805 template<class GM,class ACC,class InputIterator>
00806 void DynamicProgramming<GM,ACC,InputIterator>::PushBack()
00807 {
00808 if (_bInitializationNeeded)
00809 {
00810 _InitReverseMoveBack();
00811 }
00812
00813 _Push();
00814 _SumUpBufferToMarginals();
00815 }
00816
00820 template<class GM,class ACC,class InputIterator>
00821 void DynamicProgramming<GM,ACC,InputIterator>::MoveBack()
00822 {
00823
00824 for (IndexType i=0;i<_storage.size()-1;++i)
00825 PushBack();
00826
00827 FinalizeMove();
00828 }
00829
00830 template<class GM,class ACC,class InputIterator>
00831 void DynamicProgramming<GM,ACC,InputIterator>::_core_InitMoves(ValueType rho,MoveDirection movedirection)
00832 {
00833 _rho=rho;
00834 _moveDirection=movedirection;
00835
00836 if (_moveDirection==Storage::Direct)
00837 _InitCurrentUnaryBuffer(0);
00838 else
00839 _InitCurrentUnaryBuffer(_storage.size()-1);
00840
00841 _bInitializationNeeded=false;
00842
00843 }
00844
00845 template<class GM,class ACC,class InputIterator>
00846 void DynamicProgramming<GM,ACC,InputIterator>::_InitMove(ValueType rho,MoveDirection movedirection)
00847 {
00848 _core_InitMoves(rho,movedirection);
00849
00850 UpdateMarginals();
00851 }
00852
00853 template<class GM,class ACC,class InputIterator>
00854 void DynamicProgramming<GM,ACC,InputIterator>::FinalizeMove()
00855 {
00856 _objectiveValue=ComputeObjectiveValue();
00857 _bInitializationNeeded=true;
00858 };
00859
00860
00861 template<class GM,class ACC,class InputIterator>
00862 void DynamicProgramming<GM,ACC,InputIterator>::IncreaseUnaryWeights(InputIteratorType begin,InputIteratorType end)
00863 {
00864 exception_check((LabelType)abs(end-begin)==_storage.unaryFactors(_currentUnaryIndex).size(),"SumProdSequenceTRWSSolver::IncreaseUnaryWeights(): (end-begin)!=unaryFactor.size()");
00865
00866 std::transform(begin,end,_storage.ufBegin(_currentUnaryIndex),_storage.ufBegin(_currentUnaryIndex),std::plus<ValueType>());
00867 std::transform(_currentUnaryFactor.begin(),_currentUnaryFactor.end(),begin,_currentUnaryFactor.begin(),plus2ndMul<ValueType>(1.0/_rho));
00868 }
00869
00870 template<class GM,class ACC,class InputIterator>
00871 typename DynamicProgramming<GM,ACC,InputIterator>::IndexType
00872 DynamicProgramming<GM,ACC,InputIterator>::getPrevPWId()const
00873 {
00874 if (_currentUnaryIndex >= _storage.size()) return NaN;
00875
00876 if (_moveDirection==Storage::Direct)
00877 return (_currentUnaryIndex==0 ? NaN : _storage.pwForwardFactor(_currentUnaryIndex-1));
00878 else
00879 return (_currentUnaryIndex==_storage.size()-1 ? NaN : _storage.pwForwardFactor(_currentUnaryIndex));
00880 }
00881
00882 template<class GM,class ACC,class InputIterator>
00883 typename DynamicProgramming<GM,ACC,InputIterator>::IndexType
00884 DynamicProgramming<GM,ACC,InputIterator>::getNextPWId()const
00885 {
00886 if (_currentUnaryIndex >= (IndexType)_storage.size()) return NaN;
00887
00888 if (_moveDirection==Storage::Direct)
00889 return (_currentUnaryIndex==_storage.size()-1 ? NaN : _storage.pwForwardFactor(_currentUnaryIndex));
00890 else
00891 return (_currentUnaryIndex==0 ? NaN : _storage.pwForwardFactor(_currentUnaryIndex-1));
00892 }
00893
00894 template<class GM,class ACC,class InputIterator>
00895 void MaxSumSolver<GM,ACC,InputIterator>::_SumUpBackwardEdges(UnaryFactor* pu, LabelType fixedLabel)const
00896 {
00897 UnaryFactor& u=*pu;
00898 IndexType factorId=parent::getPrevPWId();
00899 if ((parent::_factorProperties.getFunctionType(factorId)==FunctionParameters<GM>::POTTS) && parent::_fastComputation)
00900 {
00901 u[fixedLabel]-=parent::_factorProperties.getFunctionParameters(factorId)[0];
00902 }else
00903 {
00904 const typename GM::FactorType& pwfactor=parent::_storage.masterModel()[factorId];
00905
00906 OPENGM_ASSERT( (parent::_storage.varIndex(parent::_currentUnaryIndex)==pwfactor.variableIndex(0)) || (parent::_storage.varIndex(parent::_currentUnaryIndex)==pwfactor.variableIndex(1)));
00907
00908 IndexType localVarIndx = (parent::_storage.varIndex(parent::_currentUnaryIndex)==pwfactor.variableIndex(0) ? 1 : 0);
00909 opengm::ViewFixVariablesFunction<GM> pencil(pwfactor,
00910 std::vector<opengm::PositionAndLabel<IndexType,LabelType> >(1,
00911 opengm::PositionAndLabel<IndexType,LabelType>(localVarIndx,
00912 fixedLabel)));
00913
00914 for (LabelType j=0;j<u.size();++j)
00915 u[j]+=pencil(&j);
00916 }
00917 }
00918
00919
00920
00921 template<class GM,class ACC,class InputIterator>
00922 void SumProdSolver<GM,ACC,InputIterator>::_PushMessagesToVariable()
00923 {
00924 LabelType srcsize=parent::_marginals[parent::_previous(parent::_currentUnaryIndex)].size();
00925
00926 parent::_spst.resize(srcsize,parent::_currentUnaryFactor.size());
00927
00928
00929 for (LabelType i=0;i<parent::_currentUnaryFactor.size();++i)
00930 parent::_currentUnaryFactor[i]+=_mul*log(std::accumulate(parent::_spst.beginTrg(&parent::_currentPWFactor[0],i),parent::_spst.endTrg(&parent::_currentPWFactor[0],i),ValueType(0.0)));
00931 }
00932
00933 template<class GM,class ACC,class InputIterator>
00934 void SumProdSolver<GM,ACC,InputIterator>::_UpdatePWAverage()
00935 {
00936 std::transform(_unaryBuffer.begin(),_unaryBuffer.end(),parent::_marginals[parent::_currentUnaryIndex].begin(),
00937 _unaryBuffer.begin(),std::plus<ValueType>());
00938 transform_inplace(_unaryBuffer.begin(),_unaryBuffer.end(),std::bind2nd(std::minus<ValueType>(),_getMarginalsLogNormalizer()));
00939 transform_inplace(_unaryBuffer.begin(),_unaryBuffer.end(),mulAndExp<ValueType>(_mul));
00940
00941 LabelType srcsize=parent::_marginals[parent::_previous(parent::_currentUnaryIndex)].size();
00942 parent::_spst.resize(srcsize,parent::_currentUnaryFactor.size());
00943
00944
00945 for (LabelType i=0;i<parent::_currentUnaryFactor.size();++i)
00946 _unaryBuffer[i]*=std::inner_product(parent::_spst.beginTrg(&parent::_currentPWFactor[0],i),
00947 parent::_spst.endTrg(&parent::_currentPWFactor[0],i),
00948 parent::_spst.beginTrg(&_copyPWfactor[0],i),
00949 ValueType(0.0));
00950
00951 _derivativeValue+=std::accumulate(_unaryBuffer.begin(),_unaryBuffer.end(),(ValueType)0.0);
00952 }
00953
00954 template<class GM,class ACC,class InputIterator>
00955 void SumProdSolver<GM,ACC,InputIterator>::_PushAndAverage()
00956 {
00957
00958 parent::_PushMessagesToFactor();
00959 _InitCurrentUnaryBuffer(parent::_next(parent::_currentUnaryIndex));
00960
00961
00962 parent::_ClearMessages(&_unaryBuffer);
00963
00964
00965 _ExponentiatePWFactor();
00966
00967 _UpdatePWAverage();
00968
00969
00970 _PushMessagesToVariable();
00971
00972 }
00973
00974 template<class GM,class ACC,class InputIterator>
00975 typename SumProdSolver<GM,ACC,InputIterator>::ValueType
00976 SumProdSolver<GM,ACC,InputIterator>::_GetAveragedUnaryFactors()
00977 {
00978 ValueType unaryAverage=0.0;
00979 for (size_t i=0;i<parent::size();++i)
00980 {
00981 _unaryBuffer.resize(parent::_marginals[i].size());
00982 std::transform(parent::_marginals[i].begin(),parent::_marginals[i].end(),_unaryBuffer.begin(),std::bind2nd(std::minus<ValueType>(),_getMarginalsLogNormalizer()));
00983 transform_inplace(_unaryBuffer.begin(),_unaryBuffer.end(),mulAndExp<ValueType>(_mul));
00984 unaryAverage+=std::inner_product(_unaryBuffer.begin(),_unaryBuffer.end(),parent::_storage.unaryFactors(i).begin(),(ValueType)0.0);
00985 }
00986 return unaryAverage;
00987 }
00988
00989 template<class GM,class ACC,class InputIterator>
00990 typename SumProdSolver<GM,ACC,InputIterator>::ValueType
00991 SumProdSolver<GM,ACC,InputIterator>::MoveBackGetDerivative()
00992 {
00993 if (parent::_bInitializationNeeded)
00994 {
00995 parent::_InitReverseMoveBack();
00996 }
00997
00998 _averagingFlag=true;
00999 _derivativeValue=0.0;
01000 for (size_t i=0;i<parent::size()-1;++i)
01001 {
01002 _PushAndAverage();
01003 parent::_SumUpBufferToMarginals();
01004 }
01005
01006 _derivativeValue+=_GetAveragedUnaryFactors();
01007 parent::FinalizeMove();
01008 _averagingFlag=false;
01009 _derivativeValue=(parent::GetObjectiveValue()-_derivativeValue)/parent::_rho;
01010 return _derivativeValue;
01011 }
01012
01013 template<class GM,class ACC,class InputIterator>
01014 void SumProdSolver<GM,ACC,InputIterator>::_Push()
01015 {
01016 parent::_Push();
01017
01018
01019 _ExponentiatePWFactor();
01020
01021
01022 _PushMessagesToVariable();
01023 }
01024
01025 template <class T,class ACC> struct thresholdMulAndExp : std::unary_function <T,T> {
01026 thresholdMulAndExp(T threshold):_mul(ACC::template bop<T>(1.0,0.0) ? 1.0 : -1.0),_threshold(threshold){};
01027 T operator() (T x)
01028 {_buf=fabs(x); return (_buf >= _threshold ? 0.0 : exp(-_buf));}
01029 private:
01030 T _mul;
01031 T _threshold;
01032 T _buf;
01033 };
01034
01035
01036
01037 template<class GM,class ACC,class InputIterator>
01038 void SumProdSolver<GM,ACC,InputIterator>::_ExponentiatePWFactor()
01039 {
01040 transform_inplace(parent::_currentPWFactor.begin(),parent::_currentPWFactor.end(),thresholdMulAndExp<ValueType,ACC>(-log(std::numeric_limits<ValueType>::epsilon())));
01041 }
01042
01043 template<class T,class InputIterator,class OutputIterator,class Comp >
01044 T _MaxNormalize(InputIterator begin, InputIterator end, OutputIterator outBegin, T init,Comp comp)
01045 {
01046 T max=*std::max_element(begin,end,comp);
01047 std::transform(begin,end,outBegin,std::bind2nd(std::minus<T>(),max));
01048 return init+max;
01049 }
01050
01051
01052 template<class GM,class ACC,class InputIterator>
01053 typename SumProdSolver<GM,ACC,InputIterator>::ValueType
01054 SumProdSolver<GM,ACC,InputIterator>::ComputeObjectiveValue()
01055 {
01056 typename UnaryFactor::const_iterator begin=parent::_marginals[parent::_currentUnaryIndex].begin(),
01057 end=parent::_marginals[parent::_currentUnaryIndex].end();
01058 parent::_unaryTemp.resize(end-begin);
01059 ValueType logPartition= parent::_rho*_MaxNormalize(begin,end,parent::_unaryTemp.begin(),(ValueType)0.0,ACC::template ibop<ValueType>);
01060 std::transform(parent::_unaryTemp.begin(),parent::_unaryTemp.end(),parent::_unaryTemp.begin(),mulAndExp<ValueType>(_mul));
01061 logPartition+=_mul*parent::_rho*(log(std::accumulate(parent::_unaryTemp.begin(),parent::_unaryTemp.end(),(ValueType)0.0)));
01062 return logPartition;
01063 }
01064
01065 template<class GM,class ACC,class InputIterator>
01066 void SumProdSolver<GM,ACC,InputIterator>::_makeLocalCopyOfPWFactor(LabelType trgsize)
01067 {
01068 parent::_makeLocalCopyOfPWFactor(trgsize);
01069 if (_averagingFlag)
01070 _copyPWfactor=parent::_currentPWFactor;
01071 }
01072
01073 template<class GM,class ACC,class InputIterator>
01074 void SumProdSolver<GM,ACC,InputIterator>::_InitCurrentUnaryBuffer(IndexType index)
01075 {
01076 parent::_InitCurrentUnaryBuffer(index);
01077
01078 if (parent::_rho!=1.0) transform_inplace(parent::_currentUnaryFactor.begin(),parent::_currentUnaryFactor.end(),std::bind2nd(std::multiplies<ValueType>(),1.0/parent::_rho));
01079 }
01080
01081 };
01082
01083
01084 #endif