00001 #ifndef TRWS_BASE_HXX_
00002 #define TRWS_BASE_HXX_
00003 #include <iostream>
00004 #include <time.h>
00005 #include <opengm/inference/trws/trws_decomposition.hxx>
00006 #include <opengm/inference/trws/trws_subproblemsolver.hxx>
00007 #include <functional>
00008 #include <opengm/functions/view_fix_variables_function.hxx>
00009 #include <opengm/inference/inference.hxx>
00010 #include <opengm/inference/visitors/visitor.hxx>
00011
00012 namespace trws_base{
00013
00014 template<class GM>
00015 class DecompositionStorage
00016 {
00017 public:
00018 typedef GM GraphicalModelType;
00019 typedef SequenceStorage<GM> SubModel;
00020 typedef typename GM::ValueType ValueType;
00021 typedef typename GM::IndexType IndexType;
00022 typedef typename GM::LabelType LabelType;
00023 typedef typename MonotoneChainsDecomposition<GM>::SubVariable SubVariable;
00024 typedef typename MonotoneChainsDecomposition<GM>::SubVariableListType SubVariableListType;
00025 typedef typename SubModel::UnaryFactor UnaryFactor;
00026 typedef enum {GRIDSTRUCTURE, GENERALSTRUCTURE} StructureType;
00027 typedef VariableToFactorMapping<GM> VariableToFactorMap;
00028
00029 DecompositionStorage(const GM& gm,StructureType structureType=GENERALSTRUCTURE);
00030 ~DecompositionStorage();
00031
00032 const GM& masterModel()const{return _gm;}
00033 LabelType numberOfLabels(IndexType varId)const{return _gm.numberOfLabels(varId);}
00034 IndexType numberOfModels()const{return (IndexType)_subModels.size();}
00035 IndexType numberOfSharedVariables()const{return (IndexType)_variableDecomposition.size();}
00036 SubModel& subModel(IndexType modelId){return *_subModels[modelId];}
00037 const SubModel& subModel(IndexType modelId)const{return *_subModels[modelId];}
00038 IndexType size(IndexType subModelId)const{return (IndexType)_subModels[subModelId]->size();}
00039
00040 const SubVariableListType& getSubVariableList(IndexType varId)const{return _variableDecomposition[varId];}
00041 StructureType getStructureType()const{return _structureType;}
00042 #ifdef TRWS_DEBUG_OUTPUT
00043 void PrintTestData(std::ostream& fout)const;
00044 void PrintVariableDecompositionConsistency(std::ostream& fout)const;
00045 #endif
00046
00047 private:
00048 void _InitSubModels();
00049 const GM& _gm;
00050 StructureType _structureType;
00051 std::vector<SubModel*> _subModels;
00052 std::vector<SubVariableListType> _variableDecomposition;
00053 VariableToFactorMap _var2FactorMap;
00054 };
00055
00056 template<class VISITOR, class INFERENCE_TYPE>
00057 class VisitorWrapper
00058 {
00059 public:
00060 typedef VISITOR VisitorType;
00061 typedef INFERENCE_TYPE InferenceType;
00062 typedef typename InferenceType::ValueType ValueType;
00063
00064 VisitorWrapper(VISITOR* pvisitor,INFERENCE_TYPE* pinference)
00065 :_pvisitor(pvisitor),
00066 _pinference(pinference){};
00067 void begin(ValueType value,ValueType bound){_pvisitor->begin(*_pinference,value,bound);}
00068 void end(ValueType value,ValueType bound){_pvisitor->end(*_pinference,value,bound);}
00069 void operator() (ValueType value,ValueType bound){(*_pvisitor)(*_pinference,value,bound);}
00070 private:
00071 VISITOR* _pvisitor;
00072 INFERENCE_TYPE* _pinference;
00073 };
00074
00075 template<class ValueType>
00076 struct TRWSPrototype_Parameters
00077 {
00078 size_t maxNumberOfIterations_;
00079 ValueType precision_;
00080 bool absolutePrecision_;
00081 ValueType minRelativeDualImprovement_;
00082 bool fastComputations_;
00083
00084 TRWSPrototype_Parameters(size_t maxIternum,
00085 ValueType precision=1.0,
00086 bool absolutePrecision=true,
00087 ValueType minRelativeDualImprovement=-1.0,
00088 bool fastComputations=true):
00089 maxNumberOfIterations_(maxIternum),
00090 precision_(precision),
00091 absolutePrecision_(absolutePrecision),
00092 minRelativeDualImprovement_(minRelativeDualImprovement),
00093 fastComputations_(fastComputations)
00094 {};
00095 };
00096
00097 template<class GM>
00098 class PreviousFactorTable
00099 {
00100 public:
00101 typedef typename GM::IndexType IndexType;
00102 typedef SequenceStorage<GM> Storage;
00103 typedef typename Storage::MoveDirection MoveDirection;
00104 struct FactorVarID
00105 {
00106 FactorVarID(){};
00107 FactorVarID(IndexType fID,IndexType vID,IndexType lID):
00108 factorId(fID),varId(vID),localId(lID){};
00109
00110 #ifdef TRWS_DEBUG_OUTPUT
00111 void print(std::ostream& out)const{out <<"("<<factorId<<","<<varId<<","<<localId<<"),";}
00112 #endif
00113
00114 IndexType factorId;
00115 IndexType varId;
00116 IndexType localId;
00117 };
00118 typedef std::vector<FactorVarID> FactorList;
00119 typedef typename FactorList::const_iterator const_iterator;
00120
00121 PreviousFactorTable(const GM& gm);
00122 const_iterator begin(IndexType varId,MoveDirection md)const{return (md==Storage::Direct ? _forwardFactors[varId].begin() : _backwardFactors[varId].begin());}
00123 const_iterator end(IndexType varId,MoveDirection md)const{return (md==Storage::Direct ? _forwardFactors[varId].end() : _backwardFactors[varId].end());}
00124 #ifdef TRWS_DEBUG_OUTPUT
00125 void PrintTestData(std::ostream& fout);
00126 #endif
00127 private:
00128 std::vector<FactorList> _forwardFactors;
00129 std::vector<FactorList> _backwardFactors;
00130 };
00131
00132 template<class GM>
00133 PreviousFactorTable<GM>::PreviousFactorTable(const GM& gm):
00134 _forwardFactors(gm.numberOfVariables()),
00135 _backwardFactors(gm.numberOfVariables())
00136 {
00137 std::vector<IndexType> varIDs(2);
00138 for (IndexType factorId=0;factorId<gm.numberOfFactors();++factorId)
00139 {
00140 switch (gm[factorId].numberOfVariables())
00141 {
00142 case 1: break;
00143 case 2:
00144 gm[factorId].variableIndices(varIDs.begin());
00145 if (varIDs[0] < varIDs[1])
00146 {
00147 _forwardFactors[varIDs[1]].push_back(FactorVarID(factorId,varIDs[0],0));
00148 _backwardFactors[varIDs[0]].push_back(FactorVarID(factorId,varIDs[1],1));
00149 }
00150 else
00151 {
00152 _forwardFactors[varIDs[0]].push_back(FactorVarID(factorId,varIDs[1],1));
00153 _backwardFactors[varIDs[1]].push_back(FactorVarID(factorId,varIDs[0],0));
00154 }
00155 break;
00156 default: throw std::runtime_error("PreviousFactor::PreviousFactor: only the factors of order <=2 are supported!");
00157 }
00158 }
00159 }
00160
00161 #ifdef TRWS_DEBUG_OUTPUT
00162 template<class GM>
00163 void PreviousFactorTable<GM>::PrintTestData(std::ostream& fout)
00164 {
00165 fout << "Forward factors:"<<std::endl;
00166 for (size_t varId=0;varId<_forwardFactors.size();++varId)
00167 {
00168 fout << "varId="<<varId<<", ";
00169 for (size_t i=0;i<_forwardFactors[varId].size();++i)
00170 _forwardFactors[varId][i].print(fout);
00171 fout <<std::endl;
00172 }
00173
00174 fout << "Backward factors:"<<std::endl;
00175 for (size_t varId=0;varId<_backwardFactors.size();++varId)
00176 {
00177 fout << "varId="<<varId<<", ";
00178 for (size_t i=0;i<_backwardFactors[varId].size();++i)
00179 _backwardFactors[varId][i].print(fout);
00180 fout <<std::endl;
00181 }
00182 }
00183 #endif
00184
00185 template <class SubSolver>
00186 class TRWSPrototype
00187 {
00188 public:
00189 typedef typename SubSolver::GMType GM;
00190 typedef GM GraphicalModelType;
00191 typedef typename SubSolver::ACCType ACC;
00192 typedef ACC AccumulationType;
00193 typedef SubSolver SubSolverType;
00194 typedef FunctionParameters<GM> FactorProperties;
00195 typedef opengm::EmptyVisitor< TRWSPrototype<SubSolverType> > EmptyVisitorParent;
00196 typedef VisitorWrapper<EmptyVisitorParent,TRWSPrototype<SubSolver> > EmptyVisitorType;
00197
00198 typedef typename SubSolver::const_iterators_pair const_marginals_iterators_pair;
00199 typedef typename GM::ValueType ValueType;
00200 typedef typename GM::IndexType IndexType;
00201 typedef typename GM::LabelType LabelType;
00202 typedef opengm::InferenceTermination InferenceTermination;
00203 typedef typename std::vector<ValueType> OutputContainerType;
00204 typedef typename OutputContainerType::iterator OutputIteratorType;
00205
00206 typedef TRWSPrototype_Parameters<ValueType> Parameters;
00207
00208 typedef SequenceStorage<GM> SubModel;
00209 typedef DecompositionStorage<GM> Storage;
00210 typedef typename Storage::UnaryFactor UnaryFactor;
00211
00212 TRWSPrototype(Storage& storage,const Parameters& params
00213 #ifdef TRWS_DEBUG_OUTPUT
00214 ,std::ostream& fout=std::cout
00215 #endif
00216 );
00217 virtual ~TRWSPrototype();
00218
00219 virtual ValueType GetBestIntegerBound()const{return _bestIntegerBound;};
00220 virtual ValueType value()const{return _bestIntegerBound;}
00221 virtual ValueType bound()const{return _dualBound;}
00222 virtual const std::vector<LabelType>& arg()const{return _bestIntegerLabeling;}
00223
00224 #ifdef TRWS_DEBUG_OUTPUT
00225 virtual void PrintTestData(std::ostream& fout)const;
00226 #endif
00227
00228 bool CheckDualityGap(ValueType primalBound,ValueType dualBound);
00229 virtual std::pair<ValueType,ValueType> GetMarginals(IndexType variable, OutputIteratorType begin){return std::make_pair((ValueType)0,(ValueType)0);};
00230 void GetMarginalsMove();
00231 void BackwardMove();
00232
00233 ValueType getBound(size_t i)const{return _subSolvers[i]->GetObjectiveValue();}
00234 virtual InferenceTermination infer(){EmptyVisitorParent vis; EmptyVisitorType visitor(&vis,this); return infer(visitor);};
00235 template<class VISITOR> InferenceTermination infer(VISITOR&);
00236 void ForwardMove();
00237 ValueType lastDualUpdate()const{return _lastDualUpdate;}
00238
00239 template<class VISITOR> InferenceTermination infer_visitor_updates(VISITOR&);
00240 InferenceTermination core_infer(){EmptyVisitorParent vis; EmptyVisitorType visitor(&vis,this); return _core_infer(visitor);};
00241 const FactorProperties& getFactorProperties()const{return _factorProperties;}
00242 protected:
00243 void _EstimateIntegerLabeling();
00244 template <class VISITOR> InferenceTermination _core_infer(VISITOR&);
00245 virtual ValueType _GetPrimalBound(){_EvaluateIntegerBounds(); return GetBestIntegerBound();}
00246 virtual void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)=0;
00247 virtual void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver)=0;
00248 void _EvaluateIntegerBounds();
00249
00250
00251
00252
00253 virtual void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair)=0;
00254 void _EstimateIntegerLabel(IndexType varId,const std::vector<ValueType>& sumMarginal)
00255 {_integerLabeling[varId]=std::max_element(sumMarginal.begin(),sumMarginal.end(),ACC::template ibop<ValueType>)-sumMarginal.begin();}
00256
00257 void _InitSubSolvers();
00258 void _ForwardMove();
00259 void _FinalizeMove();
00260 ValueType _GetObjectiveValue();
00261 IndexType _order(IndexType i);
00262 IndexType _core_order(IndexType i,IndexType totalSize);
00263 bool _CheckConvergence(ValueType relativeThreshold);
00264 virtual bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
00265 virtual void _EstimateTRWSBound(){};
00266
00267 virtual void _InitMove()=0;
00268
00269 Storage& _storage;
00270 FactorProperties _factorProperties;
00271 PreviousFactorTable<GM> _ftable;
00272 Parameters _parameters;
00273
00274 #ifdef TRWS_DEBUG_OUTPUT
00275 std::ostream& _fout;
00276 #endif
00277
00278 ValueType _dualBound;
00279 ValueType _oldDualBound;
00280 ValueType _lastDualUpdate;
00281
00282 typename SubModel::MoveDirection _moveDirection;
00283 std::vector<SubSolver*> _subSolvers;
00284
00285 std::vector<std::vector<ValueType> > _marginals;
00286
00287 ValueType _integerBound;
00288 ValueType _bestIntegerBound;
00289
00290 std::vector<LabelType> _integerLabeling;
00291 std::vector<LabelType> _bestIntegerLabeling;
00292
00293
00294 std::vector<ValueType> _sumMarginal;
00295 mutable typename FactorProperties::ParameterStorageType _factorParameters;
00296
00297 private:
00298 TRWSPrototype(TRWSPrototype&);
00299 TRWSPrototype& operator =(TRWSPrototype&);
00300 };
00301
00302 template<class ValueType>
00303 struct SumProdTRWS_Parameters : public TRWSPrototype_Parameters<ValueType>
00304 {
00305 typedef TRWSPrototype_Parameters<ValueType> parent;
00306 ValueType smoothingValue_;
00307 SumProdTRWS_Parameters(size_t maxIternum,
00308 ValueType smValue,
00309 ValueType precision=1.0,
00310 bool absolutePrecision=true,
00311 ValueType minRelativeDualImprovement=2*std::numeric_limits<ValueType>::epsilon(),
00312 bool fastComputations=true)
00313 :parent(maxIternum,precision,absolutePrecision,minRelativeDualImprovement,fastComputations),
00314 smoothingValue_(smValue){};
00315 };
00316
00317 template<class GM,class ACC>
00318 class SumProdTRWS : public TRWSPrototype<SumProdSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> >
00319 {
00320 public:
00321 typedef TRWSPrototype<SumProdSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> > parent;
00322 typedef ACC AccumulationType;
00323 typedef GM GraphicalModelType;
00324 typedef typename parent::SubSolverType SubSolver;
00325 typedef typename parent::const_marginals_iterators_pair const_marginals_iterators_pair;
00326 typedef typename parent::ValueType ValueType;
00327 typedef typename parent::IndexType IndexType;
00328 typedef typename parent::LabelType LabelType;
00329 typedef typename parent::InferenceTermination InferenceTermination;
00330 typedef SequenceStorage<GM> SubModel;
00331 typedef DecompositionStorage<GM> Storage;
00332 typedef typename parent::OutputContainerType OutputContainerType;
00333 typedef typename OutputContainerType::iterator OutputIteratorType;
00334
00335 typedef SumProdTRWS_Parameters<ValueType> Parameters;
00336
00337 SumProdTRWS(Storage& storage,const Parameters& params
00338 #ifdef TRWS_DEBUG_OUTPUT
00339 ,std::ostream& fout=std::cout
00340 #endif
00341 ):
00342 parent(storage,params
00343 #ifdef TRWS_DEBUG_OUTPUT
00344 ,fout
00345 #endif
00346 ),
00347 _smoothingValue(params.smoothingValue_)
00348 {};
00349 ~SumProdTRWS(){};
00350
00351 #ifdef TRWS_DEBUG_OUTPUT
00352 void PrintTestData(std::ostream& fout)const;
00353 #endif
00354
00355 void SetSmoothing(ValueType smoothingValue){_smoothingValue=smoothingValue;_InitMove();}
00356 ValueType GetSmoothing()const{return _smoothingValue;}
00357
00358
00359
00360
00361 std::pair<ValueType,ValueType> GetMarginals(IndexType variable, OutputIteratorType begin);
00362 ValueType GetMarginalsAndDerivativeMove();
00363 ValueType getDerivative(size_t i)const{return parent::_subSolvers[i]->getDerivative();}
00364
00365 protected:
00366 void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair);
00367 void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end);
00368 void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver);
00369 void _InitMove();
00370
00371
00372 ValueType _smoothingValue;
00373 };
00374
00375 template<class ValueType>
00376 struct MaxSumTRWS_Parameters : public TRWSPrototype_Parameters<ValueType>
00377 {
00378 typedef TRWSPrototype_Parameters<ValueType> parent;
00379 MaxSumTRWS_Parameters(size_t maxIternum,
00380 ValueType precision=1.0,
00381 bool absolutePrecision=true,
00382 ValueType minRelativeDualImprovement=-1.0,
00383 bool fastComputations=true,
00384 bool canonicalNormalization=false):
00385 parent(maxIternum,precision,absolutePrecision,minRelativeDualImprovement,fastComputations),
00386 canonicalNormalization_(canonicalNormalization){};
00387
00388 bool canonicalNormalization_;
00389 };
00390
00391 template<class GM,class ACC>
00392 class MaxSumTRWS : public TRWSPrototype<MaxSumSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> >
00393 {
00394 public:
00395 typedef TRWSPrototype<MaxSumSolver<GM,ACC,typename std::vector<typename GM::ValueType>::const_iterator> > parent;
00396
00397 typedef typename parent::SubSolverType SubSolver;
00398 typedef typename parent::const_marginals_iterators_pair const_marginals_iterators_pair;
00399 typedef typename parent::ValueType ValueType;
00400 typedef typename parent::IndexType IndexType;
00401 typedef typename parent::LabelType LabelType;
00402 typedef typename parent::InferenceTermination InferenceTermination;
00403 typedef typename parent::EmptyVisitorType EmptyVisitorType;
00404 typedef typename parent::UnaryFactor UnaryFactor;
00405 typedef ACC AccumulationType;
00406 typedef GM GraphicalModelType;
00407 typedef typename parent::OutputContainerType OutputContainerType;
00408
00409 typedef SequenceStorage<GM> SubModel;
00410 typedef DecompositionStorage<GM> Storage;
00411
00412 typedef MaxSumTRWS_Parameters<ValueType> Parameters;
00413
00414 MaxSumTRWS(Storage& storage,const Parameters& params
00415 #ifdef TRWS_DEBUG_OUTPUT
00416 ,std::ostream& fout=std::cout
00417 #endif
00418 ):
00419 parent(storage,params
00420 #ifdef TRWS_DEBUG_OUTPUT
00421 ,fout
00422 #endif
00423 ),
00424 _canonicalNormalization(params.canonicalNormalization_),
00425 _pseudoBoundValue(0.0),
00426 _localConsistencyCounter(0)
00427 {}
00428 ~MaxSumTRWS(){};
00429
00430 void getTreeAgreement(std::vector<bool>& out,std::vector<LabelType>* plabeling=0);
00431 bool CheckTreeAgreement(InferenceTermination* pterminationCode);
00432 protected:
00433 void _SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair);
00434 void _postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end);
00435 void _normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver);
00436 void _InitMove();
00437 void _EstimateTRWSBound();
00438 bool _CheckStoppingCondition(InferenceTermination* pterminationCode);
00439
00440 bool _canonicalNormalization;
00441 ValueType _pseudoBoundValue;
00442 size_t _localConsistencyCounter;
00443
00444
00445
00446 std::vector<bool> _treeAgree;
00447 std::vector<bool> _mask;
00448 std::vector<bool> _nodeMask;
00449 };
00450
00451
00452 template <class SubSolver>
00453 TRWSPrototype<SubSolver>::TRWSPrototype(Storage& storage,const Parameters& params
00454 #ifdef TRWS_DEBUG_OUTPUT
00455 ,std::ostream& fout
00456 #endif
00457 ):
00458 _storage(storage),
00459 _factorProperties(storage.masterModel()),
00460 _ftable(storage.masterModel()),
00461 _parameters(params),
00462 #ifdef TRWS_DEBUG_OUTPUT
00463 _fout(fout),
00464 #endif
00465 _dualBound(ACC::template ineutral<ValueType>()),
00466 _oldDualBound(ACC::template ineutral<ValueType>()),
00467 _lastDualUpdate(0),
00468 _moveDirection(SubModel::Direct),
00469 _subSolvers(),
00470 _marginals(),
00471 _integerBound(ACC::template neutral<ValueType>()),
00472 _bestIntegerBound(ACC::template neutral<ValueType>()),
00473 _integerLabeling(storage.masterModel().numberOfVariables(),0),
00474 _bestIntegerLabeling(storage.masterModel().numberOfVariables(),0),
00475 _sumMarginal()
00476 {
00477 #ifdef TRWS_DEBUG_OUTPUT
00478 _fout.precision(16);
00479 #endif
00480 _InitSubSolvers();
00481 _marginals.resize(_storage.numberOfModels());
00482 #ifdef TRWS_DEBUG_OUTPUT
00483 _factorProperties.PrintStatusData(fout);
00484 #endif
00485 }
00486
00487 template <class SubSolver>
00488 TRWSPrototype<SubSolver>::~TRWSPrototype()
00489 {
00490 for_each(_subSolvers.begin(),_subSolvers.end(),DeallocatePointer<SubSolver>);
00491 };
00492
00493 template <class SubSolver>
00494 void TRWSPrototype<SubSolver>::_InitSubSolvers()
00495 {
00496 _subSolvers.resize(_storage.numberOfModels());
00497 for (size_t modelId=0;modelId<_subSolvers.size();++modelId)
00498 _subSolvers[modelId]= new SubSolver(_storage.subModel(modelId),_factorProperties,_parameters.fastComputations_);
00499 }
00500
00501 template <class SubSolver>
00502 bool TRWSPrototype<SubSolver>::CheckDualityGap(ValueType primalBound,ValueType dualBound)
00503 {
00504
00505
00506 if (_parameters.absolutePrecision_)
00507 {
00508 if (fabs(primalBound-dualBound) <= _parameters.precision_)
00509 {
00510 return true;
00511 }
00512 }
00513 else
00514 {
00515 if (fabs((primalBound-dualBound)/dualBound)<= _parameters.precision_ )
00516 return true;
00517 }
00518 return false;
00519 }
00520
00521 template <class SubSolver>
00522 bool TRWSPrototype<SubSolver>::_CheckConvergence(ValueType relativeThreshold)
00523 {
00524 if (relativeThreshold >=0.0)
00525 {
00526 ValueType mul; ACC::iop(-1.0,1.0,mul);
00527 if (ACC::bop(_dualBound, (_oldDualBound + _dualBound*mul*relativeThreshold)))
00528 return true;
00529 }
00530 return false;
00531 }
00532
00533 template <class SubSolver>
00534 bool TRWSPrototype<SubSolver>::_CheckStoppingCondition(InferenceTermination* pterminationCode)
00535 {
00536 _lastDualUpdate=fabs(_dualBound-_oldDualBound);
00537
00538 if (CheckDualityGap(_bestIntegerBound,_dualBound))
00539 {
00540 #ifdef TRWS_DEBUG_OUTPUT
00541 _fout << "TRWSPrototype::_CheckStoppingCondition(): duality gap <= specified precision!" <<std::endl;
00542 #endif
00543 *pterminationCode=opengm::CONVERGENCE;
00544 return true;
00545 }
00546
00547 if (_CheckConvergence(_parameters.minRelativeDualImprovement_))
00548 {
00549 #ifdef TRWS_DEBUG_OUTPUT
00550 _fout << "TRWSPrototype::_CheckStoppingCondition(): Dual update is smaller than the specified threshold. Stopping"<<std::endl;
00551 #endif
00552 *pterminationCode=opengm::NORMAL;
00553 return true;
00554 }
00555
00556 _oldDualBound=_dualBound;
00557
00558 return false;
00559 }
00560
00561 template <class SubSolver>
00562 template <class VISITOR>
00563 typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::_core_infer(VISITOR& visitor)
00564 {
00565 for (size_t iterationCounter=0;iterationCounter<_parameters.maxNumberOfIterations_;++iterationCounter)
00566 {
00567 #ifdef TRWS_DEBUG_OUTPUT
00568 _fout <<"Iteration Nr."<<iterationCounter<<"-------------------------------------"<<std::endl;
00569 #endif
00570
00571 BackwardMove();
00572
00573 #ifdef TRWS_DEBUG_OUTPUT
00574 _fout << "dualBound=" << _dualBound <<", primalBound="<<_GetPrimalBound() <<std::endl;
00575 #endif
00576 _EstimateTRWSBound();
00577 visitor(value(),bound());
00578 InferenceTermination returncode;
00579 if (_CheckStoppingCondition(&returncode))
00580 return returncode;
00581 }
00582 return opengm::TIMEOUT;
00583 }
00584
00585 template <class SubSolver>
00586 typename TRWSPrototype<SubSolver>::ValueType TRWSPrototype<SubSolver>::_GetObjectiveValue()
00587 {
00588 ValueType dualBound=0;
00589 for (size_t i=0;i<_subSolvers.size();++i)
00590 dualBound+=_subSolvers[i]->GetObjectiveValue();
00591
00592 return dualBound;
00593 }
00594
00595 template <class SubSolver>
00596 void TRWSPrototype<SubSolver>::_ForwardMove()
00597 {
00598 std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::Move));
00599 _moveDirection=SubModel::ReverseDirection(_moveDirection);
00600 _dualBound=_GetObjectiveValue();
00601 }
00602
00603 template <class SubSolver>
00604 void TRWSPrototype<SubSolver>::GetMarginalsMove()
00605 {
00606 std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::MoveBack));
00607 _moveDirection=SubModel::ReverseDirection(_moveDirection);
00608 }
00609
00610 template <class SubSolver>
00611 typename TRWSPrototype<SubSolver>::IndexType TRWSPrototype<SubSolver>::_core_order(IndexType i,IndexType totalSize)
00612 {
00613 return (_moveDirection==SubModel::Direct ? i : totalSize-i-1);
00614 }
00615
00616 template <class SubSolver>
00617 typename TRWSPrototype<SubSolver>::IndexType TRWSPrototype<SubSolver>::_order(IndexType i)
00618 {
00619 return _core_order(i,_storage.numberOfSharedVariables());
00620 }
00621
00622 template <class SubSolver>
00623 void TRWSPrototype<SubSolver>::_FinalizeMove()
00624 {
00625 std::for_each(_subSolvers.begin(), _subSolvers.end(), std::mem_fun(&SubSolver::FinalizeMove));
00626 _moveDirection=SubModel::ReverseDirection(_moveDirection);
00627 _EstimateIntegerLabeling();
00628 }
00629
00630 #ifdef TRWS_DEBUG_OUTPUT
00631 template <class SubSolver>
00632 void TRWSPrototype<SubSolver>::PrintTestData(std::ostream& fout)const
00633 {
00634 fout << "_dualBound:" << _dualBound <<std::endl;
00635 fout << "_oldDualBound:" << _oldDualBound <<std::endl;
00636 fout << "_lastDualUpdate=" << _lastDualUpdate << std::endl;
00637 fout << "_moveDirection:" << _moveDirection <<std::endl;
00638 fout << "_integerBound=" << _integerBound << std::endl;
00639 fout << "_bestIntegerBound=" << _bestIntegerBound << std::endl;
00640 fout << "_integerLabeling=" << _integerLabeling;
00641 fout << "_bestIntegerLabeling=" << _bestIntegerLabeling;
00642 }
00643 #endif
00644
00645
00646
00647
00648
00649
00650
00651
00652
00653
00654
00655
00656
00657
00658
00659
00660
00661
00662
00663
00664
00665
00666
00667
00668
00669
00670
00671
00672
00673
00674
00675
00676
00677
00678
00679 template <class SubSolver>
00680 template <class VISITOR>
00681 typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::infer(VISITOR& visitor)
00682 {
00683 visitor.begin(value(),bound());
00684 InferenceTermination returncode=infer_visitor_updates(visitor);
00685 visitor.end(value(), bound());
00686 return returncode;
00687 }
00688
00689 template <class SubSolver>
00690 template <class VISITOR>
00691 typename TRWSPrototype<SubSolver>::InferenceTermination TRWSPrototype<SubSolver>::infer_visitor_updates(VISITOR& visitor)
00692 {
00693 _InitMove();
00694 _ForwardMove();
00695 visitor(value(),bound());
00696 _oldDualBound=_dualBound;
00697 #ifdef TRWS_DEBUG_OUTPUT
00698 _fout << "ForwardMove: dualBound=" << _dualBound <<std::endl;
00699 #endif
00700 InferenceTermination returncode;
00701 returncode=_core_infer(visitor);
00702 return returncode;
00703 }
00704
00705 template <class SubSolver>
00706 void TRWSPrototype<SubSolver>::ForwardMove()
00707 {
00708 _InitMove();
00709 _ForwardMove();
00710 _dualBound=_GetObjectiveValue();
00711 }
00712
00713
00714 template <class SubSolver>
00715 void TRWSPrototype<SubSolver>::BackwardMove()
00716 {
00717 std::vector<ValueType> averageMarginal;
00718
00719 for (IndexType i=0;i<_storage.numberOfSharedVariables();++i)
00720 {
00721 IndexType varId=_order(i);
00722 const typename Storage::SubVariableListType& varList=_storage.getSubVariableList(varId);
00723 averageMarginal.assign(_storage.numberOfLabels(varId),0.0);
00724
00725
00726 for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
00727 {
00728 SubSolver& subSolver=*_subSolvers[modelIt->subModelId_];
00729 std::vector<ValueType>& marginals=_marginals[modelIt->subModelId_];
00730 marginals.resize(_storage.numberOfLabels(varId));
00731
00732 IndexType startNodeIndex=_core_order(0,_storage.size(modelIt->subModelId_));
00733
00734 if (modelIt->subVariableId_!=startNodeIndex)
00735 subSolver.PushBack();
00736
00737 typename SubSolver::const_iterators_pair marginalsit=subSolver.GetMarginals();
00738
00739 std::copy(marginalsit.first,marginalsit.second,marginals.begin());
00740 _normalizeMarginals(marginals.begin(),marginals.end(),&subSolver);
00741 std::transform(marginals.begin(),marginals.end(),averageMarginal.begin(),averageMarginal.begin(),std::plus<ValueType>());
00742 }
00743 transform_inplace(averageMarginal.begin(),averageMarginal.end(),std::bind1st(std::multiplies<ValueType>(),-1.0/varList.size()));
00744
00745
00746
00747
00748 for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
00749 {
00750 SubSolver& subSolver=*_subSolvers[modelIt->subModelId_];
00751 std::vector<ValueType>& marginals=_marginals[modelIt->subModelId_];
00752
00753 std::transform(marginals.begin(),marginals.end(),averageMarginal.begin(),marginals.begin(),std::plus<ValueType>());
00754
00755 _postprocessMarginals(marginals.begin(),marginals.end());
00756
00757 subSolver.IncreaseUnaryWeights(marginals.begin(),marginals.end());
00758
00759 IndexType startNodeIndex=_core_order(0,_storage.size(modelIt->subModelId_));
00760
00761 if (modelIt->subVariableId_!=startNodeIndex)
00762 subSolver.UpdateMarginals();
00763 else subSolver.InitReverseMove();
00764 }
00765 }
00766
00767 _FinalizeMove();
00768 _EvaluateIntegerBounds();
00769 _dualBound=_GetObjectiveValue();
00770 }
00771
00772 template <class SubSolver>
00773 void TRWSPrototype<SubSolver>::_EstimateIntegerLabeling()
00774 {
00775 for (IndexType i=0;i<_storage.numberOfSharedVariables();++i)
00776 {
00777 IndexType varId=_order(i);
00778
00779 const typename Storage::SubVariableListType& varList=_storage.getSubVariableList(varId);
00780 _sumMarginal.assign(_storage.masterModel().numberOfLabels(varId),0.0);
00781 for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
00782 {
00783 const_marginals_iterators_pair itpair=_subSolvers[modelIt->subModelId_]->GetMarginals(modelIt->subVariableId_);
00784 _SumUpForwardMarginals(&_sumMarginal,itpair);
00785 }
00786
00787 typename PreviousFactorTable<GM>::const_iterator begin=_ftable.begin(varId,_moveDirection);
00788 typename PreviousFactorTable<GM>::const_iterator end=_ftable.end(varId,_moveDirection);
00789 for (;begin!=end;++begin)
00790 {
00791 if ((_factorProperties.getFunctionType(begin->factorId)==FunctionParameters<GM>::POTTS) && _parameters.fastComputations_)
00792 {
00793 _sumMarginal[_integerLabeling[begin->varId]]-=_factorProperties.getFunctionParameters(begin->factorId)[0];
00794 }else
00795 {
00796 const typename GM::FactorType& pwfactor=_storage.masterModel()[begin->factorId];
00797 IndexType localVarIndx = begin->localId;
00798 LabelType fixedLabel=_integerLabeling[begin->varId];
00799
00800 opengm::ViewFixVariablesFunction<GM> pencil(pwfactor,
00801 std::vector<opengm::PositionAndLabel<IndexType,LabelType> >(1,
00802 opengm::PositionAndLabel<IndexType,LabelType>(localVarIndx,
00803 fixedLabel)));
00804
00805 for (LabelType j=0;j<_sumMarginal.size();++j)
00806 _sumMarginal[j]+=pencil(&j);
00807 }
00808 }
00809 _EstimateIntegerLabel(varId,_sumMarginal);
00810 }
00811 }
00812
00813 template <class SubSolver>
00814 void TRWSPrototype<SubSolver>::_EvaluateIntegerBounds()
00815 {
00816 _integerBound=_storage.masterModel().evaluate(_integerLabeling.begin());
00817
00818 if (ACC::bop(_integerBound,_bestIntegerBound))
00819 {
00820 _bestIntegerLabeling=_integerLabeling;
00821 _bestIntegerBound=_integerBound;
00822 }
00823
00824 }
00825
00826
00827 template<class GM>
00828 DecompositionStorage<GM>::DecompositionStorage(const GM& gm,StructureType structureType):
00829 _gm(gm),
00830 _structureType(structureType),
00831 _subModels(),
00832 _variableDecomposition(),
00833 _var2FactorMap(gm)
00834 {
00835 _InitSubModels();
00836 }
00837
00838 template<class GM>
00839 DecompositionStorage<GM>::~DecompositionStorage()
00840 {
00841 for_each(_subModels.begin(),_subModels.end(),DeallocatePointer<SubModel>);
00842 }
00843
00844 template<class GM>
00845 void DecompositionStorage<GM>::_InitSubModels()
00846 {
00847 std::auto_ptr<Decomposition<GM> > pdecomposition;
00848
00849 if (_structureType==GRIDSTRUCTURE)
00850 pdecomposition=std::auto_ptr<Decomposition<GM> >(new GridDecomposition<GM>(_gm));
00851 else
00852 pdecomposition=std::auto_ptr<Decomposition<GM> >(new MonotoneChainsDecomposition<GM>(_gm));
00853
00854 try{
00855 pdecomposition->ComputeVariableDecomposition(&_variableDecomposition);
00856 size_t numberOfModels=pdecomposition->getNumberOfSubModels();
00857 _subModels.resize(numberOfModels);
00858 for (size_t modelId=0;modelId<numberOfModels;++modelId)
00859 {
00860 const typename SubModel::IndexList& varList=pdecomposition->getVariableList(modelId);
00861 typename SubModel::IndexList numOfSubModelsPerVar(varList.size());
00862
00863 for (size_t varIndx=0;varIndx<varList.size();++varIndx)
00864 numOfSubModelsPerVar[varIndx]=_variableDecomposition[varList[varIndx]].size();
00865
00866 _subModels[modelId]= new SubModel(_gm,_var2FactorMap,varList,pdecomposition->getFactorList(modelId),numOfSubModelsPerVar);
00867 };
00868 }catch(std::runtime_error& err)
00869 {
00870 throw err;
00871 }
00872 };
00873
00874 #ifdef TRWS_DEBUG_OUTPUT
00875 template<class GM>
00876 void DecompositionStorage<GM>::PrintTestData(std::ostream& fout)const
00877 {
00878 fout << "_variableDecomposition: "<<std::endl;
00879 for (size_t variableId=0;variableId<_variableDecomposition.size();++variableId)
00880 {
00881 std::for_each(_variableDecomposition[variableId].begin(),_variableDecomposition[variableId].end(),printSubVariable<typename MonotoneChainsDecomposition<GM>::SubVariable>(fout));
00882 fout << std::endl;
00883 }
00884 }
00885
00886 template<class GM>
00887 void DecompositionStorage<GM>::PrintVariableDecompositionConsistency(std::ostream& fout)const
00888 {
00889 fout << "Variable decomposition consistency:" <<std::endl;
00890 for (size_t varId=0;varId<_gm.numberOfVariables();++varId)
00891 {
00892 fout << varId<<": ";
00893 const SubVariableListType& varList=_variableDecomposition[varId];
00894 typename SubVariableListType::const_iterator modelIt=varList.begin();
00895 std::vector<ValueType> sum(_gm.numberOfLabels(varId),0.0);
00896 while (modelIt!=varList.end())
00897 {
00898 const SubModel& subModel=*_subModels[modelIt->subModelId_];
00899 std::transform(subModel.unaryFactors(modelIt->subVariableId_).begin(),subModel.unaryFactors(modelIt->subVariableId_).end(),
00900 sum.begin(),sum.begin(),std::plus<ValueType>());
00901 ++modelIt;
00902 }
00903 std::vector<ValueType> originalFactor(_gm.numberOfLabels(varId),0.0);
00904 _gm[varId].copyValues(originalFactor.begin());
00905
00906 std::transform(sum.begin(),sum.end(),originalFactor.begin(),sum.begin(),std::minus<ValueType>());
00907 fout << std::accumulate(sum.begin(),sum.end(),(ValueType)0.0)<<std::endl;
00908 }
00909
00910 }
00911 #endif
00912
00913
00914 template<class GM,class ACC>
00915 void MaxSumTRWS<GM,ACC>::_InitMove()
00916 {
00917 parent::_moveDirection=SubModel::Direct;
00918 std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::mem_fun_t<void,SubSolver>(&SubSolver::InitMove));
00919 }
00920
00921 template<class GM,class ACC>
00922 void MaxSumTRWS<GM,ACC>::_postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)
00923 {
00924 transform_inplace(begin,end,std::bind1st(std::multiplies<ValueType>(),-1.0));
00925 }
00926
00927 template<class GM,class ACC>
00928 void MaxSumTRWS<GM,ACC>::_SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair)
00929 {
00930 std::transform(itpair.first,itpair.second,pout->begin(),pout->begin(),std::plus<ValueType>());
00931 }
00932
00933 template<class GM,class ACC>
00934 void MaxSumTRWS<GM,ACC>::_EstimateTRWSBound()
00935 {
00936 if (_canonicalNormalization) return;
00937 std::vector<ValueType> bounds(parent::_storage.numberOfModels());
00938 for (size_t i=0;i<bounds.size();++i)
00939 bounds[i]=parent::_subSolvers[i]->GetObjectiveValue();
00940
00941 ValueType min=*std::min_element(bounds.begin(),bounds.end());
00942 ValueType max=*std::max_element(bounds.begin(),bounds.end());
00943 ValueType eps; ACC::iop(max-min,min-max,eps);
00944 ACC::iop(min,max,_pseudoBoundValue);
00945 #ifdef TRWS_DEBUG_OUTPUT
00946 parent::_fout <<"min="<<min<<", max="<<max<<", eps="<<eps<<", pseudo bound="<<bounds.size()*_pseudoBoundValue<<std::endl;
00947 #endif
00948 }
00949
00950
00951 template<class GM,class ACC>
00952 void MaxSumTRWS<GM,ACC>::_normalizeMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end,SubSolver* subSolver)
00953 {
00954 if (!_canonicalNormalization) return;
00955 ValueType maxVal=*std::max_element(begin,end,ACC::template bop<ValueType>);
00956 transform_inplace(begin,end,std::bind2nd(std::plus<ValueType>(),-maxVal));
00957 }
00958
00959 template<class GM,class ACC>
00960 void MaxSumTRWS<GM,ACC>::getTreeAgreement(std::vector<bool>& out,std::vector<LabelType>* plabeling)
00961 {
00962 if (plabeling!=0)
00963 plabeling->resize(parent::_storage.masterModel().numberOfVariables());
00964
00965 out.assign(parent::_storage.masterModel().numberOfVariables(),true);
00966 for (size_t varId=0;varId<parent::_storage.masterModel().numberOfVariables();++varId)
00967 {
00968 const typename Storage::SubVariableListType& varList=parent::_storage.getSubVariableList(varId);
00969 size_t label;
00970 for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin()
00971 ;modelIt!=varList.end();++modelIt)
00972 {
00973 size_t check_label=parent::_subSolvers[modelIt->subModelId_]->arg()[modelIt->subVariableId_];
00974
00975 if (plabeling!=0) (*plabeling)[varId]=check_label;
00976
00977 if (modelIt==varList.begin())
00978 {
00979 label=check_label;
00980 }else if (check_label!=label)
00981 {
00982 out[varId]=false;
00983 break;
00984 }
00985 }
00986
00987 }
00988 }
00989
00990
00991
00992 template<class GM,class ACC>
00993 bool MaxSumTRWS<GM,ACC>::CheckTreeAgreement(InferenceTermination* pterminationCode)
00994 {
00995 getTreeAgreement(_treeAgree);
00996 size_t agree_count=count(_treeAgree.begin(),_treeAgree.end(),true);
00997 #ifdef TRWS_DEBUG_OUTPUT
00998 parent::_fout << "tree-agreement: " << agree_count <<" out of "<<_treeAgree.size() <<", ="<<100*(double)agree_count/_treeAgree.size()<<"%"<<std::endl;
00999 #endif
01000
01001 if (_treeAgree.size()==agree_count)
01002 {
01003 #ifdef TRWS_DEBUG_OUTPUT
01004 parent::_fout <<"Problem solved."<<std::endl;
01005 #endif
01006 *pterminationCode=opengm::CONVERGENCE;
01007 return true;
01008 }else
01009 return false;
01010 }
01011
01012
01013 template<class GM,class ACC>
01014 bool MaxSumTRWS<GM,ACC>::_CheckStoppingCondition(InferenceTermination* pterminationCode)
01015 {
01016 if (CheckTreeAgreement(pterminationCode)) return true;
01017
01018 return parent::_CheckStoppingCondition(pterminationCode);
01019 }
01020
01021
01022 #ifdef TRWS_DEBUG_OUTPUT
01023 template<class GM,class ACC>
01024 void SumProdTRWS<GM,ACC>::PrintTestData(std::ostream& fout)const
01025 {
01026 fout << "_smoothingValue:"<<_smoothingValue <<std::endl;
01027 parent::PrintTestData(fout);
01028 }
01029 #endif
01030
01031 template<class GM,class ACC>
01032 void SumProdTRWS<GM,ACC>::_InitMove()
01033 {
01034 parent::_moveDirection=SubModel::Direct;
01035 std::for_each(parent::_subSolvers.begin(), parent::_subSolvers.end(), std::bind2nd(std::mem_fun(&SubSolver::InitMove),_smoothingValue));
01036 }
01037
01038 template<class GM,class ACC>
01039 void SumProdTRWS<GM,ACC>::_normalizeMarginals(typename std::vector<ValueType>::iterator begin,
01040 typename std::vector<ValueType>::iterator end,SubSolver* subSolver)
01041 {
01042 ValueType logPartition=subSolver->ComputeObjectiveValue();
01043
01044 transform_inplace(begin,end,std::bind2nd(std::plus<ValueType>(),-logPartition/_smoothingValue));
01045 }
01046
01047 template<class GM,class ACC>
01048 void SumProdTRWS<GM,ACC>::_postprocessMarginals(typename std::vector<ValueType>::iterator begin,typename std::vector<ValueType>::iterator end)
01049 {
01050 transform_inplace(begin,end,std::bind1st(std::multiplies<ValueType>(),-_smoothingValue));
01051 }
01052
01053 template<class GM,class ACC>
01054 void SumProdTRWS<GM,ACC>::_SumUpForwardMarginals(std::vector<ValueType>* pout,const_marginals_iterators_pair itpair)
01055 {
01056 std::transform(pout->begin(),pout->end(),itpair.first,pout->begin(),plus2ndMul<ValueType>(_smoothingValue));
01057 }
01058
01059 template<class GM,class ACC>
01060 std::pair<typename SumProdTRWS<GM,ACC>::ValueType,typename SumProdTRWS<GM,ACC>::ValueType>
01061 SumProdTRWS<GM,ACC>::GetMarginals(IndexType varId, OutputIteratorType begin)
01062 {
01063 std::fill_n(begin,parent::_storage.numberOfLabels(varId),0.0);
01064 const typename Storage::SubVariableListType& varList=parent::_storage.getSubVariableList(varId);
01065
01066 OPENGM_ASSERT(varList.size()>0);
01067
01068 for(typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
01069 {
01070 typename SubSolver::const_iterators_pair marginalsit=parent::_subSolvers[modelIt->subModelId_]->GetMarginals(modelIt->subVariableId_);
01071 std::vector<ValueType>& normMarginals=parent::_marginals[modelIt->subModelId_];
01072 normMarginals.resize(parent::_storage.numberOfLabels(varId));
01073
01074 std::copy(marginalsit.first,marginalsit.second,normMarginals.begin());
01075 _normalizeMarginals(normMarginals.begin(),normMarginals.end(),parent::_subSolvers[modelIt->subModelId_]);
01076 ValueType mul; ACC::op(1.0,-1.0,mul);
01077 transform_inplace(normMarginals.begin(),normMarginals.end(),mulAndExp<ValueType>(mul));
01078 std::transform(normMarginals.begin(),normMarginals.end(),begin,begin,std::plus<ValueType>());
01079 }
01080 transform_inplace(begin,begin+parent::_storage.numberOfLabels(varId),std::bind1st(std::multiplies<ValueType>(),1.0/varList.size()));
01081
01082 ValueType ell2Norm=0, ellInftyNorm=0;
01083 for (typename Storage::SubVariableListType::const_iterator modelIt=varList.begin();modelIt!=varList.end();++modelIt)
01084 {
01085 std::vector<ValueType>& normMarginals=parent::_marginals[modelIt->subModelId_];
01086 OutputIteratorType begin0=begin;
01087 for (typename std::vector<ValueType>::const_iterator bm=normMarginals.begin(); bm!=normMarginals.end();++bm)
01088 {
01089 ValueType diff=(*bm-*begin0); ++begin0;
01090 ell2Norm+=diff*diff;
01091 ellInftyNorm=std::max((ValueType)fabs(diff),ellInftyNorm);
01092 }
01093 }
01094
01095 return std::make_pair(sqrt(ell2Norm),ellInftyNorm);
01096 }
01097
01098 template<class GM,class ACC>
01099 typename SumProdTRWS<GM,ACC>::ValueType
01100 SumProdTRWS<GM,ACC>::GetMarginalsAndDerivativeMove()
01101 {
01102 ValueType derivativeValue=0.0;
01103
01104 for (size_t i=0;i<parent::_subSolvers.size();++i)
01105 derivativeValue+=parent::_subSolvers[i]->MoveBackGetDerivative();
01106
01107 parent::_moveDirection=SubModel::ReverseDirection(parent::_moveDirection);
01108 return derivativeValue;
01109 }
01110
01111 };
01112
01113 #endif