From 35a8e1fa3fa1bd1689119f437bd2e3a27d548332 Mon Sep 17 00:00:00 2001 From: Juha Reunanen Date: Sun, 28 Sep 2025 17:19:52 +0300 Subject: [PATCH] Improve the somewhat flaky `test_loss_multibinary_log` by avoiding samples very close to class boundaries (#3112) --- dlib/test/dnn.cpp | 45 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/dlib/test/dnn.cpp b/dlib/test/dnn.cpp index c564e277e1..cabb7d71e3 100644 --- a/dlib/test/dnn.cpp +++ b/dlib/test/dnn.cpp @@ -4457,13 +4457,50 @@ void test_multm_prev() for (size_t i = 0; i < labels.size(); ++i) { - matrix x = matrix_cast(randm(dims, 1)) * rnd.get_double_in_range(1, 9); - const auto norm = sqrt(sum(squared(x))); - if (norm < 3) + const double class_boundary_1 = 3.0; + const double class_boundary_2 = 6.0; + + const double desired_margin = 0.1; + + const auto get_random_matrix = [&rnd, dims]() + { + return matrix(matrix_cast(randm(dims, 1)) * rnd.get_double_in_range(1, 9)); + }; + + const auto get_distance_from_nearest_class_boundary = [class_boundary_1, class_boundary_2](double norm) + { + return std::min( + std::abs(norm - class_boundary_1), + std::abs(norm - class_boundary_2) + ); + }; + + auto x = get_random_matrix(); + auto norm = sqrt(sum(squared(x))); + auto distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary(norm); + + // Try again if the newly generated sample is very close to either of the class boundaries + int retry_counter = 0; + const int max_retry_counter = 10; + while (distance_from_nearest_class_boundary < desired_margin && ++retry_counter <= max_retry_counter) + { + const auto new_x = get_random_matrix(); + const auto new_norm = sqrt(sum(squared(new_x))); + const auto new_distance_from_nearest_class_boundary = get_distance_from_nearest_class_boundary(new_norm); + + if (new_distance_from_nearest_class_boundary > distance_from_nearest_class_boundary) + { + x = new_x; + norm = new_norm; + distance_from_nearest_class_boundary = new_distance_from_nearest_class_boundary; + } + } + + if (norm < class_boundary_1) { labels[i][0] = 1.f; } - else if (3 <= norm && norm < 6) + else if (class_boundary_1 <= norm && norm < class_boundary_2) { labels[i][0] = 1.f; labels[i][1] = 1.f;