00001 #ifndef TRWS_ADSAL_HXX_
00002 #define TRWS_ADSAL_HXX_
00003 #include <opengm/inference/inference.hxx>
00004 #include <opengm/inference/trws/trws_base.hxx>
00005 #include <opengm/inference/auxiliary/primal_lpbound.hxx>
00006
00007 namespace opengm{
00008
00009 template<class ValueType,class GM>
00010 struct ADSal_Parameter : public trws_base::SumProdTRWS_Parameters<ValueType>, public PrimalLPBound_Parameter<ValueType>
00011 {
00012 typedef trws_base::DecompositionStorage<GM> Storage;
00013 ADSal_Parameter(size_t numOfExternalIterations=0,
00014 ValueType precision=1.0,
00015 bool absolutePrecision=true,
00016 size_t numOfInternalIterations=3,
00017 typename Storage::StructureType decompositionType=Storage::GENERALSTRUCTURE,
00018 ValueType smoothingGapRatio=4,
00019 ValueType startSmoothingValue=0.0,
00020 ValueType primalBoundPrecision=std::numeric_limits<ValueType>::epsilon(),
00021 size_t maxPrimalBoundIterationNumber=100,
00022 size_t presolveMaxIterNumber=100,
00023 bool canonicalNormalization=true,
00024 ValueType presolveMinRelativeDualImprovement=0.01,
00025 bool lazyLPPrimalBoundComputation=true,
00026 bool lazyDerivativeComputation=true,
00027 ValueType smoothingDecayMultiplier=-1.0,
00028 bool worstCaseSmoothing=false,
00029 bool verbose=false
00030 )
00031 :trws_base::SumProdTRWS_Parameters<ValueType>(numOfInternalIterations,startSmoothingValue,precision,absolutePrecision),
00032 PrimalLPBound_Parameter<ValueType>(primalBoundPrecision,maxPrimalBoundIterationNumber),
00033 numOfExternalIterations_(numOfExternalIterations),
00034 presolveMaxIterNumber_(presolveMaxIterNumber),
00035 decompositionType_(decompositionType),
00036 smoothingGapRatio_(smoothingGapRatio),
00037 presolveMinRelativeDualImprovement_(presolveMinRelativeDualImprovement),
00038 canonicalNormalization_(canonicalNormalization),
00039 lazyLPPrimalBoundComputation_(lazyLPPrimalBoundComputation),
00040 lazyDerivativeComputation_(lazyDerivativeComputation),
00041 smoothingDecayMultiplier_(smoothingDecayMultiplier),
00042 worstCaseSmoothing_(worstCaseSmoothing),
00043 verbose_(verbose)
00044 {};
00045
00046 size_t numOfExternalIterations_;
00047 size_t presolveMaxIterNumber_;
00048 typename Storage::StructureType decompositionType_;
00049 ValueType smoothingGapRatio_;
00050 ValueType presolveMinRelativeDualImprovement_;
00051 bool canonicalNormalization_;
00052 bool lazyLPPrimalBoundComputation_;
00053 bool lazyDerivativeComputation_;
00054 ValueType smoothingDecayMultiplier_;
00055 bool worstCaseSmoothing_;
00056 bool verbose_;
00057
00058
00059
00060
00061 size_t& maxNumberOfIterations(){return numOfExternalIterations_;}
00062 size_t& numberOfInternalIterations(){return trws_base::SumProdTRWS_Parameters<ValueType>::maxNumberOfIterations_;}
00063 ValueType& precision(){return trws_base::SumProdTRWS_Parameters<ValueType>::precision_;}
00064 bool& isAbsolutePrecision(){return trws_base::SumProdTRWS_Parameters<ValueType>::absolutePrecision_;}
00065 ValueType& smoothingGapRatio(){return smoothingGapRatio_;}
00066 bool& lazyLPPrimalBoundComputation(){return lazyLPPrimalBoundComputation_;}
00067 bool& lazyDerivativeComputation(){return lazyDerivativeComputation_;}
00068 ValueType& smoothingDecayMultiplier(){return smoothingDecayMultiplier_;}
00069 bool& worstCaseSmoothing(){return worstCaseSmoothing_;}
00070
00071 typename Storage::StructureType& decompositionType(){return decompositionType_;}
00072 ValueType& startSmoothingValue(){return trws_base::SumProdTRWS_Parameters<ValueType>::smoothingValue_;}
00073 bool& fastComputations(){return trws_base::SumProdTRWS_Parameters<ValueType>::fastComputations_;}
00074 bool& canonicalNormalization(){return canonicalNormalization_;}
00075
00076
00077
00078
00079 size_t& maxNumberOfPresolveIterations(){return presolveMaxIterNumber_;}
00080 ValueType& presolveMinRelativeDualImprovement() {return presolveMinRelativeDualImprovement_;}
00081
00082
00083
00084
00085 size_t& maxPrimalBoundIterationNumber(){return PrimalLPBound_Parameter<ValueType>::maxIterationNumber_;}
00086 ValueType& primalBoundRelativePrecision(){return PrimalLPBound_Parameter<ValueType>::relativePrecision_;}
00087
00088 bool& verbose(){return verbose_;}
00089
00090 #ifdef TRWS_DEBUG_OUTPUT
00091 void print(std::ostream& fout)
00092 {
00093 fout << "maxNumberOfIterations="<< maxNumberOfIterations()<<std::endl;
00094 fout << "numberOfInternalIterations="<< numberOfInternalIterations()<<std::endl;
00095 fout << "precision=" <<precision()<<std::endl;
00096 fout <<"isAbsolutePrecision=" << isAbsolutePrecision()<< std::endl;
00097 fout <<"smoothingGapRatio=" << smoothingGapRatio()<< std::endl;
00098 fout <<"lazyLPPrimalBoundComputation="<<lazyLPPrimalBoundComputation()<< std::endl;
00099 fout <<"lazyDerivativeComputation="<< lazyDerivativeComputation()<< std::endl;
00100 fout <<"smoothingDecayMultiplier=" << smoothingDecayMultiplier()<< std::endl;
00101 fout <<"worstCaseSmoothing="<<worstCaseSmoothing()<<std::endl;
00102
00103 if (decompositionType()==Storage::GENERALSTRUCTURE)
00104 fout <<"decompositionType=" <<"GENERAL"<<std::endl;
00105 else if (decompositionType()==Storage::GRIDSTRUCTURE)
00106 fout <<"decompositionType=" <<"GRID"<<std::endl;
00107 else
00108 fout <<"decompositionType=" <<"UNKNOWN"<<std::endl;
00109
00110 fout <<"startSmoothingValue=" << startSmoothingValue()<<std::endl;
00111 fout <<"fastComputations="<<fastComputations()<<std::endl;
00112 fout <<"canonicalNormalization="<<canonicalNormalization()<<std::endl;
00113
00114
00115
00116
00117 fout <<"maxNumberOfPresolveIterations="<<maxNumberOfPresolveIterations()<<std::endl;
00118 fout <<"presolveMinRelativeDualImprovement=" <<presolveMinRelativeDualImprovement()<<std::endl;
00119
00120
00121
00122
00123 fout <<"maxPrimalBoundIterationNumber="<<maxPrimalBoundIterationNumber()<<std::endl;
00124 fout <<"primalBoundRelativePrecision=" <<primalBoundRelativePrecision()<<std::endl;
00125 }
00126 #endif
00127 };
00128
00146
00147 template<class GM, class ACC>
00148 class ADSal : public Inference<GM, ACC>
00149 {
00150 public:
00151 typedef Inference<GM, ACC> parent;
00152 typedef ACC AccumulationType;
00153 typedef GM GraphicalModelType;
00154 OPENGM_GM_TYPE_TYPEDEFS;
00155
00156 typedef trws_base::DecompositionStorage<GM> Storage;
00157 typedef trws_base::SumProdTRWS<GM,ACC> SumProdSolver;
00158 typedef trws_base::MaxSumTRWS<GM,ACC> MaxSumSolver;
00159 typedef PrimalLPBound<GM,ACC> PrimalBoundEstimator;
00160
00161 typedef ADSal_Parameter<ValueType,GM> Parameter;
00162
00163 typedef VerboseVisitor<ADSal<GM, ACC> > VerboseVisitorType;
00164 typedef TimingVisitor <ADSal<GM, ACC> > TimingVisitorType;
00165 typedef EmptyVisitor <ADSal<GM, ACC> > EmptyVisitorType;
00166
00167 ADSal(const GraphicalModelType& gm,const Parameter& param
00168 #ifdef TRWS_DEBUG_OUTPUT
00169 ,std::ostream& fout=std::cout
00170 #endif
00171 )
00172 :_parameters(param),
00173 _storage(gm,param.decompositionType_),
00174 _sumprodsolver(_storage,param
00175 #ifdef TRWS_DEBUG_OUTPUT
00176 ,(param.verbose_ ? fout : *OUT::nullstream::Instance())
00177 #endif
00178 ),
00179 _maxsumsolver(_storage,typename MaxSumSolver::Parameters(param.presolveMaxIterNumber_,param.precision_,param.absolutePrecision_,param.presolveMinRelativeDualImprovement_,param.fastComputations_,param.canonicalNormalization_)
00180 #ifdef TRWS_DEBUG_OUTPUT
00181 ,(param.verbose_ ? fout : *OUT::nullstream::Instance())
00182 #endif
00183 ),
00184 _estimator(gm,param),
00185 #ifdef TRWS_DEBUG_OUTPUT
00186 _fout(param.verbose_ ? fout : *OUT::nullstream::Instance()),
00187 #endif
00188 _bestPrimalLPbound(ACC::template neutral<ValueType>()),
00189 _bestPrimalBound(ACC::template neutral<ValueType>()),
00190 _bestDualBound(ACC::template ineutral<ValueType>()),
00191 _bestIntegerBound(ACC::template neutral<ValueType>()),
00192 _bestIntegerLabeling(_storage.masterModel().numberOfVariables(),0.0),
00193 _initializationStage(true)
00194 {
00195 if (param.numOfExternalIterations_==0) throw std::runtime_error("ADSal: a strictly positive number of iterations must be provided!");
00196 };
00197
00198 std::string name() const{ return "ADSal"; }
00199 const GraphicalModelType& graphicalModel() const { return _storage.masterModel(); }
00200 InferenceTermination infer(){EmptyVisitorType visitor; return infer(visitor);};
00201 template<class VISITOR> InferenceTermination infer(VISITOR & visitor);
00202 InferenceTermination arg(std::vector<LabelType>& out, const size_t = 1) const
00203 {out = _bestIntegerLabeling;
00204 return opengm::NORMAL;}
00205
00206 ValueType bound() const{return _bestDualBound;}
00207 ValueType value() const{return _bestIntegerBound;}
00208
00209
00210
00211
00212 InferenceTermination oldinfer();
00213 private:
00214 template<class VISITOR>
00215 InferenceTermination _Presolve(VISITOR& visitor);
00216 template<class VISITOR>
00217 void _EstimateStartingSmoothing(VISITOR& visitor);
00218
00219
00220
00221 bool _UpdateSmoothing(ValueType primalBound,ValueType dualBound, ValueType smoothDualBound, ValueType derivativeValue,size_t iterationCounterPlus1=0);
00222 bool _CheckStoppingCondition(InferenceTermination*);
00223 void _UpdatePrimalEstimator();
00224 ValueType _EstimateRhoDerivative()const;
00225 ValueType _FastEstimateRhoDerivative()const{return (_sumprodsolver.bound()-_maxsumsolver.bound())/_sumprodsolver.GetSmoothing();}
00226 ValueType _ComputeStartingWorstCaseSmoothing(ValueType primalBound,ValueType dualBound)const;
00227 ValueType _ComputeWorstCaseSmoothing(ValueType primalBound,ValueType smoothDualBound)const;
00228 ValueType _ComputeSmoothingMultiplier()const;
00229 LabelType _ComputeMaxNumberOfLabels()const;
00230 bool _SmoothingMustBeDecreased(ValueType primalBound,ValueType dualBound, ValueType smoothDualBound,std::pair<ValueType,ValueType>* lhsRhs)const;
00231 bool _isLPBoundComputed()const;
00232 void _SelectOptimalBoundsAndLabeling();
00233
00234 Parameter _parameters;
00235 Storage _storage;
00236 SumProdSolver _sumprodsolver;
00237 MaxSumSolver _maxsumsolver;
00238 PrimalBoundEstimator _estimator;
00239 #ifdef TRWS_DEBUG_OUTPUT
00240 std::ostream& _fout;
00241 #endif
00242 ValueType _bestPrimalLPbound;
00243 ValueType _bestPrimalBound;
00244
00245 ValueType _bestDualBound;
00246 ValueType _bestIntegerBound;
00247 std::vector<LabelType> _bestIntegerLabeling;
00248
00249 bool _initializationStage;
00250
00251
00252
00253 typename SumProdSolver::OutputContainerType _marginalsTemp;
00254 };
00255
00256
00257 template<class GM,class ACC>
00258 void ADSal<GM,ACC>::_SelectOptimalBoundsAndLabeling()
00259 {
00260
00261 if (ACC::bop(_sumprodsolver.value(),_maxsumsolver.value()))
00262 {
00263 _bestIntegerLabeling=_sumprodsolver.arg();
00264 _bestIntegerBound=_sumprodsolver.value();
00265 }else
00266 {
00267 _bestIntegerLabeling=_maxsumsolver.arg();
00268 _bestIntegerBound=_maxsumsolver.value();
00269 }
00270
00271
00272 ACC::op(_bestPrimalLPbound,_bestIntegerBound,_bestPrimalBound);
00273 #ifdef TRWS_DEBUG_OUTPUT
00274 _fout << "_bestPrimalBound=" <<_bestPrimalBound<<std::endl;
00275 #endif
00276
00277
00278 if (ACC::ibop(_sumprodsolver.bound(),_maxsumsolver.bound()))
00279 _bestDualBound=_sumprodsolver.bound();
00280 else
00281 _bestDualBound=_maxsumsolver.bound();
00282
00283 }
00284
00285 template<class GM,class ACC>
00286 template<class VISITOR>
00287 void ADSal<GM,ACC>::_EstimateStartingSmoothing(VISITOR& visitor)
00288 {
00289 _sumprodsolver.SetSmoothing(_ComputeStartingWorstCaseSmoothing(_maxsumsolver.value(),_maxsumsolver.bound()));
00290 #ifdef TRWS_DEBUG_OUTPUT
00291 _fout <<"_maxsumsolver.value()="<<_maxsumsolver.value()<<", _maxsumsolver.bound()="<<_maxsumsolver.bound()<<std::endl;
00292 _fout << "WorstCaseSmoothing="<<_ComputeStartingWorstCaseSmoothing(_maxsumsolver.value(),_maxsumsolver.bound())<<std::endl;
00293 #endif
00294 std::pair<ValueType,ValueType> lhsRhs;
00295 _sumprodsolver.ForwardMove();
00296 _sumprodsolver.GetMarginalsAndDerivativeMove();
00297
00298 if (!_parameters.worstCaseSmoothing_)
00299 {
00300 visitor(_maxsumsolver.value(),_maxsumsolver.bound());
00301 do{
00302 ValueType derivative=_EstimateRhoDerivative();
00303 _parameters.smoothingGapRatio_*=2;
00304 _UpdateSmoothing(_maxsumsolver.value(),_maxsumsolver.bound(),_sumprodsolver.bound(),derivative);
00305 _parameters.smoothingGapRatio_/=2;
00306 _sumprodsolver.ForwardMove();
00307 _sumprodsolver.GetMarginalsAndDerivativeMove();
00308 visitor(_maxsumsolver.value(),_maxsumsolver.bound());
00309 }while (_SmoothingMustBeDecreased(_maxsumsolver.value(),_maxsumsolver.bound(),_sumprodsolver.bound(),&lhsRhs));
00310 }else
00311 _UpdateSmoothing(_maxsumsolver.value(),_maxsumsolver.bound(),_sumprodsolver.bound(),_EstimateRhoDerivative());
00312 }
00313
00314
00315 template<class GM,class ACC>
00316 template<class VISITOR>
00317 opengm::InferenceTermination ADSal<GM,ACC>::_Presolve(VISITOR& visitor)
00318 {
00319 #ifdef TRWS_DEBUG_OUTPUT
00320 _fout << "Running TRWS presolve..."<<std::endl;
00321 #endif
00322 return _maxsumsolver.infer_visitor_updates(visitor);
00323 }
00324
00325 template<class GM,class ACC>
00326 typename ADSal<GM,ACC>::LabelType ADSal<GM,ACC>::_ComputeMaxNumberOfLabels()const
00327 {
00328 LabelType numOfLabels=0;
00329 for (size_t i=0;i<_storage.numberOfSharedVariables();++i)
00330 numOfLabels=std::max(numOfLabels,_storage.numberOfLabels(i));
00331
00332 return numOfLabels;
00333 }
00334
00335 template<class GM,class ACC>
00336 typename ADSal<GM,ACC>::ValueType ADSal<GM,ACC>::_ComputeSmoothingMultiplier()const
00337 {
00338 ValueType multiplier=0;
00339 ValueType logLabels=log((ValueType)_ComputeMaxNumberOfLabels());
00340 for (size_t i=0;i<_storage.numberOfModels();++i)
00341 multiplier+=_storage.size(i)*logLabels;
00342
00343 return multiplier;
00344 }
00345
00346 template<class GM,class ACC>
00347 typename ADSal<GM,ACC>::ValueType ADSal<GM,ACC>::_ComputeStartingWorstCaseSmoothing(ValueType primalBound,ValueType dualBound)const
00348 {
00349 return fabs((primalBound-dualBound)/_ComputeSmoothingMultiplier()/(2.0*_parameters.smoothingGapRatio_-1));
00350 }
00351
00352 template<class GM,class ACC>
00353 typename ADSal<GM,ACC>::ValueType ADSal<GM,ACC>::_ComputeWorstCaseSmoothing(ValueType primalBound,ValueType smoothDualBound)const
00354 {
00355 return fabs((primalBound-smoothDualBound)/_ComputeSmoothingMultiplier()/(2.0*_parameters.smoothingGapRatio_));
00356 }
00357
00358 template<class GM,class ACC>
00359 bool ADSal<GM,ACC>::_SmoothingMustBeDecreased(ValueType primalBound,ValueType dualBound, ValueType smoothDualBound,std::pair<ValueType,ValueType>* lhsRhs)const
00360 {
00361 if (!_parameters.worstCaseSmoothing_)
00362 {
00363 lhsRhs->first=dualBound-smoothDualBound;
00364 lhsRhs->second=(primalBound-smoothDualBound)/(2*_parameters.smoothingGapRatio_);
00365 if (_parameters.smoothingDecayMultiplier_ <= 0.0 || _initializationStage)
00366 return ACC::ibop(lhsRhs->first,lhsRhs->second);
00367 else
00368 return true;
00369 }else if (_ComputeWorstCaseSmoothing(primalBound,smoothDualBound)<_sumprodsolver.GetSmoothing())
00370 return true;
00371
00372 return false;
00373 }
00374
00375 template<class GM,class ACC>
00376 bool ADSal<GM,ACC>::_UpdateSmoothing(ValueType primalBound,ValueType dualBound, ValueType smoothDualBound, ValueType derivativeValue,size_t iterationCounterPlus1)
00377 {
00378 #ifdef TRWS_DEBUG_OUTPUT
00379 _fout << "dualBound="<<dualBound<<", smoothDualBound="<<smoothDualBound<<", derivativeValue="<<derivativeValue<<std::endl;
00380 #endif
00381
00382 std::pair<ValueType,ValueType> lhsRhs;
00383 if (_SmoothingMustBeDecreased(primalBound,dualBound,smoothDualBound,&lhsRhs) || _initializationStage)
00384 {
00385 ValueType newsmoothing;
00386
00387 if (!_parameters.worstCaseSmoothing_)
00388 newsmoothing=_sumprodsolver.GetSmoothing() - (lhsRhs.second - lhsRhs.first)*(2.0*_parameters.smoothingGapRatio_)/(2.0*_parameters.smoothingGapRatio_-1.0)/derivativeValue;
00389 else
00390 newsmoothing=_ComputeWorstCaseSmoothing(primalBound,smoothDualBound);
00391
00392 if ( (_parameters.smoothingDecayMultiplier_ > 0.0) && (!_initializationStage) )
00393 {
00394 ValueType newMulIt=_parameters.smoothingDecayMultiplier_*iterationCounterPlus1+1;
00395 ValueType oldMulIt=_parameters.smoothingDecayMultiplier_*(iterationCounterPlus1-1)+1;
00396 newsmoothing=std::min(newsmoothing,_sumprodsolver.GetSmoothing()*oldMulIt/newMulIt);
00397 }
00398
00399 if (newsmoothing > 0)
00400 if ((newsmoothing < _sumprodsolver.GetSmoothing()) || _initializationStage) _sumprodsolver.SetSmoothing(newsmoothing);
00401 #ifdef TRWS_DEBUG_OUTPUT
00402 _fout << "smoothing changed to " << _sumprodsolver.GetSmoothing()<<std::endl;
00403 #endif
00404 return true;
00405 }
00406 return false;
00407 }
00408
00409 template<class GM,class ACC>
00410 void ADSal<GM,ACC>::_UpdatePrimalEstimator()
00411 {
00412 std::pair<ValueType,ValueType> bestNorms=std::make_pair((ValueType)0.0,(ValueType)0.0);
00413 ValueType numberOfVariables=_storage.masterModel().numberOfVariables();
00414 for (IndexType var=0;var<numberOfVariables;++var)
00415 {
00416 _marginalsTemp.resize(_storage.numberOfLabels(var));
00417 std::pair<ValueType,ValueType> norms=_sumprodsolver.GetMarginals(var, _marginalsTemp.begin());
00418 bestNorms.second=std::max(bestNorms.second,norms.second);
00419 bestNorms.first+=norms.first*norms.first;
00420
00421 transform_inplace(_marginalsTemp.begin(),_marginalsTemp.end(),trws_base::make0ifless<ValueType>(norms.second));
00422 TransportSolver::_Normalize(_marginalsTemp.begin(),_marginalsTemp.end(),(ValueType)0.0);
00423 _estimator.setVariable(var,_marginalsTemp.begin());
00424 }
00425 #ifdef TRWS_DEBUG_OUTPUT
00426 _fout << "l2 gradient norm="<<sqrt(bestNorms.first)<<", "<<"l_inf gradient norm="<<bestNorms.second<<std::endl;
00427 #endif
00428 }
00429
00430 template<class GM,class ACC>
00431 bool ADSal<GM,ACC>::_isLPBoundComputed()const
00432 {
00433 return (!_parameters.lazyLPPrimalBoundComputation_ || !ACC::bop(_sumprodsolver.value(),_bestPrimalBound) );
00434 }
00435
00436 template<class GM,class ACC>
00437 bool ADSal<GM,ACC>::_CheckStoppingCondition(InferenceTermination* preturncode)
00438 {
00439 if( _isLPBoundComputed())
00440 {
00441 _UpdatePrimalEstimator();
00442
00443 ACC::op(_estimator.getTotalValue(),_bestPrimalLPbound);
00444 #ifdef TRWS_DEBUG_OUTPUT
00445 _fout << "_primalLPbound=" <<_estimator.getTotalValue()<<std::endl;
00446 #endif
00447 }
00448 _SelectOptimalBoundsAndLabeling();
00449
00450 if (_maxsumsolver.CheckTreeAgreement(preturncode)) return true;
00451
00452 if (_sumprodsolver.CheckDualityGap(_bestPrimalBound,_maxsumsolver.bound()))
00453 {
00454 #ifdef TRWS_DEBUG_OUTPUT
00455 _fout << "ADSal::_CheckStoppingCondition(): Precision attained! Problem solved!"<<std::endl;
00456 #endif
00457 *preturncode=CONVERGENCE;
00458 return true;
00459 }
00460
00461 return false;
00462 }
00463
00464
00465 template<class GM,class ACC>
00466 template<class VISITOR>
00467 InferenceTermination ADSal<GM,ACC>::infer(VISITOR & vis)
00468 {
00469 trws_base::VisitorWrapper<VISITOR,ADSal<GM, ACC> > visitor(&vis,this);
00470
00471 visitor.begin(value(),bound());
00472
00473 if (_sumprodsolver.GetSmoothing()<=0.0)
00474 {
00475 if (_Presolve(visitor)==CONVERGENCE)
00476 {
00477 _SelectOptimalBoundsAndLabeling();
00478 visitor.end(value(), bound());
00479 return NORMAL;
00480 }
00481 #ifdef TRWS_DEBUG_OUTPUT
00482 _fout <<"Switching to the smooth solver============================================"<<std::endl;
00483 #endif
00484 _EstimateStartingSmoothing(visitor);
00485 }
00486
00487 _initializationStage=false;
00488
00489 bool forwardMoveNeeded=true;
00490 for (size_t i=0;i<_parameters.numOfExternalIterations_;++i)
00491 {
00492 #ifdef TRWS_DEBUG_OUTPUT
00493 _fout <<"Main iteration Nr."<<i<<"============================================"<<std::endl;
00494 #endif
00495
00496 InferenceTermination returncode;
00497 if (forwardMoveNeeded)
00498 {
00499 returncode=_sumprodsolver.infer();
00500 forwardMoveNeeded=false;
00501 }
00502 else
00503 returncode=_sumprodsolver.core_infer();
00504
00505 if (returncode==CONVERGENCE)
00506 {
00507 _SelectOptimalBoundsAndLabeling();
00508 visitor.end(value(), bound());
00509 return NORMAL;
00510
00511 }
00512
00513 _maxsumsolver.ForwardMove();
00514 #ifdef TRWS_DEBUG_OUTPUT
00515 _fout << "_maxsumsolver.bound()=" <<_maxsumsolver.bound()<<std::endl;
00516 #endif
00517
00518 ValueType derivative;
00519 if (_isLPBoundComputed() || !_parameters.lazyDerivativeComputation())
00520 {
00521 _sumprodsolver.GetMarginalsAndDerivativeMove();
00522 derivative=_EstimateRhoDerivative();
00523 #ifdef TRWS_DEBUG_OUTPUT
00524 _fout << "derivative="<<derivative<<std::endl;
00525 #endif
00526 forwardMoveNeeded=true;
00527 }
00528 else
00529 derivative=_FastEstimateRhoDerivative();
00530
00531 if ( _CheckStoppingCondition(&returncode))
00532 {
00533 visitor.end(value(), bound());
00534 return NORMAL;
00535
00536 }
00537
00538 visitor(value(),bound());
00539
00540 if (_UpdateSmoothing(_bestPrimalBound,_maxsumsolver.bound(),_sumprodsolver.bound(),derivative,i+1))
00541 forwardMoveNeeded=true;
00542 }
00543
00544 _SelectOptimalBoundsAndLabeling();
00545 visitor.end(value(), bound());
00546
00547 return NORMAL;
00548 }
00549
00550 template<class GM,class ACC>
00551 InferenceTermination ADSal<GM,ACC>::oldinfer()
00552 {
00553 if (_sumprodsolver.GetSmoothing()<=0.0)
00554 {
00555 _EstimateStartingSmoothing();
00556
00557 }
00558
00559 for (size_t i=0;i<_parameters.numOfExternalIterations_;++i)
00560 {
00561 #ifdef TRWS_DEBUG_OUTPUT
00562 _fout <<"Main iteration Nr."<<i<<"============================================"<<std::endl;
00563 #endif
00564 for (size_t innerIt=0;innerIt<_parameters.maxNumberOfIterations_;++innerIt)
00565 {
00566 _sumprodsolver.ForwardMove();
00567 _sumprodsolver.BackwardMove();
00568 #ifdef TRWS_DEBUG_OUTPUT
00569 _fout <<"subIter="<< innerIt<<", smoothDualBound=" << _sumprodsolver.bound() <<std::endl;
00570 #endif
00571 }
00572
00573 _sumprodsolver.ForwardMove();
00574 _sumprodsolver.GetMarginalsAndDerivativeMove();
00575 _maxsumsolver.ForwardMove();
00576 ValueType derivative=_EstimateRhoDerivative();
00577 #ifdef TRWS_DEBUG_OUTPUT
00578 _fout << "derivative="<<derivative<<std::endl;
00579 #endif
00580 InferenceTermination returncode;
00581 if ( _CheckStoppingCondition(&returncode)) return returncode;
00582
00583 _UpdateSmoothing(_bestPrimalBound,_maxsumsolver.bound(),_sumprodsolver.bound(),derivative,i+1);
00584 }
00585 return opengm::NORMAL;
00586 }
00587
00588 template<class GM,class ACC>
00589 typename ADSal<GM,ACC>::ValueType
00590 ADSal<GM,ACC>::_EstimateRhoDerivative()const
00591 {
00592 ValueType derivative=0.0;
00593 for (size_t i=0;i<_storage.numberOfModels();++i)
00594 {
00595 ValueType delta;
00596 ACC::op(_sumprodsolver.getDerivative(i),(_sumprodsolver.getBound(i)-_maxsumsolver.getBound(i))/_sumprodsolver.GetSmoothing(),delta);
00597 derivative+=delta;
00598 }
00599 return derivative;
00600 }
00601
00602 }
00603 #endif