ITKCommand.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. #pragma once
  2. #include <itkCommand.h>
  3. #include <itkRegularStepGradientDescentOptimizer.h>
  4. #include <itkRegularStepGradientDescentOptimizerv4.h>
  5. class CommandIterationUpdate : public itk::Command
  6. {
  7. public:
  8. using Self = CommandIterationUpdate;
  9. using Superclass = itk::Command;
  10. using Pointer = itk::SmartPointer<Self>;
  11. itkNewMacro(Self);
  12. protected:
  13. CommandIterationUpdate() = default;
  14. public:
  15. using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
  16. using OptimizerPointer = const OptimizerType*;
  17. void
  18. Execute(itk::Object* caller, const itk::EventObject& event) override
  19. {
  20. Execute((const itk::Object*)caller, event);
  21. }
  22. void
  23. Execute(const itk::Object* object, const itk::EventObject& event) override
  24. {
  25. auto optimizer = static_cast<OptimizerPointer>(object);
  26. if (!(itk::IterationEvent().CheckEvent(&event)))
  27. {
  28. return;
  29. }
  30. std::cout << optimizer->GetCurrentIteration() << " ";
  31. std::cout << optimizer->GetValue() << " ";
  32. std::cout << optimizer->GetCurrentPosition() << " ";
  33. std::cout << m_CumulativeIterationIndex++ << std::endl;
  34. }
  35. private:
  36. unsigned int m_CumulativeIterationIndex{ 0 };
  37. };
  38. template <typename TRegistration>
  39. class RegistrationInterfaceCommand : public itk::Command
  40. {
  41. public:
  42. using Self = RegistrationInterfaceCommand;
  43. using Superclass = itk::Command;
  44. using Pointer = itk::SmartPointer<Self>;
  45. itkNewMacro(Self);
  46. protected:
  47. RegistrationInterfaceCommand() = default;
  48. public:
  49. using RegistrationType = TRegistration;
  50. using RegistrationPointer = RegistrationType*;
  51. using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
  52. using OptimizerPointer = OptimizerType*;
  53. void
  54. Execute(itk::Object* object, const itk::EventObject& event) override
  55. {
  56. // First we verify that the event invoked is of the right type,
  57. // \code{itk::MultiResolutionIterationEvent()}.
  58. // If not, we return without any further action.
  59. if (!(itk::MultiResolutionIterationEvent().CheckEvent(&event)))
  60. {
  61. return;
  62. }
  63. // We then convert the input object pointer to a RegistrationPointer.
  64. // Note that no error checking is done here to verify the
  65. // \code{dynamic\_cast} was successful since we know the actual object
  66. // is a registration method. Then we ask for the optimizer object
  67. // from the registration method.
  68. auto registration = static_cast<RegistrationPointer>(object);
  69. auto optimizer =
  70. static_cast<OptimizerPointer>(registration->GetModifiableOptimizer());
  71. unsigned int currentLevel = registration->GetCurrentLevel();
  72. typename RegistrationType::ShrinkFactorsPerDimensionContainerType
  73. shrinkFactors =
  74. registration->GetShrinkFactorsPerDimension(currentLevel);
  75. typename RegistrationType::SmoothingSigmasArrayType smoothingSigmas =
  76. registration->GetSmoothingSigmasPerLevel();
  77. std::cout << "-------------------------------------" << std::endl;
  78. std::cout << " Current level = " << currentLevel << std::endl;
  79. std::cout << " shrink factor = " << shrinkFactors << std::endl;
  80. std::cout << " smoothing sigma = ";
  81. std::cout << smoothingSigmas[currentLevel] << std::endl;
  82. std::cout << std::endl;
  83. // If this is the first resolution level we set the learning rate
  84. // (representing the first step size) and the minimum step length
  85. // (representing the convergence criterion) to large values. At each
  86. // subsequent resolution level, we will reduce the minimum step length by
  87. // a factor of 5 in order to allow the optimizer to focus on progressively
  88. // smaller regions. The learning rate is set up to the current step
  89. // length. In this way, when the optimizer is reinitialized at the
  90. // beginning of the registration process for the next level, the step
  91. // length will simply start with the last value used for the previous
  92. // level. This will guarantee the continuity of the path taken by the
  93. // optimizer through the parameter space.
  94. double currStepLen0 = /*16.00*/ /*8.0*/ 16.0;
  95. double miniStepLen0 = /*2.5*/ /*0.04*/ 0.08;
  96. if (registration->GetCurrentLevel() == 0)
  97. {
  98. optimizer->SetLearningRate(currStepLen0);
  99. optimizer->SetMinimumStepLength(miniStepLen0);
  100. std::cout << " Current Step Length = " << currStepLen0
  101. << ", Minimum Step Length = " << miniStepLen0 << std::endl;
  102. }
  103. else
  104. {
  105. double currStepLen = optimizer->GetCurrentStepLength();
  106. double miniStepLen = optimizer->GetMinimumStepLength();
  107. optimizer->SetLearningRate(currStepLen);
  108. optimizer->SetMinimumStepLength(miniStepLen * 0.2);
  109. std::cout << " Current Step Length = " << currStepLen
  110. << ", Minimum Step Length = " << miniStepLen << std::endl;
  111. }
  112. }
  113. // Another version of the \code{Execute()} method accepting a \code{const}
  114. // input object is also required since this method is defined as pure
  115. // virtual in the base class. This version simply returns without taking any action.
  116. void
  117. Execute(const itk::Object*, const itk::EventObject&) override
  118. {
  119. return;
  120. }
  121. };
  122. // The following section of code implements a Command observer
  123. // that will monitor the configurations of the registration process
  124. // at every change of stage and resolution level.
  125. template <typename TRegistration>
  126. class RegistrationInterfaceCommand1 : public itk::Command
  127. {
  128. public:
  129. using Self = RegistrationInterfaceCommand1;
  130. using Superclass = itk::Command;
  131. using Pointer = itk::SmartPointer<Self>;
  132. itkNewMacro(Self);
  133. protected:
  134. RegistrationInterfaceCommand1() = default;
  135. public:
  136. using RegistrationType = TRegistration;
  137. // The Execute function simply calls another version of the \code{Execute()}
  138. // method accepting a \code{const} input object
  139. void
  140. Execute(itk::Object* object, const itk::EventObject& event) override
  141. {
  142. Execute((const itk::Object*)object, event);
  143. }
  144. void
  145. Execute(const itk::Object* object, const itk::EventObject& event) override
  146. {
  147. if (!(itk::MultiResolutionIterationEvent().CheckEvent(&event)))
  148. {
  149. return;
  150. }
  151. std::cout << "\nObserving from class " << object->GetNameOfClass();
  152. if (!object->GetObjectName().empty())
  153. {
  154. std::cout << " \"" << object->GetObjectName() << "\"" << std::endl;
  155. }
  156. const auto* registration = static_cast<const RegistrationType*>(object);
  157. unsigned int currentLevel = registration->GetCurrentLevel();
  158. typename RegistrationType::ShrinkFactorsPerDimensionContainerType
  159. shrinkFactors =
  160. registration->GetShrinkFactorsPerDimension(currentLevel);
  161. typename RegistrationType::SmoothingSigmasArrayType smoothingSigmas =
  162. registration->GetSmoothingSigmasPerLevel();
  163. std::cout << "-------------------------------------" << std::endl;
  164. std::cout << " Current multi-resolution level = " << currentLevel
  165. << std::endl;
  166. std::cout << " shrink factor = " << shrinkFactors << std::endl;
  167. std::cout << " smoothing sigma = " << smoothingSigmas[currentLevel]
  168. << std::endl;
  169. std::cout << std::endl;
  170. }
  171. };
  172. // The following section of code implements an observer
  173. // that will monitor the evolution of the registration process.
  174. class CommandIterationUpdate1 : public itk::Command
  175. {
  176. public:
  177. using Self = CommandIterationUpdate1;
  178. using Superclass = itk::Command;
  179. using Pointer = itk::SmartPointer<Self>;
  180. itkNewMacro(Self);
  181. protected:
  182. CommandIterationUpdate1() = default;
  183. public:
  184. using OptimizerType = itk::GradientDescentOptimizerv4Template<double>;
  185. using OptimizerPointer = const OptimizerType*;
  186. void
  187. Execute(itk::Object* caller, const itk::EventObject& event) override
  188. {
  189. Execute((const itk::Object*)caller, event);
  190. }
  191. void
  192. Execute(const itk::Object* object, const itk::EventObject& event) override
  193. {
  194. auto optimizer = static_cast<OptimizerPointer>(object);
  195. if (!(itk::IterationEvent().CheckEvent(&event)))
  196. {
  197. return;
  198. }
  199. std::cout << optimizer->GetCurrentIteration() << " ";
  200. std::cout << optimizer->GetValue() << " ";
  201. std::cout << optimizer->GetCurrentPosition() << " "
  202. << m_CumulativeIterationIndex++ << std::endl;
  203. }
  204. private:
  205. unsigned int m_CumulativeIterationIndex{ 0 };
  206. };
  207. class CommandIterationUpdate2 : public itk::Command
  208. {
  209. public:
  210. using Self = CommandIterationUpdate2;
  211. using Superclass = itk::Command;
  212. using Pointer = itk::SmartPointer<Self>;
  213. itkNewMacro(Self);
  214. protected:
  215. CommandIterationUpdate2() = default;
  216. public:
  217. using OptimizerType = itk::RegularStepGradientDescentOptimizer;
  218. using OptimizerPointer = const OptimizerType*;
  219. void
  220. Execute(itk::Object* caller, const itk::EventObject& event) override
  221. {
  222. Execute((const itk::Object*)caller, event);
  223. }
  224. void
  225. Execute(const itk::Object* object, const itk::EventObject& event) override
  226. {
  227. auto optimizer = static_cast<OptimizerPointer>(object);
  228. if (!(itk::IterationEvent().CheckEvent(&event)))
  229. {
  230. return;
  231. }
  232. std::cout << optimizer->GetCurrentIteration() << " ";
  233. std::cout << optimizer->GetValue() << " ";
  234. std::cout << optimizer->GetCurrentPosition() << std::endl;
  235. }
  236. };
  237. template <typename TRegistration>
  238. class RegistrationInterfaceCommand2 : public itk::Command
  239. {
  240. public:
  241. using Self = RegistrationInterfaceCommand2;
  242. using Superclass = itk::Command;
  243. using Pointer = itk::SmartPointer<Self>;
  244. itkNewMacro(Self);
  245. protected:
  246. RegistrationInterfaceCommand2() = default;
  247. public:
  248. using RegistrationType = TRegistration;
  249. using RegistrationPointer = RegistrationType*;
  250. using OptimizerType = itk::RegularStepGradientDescentOptimizer;
  251. using OptimizerPointer = OptimizerType*;
  252. void Execute(itk::Object* object, const itk::EventObject& event) override
  253. {
  254. if (!(itk::IterationEvent().CheckEvent(&event)))
  255. {
  256. return;
  257. }
  258. auto registration = static_cast<RegistrationPointer>(object);
  259. if (registration == nullptr)
  260. {
  261. return;
  262. }
  263. auto optimizer =
  264. static_cast<OptimizerPointer>(registration->GetModifiableOptimizer());
  265. std::cout << "-------------------------------------" << std::endl;
  266. std::cout << "MultiResolution Level : " << registration->GetCurrentLevel()
  267. << std::endl;
  268. std::cout << std::endl;
  269. if (registration->GetCurrentLevel() == 0)
  270. {
  271. optimizer->SetMaximumStepLength(16.00);
  272. optimizer->SetMinimumStepLength(0.01);
  273. }
  274. else
  275. {
  276. optimizer->SetMaximumStepLength(optimizer->GetMaximumStepLength() / 4.0);
  277. optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() / 10.0);
  278. }
  279. }
  280. void Execute(const itk::Object*, const itk::EventObject&) override
  281. {
  282. return;
  283. }
  284. };