16 #include "training_rate_algorithm.h"
28 : performance_functional_pointer(NULL)
42 : performance_functional_pointer(new_performance_functional_pointer)
56 : performance_functional_pointer(NULL)
85 std::ostringstream buffer;
87 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
88 <<
"PerformanceFunctional* get_performance_functional_pointer(void) const method.\n"
89 <<
"Performance functional pointer is NULL.\n";
91 throw std::logic_error(buffer.str());
144 return(
"GoldenSection");
150 return(
"BrentMethod");
156 std::ostringstream buffer;
158 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
159 <<
"std::string get_training_rate_method(void) const method.\n"
160 <<
"Unknown training rate method.\n";
162 throw std::logic_error(buffer.str());
302 if(new_training_rate_method ==
"Fixed")
306 else if(new_training_rate_method ==
"GoldenSection")
310 else if(new_training_rate_method ==
"BrentMethod")
316 std::ostringstream buffer;
318 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
319 <<
"void set_method(const std::string&) method.\n"
320 <<
"Unknown training rate method: " << new_training_rate_method <<
".\n";
322 throw std::logic_error(buffer.str());
338 if(new_bracketing_factor < 0.0)
340 std::ostringstream buffer;
342 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
343 <<
"void set_bracketing_factor(const double&) method.\n"
344 <<
"Bracketing factor must be equal or greater than 0.\n";
346 throw std::logic_error(buffer.str());
366 if(new_training_rate_tolerance < 0.0)
368 std::ostringstream buffer;
370 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
371 <<
"void set_training_rate_tolerance(const double&) method.\n"
372 <<
"Tolerance must be equal or greater than 0.\n";
374 throw std::logic_error(buffer.str());
397 if(new_warning_training_rate < 0.0)
399 std::ostringstream buffer;
401 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
402 <<
"void set_warning_training_rate(const double&) method.\n"
403 <<
"Warning training rate must be equal or greater than 0.\n";
405 throw std::logic_error(buffer.str());
426 if(new_error_training_rate < 0.0)
428 std::ostringstream buffer;
430 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
431 <<
"void set_error_training_rate(const double&) method.\n"
432 <<
"Error training rate must be equal or greater than 0.\n";
434 throw std::logic_error(buffer.str());
473 std::ostringstream buffer;
475 buffer <<
"OpenNN Error: TrainingRateAlgorithm class.\n"
476 <<
"Vector<double> calculate_directional_point(const double&, const Vector<double>&, const double&) const method.\n"
477 <<
"Pointer to performance functional is NULL.\n";
479 throw std::logic_error(buffer.str());
488 if(neural_network_pointer == NULL)
490 std::ostringstream buffer;
492 buffer <<
"OpenNN Error: TrainingRateAlgorithm class.\n"
493 <<
"Vector<double> calculate_directional_point(const double&, const Vector<double>&, const double&) const method.\n"
494 <<
"Pointer to neural network is NULL.\n";
496 throw std::logic_error(buffer.str());
503 case TrainingRateAlgorithm::Fixed:
509 case TrainingRateAlgorithm::GoldenSection:
515 case TrainingRateAlgorithm::BrentMethod:
523 std::ostringstream buffer;
525 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class\n"
526 <<
"Vector<double> calculate_directional_point(const double&, const Vector<double>&, const double&) const method.\n"
527 <<
"Unknown training rate method.\n";
529 throw std::logic_error(buffer.str());
544 const double& performance,
546 const double& initial_training_rate)
const
550 if(training_direction == 0.0)
553 triplet.
A[1] = performance;
555 triplet.
U = triplet.
A;
557 triplet.
B = triplet.
A;
562 if(initial_training_rate == 0.0)
565 triplet.
A[1] = performance;
567 triplet.
U = triplet.
A;
569 triplet.
B = triplet.
A;
577 triplet.
A[1] = performance;
581 triplet.
B[0] = initial_training_rate;
584 while(triplet.
A[1] > triplet.
B[1])
586 triplet.
A = triplet.
B;
593 std::ostringstream buffer;
595 buffer <<
"OpenNN Warning: TrainingRateAlgorithm class.\n"
596 <<
"Vector<double> calculate_bracketing_triplet(double, const Vector<double>&, double) const method\n."
597 <<
"Right point is " << triplet.
B[0] <<
"." << std::endl;
599 buffer <<
"Performance: " << performance <<
"\n"
600 <<
"Training direction: " << training_direction <<
"\n"
601 <<
"Initial training rate: " << initial_training_rate << std::endl;
603 throw std::logic_error(buffer.str());
611 triplet.
U[0] = triplet.
A[0] + (triplet.
B[0] - triplet.
A[0])/2.0;
614 while(triplet.
A[1] < triplet.
U[1])
621 triplet.
U = triplet.
A;
622 triplet.
B = triplet.
A;
647 directional_point[0] = initial_training_rate;
650 return(directional_point);
663 (
const double& performance,
const Vector<double>& training_direction,
const double& initial_training_rate)
const
665 std::ostringstream buffer;
671 Triplet triplet = calculate_bracketing_triplet(performance, training_direction, initial_training_rate);
684 V[0] = calculate_golden_section_training_rate(triplet);
685 V[1] = performance_functional_pointer->calculate_performance(training_direction, V[0]);
689 if(V[0] < triplet.
U[0] && V[1] >= triplet.
U[1])
695 else if(V[0] < triplet.
U[0] && V[1] <= triplet.
U[1])
699 triplet.
B = triplet.
U;
701 else if(V[0] > triplet.
U[0] && V[1] >= triplet.
U[1])
707 else if(V[0] > triplet.
U[0] && V[1] <= triplet.
U[1])
709 triplet.
A = triplet.
U;
713 else if(V[0] == triplet.
U[0])
715 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
716 <<
"Vector<double> calculate_golden_section_directional_point(double, const Vector<double>, double) const method.\n"
717 <<
"Both interior points have the same ordinate.\n";
719 std::cout << buffer.str() << std::endl;
725 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
726 <<
"Vector<double> calculate_golden_section_directional_point(double, const Vector<double>, double) const method.\n"
728 <<
"A = (" << triplet.
A[0] <<
"," << triplet.
A[1] <<
")\n"
729 <<
"B = (" << triplet.
B[0] <<
"," << triplet.
B[1] <<
")\n"
730 <<
"U = (" << triplet.
U[0] <<
"," << triplet.
U[1] <<
")\n"
731 <<
"V = (" << V[0] <<
"," << V[1] <<
")\n";
733 throw std::logic_error(buffer.str());
740 }
while(triplet.
B[0] - triplet.
A[0] > training_rate_tolerance);
744 catch(
const std::logic_error& e)
746 std::cerr << e.what() << std::endl;
749 X[0] = initial_training_rate;
750 X[1] = performance_functional_pointer->calculate_performance(training_direction, X[0]);
752 if(X[1] > performance)
772 (
const double& performance,
const Vector<double>& training_direction,
const double& initial_training_rate)
const
774 std::ostringstream buffer;
780 Triplet triplet = calculate_bracketing_triplet(performance, training_direction, initial_training_rate);
782 if(triplet.
A == triplet.
B)
791 while(triplet.
B[0] - triplet.
A[0] > training_rate_tolerance)
795 V[0] = calculate_Brent_method_training_rate(triplet);
797 catch(
const std::logic_error&)
799 V[0] = calculate_golden_section_training_rate(triplet);
804 V[1] = performance_functional_pointer->calculate_performance(training_direction, V[0]);
808 if(V[0] < triplet.
U[0] && V[1] >= triplet.
U[1])
814 else if(V[0] < triplet.
U[0] && V[1] <= triplet.
U[1])
817 triplet.
B = triplet.
U;
820 else if(V[0] > triplet.
U[0] && V[1] >= triplet.
U[1])
826 else if(V[0] > triplet.
U[0] && V[1] <= triplet.
U[1])
828 triplet.
A = triplet.
U;
832 else if(V[0] == triplet.
U[0])
834 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
835 <<
"Vector<double> calculate_Brent_method_directional_point(double, const Vector<double>, double) const method.\n"
836 <<
"Both interior points have the same ordinate.\n";
842 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
843 <<
"Vector<double> calculate_Brent_method_directional_point(double, const Vector<double>, double) const method.\n"
845 <<
"A = (" << triplet.
A[0] <<
"," << triplet.
A[1] <<
")\n"
846 <<
"B = (" << triplet.
B[0] <<
"," << triplet.
B[1] <<
")\n"
847 <<
"U = (" << triplet.
U[0] <<
"," << triplet.
U[1] <<
")\n"
848 <<
"V = (" << V[0] <<
"," << V[1] <<
")\n";
850 throw std::logic_error(buffer.str());
860 catch(std::range_error& e)
862 std::cerr << e.what() << std::endl;
870 catch(
const std::logic_error& e)
872 std::cerr << e.what() << std::endl;
875 X[0] = initial_training_rate;
876 X[1] = performance_functional_pointer->calculate_performance(training_direction, X[0]);
878 if(X[1] > performance)
898 double training_rate;
900 if(triplet.
U[0] < triplet.
A[0] + 0.5*(triplet.
B[0] - triplet.
A[0]))
902 training_rate = triplet.
A[0] + 0.618*(triplet.
B[0] - triplet.
A[0]);
906 training_rate = triplet.
A[0] + 0.382*(triplet.
B[0] - triplet.
A[0]);
911 if(training_rate < triplet.
A[0])
913 std::ostringstream buffer;
915 buffer <<
"OpenNN Error: TrainingRateAlgorithm class.\n"
916 <<
"double calculate_golden_section_training_rate(const Triplet&) const method.\n"
917 <<
"Training rate (" << training_rate <<
") is less than triplet left point (" << triplet.
A[0] <<
").\n";
919 throw std::logic_error(buffer.str());
922 if(training_rate > triplet.
B[0])
924 std::ostringstream buffer;
926 buffer <<
"OpenNN Error: TrainingRateAlgorithm class.\n"
927 <<
"double calculate_golden_section_training_rate(const Triplet&) const method.\n"
928 <<
"Training rate (" << training_rate <<
") is greater than triplet right point (" << triplet.
B[0] <<
").\n";
930 throw std::logic_error(buffer.str());
935 return(training_rate);
946 const double c = -(triplet.
A[1]*(triplet.
U[0]-triplet.
B[0])
947 + triplet.
U[1]*(triplet.
B[0]-triplet.
A[0])
948 + triplet.
B[1]*(triplet.
A[0]-triplet.
U[0]))/((triplet.
A[0]-triplet.
U[0])*(triplet.
U[0]-triplet.
B[0])*(triplet.
B[0]-triplet.
A[0]));
952 std::ostringstream buffer;
954 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
955 <<
"double calculate_Brent_method_training_rate(Vector<double>&, Vector<double>&, Vector<double>&) const method.\n"
956 <<
"Parabola cannot be constructed.\n";
958 throw std::logic_error(buffer.str());
962 std::ostringstream buffer;
964 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
965 <<
"double calculate_Brent_method_training_rate(Vector<double>&, Vector<double>&, Vector<double>&) const method.\n"
966 <<
"Parabola does not have a minimum but a maximum.\n";
968 throw std::logic_error(buffer.str());
971 const double b = (triplet.
A[1]*(triplet.
U[0]*triplet.
U[0]-triplet.
B[0]*triplet.
B[0])
972 + triplet.
U[1]*(triplet.
B[0]*triplet.
B[0]-triplet.
A[0]*triplet.
A[0])
973 + triplet.
B[1]*(triplet.
A[0]*triplet.
A[0]-triplet.
U[0]*triplet.
U[0]))/((triplet.
A[0]-triplet.
U[0])*(triplet.
U[0]-triplet.
B[0])*(triplet.
B[0]-triplet.
A[0]));
975 const double Brent_method_training_rate = -b/(2.0*c);
977 if(Brent_method_training_rate <= triplet.
A[0] || Brent_method_training_rate >= triplet.
B[0])
979 std::ostringstream buffer;
981 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
982 <<
"double calculate_parabola_minimal_training_rate(Vector<double>&, Vector<double>&, Vector<double>&) const method.\n"
983 <<
"Brent method training rate is not inside interval.\n"
984 <<
"Interval: (" << triplet.
A[0] <<
"," << triplet.
B[0] <<
")\n"
985 <<
"Brent method training rate: " << Brent_method_training_rate << std::endl;
987 throw std::logic_error(buffer.str());
990 return(Brent_method_training_rate);
1001 std::ostringstream buffer;
1003 tinyxml2::XMLDocument* document =
new tinyxml2::XMLDocument;
1007 tinyxml2::XMLElement* root_element = document->NewElement(
"TrainingRateAlgorithm");
1009 document->InsertFirstChild(root_element);
1011 tinyxml2::XMLElement* element = NULL;
1012 tinyxml2::XMLText* text = NULL;
1016 element = document->NewElement(
"TrainingRateMethod");
1017 root_element->LinkEndChild(element);
1020 element->LinkEndChild(text);
1025 element = document->NewElement(
"BracketingFactor");
1026 root_element->LinkEndChild(element);
1031 text = document->NewText(buffer.str().c_str());
1032 element->LinkEndChild(text);
1037 element = document->NewElement(
"TrainingRateTolerance");
1038 root_element->LinkEndChild(element);
1043 text = document->NewText(buffer.str().c_str());
1044 element->LinkEndChild(text);
1049 element = document->NewElement(
"WarningTrainingRate");
1050 root_element->LinkEndChild(element);
1055 text = document->NewText(buffer.str().c_str());
1056 element->LinkEndChild(text);
1061 element = document->NewElement(
"ErrorTrainingRate");
1062 root_element->LinkEndChild(element);
1067 text = document->NewText(buffer.str().c_str());
1068 element->LinkEndChild(text);
1073 element = document->NewElement(
"Display");
1074 root_element->LinkEndChild(element);
1079 text = document->NewText(buffer.str().c_str());
1080 element->LinkEndChild(text);
1095 const tinyxml2::XMLElement* root_element = document.FirstChildElement(
"TrainingRateAlgorithm");
1099 std::ostringstream buffer;
1101 buffer <<
"OpenNN Exception: TrainingRateAlgorithm class.\n"
1102 <<
"void from_XML(const tinyxml2::XMLDocument&) method.\n"
1103 <<
"Training rate algorithm element is NULL.\n";
1105 throw std::logic_error(buffer.str());
1110 const tinyxml2::XMLElement* element = root_element->FirstChildElement(
"TrainingRateMethod");
1114 std::string new_training_rate_method = element->GetText();
1120 catch(
const std::logic_error& e)
1122 std::cout << e.what() << std::endl;
1129 const tinyxml2::XMLElement* element = root_element->FirstChildElement(
"BracketingFactor");
1133 const double new_bracketing_factor = atof(element->GetText());
1139 catch(
const std::logic_error& e)
1141 std::cout << e.what() << std::endl;
1167 const tinyxml2::XMLElement* element = root_element->FirstChildElement(
"TrainingRateTolerance");
1171 const double new_training_rate_tolerance = atof(element->GetText());
1177 catch(
const std::logic_error& e)
1179 std::cout << e.what() << std::endl;
1186 const tinyxml2::XMLElement* element = root_element->FirstChildElement(
"WarningTrainingRate");
1190 const double new_warning_training_rate = atof(element->GetText());
1196 catch(
const std::logic_error& e)
1198 std::cout << e.what() << std::endl;
1205 const tinyxml2::XMLElement* element = root_element->FirstChildElement(
"ErrorTrainingRate");
1209 const double new_error_training_rate = atof(element->GetText());
1215 catch(
const std::logic_error& e)
1217 std::cout << e.what() << std::endl;
1224 const tinyxml2::XMLElement* element = root_element->FirstChildElement(
"Display");
1228 const std::string new_display = element->GetText();
1234 catch(
const std::logic_error& e)
1236 std::cout << e.what() << std::endl;
TrainingRateMethod training_rate_method
Variable containing the actual method used to obtain a suitable perform_training rate.
double calculate_Brent_method_training_rate(const Triplet &) const
PerformanceFunctional * get_performance_functional_pointer(void) const
Vector< double > A
Left point of the triplet.
const double & get_error_training_rate(void) const
void set_performance_functional_pointer(PerformanceFunctional *)
bool display
Display messages to screen.
std::string write_training_rate_method(void) const
Returns a string with the name of the training rate method to be used.
tinyxml2::XMLDocument * to_XML(void) const
Vector< double > U
Interior point of the triplet.
Vector< double > calculate_fixed_directional_point(const double &, const Vector< double > &, const double &) const
Vector< double > calculate_Brent_method_directional_point(const double &, const Vector< double > &, const double &) const
bool has_performance_functional(void) const
Vector< double > calculate_directional_point(const double &, const Vector< double > &, const double &) const
double calculate_golden_section_training_rate(const Triplet &) const
TrainingRateMethod
Available training operators for obtaining the perform_training rate.
bool has_length_zero(void) const
void set_warning_training_rate(const double &)
void set_training_rate_tolerance(const double &)
double bracketing_factor
Increase factor when bracketing a minimum.
virtual ~TrainingRateAlgorithm(void)
Destructor.
virtual void set_default(void)
Sets the members of the training rate algorithm to their default values.
TrainingRateAlgorithm(void)
void from_XML(const tinyxml2::XMLDocument &)
const double & get_warning_training_rate(void) const
double warning_training_rate
Big training rate value at which the algorithm displays a warning.
void set_error_training_rate(const double &)
const bool & get_display(void) const
Triplet calculate_bracketing_triplet(const double &, const Vector< double > &, const double &) const
const double & get_training_rate_tolerance(void) const
Returns the tolerance value in line minimization.
const double & get_bracketing_factor(void) const
Returns the increase factor when bracketing a minimum in line minimization.
double training_rate_tolerance
Maximum interval length for the training rate.
Vector< double > calculate_golden_section_directional_point(const double &, const Vector< double > &, const double &) const
Vector< double > B
Right point of the triplet.
void set_bracketing_factor(const double &)
PerformanceFunctional * performance_functional_pointer
Pointer to an external performance functional object.
void set_display(const bool &)
double error_training_rate
Big training rate value at which the algorithm throws an exception.
const TrainingRateMethod & get_training_rate_method(void) const
Returns the training rate method used for training.
void set_training_rate_method(const TrainingRateMethod &)